Merge Android 12
Bug: 202323961
Merged-In: I435267199a30b921b02e9ee11c8aab6e56d8d988
Change-Id: I04b2c2a297f86f5769fdf1bf57304277d7016e03
diff --git a/OWNERS b/OWNERS
index fb052ec..81cfdb8 100644
--- a/OWNERS
+++ b/OWNERS
@@ -1,7 +1,7 @@
# Default code reviewers picked from top 3 or more developers.
# Please update this list if you find better candidates.
+tonymak@google.com
+toki@google.com
zilka@google.com
mns@google.com
-toki@google.com
jalt@google.com
-tonymak@google.com
diff --git a/TEST_MAPPING b/TEST_MAPPING
index 3c8e10b..93ea6d6 100644
--- a/TEST_MAPPING
+++ b/TEST_MAPPING
@@ -10,6 +10,26 @@
},
{
"name": "libtextclassifier_tests"
+ },
+ {
+ "name": "libtextclassifier_java_tests"
+ },
+ {
+ "name": "TextClassifierNotificationTests"
+ }
+ ],
+ "mainline-presubmit": [
+ {
+ "name": "TextClassifierNotificationTests[com.google.android.extservices.apex]"
+ },
+ {
+ "name": "TextClassifierServiceTest[com.google.android.extservices.apex]"
+ },
+ {
+ "name": "libtextclassifier_tests[com.google.android.extservices.apex]"
+ },
+ {
+ "name": "libtextclassifier_java_tests[com.google.android.extservices.apex]"
}
]
}
\ No newline at end of file
diff --git a/abseil-cpp/Android.bp b/abseil-cpp/Android.bp
index 52d2575..a3635f3 100644
--- a/abseil-cpp/Android.bp
+++ b/abseil-cpp/Android.bp
@@ -45,6 +45,7 @@
"com.android.extservices",
],
sdk_version: "current",
+ min_sdk_version: "30",
stl: "libc++_static",
exclude_srcs: [
"**/*_test.cc",
diff --git a/coverage/Android.bp b/coverage/Android.bp
new file mode 100644
index 0000000..943458d
--- /dev/null
+++ b/coverage/Android.bp
@@ -0,0 +1,22 @@
+package {
+ // See: http://go/android-license-faq
+ // A large-scale-change added 'default_applicable_licenses' to import
+ // all of the 'license_kinds' from "external_libtextclassifier_license"
+ // to get the below license kinds:
+ // SPDX-license-identifier-Apache-2.0
+ default_applicable_licenses: ["external_libtextclassifier_license"],
+}
+
+android_library {
+ name: "TextClassifierCoverageLib",
+
+ srcs: ["src/**/*.java"],
+
+ static_libs: [
+ "androidx.test.ext.junit",
+ "androidx.test.rules",
+ ],
+
+ min_sdk_version: "30",
+ sdk_version: "current",
+}
diff --git a/coverage/AndroidManifest.xml b/coverage/AndroidManifest.xml
new file mode 100644
index 0000000..1897421
--- /dev/null
+++ b/coverage/AndroidManifest.xml
@@ -0,0 +1,9 @@
+<manifest xmlns:android="http://schemas.android.com/apk/res/android"
+ package="com.android.textclassifier.testing">
+
+ <uses-sdk android:minSdkVersion="29" />
+
+ <application>
+ </application>
+
+</manifest>
\ No newline at end of file
diff --git a/coverage/src/com/android/textclassifier/testing/SignalMaskInfo.java b/coverage/src/com/android/textclassifier/testing/SignalMaskInfo.java
new file mode 100644
index 0000000..73ed185
--- /dev/null
+++ b/coverage/src/com/android/textclassifier/testing/SignalMaskInfo.java
@@ -0,0 +1,128 @@
+/*
+ * Copyright (C) 2020 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License
+ */
+package com.android.textclassifier.testing;
+
+import java.io.BufferedReader;
+import java.io.FileReader;
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * Class for reading a process' signal masks from the /proc filesystem. Looks for the
+ * BLOCKED, CAUGHT, IGNORED and PENDING masks from /proc/self/status, each of which is a
+ * 64 bit bitmask with one bit per signal.
+ *
+ * Maintains a map from SignalMaskInfo.Type to the bitmask. The {@code isValid} method
+ * will only return true if all 4 masks were successfully parsed. Provides lookup
+ * methods per signal, e.g. {@code isPending(signum)} which will throw
+ * {@code IllegalStateException} if the current data is not valid.
+ */
+public class SignalMaskInfo {
+ private enum Type {
+ BLOCKED("SigBlk"),
+ CAUGHT("SigCgt"),
+ IGNORED("SigIgn"),
+ PENDING("SigPnd");
+ // The tag for this mask in /proc/self/status
+ private final String tag;
+
+ Type(String tag) {
+ this.tag = tag + ":\t";
+ }
+
+ public String getTag() {
+ return tag;
+ }
+
+ public static Map<Type, Long> parseProcinfo(String path) {
+ Map<Type, Long> map = new HashMap<>();
+ try (BufferedReader reader = new BufferedReader(new FileReader(path))) {
+ String line;
+ while ((line = reader.readLine()) != null) {
+ for (Type mask : values()) {
+ long value = mask.tryToParse(line);
+ if (value >= 0) {
+ map.put(mask, value);
+ }
+ }
+ }
+ } catch (NumberFormatException | IOException e) {
+ // Ignored - the map will end up being invalid instead.
+ }
+ return map;
+ }
+
+ private long tryToParse(String line) {
+ if (line.startsWith(tag)) {
+ return Long.valueOf(line.substring(tag.length()), 16);
+ } else {
+ return -1;
+ }
+ }
+ }
+
+ private static final String PROCFS_PATH = "/proc/self/status";
+ private Map<Type, Long> maskMap = null;
+
+ SignalMaskInfo() {
+ refresh();
+ }
+
+ public void refresh() {
+ maskMap = Type.parseProcinfo(PROCFS_PATH);
+ }
+
+ public boolean isValid() {
+ return (maskMap != null && maskMap.size() == Type.values().length);
+ }
+
+ public boolean isCaught(int signal) {
+ return isSignalInMask(signal, Type.CAUGHT);
+ }
+
+ public boolean isBlocked(int signal) {
+ return isSignalInMask(signal, Type.BLOCKED);
+ }
+
+ public boolean isPending(int signal) {
+ return isSignalInMask(signal, Type.PENDING);
+ }
+
+ public boolean isIgnored(int signal) {
+ return isSignalInMask(signal, Type.IGNORED);
+ }
+
+ private void checkValid() {
+ if (!isValid()) {
+ throw new IllegalStateException();
+ }
+ }
+
+ private boolean isSignalInMask(int signal, Type mask) {
+ long bit = 1L << (signal - 1);
+ return (getSignalMask(mask) & bit) != 0;
+ }
+
+ private long getSignalMask(Type mask) {
+ checkValid();
+ Long value = maskMap.get(mask);
+ if (value == null) {
+ throw new IllegalStateException();
+ }
+ return value;
+ }
+}
diff --git a/coverage/src/com/android/textclassifier/testing/TextClassifierInstrumentationListener.java b/coverage/src/com/android/textclassifier/testing/TextClassifierInstrumentationListener.java
new file mode 100644
index 0000000..a78ca98
--- /dev/null
+++ b/coverage/src/com/android/textclassifier/testing/TextClassifierInstrumentationListener.java
@@ -0,0 +1,117 @@
+/*
+ * Copyright (C) 2020 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License
+ */
+
+package com.android.textclassifier.testing;
+
+import android.system.ErrnoException;
+import android.system.Os;
+import android.util.Log;
+
+import androidx.test.internal.runner.listener.InstrumentationRunListener;
+
+import org.junit.runner.Result;
+
+/**
+ * To make native coverage measurement possible.
+ */
+public class TextClassifierInstrumentationListener extends InstrumentationRunListener {
+ private static final String LOG_TAG = "androidtc";
+ // Signal used to trigger a dump of Clang coverage information.
+ // See {@code maybeDumpNativeCoverage} below.
+ private static final int COVERAGE_SIGNAL = 37;
+
+ @Override
+ public void testRunFinished(Result result) throws Exception {
+ maybeDumpNativeCoverage();
+ super.testRunFinished(result);
+ }
+
+ /**
+ * If this test process is instrumented for native coverage, then trigger a dump
+ * of the coverage data and wait until either we detect the dumping has finished or 60 seconds,
+ * whichever is shorter.
+ *
+ * Background: Coverage builds install a signal handler for signal 37 which flushes coverage
+ * data to disk, which may take a few seconds. Tests running as an app process will get
+ * killed with SIGKILL once the app code exits, even if the coverage handler is still running.
+ *
+ * Method: If a handler is installed for signal 37, then assume this is a coverage run and
+ * send signal 37. The handler is non-reentrant and so signal 37 will then be blocked until
+ * the handler completes. So after we send the signal, we loop checking the blocked status
+ * for signal 37 until we hit the 60 second deadline. If the signal is blocked then sleep for
+ * 2 seconds, and if it becomes unblocked then the handler exitted so we can return early.
+ * If the signal is not blocked at the start of the loop then most likely the handler has
+ * not yet been invoked. This should almost never happen as it should get blocked on delivery
+ * when we call {@code Os.kill()}, so sleep for a shorter duration (100ms) and try again. There
+ * is a race condition here where the handler is delayed but then runs for less than 100ms and
+ * gets missed, in which case this method will loop with 100ms sleeps until the deadline.
+ *
+ * In the case where the handler runs for more than 60 seconds, the test process will be allowed
+ * to exit so coverage information may be incomplete.
+ *
+ * There is no API for determining signal dispositions, so this method uses the
+ * {@link SignalMaskInfo} class to read the data from /proc. If there is an error parsing
+ * the /proc data then this method will also loop until the 60s deadline passes.
+ */
+ private void maybeDumpNativeCoverage() {
+ SignalMaskInfo siginfo = new SignalMaskInfo();
+ if (!siginfo.isValid()) {
+ Log.e(LOG_TAG, "Invalid signal info");
+ return;
+ }
+
+ if (!siginfo.isCaught(COVERAGE_SIGNAL)) {
+ // Process is not instrumented for coverage
+ Log.i(LOG_TAG, "Not dumping coverage, no handler installed");
+ return;
+ }
+
+ Log.i(LOG_TAG,
+ String.format("Sending coverage dump signal %d to pid %d uid %d", COVERAGE_SIGNAL,
+ Os.getpid(), Os.getuid()));
+ try {
+ Os.kill(Os.getpid(), COVERAGE_SIGNAL);
+ } catch (ErrnoException e) {
+ Log.e(LOG_TAG, "Unable to send coverage signal", e);
+ return;
+ }
+
+ long start = System.currentTimeMillis();
+ long deadline = start + 60 * 1000L;
+ while (System.currentTimeMillis() < deadline) {
+ siginfo.refresh();
+ try {
+ if (siginfo.isValid() && siginfo.isBlocked(COVERAGE_SIGNAL)) {
+ // Signal is currently blocked so assume a handler is running
+ Thread.sleep(2000L);
+ siginfo.refresh();
+ if (siginfo.isValid() && !siginfo.isBlocked(COVERAGE_SIGNAL)) {
+ // Coverage handler exited while we were asleep
+ Log.i(LOG_TAG,
+ String.format("Coverage dump detected finished after %dms",
+ System.currentTimeMillis() - start));
+ break;
+ }
+ } else {
+ // Coverage signal handler not yet started or invalid siginfo
+ Thread.sleep(100L);
+ }
+ } catch (InterruptedException e) {
+ // ignored
+ }
+ }
+ }
+}
\ No newline at end of file
diff --git a/java/Android.bp b/java/Android.bp
index 655692b..ca34a66 100644
--- a/java/Android.bp
+++ b/java/Android.bp
@@ -29,7 +29,7 @@
static_libs: ["TextClassifierServiceLib"],
jni_libs: ["libtextclassifier"],
sdk_version: "system_current",
- min_sdk_version: "28",
+ min_sdk_version: "30",
certificate: "platform",
optimize: {
proguard_flags_files: ["proguard.flags"],
@@ -42,8 +42,11 @@
name: "TextClassifierServiceLib",
static_libs: ["TextClassifierServiceLibNoManifest"],
sdk_version: "system_current",
- min_sdk_version: "28",
+ min_sdk_version: "30",
manifest: "AndroidManifest.xml",
+ aaptflags: [
+ "-0 .model",
+ ],
}
// Similar to TextClassifierServiceLib, but without the AndroidManifest.
@@ -60,12 +63,17 @@
"error_prone_annotations",
],
sdk_version: "system_current",
- min_sdk_version: "28",
+ min_sdk_version: "30",
+ aaptflags: [
+ "-0 .model",
+ ],
+
}
java_library {
name: "textclassifier-statsd",
sdk_version: "system_current",
+ min_sdk_version: "30",
srcs: [
":statslog-textclassifier-java-gen",
],
@@ -75,7 +83,7 @@
name: "statslog-textclassifier-java-gen",
tools: ["stats-log-api-gen"],
cmd: "$(location stats-log-api-gen) --java $(out) --module textclassifier" +
- " --javaPackage com.android.textclassifier --javaClass TextClassifierStatsLog" +
- " --minApiLevel 30",
- out: ["com/android/textclassifier/TextClassifierStatsLog.java"],
+ " --javaPackage com.android.textclassifier.common.statsd" +
+ " --javaClass TextClassifierStatsLog --minApiLevel 30",
+ out: ["com/android/textclassifier/common/statsd/TextClassifierStatsLog.java"],
}
diff --git a/java/AndroidManifest.xml b/java/AndroidManifest.xml
index 9f02689..8ef323c 100644
--- a/java/AndroidManifest.xml
+++ b/java/AndroidManifest.xml
@@ -18,7 +18,7 @@
-->
<!--
- This manifest file is for the standalone TCS used for testing.
+ This manifest file is for the tcs library.
The TCS is typically shipped as part of ExtServices and is configured
in ExtServices's manifest.
-->
@@ -27,21 +27,23 @@
android:versionCode="1"
android:versionName="1.0.0">
- <uses-sdk android:minSdkVersion="29" android:targetSdkVersion="29"/>
+ <uses-sdk android:minSdkVersion="30" android:targetSdkVersion="30"/>
+ <uses-permission android:name="android.permission.QUERY_ALL_PACKAGES" />
<uses-permission android:name="android.permission.ACCESS_COARSE_LOCATION" />
- <application android:label="@string/tcs_app_name"
- android:icon="@drawable/tcs_app_icon"
- android:extractNativeLibs="false">
+
+ <application>
+
<service
android:exported="true"
+ android:directBootAware="false"
android:name=".DefaultTextClassifierService"
android:permission="android.permission.BIND_TEXTCLASSIFIER_SERVICE">
<intent-filter>
<action android:name="android.service.textclassifier.TextClassifierService"/>
</intent-filter>
</service>
-
</application>
+
</manifest>
diff --git a/java/assets/textclassifier/actions_suggestions.universal.model b/java/assets/textclassifier/actions_suggestions.universal.model
new file mode 100755
index 0000000..f74fed4
--- /dev/null
+++ b/java/assets/textclassifier/actions_suggestions.universal.model
Binary files differ
diff --git a/java/assets/textclassifier/annotator.universal.model b/java/assets/textclassifier/annotator.universal.model
new file mode 100755
index 0000000..09f1e0b
--- /dev/null
+++ b/java/assets/textclassifier/annotator.universal.model
Binary files differ
diff --git a/java/assets/textclassifier/lang_id.model b/java/assets/textclassifier/lang_id.model
new file mode 100644
index 0000000..e94dada
--- /dev/null
+++ b/java/assets/textclassifier/lang_id.model
Binary files differ
diff --git a/java/lint-baseline.xml b/java/lint-baseline.xml
deleted file mode 100644
index 6f91923..0000000
--- a/java/lint-baseline.xml
+++ /dev/null
@@ -1,15 +0,0 @@
-<?xml version="1.0" encoding="UTF-8"?>
-<issues format="5" by="lint 4.1.0" client="cli" variant="all" version="4.1.0">
-
- <issue
- id="NewApi"
- message="Call requires API level R (current min is 29): `android.view.textclassifier.TextClassificationSessionId#getValue`"
- errorLine1=" return TextClassificationSessionId.unflattenFromString(sessionId.getValue());"
- errorLine2=" ~~~~~~~~">
- <location
- file="external/libtextclassifier/java/src/com/android/textclassifier/common/statsd/TextClassificationSessionIdConverter.java"
- line="36"
- column="70"/>
- </issue>
-
-</issues>
diff --git a/java/res/drawable/tcs_app_icon.xml b/java/res/drawable/tcs_app_icon.xml
deleted file mode 100644
index 8cce7ca..0000000
--- a/java/res/drawable/tcs_app_icon.xml
+++ /dev/null
@@ -1,11 +0,0 @@
-<?xml version="1.0" encoding="utf-8"?>
-<vector xmlns:android="http://schemas.android.com/apk/res/android"
- android:width="24dp"
- android:height="24dp"
- android:viewportWidth="24"
- android:viewportHeight="24">
-
- <path
- android:fillColor="#000000"
- android:pathData="M2.5 4v3h5v12h3V7h5V4h-13zm19 5h-9v3h3v7h3v-7h3V9z" />
-</vector>
\ No newline at end of file
diff --git a/java/src/com/android/textclassifier/ActionsSuggestionsHelper.java b/java/src/com/android/textclassifier/ActionsSuggestionsHelper.java
index a51c95d..beb155b 100644
--- a/java/src/com/android/textclassifier/ActionsSuggestionsHelper.java
+++ b/java/src/com/android/textclassifier/ActionsSuggestionsHelper.java
@@ -27,7 +27,7 @@
import android.view.textclassifier.ConversationAction;
import android.view.textclassifier.ConversationActions;
import android.view.textclassifier.ConversationActions.Message;
-import com.android.textclassifier.ModelFileManager.ModelFile;
+import com.android.textclassifier.common.ModelFileManager.ModelFile;
import com.android.textclassifier.common.base.TcLog;
import com.android.textclassifier.common.intent.LabeledIntent;
import com.android.textclassifier.common.intent.TemplateIntentFactory;
@@ -147,6 +147,9 @@
public static LabeledIntent.TitleChooser createTitleChooser(String actionType) {
if (ConversationAction.TYPE_OPEN_URL.equals(actionType)) {
return (labeledIntent, resolveInfo) -> {
+ if (resolveInfo == null) {
+ return labeledIntent.titleWithEntity;
+ }
if (resolveInfo.handleAllWebDataURI) {
return labeledIntent.titleWithEntity;
}
diff --git a/java/src/com/android/textclassifier/DefaultTextClassifierService.java b/java/src/com/android/textclassifier/DefaultTextClassifierService.java
index d2c1e38..1f1e958 100644
--- a/java/src/com/android/textclassifier/DefaultTextClassifierService.java
+++ b/java/src/com/android/textclassifier/DefaultTextClassifierService.java
@@ -16,57 +16,92 @@
package com.android.textclassifier;
+import android.content.Context;
import android.os.CancellationSignal;
import android.service.textclassifier.TextClassifierService;
import android.view.textclassifier.ConversationActions;
import android.view.textclassifier.SelectionEvent;
import android.view.textclassifier.TextClassification;
+import android.view.textclassifier.TextClassificationContext;
import android.view.textclassifier.TextClassificationSessionId;
import android.view.textclassifier.TextClassifierEvent;
import android.view.textclassifier.TextLanguage;
import android.view.textclassifier.TextLinks;
import android.view.textclassifier.TextSelection;
+import androidx.annotation.NonNull;
+import androidx.collection.LruCache;
+import com.android.textclassifier.common.ModelFileManager;
+import com.android.textclassifier.common.TextClassifierServiceExecutors;
+import com.android.textclassifier.common.TextClassifierSettings;
import com.android.textclassifier.common.base.TcLog;
+import com.android.textclassifier.common.statsd.TextClassifierApiUsageLogger;
import com.android.textclassifier.utils.IndentingPrintWriter;
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Preconditions;
import com.google.common.util.concurrent.FutureCallback;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.ListeningExecutorService;
import com.google.common.util.concurrent.MoreExecutors;
-import com.google.common.util.concurrent.ThreadFactoryBuilder;
import java.io.FileDescriptor;
import java.io.PrintWriter;
+import java.util.Map;
import java.util.concurrent.Callable;
-import java.util.concurrent.Executors;
+import java.util.concurrent.Executor;
+import javax.annotation.Nullable;
/** An implementation of a TextClassifierService. */
public final class DefaultTextClassifierService extends TextClassifierService {
private static final String TAG = "default_tcs";
+ private final Injector injector;
// TODO: Figure out do we need more concurrency.
- private final ListeningExecutorService normPriorityExecutor =
- MoreExecutors.listeningDecorator(
- Executors.newFixedThreadPool(
- /* nThreads= */ 2,
- new ThreadFactoryBuilder()
- .setNameFormat("tcs-norm-prio-executor")
- .setPriority(Thread.NORM_PRIORITY)
- .build()));
-
- private final ListeningExecutorService lowPriorityExecutor =
- MoreExecutors.listeningDecorator(
- Executors.newSingleThreadExecutor(
- new ThreadFactoryBuilder()
- .setNameFormat("tcs-low-prio-executor")
- .setPriority(Thread.NORM_PRIORITY - 1)
- .build()));
-
+ private ListeningExecutorService normPriorityExecutor;
+ private ListeningExecutorService lowPriorityExecutor;
private TextClassifierImpl textClassifier;
+ private TextClassifierSettings settings;
+ private ModelFileManager modelFileManager;
+ private LruCache<TextClassificationSessionId, TextClassificationContext> sessionIdToContext;
+
+ public DefaultTextClassifierService() {
+ this.injector = new InjectorImpl(this);
+ }
+
+ @VisibleForTesting
+ DefaultTextClassifierService(Injector injector) {
+ this.injector = Preconditions.checkNotNull(injector);
+ }
+
+ private TextClassifierApiUsageLogger textClassifierApiUsageLogger;
@Override
public void onCreate() {
super.onCreate();
- textClassifier = new TextClassifierImpl(this, new TextClassifierSettings());
+
+ settings = injector.createTextClassifierSettings();
+ modelFileManager = injector.createModelFileManager(settings);
+ normPriorityExecutor = injector.createNormPriorityExecutor();
+ lowPriorityExecutor = injector.createLowPriorityExecutor();
+ textClassifier = injector.createTextClassifierImpl(settings, modelFileManager);
+ sessionIdToContext = new LruCache<>(settings.getSessionIdToContextCacheSize());
+ textClassifierApiUsageLogger =
+ injector.createTextClassifierApiUsageLogger(settings, lowPriorityExecutor);
+ }
+
+ @Override
+ public void onDestroy() {
+ super.onDestroy();
+ }
+
+ @Override
+ public void onCreateTextClassificationSession(
+ @NonNull TextClassificationContext context, @NonNull TextClassificationSessionId sessionId) {
+ sessionIdToContext.put(sessionId, context);
+ }
+
+ @Override
+ public void onDestroyTextClassificationSession(@NonNull TextClassificationSessionId sessionId) {
+ sessionIdToContext.remove(sessionId);
}
@Override
@@ -76,7 +111,13 @@
CancellationSignal cancellationSignal,
Callback<TextSelection> callback) {
handleRequestAsync(
- () -> textClassifier.suggestSelection(request), callback, cancellationSignal);
+ () ->
+ textClassifier.suggestSelection(
+ sessionId, sessionIdToTextClassificationContext(sessionId), request),
+ callback,
+ textClassifierApiUsageLogger.createSession(
+ TextClassifierApiUsageLogger.API_TYPE_SUGGEST_SELECTION, sessionId),
+ cancellationSignal);
}
@Override
@@ -85,7 +126,14 @@
TextClassification.Request request,
CancellationSignal cancellationSignal,
Callback<TextClassification> callback) {
- handleRequestAsync(() -> textClassifier.classifyText(request), callback, cancellationSignal);
+ handleRequestAsync(
+ () ->
+ textClassifier.classifyText(
+ sessionId, sessionIdToTextClassificationContext(sessionId), request),
+ callback,
+ textClassifierApiUsageLogger.createSession(
+ TextClassifierApiUsageLogger.API_TYPE_CLASSIFY_TEXT, sessionId),
+ cancellationSignal);
}
@Override
@@ -94,7 +142,14 @@
TextLinks.Request request,
CancellationSignal cancellationSignal,
Callback<TextLinks> callback) {
- handleRequestAsync(() -> textClassifier.generateLinks(request), callback, cancellationSignal);
+ handleRequestAsync(
+ () ->
+ textClassifier.generateLinks(
+ sessionId, sessionIdToTextClassificationContext(sessionId), request),
+ callback,
+ textClassifierApiUsageLogger.createSession(
+ TextClassifierApiUsageLogger.API_TYPE_GENERATE_LINKS, sessionId),
+ cancellationSignal);
}
@Override
@@ -104,7 +159,13 @@
CancellationSignal cancellationSignal,
Callback<ConversationActions> callback) {
handleRequestAsync(
- () -> textClassifier.suggestConversationActions(request), callback, cancellationSignal);
+ () ->
+ textClassifier.suggestConversationActions(
+ sessionId, sessionIdToTextClassificationContext(sessionId), request),
+ callback,
+ textClassifierApiUsageLogger.createSession(
+ TextClassifierApiUsageLogger.API_TYPE_SUGGEST_CONVERSATION_ACTIONS, sessionId),
+ cancellationSignal);
}
@Override
@@ -113,12 +174,19 @@
TextLanguage.Request request,
CancellationSignal cancellationSignal,
Callback<TextLanguage> callback) {
- handleRequestAsync(() -> textClassifier.detectLanguage(request), callback, cancellationSignal);
+ handleRequestAsync(
+ () ->
+ textClassifier.detectLanguage(
+ sessionId, sessionIdToTextClassificationContext(sessionId), request),
+ callback,
+ textClassifierApiUsageLogger.createSession(
+ TextClassifierApiUsageLogger.API_TYPE_DETECT_LANGUAGES, sessionId),
+ cancellationSignal);
}
@Override
public void onSelectionEvent(TextClassificationSessionId sessionId, SelectionEvent event) {
- handleEvent(() -> textClassifier.onSelectionEvent(event));
+ handleEvent(() -> textClassifier.onSelectionEvent(sessionId, event));
}
@Override
@@ -130,12 +198,31 @@
@Override
protected void dump(FileDescriptor fd, PrintWriter writer, String[] args) {
IndentingPrintWriter indentingPrintWriter = new IndentingPrintWriter(writer);
+ // TODO(licha): Also dump ModelDownloadManager for debugging
textClassifier.dump(indentingPrintWriter);
+ dumpImpl(indentingPrintWriter);
indentingPrintWriter.flush();
}
+ private void dumpImpl(IndentingPrintWriter printWriter) {
+ printWriter.println("DefaultTextClassifierService:");
+ printWriter.increaseIndent();
+ printWriter.println("sessionIdToContext:");
+ printWriter.increaseIndent();
+ for (Map.Entry<TextClassificationSessionId, TextClassificationContext> entry :
+ sessionIdToContext.snapshot().entrySet()) {
+ printWriter.printPair(entry.getKey().getValue(), entry.getValue());
+ }
+ printWriter.decreaseIndent();
+ printWriter.decreaseIndent();
+ printWriter.println();
+ }
+
private <T> void handleRequestAsync(
- Callable<T> callable, Callback<T> callback, CancellationSignal cancellationSignal) {
+ Callable<T> callable,
+ Callback<T> callback,
+ TextClassifierApiUsageLogger.Session apiLoggerSession,
+ CancellationSignal cancellationSignal) {
ListenableFuture<T> result = normPriorityExecutor.submit(callable);
Futures.addCallback(
result,
@@ -143,12 +230,14 @@
@Override
public void onSuccess(T result) {
callback.onSuccess(result);
+ apiLoggerSession.reportSuccess();
}
@Override
public void onFailure(Throwable t) {
TcLog.e(TAG, "onFailure: ", t);
callback.onFailure(t.getMessage());
+ apiLoggerSession.reportFailure();
}
},
MoreExecutors.directExecutor());
@@ -175,4 +264,83 @@
},
MoreExecutors.directExecutor());
}
+
+ @Nullable
+ private TextClassificationContext sessionIdToTextClassificationContext(
+ @Nullable TextClassificationSessionId sessionId) {
+ if (sessionId == null) {
+ return null;
+ }
+ return sessionIdToContext.get(sessionId);
+ }
+
+ // Do not call any of these methods, except the constructor, before Service.onCreate is called.
+ private static class InjectorImpl implements Injector {
+ // Do not access the context object before Service.onCreate is invoked.
+ private final Context context;
+
+ private InjectorImpl(Context context) {
+ this.context = Preconditions.checkNotNull(context);
+ }
+
+ @Override
+ public Context getContext() {
+ return context;
+ }
+
+ @Override
+ public ModelFileManager createModelFileManager(TextClassifierSettings settings) {
+ return new ModelFileManager(context, settings);
+ }
+
+ @Override
+ public TextClassifierSettings createTextClassifierSettings() {
+ return new TextClassifierSettings();
+ }
+
+ @Override
+ public TextClassifierImpl createTextClassifierImpl(
+ TextClassifierSettings settings, ModelFileManager modelFileManager) {
+ return new TextClassifierImpl(context, settings, modelFileManager);
+ }
+
+ @Override
+ public ListeningExecutorService createNormPriorityExecutor() {
+ return TextClassifierServiceExecutors.getNormhPriorityExecutor();
+ }
+
+ @Override
+ public ListeningExecutorService createLowPriorityExecutor() {
+ return TextClassifierServiceExecutors.getLowPriorityExecutor();
+ }
+
+ @Override
+ public TextClassifierApiUsageLogger createTextClassifierApiUsageLogger(
+ TextClassifierSettings settings, Executor executor) {
+ return new TextClassifierApiUsageLogger(
+ settings::getTextClassifierApiLogSampleRate, executor);
+ }
+ }
+
+ /*
+ * Provides dependencies to the {@link DefaultTextClassifierService}. This makes the service
+ * class testable.
+ */
+ interface Injector {
+ Context getContext();
+
+ ModelFileManager createModelFileManager(TextClassifierSettings settings);
+
+ TextClassifierSettings createTextClassifierSettings();
+
+ TextClassifierImpl createTextClassifierImpl(
+ TextClassifierSettings settings, ModelFileManager modelFileManager);
+
+ ListeningExecutorService createNormPriorityExecutor();
+
+ ListeningExecutorService createLowPriorityExecutor();
+
+ TextClassifierApiUsageLogger createTextClassifierApiUsageLogger(
+ TextClassifierSettings settings, Executor executor);
+ }
}
diff --git a/java/src/com/android/textclassifier/ModelFileManager.java b/java/src/com/android/textclassifier/ModelFileManager.java
deleted file mode 100644
index a6f64d8..0000000
--- a/java/src/com/android/textclassifier/ModelFileManager.java
+++ /dev/null
@@ -1,311 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.android.textclassifier;
-
-import android.os.LocaleList;
-import android.os.ParcelFileDescriptor;
-import android.text.TextUtils;
-import androidx.annotation.GuardedBy;
-import com.android.textclassifier.common.base.TcLog;
-import com.android.textclassifier.common.logging.ResultIdUtils.ModelInfo;
-import com.google.common.base.Optional;
-import com.google.common.base.Preconditions;
-import com.google.common.base.Splitter;
-import com.google.common.collect.ImmutableList;
-import java.io.File;
-import java.io.FileNotFoundException;
-import java.io.IOException;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.Collections;
-import java.util.List;
-import java.util.Locale;
-import java.util.Objects;
-import java.util.function.Function;
-import java.util.function.Supplier;
-import java.util.regex.Matcher;
-import java.util.regex.Pattern;
-import java.util.stream.Collectors;
-import javax.annotation.Nullable;
-
-/** Manages model files that are listed by the model files supplier. */
-final class ModelFileManager {
- private static final String TAG = "ModelFileManager";
-
- private final Supplier<ImmutableList<ModelFile>> modelFileSupplier;
-
- public ModelFileManager(Supplier<ImmutableList<ModelFile>> modelFileSupplier) {
- this.modelFileSupplier = Preconditions.checkNotNull(modelFileSupplier);
- }
-
- /** Returns an immutable list of model files listed by the given model files supplier. */
- public ImmutableList<ModelFile> listModelFiles() {
- return modelFileSupplier.get();
- }
-
- /**
- * Returns the best model file for the given localelist, {@code null} if nothing is found.
- *
- * @param localeList the required locales, use {@code null} if there is no preference.
- */
- public ModelFile findBestModelFile(@Nullable LocaleList localeList) {
- final String languages =
- localeList == null || localeList.isEmpty()
- ? LocaleList.getDefault().toLanguageTags()
- : localeList.toLanguageTags();
- final List<Locale.LanguageRange> languageRangeList = Locale.LanguageRange.parse(languages);
-
- ModelFile bestModel = null;
- for (ModelFile model : listModelFiles()) {
- if (model.isAnyLanguageSupported(languageRangeList)) {
- if (model.isPreferredTo(bestModel)) {
- bestModel = model;
- }
- }
- }
- return bestModel;
- }
-
- /** Default implementation of the model file supplier. */
- public static final class ModelFileSupplierImpl implements Supplier<ImmutableList<ModelFile>> {
- private final File updatedModelFile;
- private final File factoryModelDir;
- private final Pattern modelFilenamePattern;
- private final Function<Integer, Integer> versionSupplier;
- private final Function<Integer, String> supportedLocalesSupplier;
- private final Object lock = new Object();
-
- @GuardedBy("lock")
- private ImmutableList<ModelFile> factoryModels;
-
- public ModelFileSupplierImpl(
- File factoryModelDir,
- String factoryModelFileNameRegex,
- File updatedModelFile,
- Function<Integer, Integer> versionSupplier,
- Function<Integer, String> supportedLocalesSupplier) {
- this.updatedModelFile = Preconditions.checkNotNull(updatedModelFile);
- this.factoryModelDir = Preconditions.checkNotNull(factoryModelDir);
- modelFilenamePattern = Pattern.compile(Preconditions.checkNotNull(factoryModelFileNameRegex));
- this.versionSupplier = Preconditions.checkNotNull(versionSupplier);
- this.supportedLocalesSupplier = Preconditions.checkNotNull(supportedLocalesSupplier);
- }
-
- @Override
- public ImmutableList<ModelFile> get() {
- final List<ModelFile> modelFiles = new ArrayList<>();
- // The update model has the highest precedence.
- if (updatedModelFile.exists()) {
- final ModelFile updatedModel = createModelFile(updatedModelFile);
- if (updatedModel != null) {
- modelFiles.add(updatedModel);
- }
- }
- // Factory models should never have overlapping locales, so the order doesn't matter.
- synchronized (lock) {
- if (factoryModels == null) {
- factoryModels = getFactoryModels();
- }
- modelFiles.addAll(factoryModels);
- }
- return ImmutableList.copyOf(modelFiles);
- }
-
- private ImmutableList<ModelFile> getFactoryModels() {
- List<ModelFile> factoryModelFiles = new ArrayList<>();
- if (factoryModelDir.exists() && factoryModelDir.isDirectory()) {
- final File[] files = factoryModelDir.listFiles();
- for (File file : files) {
- final Matcher matcher = modelFilenamePattern.matcher(file.getName());
- if (matcher.matches() && file.isFile()) {
- final ModelFile model = createModelFile(file);
- if (model != null) {
- factoryModelFiles.add(model);
- }
- }
- }
- }
- return ImmutableList.copyOf(factoryModelFiles);
- }
-
- /** Returns null if the path did not point to a compatible model. */
- @Nullable
- private ModelFile createModelFile(File file) {
- if (!file.exists()) {
- return null;
- }
- ParcelFileDescriptor modelFd = null;
- try {
- modelFd = ParcelFileDescriptor.open(file, ParcelFileDescriptor.MODE_READ_ONLY);
- if (modelFd == null) {
- return null;
- }
- final int modelFdInt = modelFd.getFd();
- final int version = versionSupplier.apply(modelFdInt);
- final String supportedLocalesStr = supportedLocalesSupplier.apply(modelFdInt);
- if (supportedLocalesStr.isEmpty()) {
- TcLog.d(TAG, "Ignoring " + file.getAbsolutePath());
- return null;
- }
- final List<Locale> supportedLocales = new ArrayList<>();
- for (String langTag : Splitter.on(',').split(supportedLocalesStr)) {
- supportedLocales.add(Locale.forLanguageTag(langTag));
- }
- return new ModelFile(
- file,
- version,
- supportedLocales,
- supportedLocalesStr,
- ModelFile.LANGUAGE_INDEPENDENT.equals(supportedLocalesStr));
- } catch (FileNotFoundException e) {
- TcLog.e(TAG, "Failed to find " + file.getAbsolutePath(), e);
- return null;
- } finally {
- maybeCloseAndLogError(modelFd);
- }
- }
-
- /** Closes the ParcelFileDescriptor, if non-null, and logs any errors that occur. */
- private static void maybeCloseAndLogError(@Nullable ParcelFileDescriptor fd) {
- if (fd == null) {
- return;
- }
- try {
- fd.close();
- } catch (IOException e) {
- TcLog.e(TAG, "Error closing file.", e);
- }
- }
- }
-
- /** Describes TextClassifier model files on disk. */
- public static final class ModelFile {
- public static final String LANGUAGE_INDEPENDENT = "*";
-
- private final File file;
- private final int version;
- private final List<Locale> supportedLocales;
- private final String supportedLocalesStr;
- private final boolean languageIndependent;
-
- public ModelFile(
- File file,
- int version,
- List<Locale> supportedLocales,
- String supportedLocalesStr,
- boolean languageIndependent) {
- this.file = Preconditions.checkNotNull(file);
- this.version = version;
- this.supportedLocales = Preconditions.checkNotNull(supportedLocales);
- this.supportedLocalesStr = Preconditions.checkNotNull(supportedLocalesStr);
- this.languageIndependent = languageIndependent;
- }
-
- /** Returns the absolute path to the model file. */
- public String getPath() {
- return file.getAbsolutePath();
- }
-
- /** Returns a name to use for id generation, effectively the name of the model file. */
- public String getName() {
- return file.getName();
- }
-
- /** Returns the version tag in the model's metadata. */
- public int getVersion() {
- return version;
- }
-
- /** Returns whether the language supports any language in the given ranges. */
- public boolean isAnyLanguageSupported(List<Locale.LanguageRange> languageRanges) {
- Preconditions.checkNotNull(languageRanges);
- return languageIndependent || Locale.lookup(languageRanges, supportedLocales) != null;
- }
-
- /** Returns an immutable lists of supported locales. */
- public List<Locale> getSupportedLocales() {
- return Collections.unmodifiableList(supportedLocales);
- }
-
- /** Returns the original supported locals string read from the model file. */
- public String getSupportedLocalesStr() {
- return supportedLocalesStr;
- }
-
- /** Returns if this model file is preferred to the given one. */
- public boolean isPreferredTo(@Nullable ModelFile model) {
- // A model is preferred to no model.
- if (model == null) {
- return true;
- }
-
- // A language-specific model is preferred to a language independent
- // model.
- if (!languageIndependent && model.languageIndependent) {
- return true;
- }
- if (languageIndependent && !model.languageIndependent) {
- return false;
- }
-
- // A higher-version model is preferred.
- if (version > model.getVersion()) {
- return true;
- }
- return false;
- }
-
- @Override
- public int hashCode() {
- return Objects.hash(getPath());
- }
-
- @Override
- public boolean equals(Object other) {
- if (this == other) {
- return true;
- }
- if (other instanceof ModelFile) {
- final ModelFile otherModel = (ModelFile) other;
- return TextUtils.equals(getPath(), otherModel.getPath());
- }
- return false;
- }
-
- public ModelInfo toModelInfo() {
- return new ModelInfo(getVersion(), supportedLocalesStr);
- }
-
- @Override
- public String toString() {
- return String.format(
- Locale.US,
- "ModelFile { path=%s name=%s version=%d locales=%s }",
- getPath(),
- getName(),
- version,
- supportedLocalesStr);
- }
-
- public static ImmutableList<Optional<ModelInfo>> toModelInfos(
- Optional<ModelFile>... modelFiles) {
- return Arrays.stream(modelFiles)
- .map(modelFile -> modelFile.transform(ModelFile::toModelInfo))
- .collect(Collectors.collectingAndThen(Collectors.toList(), ImmutableList::copyOf));
- }
- }
-}
diff --git a/java/src/com/android/textclassifier/TextClassifierImpl.java b/java/src/com/android/textclassifier/TextClassifierImpl.java
index 5c028ef..bf326fb 100644
--- a/java/src/com/android/textclassifier/TextClassifierImpl.java
+++ b/java/src/com/android/textclassifier/TextClassifierImpl.java
@@ -22,17 +22,19 @@
import android.app.RemoteAction;
import android.content.Context;
import android.content.Intent;
+import android.content.res.AssetFileDescriptor;
import android.icu.util.ULocale;
import android.os.Bundle;
import android.os.LocaleList;
import android.os.Looper;
-import android.os.ParcelFileDescriptor;
import android.util.ArrayMap;
import android.view.View.OnClickListener;
import android.view.textclassifier.ConversationAction;
import android.view.textclassifier.ConversationActions;
import android.view.textclassifier.SelectionEvent;
import android.view.textclassifier.TextClassification;
+import android.view.textclassifier.TextClassification.Request;
+import android.view.textclassifier.TextClassificationContext;
import android.view.textclassifier.TextClassificationSessionId;
import android.view.textclassifier.TextClassifier;
import android.view.textclassifier.TextClassifierEvent;
@@ -42,7 +44,10 @@
import androidx.annotation.GuardedBy;
import androidx.annotation.WorkerThread;
import androidx.core.util.Pair;
-import com.android.textclassifier.ModelFileManager.ModelFile;
+import com.android.textclassifier.common.ModelFileManager;
+import com.android.textclassifier.common.ModelFileManager.ModelFile;
+import com.android.textclassifier.common.ModelType;
+import com.android.textclassifier.common.TextClassifierSettings;
import com.android.textclassifier.common.base.TcLog;
import com.android.textclassifier.common.intent.LabeledIntent;
import com.android.textclassifier.common.intent.TemplateIntentFactory;
@@ -55,14 +60,13 @@
import com.android.textclassifier.common.statsd.TextClassifierEventLogger;
import com.android.textclassifier.utils.IndentingPrintWriter;
import com.google.android.textclassifier.ActionsSuggestionsModel;
+import com.google.android.textclassifier.ActionsSuggestionsModel.ActionSuggestions;
import com.google.android.textclassifier.AnnotatorModel;
import com.google.android.textclassifier.LangIdModel;
import com.google.common.base.Optional;
import com.google.common.base.Preconditions;
import com.google.common.collect.FluentIterable;
import com.google.common.collect.ImmutableList;
-import java.io.File;
-import java.io.FileNotFoundException;
import java.io.IOException;
import java.time.ZoneId;
import java.time.ZonedDateTime;
@@ -84,26 +88,8 @@
private static final String TAG = "TextClassifierImpl";
- private static final File FACTORY_MODEL_DIR = new File("/etc/textclassifier/");
- // Annotator
- private static final String ANNOTATOR_FACTORY_MODEL_FILENAME_REGEX =
- "textclassifier\\.(.*)\\.model";
- private static final File ANNOTATOR_UPDATED_MODEL_FILE =
- new File("/data/misc/textclassifier/textclassifier.model");
-
- // LangIdModel
- private static final String LANG_ID_FACTORY_MODEL_FILENAME_REGEX = "lang_id.model";
- private static final File UPDATED_LANG_ID_MODEL_FILE =
- new File("/data/misc/textclassifier/lang_id.model");
-
- // Actions
- private static final String ACTIONS_FACTORY_MODEL_FILENAME_REGEX =
- "actions_suggestions\\.(.*)\\.model";
- private static final File UPDATED_ACTIONS_MODEL =
- new File("/data/misc/textclassifier/actions_suggestions.model");
-
private final Context context;
- private final TextClassifier fallback;
+ private final ModelFileManager modelFileManager;
private final GenerateLinksLogger generateLinksLogger;
private final Object lock = new Object();
@@ -131,155 +117,131 @@
private final TextClassifierSettings settings;
- private final ModelFileManager annotatorModelFileManager;
- private final ModelFileManager langIdModelFileManager;
- private final ModelFileManager actionsModelFileManager;
private final TemplateIntentFactory templateIntentFactory;
- TextClassifierImpl(Context context, TextClassifierSettings settings, TextClassifier fallback) {
+ TextClassifierImpl(
+ Context context, TextClassifierSettings settings, ModelFileManager modelFileManager) {
this.context = Preconditions.checkNotNull(context);
- this.fallback = Preconditions.checkNotNull(fallback);
this.settings = Preconditions.checkNotNull(settings);
- generateLinksLogger = new GenerateLinksLogger(this.settings.getGenerateLinksLogSampleRate());
- annotatorModelFileManager =
- new ModelFileManager(
- new ModelFileManager.ModelFileSupplierImpl(
- FACTORY_MODEL_DIR,
- ANNOTATOR_FACTORY_MODEL_FILENAME_REGEX,
- ANNOTATOR_UPDATED_MODEL_FILE,
- AnnotatorModel::getVersion,
- AnnotatorModel::getLocales));
- langIdModelFileManager =
- new ModelFileManager(
- new ModelFileManager.ModelFileSupplierImpl(
- FACTORY_MODEL_DIR,
- LANG_ID_FACTORY_MODEL_FILENAME_REGEX,
- UPDATED_LANG_ID_MODEL_FILE,
- LangIdModel::getVersion,
- fd -> ModelFileManager.ModelFile.LANGUAGE_INDEPENDENT));
- actionsModelFileManager =
- new ModelFileManager(
- new ModelFileManager.ModelFileSupplierImpl(
- FACTORY_MODEL_DIR,
- ACTIONS_FACTORY_MODEL_FILENAME_REGEX,
- UPDATED_ACTIONS_MODEL,
- ActionsSuggestionsModel::getVersion,
- ActionsSuggestionsModel::getLocales));
+ this.modelFileManager = Preconditions.checkNotNull(modelFileManager);
+ generateLinksLogger = new GenerateLinksLogger(this.settings.getGenerateLinksLogSampleRate());
templateIntentFactory = new TemplateIntentFactory();
}
- TextClassifierImpl(Context context, TextClassifierSettings settings) {
- this(context, settings, TextClassifier.NO_OP);
+ @WorkerThread
+ TextSelection suggestSelection(
+ @Nullable TextClassificationSessionId sessionId,
+ @Nullable TextClassificationContext textClassificationContext,
+ TextSelection.Request request)
+ throws IOException {
+ Preconditions.checkNotNull(request);
+ checkMainThread();
+ final int rangeLength = request.getEndIndex() - request.getStartIndex();
+ final String string = request.getText().toString();
+ Preconditions.checkArgument(!string.isEmpty(), "input string should not be empty");
+ Preconditions.checkArgument(
+ rangeLength <= settings.getClassifyTextMaxRangeLength(), "range is too large");
+ final String localesString = concatenateLocales(request.getDefaultLocales());
+ final LangIdModel langIdModel = getLangIdImpl();
+ final String detectLanguageTags =
+ String.join(",", detectLanguageTags(langIdModel, request.getText()));
+ final ZonedDateTime refTime = ZonedDateTime.now(ZoneId.systemDefault());
+ final AnnotatorModel annotatorImpl = getAnnotatorImpl(request.getDefaultLocales());
+ final int[] startEnd =
+ annotatorImpl.suggestSelection(
+ string,
+ request.getStartIndex(),
+ request.getEndIndex(),
+ AnnotatorModel.SelectionOptions.builder()
+ .setLocales(localesString)
+ .setDetectedTextLanguageTags(detectLanguageTags)
+ .build());
+ final int start = startEnd[0];
+ final int end = startEnd[1];
+ if (start >= end
+ || start < 0
+ || start > request.getStartIndex()
+ || end > string.length()
+ || end < request.getEndIndex()) {
+ throw new IllegalArgumentException("Got bad indices for input text. Ignoring result.");
+ }
+ final TextSelection.Builder tsBuilder = new TextSelection.Builder(start, end);
+ final AnnotatorModel.ClassificationResult[] results =
+ annotatorImpl.classifyText(
+ string,
+ start,
+ end,
+ AnnotatorModel.ClassificationOptions.builder()
+ .setReferenceTimeMsUtc(refTime.toInstant().toEpochMilli())
+ .setReferenceTimezone(refTime.getZone().getId())
+ .setLocales(localesString)
+ .setDetectedTextLanguageTags(detectLanguageTags)
+ .setUserFamiliarLanguageTags(LocaleList.getDefault().toLanguageTags())
+ .build(),
+ // Passing null here to suppress intent generation
+ // TODO: Use an explicit flag to suppress it.
+ /* appContext */ null,
+ /* deviceLocales */ null);
+ final int size = results.length;
+ for (int i = 0; i < size; i++) {
+ tsBuilder.setEntityType(results[i].getCollection(), results[i].getScore());
+ }
+ final String resultId =
+ createAnnotatorId(string, request.getStartIndex(), request.getEndIndex());
+ return tsBuilder.setId(resultId).build();
}
@WorkerThread
- TextSelection suggestSelection(TextSelection.Request request) {
+ TextClassification classifyText(
+ @Nullable TextClassificationSessionId sessionId,
+ @Nullable TextClassificationContext textClassificationContext,
+ Request request)
+ throws IOException {
Preconditions.checkNotNull(request);
checkMainThread();
- try {
- final int rangeLength = request.getEndIndex() - request.getStartIndex();
- final String string = request.getText().toString();
- if (string.length() > 0 && rangeLength <= settings.getSuggestSelectionMaxRangeLength()) {
- final String localesString = concatenateLocales(request.getDefaultLocales());
- final Optional<LangIdModel> langIdModel = getLangIdImpl();
- final String detectLanguageTags =
- String.join(",", detectLanguageTags(langIdModel, request.getText()));
- final ZonedDateTime refTime = ZonedDateTime.now(ZoneId.systemDefault());
- final AnnotatorModel annotatorImpl = getAnnotatorImpl(request.getDefaultLocales());
- final int[] startEnd =
- annotatorImpl.suggestSelection(
+ LangIdModel langId = getLangIdImpl();
+ List<String> detectLanguageTags = detectLanguageTags(langId, request.getText());
+ final int rangeLength = request.getEndIndex() - request.getStartIndex();
+ final String string = request.getText().toString();
+ Preconditions.checkArgument(!string.isEmpty(), "input string should not be empty");
+ Preconditions.checkArgument(
+ rangeLength <= settings.getClassifyTextMaxRangeLength(), "range is too large");
+
+ final String localesString = concatenateLocales(request.getDefaultLocales());
+ final ZonedDateTime refTime =
+ request.getReferenceTime() != null
+ ? request.getReferenceTime()
+ : ZonedDateTime.now(ZoneId.systemDefault());
+ final AnnotatorModel.ClassificationResult[] results =
+ getAnnotatorImpl(request.getDefaultLocales())
+ .classifyText(
string,
request.getStartIndex(),
request.getEndIndex(),
- new AnnotatorModel.SelectionOptions(localesString, detectLanguageTags));
- final int start = startEnd[0];
- final int end = startEnd[1];
- if (start < end
- && start >= 0
- && end <= string.length()
- && start <= request.getStartIndex()
- && end >= request.getEndIndex()) {
- final TextSelection.Builder tsBuilder = new TextSelection.Builder(start, end);
- final AnnotatorModel.ClassificationResult[] results =
- annotatorImpl.classifyText(
- string,
- start,
- end,
- new AnnotatorModel.ClassificationOptions(
- refTime.toInstant().toEpochMilli(),
- refTime.getZone().getId(),
- localesString,
- detectLanguageTags),
- // Passing null here to suppress intent generation
- // TODO: Use an explicit flag to suppress it.
- /* appContext */ null,
- /* deviceLocales */ null);
- final int size = results.length;
- for (int i = 0; i < size; i++) {
- tsBuilder.setEntityType(results[i].getCollection(), results[i].getScore());
- }
- final String resultId =
- createAnnotatorId(string, request.getStartIndex(), request.getEndIndex());
- return tsBuilder.setId(resultId).build();
- } else {
- // We can not trust the result. Log the issue and ignore the result.
- TcLog.d(TAG, "Got bad indices for input text. Ignoring result.");
- }
- }
- } catch (Throwable t) {
- // Avoid throwing from this method. Log the error.
- TcLog.e(TAG, "Error suggesting selection for text. No changes to selection suggested.", t);
+ AnnotatorModel.ClassificationOptions.builder()
+ .setReferenceTimeMsUtc(refTime.toInstant().toEpochMilli())
+ .setReferenceTimezone(refTime.getZone().getId())
+ .setLocales(localesString)
+ .setDetectedTextLanguageTags(String.join(",", detectLanguageTags))
+ .setAnnotationUsecase(AnnotatorModel.AnnotationUsecase.SMART.getValue())
+ .setUserFamiliarLanguageTags(LocaleList.getDefault().toLanguageTags())
+ .build(),
+ context,
+ getResourceLocalesString());
+ if (results.length == 0) {
+ throw new IllegalStateException("Empty text classification. Something went wrong.");
}
- // Getting here means something went wrong, return a NO_OP result.
- return fallback.suggestSelection(request);
+ return createClassificationResult(
+ results, string, request.getStartIndex(), request.getEndIndex(), langId);
}
@WorkerThread
- TextClassification classifyText(TextClassification.Request request) {
- Preconditions.checkNotNull(request);
- checkMainThread();
- try {
- Optional<LangIdModel> langId = getLangIdImpl();
- List<String> detectLanguageTags = detectLanguageTags(langId, request.getText());
- final int rangeLength = request.getEndIndex() - request.getStartIndex();
- final String string = request.getText().toString();
- if (string.length() > 0 && rangeLength <= settings.getClassifyTextMaxRangeLength()) {
- final String localesString = concatenateLocales(request.getDefaultLocales());
- final ZonedDateTime refTime =
- request.getReferenceTime() != null
- ? request.getReferenceTime()
- : ZonedDateTime.now(ZoneId.systemDefault());
- final AnnotatorModel.ClassificationResult[] results =
- getAnnotatorImpl(request.getDefaultLocales())
- .classifyText(
- string,
- request.getStartIndex(),
- request.getEndIndex(),
- new AnnotatorModel.ClassificationOptions(
- refTime.toInstant().toEpochMilli(),
- refTime.getZone().getId(),
- localesString,
- String.join(",", detectLanguageTags),
- AnnotatorModel.AnnotationUsecase.SMART.getValue(),
- LocaleList.getDefault().toLanguageTags()),
- context,
- getResourceLocalesString());
- if (results.length > 0) {
- return createClassificationResult(
- results, string, request.getStartIndex(), request.getEndIndex(), langId);
- }
- }
- } catch (Throwable t) {
- // Avoid throwing from this method. Log the error.
- TcLog.e(TAG, "Error getting text classification info.", t);
- }
- // Getting here means something went wrong, return a NO_OP result.
- return fallback.classifyText(request);
- }
-
- @WorkerThread
- TextLinks generateLinks(TextLinks.Request request) {
+ TextLinks generateLinks(
+ @Nullable TextClassificationSessionId sessionId,
+ @Nullable TextClassificationContext textClassificationContext,
+ TextLinks.Request request)
+ throws IOException {
Preconditions.checkNotNull(request);
Preconditions.checkArgument(
request.getText().length() <= getMaxGenerateLinksTextLength(),
@@ -290,74 +252,71 @@
final String textString = request.getText().toString();
final TextLinks.Builder builder = new TextLinks.Builder(textString);
- try {
- final long startTimeMs = System.currentTimeMillis();
- final ZonedDateTime refTime = ZonedDateTime.now(ZoneId.systemDefault());
- final Collection<String> entitiesToIdentify =
- request.getEntityConfig() != null
- ? request
- .getEntityConfig()
- .resolveEntityListModifications(
- getEntitiesForHints(request.getEntityConfig().getHints()))
- : settings.getEntityListDefault();
- final String localesString = concatenateLocales(request.getDefaultLocales());
- Optional<LangIdModel> langId = getLangIdImpl();
- ImmutableList<String> detectLanguageTags = detectLanguageTags(langId, request.getText());
- final AnnotatorModel annotatorImpl = getAnnotatorImpl(request.getDefaultLocales());
- final boolean isSerializedEntityDataEnabled =
- ExtrasUtils.isSerializedEntityDataEnabled(request);
- final AnnotatorModel.AnnotatedSpan[] annotations =
- annotatorImpl.annotate(
- textString,
- new AnnotatorModel.AnnotationOptions(
- refTime.toInstant().toEpochMilli(),
- refTime.getZone().getId(),
- localesString,
- String.join(",", detectLanguageTags),
- entitiesToIdentify,
- AnnotatorModel.AnnotationUsecase.SMART.getValue(),
- isSerializedEntityDataEnabled));
- for (AnnotatorModel.AnnotatedSpan span : annotations) {
- final AnnotatorModel.ClassificationResult[] results = span.getClassification();
- if (results.length == 0 || !entitiesToIdentify.contains(results[0].getCollection())) {
- continue;
- }
- final Map<String, Float> entityScores = new ArrayMap<>();
- for (int i = 0; i < results.length; i++) {
- entityScores.put(results[i].getCollection(), results[i].getScore());
- }
- Bundle extras = new Bundle();
- if (isSerializedEntityDataEnabled) {
- ExtrasUtils.putEntities(extras, results);
- }
- builder.addLink(span.getStartIndex(), span.getEndIndex(), entityScores, extras);
+ final long startTimeMs = System.currentTimeMillis();
+ final ZonedDateTime refTime = ZonedDateTime.now(ZoneId.systemDefault());
+ final Collection<String> entitiesToIdentify =
+ request.getEntityConfig() != null
+ ? request
+ .getEntityConfig()
+ .resolveEntityListModifications(
+ getEntitiesForHints(request.getEntityConfig().getHints()))
+ : settings.getEntityListDefault();
+ final String localesString = concatenateLocales(request.getDefaultLocales());
+ LangIdModel langId = getLangIdImpl();
+ ImmutableList<String> detectLanguageTags = detectLanguageTags(langId, request.getText());
+ final AnnotatorModel annotatorImpl = getAnnotatorImpl(request.getDefaultLocales());
+ final boolean isSerializedEntityDataEnabled =
+ ExtrasUtils.isSerializedEntityDataEnabled(request);
+ final AnnotatorModel.AnnotatedSpan[] annotations =
+ annotatorImpl.annotate(
+ textString,
+ AnnotatorModel.AnnotationOptions.builder()
+ .setReferenceTimeMsUtc(refTime.toInstant().toEpochMilli())
+ .setReferenceTimezone(refTime.getZone().getId())
+ .setLocales(localesString)
+ .setDetectedTextLanguageTags(String.join(",", detectLanguageTags))
+ .setEntityTypes(entitiesToIdentify)
+ .setAnnotationUsecase(AnnotatorModel.AnnotationUsecase.SMART.getValue())
+ .setIsSerializedEntityDataEnabled(isSerializedEntityDataEnabled)
+ .build());
+ for (AnnotatorModel.AnnotatedSpan span : annotations) {
+ final AnnotatorModel.ClassificationResult[] results = span.getClassification();
+ if (results.length == 0 || !entitiesToIdentify.contains(results[0].getCollection())) {
+ continue;
}
- final TextLinks links = builder.build();
- final long endTimeMs = System.currentTimeMillis();
- final String callingPackageName =
- request.getCallingPackageName() == null
- ? context.getPackageName() // local (in process) TC.
- : request.getCallingPackageName();
- Optional<ModelInfo> annotatorModelInfo;
- Optional<ModelInfo> langIdModelInfo;
- synchronized (lock) {
- annotatorModelInfo =
- Optional.fromNullable(annotatorModelInUse).transform(ModelFile::toModelInfo);
- langIdModelInfo = Optional.fromNullable(langIdModelInUse).transform(ModelFile::toModelInfo);
+ final Map<String, Float> entityScores = new ArrayMap<>();
+ for (AnnotatorModel.ClassificationResult result : results) {
+ entityScores.put(result.getCollection(), result.getScore());
}
- generateLinksLogger.logGenerateLinks(
- request.getText(),
- links,
- callingPackageName,
- endTimeMs - startTimeMs,
- annotatorModelInfo,
- langIdModelInfo);
- return links;
- } catch (Throwable t) {
- // Avoid throwing from this method. Log the error.
- TcLog.e(TAG, "Error getting links info.", t);
+ Bundle extras = new Bundle();
+ if (isSerializedEntityDataEnabled) {
+ ExtrasUtils.putEntities(extras, results);
+ }
+ builder.addLink(span.getStartIndex(), span.getEndIndex(), entityScores, extras);
}
- return fallback.generateLinks(request);
+ final TextLinks links = builder.build();
+ final long endTimeMs = System.currentTimeMillis();
+ final String callingPackageName =
+ request.getCallingPackageName() == null
+ ? context.getPackageName() // local (in process) TC.
+ : request.getCallingPackageName();
+ Optional<ModelInfo> annotatorModelInfo;
+ Optional<ModelInfo> langIdModelInfo;
+ synchronized (lock) {
+ annotatorModelInfo =
+ Optional.fromNullable(annotatorModelInUse).transform(ModelFile::toModelInfo);
+ langIdModelInfo = Optional.fromNullable(langIdModelInUse).transform(ModelFile::toModelInfo);
+ }
+ generateLinksLogger.logGenerateLinks(
+ sessionId,
+ textClassificationContext,
+ request.getText(),
+ links,
+ callingPackageName,
+ endTimeMs - startTimeMs,
+ annotatorModelInfo,
+ langIdModelInfo);
+ return links;
}
int getMaxGenerateLinksTextLength() {
@@ -379,7 +338,7 @@
}
}
- void onSelectionEvent(SelectionEvent event) {
+ void onSelectionEvent(@Nullable TextClassificationSessionId sessionId, SelectionEvent event) {
TextClassifierEvent textClassifierEvent = SelectionEventConverter.toTextClassifierEvent(event);
if (textClassifierEvent == null) {
return;
@@ -394,60 +353,49 @@
TextClassifierEventConverter.fromPlatform(event));
}
- TextLanguage detectLanguage(TextLanguage.Request request) {
+ TextLanguage detectLanguage(
+ @Nullable TextClassificationSessionId sessionId,
+ @Nullable TextClassificationContext textClassificationContext,
+ TextLanguage.Request request)
+ throws IOException {
Preconditions.checkNotNull(request);
checkMainThread();
- try {
- final TextLanguage.Builder builder = new TextLanguage.Builder();
- Optional<LangIdModel> langIdImpl = getLangIdImpl();
- if (langIdImpl.isPresent()) {
- final LangIdModel.LanguageResult[] langResults =
- langIdImpl.get().detectLanguages(request.getText().toString());
- for (int i = 0; i < langResults.length; i++) {
- builder.putLocale(
- ULocale.forLanguageTag(langResults[i].getLanguage()), langResults[i].getScore());
- }
- return builder.build();
- }
- } catch (Throwable t) {
- // Avoid throwing from this method. Log the error.
- TcLog.e(TAG, "Error detecting text language.", t);
+ final TextLanguage.Builder builder = new TextLanguage.Builder();
+ LangIdModel langIdImpl = getLangIdImpl();
+ final LangIdModel.LanguageResult[] langResults =
+ langIdImpl.detectLanguages(request.getText().toString());
+ for (LangIdModel.LanguageResult langResult : langResults) {
+ builder.putLocale(ULocale.forLanguageTag(langResult.getLanguage()), langResult.getScore());
}
- return fallback.detectLanguage(request);
+ return builder.build();
}
- ConversationActions suggestConversationActions(ConversationActions.Request request) {
+ ConversationActions suggestConversationActions(
+ @Nullable TextClassificationSessionId sessionId,
+ @Nullable TextClassificationContext textClassificationContext,
+ ConversationActions.Request request)
+ throws IOException {
Preconditions.checkNotNull(request);
checkMainThread();
- try {
- ActionsSuggestionsModel actionsImpl = getActionsImpl();
- if (actionsImpl == null) {
- // Actions model is optional, fallback if it is not available.
- return fallback.suggestConversationActions(request);
- }
- Optional<LangIdModel> langId = getLangIdImpl();
- ActionsSuggestionsModel.ConversationMessage[] nativeMessages =
- ActionsSuggestionsHelper.toNativeMessages(
- request.getConversation(), text -> detectLanguageTags(langId, text));
- if (nativeMessages.length == 0) {
- return fallback.suggestConversationActions(request);
- }
- ActionsSuggestionsModel.Conversation nativeConversation =
- new ActionsSuggestionsModel.Conversation(nativeMessages);
-
- ActionsSuggestionsModel.ActionSuggestion[] nativeSuggestions =
- actionsImpl.suggestActionsWithIntents(
- nativeConversation,
- null,
- context,
- getResourceLocalesString(),
- getAnnotatorImpl(LocaleList.getDefault()));
- return createConversationActionResult(request, nativeSuggestions);
- } catch (Throwable t) {
- // Avoid throwing from this method. Log the error.
- TcLog.e(TAG, "Error suggesting conversation actions.", t);
+ ActionsSuggestionsModel actionsImpl = getActionsImpl();
+ LangIdModel langId = getLangIdImpl();
+ ActionsSuggestionsModel.ConversationMessage[] nativeMessages =
+ ActionsSuggestionsHelper.toNativeMessages(
+ request.getConversation(), text -> detectLanguageTags(langId, text));
+ if (nativeMessages.length == 0) {
+ return new ConversationActions(ImmutableList.of(), /* id= */ null);
}
- return fallback.suggestConversationActions(request);
+ ActionsSuggestionsModel.Conversation nativeConversation =
+ new ActionsSuggestionsModel.Conversation(nativeMessages);
+
+ ActionSuggestions nativeSuggestions =
+ actionsImpl.suggestActionsWithIntents(
+ nativeConversation,
+ null,
+ context,
+ getResourceLocalesString(),
+ getAnnotatorImpl(LocaleList.getDefault()));
+ return createConversationActionResult(request, nativeSuggestions);
}
/**
@@ -457,11 +405,11 @@
* non-null component name is in the extras.
*/
private ConversationActions createConversationActionResult(
- ConversationActions.Request request,
- ActionsSuggestionsModel.ActionSuggestion[] nativeSuggestions) {
+ ConversationActions.Request request, ActionSuggestions nativeSuggestions) {
Collection<String> expectedTypes = resolveActionTypesFromRequest(request);
List<ConversationAction> conversationActions = new ArrayList<>();
- for (ActionsSuggestionsModel.ActionSuggestion nativeSuggestion : nativeSuggestions) {
+ for (ActionsSuggestionsModel.ActionSuggestion nativeSuggestion :
+ nativeSuggestions.actionSuggestions) {
String actionType = nativeSuggestion.getActionType();
if (!expectedTypes.contains(actionType)) {
continue;
@@ -512,92 +460,61 @@
return request.getTypeConfig().resolveEntityListModifications(defaultActionTypes);
}
- private AnnotatorModel getAnnotatorImpl(LocaleList localeList) throws FileNotFoundException {
+ private AnnotatorModel getAnnotatorImpl(LocaleList localeList) throws IOException {
synchronized (lock) {
localeList = localeList == null ? LocaleList.getDefault() : localeList;
final ModelFileManager.ModelFile bestModel =
- annotatorModelFileManager.findBestModelFile(localeList);
+ modelFileManager.findBestModelFile(ModelType.ANNOTATOR, localeList);
if (bestModel == null) {
- throw new FileNotFoundException("No annotator model for " + localeList.toLanguageTags());
+ throw new IllegalStateException("Failed to find the best annotator model");
}
if (annotatorImpl == null || !Objects.equals(annotatorModelInUse, bestModel)) {
TcLog.d(TAG, "Loading " + bestModel);
- final ParcelFileDescriptor pfd =
- ParcelFileDescriptor.open(
- new File(bestModel.getPath()), ParcelFileDescriptor.MODE_READ_ONLY);
- try {
- if (pfd != null) {
- // The current annotator model may be still used by another thread / model.
- // Do not call close() here, and let the GC to clean it up when no one else
- // is using it.
- annotatorImpl = new AnnotatorModel(pfd.getFd());
- Optional<LangIdModel> langIdModel = getLangIdImpl();
- if (langIdModel.isPresent()) {
- annotatorImpl.setLangIdModel(langIdModel.get());
- }
- annotatorModelInUse = bestModel;
- }
- } finally {
- maybeCloseAndLogError(pfd);
+ // The current annotator model may be still used by another thread / model.
+ // Do not call close() here, and let the GC to clean it up when no one else
+ // is using it.
+ try (AssetFileDescriptor afd = bestModel.open(context.getAssets())) {
+ annotatorImpl = new AnnotatorModel(afd);
+ annotatorImpl.setLangIdModel(getLangIdImpl());
+ annotatorModelInUse = bestModel;
}
}
return annotatorImpl;
}
}
- private Optional<LangIdModel> getLangIdImpl() {
+ private LangIdModel getLangIdImpl() throws IOException {
synchronized (lock) {
- final ModelFileManager.ModelFile bestModel = langIdModelFileManager.findBestModelFile(null);
+ final ModelFileManager.ModelFile bestModel =
+ modelFileManager.findBestModelFile(ModelType.LANG_ID, /* localePreferences= */ null);
if (bestModel == null) {
- return Optional.absent();
+ throw new IllegalStateException("Failed to find the best LangID model.");
}
if (langIdImpl == null || !Objects.equals(langIdModelInUse, bestModel)) {
TcLog.d(TAG, "Loading " + bestModel);
- final ParcelFileDescriptor pfd;
- try {
- pfd =
- ParcelFileDescriptor.open(
- new File(bestModel.getPath()), ParcelFileDescriptor.MODE_READ_ONLY);
- } catch (FileNotFoundException e) {
- TcLog.e(TAG, "Failed to open the LangID model file", e);
- return Optional.absent();
- }
- try {
- if (pfd != null) {
- langIdImpl = new LangIdModel(pfd.getFd());
- langIdModelInUse = bestModel;
- }
- } finally {
- maybeCloseAndLogError(pfd);
+ try (AssetFileDescriptor afd = bestModel.open(context.getAssets())) {
+ langIdImpl = new LangIdModel(afd);
+ langIdModelInUse = bestModel;
}
}
- return Optional.of(langIdImpl);
+ return langIdImpl;
}
}
- @Nullable
- private ActionsSuggestionsModel getActionsImpl() throws FileNotFoundException {
+ private ActionsSuggestionsModel getActionsImpl() throws IOException {
synchronized (lock) {
// TODO: Use LangID to determine the locale we should use here?
final ModelFileManager.ModelFile bestModel =
- actionsModelFileManager.findBestModelFile(LocaleList.getDefault());
+ modelFileManager.findBestModelFile(
+ ModelType.ACTIONS_SUGGESTIONS, LocaleList.getDefault());
if (bestModel == null) {
- return null;
+ throw new IllegalStateException("Failed to find the best actions model");
}
if (actionsImpl == null || !Objects.equals(actionModelInUse, bestModel)) {
TcLog.d(TAG, "Loading " + bestModel);
- final ParcelFileDescriptor pfd =
- ParcelFileDescriptor.open(
- new File(bestModel.getPath()), ParcelFileDescriptor.MODE_READ_ONLY);
- try {
- if (pfd == null) {
- TcLog.d(TAG, "Failed to read the model file: " + bestModel.getPath());
- return null;
- }
- actionsImpl = new ActionsSuggestionsModel(pfd.getFd());
+ try (AssetFileDescriptor afd = bestModel.open(context.getAssets())) {
+ actionsImpl = new ActionsSuggestionsModel(afd);
actionModelInUse = bestModel;
- } finally {
- maybeCloseAndLogError(pfd);
}
}
return actionsImpl;
@@ -625,7 +542,7 @@
String text,
int start,
int end,
- Optional<LangIdModel> langId) {
+ LangIdModel langId) {
final String classifiedText = text.substring(start, end);
final TextClassification.Builder builder =
new TextClassification.Builder().setText(classifiedText);
@@ -672,10 +589,7 @@
actionIntents.add(intent);
}
Bundle extras = new Bundle();
- Optional<Bundle> foreignLanguageExtra =
- langId
- .transform(model -> maybeCreateExtrasForTranslate(actionIntents, model))
- .or(Optional.<Bundle>absent());
+ Optional<Bundle> foreignLanguageExtra = maybeCreateExtrasForTranslate(actionIntents, langId);
if (foreignLanguageExtra.isPresent()) {
ExtrasUtils.putForeignLanguageExtra(extras, foreignLanguageExtra.get());
}
@@ -722,16 +636,10 @@
topLanguageWithScore.first, topLanguageWithScore.second, langId.getVersion()));
}
- private ImmutableList<String> detectLanguageTags(
- Optional<LangIdModel> langId, CharSequence text) {
- return langId
- .transform(
- model -> {
- float threshold = getLangIdThreshold(model);
- EntityConfidence languagesConfidence = detectLanguages(model, text, threshold);
- return ImmutableList.copyOf(languagesConfidence.getEntities());
- })
- .or(ImmutableList.of());
+ private ImmutableList<String> detectLanguageTags(LangIdModel langId, CharSequence text) {
+ float threshold = getLangIdThreshold(langId);
+ EntityConfidence languagesConfidence = detectLanguages(langId, text, threshold);
+ return ImmutableList.copyOf(languagesConfidence.getEntities());
}
/**
@@ -759,29 +667,14 @@
void dump(IndentingPrintWriter printWriter) {
synchronized (lock) {
printWriter.println("TextClassifierImpl:");
+
printWriter.increaseIndent();
- printWriter.println("Annotator model file(s):");
- printWriter.increaseIndent();
- for (ModelFileManager.ModelFile modelFile : annotatorModelFileManager.listModelFiles()) {
- printWriter.println(modelFile.toString());
- }
+ modelFileManager.dump(printWriter);
printWriter.decreaseIndent();
- printWriter.println("LangID model file(s):");
- printWriter.increaseIndent();
- for (ModelFileManager.ModelFile modelFile : langIdModelFileManager.listModelFiles()) {
- printWriter.println(modelFile.toString());
- }
- printWriter.decreaseIndent();
- printWriter.println("Actions model file(s):");
- printWriter.increaseIndent();
- for (ModelFileManager.ModelFile modelFile : actionsModelFileManager.listModelFiles()) {
- printWriter.println(modelFile.toString());
- }
- printWriter.decreaseIndent();
- printWriter.printPair("mFallback", fallback);
- printWriter.decreaseIndent();
+
printWriter.println();
settings.dump(printWriter);
+ printWriter.println();
}
}
@@ -796,29 +689,19 @@
}
}
- /** Closes the ParcelFileDescriptor, if non-null, and logs any errors that occur. */
- private static void maybeCloseAndLogError(@Nullable ParcelFileDescriptor fd) {
- if (fd == null) {
- return;
- }
-
- try {
- fd.close();
- } catch (IOException e) {
- TcLog.e(TAG, "Error closing file.", e);
- }
- }
-
private static void checkMainThread() {
if (Looper.myLooper() == Looper.getMainLooper()) {
- TcLog.e(TAG, "TextClassifier called on main thread", new Exception());
+ TcLog.e(TAG, "TCS TextClassifier called on main thread", new Exception());
}
}
private static PendingIntent createPendingIntent(
final Context context, final Intent intent, int requestCode) {
return PendingIntent.getActivity(
- context, requestCode, intent, PendingIntent.FLAG_UPDATE_CURRENT);
+ context,
+ requestCode,
+ intent,
+ PendingIntent.FLAG_UPDATE_CURRENT | PendingIntent.FLAG_IMMUTABLE);
}
@Nullable
diff --git a/java/src/com/android/textclassifier/TextClassifierSettings.java b/java/src/com/android/textclassifier/TextClassifierSettings.java
deleted file mode 100644
index 3decd38..0000000
--- a/java/src/com/android/textclassifier/TextClassifierSettings.java
+++ /dev/null
@@ -1,324 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.android.textclassifier;
-
-import android.provider.DeviceConfig;
-import android.view.textclassifier.ConversationAction;
-import android.view.textclassifier.TextClassifier;
-import com.android.textclassifier.utils.IndentingPrintWriter;
-import com.google.common.annotations.VisibleForTesting;
-import com.google.common.base.Splitter;
-import com.google.common.collect.ImmutableList;
-import java.util.Arrays;
-import java.util.Collections;
-import java.util.List;
-import javax.annotation.Nullable;
-
-/**
- * TextClassifier specific settings.
- *
- * <p>Currently, this class does not guarantee co-diverted flags are updated atomically.
- *
- * <p>Example of setting the values for testing.
- *
- * <pre>
- * adb shell cmd device_config put textclassifier system_textclassifier_enabled true
- * </pre>
- *
- * @see android.provider.DeviceConfig#NAMESPACE_TEXTCLASSIFIER
- */
-public final class TextClassifierSettings {
- private static final String DELIMITER = ":";
-
- /** Whether the user language profile feature is enabled. */
- private static final String USER_LANGUAGE_PROFILE_ENABLED = "user_language_profile_enabled";
- /** Max length of text that suggestSelection can accept. */
- @VisibleForTesting
- static final String SUGGEST_SELECTION_MAX_RANGE_LENGTH = "suggest_selection_max_range_length";
- /** Max length of text that classifyText can accept. */
- private static final String CLASSIFY_TEXT_MAX_RANGE_LENGTH = "classify_text_max_range_length";
- /** Max length of text that generateLinks can accept. */
- private static final String GENERATE_LINKS_MAX_TEXT_LENGTH = "generate_links_max_text_length";
- /** Sampling rate for generateLinks logging. */
- private static final String GENERATE_LINKS_LOG_SAMPLE_RATE = "generate_links_log_sample_rate";
- /**
- * Extra count that is added to some languages, e.g. system languages, when deducing the frequent
- * languages in {@link
- * com.android.textclassifier.ulp.LanguageProfileAnalyzer#getFrequentLanguages(int)}.
- */
-
- /**
- * A colon(:) separated string that specifies the default entities types for generateLinks when
- * hint is not given.
- */
- @VisibleForTesting static final String ENTITY_LIST_DEFAULT = "entity_list_default";
- /**
- * A colon(:) separated string that specifies the default entities types for generateLinks when
- * the text is in a not editable UI widget.
- */
- private static final String ENTITY_LIST_NOT_EDITABLE = "entity_list_not_editable";
- /**
- * A colon(:) separated string that specifies the default entities types for generateLinks when
- * the text is in an editable UI widget.
- */
- private static final String ENTITY_LIST_EDITABLE = "entity_list_editable";
- /**
- * A colon(:) separated string that specifies the default action types for
- * suggestConversationActions when the suggestions are used in an app.
- */
- private static final String IN_APP_CONVERSATION_ACTION_TYPES_DEFAULT =
- "in_app_conversation_action_types_default";
- /**
- * A colon(:) separated string that specifies the default action types for
- * suggestConversationActions when the suggestions are used in a notification.
- */
- private static final String NOTIFICATION_CONVERSATION_ACTION_TYPES_DEFAULT =
- "notification_conversation_action_types_default";
- /** Threshold to accept a suggested language from LangID model. */
- @VisibleForTesting static final String LANG_ID_THRESHOLD_OVERRIDE = "lang_id_threshold_override";
- /** Whether to enable {@link com.android.textclassifier.intent.TemplateIntentFactory}. */
- @VisibleForTesting
- static final String TEMPLATE_INTENT_FACTORY_ENABLED = "template_intent_factory_enabled";
- /** Whether to enable "translate" action in classifyText. */
- private static final String TRANSLATE_IN_CLASSIFICATION_ENABLED =
- "translate_in_classification_enabled";
- /**
- * Whether to detect the languages of the text in request by using langId for the native model.
- */
- private static final String DETECT_LANGUAGES_FROM_TEXT_ENABLED =
- "detect_languages_from_text_enabled";
- /**
- * A colon(:) separated string that specifies the configuration to use when including surrounding
- * context text in language detection queries.
- *
- * <p>Format= minimumTextSize<int>:penalizeRatio<float>:textScoreRatio<float>
- *
- * <p>e.g. 20:1.0:0.4
- *
- * <p>Accept all text lengths with minimumTextSize=0
- *
- * <p>Reject all text less than minimumTextSize with penalizeRatio=0
- *
- * @see {@code TextClassifierImpl#detectLanguages(String, int, int)} for reference.
- */
- @VisibleForTesting static final String LANG_ID_CONTEXT_SETTINGS = "lang_id_context_settings";
- /** Default threshold to translate the language of the context the user selects */
- private static final String TRANSLATE_ACTION_THRESHOLD = "translate_action_threshold";
-
- // Sync this with ConversationAction.TYPE_ADD_CONTACT;
- public static final String TYPE_ADD_CONTACT = "add_contact";
- // Sync this with ConversationAction.COPY;
- public static final String TYPE_COPY = "copy";
-
- private static final int SUGGEST_SELECTION_MAX_RANGE_LENGTH_DEFAULT = 10 * 1000;
- private static final int CLASSIFY_TEXT_MAX_RANGE_LENGTH_DEFAULT = 10 * 1000;
- private static final int GENERATE_LINKS_MAX_TEXT_LENGTH_DEFAULT = 100 * 1000;
- private static final int GENERATE_LINKS_LOG_SAMPLE_RATE_DEFAULT = 100;
-
- private static final ImmutableList<String> ENTITY_LIST_DEFAULT_VALUE =
- ImmutableList.of(
- TextClassifier.TYPE_ADDRESS,
- TextClassifier.TYPE_EMAIL,
- TextClassifier.TYPE_PHONE,
- TextClassifier.TYPE_URL,
- TextClassifier.TYPE_DATE,
- TextClassifier.TYPE_DATE_TIME,
- TextClassifier.TYPE_FLIGHT_NUMBER);
- private static final ImmutableList<String> CONVERSATION_ACTIONS_TYPES_DEFAULT_VALUES =
- ImmutableList.of(
- ConversationAction.TYPE_TEXT_REPLY,
- ConversationAction.TYPE_CREATE_REMINDER,
- ConversationAction.TYPE_CALL_PHONE,
- ConversationAction.TYPE_OPEN_URL,
- ConversationAction.TYPE_SEND_EMAIL,
- ConversationAction.TYPE_SEND_SMS,
- ConversationAction.TYPE_TRACK_FLIGHT,
- ConversationAction.TYPE_VIEW_CALENDAR,
- ConversationAction.TYPE_VIEW_MAP,
- TYPE_ADD_CONTACT,
- TYPE_COPY);
- /**
- * < 0 : Not set. Use value from LangId model. 0 - 1: Override value in LangId model.
- *
- * @see EntityConfidence
- */
- private static final float LANG_ID_THRESHOLD_OVERRIDE_DEFAULT = -1f;
-
- private static final float TRANSLATE_ACTION_THRESHOLD_DEFAULT = 0.5f;
-
- private static final boolean USER_LANGUAGE_PROFILE_ENABLED_DEFAULT = true;
- private static final boolean TEMPLATE_INTENT_FACTORY_ENABLED_DEFAULT = true;
- private static final boolean TRANSLATE_IN_CLASSIFICATION_ENABLED_DEFAULT = true;
- private static final boolean DETECT_LANGUAGES_FROM_TEXT_ENABLED_DEFAULT = true;
- private static final float[] LANG_ID_CONTEXT_SETTINGS_DEFAULT = new float[] {20f, 1.0f, 0.4f};
-
- public int getSuggestSelectionMaxRangeLength() {
- return DeviceConfig.getInt(
- DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
- SUGGEST_SELECTION_MAX_RANGE_LENGTH,
- SUGGEST_SELECTION_MAX_RANGE_LENGTH_DEFAULT);
- }
-
- public int getClassifyTextMaxRangeLength() {
- return DeviceConfig.getInt(
- DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
- CLASSIFY_TEXT_MAX_RANGE_LENGTH,
- CLASSIFY_TEXT_MAX_RANGE_LENGTH_DEFAULT);
- }
-
- public int getGenerateLinksMaxTextLength() {
- return DeviceConfig.getInt(
- DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
- GENERATE_LINKS_MAX_TEXT_LENGTH,
- GENERATE_LINKS_MAX_TEXT_LENGTH_DEFAULT);
- }
-
- public int getGenerateLinksLogSampleRate() {
- return DeviceConfig.getInt(
- DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
- GENERATE_LINKS_LOG_SAMPLE_RATE,
- GENERATE_LINKS_LOG_SAMPLE_RATE_DEFAULT);
- }
-
- public List<String> getEntityListDefault() {
- return getDeviceConfigStringList(ENTITY_LIST_DEFAULT, ENTITY_LIST_DEFAULT_VALUE);
- }
-
- public List<String> getEntityListNotEditable() {
- return getDeviceConfigStringList(ENTITY_LIST_NOT_EDITABLE, ENTITY_LIST_DEFAULT_VALUE);
- }
-
- public List<String> getEntityListEditable() {
- return getDeviceConfigStringList(ENTITY_LIST_EDITABLE, ENTITY_LIST_DEFAULT_VALUE);
- }
-
- public List<String> getInAppConversationActionTypes() {
- return getDeviceConfigStringList(
- IN_APP_CONVERSATION_ACTION_TYPES_DEFAULT, CONVERSATION_ACTIONS_TYPES_DEFAULT_VALUES);
- }
-
- public List<String> getNotificationConversationActionTypes() {
- return getDeviceConfigStringList(
- NOTIFICATION_CONVERSATION_ACTION_TYPES_DEFAULT, CONVERSATION_ACTIONS_TYPES_DEFAULT_VALUES);
- }
-
- public float getLangIdThresholdOverride() {
- return DeviceConfig.getFloat(
- DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
- LANG_ID_THRESHOLD_OVERRIDE,
- LANG_ID_THRESHOLD_OVERRIDE_DEFAULT);
- }
-
- public float getTranslateActionThreshold() {
- return DeviceConfig.getFloat(
- DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
- TRANSLATE_ACTION_THRESHOLD,
- TRANSLATE_ACTION_THRESHOLD_DEFAULT);
- }
-
- public boolean isUserLanguageProfileEnabled() {
- return DeviceConfig.getBoolean(
- DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
- USER_LANGUAGE_PROFILE_ENABLED,
- USER_LANGUAGE_PROFILE_ENABLED_DEFAULT);
- }
-
- public boolean isTemplateIntentFactoryEnabled() {
- return DeviceConfig.getBoolean(
- DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
- TEMPLATE_INTENT_FACTORY_ENABLED,
- TEMPLATE_INTENT_FACTORY_ENABLED_DEFAULT);
- }
-
- public boolean isTranslateInClassificationEnabled() {
- return DeviceConfig.getBoolean(
- DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
- TRANSLATE_IN_CLASSIFICATION_ENABLED,
- TRANSLATE_IN_CLASSIFICATION_ENABLED_DEFAULT);
- }
-
- public boolean isDetectLanguagesFromTextEnabled() {
- return DeviceConfig.getBoolean(
- DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
- DETECT_LANGUAGES_FROM_TEXT_ENABLED,
- DETECT_LANGUAGES_FROM_TEXT_ENABLED_DEFAULT);
- }
-
- public float[] getLangIdContextSettings() {
- return getDeviceConfigFloatArray(LANG_ID_CONTEXT_SETTINGS, LANG_ID_CONTEXT_SETTINGS_DEFAULT);
- }
-
- void dump(IndentingPrintWriter pw) {
- pw.println("TextClassifierSettings:");
- pw.increaseIndent();
- pw.printPair("classify_text_max_range_length", getClassifyTextMaxRangeLength());
- pw.printPair("detect_language_from_text_enabled", isDetectLanguagesFromTextEnabled());
- pw.printPair("entity_list_default", getEntityListDefault());
- pw.printPair("entity_list_editable", getEntityListEditable());
- pw.printPair("entity_list_not_editable", getEntityListNotEditable());
- pw.printPair("generate_links_log_sample_rate", getGenerateLinksLogSampleRate());
- pw.printPair("generate_links_max_text_length", getGenerateLinksMaxTextLength());
- pw.printPair("in_app_conversation_action_types_default", getInAppConversationActionTypes());
- pw.printPair("lang_id_context_settings", Arrays.toString(getLangIdContextSettings()));
- pw.printPair("lang_id_threshold_override", getLangIdThresholdOverride());
- pw.printPair("translate_action_threshold", getTranslateActionThreshold());
- pw.printPair(
- "notification_conversation_action_types_default", getNotificationConversationActionTypes());
- pw.printPair("suggest_selection_max_range_length", getSuggestSelectionMaxRangeLength());
- pw.printPair("user_language_profile_enabled", isUserLanguageProfileEnabled());
- pw.printPair("template_intent_factory_enabled", isTemplateIntentFactoryEnabled());
- pw.printPair("translate_in_classification_enabled", isTranslateInClassificationEnabled());
- pw.decreaseIndent();
- }
-
- private static List<String> getDeviceConfigStringList(String key, List<String> defaultValue) {
- return parse(
- DeviceConfig.getString(DeviceConfig.NAMESPACE_TEXTCLASSIFIER, key, null), defaultValue);
- }
-
- private static float[] getDeviceConfigFloatArray(String key, float[] defaultValue) {
- return parse(
- DeviceConfig.getString(DeviceConfig.NAMESPACE_TEXTCLASSIFIER, key, null), defaultValue);
- }
-
- private static List<String> parse(@Nullable String listStr, List<String> defaultValue) {
- if (listStr != null) {
- return Collections.unmodifiableList(Arrays.asList(listStr.split(DELIMITER)));
- }
- return defaultValue;
- }
-
- private static float[] parse(@Nullable String arrayStr, float[] defaultValue) {
- if (arrayStr != null) {
- final List<String> split = Splitter.onPattern(DELIMITER).splitToList(arrayStr);
- if (split.size() != defaultValue.length) {
- return defaultValue;
- }
- final float[] result = new float[split.size()];
- for (int i = 0; i < split.size(); i++) {
- try {
- result[i] = Float.parseFloat(split.get(i));
- } catch (NumberFormatException e) {
- return defaultValue;
- }
- }
- return result;
- } else {
- return defaultValue;
- }
- }
-}
diff --git a/java/src/com/android/textclassifier/common/ModelFileManager.java b/java/src/com/android/textclassifier/common/ModelFileManager.java
new file mode 100644
index 0000000..406a889
--- /dev/null
+++ b/java/src/com/android/textclassifier/common/ModelFileManager.java
@@ -0,0 +1,603 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.common;
+
+import android.content.Context;
+import android.content.res.AssetFileDescriptor;
+import android.content.res.AssetManager;
+import android.os.LocaleList;
+import android.os.ParcelFileDescriptor;
+import android.util.ArraySet;
+import androidx.annotation.GuardedBy;
+import androidx.collection.ArrayMap;
+import com.android.textclassifier.common.ModelType.ModelTypeDef;
+import com.android.textclassifier.common.base.TcLog;
+import com.android.textclassifier.common.logging.ResultIdUtils.ModelInfo;
+import com.android.textclassifier.utils.IndentingPrintWriter;
+import com.google.android.textclassifier.ActionsSuggestionsModel;
+import com.google.android.textclassifier.AnnotatorModel;
+import com.google.android.textclassifier.LangIdModel;
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Function;
+import com.google.common.base.Optional;
+import com.google.common.base.Preconditions;
+import com.google.common.base.Supplier;
+import com.google.common.collect.ImmutableList;
+import java.io.File;
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Locale;
+import java.util.Map;
+import java.util.Objects;
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
+import java.util.stream.Collectors;
+import javax.annotation.Nullable;
+
+// TODO(licha): Consider making this a singleton class
+// TODO(licha): Check whether this is thread-safe
+/**
+ * Manages all model files in storage. {@link TextClassifierImpl} depends on this class to get the
+ * model files to load.
+ */
+public final class ModelFileManager {
+
+ private static final String TAG = "ModelFileManager";
+
+ private static final String DOWNLOAD_SUB_DIR_NAME = "textclassifier/downloads/models/";
+ private static final File CONFIG_UPDATER_DIR = new File("/data/misc/textclassifier/");
+ private static final String ASSETS_DIR = "textclassifier";
+
+ private final List<ModelFileLister> modelFileListers;
+ private final File modelDownloaderDir;
+
+ public ModelFileManager(Context context, TextClassifierSettings settings) {
+ Preconditions.checkNotNull(context);
+ Preconditions.checkNotNull(settings);
+
+ AssetManager assetManager = context.getAssets();
+ this.modelDownloaderDir = new File(context.getFilesDir(), DOWNLOAD_SUB_DIR_NAME);
+ modelFileListers =
+ ImmutableList.of(
+ // Annotator models.
+ new RegularFilePatternMatchLister(
+ ModelType.ANNOTATOR,
+ this.modelDownloaderDir,
+ "annotator\\.(.*)\\.model",
+ settings::isModelDownloadManagerEnabled),
+ new RegularFileFullMatchLister(
+ ModelType.ANNOTATOR,
+ new File(CONFIG_UPDATER_DIR, "textclassifier.model"),
+ /* isEnabled= */ () -> true),
+ new AssetFilePatternMatchLister(
+ assetManager,
+ ModelType.ANNOTATOR,
+ ASSETS_DIR,
+ "annotator\\.(.*)\\.model",
+ /* isEnabled= */ () -> true),
+ // Actions models.
+ new RegularFilePatternMatchLister(
+ ModelType.ACTIONS_SUGGESTIONS,
+ this.modelDownloaderDir,
+ "actions_suggestions\\.(.*)\\.model",
+ settings::isModelDownloadManagerEnabled),
+ new RegularFileFullMatchLister(
+ ModelType.ACTIONS_SUGGESTIONS,
+ new File(CONFIG_UPDATER_DIR, "actions_suggestions.model"),
+ /* isEnabled= */ () -> true),
+ new AssetFilePatternMatchLister(
+ assetManager,
+ ModelType.ACTIONS_SUGGESTIONS,
+ ASSETS_DIR,
+ "actions_suggestions\\.(.*)\\.model",
+ /* isEnabled= */ () -> true),
+ // LangID models.
+ new RegularFilePatternMatchLister(
+ ModelType.LANG_ID,
+ this.modelDownloaderDir,
+ "lang_id\\.(.*)\\.model",
+ settings::isModelDownloadManagerEnabled),
+ new RegularFileFullMatchLister(
+ ModelType.LANG_ID,
+ new File(CONFIG_UPDATER_DIR, "lang_id.model"),
+ /* isEnabled= */ () -> true),
+ new AssetFilePatternMatchLister(
+ assetManager,
+ ModelType.LANG_ID,
+ ASSETS_DIR,
+ "lang_id.model",
+ /* isEnabled= */ () -> true));
+ }
+
+ @VisibleForTesting
+ public ModelFileManager(Context context, List<ModelFileLister> modelFileListers) {
+ this.modelDownloaderDir = new File(context.getFilesDir(), DOWNLOAD_SUB_DIR_NAME);
+ this.modelFileListers = ImmutableList.copyOf(modelFileListers);
+ }
+
+ /**
+ * Returns an immutable list of model files listed by the given model files supplier.
+ *
+ * @param modelType which type of model files to look for
+ */
+ public ImmutableList<ModelFile> listModelFiles(@ModelTypeDef String modelType) {
+ Preconditions.checkNotNull(modelType);
+
+ ImmutableList.Builder<ModelFile> modelFiles = new ImmutableList.Builder<>();
+ for (ModelFileLister modelFileLister : modelFileListers) {
+ modelFiles.addAll(modelFileLister.list(modelType));
+ }
+ return modelFiles.build();
+ }
+
+ /** Lists model files. */
+ public interface ModelFileLister {
+ List<ModelFile> list(@ModelTypeDef String modelType);
+ }
+
+ /** Lists model files by performing full match on file path. */
+ public static class RegularFileFullMatchLister implements ModelFileLister {
+ private final String modelType;
+ private final File targetFile;
+ private final Supplier<Boolean> isEnabled;
+
+ /**
+ * @param modelType the type of the model
+ * @param targetFile the expected model file
+ * @param isEnabled whether this lister is enabled
+ */
+ public RegularFileFullMatchLister(
+ @ModelTypeDef String modelType, File targetFile, Supplier<Boolean> isEnabled) {
+ this.modelType = Preconditions.checkNotNull(modelType);
+ this.targetFile = Preconditions.checkNotNull(targetFile);
+ this.isEnabled = Preconditions.checkNotNull(isEnabled);
+ }
+
+ @Override
+ public ImmutableList<ModelFile> list(@ModelTypeDef String modelType) {
+ if (!this.modelType.equals(modelType)) {
+ return ImmutableList.of();
+ }
+ if (!isEnabled.get()) {
+ return ImmutableList.of();
+ }
+ if (!targetFile.exists()) {
+ return ImmutableList.of();
+ }
+ try {
+ return ImmutableList.of(ModelFile.createFromRegularFile(targetFile, modelType));
+ } catch (IOException e) {
+ TcLog.e(
+ TAG, "Failed to call createFromRegularFile with: " + targetFile.getAbsolutePath(), e);
+ }
+ return ImmutableList.of();
+ }
+ }
+
+ /** Lists model file in a specified folder by doing pattern matching on file names. */
+ public static class RegularFilePatternMatchLister implements ModelFileLister {
+ private final String modelType;
+ private final File folder;
+ private final Pattern fileNamePattern;
+ private final Supplier<Boolean> isEnabled;
+
+ /**
+ * @param modelType the type of the model
+ * @param folder the folder to list files
+ * @param fileNameRegex the regex to match the file name in the specified folder
+ * @param isEnabled whether the lister is enabled
+ */
+ public RegularFilePatternMatchLister(
+ @ModelTypeDef String modelType,
+ File folder,
+ String fileNameRegex,
+ Supplier<Boolean> isEnabled) {
+ this.modelType = Preconditions.checkNotNull(modelType);
+ this.folder = Preconditions.checkNotNull(folder);
+ this.fileNamePattern = Pattern.compile(Preconditions.checkNotNull(fileNameRegex));
+ this.isEnabled = Preconditions.checkNotNull(isEnabled);
+ }
+
+ @Override
+ public ImmutableList<ModelFile> list(@ModelTypeDef String modelType) {
+ if (!this.modelType.equals(modelType)) {
+ return ImmutableList.of();
+ }
+ if (!isEnabled.get()) {
+ return ImmutableList.of();
+ }
+ if (!folder.isDirectory()) {
+ return ImmutableList.of();
+ }
+ File[] files = folder.listFiles();
+ if (files == null) {
+ return ImmutableList.of();
+ }
+ ImmutableList.Builder<ModelFile> modelFilesBuilder = ImmutableList.builder();
+ for (File file : files) {
+ final Matcher matcher = fileNamePattern.matcher(file.getName());
+ if (!matcher.matches() || !file.isFile()) {
+ continue;
+ }
+ try {
+ modelFilesBuilder.add(ModelFile.createFromRegularFile(file, modelType));
+ } catch (IOException e) {
+ TcLog.w(TAG, "Failed to call createFromRegularFile with: " + file.getAbsolutePath());
+ }
+ }
+ return modelFilesBuilder.build();
+ }
+ }
+
+ /** Lists the model files preloaded in the APK file. */
+ public static class AssetFilePatternMatchLister implements ModelFileLister {
+ private final AssetManager assetManager;
+ private final String modelType;
+ private final String pathToList;
+ private final Pattern fileNamePattern;
+ private final Supplier<Boolean> isEnabled;
+ private final Object lock = new Object();
+ // Assets won't change without updating the app, so cache the result for performance reason.
+ @GuardedBy("lock")
+ private final Map<String, ImmutableList<ModelFile>> resultCache;
+
+ /**
+ * @param modelType the type of the model.
+ * @param pathToList the folder to list files
+ * @param fileNameRegex the regex to match the file name in the specified folder
+ * @param isEnabled whether this lister is enabled
+ */
+ public AssetFilePatternMatchLister(
+ AssetManager assetManager,
+ @ModelTypeDef String modelType,
+ String pathToList,
+ String fileNameRegex,
+ Supplier<Boolean> isEnabled) {
+ this.assetManager = Preconditions.checkNotNull(assetManager);
+ this.modelType = Preconditions.checkNotNull(modelType);
+ this.pathToList = Preconditions.checkNotNull(pathToList);
+ this.fileNamePattern = Pattern.compile(Preconditions.checkNotNull(fileNameRegex));
+ this.isEnabled = Preconditions.checkNotNull(isEnabled);
+ resultCache = new ArrayMap<>();
+ }
+
+ @Override
+ public ImmutableList<ModelFile> list(@ModelTypeDef String modelType) {
+ if (!this.modelType.equals(modelType)) {
+ return ImmutableList.of();
+ }
+ if (!isEnabled.get()) {
+ return ImmutableList.of();
+ }
+ synchronized (lock) {
+ if (resultCache.get(modelType) != null) {
+ return resultCache.get(modelType);
+ }
+ String[] fileNames = null;
+ try {
+ fileNames = assetManager.list(pathToList);
+ } catch (IOException e) {
+ TcLog.e(TAG, "Failed to list assets", e);
+ }
+ if (fileNames == null) {
+ return ImmutableList.of();
+ }
+ ImmutableList.Builder<ModelFile> modelFilesBuilder = ImmutableList.builder();
+ for (String fileName : fileNames) {
+ final Matcher matcher = fileNamePattern.matcher(fileName);
+ if (!matcher.matches()) {
+ continue;
+ }
+ String absolutePath =
+ new StringBuilder(pathToList).append('/').append(fileName).toString();
+ try {
+ modelFilesBuilder.add(ModelFile.createFromAsset(assetManager, absolutePath, modelType));
+ } catch (IOException e) {
+ TcLog.w(TAG, "Failed to call createFromAsset with: " + absolutePath);
+ }
+ }
+ ImmutableList<ModelFile> result = modelFilesBuilder.build();
+ resultCache.put(modelType, result);
+ return result;
+ }
+ }
+ }
+
+ /**
+ * Returns the best model file for the given localelist, {@code null} if nothing is found.
+ *
+ * @param modelType the type of model to look up (e.g. annotator, lang_id, etc.)
+ * @param localePreferences an ordered list of user preferences for locales, use {@code null} if
+ * there is no preference.
+ */
+ @Nullable
+ public ModelFile findBestModelFile(
+ @ModelTypeDef String modelType, @Nullable LocaleList localePreferences) {
+ final String languages =
+ localePreferences == null || localePreferences.isEmpty()
+ ? LocaleList.getDefault().toLanguageTags()
+ : localePreferences.toLanguageTags();
+ final List<Locale.LanguageRange> languageRangeList = Locale.LanguageRange.parse(languages);
+
+ ModelFile bestModel = null;
+ for (ModelFile model : listModelFiles(modelType)) {
+ // TODO(licha): update this when we want to support multiple languages
+ if (model.isAnyLanguageSupported(languageRangeList)) {
+ if (model.isPreferredTo(bestModel)) {
+ bestModel = model;
+ }
+ }
+ }
+ return bestModel;
+ }
+
+ /**
+ * Deletes model files that are not preferred for any locales in user's preference.
+ *
+ * <p>This method will be invoked as a clean-up after we download a new model successfully. Race
+ * conditions are hard to avoid because we do not hold locks for files. But it should rarely cause
+ * any issues since it's safe to delete a model file in use (b/c we mmap it to memory).
+ */
+ public void deleteUnusedModelFiles() {
+ TcLog.d(TAG, "Start to delete unused model files.");
+ LocaleList localeList = LocaleList.getDefault();
+ for (@ModelTypeDef String modelType : ModelType.values()) {
+ ArraySet<ModelFile> allModelFiles = new ArraySet<>(listModelFiles(modelType));
+ for (int i = 0; i < localeList.size(); i++) {
+ // If a model file is preferred for any local in locale list, then keep it
+ ModelFile bestModel = findBestModelFile(modelType, new LocaleList(localeList.get(i)));
+ allModelFiles.remove(bestModel);
+ }
+ for (ModelFile modelFile : allModelFiles) {
+ if (modelFile.canWrite()) {
+ TcLog.d(TAG, "Deleting model: " + modelFile);
+ if (!modelFile.delete()) {
+ TcLog.w(TAG, "Failed to delete model: " + modelFile);
+ }
+ }
+ }
+ }
+ }
+
+ /** Returns the directory containing models downloaded by the downloader. */
+ public File getModelDownloaderDir() {
+ return modelDownloaderDir;
+ }
+
+ /**
+ * Dumps the internal state for debugging.
+ *
+ * @param printWriter writer to write dumped states
+ */
+ public void dump(IndentingPrintWriter printWriter) {
+ printWriter.println("ModelFileManager:");
+ printWriter.increaseIndent();
+ for (@ModelTypeDef String modelType : ModelType.values()) {
+ printWriter.println(modelType + " model file(s):");
+ printWriter.increaseIndent();
+ for (ModelFile modelFile : listModelFiles(modelType)) {
+ printWriter.println(modelFile.toString());
+ }
+ printWriter.decreaseIndent();
+ }
+ printWriter.decreaseIndent();
+ }
+
+ /** Fetch metadata of a model file. */
+ private static class ModelInfoFetcher {
+ private final Function<AssetFileDescriptor, Integer> versionFetcher;
+ private final Function<AssetFileDescriptor, String> supportedLocalesFetcher;
+
+ private ModelInfoFetcher(
+ Function<AssetFileDescriptor, Integer> versionFetcher,
+ Function<AssetFileDescriptor, String> supportedLocalesFetcher) {
+ this.versionFetcher = versionFetcher;
+ this.supportedLocalesFetcher = supportedLocalesFetcher;
+ }
+
+ int getVersion(AssetFileDescriptor assetFileDescriptor) {
+ return versionFetcher.apply(assetFileDescriptor);
+ }
+
+ String getSupportedLocales(AssetFileDescriptor assetFileDescriptor) {
+ return supportedLocalesFetcher.apply(assetFileDescriptor);
+ }
+
+ static ModelInfoFetcher create(@ModelTypeDef String modelType) {
+ switch (modelType) {
+ case ModelType.ANNOTATOR:
+ return new ModelInfoFetcher(AnnotatorModel::getVersion, AnnotatorModel::getLocales);
+ case ModelType.ACTIONS_SUGGESTIONS:
+ return new ModelInfoFetcher(
+ ActionsSuggestionsModel::getVersion, ActionsSuggestionsModel::getLocales);
+ case ModelType.LANG_ID:
+ return new ModelInfoFetcher(
+ LangIdModel::getVersion, afd -> ModelFile.LANGUAGE_INDEPENDENT);
+ default: // fall out
+ }
+ throw new IllegalStateException("Unsupported model types");
+ }
+ }
+
+ /** Describes TextClassifier model files on disk. */
+ public static class ModelFile {
+ @VisibleForTesting static final String LANGUAGE_INDEPENDENT = "*";
+
+ @ModelTypeDef public final String modelType;
+ public final String absolutePath;
+ public final int version;
+ public final LocaleList supportedLocales;
+ public final boolean languageIndependent;
+ public final boolean isAsset;
+
+ public static ModelFile createFromRegularFile(File file, @ModelTypeDef String modelType)
+ throws IOException {
+ ParcelFileDescriptor pfd =
+ ParcelFileDescriptor.open(file, ParcelFileDescriptor.MODE_READ_ONLY);
+ try (AssetFileDescriptor afd = new AssetFileDescriptor(pfd, 0, file.length())) {
+ return createFromAssetFileDescriptor(
+ file.getAbsolutePath(), modelType, afd, /* isAsset= */ false);
+ }
+ }
+
+ public static ModelFile createFromAsset(
+ AssetManager assetManager, String absolutePath, @ModelTypeDef String modelType)
+ throws IOException {
+ try (AssetFileDescriptor assetFileDescriptor = assetManager.openFd(absolutePath)) {
+ return createFromAssetFileDescriptor(
+ absolutePath, modelType, assetFileDescriptor, /* isAsset= */ true);
+ }
+ }
+
+ private static ModelFile createFromAssetFileDescriptor(
+ String absolutePath,
+ @ModelTypeDef String modelType,
+ AssetFileDescriptor assetFileDescriptor,
+ boolean isAsset) {
+ ModelInfoFetcher modelInfoFetcher = ModelInfoFetcher.create(modelType);
+ return new ModelFile(
+ modelType,
+ absolutePath,
+ modelInfoFetcher.getVersion(assetFileDescriptor),
+ modelInfoFetcher.getSupportedLocales(assetFileDescriptor),
+ isAsset);
+ }
+
+ @VisibleForTesting
+ ModelFile(
+ @ModelTypeDef String modelType,
+ String absolutePath,
+ int version,
+ String supportedLocaleTags,
+ boolean isAsset) {
+ this.modelType = modelType;
+ this.absolutePath = absolutePath;
+ this.version = version;
+ this.languageIndependent = LANGUAGE_INDEPENDENT.equals(supportedLocaleTags);
+ this.supportedLocales =
+ languageIndependent
+ ? LocaleList.getEmptyLocaleList()
+ : LocaleList.forLanguageTags(supportedLocaleTags);
+ this.isAsset = isAsset;
+ }
+
+ /** Returns if this model file is preferred to the given one. */
+ public boolean isPreferredTo(@Nullable ModelFile model) {
+ // A model is preferred to no model.
+ if (model == null) {
+ return true;
+ }
+
+ // A language-specific model is preferred to a language independent
+ // model.
+ if (!languageIndependent && model.languageIndependent) {
+ return true;
+ }
+ if (languageIndependent && !model.languageIndependent) {
+ return false;
+ }
+
+ // A higher-version model is preferred.
+ if (version > model.version) {
+ return true;
+ }
+ return false;
+ }
+
+ /** Returns whether the language supports any language in the given ranges. */
+ public boolean isAnyLanguageSupported(List<Locale.LanguageRange> languageRanges) {
+ Preconditions.checkNotNull(languageRanges);
+ if (languageIndependent) {
+ return true;
+ }
+ List<String> supportedLocaleTags =
+ Arrays.asList(supportedLocales.toLanguageTags().split(","));
+ return Locale.lookupTag(languageRanges, supportedLocaleTags) != null;
+ }
+
+ public AssetFileDescriptor open(AssetManager assetManager) throws IOException {
+ if (isAsset) {
+ return assetManager.openFd(absolutePath);
+ }
+ File file = new File(absolutePath);
+ ParcelFileDescriptor parcelFileDescriptor =
+ ParcelFileDescriptor.open(file, ParcelFileDescriptor.MODE_READ_ONLY);
+ return new AssetFileDescriptor(parcelFileDescriptor, 0, file.length());
+ }
+
+ public boolean canWrite() {
+ if (isAsset) {
+ return false;
+ }
+ return new File(absolutePath).canWrite();
+ }
+
+ public boolean delete() {
+ if (isAsset) {
+ throw new IllegalStateException("asset is read-only, deleting it is not allowed.");
+ }
+ return new File(absolutePath).delete();
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (!(o instanceof ModelFile)) {
+ return false;
+ }
+ ModelFile modelFile = (ModelFile) o;
+ return version == modelFile.version
+ && languageIndependent == modelFile.languageIndependent
+ && isAsset == modelFile.isAsset
+ && Objects.equals(modelType, modelFile.modelType)
+ && Objects.equals(absolutePath, modelFile.absolutePath)
+ && Objects.equals(supportedLocales, modelFile.supportedLocales);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(
+ modelType, absolutePath, version, supportedLocales, languageIndependent, isAsset);
+ }
+
+ public ModelInfo toModelInfo() {
+ return new ModelInfo(version, supportedLocales.toLanguageTags());
+ }
+
+ @Override
+ public String toString() {
+ return String.format(
+ Locale.US,
+ "ModelFile { type=%s path=%s version=%d locales=%s isAsset=%b}",
+ modelType,
+ absolutePath,
+ version,
+ languageIndependent ? LANGUAGE_INDEPENDENT : supportedLocales.toLanguageTags(),
+ isAsset);
+ }
+
+ public static ImmutableList<Optional<ModelInfo>> toModelInfos(
+ Optional<ModelFileManager.ModelFile>... modelFiles) {
+ return Arrays.stream(modelFiles)
+ .map(modelFile -> modelFile.transform(ModelFileManager.ModelFile::toModelInfo))
+ .collect(Collectors.collectingAndThen(Collectors.toList(), ImmutableList::copyOf));
+ }
+ }
+}
diff --git a/java/src/com/android/textclassifier/common/ModelType.java b/java/src/com/android/textclassifier/common/ModelType.java
new file mode 100644
index 0000000..a30fce0
--- /dev/null
+++ b/java/src/com/android/textclassifier/common/ModelType.java
@@ -0,0 +1,43 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.common;
+
+import androidx.annotation.StringDef;
+import com.google.common.collect.ImmutableList;
+import java.lang.annotation.Retention;
+import java.lang.annotation.RetentionPolicy;
+
+/** Effectively an enum class to represent types of models. */
+public final class ModelType {
+ /** TextClassifier model types as String. */
+ @Retention(RetentionPolicy.SOURCE)
+ @StringDef({ANNOTATOR, LANG_ID, ACTIONS_SUGGESTIONS})
+ public @interface ModelTypeDef {}
+
+ public static final String ANNOTATOR = "annotator";
+ public static final String LANG_ID = "lang_id";
+ public static final String ACTIONS_SUGGESTIONS = "actions_suggestions";
+
+ public static final ImmutableList<String> VALUES =
+ ImmutableList.of(ANNOTATOR, LANG_ID, ACTIONS_SUGGESTIONS);
+
+ public static ImmutableList<String> values() {
+ return VALUES;
+ }
+
+ private ModelType() {}
+}
diff --git a/java/src/com/android/textclassifier/common/TextClassifierServiceExecutors.java b/java/src/com/android/textclassifier/common/TextClassifierServiceExecutors.java
new file mode 100644
index 0000000..43164e0
--- /dev/null
+++ b/java/src/com/android/textclassifier/common/TextClassifierServiceExecutors.java
@@ -0,0 +1,70 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.common;
+
+import com.android.textclassifier.common.base.TcLog;
+import com.google.common.util.concurrent.ListeningExecutorService;
+import com.google.common.util.concurrent.MoreExecutors;
+import com.google.common.util.concurrent.ThreadFactoryBuilder;
+import java.util.concurrent.Executors;
+
+// TODO(licha): Revisit the configurations of thread pools
+/**
+ * Holder of executor singletons.
+ *
+ * <p>Note because we have two processes, so we may keep two copis of executors in this class.
+ */
+public final class TextClassifierServiceExecutors {
+ private static final String TAG = "TextClassifierServiceExecutors";
+
+ /** Returns an executor with normal priority. Used for handling client requests. */
+ public static ListeningExecutorService getNormhPriorityExecutor() {
+ return NormPriorityExecutorHolder.normPriorityExecutor;
+ }
+
+ /** Returns a single-thread executor with low priority. Used for internal tasks like logging. */
+ public static ListeningExecutorService getLowPriorityExecutor() {
+ return LowPriorityExecutorHolder.lowPriorityExecutor;
+ }
+
+ private static class NormPriorityExecutorHolder {
+ static final ListeningExecutorService normPriorityExecutor =
+ init("tcs-norm-prio-executor-%d", Thread.NORM_PRIORITY, /* corePoolSize= */ 2);
+ }
+
+ private static class LowPriorityExecutorHolder {
+ static final ListeningExecutorService lowPriorityExecutor =
+ init("tcs-low-prio-executor-%d", Thread.NORM_PRIORITY - 1, /* corePoolSize= */ 1);
+ }
+
+ private static ListeningExecutorService init(String nameFormat, int priority, int corePoolSize) {
+ TcLog.v(TAG, "Creating executor: " + nameFormat);
+ return MoreExecutors.listeningDecorator(
+ Executors.newFixedThreadPool(
+ corePoolSize,
+ new ThreadFactoryBuilder()
+ .setNameFormat(nameFormat)
+ .setPriority(priority)
+ // In Android, those uncaught exceptions will crash the whole process if not handled
+ .setUncaughtExceptionHandler(
+ (thread, throwable) ->
+ TcLog.e(TAG, "Exception from executor: " + thread, throwable))
+ .build()));
+ }
+
+ private TextClassifierServiceExecutors() {}
+}
diff --git a/java/src/com/android/textclassifier/common/TextClassifierSettings.java b/java/src/com/android/textclassifier/common/TextClassifierSettings.java
new file mode 100644
index 0000000..fdf259e
--- /dev/null
+++ b/java/src/com/android/textclassifier/common/TextClassifierSettings.java
@@ -0,0 +1,502 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.common;
+
+import static java.util.concurrent.TimeUnit.HOURS;
+
+import android.provider.DeviceConfig;
+import android.provider.DeviceConfig.Properties;
+import android.view.textclassifier.ConversationAction;
+import android.view.textclassifier.TextClassifier;
+import androidx.annotation.NonNull;
+import com.android.textclassifier.utils.IndentingPrintWriter;
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Splitter;
+import com.google.common.collect.ImmutableList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import javax.annotation.Nullable;
+
+/**
+ * TextClassifier specific settings.
+ *
+ * <p>Currently, this class does not guarantee co-diverted flags are updated atomically.
+ *
+ * <p>Example of setting the values for testing.
+ *
+ * <pre>
+ * adb shell cmd device_config put textclassifier system_textclassifier_enabled true
+ * </pre>
+ *
+ * @see android.provider.DeviceConfig#NAMESPACE_TEXTCLASSIFIER
+ */
+public final class TextClassifierSettings {
+ private static final String TAG = "TextClassifierSettings";
+ public static final String NAMESPACE = DeviceConfig.NAMESPACE_TEXTCLASSIFIER;
+
+ private static final String DELIMITER = ":";
+
+ /** Whether the user language profile feature is enabled. */
+ private static final String USER_LANGUAGE_PROFILE_ENABLED = "user_language_profile_enabled";
+ /** Max length of text that suggestSelection can accept. */
+ @VisibleForTesting
+ static final String SUGGEST_SELECTION_MAX_RANGE_LENGTH = "suggest_selection_max_range_length";
+ /** Max length of text that classifyText can accept. */
+ private static final String CLASSIFY_TEXT_MAX_RANGE_LENGTH = "classify_text_max_range_length";
+ /** Max length of text that generateLinks can accept. */
+ private static final String GENERATE_LINKS_MAX_TEXT_LENGTH = "generate_links_max_text_length";
+ /** Sampling rate for generateLinks logging. */
+ private static final String GENERATE_LINKS_LOG_SAMPLE_RATE = "generate_links_log_sample_rate";
+ /**
+ * Extra count that is added to some languages, e.g. system languages, when deducing the frequent
+ * languages in {@link
+ * com.android.textclassifier.ulp.LanguageProfileAnalyzer#getFrequentLanguages(int)}.
+ */
+
+ /**
+ * A colon(:) separated string that specifies the default entities types for generateLinks when
+ * hint is not given.
+ */
+ @VisibleForTesting static final String ENTITY_LIST_DEFAULT = "entity_list_default";
+ /**
+ * A colon(:) separated string that specifies the default entities types for generateLinks when
+ * the text is in a not editable UI widget.
+ */
+ private static final String ENTITY_LIST_NOT_EDITABLE = "entity_list_not_editable";
+ /**
+ * A colon(:) separated string that specifies the default entities types for generateLinks when
+ * the text is in an editable UI widget.
+ */
+ private static final String ENTITY_LIST_EDITABLE = "entity_list_editable";
+ /**
+ * A colon(:) separated string that specifies the default action types for
+ * suggestConversationActions when the suggestions are used in an app.
+ */
+ private static final String IN_APP_CONVERSATION_ACTION_TYPES_DEFAULT =
+ "in_app_conversation_action_types_default";
+ /**
+ * A colon(:) separated string that specifies the default action types for
+ * suggestConversationActions when the suggestions are used in a notification.
+ */
+ private static final String NOTIFICATION_CONVERSATION_ACTION_TYPES_DEFAULT =
+ "notification_conversation_action_types_default";
+ /** Threshold to accept a suggested language from LangID model. */
+ @VisibleForTesting static final String LANG_ID_THRESHOLD_OVERRIDE = "lang_id_threshold_override";
+ /** Whether to enable {@link com.android.textclassifier.intent.TemplateIntentFactory}. */
+ @VisibleForTesting
+ static final String TEMPLATE_INTENT_FACTORY_ENABLED = "template_intent_factory_enabled";
+ /** Whether to enable "translate" action in classifyText. */
+ private static final String TRANSLATE_IN_CLASSIFICATION_ENABLED =
+ "translate_in_classification_enabled";
+ /**
+ * Whether to detect the languages of the text in request by using langId for the native model.
+ */
+ private static final String DETECT_LANGUAGES_FROM_TEXT_ENABLED =
+ "detect_languages_from_text_enabled";
+
+ /** Whether to enable model downloading with ModelDownloadManager */
+ @VisibleForTesting
+ public static final String MODEL_DOWNLOAD_MANAGER_ENABLED = "model_download_manager_enabled";
+ /** Type of network to download model manifest. A String value of androidx.work.NetworkType. */
+ private static final String MANIFEST_DOWNLOAD_REQUIRED_NETWORK_TYPE =
+ "manifest_download_required_network_type";
+ /** Max attempts allowed for a single ModelDownloader downloading task. */
+ @VisibleForTesting
+ static final String MODEL_DOWNLOAD_MAX_ATTEMPTS = "model_download_max_attempts";
+
+ @VisibleForTesting
+ static final String MODEL_DOWNLOAD_BACKOFF_DELAY_IN_MILLIS =
+ "model_download_backoff_delay_in_millis";
+ /** Flag name for manifest url is dynamically formatted based on model type and model language. */
+ @VisibleForTesting public static final String MANIFEST_URL_TEMPLATE = "manifest_url_%s_%s";
+ /** Sampling rate for TextClassifier API logging. */
+ static final String TEXTCLASSIFIER_API_LOG_SAMPLE_RATE = "textclassifier_api_log_sample_rate";
+
+ /** The size of the cache of the mapping of session id to text classification context. */
+ private static final String SESSION_ID_TO_CONTEXT_CACHE_SIZE = "session_id_to_context_cache_size";
+
+ /**
+ * A colon(:) separated string that specifies the configuration to use when including surrounding
+ * context text in language detection queries.
+ *
+ * <p>Format= minimumTextSize<int>:penalizeRatio<float>:textScoreRatio<float>
+ *
+ * <p>e.g. 20:1.0:0.4
+ *
+ * <p>Accept all text lengths with minimumTextSize=0
+ *
+ * <p>Reject all text less than minimumTextSize with penalizeRatio=0
+ *
+ * @see {@code TextClassifierImpl#detectLanguages(String, int, int)} for reference.
+ */
+ @VisibleForTesting static final String LANG_ID_CONTEXT_SETTINGS = "lang_id_context_settings";
+ /** Default threshold to translate the language of the context the user selects */
+ private static final String TRANSLATE_ACTION_THRESHOLD = "translate_action_threshold";
+
+ // Sync this with ConversationAction.TYPE_ADD_CONTACT;
+ public static final String TYPE_ADD_CONTACT = "add_contact";
+ // Sync this with ConversationAction.COPY;
+ public static final String TYPE_COPY = "copy";
+
+ private static final int SUGGEST_SELECTION_MAX_RANGE_LENGTH_DEFAULT = 10 * 1000;
+ private static final int CLASSIFY_TEXT_MAX_RANGE_LENGTH_DEFAULT = 10 * 1000;
+ private static final int GENERATE_LINKS_MAX_TEXT_LENGTH_DEFAULT = 100 * 1000;
+ private static final int GENERATE_LINKS_LOG_SAMPLE_RATE_DEFAULT = 100;
+
+ private static final ImmutableList<String> ENTITY_LIST_DEFAULT_VALUE =
+ ImmutableList.of(
+ TextClassifier.TYPE_ADDRESS,
+ TextClassifier.TYPE_EMAIL,
+ TextClassifier.TYPE_PHONE,
+ TextClassifier.TYPE_URL,
+ TextClassifier.TYPE_DATE,
+ TextClassifier.TYPE_DATE_TIME,
+ TextClassifier.TYPE_FLIGHT_NUMBER);
+ private static final ImmutableList<String> CONVERSATION_ACTIONS_TYPES_DEFAULT_VALUES =
+ ImmutableList.of(
+ ConversationAction.TYPE_TEXT_REPLY,
+ ConversationAction.TYPE_CREATE_REMINDER,
+ ConversationAction.TYPE_CALL_PHONE,
+ ConversationAction.TYPE_OPEN_URL,
+ ConversationAction.TYPE_SEND_EMAIL,
+ ConversationAction.TYPE_SEND_SMS,
+ ConversationAction.TYPE_TRACK_FLIGHT,
+ ConversationAction.TYPE_VIEW_CALENDAR,
+ ConversationAction.TYPE_VIEW_MAP,
+ TYPE_ADD_CONTACT,
+ TYPE_COPY);
+ /**
+ * < 0 : Not set. Use value from LangId model. 0 - 1: Override value in LangId model.
+ *
+ * @see EntityConfidence
+ */
+ private static final float LANG_ID_THRESHOLD_OVERRIDE_DEFAULT = -1f;
+
+ private static final float TRANSLATE_ACTION_THRESHOLD_DEFAULT = 0.5f;
+
+ private static final boolean USER_LANGUAGE_PROFILE_ENABLED_DEFAULT = true;
+ private static final boolean TEMPLATE_INTENT_FACTORY_ENABLED_DEFAULT = true;
+ private static final boolean TRANSLATE_IN_CLASSIFICATION_ENABLED_DEFAULT = true;
+ private static final boolean DETECT_LANGUAGES_FROM_TEXT_ENABLED_DEFAULT = true;
+ private static final boolean MODEL_DOWNLOAD_MANAGER_ENABLED_DEFAULT = false;
+ private static final String MANIFEST_DOWNLOAD_REQUIRED_NETWORK_TYPE_DEFAULT = "UNMETERED";
+ private static final int MODEL_DOWNLOAD_MAX_ATTEMPTS_DEFAULT = 5;
+ private static final long MODEL_DOWNLOAD_BACKOFF_DELAY_IN_MILLIS_DEFAULT = HOURS.toMillis(1);
+ private static final String MANIFEST_URL_DEFAULT = "";
+ private static final float[] LANG_ID_CONTEXT_SETTINGS_DEFAULT = new float[] {20f, 1.0f, 0.4f};
+ /**
+ * Sampling rate for API logging. For example, 100 means there is a 0.01 chance that the API call
+ * is the logged.
+ */
+ private static final int TEXTCLASSIFIER_API_LOG_SAMPLE_RATE_DEFAULT = 10;
+
+ private static final int SESSION_ID_TO_CONTEXT_CACHE_SIZE_DEFAULT = 10;
+
+ // TODO(licha): Consider removing this. We can use real device config for testing.
+ /** DeviceConfig interface to facilitate testing. */
+ @VisibleForTesting
+ public interface IDeviceConfig {
+ default Properties getProperties(@NonNull String namespace, @NonNull String... names) {
+ return new Properties.Builder(namespace).build();
+ }
+
+ default int getInt(@NonNull String namespace, @NonNull String name, @NonNull int defaultValue) {
+ return defaultValue;
+ }
+
+ default long getLong(
+ @NonNull String namespace, @NonNull String name, @NonNull long defaultValue) {
+ return defaultValue;
+ }
+
+ default float getFloat(
+ @NonNull String namespace, @NonNull String name, @NonNull float defaultValue) {
+ return defaultValue;
+ }
+
+ default String getString(
+ @NonNull String namespace, @NonNull String name, @Nullable String defaultValue) {
+ return defaultValue;
+ }
+
+ default boolean getBoolean(
+ @NonNull String namespace, @NonNull String name, boolean defaultValue) {
+ return defaultValue;
+ }
+ }
+
+ private static final IDeviceConfig DEFAULT_DEVICE_CONFIG =
+ new IDeviceConfig() {
+ @Override
+ public Properties getProperties(@NonNull String namespace, @NonNull String... names) {
+ return DeviceConfig.getProperties(namespace, names);
+ }
+
+ @Override
+ public int getInt(
+ @NonNull String namespace, @NonNull String name, @NonNull int defaultValue) {
+ return DeviceConfig.getInt(namespace, name, defaultValue);
+ }
+
+ @Override
+ public long getLong(
+ @NonNull String namespace, @NonNull String name, @NonNull long defaultValue) {
+ return DeviceConfig.getLong(namespace, name, defaultValue);
+ }
+
+ @Override
+ public float getFloat(
+ @NonNull String namespace, @NonNull String name, @NonNull float defaultValue) {
+ return DeviceConfig.getFloat(namespace, name, defaultValue);
+ }
+
+ @Override
+ public String getString(
+ @NonNull String namespace, @NonNull String name, @NonNull String defaultValue) {
+ return DeviceConfig.getString(namespace, name, defaultValue);
+ }
+
+ @Override
+ public boolean getBoolean(
+ @NonNull String namespace, @NonNull String name, @NonNull boolean defaultValue) {
+ return DeviceConfig.getBoolean(namespace, name, defaultValue);
+ }
+ };
+
+ private final IDeviceConfig deviceConfig;
+
+ public TextClassifierSettings() {
+ this(DEFAULT_DEVICE_CONFIG);
+ }
+
+ @VisibleForTesting
+ public TextClassifierSettings(IDeviceConfig deviceConfig) {
+ this.deviceConfig = deviceConfig;
+ }
+
+ public int getSuggestSelectionMaxRangeLength() {
+ return deviceConfig.getInt(
+ NAMESPACE, SUGGEST_SELECTION_MAX_RANGE_LENGTH, SUGGEST_SELECTION_MAX_RANGE_LENGTH_DEFAULT);
+ }
+
+ public int getClassifyTextMaxRangeLength() {
+ return deviceConfig.getInt(
+ NAMESPACE, CLASSIFY_TEXT_MAX_RANGE_LENGTH, CLASSIFY_TEXT_MAX_RANGE_LENGTH_DEFAULT);
+ }
+
+ public int getGenerateLinksMaxTextLength() {
+ return deviceConfig.getInt(
+ NAMESPACE, GENERATE_LINKS_MAX_TEXT_LENGTH, GENERATE_LINKS_MAX_TEXT_LENGTH_DEFAULT);
+ }
+
+ public int getGenerateLinksLogSampleRate() {
+ return deviceConfig.getInt(
+ NAMESPACE, GENERATE_LINKS_LOG_SAMPLE_RATE, GENERATE_LINKS_LOG_SAMPLE_RATE_DEFAULT);
+ }
+
+ public List<String> getEntityListDefault() {
+ return getDeviceConfigStringList(ENTITY_LIST_DEFAULT, ENTITY_LIST_DEFAULT_VALUE);
+ }
+
+ public List<String> getEntityListNotEditable() {
+ return getDeviceConfigStringList(ENTITY_LIST_NOT_EDITABLE, ENTITY_LIST_DEFAULT_VALUE);
+ }
+
+ public List<String> getEntityListEditable() {
+ return getDeviceConfigStringList(ENTITY_LIST_EDITABLE, ENTITY_LIST_DEFAULT_VALUE);
+ }
+
+ public List<String> getInAppConversationActionTypes() {
+ return getDeviceConfigStringList(
+ IN_APP_CONVERSATION_ACTION_TYPES_DEFAULT, CONVERSATION_ACTIONS_TYPES_DEFAULT_VALUES);
+ }
+
+ public List<String> getNotificationConversationActionTypes() {
+ return getDeviceConfigStringList(
+ NOTIFICATION_CONVERSATION_ACTION_TYPES_DEFAULT, CONVERSATION_ACTIONS_TYPES_DEFAULT_VALUES);
+ }
+
+ public float getLangIdThresholdOverride() {
+ return deviceConfig.getFloat(
+ NAMESPACE, LANG_ID_THRESHOLD_OVERRIDE, LANG_ID_THRESHOLD_OVERRIDE_DEFAULT);
+ }
+
+ public float getTranslateActionThreshold() {
+ return deviceConfig.getFloat(
+ NAMESPACE, TRANSLATE_ACTION_THRESHOLD, TRANSLATE_ACTION_THRESHOLD_DEFAULT);
+ }
+
+ public boolean isUserLanguageProfileEnabled() {
+ return deviceConfig.getBoolean(
+ NAMESPACE, USER_LANGUAGE_PROFILE_ENABLED, USER_LANGUAGE_PROFILE_ENABLED_DEFAULT);
+ }
+
+ public boolean isTemplateIntentFactoryEnabled() {
+ return deviceConfig.getBoolean(
+ NAMESPACE, TEMPLATE_INTENT_FACTORY_ENABLED, TEMPLATE_INTENT_FACTORY_ENABLED_DEFAULT);
+ }
+
+ public boolean isTranslateInClassificationEnabled() {
+ return deviceConfig.getBoolean(
+ NAMESPACE,
+ TRANSLATE_IN_CLASSIFICATION_ENABLED,
+ TRANSLATE_IN_CLASSIFICATION_ENABLED_DEFAULT);
+ }
+
+ public boolean isDetectLanguagesFromTextEnabled() {
+ return deviceConfig.getBoolean(
+ NAMESPACE, DETECT_LANGUAGES_FROM_TEXT_ENABLED, DETECT_LANGUAGES_FROM_TEXT_ENABLED_DEFAULT);
+ }
+
+ public float[] getLangIdContextSettings() {
+ return getDeviceConfigFloatArray(LANG_ID_CONTEXT_SETTINGS, LANG_ID_CONTEXT_SETTINGS_DEFAULT);
+ }
+
+ public boolean isModelDownloadManagerEnabled() {
+ return deviceConfig.getBoolean(
+ NAMESPACE, MODEL_DOWNLOAD_MANAGER_ENABLED, MODEL_DOWNLOAD_MANAGER_ENABLED_DEFAULT);
+ }
+
+ /** Returns a string which represents a androidx.work.NetworkType enum. */
+ public String getManifestDownloadRequiredNetworkType() {
+ return deviceConfig.getString(
+ NAMESPACE,
+ MANIFEST_DOWNLOAD_REQUIRED_NETWORK_TYPE,
+ MANIFEST_DOWNLOAD_REQUIRED_NETWORK_TYPE_DEFAULT);
+ }
+
+ public int getModelDownloadMaxAttempts() {
+ return deviceConfig.getInt(
+ NAMESPACE, MODEL_DOWNLOAD_MAX_ATTEMPTS, MODEL_DOWNLOAD_MAX_ATTEMPTS_DEFAULT);
+ }
+
+ public long getModelDownloadBackoffDelayInMillis() {
+ return deviceConfig.getLong(
+ NAMESPACE,
+ MODEL_DOWNLOAD_BACKOFF_DELAY_IN_MILLIS,
+ MODEL_DOWNLOAD_BACKOFF_DELAY_IN_MILLIS_DEFAULT);
+ }
+
+ /**
+ * Get model's manifest url for given model type and language.
+ *
+ * @param modelType the type of model for the target url
+ * @param modelLanguageTag the language tag for the model (e.g. en), but can also be "universal"
+ * @return DeviceConfig configured url or empty string if not set
+ */
+ public String getManifestURL(@ModelType.ModelTypeDef String modelType, String modelLanguageTag) {
+ // E.g: manifest_url_annotator_zh, manifest_url_lang_id_universal,
+ // manifest_url_actions_suggestions_en
+ String urlFlagName = String.format(MANIFEST_URL_TEMPLATE, modelType, modelLanguageTag);
+ return deviceConfig.getString(NAMESPACE, urlFlagName, MANIFEST_URL_DEFAULT);
+ }
+
+ /**
+ * Gets all language variants configured for a specific ModelType.
+ *
+ * <p>For a specific language, there can be many variants: de-CH, de-LI, zh-Hans, zh-Hant. There
+ * is no easy way to hardcode the list in client. Therefore, we parse all configured flag's name
+ * in DeviceConfig, and let the client to choose the best variant to download.
+ */
+ public ImmutableList<String> getLanguageTagsForManifestURL(
+ @ModelType.ModelTypeDef String modelType) {
+ String urlFlagBaseName = String.format(MANIFEST_URL_TEMPLATE, modelType, /* language */ "");
+ Properties properties = deviceConfig.getProperties(NAMESPACE);
+ ImmutableList.Builder<String> variantsBuilder = ImmutableList.builder();
+ for (String name : properties.getKeyset()) {
+ if (name.startsWith(urlFlagBaseName)
+ && properties.getString(name, /* defaultValue= */ null) != null) {
+ variantsBuilder.add(name.substring(urlFlagBaseName.length()));
+ }
+ }
+ return variantsBuilder.build();
+ }
+
+ public int getTextClassifierApiLogSampleRate() {
+ return deviceConfig.getInt(
+ NAMESPACE, TEXTCLASSIFIER_API_LOG_SAMPLE_RATE, TEXTCLASSIFIER_API_LOG_SAMPLE_RATE_DEFAULT);
+ }
+
+ public int getSessionIdToContextCacheSize() {
+ return deviceConfig.getInt(
+ NAMESPACE, SESSION_ID_TO_CONTEXT_CACHE_SIZE, SESSION_ID_TO_CONTEXT_CACHE_SIZE_DEFAULT);
+ }
+
+ public void dump(IndentingPrintWriter pw) {
+ pw.println("TextClassifierSettings:");
+ pw.increaseIndent();
+ pw.printPair(CLASSIFY_TEXT_MAX_RANGE_LENGTH, getClassifyTextMaxRangeLength());
+ pw.printPair(DETECT_LANGUAGES_FROM_TEXT_ENABLED, isDetectLanguagesFromTextEnabled());
+ pw.printPair(ENTITY_LIST_DEFAULT, getEntityListDefault());
+ pw.printPair(ENTITY_LIST_EDITABLE, getEntityListEditable());
+ pw.printPair(ENTITY_LIST_NOT_EDITABLE, getEntityListNotEditable());
+ pw.printPair(GENERATE_LINKS_LOG_SAMPLE_RATE, getGenerateLinksLogSampleRate());
+ pw.printPair(GENERATE_LINKS_MAX_TEXT_LENGTH, getGenerateLinksMaxTextLength());
+ pw.printPair(IN_APP_CONVERSATION_ACTION_TYPES_DEFAULT, getInAppConversationActionTypes());
+ pw.printPair(LANG_ID_CONTEXT_SETTINGS, Arrays.toString(getLangIdContextSettings()));
+ pw.printPair(LANG_ID_THRESHOLD_OVERRIDE, getLangIdThresholdOverride());
+ pw.printPair(TRANSLATE_ACTION_THRESHOLD, getTranslateActionThreshold());
+ pw.printPair(
+ NOTIFICATION_CONVERSATION_ACTION_TYPES_DEFAULT, getNotificationConversationActionTypes());
+ pw.printPair(SUGGEST_SELECTION_MAX_RANGE_LENGTH, getSuggestSelectionMaxRangeLength());
+ pw.printPair(USER_LANGUAGE_PROFILE_ENABLED, isUserLanguageProfileEnabled());
+ pw.printPair(TEMPLATE_INTENT_FACTORY_ENABLED, isTemplateIntentFactoryEnabled());
+ pw.printPair(TRANSLATE_IN_CLASSIFICATION_ENABLED, isTranslateInClassificationEnabled());
+ pw.printPair(MODEL_DOWNLOAD_MANAGER_ENABLED, isModelDownloadManagerEnabled());
+ pw.printPair(MODEL_DOWNLOAD_MAX_ATTEMPTS, getModelDownloadMaxAttempts());
+ pw.decreaseIndent();
+ pw.printPair(TEXTCLASSIFIER_API_LOG_SAMPLE_RATE, getTextClassifierApiLogSampleRate());
+ pw.printPair(SESSION_ID_TO_CONTEXT_CACHE_SIZE, getSessionIdToContextCacheSize());
+ pw.decreaseIndent();
+ }
+
+ private List<String> getDeviceConfigStringList(String key, List<String> defaultValue) {
+ return parse(deviceConfig.getString(NAMESPACE, key, null), defaultValue);
+ }
+
+ private float[] getDeviceConfigFloatArray(String key, float[] defaultValue) {
+ return parse(deviceConfig.getString(NAMESPACE, key, null), defaultValue);
+ }
+
+ private static List<String> parse(@Nullable String listStr, List<String> defaultValue) {
+ if (listStr != null) {
+ return Collections.unmodifiableList(Arrays.asList(listStr.split(DELIMITER)));
+ }
+ return defaultValue;
+ }
+
+ private static float[] parse(@Nullable String arrayStr, float[] defaultValue) {
+ if (arrayStr != null) {
+ final List<String> split = Splitter.onPattern(DELIMITER).splitToList(arrayStr);
+ if (split.size() != defaultValue.length) {
+ return defaultValue;
+ }
+ final float[] result = new float[split.size()];
+ for (int i = 0; i < split.size(); i++) {
+ try {
+ result[i] = Float.parseFloat(split.get(i));
+ } catch (NumberFormatException e) {
+ return defaultValue;
+ }
+ }
+ return result;
+ } else {
+ return defaultValue;
+ }
+ }
+}
diff --git a/java/src/com/android/textclassifier/common/base/TcLog.java b/java/src/com/android/textclassifier/common/base/TcLog.java
index 87f1187..05a2443 100644
--- a/java/src/com/android/textclassifier/common/base/TcLog.java
+++ b/java/src/com/android/textclassifier/common/base/TcLog.java
@@ -16,6 +16,8 @@
package com.android.textclassifier.common.base;
+import android.util.Log;
+
/**
* Logging for android.view.textclassifier package.
*
@@ -31,27 +33,30 @@
public static final String TAG = "androidtc";
/** true: Enables full logging. false: Limits logging to debug level. */
- public static final boolean ENABLE_FULL_LOGGING =
- android.util.Log.isLoggable(TAG, android.util.Log.VERBOSE);
+ public static final boolean ENABLE_FULL_LOGGING = Log.isLoggable(TAG, Log.VERBOSE);
private TcLog() {}
public static void v(String tag, String msg) {
if (ENABLE_FULL_LOGGING) {
- android.util.Log.v(getTag(tag), msg);
+ Log.v(getTag(tag), msg);
}
}
public static void d(String tag, String msg) {
- android.util.Log.d(getTag(tag), msg);
+ Log.d(getTag(tag), msg);
}
public static void w(String tag, String msg) {
- android.util.Log.w(getTag(tag), msg);
+ Log.w(getTag(tag), msg);
+ }
+
+ public static void e(String tag, String msg) {
+ Log.e(getTag(tag), msg);
}
public static void e(String tag, String msg, Throwable tr) {
- android.util.Log.e(getTag(tag), msg, tr);
+ Log.e(getTag(tag), msg, tr);
}
private static String getTag(String customTag) {
diff --git a/java/src/com/android/textclassifier/common/intent/LabeledIntent.java b/java/src/com/android/textclassifier/common/intent/LabeledIntent.java
index b56d0bb..5c420ad 100644
--- a/java/src/com/android/textclassifier/common/intent/LabeledIntent.java
+++ b/java/src/com/android/textclassifier/common/intent/LabeledIntent.java
@@ -29,6 +29,7 @@
import androidx.core.content.ContextCompat;
import androidx.core.graphics.drawable.IconCompat;
import com.android.textclassifier.common.base.TcLog;
+import com.google.common.base.Objects;
import com.google.common.base.Preconditions;
import javax.annotation.Nullable;
@@ -94,8 +95,26 @@
final ResolveInfo resolveInfo = pm.resolveActivity(intent, 0);
if (resolveInfo == null || resolveInfo.activityInfo == null) {
- TcLog.w(TAG, "resolveInfo or activityInfo is null");
- return null;
+ // Failed to resolve the intent. It could be because there are no apps to handle
+ // the intent. It could be also because the calling app has no visibility to the target app
+ // due to the app visibility feature introduced on R. For privacy reason, we don't want to
+ // force users of our library to ask for the visibility to the http/https view intent.
+ // Getting visibility to this intent effectively means getting visibility of ~70% of apps.
+ // This defeats the purpose of the app visibility feature. Practically speaking, all devices
+ // are very likely to have a browser installed. Thus, if it is a web intent, we assume we
+ // failed to resolve the intent just because of the app visibility feature. In which case, we
+ // return an implicit intent without an icon.
+ if (isWebIntent()) {
+ IconCompat icon = IconCompat.createWithResource(context, android.R.drawable.ic_menu_more);
+ RemoteActionCompat action =
+ createRemoteAction(
+ context, intent, icon, /* shouldShowIcon= */ false, resolveInfo, titleChooser);
+ // Create a clone so that the client does not modify the original intent.
+ return new Result(new Intent(intent), action);
+ } else {
+ TcLog.w(TAG, "resolveInfo or activityInfo is null");
+ return null;
+ }
}
if (!hasPermission(context, resolveInfo.activityInfo)) {
TcLog.d(TAG, "No permission to access: " + resolveInfo.activityInfo);
@@ -126,6 +145,19 @@
// RemoteAction requires that there be an icon.
icon = IconCompat.createWithResource(context, android.R.drawable.ic_menu_more);
}
+ RemoteActionCompat action =
+ createRemoteAction(
+ context, resolvedIntent, icon, shouldShowIcon, resolveInfo, titleChooser);
+ return new Result(resolvedIntent, action);
+ }
+
+ private RemoteActionCompat createRemoteAction(
+ Context context,
+ Intent resolvedIntent,
+ IconCompat icon,
+ boolean shouldShowIcon,
+ @Nullable ResolveInfo resolveInfo,
+ @Nullable TitleChooser titleChooser) {
final PendingIntent pendingIntent = createPendingIntent(context, resolvedIntent, requestCode);
titleChooser = titleChooser == null ? DEFAULT_TITLE_CHOOSER : titleChooser;
CharSequence title = titleChooser.chooseTitle(this, resolveInfo);
@@ -134,12 +166,25 @@
title = DEFAULT_TITLE_CHOOSER.chooseTitle(this, resolveInfo);
}
final RemoteActionCompat action =
- new RemoteActionCompat(icon, title, resolveDescription(resolveInfo, pm), pendingIntent);
+ new RemoteActionCompat(
+ icon,
+ title,
+ resolveDescription(resolveInfo, context.getPackageManager()),
+ pendingIntent);
action.setShouldShowIcon(shouldShowIcon);
- return new Result(resolvedIntent, action);
+ return action;
}
- private String resolveDescription(ResolveInfo resolveInfo, PackageManager packageManager) {
+ private boolean isWebIntent() {
+ if (!Intent.ACTION_VIEW.equals(intent.getAction())) {
+ return false;
+ }
+ final String scheme = intent.getScheme();
+ return Objects.equal(scheme, "http") || Objects.equal(scheme, "https");
+ }
+
+ private String resolveDescription(
+ @Nullable ResolveInfo resolveInfo, PackageManager packageManager) {
if (!TextUtils.isEmpty(descriptionWithAppName)) {
// Example string format of descriptionWithAppName: "Use %1$s to open map".
String applicationName = getApplicationName(resolveInfo, packageManager);
@@ -165,12 +210,16 @@
private static PendingIntent createPendingIntent(
final Context context, final Intent intent, int requestCode) {
return PendingIntent.getActivity(
- context, requestCode, intent, PendingIntent.FLAG_UPDATE_CURRENT);
+ context,
+ requestCode,
+ intent,
+ PendingIntent.FLAG_UPDATE_CURRENT | PendingIntent.FLAG_IMMUTABLE);
}
@Nullable
- private static String getApplicationName(ResolveInfo resolveInfo, PackageManager packageManager) {
- if (resolveInfo.activityInfo == null) {
+ private static String getApplicationName(
+ @Nullable ResolveInfo resolveInfo, PackageManager packageManager) {
+ if (resolveInfo == null || resolveInfo.activityInfo == null) {
return null;
}
if ("android".equals(resolveInfo.activityInfo.packageName)) {
@@ -214,6 +263,6 @@
* is guaranteed to have a non-null {@code activityInfo}.
*/
@Nullable
- CharSequence chooseTitle(LabeledIntent labeledIntent, ResolveInfo resolveInfo);
+ CharSequence chooseTitle(LabeledIntent labeledIntent, @Nullable ResolveInfo resolveInfo);
}
}
diff --git a/java/src/com/android/textclassifier/common/statsd/GenerateLinksLogger.java b/java/src/com/android/textclassifier/common/statsd/GenerateLinksLogger.java
index c132749..df63f2f 100644
--- a/java/src/com/android/textclassifier/common/statsd/GenerateLinksLogger.java
+++ b/java/src/com/android/textclassifier/common/statsd/GenerateLinksLogger.java
@@ -16,23 +16,20 @@
package com.android.textclassifier.common.statsd;
-import android.util.StatsEvent;
-import android.util.StatsLog;
+import android.view.textclassifier.TextClassificationContext;
+import android.view.textclassifier.TextClassificationSessionId;
import android.view.textclassifier.TextClassifier;
import android.view.textclassifier.TextLinks;
import androidx.collection.ArrayMap;
import com.android.textclassifier.common.base.TcLog;
import com.android.textclassifier.common.logging.ResultIdUtils.ModelInfo;
import com.android.textclassifier.common.logging.TextClassifierEvent;
-import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Optional;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import java.util.Locale;
import java.util.Map;
import java.util.Random;
-import java.util.UUID;
-import java.util.function.Supplier;
import javax.annotation.Nullable;
/** A helper for logging calls to generateLinks. */
@@ -42,7 +39,6 @@
private final Random random;
private final int sampleRate;
- private final Supplier<String> randomUuidSupplier;
/**
* @param sampleRate the rate at which log events are written. (e.g. 100 means there is a 0.01
@@ -50,24 +46,14 @@
* events, pass 1.
*/
public GenerateLinksLogger(int sampleRate) {
- this(sampleRate, () -> UUID.randomUUID().toString());
- }
-
- /**
- * @param sampleRate the rate at which log events are written. (e.g. 100 means there is a 0.01
- * chance that a call to logGenerateLinks results in an event being written). To write all
- * events, pass 1.
- * @param randomUuidSupplier supplies random UUIDs.
- */
- @VisibleForTesting
- GenerateLinksLogger(int sampleRate, Supplier<String> randomUuidSupplier) {
this.sampleRate = sampleRate;
random = new Random();
- this.randomUuidSupplier = Preconditions.checkNotNull(randomUuidSupplier);
}
/** Logs statistics about a call to generateLinks. */
public void logGenerateLinks(
+ @Nullable TextClassificationSessionId sessionId,
+ @Nullable TextClassificationContext textClassificationContext,
CharSequence text,
TextLinks links,
String callingPackageName,
@@ -97,20 +83,33 @@
totalStats.countLink(link);
perEntityTypeStats.computeIfAbsent(entityType, k -> new LinkifyStats()).countLink(link);
}
+ int widgetType = TextClassifierStatsLog.TEXT_SELECTION_EVENT__WIDGET_TYPE__WIDGET_TYPE_UNKNOWN;
+ if (textClassificationContext != null) {
+ widgetType = WidgetTypeConverter.toLoggingValue(textClassificationContext.getWidgetType());
+ }
- final String callId = randomUuidSupplier.get();
+ final String sessionIdStr = sessionId == null ? null : sessionId.getValue();
writeStats(
- callId, callingPackageName, null, totalStats, text, latencyMs, annotatorModel, langIdModel);
+ sessionIdStr,
+ callingPackageName,
+ null,
+ totalStats,
+ text,
+ widgetType,
+ latencyMs,
+ annotatorModel,
+ langIdModel);
// Sort the entity types to ensure the logging order is deterministic.
ImmutableList<String> sortedEntityTypes =
ImmutableList.sortedCopyOf(perEntityTypeStats.keySet());
for (String entityType : sortedEntityTypes) {
writeStats(
- callId,
+ sessionIdStr,
callingPackageName,
entityType,
perEntityTypeStats.get(entityType),
text,
+ widgetType,
latencyMs,
annotatorModel,
langIdModel);
@@ -132,34 +131,31 @@
/** Writes a log event for the given stats. */
private static void writeStats(
- String callId,
+ @Nullable String sessionId,
String callingPackageName,
@Nullable String entityType,
LinkifyStats stats,
CharSequence text,
+ int widgetType,
long latencyMs,
Optional<ModelInfo> annotatorModel,
Optional<ModelInfo> langIdModel) {
String annotatorModelName = annotatorModel.transform(ModelInfo::toModelName).or("");
String langIdModelName = langIdModel.transform(ModelInfo::toModelName).or("");
- StatsEvent statsEvent =
- StatsEvent.newBuilder()
- .setAtomId(TextClassifierEventLogger.TEXT_LINKIFY_EVENT_ATOM_ID)
- .writeString(callId)
- .writeInt(TextClassifierEvent.TYPE_LINKS_GENERATED)
- .writeString(annotatorModelName)
- .writeInt(TextClassifierEventLogger.WidgetType.WIDGET_TYPE_UNKNOWN)
- .writeInt(/* eventIndex */ 0)
- .writeString(entityType)
- .writeInt(stats.numLinks)
- .writeInt(stats.numLinksTextLength)
- .writeInt(text.length())
- .writeLong(latencyMs)
- .writeString(callingPackageName)
- .writeString(langIdModelName)
- .usePooledBuffer()
- .build();
- StatsLog.write(statsEvent);
+ TextClassifierStatsLog.write(
+ TextClassifierStatsLog.TEXT_LINKIFY_EVENT,
+ sessionId,
+ TextClassifierEvent.TYPE_LINKS_GENERATED,
+ annotatorModelName,
+ widgetType,
+ /* eventIndex */ 0,
+ entityType,
+ stats.numLinks,
+ stats.numLinksTextLength,
+ text.length(),
+ latencyMs,
+ callingPackageName,
+ langIdModelName);
if (TcLog.ENABLE_FULL_LOGGING) {
TcLog.v(
@@ -167,7 +163,7 @@
String.format(
Locale.US,
"%s:%s %d links (%d/%d chars) %dms %s annotator=%s langid=%s",
- callId,
+ sessionId,
entityType,
stats.numLinks,
stats.numLinksTextLength,
diff --git a/java/src/com/android/textclassifier/common/statsd/TextClassifierApiUsageLogger.java b/java/src/com/android/textclassifier/common/statsd/TextClassifierApiUsageLogger.java
new file mode 100644
index 0000000..8a79d74
--- /dev/null
+++ b/java/src/com/android/textclassifier/common/statsd/TextClassifierApiUsageLogger.java
@@ -0,0 +1,142 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.common.statsd;
+
+import static java.lang.annotation.RetentionPolicy.SOURCE;
+
+import android.os.SystemClock;
+import android.view.textclassifier.TextClassificationSessionId;
+import androidx.annotation.IntDef;
+import androidx.annotation.Nullable;
+import com.android.textclassifier.common.base.TcLog;
+import com.google.common.base.Preconditions;
+import com.google.common.base.Supplier;
+import java.lang.annotation.Retention;
+import java.util.Locale;
+import java.util.Random;
+import java.util.concurrent.Executor;
+
+/** Logs the TextClassifier API usages. */
+public final class TextClassifierApiUsageLogger {
+ private static final String TAG = "ApiUsageLogger";
+
+ public static final int API_TYPE_SUGGEST_SELECTION =
+ TextClassifierStatsLog.TEXT_CLASSIFIER_API_USAGE_REPORTED__API_TYPE__SUGGEST_SELECTION;
+ public static final int API_TYPE_CLASSIFY_TEXT =
+ TextClassifierStatsLog.TEXT_CLASSIFIER_API_USAGE_REPORTED__API_TYPE__CLASSIFY_TEXT;
+ public static final int API_TYPE_GENERATE_LINKS =
+ TextClassifierStatsLog.TEXT_CLASSIFIER_API_USAGE_REPORTED__API_TYPE__GENERATE_LINKS;
+ public static final int API_TYPE_SUGGEST_CONVERSATION_ACTIONS =
+ TextClassifierStatsLog
+ .TEXT_CLASSIFIER_API_USAGE_REPORTED__API_TYPE__SUGGEST_CONVERSATION_ACTIONS;
+ public static final int API_TYPE_DETECT_LANGUAGES =
+ TextClassifierStatsLog.TEXT_CLASSIFIER_API_USAGE_REPORTED__API_TYPE__DETECT_LANGUAGES;
+
+ /** The type of the API. */
+ @Retention(SOURCE)
+ @IntDef({
+ API_TYPE_SUGGEST_SELECTION,
+ API_TYPE_CLASSIFY_TEXT,
+ API_TYPE_GENERATE_LINKS,
+ API_TYPE_SUGGEST_CONVERSATION_ACTIONS,
+ API_TYPE_DETECT_LANGUAGES
+ })
+ public @interface ApiType {}
+
+ private final Executor executor;
+
+ private final Supplier<Integer> sampleRateSupplier;
+
+ private final Random random;
+
+ /**
+ * @param sampleRateSupplier The rate at which log events are written. (e.g. 100 means there is a
+ * 0.01 chance that a call to logGenerateLinks results in an event being written). To write
+ * all events, pass 1. To disable logging, pass any number < 1. Sampling is used to reduce the
+ * amount of logging data generated.
+ * @param executor that is used to execute the logging work.
+ */
+ public TextClassifierApiUsageLogger(Supplier<Integer> sampleRateSupplier, Executor executor) {
+ this.executor = Preconditions.checkNotNull(executor);
+ this.sampleRateSupplier = sampleRateSupplier;
+ this.random = new Random();
+ }
+
+ public Session createSession(
+ @ApiType int apiType, @Nullable TextClassificationSessionId sessionId) {
+ return new Session(apiType, sessionId);
+ }
+
+ /** A session to log an API invocation. Creates a new session for each API call. */
+ public final class Session {
+ @ApiType private final int apiType;
+ @Nullable private final TextClassificationSessionId sessionId;
+ private final long beginElapsedRealTime;
+
+ private Session(@ApiType int apiType, @Nullable TextClassificationSessionId sessionId) {
+ this.apiType = apiType;
+ this.sessionId = sessionId;
+ beginElapsedRealTime = SystemClock.elapsedRealtime();
+ }
+
+ public void reportSuccess() {
+ reportInternal(/* success= */ true);
+ }
+
+ public void reportFailure() {
+ reportInternal(/* success= */ false);
+ }
+
+ private void reportInternal(boolean success) {
+ if (!shouldLog()) {
+ return;
+ }
+ final long latencyInMillis = SystemClock.elapsedRealtime() - beginElapsedRealTime;
+ if (TcLog.ENABLE_FULL_LOGGING) {
+ TcLog.v(
+ TAG,
+ String.format(
+ Locale.ENGLISH,
+ "TextClassifierApiUsageLogger: apiType=%d success=%b latencyInMillis=%d",
+ apiType,
+ success,
+ latencyInMillis));
+ }
+ executor.execute(
+ () ->
+ TextClassifierStatsLog.write(
+ TextClassifierStatsLog.TEXT_CLASSIFIER_API_USAGE_REPORTED,
+ apiType,
+ success
+ ? TextClassifierStatsLog
+ .TEXT_CLASSIFIER_API_USAGE_REPORTED__RESULT_TYPE__SUCCESS
+ : TextClassifierStatsLog
+ .TEXT_CLASSIFIER_API_USAGE_REPORTED__RESULT_TYPE__FAIL,
+ latencyInMillis,
+ sessionId == null ? "" : sessionId.getValue()));
+ }
+ }
+
+ /** Returns whether this particular event should be logged. */
+ private boolean shouldLog() {
+ if (sampleRateSupplier.get() < 1) {
+ return false;
+ } else {
+ return random.nextInt(sampleRateSupplier.get()) == 0;
+ }
+ }
+}
diff --git a/java/src/com/android/textclassifier/common/statsd/TextClassifierEventLogger.java b/java/src/com/android/textclassifier/common/statsd/TextClassifierEventLogger.java
index 41f546c..06ad44f 100644
--- a/java/src/com/android/textclassifier/common/statsd/TextClassifierEventLogger.java
+++ b/java/src/com/android/textclassifier/common/statsd/TextClassifierEventLogger.java
@@ -19,9 +19,6 @@
import static com.google.common.base.Charsets.UTF_8;
import static com.google.common.base.Strings.nullToEmpty;
-import android.util.StatsEvent;
-import android.util.StatsLog;
-import android.view.textclassifier.TextClassifier;
import com.android.textclassifier.common.base.TcLog;
import com.android.textclassifier.common.logging.ResultIdUtils;
import com.android.textclassifier.common.logging.TextClassificationContext;
@@ -36,11 +33,6 @@
/** Logs {@link android.view.textclassifier.TextClassifierEvent}. */
public final class TextClassifierEventLogger {
private static final String TAG = "TCEventLogger";
- // These constants are defined in atoms.proto.
- private static final int TEXT_SELECTION_EVENT_ATOM_ID = 219;
- static final int TEXT_LINKIFY_EVENT_ATOM_ID = 220;
- private static final int CONVERSATION_ACTIONS_EVENT_ATOM_ID = 221;
- private static final int LANGUAGE_DETECTION_EVENT_ATOM_ID = 222;
/** Emits a text classifier event to the logs. */
public void writeEvent(
@@ -69,24 +61,20 @@
@Nullable TextClassificationSessionId sessionId,
TextClassifierEvent.TextSelectionEvent event) {
ImmutableList<String> modelNames = getModelNames(event);
- StatsEvent statsEvent =
- StatsEvent.newBuilder()
- .setAtomId(TEXT_SELECTION_EVENT_ATOM_ID)
- .writeString(sessionId == null ? null : sessionId.getValue())
- .writeInt(getEventType(event))
- .writeString(getItemAt(modelNames, /* index= */ 0, /* defaultValue= */ null))
- .writeInt(getWidgetType(event))
- .writeInt(event.getEventIndex())
- .writeString(getItemAt(event.getEntityTypes(), /* index= */ 0))
- .writeInt(event.getRelativeWordStartIndex())
- .writeInt(event.getRelativeWordEndIndex())
- .writeInt(event.getRelativeSuggestedWordStartIndex())
- .writeInt(event.getRelativeSuggestedWordEndIndex())
- .writeString(getPackageName(event))
- .writeString(getItemAt(modelNames, /* index= */ 1, /* defaultValue= */ null))
- .usePooledBuffer()
- .build();
- StatsLog.write(statsEvent);
+ TextClassifierStatsLog.write(
+ TextClassifierStatsLog.TEXT_SELECTION_EVENT,
+ sessionId == null ? null : sessionId.getValue(),
+ getEventType(event),
+ getItemAt(modelNames, /* index= */ 0, /* defaultValue= */ null),
+ getWidgetType(event),
+ event.getEventIndex(),
+ getItemAt(event.getEntityTypes(), /* index= */ 0),
+ event.getRelativeWordStartIndex(),
+ event.getRelativeWordEndIndex(),
+ event.getRelativeSuggestedWordStartIndex(),
+ event.getRelativeSuggestedWordEndIndex(),
+ getPackageName(event),
+ getItemAt(modelNames, /* index= */ 1, /* defaultValue= */ null));
}
private static int getEventType(TextClassifierEvent.TextSelectionEvent event) {
@@ -103,24 +91,20 @@
private static void logTextLinkifyEvent(
TextClassificationSessionId sessionId, TextClassifierEvent.TextLinkifyEvent event) {
ImmutableList<String> modelNames = getModelNames(event);
- StatsEvent statsEvent =
- StatsEvent.newBuilder()
- .setAtomId(TEXT_LINKIFY_EVENT_ATOM_ID)
- .writeString(sessionId == null ? null : sessionId.getValue())
- .writeInt(event.getEventType())
- .writeString(getItemAt(modelNames, /* index= */ 0, /* defaultValue= */ null))
- .writeInt(getWidgetType(event))
- .writeInt(event.getEventIndex())
- .writeString(getItemAt(event.getEntityTypes(), /* index= */ 0))
- .writeInt(/* numOfLinks */ 0)
- .writeInt(/* linkedTextLength */ 0)
- .writeInt(/* textLength */ 0)
- .writeLong(/* latencyInMillis */ 0L)
- .writeString(getPackageName(event))
- .writeString(getItemAt(modelNames, /* index= */ 1, /* defaultValue= */ null))
- .usePooledBuffer()
- .build();
- StatsLog.write(statsEvent);
+ TextClassifierStatsLog.write(
+ TextClassifierStatsLog.TEXT_LINKIFY_EVENT,
+ sessionId == null ? null : sessionId.getValue(),
+ event.getEventType(),
+ getItemAt(modelNames, /* index= */ 0, /* defaultValue= */ null),
+ getWidgetType(event),
+ event.getEventIndex(),
+ getItemAt(event.getEntityTypes(), /* index= */ 0),
+ /* numOfLinks */ 0,
+ /* linkedTextLength */ 0,
+ /* textLength */ 0,
+ /* latencyInMillis */ 0L,
+ getPackageName(event),
+ getItemAt(modelNames, /* index= */ 1, /* defaultValue= */ null));
}
private static void logConversationActionsEvent(
@@ -128,46 +112,37 @@
TextClassifierEvent.ConversationActionsEvent event) {
String resultId = nullToEmpty(event.getResultId());
ImmutableList<String> modelNames = ResultIdUtils.getModelNames(resultId);
- StatsEvent statsEvent =
- StatsEvent.newBuilder()
- .setAtomId(CONVERSATION_ACTIONS_EVENT_ATOM_ID)
- // TODO: Update ExtServices to set the session id.
- .writeString(
- sessionId == null
- ? Hashing.goodFastHash(64).hashString(resultId, UTF_8).toString()
- : sessionId.getValue())
- .writeInt(event.getEventType())
- .writeString(getItemAt(modelNames, /* index= */ 0, /* defaultValue= */ null))
- .writeInt(getWidgetType(event))
- .writeString(getItemAt(event.getEntityTypes(), /* index= */ 0))
- .writeString(getItemAt(event.getEntityTypes(), /* index= */ 1))
- .writeString(getItemAt(event.getEntityTypes(), /* index= */ 2))
- .writeFloat(getFloatAt(event.getScores(), /* index= */ 0))
- .writeString(getPackageName(event))
- .writeString(getItemAt(modelNames, /* index= */ 1, /* defaultValue= */ null))
- .writeString(getItemAt(modelNames, /* index= */ 2, /* defaultValue= */ null))
- .usePooledBuffer()
- .build();
- StatsLog.write(statsEvent);
+ TextClassifierStatsLog.write(
+ TextClassifierStatsLog.CONVERSATION_ACTIONS_EVENT,
+ // TODO: Update ExtServices to set the session id.
+ sessionId == null
+ ? Hashing.goodFastHash(64).hashString(resultId, UTF_8).toString()
+ : sessionId.getValue(),
+ event.getEventType(),
+ getItemAt(modelNames, /* index= */ 0, /* defaultValue= */ null),
+ getWidgetType(event),
+ getItemAt(event.getEntityTypes(), /* index= */ 0),
+ getItemAt(event.getEntityTypes(), /* index= */ 1),
+ getItemAt(event.getEntityTypes(), /* index= */ 2),
+ getFloatAt(event.getScores(), /* index= */ 0),
+ getPackageName(event),
+ getItemAt(modelNames, /* index= */ 1, /* defaultValue= */ null),
+ getItemAt(modelNames, /* index= */ 2, /* defaultValue= */ null));
}
private static void logLanguageDetectionEvent(
@Nullable TextClassificationSessionId sessionId,
TextClassifierEvent.LanguageDetectionEvent event) {
- StatsEvent statsEvent =
- StatsEvent.newBuilder()
- .setAtomId(LANGUAGE_DETECTION_EVENT_ATOM_ID)
- .writeString(sessionId == null ? null : sessionId.getValue())
- .writeInt(event.getEventType())
- .writeString(getItemAt(getModelNames(event), /* index= */ 0, /* defaultValue= */ null))
- .writeInt(getWidgetType(event))
- .writeString(getItemAt(event.getEntityTypes(), /* index= */ 0))
- .writeFloat(getFloatAt(event.getScores(), /* index= */ 0))
- .writeInt(getIntAt(event.getActionIndices(), /* index= */ 0))
- .writeString(getPackageName(event))
- .usePooledBuffer()
- .build();
- StatsLog.write(statsEvent);
+ TextClassifierStatsLog.write(
+ TextClassifierStatsLog.LANGUAGE_DETECTION_EVENT,
+ sessionId == null ? null : sessionId.getValue(),
+ event.getEventType(),
+ getItemAt(getModelNames(event), /* index= */ 0, /* defaultValue= */ null),
+ getWidgetType(event),
+ getItemAt(event.getEntityTypes(), /* index= */ 0),
+ getFloatAt(event.getScores(), /* index= */ 0),
+ getIntAt(event.getActionIndices(), /* index= */ 0),
+ getPackageName(event));
}
@Nullable
@@ -219,6 +194,14 @@
return ResultIdUtils.getModelNames(event.getResultId());
}
+ private static int getWidgetType(TextClassifierEvent event) {
+ TextClassificationContext eventContext = event.getEventContext();
+ if (eventContext == null) {
+ return TextClassifierStatsLog.TEXT_SELECTION_EVENT__WIDGET_TYPE__WIDGET_TYPE_UNKNOWN;
+ }
+ return WidgetTypeConverter.toLoggingValue(eventContext.getWidgetType());
+ }
+
@Nullable
private static String getPackageName(TextClassifierEvent event) {
TextClassificationContext eventContext = event.getEventContext();
@@ -227,52 +210,4 @@
}
return eventContext.getPackageName();
}
-
- private static int getWidgetType(TextClassifierEvent event) {
- TextClassificationContext eventContext = event.getEventContext();
- if (eventContext == null) {
- return WidgetType.WIDGET_TYPE_UNKNOWN;
- }
- switch (eventContext.getWidgetType()) {
- case TextClassifier.WIDGET_TYPE_UNKNOWN:
- return WidgetType.WIDGET_TYPE_UNKNOWN;
- case TextClassifier.WIDGET_TYPE_TEXTVIEW:
- return WidgetType.WIDGET_TYPE_TEXTVIEW;
- case TextClassifier.WIDGET_TYPE_EDITTEXT:
- return WidgetType.WIDGET_TYPE_EDITTEXT;
- case TextClassifier.WIDGET_TYPE_UNSELECTABLE_TEXTVIEW:
- return WidgetType.WIDGET_TYPE_UNSELECTABLE_TEXTVIEW;
- case TextClassifier.WIDGET_TYPE_WEBVIEW:
- return WidgetType.WIDGET_TYPE_WEBVIEW;
- case TextClassifier.WIDGET_TYPE_EDIT_WEBVIEW:
- return WidgetType.WIDGET_TYPE_EDIT_WEBVIEW;
- case TextClassifier.WIDGET_TYPE_CUSTOM_TEXTVIEW:
- return WidgetType.WIDGET_TYPE_CUSTOM_TEXTVIEW;
- case TextClassifier.WIDGET_TYPE_CUSTOM_EDITTEXT:
- return WidgetType.WIDGET_TYPE_CUSTOM_EDITTEXT;
- case TextClassifier.WIDGET_TYPE_CUSTOM_UNSELECTABLE_TEXTVIEW:
- return WidgetType.WIDGET_TYPE_CUSTOM_UNSELECTABLE_TEXTVIEW;
- case TextClassifier.WIDGET_TYPE_NOTIFICATION:
- return WidgetType.WIDGET_TYPE_NOTIFICATION;
- default: // fall out
- }
- return WidgetType.WIDGET_TYPE_UNKNOWN;
- }
-
- /** Widget type constants for logging. */
- public static final class WidgetType {
- // Sync these constants with textclassifier_enums.proto.
- public static final int WIDGET_TYPE_UNKNOWN = 0;
- public static final int WIDGET_TYPE_TEXTVIEW = 1;
- public static final int WIDGET_TYPE_EDITTEXT = 2;
- public static final int WIDGET_TYPE_UNSELECTABLE_TEXTVIEW = 3;
- public static final int WIDGET_TYPE_WEBVIEW = 4;
- public static final int WIDGET_TYPE_EDIT_WEBVIEW = 5;
- public static final int WIDGET_TYPE_CUSTOM_TEXTVIEW = 6;
- public static final int WIDGET_TYPE_CUSTOM_EDITTEXT = 7;
- public static final int WIDGET_TYPE_CUSTOM_UNSELECTABLE_TEXTVIEW = 8;
- public static final int WIDGET_TYPE_NOTIFICATION = 9;
-
- private WidgetType() {}
- }
}
diff --git a/java/src/com/android/textclassifier/common/statsd/WidgetTypeConverter.java b/java/src/com/android/textclassifier/common/statsd/WidgetTypeConverter.java
new file mode 100644
index 0000000..13c04d1
--- /dev/null
+++ b/java/src/com/android/textclassifier/common/statsd/WidgetTypeConverter.java
@@ -0,0 +1,57 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.common.statsd;
+
+import android.view.textclassifier.TextClassifier;
+
+/** Converts TextClassifier's WidgetTypes to enum values that are logged to server. */
+final class WidgetTypeConverter {
+ public static int toLoggingValue(String widgetType) {
+ switch (widgetType) {
+ case TextClassifier.WIDGET_TYPE_UNKNOWN:
+ return TextClassifierStatsLog.TEXT_SELECTION_EVENT__WIDGET_TYPE__WIDGET_TYPE_UNKNOWN;
+ case TextClassifier.WIDGET_TYPE_TEXTVIEW:
+ return TextClassifierStatsLog.TEXT_SELECTION_EVENT__WIDGET_TYPE__WIDGET_TYPE_TEXTVIEW;
+ case TextClassifier.WIDGET_TYPE_EDITTEXT:
+ return TextClassifierStatsLog.TEXT_SELECTION_EVENT__WIDGET_TYPE__WIDGET_TYPE_EDITTEXT;
+ case TextClassifier.WIDGET_TYPE_UNSELECTABLE_TEXTVIEW:
+ return TextClassifierStatsLog
+ .TEXT_SELECTION_EVENT__WIDGET_TYPE__WIDGET_TYPE_UNSELECTABLE_TEXTVIEW;
+ case TextClassifier.WIDGET_TYPE_WEBVIEW:
+ return TextClassifierStatsLog.TEXT_SELECTION_EVENT__WIDGET_TYPE__WIDGET_TYPE_WEBVIEW;
+ case TextClassifier.WIDGET_TYPE_EDIT_WEBVIEW:
+ return TextClassifierStatsLog.TEXT_SELECTION_EVENT__WIDGET_TYPE__WIDGET_TYPE_EDIT_WEBVIEW;
+ case TextClassifier.WIDGET_TYPE_CUSTOM_TEXTVIEW:
+ return TextClassifierStatsLog
+ .TEXT_SELECTION_EVENT__WIDGET_TYPE__WIDGET_TYPE_CUSTOM_TEXTVIEW;
+ case TextClassifier.WIDGET_TYPE_CUSTOM_EDITTEXT:
+ return TextClassifierStatsLog
+ .TEXT_SELECTION_EVENT__WIDGET_TYPE__WIDGET_TYPE_CUSTOM_EDITTEXT;
+ case TextClassifier.WIDGET_TYPE_CUSTOM_UNSELECTABLE_TEXTVIEW:
+ return TextClassifierStatsLog
+ .TEXT_SELECTION_EVENT__WIDGET_TYPE__WIDGET_TYPE_CUSTOM_UNSELECTABLE_TEXTVIEW;
+ case TextClassifier.WIDGET_TYPE_NOTIFICATION:
+ return TextClassifierStatsLog.TEXT_SELECTION_EVENT__WIDGET_TYPE__WIDGET_TYPE_NOTIFICATION;
+ case "clipboard": // TODO(tonymak) Replace it with WIDGET_TYPE_CLIPBOARD once S SDK is dropped
+ return TextClassifierStatsLog.TEXT_SELECTION_EVENT__WIDGET_TYPE__WIDGET_TYPE_CLIPBOARD;
+ default: // fall out
+ }
+ return TextClassifierStatsLog.TEXT_SELECTION_EVENT__WIDGET_TYPE__WIDGET_TYPE_UNKNOWN;
+ }
+
+ private WidgetTypeConverter() {}
+}
diff --git a/java/tests/instrumentation/Android.bp b/java/tests/instrumentation/Android.bp
index 1f9fc23..74261c1 100644
--- a/java/tests/instrumentation/Android.bp
+++ b/java/tests/instrumentation/Android.bp
@@ -45,6 +45,7 @@
"TextClassifierServiceLib",
"statsdprotolite",
"textclassifierprotoslite",
+ "TextClassifierCoverageLib"
],
jni_libs: [
@@ -53,14 +54,16 @@
],
test_suites: [
- "device-tests", "mts-extservices"
+ "general-tests", "mts-extservices"
],
plugins: ["androidx.room_room-compiler-plugin",],
- platform_apis: true,
+ min_sdk_version: "30",
+ sdk_version: "system_current",
use_embedded_native_libs: true,
compile_multilib: "both",
instrumentation_for: "TextClassifierService",
- min_sdk_version: "30",
+
+ data: ["testdata/*"],
}
diff --git a/java/tests/instrumentation/AndroidManifest.xml b/java/tests/instrumentation/AndroidManifest.xml
index 4964caf..3ee30da 100644
--- a/java/tests/instrumentation/AndroidManifest.xml
+++ b/java/tests/instrumentation/AndroidManifest.xml
@@ -2,8 +2,9 @@
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="com.android.textclassifier.tests">
- <uses-sdk android:minSdkVersion="29" android:targetSdkVersion="30"/>
+ <uses-sdk android:minSdkVersion="30" android:targetSdkVersion="30"/>
<uses-permission android:name="android.permission.QUERY_ALL_PACKAGES" />
+ <uses-permission android:name="android.permission.MANAGE_EXTERNAL_STORAGE"/>
<application>
<uses-library android:name="android.test.runner"/>
diff --git a/java/tests/instrumentation/AndroidTest.xml b/java/tests/instrumentation/AndroidTest.xml
index e02a338..6c47a1a 100644
--- a/java/tests/instrumentation/AndroidTest.xml
+++ b/java/tests/instrumentation/AndroidTest.xml
@@ -13,8 +13,8 @@
See the License for the specific language governing permissions and
limitations under the License.
-->
-<!-- This test config file is auto-generated. -->
<configuration description="Runs TextClassifierServiceTest.">
+ <option name="config-descriptor:metadata" key="mainline-param" value="com.google.android.extservices.apex" />
<option name="test-suite-tag" value="apct" />
<option name="test-suite-tag" value="apct-instrumentation" />
<target_preparer class="com.android.tradefed.targetprep.suite.SuiteApkInstaller">
@@ -22,9 +22,15 @@
<option name="test-file-name" value="TextClassifierServiceTest.apk" />
</target_preparer>
+ <target_preparer class="com.android.compatibility.common.tradefed.targetprep.FilePusher">
+ <option name="cleanup" value="true" />
+ <option name="push" value="testdata->/data/local/tmp/TextClassifierServiceTest/testdata" />
+ </target_preparer>
+
<test class="com.android.tradefed.testtype.AndroidJUnitTest" >
<option name="package" value="com.android.textclassifier.tests" />
<option name="runner" value="androidx.test.runner.AndroidJUnitRunner" />
+ <option name="device-listeners" value="com.android.textclassifier.testing.TextClassifierInstrumentationListener" />
</test>
<object type="module_controller" class="com.android.tradefed.testtype.suite.module.MainlineTestModuleController">
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/ActionsSuggestionsHelperTest.java b/java/tests/instrumentation/src/com/android/textclassifier/ActionsSuggestionsHelperTest.java
index 59dc41a..ebfeed3 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/ActionsSuggestionsHelperTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/ActionsSuggestionsHelperTest.java
@@ -25,6 +25,8 @@
import android.app.RemoteAction;
import android.content.ComponentName;
import android.content.Intent;
+import android.content.pm.ActivityInfo;
+import android.content.pm.ResolveInfo;
import android.graphics.drawable.Icon;
import android.net.Uri;
import android.os.Bundle;
@@ -34,6 +36,7 @@
import androidx.test.ext.junit.runners.AndroidJUnit4;
import androidx.test.filters.SmallTest;
import com.android.textclassifier.common.intent.LabeledIntent;
+import com.android.textclassifier.common.intent.LabeledIntent.TitleChooser;
import com.android.textclassifier.common.intent.TemplateIntentFactory;
import com.google.android.textclassifier.ActionsSuggestionsModel;
import com.google.android.textclassifier.RemoteActionTemplate;
@@ -223,7 +226,7 @@
public void createLabeledIntentResult_null() {
ActionsSuggestionsModel.ActionSuggestion nativeSuggestion =
new ActionsSuggestionsModel.ActionSuggestion(
- "text", ConversationAction.TYPE_OPEN_URL, 1.0f, null, null, null);
+ "text", ConversationAction.TYPE_OPEN_URL, 1.0f, null, null, null, null);
LabeledIntent.Result labeledIntentResult =
ActionsSuggestionsHelper.createLabeledIntentResult(
@@ -243,7 +246,8 @@
1.0f,
null,
null,
- new RemoteActionTemplate[0]);
+ new RemoteActionTemplate[0],
+ null);
LabeledIntent.Result labeledIntentResult =
ActionsSuggestionsHelper.createLabeledIntentResult(
@@ -277,7 +281,8 @@
null,
null,
0)
- });
+ },
+ null);
LabeledIntent.Result labeledIntentResult =
ActionsSuggestionsHelper.createLabeledIntentResult(
@@ -289,6 +294,92 @@
assertThat(labeledIntentResult.resolvedIntent.getAction()).isEqualTo(Intent.ACTION_VIEW);
}
+ @Test
+ public void createTitleChooser_notOpenUrl() {
+ assertThat(ActionsSuggestionsHelper.createTitleChooser(ConversationAction.TYPE_CALL_PHONE))
+ .isNull();
+ }
+
+ @Test
+ public void createTitleChooser_openUrl_resolveInfoIsNull() {
+ TitleChooser titleChooser =
+ ActionsSuggestionsHelper.createTitleChooser(ConversationAction.TYPE_OPEN_URL);
+ LabeledIntent labeledIntent = createWebLabeledIntent();
+
+ assertThat(titleChooser.chooseTitle(labeledIntent, /* resolveInfo= */ null).toString())
+ .isEqualTo("titleWithEntity");
+ }
+
+ @Test
+ public void createTitleChooser_openUrl_packageIsNotAndroidAndHandleAllWebDataUriTrue() {
+ TitleChooser titleChooser =
+ ActionsSuggestionsHelper.createTitleChooser(ConversationAction.TYPE_OPEN_URL);
+ LabeledIntent labeledIntent = createWebLabeledIntent();
+
+ assertThat(
+ titleChooser
+ .chooseTitle(
+ labeledIntent,
+ createResolveInfo("com.android.chrome", /* handleAllWebDataURI= */ true))
+ .toString())
+ .isEqualTo("titleWithEntity");
+ }
+
+ @Test
+ public void createTitleChooser_openUrl_packageIsNotAndroidAndHandleAllWebDataUriFalse() {
+ TitleChooser titleChooser =
+ ActionsSuggestionsHelper.createTitleChooser(ConversationAction.TYPE_OPEN_URL);
+ LabeledIntent labeledIntent = createWebLabeledIntent();
+
+ assertThat(
+ titleChooser
+ .chooseTitle(
+ labeledIntent,
+ createResolveInfo("com.youtube", /* handleAllWebDataURI= */ false))
+ .toString())
+ .isEqualTo("titleWithoutEntity");
+ }
+
+ @Test
+ public void createTitleChooser_openUrl_packageIsAndroidAndHandleAllWebDataUriFalse() {
+ TitleChooser titleChooser =
+ ActionsSuggestionsHelper.createTitleChooser(ConversationAction.TYPE_OPEN_URL);
+ LabeledIntent labeledIntent = createWebLabeledIntent();
+
+ assertThat(
+ titleChooser
+ .chooseTitle(
+ labeledIntent, createResolveInfo("android", /* handleAllWebDataURI= */ false))
+ .toString())
+ .isEqualTo("titleWithEntity");
+ }
+
+ @Test
+ public void createTitleChooser_openUrl_packageIsAndroidAndHandleAllWebDataUriTrue() {
+ TitleChooser titleChooser =
+ ActionsSuggestionsHelper.createTitleChooser(ConversationAction.TYPE_OPEN_URL);
+ LabeledIntent labeledIntent = createWebLabeledIntent();
+
+ assertThat(
+ titleChooser
+ .chooseTitle(
+ labeledIntent, createResolveInfo("android", /* handleAllWebDataURI= */ true))
+ .toString())
+ .isEqualTo("titleWithEntity");
+ }
+
+ private LabeledIntent createWebLabeledIntent() {
+ Intent webIntent = new Intent(Intent.ACTION_VIEW);
+ webIntent.setData(Uri.parse("http://www.android.com"));
+ return new LabeledIntent(
+ "titleWithoutEntity",
+ "titleWithEntity",
+ "description",
+ "descriptionWithAppName",
+ webIntent,
+ /* requestCode= */ 0);
+ }
+
private static ZonedDateTime createZonedDateTimeFromMsUtc(long msUtc) {
return ZonedDateTime.ofInstant(Instant.ofEpochMilli(msUtc), ZoneId.of("UTC"));
}
@@ -303,4 +394,12 @@
assertThat(nativeMessage.getDetectedTextLanguageTags()).isEqualTo(LOCALE_TAG);
assertThat(nativeMessage.getReferenceTimeMsUtc()).isEqualTo(referenceTimeInMsUtc);
}
+
+ private static ResolveInfo createResolveInfo(String packageName, boolean handleAllWebDataURI) {
+ ResolveInfo resolveInfo = new ResolveInfo();
+ resolveInfo.activityInfo = new ActivityInfo();
+ resolveInfo.activityInfo.packageName = packageName;
+ resolveInfo.handleAllWebDataURI = handleAllWebDataURI;
+ return resolveInfo;
+ }
}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/DefaultTextClassifierServiceTest.java b/java/tests/instrumentation/src/com/android/textclassifier/DefaultTextClassifierServiceTest.java
new file mode 100644
index 0000000..746931b
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/DefaultTextClassifierServiceTest.java
@@ -0,0 +1,305 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier;
+
+import static com.google.common.truth.Truth.assertThat;
+import static org.mockito.Mockito.verify;
+
+import android.content.Context;
+import android.os.Binder;
+import android.os.CancellationSignal;
+import android.os.Parcel;
+import android.service.textclassifier.TextClassifierService;
+import android.view.textclassifier.ConversationAction;
+import android.view.textclassifier.ConversationActions;
+import android.view.textclassifier.TextClassification;
+import android.view.textclassifier.TextClassificationSessionId;
+import android.view.textclassifier.TextClassifier;
+import android.view.textclassifier.TextLanguage;
+import android.view.textclassifier.TextLinks;
+import android.view.textclassifier.TextLinks.TextLink;
+import android.view.textclassifier.TextSelection;
+import androidx.test.core.app.ApplicationProvider;
+import androidx.test.ext.junit.runners.AndroidJUnit4;
+import androidx.test.filters.SmallTest;
+import com.android.internal.os.StatsdConfigProto.StatsdConfig;
+import com.android.os.AtomsProto;
+import com.android.os.AtomsProto.Atom;
+import com.android.os.AtomsProto.TextClassifierApiUsageReported;
+import com.android.os.AtomsProto.TextClassifierApiUsageReported.ApiType;
+import com.android.os.AtomsProto.TextClassifierApiUsageReported.ResultType;
+import com.android.textclassifier.common.ModelFileManager;
+import com.android.textclassifier.common.TextClassifierSettings;
+import com.android.textclassifier.common.statsd.StatsdTestUtils;
+import com.android.textclassifier.common.statsd.TextClassifierApiUsageLogger;
+import com.google.common.base.Preconditions;
+import com.google.common.collect.ImmutableList;
+import com.google.common.util.concurrent.ListeningExecutorService;
+import com.google.common.util.concurrent.MoreExecutors;
+import java.util.List;
+import java.util.concurrent.Executor;
+import java.util.stream.Collectors;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.ArgumentCaptor;
+import org.mockito.Mock;
+import org.mockito.Mockito;
+import org.mockito.MockitoAnnotations;
+
+@SmallTest
+@RunWith(AndroidJUnit4.class)
+public class DefaultTextClassifierServiceTest {
+ /** A statsd config ID, which is arbitrary. */
+ private static final long CONFIG_ID = 689777;
+
+ private static final long SHORT_TIMEOUT_MS = 1000;
+
+ private static final String SESSION_ID = "abcdef";
+
+ private TestInjector testInjector;
+ private DefaultTextClassifierService defaultTextClassifierService;
+ @Mock private TextClassifierService.Callback<TextClassification> textClassificationCallback;
+ @Mock private TextClassifierService.Callback<TextSelection> textSelectionCallback;
+ @Mock private TextClassifierService.Callback<TextLinks> textLinksCallback;
+ @Mock private TextClassifierService.Callback<ConversationActions> conversationActionsCallback;
+ @Mock private TextClassifierService.Callback<TextLanguage> textLanguageCallback;
+
+ @Before
+ public void setup() {
+ MockitoAnnotations.initMocks(this);
+
+ testInjector = new TestInjector(ApplicationProvider.getApplicationContext());
+ defaultTextClassifierService = new DefaultTextClassifierService(testInjector);
+ defaultTextClassifierService.onCreate();
+ }
+
+ @Before
+ public void setupStatsdTestUtils() throws Exception {
+ StatsdTestUtils.cleanup(CONFIG_ID);
+
+ StatsdConfig.Builder builder =
+ StatsdConfig.newBuilder()
+ .setId(CONFIG_ID)
+ .addAllowedLogSource(ApplicationProvider.getApplicationContext().getPackageName());
+ StatsdTestUtils.addAtomMatcher(builder, Atom.TEXT_CLASSIFIER_API_USAGE_REPORTED_FIELD_NUMBER);
+ StatsdTestUtils.pushConfig(builder.build());
+ }
+
+ @After
+ public void tearDown() throws Exception {
+ StatsdTestUtils.cleanup(CONFIG_ID);
+ }
+
+ @Test
+ public void classifyText_success() throws Exception {
+ String text = "www.android.com";
+ TextClassification.Request request =
+ new TextClassification.Request.Builder(text, 0, text.length()).build();
+
+ defaultTextClassifierService.onClassifyText(
+ createTextClassificationSessionId(),
+ request,
+ new CancellationSignal(),
+ textClassificationCallback);
+
+ ArgumentCaptor<TextClassification> captor = ArgumentCaptor.forClass(TextClassification.class);
+ verify(textClassificationCallback).onSuccess(captor.capture());
+ assertThat(captor.getValue().getEntityCount()).isGreaterThan(0);
+ assertThat(captor.getValue().getEntity(0)).isEqualTo(TextClassifier.TYPE_URL);
+ verifyApiUsageLog(ApiType.CLASSIFY_TEXT, ResultType.SUCCESS);
+ }
+
+ @Test
+ public void suggestSelection_success() throws Exception {
+ String text = "Visit http://www.android.com for more information";
+ String selected = "http";
+ String suggested = "http://www.android.com";
+ int start = text.indexOf(selected);
+ int end = start + suggested.length();
+ TextSelection.Request request = new TextSelection.Request.Builder(text, start, end).build();
+
+ defaultTextClassifierService.onSuggestSelection(
+ createTextClassificationSessionId(),
+ request,
+ new CancellationSignal(),
+ textSelectionCallback);
+
+ ArgumentCaptor<TextSelection> captor = ArgumentCaptor.forClass(TextSelection.class);
+ verify(textSelectionCallback).onSuccess(captor.capture());
+ assertThat(captor.getValue().getEntityCount()).isGreaterThan(0);
+ assertThat(captor.getValue().getEntity(0)).isEqualTo(TextClassifier.TYPE_URL);
+ verifyApiUsageLog(ApiType.SUGGEST_SELECTION, ResultType.SUCCESS);
+ }
+
+ @Test
+ public void generateLinks_success() throws Exception {
+ String text = "Visit http://www.android.com for more information";
+ TextLinks.Request request = new TextLinks.Request.Builder(text).build();
+
+ defaultTextClassifierService.onGenerateLinks(
+ createTextClassificationSessionId(), request, new CancellationSignal(), textLinksCallback);
+
+ ArgumentCaptor<TextLinks> captor = ArgumentCaptor.forClass(TextLinks.class);
+ verify(textLinksCallback).onSuccess(captor.capture());
+ assertThat(captor.getValue().getLinks()).hasSize(1);
+ TextLink textLink = captor.getValue().getLinks().iterator().next();
+ assertThat(textLink.getEntityCount()).isGreaterThan(0);
+ assertThat(textLink.getEntity(0)).isEqualTo(TextClassifier.TYPE_URL);
+ verifyApiUsageLog(ApiType.GENERATE_LINKS, ResultType.SUCCESS);
+ }
+
+ @Test
+ public void detectLanguage_success() throws Exception {
+ String text = "ピカチュウ";
+ TextLanguage.Request request = new TextLanguage.Request.Builder(text).build();
+
+ defaultTextClassifierService.onDetectLanguage(
+ createTextClassificationSessionId(),
+ request,
+ new CancellationSignal(),
+ textLanguageCallback);
+
+ ArgumentCaptor<TextLanguage> captor = ArgumentCaptor.forClass(TextLanguage.class);
+ verify(textLanguageCallback).onSuccess(captor.capture());
+ assertThat(captor.getValue().getLocaleHypothesisCount()).isGreaterThan(0);
+ assertThat(captor.getValue().getLocale(0).toLanguageTag()).isEqualTo("ja");
+ verifyApiUsageLog(ApiType.DETECT_LANGUAGES, ResultType.SUCCESS);
+ }
+
+ @Test
+ public void suggestConversationActions_success() throws Exception {
+ ConversationActions.Message message =
+ new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS)
+ .setText("Checkout www.android.com")
+ .build();
+ ConversationActions.Request request =
+ new ConversationActions.Request.Builder(ImmutableList.of(message)).build();
+
+ defaultTextClassifierService.onSuggestConversationActions(
+ createTextClassificationSessionId(),
+ request,
+ new CancellationSignal(),
+ conversationActionsCallback);
+
+ ArgumentCaptor<ConversationActions> captor = ArgumentCaptor.forClass(ConversationActions.class);
+ verify(conversationActionsCallback).onSuccess(captor.capture());
+ List<ConversationAction> conversationActions = captor.getValue().getConversationActions();
+ assertThat(conversationActions.size()).isGreaterThan(0);
+ assertThat(conversationActions.get(0).getType()).isEqualTo(ConversationAction.TYPE_OPEN_URL);
+ verifyApiUsageLog(ApiType.SUGGEST_CONVERSATION_ACTIONS, ResultType.SUCCESS);
+ }
+
+ @Test
+ public void missingModelFile_onFailureShouldBeCalled() throws Exception {
+ testInjector.setModelFileManager(
+ new ModelFileManager(ApplicationProvider.getApplicationContext(), ImmutableList.of()));
+ defaultTextClassifierService.onCreate();
+
+ TextClassification.Request request = new TextClassification.Request.Builder("hi", 0, 2).build();
+ defaultTextClassifierService.onClassifyText(
+ createTextClassificationSessionId(),
+ request,
+ new CancellationSignal(),
+ textClassificationCallback);
+
+ verify(textClassificationCallback).onFailure(Mockito.anyString());
+ verifyApiUsageLog(ApiType.CLASSIFY_TEXT, ResultType.FAIL);
+ }
+
+ private static void verifyApiUsageLog(
+ AtomsProto.TextClassifierApiUsageReported.ApiType expectedApiType,
+ AtomsProto.TextClassifierApiUsageReported.ResultType expectedResultApiType)
+ throws Exception {
+ ImmutableList<Atom> loggedAtoms = StatsdTestUtils.getLoggedAtoms(CONFIG_ID, SHORT_TIMEOUT_MS);
+ ImmutableList<TextClassifierApiUsageReported> loggedEvents =
+ ImmutableList.copyOf(
+ loggedAtoms.stream()
+ .map(Atom::getTextClassifierApiUsageReported)
+ .collect(Collectors.toList()));
+ assertThat(loggedEvents).hasSize(1);
+ TextClassifierApiUsageReported loggedEvent = loggedEvents.get(0);
+ assertThat(loggedEvent.getLatencyMillis()).isGreaterThan(0L);
+ assertThat(loggedEvent.getApiType()).isEqualTo(expectedApiType);
+ assertThat(loggedEvent.getResultType()).isEqualTo(expectedResultApiType);
+ assertThat(loggedEvent.getSessionId()).isEqualTo(SESSION_ID);
+ }
+
+ private static TextClassificationSessionId createTextClassificationSessionId() {
+ // Used a hack to create TextClassificationSessionId because its constructor is @hide.
+ Parcel parcel = Parcel.obtain();
+ parcel.writeString(SESSION_ID);
+ parcel.writeStrongBinder(new Binder());
+ parcel.setDataPosition(0);
+ return TextClassificationSessionId.CREATOR.createFromParcel(parcel);
+ }
+
+ private static final class TestInjector implements DefaultTextClassifierService.Injector {
+ private final Context context;
+ private ModelFileManager modelFileManager;
+
+ private TestInjector(Context context) {
+ this.context = Preconditions.checkNotNull(context);
+ }
+
+ private void setModelFileManager(ModelFileManager modelFileManager) {
+ this.modelFileManager = modelFileManager;
+ }
+
+ @Override
+ public Context getContext() {
+ return context;
+ }
+
+ @Override
+ public ModelFileManager createModelFileManager(TextClassifierSettings settings) {
+ if (modelFileManager == null) {
+ return TestDataUtils.createModelFileManagerForTesting(context);
+ }
+ return modelFileManager;
+ }
+
+ @Override
+ public TextClassifierSettings createTextClassifierSettings() {
+ return new TextClassifierSettings();
+ }
+
+ @Override
+ public TextClassifierImpl createTextClassifierImpl(
+ TextClassifierSettings settings, ModelFileManager modelFileManager) {
+ return new TextClassifierImpl(context, settings, modelFileManager);
+ }
+
+ @Override
+ public ListeningExecutorService createNormPriorityExecutor() {
+ return MoreExecutors.newDirectExecutorService();
+ }
+
+ @Override
+ public ListeningExecutorService createLowPriorityExecutor() {
+ return MoreExecutors.newDirectExecutorService();
+ }
+
+ @Override
+ public TextClassifierApiUsageLogger createTextClassifierApiUsageLogger(
+ TextClassifierSettings settings, Executor executor) {
+ return new TextClassifierApiUsageLogger(
+ /* sampleRateSupplier= */ () -> 1, MoreExecutors.directExecutor());
+ }
+ }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/ModelFileManagerTest.java b/java/tests/instrumentation/src/com/android/textclassifier/ModelFileManagerTest.java
deleted file mode 100644
index 06d47d6..0000000
--- a/java/tests/instrumentation/src/com/android/textclassifier/ModelFileManagerTest.java
+++ /dev/null
@@ -1,385 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.android.textclassifier;
-
-import static com.google.common.truth.Truth.assertThat;
-import static org.mockito.Mockito.when;
-
-import android.os.LocaleList;
-import androidx.test.core.app.ApplicationProvider;
-import androidx.test.ext.junit.runners.AndroidJUnit4;
-import androidx.test.filters.SmallTest;
-import com.android.textclassifier.ModelFileManager.ModelFile;
-import com.android.textclassifier.common.logging.ResultIdUtils.ModelInfo;
-import com.google.common.base.Optional;
-import com.google.common.collect.ImmutableList;
-import java.io.File;
-import java.io.IOException;
-import java.util.Collections;
-import java.util.List;
-import java.util.Locale;
-import java.util.function.Supplier;
-import java.util.stream.Collectors;
-import org.junit.After;
-import org.junit.Before;
-import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.mockito.Mock;
-import org.mockito.MockitoAnnotations;
-
-@SmallTest
-@RunWith(AndroidJUnit4.class)
-public class ModelFileManagerTest {
- private static final Locale DEFAULT_LOCALE = Locale.forLanguageTag("en-US");
- @Mock private Supplier<ImmutableList<ModelFile>> modelFileSupplier;
- private ModelFileManager.ModelFileSupplierImpl modelFileSupplierImpl;
- private ModelFileManager modelFileManager;
- private File rootTestDir;
- private File factoryModelDir;
- private File updatedModelFile;
-
- @Before
- public void setup() {
- MockitoAnnotations.initMocks(this);
- modelFileManager = new ModelFileManager(modelFileSupplier);
- rootTestDir = ApplicationProvider.getApplicationContext().getCacheDir();
- factoryModelDir = new File(rootTestDir, "factory");
- updatedModelFile = new File(rootTestDir, "updated.model");
-
- modelFileSupplierImpl =
- new ModelFileManager.ModelFileSupplierImpl(
- factoryModelDir,
- "test\\d.model",
- updatedModelFile,
- fd -> 1,
- fd -> ModelFileManager.ModelFile.LANGUAGE_INDEPENDENT);
-
- rootTestDir.mkdirs();
- factoryModelDir.mkdirs();
-
- Locale.setDefault(DEFAULT_LOCALE);
- }
-
- @After
- public void removeTestDir() {
- recursiveDelete(rootTestDir);
- }
-
- @Test
- public void get() {
- ModelFileManager.ModelFile modelFile =
- new ModelFileManager.ModelFile(new File("/path/a"), 1, ImmutableList.of(), "", true);
- when(modelFileSupplier.get()).thenReturn(ImmutableList.of(modelFile));
-
- List<ModelFileManager.ModelFile> modelFiles = modelFileManager.listModelFiles();
-
- assertThat(modelFiles).hasSize(1);
- assertThat(modelFiles.get(0)).isEqualTo(modelFile);
- }
-
- @Test
- public void findBestModel_versionCode() {
- ModelFileManager.ModelFile olderModelFile =
- new ModelFileManager.ModelFile(new File("/path/a"), 1, ImmutableList.of(), "", true);
-
- ModelFileManager.ModelFile newerModelFile =
- new ModelFileManager.ModelFile(new File("/path/b"), 2, ImmutableList.of(), "", true);
- when(modelFileSupplier.get()).thenReturn(ImmutableList.of(olderModelFile, newerModelFile));
-
- ModelFileManager.ModelFile bestModelFile =
- modelFileManager.findBestModelFile(LocaleList.getEmptyLocaleList());
-
- assertThat(bestModelFile).isEqualTo(newerModelFile);
- }
-
- @Test
- public void findBestModel_languageDependentModelIsPreferred() {
- Locale locale = Locale.forLanguageTag("ja");
- ModelFileManager.ModelFile languageIndependentModelFile =
- new ModelFileManager.ModelFile(new File("/path/a"), 1, ImmutableList.of(), "", true);
-
- ModelFileManager.ModelFile languageDependentModelFile =
- new ModelFileManager.ModelFile(
- new File("/path/b"),
- 1,
- Collections.singletonList(locale),
- locale.toLanguageTag(),
- false);
- when(modelFileSupplier.get())
- .thenReturn(ImmutableList.of(languageIndependentModelFile, languageDependentModelFile));
-
- ModelFileManager.ModelFile bestModelFile =
- modelFileManager.findBestModelFile(LocaleList.forLanguageTags(locale.toLanguageTag()));
- assertThat(bestModelFile).isEqualTo(languageDependentModelFile);
- }
-
- @Test
- public void findBestModel_noMatchedLanguageModel() {
- Locale locale = Locale.forLanguageTag("ja");
- ModelFileManager.ModelFile languageIndependentModelFile =
- new ModelFileManager.ModelFile(new File("/path/a"), 1, Collections.emptyList(), "", true);
-
- ModelFileManager.ModelFile languageDependentModelFile =
- new ModelFileManager.ModelFile(
- new File("/path/b"),
- 1,
- Collections.singletonList(locale),
- locale.toLanguageTag(),
- false);
-
- when(modelFileSupplier.get())
- .thenReturn(ImmutableList.of(languageIndependentModelFile, languageDependentModelFile));
-
- ModelFileManager.ModelFile bestModelFile =
- modelFileManager.findBestModelFile(LocaleList.forLanguageTags("zh-hk"));
- assertThat(bestModelFile).isEqualTo(languageIndependentModelFile);
- }
-
- @Test
- public void findBestModel_noMatchedLanguageModel_defaultLocaleModelExists() {
- ModelFileManager.ModelFile languageIndependentModelFile =
- new ModelFileManager.ModelFile(new File("/path/a"), 1, ImmutableList.of(), "", true);
-
- ModelFileManager.ModelFile languageDependentModelFile =
- new ModelFileManager.ModelFile(
- new File("/path/b"),
- 1,
- Collections.singletonList(DEFAULT_LOCALE),
- DEFAULT_LOCALE.toLanguageTag(),
- false);
-
- when(modelFileSupplier.get())
- .thenReturn(ImmutableList.of(languageIndependentModelFile, languageDependentModelFile));
-
- ModelFileManager.ModelFile bestModelFile =
- modelFileManager.findBestModelFile(LocaleList.forLanguageTags("zh-hk"));
- assertThat(bestModelFile).isEqualTo(languageIndependentModelFile);
- }
-
- @Test
- public void findBestModel_languageIsMoreImportantThanVersion() {
- ModelFileManager.ModelFile matchButOlderModel =
- new ModelFileManager.ModelFile(
- new File("/path/a"),
- 1,
- Collections.singletonList(Locale.forLanguageTag("fr")),
- "fr",
- false);
-
- ModelFileManager.ModelFile mismatchButNewerModel =
- new ModelFileManager.ModelFile(
- new File("/path/b"),
- 2,
- Collections.singletonList(Locale.forLanguageTag("ja")),
- "ja",
- false);
-
- when(modelFileSupplier.get())
- .thenReturn(ImmutableList.of(matchButOlderModel, mismatchButNewerModel));
-
- ModelFileManager.ModelFile bestModelFile =
- modelFileManager.findBestModelFile(LocaleList.forLanguageTags("fr"));
- assertThat(bestModelFile).isEqualTo(matchButOlderModel);
- }
-
- @Test
- public void findBestModel_languageIsMoreImportantThanVersion_bestModelComesFirst() {
- ModelFileManager.ModelFile matchLocaleModel =
- new ModelFileManager.ModelFile(
- new File("/path/b"),
- 1,
- Collections.singletonList(Locale.forLanguageTag("ja")),
- "ja",
- false);
-
- ModelFileManager.ModelFile languageIndependentModel =
- new ModelFileManager.ModelFile(new File("/path/a"), 2, ImmutableList.of(), "", true);
- when(modelFileSupplier.get())
- .thenReturn(ImmutableList.of(matchLocaleModel, languageIndependentModel));
-
- ModelFileManager.ModelFile bestModelFile =
- modelFileManager.findBestModelFile(LocaleList.forLanguageTags("ja"));
-
- assertThat(bestModelFile).isEqualTo(matchLocaleModel);
- }
-
- @Test
- public void modelFileEquals() {
- ModelFileManager.ModelFile modelA =
- new ModelFileManager.ModelFile(
- new File("/path/a"),
- 1,
- Collections.singletonList(Locale.forLanguageTag("ja")),
- "ja",
- false);
-
- ModelFileManager.ModelFile modelB =
- new ModelFileManager.ModelFile(
- new File("/path/a"),
- 1,
- Collections.singletonList(Locale.forLanguageTag("ja")),
- "ja",
- false);
-
- assertThat(modelA).isEqualTo(modelB);
- }
-
- @Test
- public void modelFile_different() {
- ModelFileManager.ModelFile modelA =
- new ModelFileManager.ModelFile(
- new File("/path/a"),
- 1,
- Collections.singletonList(Locale.forLanguageTag("ja")),
- "ja",
- false);
-
- ModelFileManager.ModelFile modelB =
- new ModelFileManager.ModelFile(
- new File("/path/b"),
- 1,
- Collections.singletonList(Locale.forLanguageTag("ja")),
- "ja",
- false);
-
- assertThat(modelA).isNotEqualTo(modelB);
- }
-
- @Test
- public void modelFile_getPath() {
- ModelFileManager.ModelFile modelA =
- new ModelFileManager.ModelFile(
- new File("/path/a"),
- 1,
- Collections.singletonList(Locale.forLanguageTag("ja")),
- "ja",
- false);
-
- assertThat(modelA.getPath()).isEqualTo("/path/a");
- }
-
- @Test
- public void modelFile_getName() {
- ModelFileManager.ModelFile modelA =
- new ModelFileManager.ModelFile(
- new File("/path/a"),
- 1,
- Collections.singletonList(Locale.forLanguageTag("ja")),
- "ja",
- false);
-
- assertThat(modelA.getName()).isEqualTo("a");
- }
-
- @Test
- public void modelFile_isPreferredTo_languageDependentIsBetter() {
- ModelFileManager.ModelFile modelA =
- new ModelFileManager.ModelFile(
- new File("/path/a"),
- 1,
- Collections.singletonList(Locale.forLanguageTag("ja")),
- "ja",
- false);
-
- ModelFileManager.ModelFile modelB =
- new ModelFileManager.ModelFile(new File("/path/b"), 2, ImmutableList.of(), "", true);
-
- assertThat(modelA.isPreferredTo(modelB)).isTrue();
- }
-
- @Test
- public void modelFile_isPreferredTo_version() {
- ModelFileManager.ModelFile modelA =
- new ModelFileManager.ModelFile(
- new File("/path/a"),
- 2,
- Collections.singletonList(Locale.forLanguageTag("ja")),
- "ja",
- false);
-
- ModelFileManager.ModelFile modelB =
- new ModelFileManager.ModelFile(new File("/path/b"), 1, Collections.emptyList(), "", false);
-
- assertThat(modelA.isPreferredTo(modelB)).isTrue();
- }
-
- @Test
- public void modelFile_toModelInfo() {
- ModelFileManager.ModelFile modelFile =
- new ModelFileManager.ModelFile(
- new File("/path/a"), 2, ImmutableList.of(Locale.JAPANESE), "ja", false);
-
- ModelInfo modelInfo = modelFile.toModelInfo();
-
- assertThat(modelInfo.toModelName()).isEqualTo("ja_v2");
- }
-
- @Test
- public void modelFile_toModelInfos() {
- ModelFile englishModelFile =
- new ModelFile(new File("/path/a"), 1, ImmutableList.of(Locale.ENGLISH), "en", false);
- ModelFile japaneseModelFile =
- new ModelFile(new File("/path/a"), 2, ImmutableList.of(Locale.JAPANESE), "ja", false);
-
- ImmutableList<Optional<ModelInfo>> modelInfos =
- ModelFileManager.ModelFile.toModelInfos(
- Optional.of(englishModelFile), Optional.of(japaneseModelFile));
-
- assertThat(
- modelInfos.stream()
- .map(modelFile -> modelFile.transform(ModelInfo::toModelName).or(""))
- .collect(Collectors.toList()))
- .containsExactly("en_v1", "ja_v2")
- .inOrder();
- }
-
- @Test
- public void testFileSupplierImpl_updatedFileOnly() throws IOException {
- updatedModelFile.createNewFile();
- File model1 = new File(factoryModelDir, "test1.model");
- model1.createNewFile();
- File model2 = new File(factoryModelDir, "test2.model");
- model2.createNewFile();
- new File(factoryModelDir, "not_match_regex.model").createNewFile();
-
- List<ModelFileManager.ModelFile> modelFiles = modelFileSupplierImpl.get();
- List<String> modelFilePaths =
- modelFiles.stream().map(modelFile -> modelFile.getPath()).collect(Collectors.toList());
-
- assertThat(modelFiles).hasSize(3);
- assertThat(modelFilePaths)
- .containsExactly(
- updatedModelFile.getAbsolutePath(), model1.getAbsolutePath(), model2.getAbsolutePath());
- }
-
- @Test
- public void testFileSupplierImpl_empty() {
- factoryModelDir.delete();
- List<ModelFileManager.ModelFile> modelFiles = modelFileSupplierImpl.get();
-
- assertThat(modelFiles).hasSize(0);
- }
-
- private static void recursiveDelete(File f) {
- if (f.isDirectory()) {
- for (File innerFile : f.listFiles()) {
- recursiveDelete(innerFile);
- }
- }
- f.delete();
- }
-}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/TestDataUtils.java b/java/tests/instrumentation/src/com/android/textclassifier/TestDataUtils.java
new file mode 100644
index 0000000..5c1d95e
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/TestDataUtils.java
@@ -0,0 +1,61 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier;
+
+import android.content.Context;
+import com.android.textclassifier.common.ModelFileManager;
+import com.android.textclassifier.common.ModelFileManager.RegularFileFullMatchLister;
+import com.android.textclassifier.common.ModelType;
+import com.google.common.collect.ImmutableList;
+import java.io.File;
+
+/** Utils to access test data files. */
+public final class TestDataUtils {
+ private static final String TEST_ANNOTATOR_MODEL_PATH = "testdata/annotator.model";
+ private static final String TEST_ACTIONS_MODEL_PATH = "testdata/actions.model";
+ private static final String TEST_LANGID_MODEL_PATH = "testdata/langid.model";
+
+ /** Returns the root folder that contains the test data. */
+ public static File getTestDataFolder() {
+ return new File("/data/local/tmp/TextClassifierServiceTest/");
+ }
+
+ public static File getTestAnnotatorModelFile() {
+ return new File(getTestDataFolder(), TEST_ANNOTATOR_MODEL_PATH);
+ }
+
+ public static File getTestActionsModelFile() {
+ return new File(getTestDataFolder(), TEST_ACTIONS_MODEL_PATH);
+ }
+
+ public static File getLangIdModelFile() {
+ return new File(getTestDataFolder(), TEST_LANGID_MODEL_PATH);
+ }
+
+ public static ModelFileManager createModelFileManagerForTesting(Context context) {
+ return new ModelFileManager(
+ context,
+ ImmutableList.of(
+ new RegularFileFullMatchLister(
+ ModelType.ANNOTATOR, getTestAnnotatorModelFile(), () -> true),
+ new RegularFileFullMatchLister(
+ ModelType.ACTIONS_SUGGESTIONS, getTestActionsModelFile(), () -> true),
+ new RegularFileFullMatchLister(ModelType.LANG_ID, getLangIdModelFile(), () -> true)));
+ }
+
+ private TestDataUtils() {}
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierApiTest.java b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierApiTest.java
new file mode 100644
index 0000000..27ea7f0
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierApiTest.java
@@ -0,0 +1,212 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import android.app.UiAutomation;
+import android.content.pm.PackageManager;
+import android.content.pm.PackageManager.NameNotFoundException;
+import android.icu.util.ULocale;
+import android.provider.DeviceConfig;
+import android.view.textclassifier.ConversationAction;
+import android.view.textclassifier.ConversationActions;
+import android.view.textclassifier.TextClassification;
+import android.view.textclassifier.TextClassificationManager;
+import android.view.textclassifier.TextClassifier;
+import android.view.textclassifier.TextLanguage;
+import android.view.textclassifier.TextLinks;
+import android.view.textclassifier.TextLinks.TextLink;
+import android.view.textclassifier.TextSelection;
+import androidx.test.core.app.ApplicationProvider;
+import androidx.test.ext.junit.runners.AndroidJUnit4;
+import androidx.test.filters.SmallTest;
+import androidx.test.platform.app.InstrumentationRegistry;
+import com.google.common.collect.ImmutableList;
+import java.util.ArrayList;
+import java.util.List;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ExternalResource;
+import org.junit.runner.RunWith;
+
+/**
+ * End-to-end tests for the {@link TextClassifier} APIs. Unlike {@link TextClassifierImplTest}.
+ *
+ * <p>Unlike {@link TextClassifierImplTest}, we are trying to run the tests in a environment that is
+ * closer to the production environment. For example, we are not injecting the model files.
+ */
+@SmallTest
+@RunWith(AndroidJUnit4.class)
+public class TextClassifierApiTest {
+
+ private TextClassifier textClassifier;
+
+ @Rule
+ public final ExtServicesTextClassifierRule extServicesTextClassifierRule =
+ new ExtServicesTextClassifierRule();
+
+ @Before
+ public void setup() {
+ textClassifier = extServicesTextClassifierRule.getTextClassifier();
+ }
+
+ @Test
+ public void suggestSelection() {
+ String text = "Visit http://www.android.com for more information";
+ String selected = "http";
+ String suggested = "http://www.android.com";
+ int startIndex = text.indexOf(selected);
+ int endIndex = startIndex + selected.length();
+ int smartStartIndex = text.indexOf(suggested);
+ int smartEndIndex = smartStartIndex + suggested.length();
+
+ TextSelection.Request request =
+ new TextSelection.Request.Builder(text, startIndex, endIndex).build();
+
+ TextSelection selection = textClassifier.suggestSelection(request);
+ assertThat(selection.getEntityCount()).isGreaterThan(0);
+ assertThat(selection.getEntity(0)).isEqualTo(TextClassifier.TYPE_URL);
+ assertThat(selection.getSelectionStartIndex()).isEqualTo(smartStartIndex);
+ assertThat(selection.getSelectionEndIndex()).isEqualTo(smartEndIndex);
+ }
+
+ @Test
+ public void classifyText() {
+ String text = "Contact me at droid@android.com";
+ String classifiedText = "droid@android.com";
+ int startIndex = text.indexOf(classifiedText);
+ int endIndex = startIndex + classifiedText.length();
+ TextClassification.Request request =
+ new TextClassification.Request.Builder(text, startIndex, endIndex).build();
+
+ TextClassification classification = textClassifier.classifyText(request);
+ assertThat(classification.getEntityCount()).isGreaterThan(0);
+ assertThat(classification.getEntity(0)).isEqualTo(TextClassifier.TYPE_EMAIL);
+ assertThat(classification.getText()).isEqualTo(classifiedText);
+ assertThat(classification.getActions()).isNotEmpty();
+ }
+
+ @Test
+ public void generateLinks() {
+ String text = "Check this out, http://www.android.com";
+
+ TextLinks.Request request = new TextLinks.Request.Builder(text).build();
+
+ TextLinks textLinks = textClassifier.generateLinks(request);
+
+ List<TextLink> links = new ArrayList<>(textLinks.getLinks());
+ assertThat(textLinks.getText().toString()).isEqualTo(text);
+ assertThat(links).hasSize(1);
+ assertThat(links.get(0).getEntityCount()).isGreaterThan(0);
+ assertThat(links.get(0).getEntity(0)).isEqualTo(TextClassifier.TYPE_URL);
+ assertThat(links.get(0).getConfidenceScore(TextClassifier.TYPE_URL)).isGreaterThan(0f);
+ }
+
+ @Test
+ public void detectedLanguage() {
+ String text = "朝、ピカチュウ";
+ TextLanguage.Request request = new TextLanguage.Request.Builder(text).build();
+
+ TextLanguage textLanguage = textClassifier.detectLanguage(request);
+
+ assertThat(textLanguage.getLocaleHypothesisCount()).isGreaterThan(0);
+ assertThat(textLanguage.getLocale(0).getLanguage()).isEqualTo("ja");
+ assertThat(textLanguage.getConfidenceScore(ULocale.JAPANESE)).isGreaterThan(0f);
+ }
+
+ @Test
+ public void suggestConversationActions() {
+ ConversationActions.Message message =
+ new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS)
+ .setText("Check this out: https://www.android.com")
+ .build();
+ ConversationActions.Request request =
+ new ConversationActions.Request.Builder(ImmutableList.of(message)).build();
+
+ ConversationActions conversationActions = textClassifier.suggestConversationActions(request);
+
+ assertThat(conversationActions.getConversationActions()).hasSize(1);
+ ConversationAction conversationAction = conversationActions.getConversationActions().get(0);
+ assertThat(conversationAction.getType()).isEqualTo(ConversationAction.TYPE_OPEN_URL);
+ assertThat(conversationAction.getAction()).isNotNull();
+ }
+
+ /** A rule that manages a text classifier that is backed by the ExtServices. */
+ private static class ExtServicesTextClassifierRule extends ExternalResource {
+ private static final String CONFIG_TEXT_CLASSIFIER_SERVICE_PACKAGE_OVERRIDE =
+ "textclassifier_service_package_override";
+ private static final String PKG_NAME_GOOGLE_EXTSERVICES = "com.google.android.ext.services";
+ private static final String PKG_NAME_AOSP_EXTSERVICES = "android.ext.services";
+
+ private String textClassifierServiceOverrideFlagOldValue;
+
+ @Override
+ protected void before() {
+ UiAutomation uiAutomation = InstrumentationRegistry.getInstrumentation().getUiAutomation();
+ try {
+ uiAutomation.adoptShellPermissionIdentity();
+ textClassifierServiceOverrideFlagOldValue =
+ DeviceConfig.getString(
+ DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
+ CONFIG_TEXT_CLASSIFIER_SERVICE_PACKAGE_OVERRIDE,
+ null);
+ DeviceConfig.setProperty(
+ DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
+ CONFIG_TEXT_CLASSIFIER_SERVICE_PACKAGE_OVERRIDE,
+ getExtServicesPackageName(),
+ /* makeDefault= */ false);
+ } finally {
+ uiAutomation.dropShellPermissionIdentity();
+ }
+ }
+
+ @Override
+ protected void after() {
+ UiAutomation uiAutomation = InstrumentationRegistry.getInstrumentation().getUiAutomation();
+ try {
+ uiAutomation.adoptShellPermissionIdentity();
+ DeviceConfig.setProperty(
+ DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
+ CONFIG_TEXT_CLASSIFIER_SERVICE_PACKAGE_OVERRIDE,
+ textClassifierServiceOverrideFlagOldValue,
+ /* makeDefault= */ false);
+ } finally {
+ uiAutomation.dropShellPermissionIdentity();
+ }
+ }
+
+ private static String getExtServicesPackageName() {
+ PackageManager packageManager =
+ ApplicationProvider.getApplicationContext().getPackageManager();
+ try {
+ packageManager.getApplicationInfo(PKG_NAME_GOOGLE_EXTSERVICES, /* flags= */ 0);
+ return PKG_NAME_GOOGLE_EXTSERVICES;
+ } catch (NameNotFoundException e) {
+ return PKG_NAME_AOSP_EXTSERVICES;
+ }
+ }
+
+ public TextClassifier getTextClassifier() {
+ TextClassificationManager textClassificationManager =
+ ApplicationProvider.getApplicationContext()
+ .getSystemService(TextClassificationManager.class);
+ return textClassificationManager.getTextClassifier();
+ }
+ }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java
index 6d80673..81aa832 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java
@@ -21,7 +21,7 @@
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
-import static org.testng.Assert.assertThrows;
+import static org.testng.Assert.expectThrows;
import android.app.RemoteAction;
import android.content.Context;
@@ -38,10 +38,14 @@
import android.view.textclassifier.TextLanguage;
import android.view.textclassifier.TextLinks;
import android.view.textclassifier.TextSelection;
+import androidx.test.core.app.ApplicationProvider;
import androidx.test.ext.junit.runners.AndroidJUnit4;
import androidx.test.filters.SmallTest;
+import com.android.textclassifier.common.ModelFileManager;
+import com.android.textclassifier.common.TextClassifierSettings;
import com.android.textclassifier.testing.FakeContextBuilder;
import com.google.common.collect.ImmutableList;
+import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
@@ -50,7 +54,6 @@
import org.hamcrest.Description;
import org.hamcrest.Matcher;
import org.junit.Before;
-import org.junit.Ignore;
import org.junit.Test;
import org.junit.runner.RunWith;
@@ -63,6 +66,8 @@
private static final String NO_TYPE = null;
private TextClassifierImpl classifier;
+ private final ModelFileManager modelFileManager =
+ TestDataUtils.createModelFileManagerForTesting(ApplicationProvider.getApplicationContext());
@Before
public void setup() {
@@ -71,11 +76,12 @@
.setAllIntentComponent(FakeContextBuilder.DEFAULT_COMPONENT)
.setAppLabel(FakeContextBuilder.DEFAULT_COMPONENT.getPackageName(), "Test app")
.build();
- classifier = new TextClassifierImpl(context, new TextClassifierSettings());
+ TextClassifierSettings settings = new TextClassifierSettings();
+ classifier = new TextClassifierImpl(context, settings, modelFileManager);
}
@Test
- public void testSuggestSelection() {
+ public void testSuggestSelection() throws IOException {
String text = "Contact me at droid@android.com";
String selected = "droid";
String suggested = "droid@android.com";
@@ -88,13 +94,13 @@
.setDefaultLocales(LOCALES)
.build();
- TextSelection selection = classifier.suggestSelection(request);
+ TextSelection selection = classifier.suggestSelection(null, null, request);
assertThat(
selection, isTextSelection(smartStartIndex, smartEndIndex, TextClassifier.TYPE_EMAIL));
}
@Test
- public void testSuggestSelection_url() {
+ public void testSuggestSelection_url() throws IOException {
String text = "Visit http://www.android.com for more information";
String selected = "http";
String suggested = "http://www.android.com";
@@ -107,12 +113,12 @@
.setDefaultLocales(LOCALES)
.build();
- TextSelection selection = classifier.suggestSelection(request);
+ TextSelection selection = classifier.suggestSelection(null, null, request);
assertThat(selection, isTextSelection(smartStartIndex, smartEndIndex, TextClassifier.TYPE_URL));
}
@Test
- public void testSmartSelection_withEmoji() {
+ public void testSmartSelection_withEmoji() throws IOException {
String text = "\uD83D\uDE02 Hello.";
String selected = "Hello";
int startIndex = text.indexOf(selected);
@@ -122,12 +128,12 @@
.setDefaultLocales(LOCALES)
.build();
- TextSelection selection = classifier.suggestSelection(request);
+ TextSelection selection = classifier.suggestSelection(null, null, request);
assertThat(selection, isTextSelection(startIndex, endIndex, NO_TYPE));
}
@Test
- public void testClassifyText() {
+ public void testClassifyText() throws IOException {
String text = "Contact me at droid@android.com";
String classifiedText = "droid@android.com";
int startIndex = text.indexOf(classifiedText);
@@ -137,12 +143,13 @@
.setDefaultLocales(LOCALES)
.build();
- TextClassification classification = classifier.classifyText(request);
+ TextClassification classification =
+ classifier.classifyText(/* sessionId= */ null, null, request);
assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_EMAIL));
}
@Test
- public void testClassifyText_url() {
+ public void testClassifyText_url() throws IOException {
String text = "Visit www.android.com for more information";
String classifiedText = "www.android.com";
int startIndex = text.indexOf(classifiedText);
@@ -152,25 +159,25 @@
.setDefaultLocales(LOCALES)
.build();
- TextClassification classification = classifier.classifyText(request);
+ TextClassification classification = classifier.classifyText(null, null, request);
assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_URL));
assertThat(classification, containsIntentWithAction(Intent.ACTION_VIEW));
}
@Test
- public void testClassifyText_address() {
+ public void testClassifyText_address() throws IOException {
String text = "Brandschenkestrasse 110, Zürich, Switzerland";
TextClassification.Request request =
new TextClassification.Request.Builder(text, 0, text.length())
.setDefaultLocales(LOCALES)
.build();
- TextClassification classification = classifier.classifyText(request);
+ TextClassification classification = classifier.classifyText(null, null, request);
assertThat(classification, isTextClassification(text, TextClassifier.TYPE_ADDRESS));
}
@Test
- public void testClassifyText_url_inCaps() {
+ public void testClassifyText_url_inCaps() throws IOException {
String text = "Visit HTTP://ANDROID.COM for more information";
String classifiedText = "HTTP://ANDROID.COM";
int startIndex = text.indexOf(classifiedText);
@@ -180,13 +187,13 @@
.setDefaultLocales(LOCALES)
.build();
- TextClassification classification = classifier.classifyText(request);
+ TextClassification classification = classifier.classifyText(null, null, request);
assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_URL));
assertThat(classification, containsIntentWithAction(Intent.ACTION_VIEW));
}
@Test
- public void testClassifyText_date() {
+ public void testClassifyText_date() throws IOException {
String text = "Let's meet on January 9, 2018.";
String classifiedText = "January 9, 2018";
int startIndex = text.indexOf(classifiedText);
@@ -196,7 +203,7 @@
.setDefaultLocales(LOCALES)
.build();
- TextClassification classification = classifier.classifyText(request);
+ TextClassification classification = classifier.classifyText(null, null, request);
assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_DATE));
Bundle extras = classification.getExtras();
List<Bundle> entities = ExtrasUtils.getEntities(extras);
@@ -207,7 +214,7 @@
}
@Test
- public void testClassifyText_datetime() {
+ public void testClassifyText_datetime() throws IOException {
String text = "Let's meet 2018/01/01 10:30:20.";
String classifiedText = "2018/01/01 10:30:20";
int startIndex = text.indexOf(classifiedText);
@@ -217,15 +224,12 @@
.setDefaultLocales(LOCALES)
.build();
- TextClassification classification = classifier.classifyText(request);
+ TextClassification classification = classifier.classifyText(null, null, request);
assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_DATE_TIME));
}
- // TODO(tonymak): Enable it once we drop the v8 image to Android. I have already run this test
- // after pushing a test model to a device manually.
- @Ignore
@Test
- public void testClassifyText_foreignText() {
+ public void testClassifyText_foreignText() throws IOException {
LocaleList originalLocales = LocaleList.getDefault();
LocaleList.setDefault(LocaleList.forLanguageTags("en"));
String japaneseText = "これは日本語のテキストです";
@@ -234,7 +238,7 @@
.setDefaultLocales(LOCALES)
.build();
- TextClassification classification = classifier.classifyText(request);
+ TextClassification classification = classifier.classifyText(null, null, request);
RemoteAction translateAction = classification.getActions().get(0);
assertEquals(1, classification.getActions().size());
assertEquals("Translate", translateAction.getTitle().toString());
@@ -254,16 +258,16 @@
}
@Test
- public void testGenerateLinks_phone() {
+ public void testGenerateLinks_phone() throws IOException {
String text = "The number is +12122537077. See you tonight!";
TextLinks.Request request = new TextLinks.Request.Builder(text).build();
assertThat(
- classifier.generateLinks(request),
+ classifier.generateLinks(null, null, request),
isTextLinksContaining(text, "+12122537077", TextClassifier.TYPE_PHONE));
}
@Test
- public void testGenerateLinks_exclude() {
+ public void testGenerateLinks_exclude() throws IOException {
String text = "You want apple@banana.com. See you tonight!";
List<String> hints = ImmutableList.of();
List<String> included = ImmutableList.of();
@@ -274,12 +278,12 @@
.setDefaultLocales(LOCALES)
.build();
assertThat(
- classifier.generateLinks(request),
+ classifier.generateLinks(null, null, request),
not(isTextLinksContaining(text, "apple@banana.com", TextClassifier.TYPE_EMAIL)));
}
@Test
- public void testGenerateLinks_explicit_address() {
+ public void testGenerateLinks_explicit_address() throws IOException {
String text = "The address is 1600 Amphitheater Parkway, Mountain View, CA. See you!";
List<String> explicit = Arrays.asList(TextClassifier.TYPE_ADDRESS);
TextLinks.Request request =
@@ -288,13 +292,13 @@
.setDefaultLocales(LOCALES)
.build();
assertThat(
- classifier.generateLinks(request),
+ classifier.generateLinks(null, null, request),
isTextLinksContaining(
text, "1600 Amphitheater Parkway, Mountain View, CA", TextClassifier.TYPE_ADDRESS));
}
@Test
- public void testGenerateLinks_exclude_override() {
+ public void testGenerateLinks_exclude_override() throws IOException {
String text = "You want apple@banana.com. See you tonight!";
List<String> hints = ImmutableList.of();
List<String> included = Arrays.asList(TextClassifier.TYPE_EMAIL);
@@ -305,26 +309,26 @@
.setDefaultLocales(LOCALES)
.build();
assertThat(
- classifier.generateLinks(request),
+ classifier.generateLinks(null, null, request),
not(isTextLinksContaining(text, "apple@banana.com", TextClassifier.TYPE_EMAIL)));
}
@Test
- public void testGenerateLinks_maxLength() {
+ public void testGenerateLinks_maxLength() throws IOException {
char[] manySpaces = new char[classifier.getMaxGenerateLinksTextLength()];
Arrays.fill(manySpaces, ' ');
TextLinks.Request request = new TextLinks.Request.Builder(new String(manySpaces)).build();
- TextLinks links = classifier.generateLinks(request);
+ TextLinks links = classifier.generateLinks(null, null, request);
assertTrue(links.getLinks().isEmpty());
}
@Test
- public void testApplyLinks_unsupportedCharacter() {
+ public void testApplyLinks_unsupportedCharacter() throws IOException {
Spannable url = new SpannableString("\u202Emoc.diordna.com");
TextLinks.Request request = new TextLinks.Request.Builder(url).build();
assertEquals(
TextLinks.STATUS_UNSUPPORTED_CHARACTER,
- classifier.generateLinks(request).apply(url, 0, null));
+ classifier.generateLinks(null, null, request).apply(url, 0, null));
}
@Test
@@ -332,17 +336,18 @@
char[] manySpaces = new char[classifier.getMaxGenerateLinksTextLength() + 1];
Arrays.fill(manySpaces, ' ');
TextLinks.Request request = new TextLinks.Request.Builder(new String(manySpaces)).build();
- assertThrows(IllegalArgumentException.class, () -> classifier.generateLinks(request));
+ expectThrows(
+ IllegalArgumentException.class, () -> classifier.generateLinks(null, null, request));
}
@Test
- public void testGenerateLinks_entityData() {
+ public void testGenerateLinks_entityData() throws IOException {
String text = "The number is +12122537077.";
Bundle extras = new Bundle();
ExtrasUtils.putIsSerializedEntityDataEnabled(extras, true);
TextLinks.Request request = new TextLinks.Request.Builder(text).setExtras(extras).build();
- TextLinks textLinks = classifier.generateLinks(request);
+ TextLinks textLinks = classifier.generateLinks(null, null, request);
assertThat(textLinks.getLinks()).hasSize(1);
TextLinks.TextLink textLink = textLinks.getLinks().iterator().next();
@@ -353,11 +358,11 @@
}
@Test
- public void testGenerateLinks_entityData_disabled() {
+ public void testGenerateLinks_entityData_disabled() throws IOException {
String text = "The number is +12122537077.";
TextLinks.Request request = new TextLinks.Request.Builder(text).build();
- TextLinks textLinks = classifier.generateLinks(request);
+ TextLinks textLinks = classifier.generateLinks(null, null, request);
assertThat(textLinks.getLinks()).hasSize(1);
TextLinks.TextLink textLink = textLinks.getLinks().iterator().next();
@@ -366,24 +371,23 @@
}
@Test
- public void testDetectLanguage() {
+ public void testDetectLanguage() throws IOException {
String text = "This is English text";
TextLanguage.Request request = new TextLanguage.Request.Builder(text).build();
- TextLanguage textLanguage = classifier.detectLanguage(request);
+ TextLanguage textLanguage = classifier.detectLanguage(null, null, request);
assertThat(textLanguage, isTextLanguage("en"));
}
@Test
- public void testDetectLanguage_japanese() {
+ public void testDetectLanguage_japanese() throws IOException {
String text = "これは日本語のテキストです";
TextLanguage.Request request = new TextLanguage.Request.Builder(text).build();
- TextLanguage textLanguage = classifier.detectLanguage(request);
+ TextLanguage textLanguage = classifier.detectLanguage(null, null, request);
assertThat(textLanguage, isTextLanguage("ja"));
}
- @Ignore // Doesn't work without a language-based model.
@Test
- public void testSuggestConversationActions_textReplyOnly_maxOne() {
+ public void testSuggestConversationActions_textReplyOnly_maxOne() throws IOException {
ConversationActions.Message message =
new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS)
.setText("Where are you?")
@@ -399,16 +403,16 @@
.setTypeConfig(typeConfig)
.build();
- ConversationActions conversationActions = classifier.suggestConversationActions(request);
+ ConversationActions conversationActions =
+ classifier.suggestConversationActions(null, null, request);
assertThat(conversationActions.getConversationActions()).hasSize(1);
ConversationAction conversationAction = conversationActions.getConversationActions().get(0);
assertThat(conversationAction.getType()).isEqualTo(ConversationAction.TYPE_TEXT_REPLY);
assertThat(conversationAction.getTextReply()).isNotNull();
}
- @Ignore // Doesn't work without a language-based model.
@Test
- public void testSuggestConversationActions_textReplyOnly_noMax() {
+ public void testSuggestConversationActions_textReplyOnly_noMax() throws IOException {
ConversationActions.Message message =
new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS)
.setText("Where are you?")
@@ -423,7 +427,8 @@
.setTypeConfig(typeConfig)
.build();
- ConversationActions conversationActions = classifier.suggestConversationActions(request);
+ ConversationActions conversationActions =
+ classifier.suggestConversationActions(null, null, request);
assertTrue(conversationActions.getConversationActions().size() > 1);
for (ConversationAction conversationAction : conversationActions.getConversationActions()) {
assertThat(conversationAction, isConversationAction(ConversationAction.TYPE_TEXT_REPLY));
@@ -431,7 +436,7 @@
}
@Test
- public void testSuggestConversationActions_openUrl() {
+ public void testSuggestConversationActions_openUrl() throws IOException {
ConversationActions.Message message =
new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS)
.setText("Check this out: https://www.android.com")
@@ -447,7 +452,8 @@
.setTypeConfig(typeConfig)
.build();
- ConversationActions conversationActions = classifier.suggestConversationActions(request);
+ ConversationActions conversationActions =
+ classifier.suggestConversationActions(null, null, request);
assertThat(conversationActions.getConversationActions()).hasSize(1);
ConversationAction conversationAction = conversationActions.getConversationActions().get(0);
assertThat(conversationAction.getType()).isEqualTo(ConversationAction.TYPE_OPEN_URL);
@@ -457,9 +463,8 @@
assertNoPackageInfoInExtras(actionIntent);
}
- @Ignore // Doesn't work without a language-based model.
@Test
- public void testSuggestConversationActions_copy() {
+ public void testSuggestConversationActions_copy() throws IOException {
ConversationActions.Message message =
new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS)
.setText("Authentication code: 12345")
@@ -475,7 +480,8 @@
.setTypeConfig(typeConfig)
.build();
- ConversationActions conversationActions = classifier.suggestConversationActions(request);
+ ConversationActions conversationActions =
+ classifier.suggestConversationActions(null, null, request);
assertThat(conversationActions.getConversationActions()).hasSize(1);
ConversationAction conversationAction = conversationActions.getConversationActions().get(0);
assertThat(conversationAction.getType()).isEqualTo(TYPE_COPY);
@@ -487,7 +493,7 @@
}
@Test
- public void testSuggestConversationActions_deduplicate() {
+ public void testSuggestConversationActions_deduplicate() throws IOException {
ConversationActions.Message message =
new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS)
.setText("a@android.com b@android.com")
@@ -497,7 +503,8 @@
.setMaxSuggestions(3)
.build();
- ConversationActions conversationActions = classifier.suggestConversationActions(request);
+ ConversationActions conversationActions =
+ classifier.suggestConversationActions(null, null, request);
assertThat(conversationActions.getConversationActions()).isEmpty();
}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierSettingsTest.java b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierSettingsTest.java
deleted file mode 100644
index 21ed0b6..0000000
--- a/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierSettingsTest.java
+++ /dev/null
@@ -1,112 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.android.textclassifier;
-
-import static com.google.common.truth.Truth.assertThat;
-
-import android.provider.DeviceConfig;
-import androidx.test.ext.junit.runners.AndroidJUnit4;
-import androidx.test.filters.SmallTest;
-import androidx.test.platform.app.InstrumentationRegistry;
-import java.util.function.Consumer;
-import org.junit.After;
-import org.junit.Before;
-import org.junit.Test;
-import org.junit.runner.RunWith;
-
-@SmallTest
-@RunWith(AndroidJUnit4.class)
-public class TextClassifierSettingsTest {
- private static final String WRITE_DEVICE_CONFIG_PERMISSION =
- "android.permission.WRITE_DEVICE_CONFIG";
- private static final float EPSILON = 0.0001f;
-
- @Before
- public void setup() {
- InstrumentationRegistry.getInstrumentation()
- .getUiAutomation()
- .adoptShellPermissionIdentity(WRITE_DEVICE_CONFIG_PERMISSION);
- }
-
- @After
- public void tearDown() {
- InstrumentationRegistry.getInstrumentation().getUiAutomation().dropShellPermissionIdentity();
- }
-
- @Test
- public void booleanSetting() {
- assertSettings(
- TextClassifierSettings.TEMPLATE_INTENT_FACTORY_ENABLED,
- "false",
- settings -> assertThat(settings.isTemplateIntentFactoryEnabled()).isFalse());
- }
-
- @Test
- public void intSetting() {
- assertSettings(
- TextClassifierSettings.SUGGEST_SELECTION_MAX_RANGE_LENGTH,
- "8",
- settings -> assertThat(settings.getSuggestSelectionMaxRangeLength()).isEqualTo(8));
- }
-
- @Test
- public void floatSetting() {
- assertSettings(
- TextClassifierSettings.LANG_ID_THRESHOLD_OVERRIDE,
- "3.14",
- settings -> assertThat(settings.getLangIdThresholdOverride()).isWithin(EPSILON).of(3.14f));
- }
-
- @Test
- public void stringListSetting() {
- assertSettings(
- TextClassifierSettings.ENTITY_LIST_DEFAULT,
- "email:url",
- settings ->
- assertThat(settings.getEntityListDefault()).containsExactly("email", "url").inOrder());
- }
-
- @Test
- public void floatListSetting() {
- assertSettings(
- TextClassifierSettings.LANG_ID_CONTEXT_SETTINGS,
- "30:0.5:0.3",
- settings ->
- assertThat(settings.getLangIdContextSettings())
- .usingTolerance(EPSILON)
- .containsExactly(30f, 0.5f, 0.3f)
- .inOrder());
- }
-
- private static void assertSettings(
- String key, String value, Consumer<TextClassifierSettings> settingsConsumer) {
- final String originalValue =
- DeviceConfig.getProperty(DeviceConfig.NAMESPACE_TEXTCLASSIFIER, key);
- TextClassifierSettings settings = new TextClassifierSettings();
- try {
- setDeviceConfig(key, value);
- settingsConsumer.accept(settings);
- } finally {
- setDeviceConfig(key, originalValue);
- }
- }
-
- private static void setDeviceConfig(String key, String value) {
- DeviceConfig.setProperty(
- DeviceConfig.NAMESPACE_TEXTCLASSIFIER, key, value, /* makeDefault */ false);
- }
-}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/common/ModelFileManagerTest.java b/java/tests/instrumentation/src/com/android/textclassifier/common/ModelFileManagerTest.java
new file mode 100644
index 0000000..40838ac
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/common/ModelFileManagerTest.java
@@ -0,0 +1,507 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.common;
+
+import static com.android.textclassifier.common.ModelFileManager.ModelFile.LANGUAGE_INDEPENDENT;
+import static com.google.common.truth.Truth.assertThat;
+
+import android.os.LocaleList;
+import androidx.test.core.app.ApplicationProvider;
+import androidx.test.ext.junit.runners.AndroidJUnit4;
+import androidx.test.filters.SmallTest;
+import com.android.textclassifier.TestDataUtils;
+import com.android.textclassifier.common.ModelFileManager.ModelFile;
+import com.android.textclassifier.common.ModelFileManager.RegularFileFullMatchLister;
+import com.android.textclassifier.common.ModelFileManager.RegularFilePatternMatchLister;
+import com.android.textclassifier.common.ModelType.ModelTypeDef;
+import com.android.textclassifier.common.logging.ResultIdUtils.ModelInfo;
+import com.android.textclassifier.testing.SetDefaultLocalesRule;
+import com.google.common.base.Optional;
+import com.google.common.collect.ImmutableList;
+import com.google.common.io.Files;
+import java.io.File;
+import java.io.IOException;
+import java.util.List;
+import java.util.Locale;
+import java.util.stream.Collectors;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+
+@SmallTest
+@RunWith(AndroidJUnit4.class)
+public final class ModelFileManagerTest {
+ private static final Locale DEFAULT_LOCALE = Locale.forLanguageTag("en-US");
+
+ @ModelTypeDef private static final String MODEL_TYPE = ModelType.ANNOTATOR;
+
+ @Mock private TextClassifierSettings.IDeviceConfig mockDeviceConfig;
+
+ @Rule public final SetDefaultLocalesRule setDefaultLocalesRule = new SetDefaultLocalesRule();
+
+ private File rootTestDir;
+ private ModelFileManager modelFileManager;
+
+ @Before
+ public void setup() {
+ MockitoAnnotations.initMocks(this);
+
+ rootTestDir =
+ new File(ApplicationProvider.getApplicationContext().getCacheDir(), "rootTestDir");
+ rootTestDir.mkdirs();
+ modelFileManager =
+ new ModelFileManager(
+ ApplicationProvider.getApplicationContext(),
+ new TextClassifierSettings(mockDeviceConfig));
+ }
+
+ @After
+ public void removeTestDir() {
+ recursiveDelete(rootTestDir);
+ }
+
+ @Test
+ public void annotatorModelPreloaded() {
+ verifyModelPreloadedAsAsset(ModelType.ANNOTATOR, "textclassifier/annotator.universal.model");
+ }
+
+ @Test
+ public void actionsModelPreloaded() {
+ verifyModelPreloadedAsAsset(
+ ModelType.ACTIONS_SUGGESTIONS, "textclassifier/actions_suggestions.universal.model");
+ }
+
+ @Test
+ public void langIdModelPreloaded() {
+ verifyModelPreloadedAsAsset(ModelType.LANG_ID, "textclassifier/lang_id.model");
+ }
+
+ private void verifyModelPreloadedAsAsset(
+ @ModelTypeDef String modelType, String expectedModelPath) {
+ List<ModelFileManager.ModelFile> modelFiles = modelFileManager.listModelFiles(modelType);
+ List<ModelFile> assetFiles =
+ modelFiles.stream().filter(modelFile -> modelFile.isAsset).collect(Collectors.toList());
+
+ assertThat(assetFiles).hasSize(1);
+ assertThat(assetFiles.get(0).absolutePath).isEqualTo(expectedModelPath);
+ }
+
+ @Test
+ public void findBestModel_versionCode() {
+ ModelFileManager.ModelFile olderModelFile =
+ new ModelFileManager.ModelFile(
+ MODEL_TYPE,
+ new File(rootTestDir, "a").getAbsolutePath(),
+ /* version= */ 1,
+ LANGUAGE_INDEPENDENT,
+ /* isAsset= */ false);
+ ModelFileManager.ModelFile newerModelFile =
+ new ModelFileManager.ModelFile(
+ MODEL_TYPE,
+ new File(rootTestDir, "b").getAbsolutePath(),
+ /* version= */ 2,
+ LANGUAGE_INDEPENDENT,
+ /* isAsset= */ false);
+ ModelFileManager modelFileManager =
+ new ModelFileManager(
+ ApplicationProvider.getApplicationContext(),
+ ImmutableList.of(modelType -> ImmutableList.of(olderModelFile, newerModelFile)));
+
+ ModelFile bestModelFile = modelFileManager.findBestModelFile(MODEL_TYPE, null);
+ assertThat(bestModelFile).isEqualTo(newerModelFile);
+ }
+
+ @Test
+ public void findBestModel_languageDependentModelIsPreferred() {
+ ModelFileManager.ModelFile languageIndependentModelFile =
+ new ModelFileManager.ModelFile(
+ MODEL_TYPE,
+ new File(rootTestDir, "a").getAbsolutePath(),
+ /* version= */ 1,
+ LANGUAGE_INDEPENDENT,
+ /* isAsset= */ false);
+ ModelFileManager.ModelFile languageDependentModelFile =
+ new ModelFileManager.ModelFile(
+ MODEL_TYPE,
+ new File(rootTestDir, "b").getAbsolutePath(),
+ /* version= */ 2,
+ DEFAULT_LOCALE.toLanguageTag(),
+ /* isAsset= */ false);
+ ModelFileManager modelFileManager =
+ new ModelFileManager(
+ ApplicationProvider.getApplicationContext(),
+ ImmutableList.of(
+ modelType ->
+ ImmutableList.of(languageIndependentModelFile, languageDependentModelFile)));
+
+ ModelFile bestModelFile =
+ modelFileManager.findBestModelFile(MODEL_TYPE, new LocaleList(DEFAULT_LOCALE));
+ assertThat(bestModelFile).isEqualTo(languageDependentModelFile);
+ }
+
+ @Test
+ public void findBestModel_noMatchedLanguageModel() {
+ ModelFileManager.ModelFile languageIndependentModelFile =
+ new ModelFileManager.ModelFile(
+ MODEL_TYPE,
+ new File(rootTestDir, "a").getAbsolutePath(),
+ /* version= */ 1,
+ LANGUAGE_INDEPENDENT,
+ /* isAsset= */ false);
+ ModelFileManager.ModelFile languageDependentModelFile =
+ new ModelFileManager.ModelFile(
+ MODEL_TYPE,
+ new File(rootTestDir, "b").getAbsolutePath(),
+ /* version= */ 2,
+ DEFAULT_LOCALE.toLanguageTag(),
+ /* isAsset= */ false);
+ ModelFileManager modelFileManager =
+ new ModelFileManager(
+ ApplicationProvider.getApplicationContext(),
+ ImmutableList.of(
+ modelType ->
+ ImmutableList.of(languageIndependentModelFile, languageDependentModelFile)));
+
+ ModelFileManager.ModelFile bestModelFile =
+ modelFileManager.findBestModelFile(MODEL_TYPE, LocaleList.forLanguageTags("zh-hk"));
+ assertThat(bestModelFile).isEqualTo(languageIndependentModelFile);
+ }
+
+ @Test
+ public void findBestModel_languageIsMoreImportantThanVersion() {
+ ModelFileManager.ModelFile matchButOlderModel =
+ new ModelFileManager.ModelFile(
+ MODEL_TYPE,
+ new File(rootTestDir, "a").getAbsolutePath(),
+ /* version= */ 1,
+ "fr",
+ /* isAsset= */ false);
+ ModelFileManager.ModelFile mismatchButNewerModel =
+ new ModelFileManager.ModelFile(
+ MODEL_TYPE,
+ new File(rootTestDir, "b").getAbsolutePath(),
+ /* version= */ 1,
+ "ja",
+ /* isAsset= */ false);
+ ModelFileManager modelFileManager =
+ new ModelFileManager(
+ ApplicationProvider.getApplicationContext(),
+ ImmutableList.of(
+ modelType -> ImmutableList.of(matchButOlderModel, mismatchButNewerModel)));
+
+ ModelFileManager.ModelFile bestModelFile =
+ modelFileManager.findBestModelFile(MODEL_TYPE, LocaleList.forLanguageTags("fr"));
+ assertThat(bestModelFile).isEqualTo(matchButOlderModel);
+ }
+
+ @Test
+ public void findBestModel_preferMatchedLocaleModel() {
+ ModelFileManager.ModelFile matchLocaleModel =
+ new ModelFileManager.ModelFile(
+ MODEL_TYPE,
+ new File(rootTestDir, "a").getAbsolutePath(),
+ /* version= */ 1,
+ "ja",
+ /* isAsset= */ false);
+ ModelFileManager.ModelFile languageIndependentModel =
+ new ModelFileManager.ModelFile(
+ MODEL_TYPE,
+ new File(rootTestDir, "b").getAbsolutePath(),
+ /* version= */ 1,
+ LANGUAGE_INDEPENDENT,
+ /* isAsset= */ false);
+ ModelFileManager modelFileManager =
+ new ModelFileManager(
+ ApplicationProvider.getApplicationContext(),
+ ImmutableList.of(
+ modelType -> ImmutableList.of(matchLocaleModel, languageIndependentModel)));
+
+ ModelFileManager.ModelFile bestModelFile =
+ modelFileManager.findBestModelFile(MODEL_TYPE, LocaleList.forLanguageTags("ja"));
+
+ assertThat(bestModelFile).isEqualTo(matchLocaleModel);
+ }
+
+ @Test
+ public void deleteUnusedModelFiles_olderModelDeleted() throws Exception {
+ File model1 = new File(rootTestDir, "model1.fb");
+ model1.createNewFile();
+ File model2 = new File(rootTestDir, "model2.fb");
+ model2.createNewFile();
+ ModelFileManager.ModelFile modelFile1 =
+ new ModelFileManager.ModelFile(
+ MODEL_TYPE, model1.getAbsolutePath(), /* version= */ 1, "ja", /* isAsset= */ false);
+ ModelFileManager.ModelFile modelFile2 =
+ new ModelFileManager.ModelFile(
+ MODEL_TYPE, model2.getAbsolutePath(), /* version= */ 2, "ja", /* isAsset= */ false);
+ setDefaultLocalesRule.set(new LocaleList(Locale.forLanguageTag("ja")));
+ ModelFileManager modelFileManager =
+ new ModelFileManager(
+ ApplicationProvider.getApplicationContext(),
+ ImmutableList.of(modelType -> ImmutableList.of(modelFile1, modelFile2)));
+
+ modelFileManager.deleteUnusedModelFiles();
+
+ assertThat(model1.exists()).isFalse();
+ assertThat(model2.exists()).isTrue();
+ }
+
+ @Test
+ public void deleteUnusedModelFiles_languageIndependentOlderModelDeleted() throws Exception {
+ File model1 = new File(rootTestDir, "model1.fb");
+ model1.createNewFile();
+ File model2 = new File(rootTestDir, "model2.fb");
+ model2.createNewFile();
+ ModelFileManager.ModelFile modelFile1 =
+ new ModelFileManager.ModelFile(
+ MODEL_TYPE,
+ model1.getAbsolutePath(),
+ /* version= */ 1,
+ LANGUAGE_INDEPENDENT,
+ /* isAsset= */ false);
+ ModelFileManager.ModelFile modelFile2 =
+ new ModelFileManager.ModelFile(
+ MODEL_TYPE,
+ model2.getAbsolutePath(),
+ /* version= */ 2,
+ LANGUAGE_INDEPENDENT,
+ /* isAsset= */ false);
+ setDefaultLocalesRule.set(new LocaleList(Locale.forLanguageTag("ja")));
+ ModelFileManager modelFileManager =
+ new ModelFileManager(
+ ApplicationProvider.getApplicationContext(),
+ ImmutableList.of(modelType -> ImmutableList.of(modelFile1, modelFile2)));
+
+ modelFileManager.deleteUnusedModelFiles();
+
+ assertThat(model1.exists()).isFalse();
+ assertThat(model2.exists()).isTrue();
+ }
+
+ @Test
+ public void deleteUnusedModelFiles_modelOnlySupportingLocalesNotInListDeleted() throws Exception {
+ File model1 = new File(rootTestDir, "model1.fb");
+ model1.createNewFile();
+ File model2 = new File(rootTestDir, "model2.fb");
+ model2.createNewFile();
+ ModelFileManager.ModelFile modelFile1 =
+ new ModelFileManager.ModelFile(
+ MODEL_TYPE, model1.getAbsolutePath(), /* version= */ 1, "ja", /* isAsset= */ false);
+ ModelFileManager.ModelFile modelFile2 =
+ new ModelFileManager.ModelFile(
+ MODEL_TYPE, model2.getAbsolutePath(), /* version= */ 1, "en", /* isAsset= */ false);
+ setDefaultLocalesRule.set(new LocaleList(Locale.forLanguageTag("ja")));
+ ModelFileManager modelFileManager =
+ new ModelFileManager(
+ ApplicationProvider.getApplicationContext(),
+ ImmutableList.of(modelType -> ImmutableList.of(modelFile1, modelFile2)));
+
+ modelFileManager.deleteUnusedModelFiles();
+
+ assertThat(model1.exists()).isTrue();
+ assertThat(model2.exists()).isFalse();
+ }
+
+ @Test
+ public void deleteUnusedModelFiles_multiLocalesInLocaleList() throws Exception {
+ File model1 = new File(rootTestDir, "model1.fb");
+ model1.createNewFile();
+ File model2 = new File(rootTestDir, "model2.fb");
+ model2.createNewFile();
+ ModelFileManager.ModelFile modelFile1 =
+ new ModelFileManager.ModelFile(
+ MODEL_TYPE, model1.getAbsolutePath(), /* version= */ 1, "ja", /* isAsset= */ false);
+ ModelFileManager.ModelFile modelFile2 =
+ new ModelFileManager.ModelFile(
+ MODEL_TYPE, model2.getAbsolutePath(), /* version= */ 2, "en", /* isAsset= */ false);
+ setDefaultLocalesRule.set(
+ new LocaleList(Locale.forLanguageTag("ja"), Locale.forLanguageTag("en")));
+ ModelFileManager modelFileManager =
+ new ModelFileManager(
+ ApplicationProvider.getApplicationContext(),
+ ImmutableList.of(modelType -> ImmutableList.of(modelFile1, modelFile2)));
+ modelFileManager.deleteUnusedModelFiles();
+
+ assertThat(model1.exists()).isTrue();
+ assertThat(model2.exists()).isTrue();
+ }
+
+ @Test
+ public void deleteUnusedModelFiles_readOnlyModelsUntouched() throws Exception {
+ File readOnlyDir = new File(rootTestDir, "read_only/");
+ readOnlyDir.mkdirs();
+ File model1 = new File(readOnlyDir, "model1.fb");
+ model1.createNewFile();
+ readOnlyDir.setWritable(false);
+ ModelFileManager.ModelFile modelFile =
+ new ModelFileManager.ModelFile(
+ MODEL_TYPE, model1.getAbsolutePath(), /* version= */ 1, "ja", /* isAsset= */ false);
+ ModelFileManager modelFileManager =
+ new ModelFileManager(
+ ApplicationProvider.getApplicationContext(),
+ ImmutableList.of(modelType -> ImmutableList.of(modelFile)));
+ setDefaultLocalesRule.set(new LocaleList(Locale.forLanguageTag("ja")));
+
+ modelFileManager.deleteUnusedModelFiles();
+
+ assertThat(model1.exists()).isTrue();
+ }
+
+ @Test
+ public void modelFileEquals() {
+ ModelFileManager.ModelFile modelA =
+ new ModelFileManager.ModelFile(
+ MODEL_TYPE, "/path/a", /* version= */ 1, "ja", /* isAsset= */ false);
+
+ ModelFileManager.ModelFile modelB =
+ new ModelFileManager.ModelFile(
+ MODEL_TYPE, "/path/a", /* version= */ 1, "ja", /* isAsset= */ false);
+
+ assertThat(modelA).isEqualTo(modelB);
+ }
+
+ @Test
+ public void modelFile_different() {
+ ModelFileManager.ModelFile modelA =
+ new ModelFileManager.ModelFile(
+ MODEL_TYPE, "/path/a", /* version= */ 1, "ja", /* isAsset= */ false);
+ ModelFileManager.ModelFile modelB =
+ new ModelFileManager.ModelFile(
+ MODEL_TYPE, "/path/b", /* version= */ 1, "ja", /* isAsset= */ false);
+
+ assertThat(modelA).isNotEqualTo(modelB);
+ }
+
+ @Test
+ public void modelFile_isPreferredTo_languageDependentIsBetter() {
+ ModelFileManager.ModelFile modelA =
+ new ModelFileManager.ModelFile(
+ MODEL_TYPE, "/path/a", /* version= */ 1, "ja", /* isAsset= */ false);
+
+ ModelFileManager.ModelFile modelB =
+ new ModelFileManager.ModelFile(
+ MODEL_TYPE, "/path/b", /* version= */ 2, LANGUAGE_INDEPENDENT, /* isAsset= */ false);
+
+ assertThat(modelA.isPreferredTo(modelB)).isTrue();
+ }
+
+ @Test
+ public void modelFile_isPreferredTo_version() {
+ ModelFileManager.ModelFile modelA =
+ new ModelFileManager.ModelFile(
+ MODEL_TYPE, "/path/a", /* version= */ 2, "ja", /* isAsset= */ false);
+
+ ModelFileManager.ModelFile modelB =
+ new ModelFileManager.ModelFile(
+ MODEL_TYPE, "/path/b", /* version= */ 1, "ja", /* isAsset= */ false);
+
+ assertThat(modelA.isPreferredTo(modelB)).isTrue();
+ }
+
+ @Test
+ public void modelFile_toModelInfo() {
+ ModelFileManager.ModelFile modelFile =
+ new ModelFileManager.ModelFile(
+ MODEL_TYPE, "/path/a", /* version= */ 2, "ja", /* isAsset= */ false);
+
+ ModelInfo modelInfo = modelFile.toModelInfo();
+
+ assertThat(modelInfo.toModelName()).isEqualTo("ja_v2");
+ }
+
+ @Test
+ public void modelFile_toModelInfos() {
+ ModelFile englishModelFile =
+ new ModelFile(MODEL_TYPE, "/path/a", /* version= */ 1, "en", /* isAsset= */ false);
+ ModelFile japaneseModelFile =
+ new ModelFile(MODEL_TYPE, "/path/a", /* version= */ 2, "ja", /* isAsset= */ false);
+
+ ImmutableList<Optional<ModelInfo>> modelInfos =
+ ModelFileManager.ModelFile.toModelInfos(
+ Optional.of(englishModelFile), Optional.of(japaneseModelFile));
+
+ assertThat(
+ modelInfos.stream()
+ .map(modelFile -> modelFile.transform(ModelInfo::toModelName).or(""))
+ .collect(Collectors.toList()))
+ .containsExactly("en_v1", "ja_v2")
+ .inOrder();
+ }
+
+ @Test
+ public void regularFileFullMatchLister() throws IOException {
+ File modelFile = new File(rootTestDir, "test.model");
+ Files.copy(TestDataUtils.getTestAnnotatorModelFile(), modelFile);
+ File wrongFile = new File(rootTestDir, "wrong.model");
+ Files.copy(TestDataUtils.getTestAnnotatorModelFile(), wrongFile);
+
+ RegularFileFullMatchLister regularFileFullMatchLister =
+ new RegularFileFullMatchLister(MODEL_TYPE, modelFile, () -> true);
+ ImmutableList<ModelFile> listedModels = regularFileFullMatchLister.list(MODEL_TYPE);
+
+ assertThat(listedModels).hasSize(1);
+ assertThat(listedModels.get(0).absolutePath).isEqualTo(modelFile.getAbsolutePath());
+ assertThat(listedModels.get(0).isAsset).isFalse();
+ }
+
+ @Test
+ public void regularFilePatternMatchLister() throws IOException {
+ File modelFile1 = new File(rootTestDir, "annotator.en.model");
+ Files.copy(TestDataUtils.getTestAnnotatorModelFile(), modelFile1);
+ File modelFile2 = new File(rootTestDir, "annotator.fr.model");
+ Files.copy(TestDataUtils.getTestAnnotatorModelFile(), modelFile2);
+ File mismatchedModelFile = new File(rootTestDir, "actions.en.model");
+ Files.copy(TestDataUtils.getTestAnnotatorModelFile(), mismatchedModelFile);
+
+ RegularFilePatternMatchLister regularFilePatternMatchLister =
+ new RegularFilePatternMatchLister(
+ MODEL_TYPE, rootTestDir, "annotator\\.(.*)\\.model", () -> true);
+ ImmutableList<ModelFile> listedModels = regularFilePatternMatchLister.list(MODEL_TYPE);
+
+ assertThat(listedModels).hasSize(2);
+ assertThat(listedModels.get(0).isAsset).isFalse();
+ assertThat(listedModels.get(1).isAsset).isFalse();
+ assertThat(ImmutableList.of(listedModels.get(0).absolutePath, listedModels.get(1).absolutePath))
+ .containsExactly(modelFile1.getAbsolutePath(), modelFile2.getAbsolutePath());
+ }
+
+ @Test
+ public void regularFilePatternMatchLister_disabled() throws IOException {
+ File modelFile1 = new File(rootTestDir, "annotator.en.model");
+ Files.copy(TestDataUtils.getTestAnnotatorModelFile(), modelFile1);
+
+ RegularFilePatternMatchLister regularFilePatternMatchLister =
+ new RegularFilePatternMatchLister(
+ MODEL_TYPE, rootTestDir, "annotator\\.(.*)\\.model", () -> false);
+ ImmutableList<ModelFile> listedModels = regularFilePatternMatchLister.list(MODEL_TYPE);
+
+ assertThat(listedModels).isEmpty();
+ }
+
+ private static void recursiveDelete(File f) {
+ if (f.isDirectory()) {
+ for (File innerFile : f.listFiles()) {
+ recursiveDelete(innerFile);
+ }
+ }
+ f.delete();
+ }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/common/TextClassifierSettingsTest.java b/java/tests/instrumentation/src/com/android/textclassifier/common/TextClassifierSettingsTest.java
new file mode 100644
index 0000000..21d6943
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/common/TextClassifierSettingsTest.java
@@ -0,0 +1,182 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.common;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import android.provider.DeviceConfig;
+import androidx.test.ext.junit.runners.AndroidJUnit4;
+import androidx.test.filters.SmallTest;
+import androidx.test.platform.app.InstrumentationRegistry;
+import com.google.common.collect.ImmutableMap;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.function.Consumer;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+
+@SmallTest
+@RunWith(AndroidJUnit4.class)
+public class TextClassifierSettingsTest {
+ private static final String WRITE_DEVICE_CONFIG_PERMISSION =
+ "android.permission.WRITE_DEVICE_CONFIG";
+ private static final float EPSILON = 0.0001f;
+
+ @Before
+ public void setup() {
+ InstrumentationRegistry.getInstrumentation()
+ .getUiAutomation()
+ .adoptShellPermissionIdentity(WRITE_DEVICE_CONFIG_PERMISSION);
+ }
+
+ @After
+ public void tearDown() {
+ InstrumentationRegistry.getInstrumentation().getUiAutomation().dropShellPermissionIdentity();
+ }
+
+ @Test
+ public void booleanSetting() {
+ assertSettings(
+ TextClassifierSettings.TEMPLATE_INTENT_FACTORY_ENABLED,
+ "false",
+ settings -> assertThat(settings.isTemplateIntentFactoryEnabled()).isFalse());
+ }
+
+ @Test
+ public void intSetting() {
+ assertSettings(
+ TextClassifierSettings.SUGGEST_SELECTION_MAX_RANGE_LENGTH,
+ "8",
+ settings -> assertThat(settings.getSuggestSelectionMaxRangeLength()).isEqualTo(8));
+ }
+
+ @Test
+ public void floatSetting() {
+ assertSettings(
+ TextClassifierSettings.LANG_ID_THRESHOLD_OVERRIDE,
+ "3.14",
+ settings -> assertThat(settings.getLangIdThresholdOverride()).isWithin(EPSILON).of(3.14f));
+ }
+
+ @Test
+ public void stringListSetting() {
+ assertSettings(
+ TextClassifierSettings.ENTITY_LIST_DEFAULT,
+ "email:url",
+ settings ->
+ assertThat(settings.getEntityListDefault()).containsExactly("email", "url").inOrder());
+ }
+
+ @Test
+ public void floatListSetting() {
+ assertSettings(
+ TextClassifierSettings.LANG_ID_CONTEXT_SETTINGS,
+ "30:0.5:0.3",
+ settings ->
+ assertThat(settings.getLangIdContextSettings())
+ .usingTolerance(EPSILON)
+ .containsExactly(30f, 0.5f, 0.3f)
+ .inOrder());
+ }
+
+ @Test
+ public void getManifestURLSetting() {
+ assertSettings(
+ "manifest_url_annotator_en",
+ "https://annotator",
+ settings ->
+ assertThat(settings.getManifestURL(ModelType.ANNOTATOR, "en"))
+ .isEqualTo("https://annotator"));
+ assertSettings(
+ "manifest_url_lang_id_universal",
+ "https://lang_id",
+ settings ->
+ assertThat(settings.getManifestURL(ModelType.LANG_ID, "universal"))
+ .isEqualTo("https://lang_id"));
+ assertSettings(
+ "manifest_url_actions_suggestions_zh",
+ "https://actions_suggestions",
+ settings ->
+ assertThat(settings.getManifestURL(ModelType.ACTIONS_SUGGESTIONS, "zh"))
+ .isEqualTo("https://actions_suggestions"));
+ }
+
+ @Test
+ public void getLanguageTagsForManifestURL() {
+ assertSettings(
+ ImmutableMap.of(
+ "manifest_url_annotator_en", "https://annotator-en",
+ "manifest_url_annotator_en-us", "https://annotator-en-us",
+ "manifest_url_annotator_zh-hant-hk", "https://annotator-zh",
+ "manifest_url_lang_id_universal", "https://lang_id"),
+ settings ->
+ assertThat(settings.getLanguageTagsForManifestURL(ModelType.ANNOTATOR))
+ .containsExactly("en", "en-us", "zh-hant-hk"));
+
+ assertSettings(
+ ImmutableMap.of(
+ "manifest_url_annotator_en", "https://annotator-en",
+ "manifest_url_annotator_en-us", "https://annotator-en-us",
+ "manifest_url_annotator_zh-hant-hk", "https://annotator-zh",
+ "manifest_url_lang_id_universal", "https://lang_id"),
+ settings ->
+ assertThat(settings.getLanguageTagsForManifestURL(ModelType.LANG_ID))
+ .containsExactly("universal"));
+
+ assertSettings(
+ ImmutableMap.of(
+ "manifest_url_annotator_en", "https://annotator-en",
+ "manifest_url_annotator_en-us", "https://annotator-en-us",
+ "manifest_url_annotator_zh-hant-hk", "https://annotator-zh",
+ "manifest_url_lang_id_universal", "https://lang_id"),
+ settings ->
+ assertThat(settings.getLanguageTagsForManifestURL(ModelType.ACTIONS_SUGGESTIONS))
+ .isEmpty());
+ }
+
+ private static void assertSettings(
+ String key, String value, Consumer<TextClassifierSettings> settingsConsumer) {
+ assertSettings(ImmutableMap.of(key, value), settingsConsumer);
+ }
+
+ private static void assertSettings(
+ Map<String, String> keyValueMap, Consumer<TextClassifierSettings> settingsConsumer) {
+ HashMap<String, String> keyOriginalValueMap = new HashMap<>();
+ for (String key : keyValueMap.keySet()) {
+ keyOriginalValueMap.put(
+ key, DeviceConfig.getProperty(DeviceConfig.NAMESPACE_TEXTCLASSIFIER, key));
+ }
+ TextClassifierSettings settings = new TextClassifierSettings();
+ try {
+ for (String key : keyValueMap.keySet()) {
+ setDeviceConfig(key, keyValueMap.get(key));
+ }
+ settingsConsumer.accept(settings);
+ } finally {
+ for (String key : keyValueMap.keySet()) {
+ setDeviceConfig(key, keyOriginalValueMap.get(key));
+ }
+ }
+ }
+
+ private static void setDeviceConfig(String key, String value) {
+ DeviceConfig.setProperty(
+ DeviceConfig.NAMESPACE_TEXTCLASSIFIER, key, value, /* makeDefault */ false);
+ }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/common/intent/LabeledIntentTest.java b/java/tests/instrumentation/src/com/android/textclassifier/common/intent/LabeledIntentTest.java
index a1d9dcf..fdc454d 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/common/intent/LabeledIntentTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/common/intent/LabeledIntentTest.java
@@ -17,7 +17,7 @@
package com.android.textclassifier.common.intent;
import static com.google.common.truth.Truth.assertThat;
-import static org.testng.Assert.assertThrows;
+import static org.testng.Assert.expectThrows;
import android.content.ComponentName;
import android.content.Context;
@@ -119,7 +119,7 @@
@Test
public void resolve_missingTitle() {
- assertThrows(
+ expectThrows(
IllegalArgumentException.class,
() -> new LabeledIntent(null, null, DESCRIPTION, null, INTENT, REQUEST_CODE));
}
@@ -154,4 +154,56 @@
assertThat(result.remoteAction.getContentDescription().toString())
.isEqualTo("Use fake to open map");
}
+
+ @Test
+ public void resolve_noVisibilityToWebIntentHandler() {
+ Context context =
+ new FakeContextBuilder()
+ .setIntentComponent(Intent.ACTION_VIEW, /* component= */ null)
+ .build();
+ Intent webIntent = new Intent(Intent.ACTION_VIEW);
+ webIntent.setData(Uri.parse("https://www.android.com"));
+ LabeledIntent labeledIntent =
+ new LabeledIntent(
+ TITLE_WITHOUT_ENTITY,
+ TITLE_WITH_ENTITY,
+ DESCRIPTION,
+ /* descriptionWithAppName= */ null,
+ webIntent,
+ REQUEST_CODE);
+
+ LabeledIntent.Result result = labeledIntent.resolve(context, /*titleChooser*/ null);
+
+ assertThat(result).isNotNull();
+ assertThat(result.remoteAction.getTitle().toString()).isEqualTo(TITLE_WITH_ENTITY);
+ assertThat(result.remoteAction.getContentDescription().toString()).isEqualTo(DESCRIPTION);
+ assertThat(result.resolvedIntent.getAction()).isEqualTo(Intent.ACTION_VIEW);
+ assertThat(result.resolvedIntent.getComponent()).isNull();
+ }
+
+ @Test
+ public void resolve_noVisibilityToWebIntentHandler_withDescriptionWithAppName() {
+ Context context =
+ new FakeContextBuilder()
+ .setIntentComponent(Intent.ACTION_VIEW, /* component= */ null)
+ .build();
+ Intent webIntent = new Intent(Intent.ACTION_VIEW);
+ webIntent.setData(Uri.parse("https://www.android.com"));
+ LabeledIntent labeledIntent =
+ new LabeledIntent(
+ TITLE_WITHOUT_ENTITY,
+ TITLE_WITH_ENTITY,
+ DESCRIPTION,
+ /* descriptionWithAppName= */ "name",
+ webIntent,
+ REQUEST_CODE);
+
+ LabeledIntent.Result result = labeledIntent.resolve(context, /*titleChooser*/ null);
+
+ assertThat(result).isNotNull();
+ assertThat(result.remoteAction.getTitle().toString()).isEqualTo(TITLE_WITH_ENTITY);
+ assertThat(result.remoteAction.getContentDescription().toString()).isEqualTo(DESCRIPTION);
+ assertThat(result.resolvedIntent.getAction()).isEqualTo(Intent.ACTION_VIEW);
+ assertThat(result.resolvedIntent.getComponent()).isNull();
+ }
}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/common/intent/TemplateIntentFactoryTest.java b/java/tests/instrumentation/src/com/android/textclassifier/common/intent/TemplateIntentFactoryTest.java
index ab241c5..216cd5d 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/common/intent/TemplateIntentFactoryTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/common/intent/TemplateIntentFactoryTest.java
@@ -107,7 +107,7 @@
assertThat(intent.getData().toString()).isEqualTo(DATA);
assertThat(intent.getType()).isEqualTo(TYPE);
assertThat(intent.getFlags()).isEqualTo(FLAG);
- assertThat(intent.getCategories()).containsExactly((Object[]) CATEGORY);
+ assertThat(intent.getCategories()).containsExactlyElementsIn(CATEGORY);
assertThat(intent.getPackage()).isNull();
assertThat(intent.getStringExtra(KEY_ONE)).isEqualTo(VALUE_ONE);
assertThat(intent.getIntExtra(KEY_TWO, 0)).isEqualTo(VALUE_TWO);
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/GenerateLinksLoggerTest.java b/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/GenerateLinksLoggerTest.java
index c2a911a..e215b15 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/GenerateLinksLoggerTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/GenerateLinksLoggerTest.java
@@ -18,8 +18,12 @@
import static com.google.common.truth.Truth.assertThat;
+import android.os.Binder;
+import android.os.Parcel;
import android.stats.textclassifier.EventType;
import android.stats.textclassifier.WidgetType;
+import android.view.textclassifier.TextClassificationContext;
+import android.view.textclassifier.TextClassificationSessionId;
import android.view.textclassifier.TextClassifier;
import android.view.textclassifier.TextLinks;
import androidx.test.core.app.ApplicationProvider;
@@ -49,10 +53,17 @@
/** A statsd config ID, which is arbitrary. */
private static final long CONFIG_ID = 689777;
+ private static final long SHORT_TIMEOUT_MS = 1000;
+
private static final ModelInfo ANNOTATOR_MODEL =
new ModelInfo(1, ImmutableList.of(Locale.ENGLISH));
private static final ModelInfo LANGID_MODEL =
new ModelInfo(2, ImmutableList.of(Locale.forLanguageTag("*")));
+ private static final String SESSION_ID = "123456";
+ private static final String WIDGET_TYPE = TextClassifier.WIDGET_TYPE_WEBVIEW;
+ private static final WidgetType WIDGET_TYPE_ENUM = WidgetType.WIDGET_TYPE_WEBVIEW;
+ private final TextClassificationContext textClassificationContext =
+ new TextClassificationContext.Builder(PACKAGE_NAME, WIDGET_TYPE).build();
@Before
public void setup() throws Exception {
@@ -81,18 +92,18 @@
new TextLinks.Builder(testText)
.addLink(phoneOffset, phoneOffset + phoneText.length(), phoneEntityScores)
.build();
- String uuid = "uuid";
- GenerateLinksLogger generateLinksLogger =
- new GenerateLinksLogger(/* sampleRate= */ 1, () -> uuid);
+ GenerateLinksLogger generateLinksLogger = new GenerateLinksLogger(/* sampleRate= */ 1);
generateLinksLogger.logGenerateLinks(
+ createTextClassificationSessionId(),
+ textClassificationContext,
testText,
links,
PACKAGE_NAME,
LATENCY_MS,
Optional.of(ANNOTATOR_MODEL),
Optional.of(LANGID_MODEL));
- ImmutableList<Atom> loggedAtoms = StatsdTestUtils.getLoggedAtoms(CONFIG_ID);
+ ImmutableList<Atom> loggedAtoms = StatsdTestUtils.getLoggedAtoms(CONFIG_ID, SHORT_TIMEOUT_MS);
ImmutableList<TextLinkifyEvent> loggedEvents =
ImmutableList.copyOf(
@@ -101,10 +112,10 @@
assertThat(loggedEvents).hasSize(2);
TextLinkifyEvent summaryEvent =
AtomsProto.TextLinkifyEvent.newBuilder()
- .setSessionId(uuid)
+ .setSessionId(SESSION_ID)
.setEventIndex(0)
.setModelName("en_v1")
- .setWidgetType(WidgetType.WIDGET_TYPE_UNKNOWN)
+ .setWidgetType(WIDGET_TYPE_ENUM)
.setEventType(EventType.LINKS_GENERATED)
.setPackageName(PACKAGE_NAME)
.setEntityType("")
@@ -116,10 +127,10 @@
.build();
TextLinkifyEvent phoneEvent =
AtomsProto.TextLinkifyEvent.newBuilder()
- .setSessionId(uuid)
+ .setSessionId(SESSION_ID)
.setEventIndex(0)
.setModelName("en_v1")
- .setWidgetType(WidgetType.WIDGET_TYPE_UNKNOWN)
+ .setWidgetType(WIDGET_TYPE_ENUM)
.setEventType(EventType.LINKS_GENERATED)
.setPackageName(PACKAGE_NAME)
.setEntityType(TextClassifier.TYPE_PHONE)
@@ -146,18 +157,18 @@
.addLink(phoneOffset, phoneOffset + phoneText.length(), phoneEntityScores)
.addLink(addressOffset, addressOffset + addressText.length(), addressEntityScores)
.build();
- String uuid = "uuid";
- GenerateLinksLogger generateLinksLogger =
- new GenerateLinksLogger(/* sampleRate= */ 1, () -> uuid);
+ GenerateLinksLogger generateLinksLogger = new GenerateLinksLogger(/* sampleRate= */ 1);
generateLinksLogger.logGenerateLinks(
+ createTextClassificationSessionId(),
+ textClassificationContext,
testText,
links,
PACKAGE_NAME,
LATENCY_MS,
Optional.of(ANNOTATOR_MODEL),
Optional.of(LANGID_MODEL));
- ImmutableList<Atom> loggedAtoms = StatsdTestUtils.getLoggedAtoms(CONFIG_ID);
+ ImmutableList<Atom> loggedAtoms = StatsdTestUtils.getLoggedAtoms(CONFIG_ID, SHORT_TIMEOUT_MS);
ImmutableList<TextLinkifyEvent> loggedEvents =
ImmutableList.copyOf(
@@ -180,4 +191,13 @@
assertThat(phoneEvent.getNumLinks()).isEqualTo(1);
assertThat(phoneEvent.getLinkedTextLength()).isEqualTo(phoneText.length());
}
+
+ private static TextClassificationSessionId createTextClassificationSessionId() {
+ // A hack to create TextClassificationSessionId because its constructor is @hide.
+ Parcel parcel = Parcel.obtain();
+ parcel.writeString(SESSION_ID);
+ parcel.writeStrongBinder(new Binder());
+ parcel.setDataPosition(0);
+ return TextClassificationSessionId.CREATOR.createFromParcel(parcel);
+ }
}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/StatsdTestUtils.java b/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/StatsdTestUtils.java
index f2b8223..1bcd7b7 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/StatsdTestUtils.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/StatsdTestUtils.java
@@ -46,7 +46,6 @@
/** Util functions to make statsd testing easier by using adb shell cmd stats commands. */
public class StatsdTestUtils {
private static final String TAG = "StatsdTestUtils";
- private static final long SHORT_WAIT_MS = 1000;
private StatsdTestUtils() {}
@@ -74,9 +73,10 @@
/**
* Extracts logged atoms from the report, sorted by logging time, and deletes the saved report.
*/
- public static ImmutableList<Atom> getLoggedAtoms(long configId) throws Exception {
- // There is no callback to notify us the log is collected. So we do a short wait here.
- Thread.sleep(SHORT_WAIT_MS);
+ public static ImmutableList<Atom> getLoggedAtoms(long configId, long timeoutMillis)
+ throws Exception {
+ // There is no callback to notify us the log is collected. So we do a wait here.
+ Thread.sleep(timeoutMillis);
ConfigMetricsReportList reportList = getAndRemoveReportList(configId);
assertThat(reportList.getReportsCount()).isEqualTo(1);
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/TextClassifierApiUsageLoggerTest.java b/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/TextClassifierApiUsageLoggerTest.java
new file mode 100644
index 0000000..b9b7a95
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/TextClassifierApiUsageLoggerTest.java
@@ -0,0 +1,156 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.common.statsd;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import android.os.Binder;
+import android.os.Parcel;
+import android.view.textclassifier.TextClassificationSessionId;
+import androidx.test.core.app.ApplicationProvider;
+import androidx.test.ext.junit.runners.AndroidJUnit4;
+import androidx.test.filters.LargeTest;
+import com.android.internal.os.StatsdConfigProto.StatsdConfig;
+import com.android.os.AtomsProto.Atom;
+import com.android.os.AtomsProto.TextClassifierApiUsageReported;
+import com.android.os.AtomsProto.TextClassifierApiUsageReported.ApiType;
+import com.android.os.AtomsProto.TextClassifierApiUsageReported.ResultType;
+import com.android.textclassifier.common.statsd.TextClassifierApiUsageLogger.Session;
+import com.google.common.collect.ImmutableList;
+import com.google.common.util.concurrent.MoreExecutors;
+import java.util.stream.Collectors;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+
+@RunWith(AndroidJUnit4.class)
+@LargeTest
+public class TextClassifierApiUsageLoggerTest {
+ /** A statsd config ID, which is arbitrary. */
+ private static final long CONFIG_ID = 689777;
+
+ private static final long SHORT_TIMEOUT_MS = 1000;
+
+ private static final String SESSION_ID = "abcdef";
+
+ @Before
+ public void setup() throws Exception {
+ StatsdTestUtils.cleanup(CONFIG_ID);
+
+ StatsdConfig.Builder builder =
+ StatsdConfig.newBuilder()
+ .setId(CONFIG_ID)
+ .addAllowedLogSource(ApplicationProvider.getApplicationContext().getPackageName());
+ StatsdTestUtils.addAtomMatcher(builder, Atom.TEXT_CLASSIFIER_API_USAGE_REPORTED_FIELD_NUMBER);
+ StatsdTestUtils.pushConfig(builder.build());
+ }
+
+ @After
+ public void tearDown() throws Exception {
+ StatsdTestUtils.cleanup(CONFIG_ID);
+ }
+
+ @Test
+ public void reportSuccess() throws Exception {
+ TextClassifierApiUsageLogger textClassifierApiUsageLogger =
+ new TextClassifierApiUsageLogger(
+ /* sampleRateSupplier= */ () -> 1, MoreExecutors.directExecutor());
+ Session session =
+ textClassifierApiUsageLogger.createSession(
+ TextClassifierApiUsageLogger.API_TYPE_SUGGEST_SELECTION,
+ createTextClassificationSessionId());
+ // so that the latency we log is greater than 0.
+ Thread.sleep(10);
+ session.reportSuccess();
+
+ ImmutableList<Atom> loggedAtoms = StatsdTestUtils.getLoggedAtoms(CONFIG_ID, SHORT_TIMEOUT_MS);
+
+ ImmutableList<TextClassifierApiUsageReported> loggedEvents =
+ ImmutableList.copyOf(
+ loggedAtoms.stream()
+ .map(Atom::getTextClassifierApiUsageReported)
+ .collect(Collectors.toList()));
+
+ assertThat(loggedEvents).hasSize(1);
+ TextClassifierApiUsageReported event = loggedEvents.get(0);
+ assertThat(event.getApiType()).isEqualTo(ApiType.SUGGEST_SELECTION);
+ assertThat(event.getResultType()).isEqualTo(ResultType.SUCCESS);
+ assertThat(event.getLatencyMillis()).isGreaterThan(0L);
+ assertThat(event.getSessionId()).isEqualTo(SESSION_ID);
+ }
+
+ @Test
+ public void reportFailure() throws Exception {
+ TextClassifierApiUsageLogger textClassifierApiUsageLogger =
+ new TextClassifierApiUsageLogger(
+ /* sampleRateSupplier= */ () -> 1, MoreExecutors.directExecutor());
+ Session session =
+ textClassifierApiUsageLogger.createSession(
+ TextClassifierApiUsageLogger.API_TYPE_CLASSIFY_TEXT,
+ createTextClassificationSessionId());
+ // so that the latency we log is greater than 0.
+ Thread.sleep(10);
+ session.reportFailure();
+
+ ImmutableList<Atom> loggedAtoms = StatsdTestUtils.getLoggedAtoms(CONFIG_ID, SHORT_TIMEOUT_MS);
+
+ ImmutableList<TextClassifierApiUsageReported> loggedEvents =
+ ImmutableList.copyOf(
+ loggedAtoms.stream()
+ .map(Atom::getTextClassifierApiUsageReported)
+ .collect(Collectors.toList()));
+
+ assertThat(loggedEvents).hasSize(1);
+ TextClassifierApiUsageReported event = loggedEvents.get(0);
+ assertThat(event.getApiType()).isEqualTo(ApiType.CLASSIFY_TEXT);
+ assertThat(event.getResultType()).isEqualTo(ResultType.FAIL);
+ assertThat(event.getLatencyMillis()).isGreaterThan(0L);
+ assertThat(event.getSessionId()).isEqualTo(SESSION_ID);
+ }
+
+ @Test
+ public void noLog_sampleRateZero() throws Exception {
+ TextClassifierApiUsageLogger textClassifierApiUsageLogger =
+ new TextClassifierApiUsageLogger(
+ /* sampleRateSupplier= */ () -> 0, MoreExecutors.directExecutor());
+ Session session =
+ textClassifierApiUsageLogger.createSession(
+ TextClassifierApiUsageLogger.API_TYPE_CLASSIFY_TEXT,
+ createTextClassificationSessionId());
+ session.reportSuccess();
+
+ ImmutableList<Atom> loggedAtoms = StatsdTestUtils.getLoggedAtoms(CONFIG_ID, SHORT_TIMEOUT_MS);
+
+ ImmutableList<TextClassifierApiUsageReported> loggedEvents =
+ ImmutableList.copyOf(
+ loggedAtoms.stream()
+ .map(Atom::getTextClassifierApiUsageReported)
+ .collect(Collectors.toList()));
+
+ assertThat(loggedEvents).isEmpty();
+ }
+
+ private static TextClassificationSessionId createTextClassificationSessionId() {
+ // Used a hack to create TextClassificationSessionId because its constructor is @hide.
+ Parcel parcel = Parcel.obtain();
+ parcel.writeString(SESSION_ID);
+ parcel.writeStrongBinder(new Binder());
+ parcel.setDataPosition(0);
+ return TextClassificationSessionId.CREATOR.createFromParcel(parcel);
+ }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/TextClassifierEventLoggerTest.java b/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/TextClassifierEventLoggerTest.java
index 719fc31..f105e26 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/TextClassifierEventLoggerTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/TextClassifierEventLoggerTest.java
@@ -45,6 +45,8 @@
/** A statsd config ID, which is arbitrary. */
private static final long CONFIG_ID = 689777;
+ private static final long SHORT_TIMEOUT_MS = 1000;
+
private TextClassifierEventLogger textClassifierEventLogger;
@Before
@@ -102,7 +104,7 @@
.setPackageName(PKG_NAME)
.setLangidModelName("und_v1")
.build();
- ImmutableList<Atom> atoms = StatsdTestUtils.getLoggedAtoms(CONFIG_ID);
+ ImmutableList<Atom> atoms = StatsdTestUtils.getLoggedAtoms(CONFIG_ID, SHORT_TIMEOUT_MS);
assertThat(atoms).hasSize(1);
assertThat(atoms.get(0).getTextSelectionEvent()).isEqualTo(event);
}
@@ -119,7 +121,7 @@
textClassifierEventLogger.writeEvent(sessionId, textSelectionEvent);
- ImmutableList<Atom> atoms = StatsdTestUtils.getLoggedAtoms(CONFIG_ID);
+ ImmutableList<Atom> atoms = StatsdTestUtils.getLoggedAtoms(CONFIG_ID, SHORT_TIMEOUT_MS);
assertThat(atoms).hasSize(1);
assertThat(atoms.get(0).getTextSelectionEvent().getEventType())
.isEqualTo(EventType.SMART_SELECTION_SINGLE);
@@ -137,7 +139,7 @@
textClassifierEventLogger.writeEvent(sessionId, textSelectionEvent);
- ImmutableList<Atom> atoms = StatsdTestUtils.getLoggedAtoms(CONFIG_ID);
+ ImmutableList<Atom> atoms = StatsdTestUtils.getLoggedAtoms(CONFIG_ID, SHORT_TIMEOUT_MS);
assertThat(atoms).hasSize(1);
assertThat(atoms.get(0).getTextSelectionEvent().getEventType())
.isEqualTo(EventType.SMART_SELECTION_MULTI);
@@ -155,7 +157,7 @@
textClassifierEventLogger.writeEvent(sessionId, textSelectionEvent);
- ImmutableList<Atom> atoms = StatsdTestUtils.getLoggedAtoms(CONFIG_ID);
+ ImmutableList<Atom> atoms = StatsdTestUtils.getLoggedAtoms(CONFIG_ID, SHORT_TIMEOUT_MS);
assertThat(atoms).hasSize(1);
assertThat(atoms.get(0).getTextSelectionEvent().getEventType())
.isEqualTo(EventType.AUTO_SELECTION);
@@ -189,7 +191,7 @@
.setPackageName(PKG_NAME)
.setLangidModelName("und_v1")
.build();
- ImmutableList<Atom> atoms = StatsdTestUtils.getLoggedAtoms(CONFIG_ID);
+ ImmutableList<Atom> atoms = StatsdTestUtils.getLoggedAtoms(CONFIG_ID, SHORT_TIMEOUT_MS);
assertThat(atoms).hasSize(1);
assertThat(atoms.get(0).getTextLinkifyEvent()).isEqualTo(event);
}
@@ -223,7 +225,7 @@
.setAnnotatorModelName("zh_v2")
.setLangidModelName("und_v3")
.build();
- ImmutableList<Atom> atoms = StatsdTestUtils.getLoggedAtoms(CONFIG_ID);
+ ImmutableList<Atom> atoms = StatsdTestUtils.getLoggedAtoms(CONFIG_ID, SHORT_TIMEOUT_MS);
assertThat(atoms).hasSize(1);
assertThat(atoms.get(0).getConversationActionsEvent()).isEqualTo(event);
}
@@ -254,7 +256,7 @@
.setActionIndex(1)
.setPackageName(PKG_NAME)
.build();
- ImmutableList<Atom> atoms = StatsdTestUtils.getLoggedAtoms(CONFIG_ID);
+ ImmutableList<Atom> atoms = StatsdTestUtils.getLoggedAtoms(CONFIG_ID, SHORT_TIMEOUT_MS);
assertThat(atoms).hasSize(1);
assertThat(atoms.get(0).getLanguageDetectionEvent()).isEqualTo(event);
}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/subjects/EntitySubject.java b/java/tests/instrumentation/src/com/android/textclassifier/subjects/EntitySubject.java
deleted file mode 100644
index 569143f..0000000
--- a/java/tests/instrumentation/src/com/android/textclassifier/subjects/EntitySubject.java
+++ /dev/null
@@ -1,49 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.android.textclassifier.subjects;
-
-import static com.google.common.truth.Truth.assertAbout;
-
-import com.android.textclassifier.Entity;
-import com.google.common.truth.FailureMetadata;
-import com.google.common.truth.Subject;
-import javax.annotation.Nullable;
-
-/** Test helper for checking {@link com.android.textclassifier.Entity} results. */
-public final class EntitySubject extends Subject {
-
- private static final float TOLERANCE = 0.0001f;
-
- private final Entity entity;
-
- public static EntitySubject assertThat(@Nullable Entity entity) {
- return assertAbout(EntitySubject::new).that(entity);
- }
-
- private EntitySubject(FailureMetadata failureMetadata, @Nullable Entity entity) {
- super(failureMetadata, entity);
- this.entity = entity;
- }
-
- public void isMatchWithinTolerance(@Nullable Entity entity) {
- if (!entity.getEntityType().equals(this.entity.getEntityType())) {
- failWithActual("expected to have type", entity.getEntityType());
- }
- check("expected to have confidence score").that(entity.getScore()).isWithin(TOLERANCE)
- .of(this.entity.getScore());
- }
-}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/testing/SetDefaultLocalesRule.java b/java/tests/instrumentation/src/com/android/textclassifier/testing/SetDefaultLocalesRule.java
new file mode 100644
index 0000000..ec1405b
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/testing/SetDefaultLocalesRule.java
@@ -0,0 +1,41 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.testing;
+
+import android.os.LocaleList;
+import org.junit.rules.ExternalResource;
+
+public class SetDefaultLocalesRule extends ExternalResource {
+
+ private LocaleList originalValue;
+
+ @Override
+ protected void before() throws Throwable {
+ super.before();
+ originalValue = LocaleList.getDefault();
+ }
+
+ public void set(LocaleList newValue) {
+ LocaleList.setDefault(newValue);
+ }
+
+ @Override
+ protected void after() {
+ super.after();
+ LocaleList.setDefault(originalValue);
+ }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/testing/TestingDeviceConfig.java b/java/tests/instrumentation/src/com/android/textclassifier/testing/TestingDeviceConfig.java
new file mode 100644
index 0000000..670e3d0
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/testing/TestingDeviceConfig.java
@@ -0,0 +1,66 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.testing;
+
+import android.provider.DeviceConfig.Properties;
+import androidx.annotation.NonNull;
+import com.android.textclassifier.common.TextClassifierSettings;
+import java.util.HashMap;
+import javax.annotation.Nullable;
+
+/** A fake DeviceConfig implementation for testing purpose. */
+public final class TestingDeviceConfig implements TextClassifierSettings.IDeviceConfig {
+
+ private final HashMap<String, String> strConfigs;
+ private final HashMap<String, Boolean> boolConfigs;
+
+ public TestingDeviceConfig() {
+ this.strConfigs = new HashMap<>();
+ this.boolConfigs = new HashMap<>();
+ }
+
+ public void setConfig(String key, String value) {
+ strConfigs.put(key, value);
+ }
+
+ public void setConfig(String key, boolean value) {
+ boolConfigs.put(key, value);
+ }
+
+ @Override
+ public Properties getProperties(@NonNull String namespace, @NonNull String... names) {
+ Properties.Builder builder = new Properties.Builder(namespace);
+ for (String key : strConfigs.keySet()) {
+ builder.setString(key, strConfigs.get(key));
+ }
+ for (String key : boolConfigs.keySet()) {
+ builder.setBoolean(key, boolConfigs.get(key));
+ }
+ return builder.build();
+ }
+
+ @Override
+ public boolean getBoolean(@NonNull String namespace, @NonNull String name, boolean defaultValue) {
+ return boolConfigs.containsKey(name) ? boolConfigs.get(name) : defaultValue;
+ }
+
+ @Override
+ public String getString(
+ @NonNull String namespace, @NonNull String name, @Nullable String defaultValue) {
+ return strConfigs.containsKey(name) ? strConfigs.get(name) : defaultValue;
+ }
+}
diff --git a/java/tests/instrumentation/testdata/actions.model b/java/tests/instrumentation/testdata/actions.model
new file mode 100755
index 0000000..74422f6
--- /dev/null
+++ b/java/tests/instrumentation/testdata/actions.model
Binary files differ
diff --git a/java/tests/instrumentation/testdata/annotator.model b/java/tests/instrumentation/testdata/annotator.model
new file mode 100755
index 0000000..f5fcc23
--- /dev/null
+++ b/java/tests/instrumentation/testdata/annotator.model
Binary files differ
diff --git a/java/tests/instrumentation/testdata/langid.model b/java/tests/instrumentation/testdata/langid.model
new file mode 100755
index 0000000..e94dada
--- /dev/null
+++ b/java/tests/instrumentation/testdata/langid.model
Binary files differ
diff --git a/jni/Android.bp b/jni/Android.bp
index d94821e..4300d8e 100644
--- a/jni/Android.bp
+++ b/jni/Android.bp
@@ -23,6 +23,8 @@
java_library_static {
name: "libtextclassifier-java",
- sdk_version: "core_current",
srcs: ["**/*.java"],
+ static_libs: ["guava"],
+ sdk_version: "system_current",
+ min_sdk_version: "28",
}
diff --git a/jni/com/google/android/textclassifier/ActionsSuggestionsModel.java b/jni/com/google/android/textclassifier/ActionsSuggestionsModel.java
index 3af04e8..b5c8ab6 100644
--- a/jni/com/google/android/textclassifier/ActionsSuggestionsModel.java
+++ b/jni/com/google/android/textclassifier/ActionsSuggestionsModel.java
@@ -16,7 +16,9 @@
package com.google.android.textclassifier;
+import android.content.res.AssetFileDescriptor;
import java.util.concurrent.atomic.AtomicBoolean;
+import javax.annotation.Nullable;
/**
* Java wrapper for ActionsSuggestions native library interface. This library is used to suggest
@@ -37,7 +39,7 @@
* Creates a new instance of Actions predictor, using the provided model image, given as a file
* descriptor.
*/
- public ActionsSuggestionsModel(int fileDescriptor, byte[] serializedPreconditions) {
+ public ActionsSuggestionsModel(int fileDescriptor, @Nullable byte[] serializedPreconditions) {
actionsModelPtr = nativeNewActionsModel(fileDescriptor, serializedPreconditions);
if (actionsModelPtr == 0L) {
throw new IllegalArgumentException("Couldn't initialize actions model from file descriptor.");
@@ -52,7 +54,7 @@
* Creates a new instance of Actions predictor, using the provided model image, given as a file
* path.
*/
- public ActionsSuggestionsModel(String path, byte[] serializedPreconditions) {
+ public ActionsSuggestionsModel(String path, @Nullable byte[] serializedPreconditions) {
actionsModelPtr = nativeNewActionsModelFromPath(path, serializedPreconditions);
if (actionsModelPtr == 0L) {
throw new IllegalArgumentException("Couldn't initialize actions model from given file.");
@@ -63,8 +65,29 @@
this(path, /* serializedPreconditions= */ null);
}
+ /**
+ * Creates a new instance of Actions predictor, using the provided model image, given as an {@link
+ * AssetFileDescriptor}).
+ */
+ public ActionsSuggestionsModel(
+ AssetFileDescriptor assetFileDescriptor, @Nullable byte[] serializedPreconditions) {
+ actionsModelPtr =
+ nativeNewActionsModelWithOffset(
+ assetFileDescriptor.getParcelFileDescriptor().getFd(),
+ assetFileDescriptor.getStartOffset(),
+ assetFileDescriptor.getLength(),
+ serializedPreconditions);
+ if (actionsModelPtr == 0L) {
+ throw new IllegalArgumentException("Couldn't initialize actions model from file descriptor.");
+ }
+ }
+
+ public ActionsSuggestionsModel(AssetFileDescriptor assetFileDescriptor) {
+ this(assetFileDescriptor, /* serializedPreconditions= */ null);
+ }
+
/** Suggests actions / replies to the given conversation. */
- public ActionSuggestion[] suggestActions(
+ public ActionSuggestions suggestActions(
Conversation conversation, ActionSuggestionOptions options, AnnotatorModel annotator) {
return nativeSuggestActions(
actionsModelPtr,
@@ -76,7 +99,7 @@
/* generateAndroidIntents= */ false);
}
- public ActionSuggestion[] suggestActionsWithIntents(
+ public ActionSuggestions suggestActionsWithIntents(
Conversation conversation,
ActionSuggestionOptions options,
Object appContext,
@@ -115,40 +138,88 @@
return nativeGetLocales(fd);
}
+ /** Returns a comma separated list of locales supported by the model as BCP 47 tags. */
+ public static String getLocales(AssetFileDescriptor assetFileDescriptor) {
+ return nativeGetLocalesWithOffset(
+ assetFileDescriptor.getParcelFileDescriptor().getFd(),
+ assetFileDescriptor.getStartOffset(),
+ assetFileDescriptor.getLength());
+ }
+
/** Returns the version of the model. */
public static int getVersion(int fd) {
return nativeGetVersion(fd);
}
+ /** Returns the version of the model. */
+ public static int getVersion(AssetFileDescriptor assetFileDescriptor) {
+ return nativeGetVersionWithOffset(
+ assetFileDescriptor.getParcelFileDescriptor().getFd(),
+ assetFileDescriptor.getStartOffset(),
+ assetFileDescriptor.getLength());
+ }
+
/** Returns the name of the model. */
public static String getName(int fd) {
return nativeGetName(fd);
}
+ /** Returns the name of the model. */
+ public static String getName(AssetFileDescriptor assetFileDescriptor) {
+ return nativeGetNameWithOffset(
+ assetFileDescriptor.getParcelFileDescriptor().getFd(),
+ assetFileDescriptor.getStartOffset(),
+ assetFileDescriptor.getLength());
+ }
+
+ /** Initializes conversation intent detection, passing the given serialized config to it. */
+ public void initializeConversationIntentDetection(byte[] serializedConfig) {
+ if (!nativeInitializeConversationIntentDetection(actionsModelPtr, serializedConfig)) {
+ throw new IllegalArgumentException("Couldn't initialize conversation intent detection");
+ }
+ }
+
+ /** Represents a list of suggested actions of a given conversation. */
+ public static final class ActionSuggestions {
+ /** A list of suggested actionsm sorted by score descendingly. */
+ public final ActionSuggestion[] actionSuggestions;
+ /** Whether the input conversation is considered as sensitive. */
+ public final boolean isSensitive;
+
+ public ActionSuggestions(ActionSuggestion[] actionSuggestions, boolean isSensitive) {
+ this.actionSuggestions = actionSuggestions;
+ this.isSensitive = isSensitive;
+ }
+ }
+
/** Action suggestion that contains a response text and the type of the response. */
public static final class ActionSuggestion {
- private final String responseText;
+ @Nullable private final String responseText;
private final String actionType;
private final float score;
- private final NamedVariant[] entityData;
- private final byte[] serializedEntityData;
- private final RemoteActionTemplate[] remoteActionTemplates;
+ @Nullable private final NamedVariant[] entityData;
+ @Nullable private final byte[] serializedEntityData;
+ @Nullable private final RemoteActionTemplate[] remoteActionTemplates;
+ @Nullable private final Slot[] slots;
public ActionSuggestion(
- String responseText,
+ @Nullable String responseText,
String actionType,
float score,
- NamedVariant[] entityData,
- byte[] serializedEntityData,
- RemoteActionTemplate[] remoteActionTemplates) {
+ @Nullable NamedVariant[] entityData,
+ @Nullable byte[] serializedEntityData,
+ @Nullable RemoteActionTemplate[] remoteActionTemplates,
+ @Nullable Slot[] slots) {
this.responseText = responseText;
this.actionType = actionType;
this.score = score;
this.entityData = entityData;
this.serializedEntityData = serializedEntityData;
this.remoteActionTemplates = remoteActionTemplates;
+ this.slots = slots;
}
+ @Nullable
public String getResponseText() {
return responseText;
}
@@ -162,33 +233,41 @@
return score;
}
+ @Nullable
public NamedVariant[] getEntityData() {
return entityData;
}
+ @Nullable
public byte[] getSerializedEntityData() {
return serializedEntityData;
}
+ @Nullable
public RemoteActionTemplate[] getRemoteActionTemplates() {
return remoteActionTemplates;
}
+
+ @Nullable
+ public Slot[] getSlots() {
+ return slots;
+ }
}
/** Represents a single message in the conversation. */
public static final class ConversationMessage {
private final int userId;
- private final String text;
+ @Nullable private final String text;
private final long referenceTimeMsUtc;
- private final String referenceTimezone;
- private final String detectedTextLanguageTags;
+ @Nullable private final String referenceTimezone;
+ @Nullable private final String detectedTextLanguageTags;
public ConversationMessage(
int userId,
- String text,
+ @Nullable String text,
long referenceTimeMsUtc,
- String referenceTimezone,
- String detectedTextLanguageTags) {
+ @Nullable String referenceTimezone,
+ @Nullable String detectedTextLanguageTags) {
this.userId = userId;
this.text = text;
this.referenceTimeMsUtc = referenceTimeMsUtc;
@@ -201,6 +280,7 @@
return userId;
}
+ @Nullable
public String getText() {
return text;
}
@@ -213,11 +293,13 @@
return referenceTimeMsUtc;
}
+ @Nullable
public String getReferenceTimezone() {
return referenceTimezone;
}
/** Returns a comma separated list of BCP 47 language tags. */
+ @Nullable
public String getDetectedTextLanguageTags() {
return detectedTextLanguageTags;
}
@@ -241,6 +323,33 @@
public ActionSuggestionOptions() {}
}
+ /** Represents a slot for an {@link ActionSuggestion}. */
+ public static final class Slot {
+
+ public final String type;
+ public final int messageIndex;
+ public final int startIndex;
+ public final int endIndex;
+ public final float confidenceScore;
+
+ public Slot(
+ String type, int messageIndex, int startIndex, int endIndex, float confidenceScore) {
+ this.type = type;
+ this.messageIndex = messageIndex;
+ this.startIndex = startIndex;
+ this.endIndex = endIndex;
+ this.confidenceScore = confidenceScore;
+ }
+ }
+
+ /**
+ * Retrieves the pointer to the native object. Note: Need to keep the {@code
+ * ActionsSuggestionsModel} alive as long as the pointer is used.
+ */
+ long getNativeModelPointer() {
+ return nativeGetNativeModelPtr(actionsModelPtr);
+ }
+
private static native long nativeNewActionsModel(int fd, byte[] serializedPreconditions);
private static native long nativeNewActionsModelFromPath(
@@ -249,6 +358,9 @@
private static native long nativeNewActionsModelWithOffset(
int fd, long offset, long size, byte[] preconditionsOverwrite);
+ private native boolean nativeInitializeConversationIntentDetection(
+ long actionsModelPtr, byte[] serializedConfig);
+
private static native String nativeGetLocales(int fd);
private static native String nativeGetLocalesWithOffset(int fd, long offset, long size);
@@ -261,7 +373,7 @@
private static native String nativeGetNameWithOffset(int fd, long offset, long size);
- private native ActionSuggestion[] nativeSuggestActions(
+ private native ActionSuggestions nativeSuggestActions(
long context,
Conversation conversation,
ActionSuggestionOptions options,
@@ -271,4 +383,6 @@
boolean generateAndroidIntents);
private native void nativeCloseActionsModel(long ptr);
+
+ private native long nativeGetNativeModelPtr(long context);
}
diff --git a/jni/com/google/android/textclassifier/AnnotatorModel.java b/jni/com/google/android/textclassifier/AnnotatorModel.java
index 7658bf5..47a369e 100644
--- a/jni/com/google/android/textclassifier/AnnotatorModel.java
+++ b/jni/com/google/android/textclassifier/AnnotatorModel.java
@@ -16,8 +16,10 @@
package com.google.android.textclassifier;
+import android.content.res.AssetFileDescriptor;
import java.util.Collection;
import java.util.concurrent.atomic.AtomicBoolean;
+import javax.annotation.Nullable;
/**
* Java wrapper for Annotator native library interface. This library is used for detecting entities
@@ -49,7 +51,7 @@
private long annotatorPtr;
// To tell GC to keep the LangID model alive at least as long as this object.
- private LangIdModel langIdModel;
+ @Nullable private LangIdModel langIdModel;
/** Enumeration for specifying the usecase of the annotations. */
public static enum AnnotationUsecase {
@@ -73,6 +75,25 @@
}
};
+ /** Enumeration for specifying the annotate mode. */
+ public static enum AnnotateMode {
+ /** Result contains entity annotation for each input fragment. */
+ ENTITY_ANNOTATION(0),
+
+ /** Result will include both entity annotation and topicality annotation. */
+ ENTITY_AND_TOPICALITY_ANNOTATION(1);
+
+ private final int value;
+
+ AnnotateMode(int value) {
+ this.value = value;
+ }
+
+ public int getValue() {
+ return value;
+ }
+ };
+
/**
* Creates a new instance of SmartSelect predictor, using the provided model image, given as a
* file descriptor.
@@ -95,6 +116,21 @@
}
}
+ /**
+ * Creates a new instance of SmartSelect predictor, using the provided model image, given as an
+ * {@link AssetFileDescriptor}.
+ */
+ public AnnotatorModel(AssetFileDescriptor assetFileDescriptor) {
+ annotatorPtr =
+ nativeNewAnnotatorWithOffset(
+ assetFileDescriptor.getParcelFileDescriptor().getFd(),
+ assetFileDescriptor.getStartOffset(),
+ assetFileDescriptor.getLength());
+ if (annotatorPtr == 0L) {
+ throw new IllegalArgumentException("Couldn't initialize TC from asset file descriptor.");
+ }
+ }
+
/** Initializes the knowledge engine, passing the given serialized config to it. */
public void initializeKnowledgeEngine(byte[] serializedConfig) {
if (!nativeInitializeKnowledgeEngine(annotatorPtr, serializedConfig)) {
@@ -121,12 +157,26 @@
* before this object is closed. Also, this object does not take the memory ownership of the given
* LangIdModel object.
*/
- public void setLangIdModel(LangIdModel langIdModel) {
+ public void setLangIdModel(@Nullable LangIdModel langIdModel) {
this.langIdModel = langIdModel;
nativeSetLangId(annotatorPtr, langIdModel == null ? 0 : langIdModel.getNativePointer());
}
/**
+ * Initializes the person name engine, using the provided model image, given as an {@link
+ * AssetFileDescriptor}.
+ */
+ public void initializePersonNameEngine(AssetFileDescriptor assetFileDescriptor) {
+ if (!nativeInitializePersonNameEngine(
+ annotatorPtr,
+ assetFileDescriptor.getParcelFileDescriptor().getFd(),
+ assetFileDescriptor.getStartOffset(),
+ assetFileDescriptor.getLength())) {
+ throw new IllegalArgumentException("Couldn't initialize the person name engine");
+ }
+ }
+
+ /**
* Given a string context and current selection, computes the selection suggestion.
*
* <p>The begin and end are character indices into the context UTF8 string. selectionBegin is the
@@ -183,8 +233,7 @@
* Annotates multiple fragments of text at once. There will be one AnnotatedSpan array for each
* input fragment to annotate.
*/
- public AnnotatedSpan[][] annotateStructuredInput(
- InputFragment[] fragments, AnnotationOptions options) {
+ public Annotations annotateStructuredInput(InputFragment[] fragments, AnnotationOptions options) {
return nativeAnnotateStructuredInput(annotatorPtr, fragments, options);
}
@@ -219,16 +268,40 @@
return nativeGetLocales(fd);
}
+ /** Returns a comma separated list of locales supported by the model as BCP 47 tags. */
+ public static String getLocales(AssetFileDescriptor assetFileDescriptor) {
+ return nativeGetLocalesWithOffset(
+ assetFileDescriptor.getParcelFileDescriptor().getFd(),
+ assetFileDescriptor.getStartOffset(),
+ assetFileDescriptor.getLength());
+ }
+
/** Returns the version of the model. */
public static int getVersion(int fd) {
return nativeGetVersion(fd);
}
+ /** Returns the version of the model. */
+ public static int getVersion(AssetFileDescriptor assetFileDescriptor) {
+ return nativeGetVersionWithOffset(
+ assetFileDescriptor.getParcelFileDescriptor().getFd(),
+ assetFileDescriptor.getStartOffset(),
+ assetFileDescriptor.getLength());
+ }
+
/** Returns the name of the model. */
public static String getName(int fd) {
return nativeGetName(fd);
}
+ /** Returns the name of the model. */
+ public static String getName(AssetFileDescriptor assetFileDescriptor) {
+ return nativeGetNameWithOffset(
+ assetFileDescriptor.getParcelFileDescriptor().getFd(),
+ assetFileDescriptor.getStartOffset(),
+ assetFileDescriptor.getLength());
+ }
+
/** Information about a parsed time/date. */
public static final class DatetimeResult {
@@ -261,20 +334,22 @@
public static final class ClassificationResult {
private final String collection;
private final float score;
- private final DatetimeResult datetimeResult;
- private final byte[] serializedKnowledgeResult;
- private final String contactName;
- private final String contactGivenName;
- private final String contactFamilyName;
- private final String contactNickname;
- private final String contactEmailAddress;
- private final String contactPhoneNumber;
- private final String contactId;
- private final String appName;
- private final String appPackageName;
- private final NamedVariant[] entityData;
- private final byte[] serializedEntityData;
- private final RemoteActionTemplate[] remoteActionTemplates;
+ @Nullable private final DatetimeResult datetimeResult;
+ @Nullable private final byte[] serializedKnowledgeResult;
+ @Nullable private final String contactName;
+ @Nullable private final String contactGivenName;
+ @Nullable private final String contactFamilyName;
+ @Nullable private final String contactNickname;
+ @Nullable private final String contactEmailAddress;
+ @Nullable private final String contactPhoneNumber;
+ @Nullable private final String contactAccountType;
+ @Nullable private final String contactAccountName;
+ @Nullable private final String contactId;
+ @Nullable private final String appName;
+ @Nullable private final String appPackageName;
+ @Nullable private final NamedVariant[] entityData;
+ @Nullable private final byte[] serializedEntityData;
+ @Nullable private final RemoteActionTemplate[] remoteActionTemplates;
private final long durationMs;
private final long numericValue;
private final double numericDoubleValue;
@@ -282,20 +357,22 @@
public ClassificationResult(
String collection,
float score,
- DatetimeResult datetimeResult,
- byte[] serializedKnowledgeResult,
- String contactName,
- String contactGivenName,
- String contactFamilyName,
- String contactNickname,
- String contactEmailAddress,
- String contactPhoneNumber,
- String contactId,
- String appName,
- String appPackageName,
- NamedVariant[] entityData,
- byte[] serializedEntityData,
- RemoteActionTemplate[] remoteActionTemplates,
+ @Nullable DatetimeResult datetimeResult,
+ @Nullable byte[] serializedKnowledgeResult,
+ @Nullable String contactName,
+ @Nullable String contactGivenName,
+ @Nullable String contactFamilyName,
+ @Nullable String contactNickname,
+ @Nullable String contactEmailAddress,
+ @Nullable String contactPhoneNumber,
+ @Nullable String contactAccountType,
+ @Nullable String contactAccountName,
+ @Nullable String contactId,
+ @Nullable String appName,
+ @Nullable String appPackageName,
+ @Nullable NamedVariant[] entityData,
+ @Nullable byte[] serializedEntityData,
+ @Nullable RemoteActionTemplate[] remoteActionTemplates,
long durationMs,
long numericValue,
double numericDoubleValue) {
@@ -309,6 +386,8 @@
this.contactNickname = contactNickname;
this.contactEmailAddress = contactEmailAddress;
this.contactPhoneNumber = contactPhoneNumber;
+ this.contactAccountType = contactAccountType;
+ this.contactAccountName = contactAccountName;
this.contactId = contactId;
this.appName = appName;
this.appPackageName = appPackageName;
@@ -330,58 +409,82 @@
return score;
}
+ @Nullable
public DatetimeResult getDatetimeResult() {
return datetimeResult;
}
+ @Nullable
public byte[] getSerializedKnowledgeResult() {
return serializedKnowledgeResult;
}
+ @Nullable
public String getContactName() {
return contactName;
}
+ @Nullable
public String getContactGivenName() {
return contactGivenName;
}
+ @Nullable
public String getContactFamilyName() {
return contactFamilyName;
}
+ @Nullable
public String getContactNickname() {
return contactNickname;
}
+ @Nullable
public String getContactEmailAddress() {
return contactEmailAddress;
}
+ @Nullable
public String getContactPhoneNumber() {
return contactPhoneNumber;
}
+ @Nullable
+ public String getContactAccountType() {
+ return contactAccountType;
+ }
+
+ @Nullable
+ public String getContactAccountName() {
+ return contactAccountName;
+ }
+
+ @Nullable
public String getContactId() {
return contactId;
}
+ @Nullable
public String getAppName() {
return appName;
}
+ @Nullable
public String getAppPackageName() {
return appPackageName;
}
+ @Nullable
public NamedVariant[] getEntityData() {
return entityData;
}
+ @Nullable
public byte[] getSerializedEntityData() {
return serializedEntityData;
}
+ @Nullable
public RemoteActionTemplate[] getRemoteActionTemplates() {
return remoteActionTemplates;
}
@@ -424,6 +527,28 @@
}
}
+ /**
+ * Represents a result of Annotate call, which will include both entity annotations and topicality
+ * annotations.
+ */
+ public static final class Annotations {
+ private final AnnotatedSpan[][] annotatedSpans;
+ private final ClassificationResult[] topicalityResults;
+
+ Annotations(AnnotatedSpan[][] annotatedSpans, ClassificationResult[] topicalityResults) {
+ this.annotatedSpans = annotatedSpans;
+ this.topicalityResults = topicalityResults;
+ }
+
+ public AnnotatedSpan[][] getAnnotatedSpans() {
+ return annotatedSpans;
+ }
+
+ public ClassificationResult[] getTopicalityResults() {
+ return topicalityResults;
+ }
+ }
+
/** Represents a fragment of text to the AnnotateStructuredInput call. */
public static final class InputFragment {
@@ -441,22 +566,40 @@
public InputFragment(String text) {
this.text = text;
this.datetimeOptionsNullable = null;
+ this.boundingBoxTop = 0;
+ this.boundingBoxHeight = 0;
}
- public InputFragment(String text, DatetimeOptions datetimeOptions) {
+ public InputFragment(
+ String text,
+ DatetimeOptions datetimeOptions,
+ float boundingBoxTop,
+ float boundingBoxHeight) {
this.text = text;
this.datetimeOptionsNullable = datetimeOptions;
+ this.boundingBoxTop = boundingBoxTop;
+ this.boundingBoxHeight = boundingBoxHeight;
}
private final String text;
// The DatetimeOptions can't be Optional because the _api16 build of the TCLib SDK does not
// support java.util.Optional.
private final DatetimeOptions datetimeOptionsNullable;
+ private final float boundingBoxTop;
+ private final float boundingBoxHeight;
public String getText() {
return text;
}
+ public float getBoundingBoxTop() {
+ return boundingBoxTop;
+ }
+
+ public float getBoundingBoxHeight() {
+ return boundingBoxHeight;
+ }
+
public boolean hasDatetimeOptions() {
return datetimeOptionsNullable != null;
}
@@ -470,37 +613,111 @@
}
}
- /**
- * Represents options for the suggestSelection call. TODO(b/63427420): Use location with Selection
- * options.
- */
+ /** Represents options for the suggestSelection call. */
public static final class SelectionOptions {
- private final String locales;
- private final String detectedTextLanguageTags;
+ @Nullable private final String locales;
+ @Nullable private final String detectedTextLanguageTags;
private final int annotationUsecase;
private final double userLocationLat;
private final double userLocationLng;
private final float userLocationAccuracyMeters;
+ private final boolean usePodNer;
+ private final boolean useVocabAnnotator;
- public SelectionOptions(
- String locales, String detectedTextLanguageTags, int annotationUsecase) {
+ private SelectionOptions(
+ @Nullable String locales,
+ @Nullable String detectedTextLanguageTags,
+ int annotationUsecase,
+ double userLocationLat,
+ double userLocationLng,
+ float userLocationAccuracyMeters,
+ boolean usePodNer,
+ boolean useVocabAnnotator) {
this.locales = locales;
this.detectedTextLanguageTags = detectedTextLanguageTags;
this.annotationUsecase = annotationUsecase;
- this.userLocationLat = INVALID_LATITUDE;
- this.userLocationLng = INVALID_LONGITUDE;
- this.userLocationAccuracyMeters = INVALID_LOCATION_ACCURACY_METERS;
+ this.userLocationLat = userLocationLat;
+ this.userLocationLng = userLocationLng;
+ this.userLocationAccuracyMeters = userLocationAccuracyMeters;
+ this.usePodNer = usePodNer;
+ this.useVocabAnnotator = useVocabAnnotator;
}
- public SelectionOptions(String locales, String detectedTextLanguageTags) {
- this(locales, detectedTextLanguageTags, AnnotationUsecase.SMART.getValue());
+ /** Can be used to build a SelectionsOptions instance. */
+ public static class Builder {
+ @Nullable private String locales;
+ @Nullable private String detectedTextLanguageTags;
+ private int annotationUsecase = AnnotationUsecase.SMART.getValue();
+ private double userLocationLat = INVALID_LATITUDE;
+ private double userLocationLng = INVALID_LONGITUDE;
+ private float userLocationAccuracyMeters = INVALID_LOCATION_ACCURACY_METERS;
+ private boolean usePodNer = true;
+ private boolean useVocabAnnotator = true;
+
+ public Builder setLocales(@Nullable String locales) {
+ this.locales = locales;
+ return this;
+ }
+
+ public Builder setDetectedTextLanguageTags(@Nullable String detectedTextLanguageTags) {
+ this.detectedTextLanguageTags = detectedTextLanguageTags;
+ return this;
+ }
+
+ public Builder setAnnotationUsecase(int annotationUsecase) {
+ this.annotationUsecase = annotationUsecase;
+ return this;
+ }
+
+ public Builder setUserLocationLat(double userLocationLat) {
+ this.userLocationLat = userLocationLat;
+ return this;
+ }
+
+ public Builder setUserLocationLng(double userLocationLng) {
+ this.userLocationLng = userLocationLng;
+ return this;
+ }
+
+ public Builder setUserLocationAccuracyMeters(float userLocationAccuracyMeters) {
+ this.userLocationAccuracyMeters = userLocationAccuracyMeters;
+ return this;
+ }
+
+ public Builder setUsePodNer(boolean usePodNer) {
+ this.usePodNer = usePodNer;
+ return this;
+ }
+
+ public Builder setUseVocabAnnotator(boolean useVocabAnnotator) {
+ this.useVocabAnnotator = useVocabAnnotator;
+ return this;
+ }
+
+ public SelectionOptions build() {
+ return new SelectionOptions(
+ locales,
+ detectedTextLanguageTags,
+ annotationUsecase,
+ userLocationLat,
+ userLocationLng,
+ userLocationAccuracyMeters,
+ usePodNer,
+ useVocabAnnotator);
+ }
}
+ public static Builder builder() {
+ return new Builder();
+ }
+
+ @Nullable
public String getLocales() {
return locales;
}
/** Returns a comma separated list of BCP 47 language tags. */
+ @Nullable
public String getDetectedTextLanguageTags() {
return detectedTextLanguageTags;
}
@@ -520,53 +737,153 @@
public float getUserLocationAccuracyMeters() {
return userLocationAccuracyMeters;
}
+
+ public boolean getUsePodNer() {
+ return usePodNer;
+ }
+
+ public boolean getUseVocabAnnotator() {
+ return useVocabAnnotator;
+ }
}
- /**
- * Represents options for the classifyText call. TODO(b/63427420): Use location with
- * Classification options.
- */
+ /** Represents options for the classifyText call. */
public static final class ClassificationOptions {
private final long referenceTimeMsUtc;
private final String referenceTimezone;
- private final String locales;
- private final String detectedTextLanguageTags;
+ @Nullable private final String locales;
+ @Nullable private final String detectedTextLanguageTags;
private final int annotationUsecase;
private final double userLocationLat;
private final double userLocationLng;
private final float userLocationAccuracyMeters;
private final String userFamiliarLanguageTags;
+ private final boolean usePodNer;
+ private final boolean triggerDictionaryOnBeginnerWords;
+ private final boolean useVocabAnnotator;
- public ClassificationOptions(
+ private ClassificationOptions(
long referenceTimeMsUtc,
String referenceTimezone,
- String locales,
- String detectedTextLanguageTags,
+ @Nullable String locales,
+ @Nullable String detectedTextLanguageTags,
int annotationUsecase,
- String userFamiliarLanguageTags) {
+ double userLocationLat,
+ double userLocationLng,
+ float userLocationAccuracyMeters,
+ String userFamiliarLanguageTags,
+ boolean usePodNer,
+ boolean triggerDictionaryOnBeginnerWords,
+ boolean useVocabAnnotator) {
this.referenceTimeMsUtc = referenceTimeMsUtc;
this.referenceTimezone = referenceTimezone;
this.locales = locales;
this.detectedTextLanguageTags = detectedTextLanguageTags;
this.annotationUsecase = annotationUsecase;
- this.userLocationLat = INVALID_LATITUDE;
- this.userLocationLng = INVALID_LONGITUDE;
- this.userLocationAccuracyMeters = INVALID_LOCATION_ACCURACY_METERS;
+ this.userLocationLat = userLocationLat;
+ this.userLocationLng = userLocationLng;
+ this.userLocationAccuracyMeters = userLocationAccuracyMeters;
this.userFamiliarLanguageTags = userFamiliarLanguageTags;
+ this.usePodNer = usePodNer;
+ this.triggerDictionaryOnBeginnerWords = triggerDictionaryOnBeginnerWords;
+ this.useVocabAnnotator = useVocabAnnotator;
}
- public ClassificationOptions(
- long referenceTimeMsUtc,
- String referenceTimezone,
- String locales,
- String detectedTextLanguageTags) {
- this(
- referenceTimeMsUtc,
- referenceTimezone,
- locales,
- detectedTextLanguageTags,
- AnnotationUsecase.SMART.getValue(),
- "");
+ /** Can be used to build a ClassificationOptions instance. */
+ public static class Builder {
+ private long referenceTimeMsUtc;
+ @Nullable private String referenceTimezone;
+ @Nullable private String locales;
+ @Nullable private String detectedTextLanguageTags;
+ private int annotationUsecase = AnnotationUsecase.SMART.getValue();
+ private double userLocationLat = INVALID_LATITUDE;
+ private double userLocationLng = INVALID_LONGITUDE;
+ private float userLocationAccuracyMeters = INVALID_LOCATION_ACCURACY_METERS;
+ private String userFamiliarLanguageTags = "";
+ private boolean usePodNer = true;
+ private boolean triggerDictionaryOnBeginnerWords = false;
+ private boolean useVocabAnnotator = true;
+
+ public Builder setReferenceTimeMsUtc(long referenceTimeMsUtc) {
+ this.referenceTimeMsUtc = referenceTimeMsUtc;
+ return this;
+ }
+
+ public Builder setReferenceTimezone(String referenceTimezone) {
+ this.referenceTimezone = referenceTimezone;
+ return this;
+ }
+
+ public Builder setLocales(@Nullable String locales) {
+ this.locales = locales;
+ return this;
+ }
+
+ public Builder setDetectedTextLanguageTags(@Nullable String detectedTextLanguageTags) {
+ this.detectedTextLanguageTags = detectedTextLanguageTags;
+ return this;
+ }
+
+ public Builder setAnnotationUsecase(int annotationUsecase) {
+ this.annotationUsecase = annotationUsecase;
+ return this;
+ }
+
+ public Builder setUserLocationLat(double userLocationLat) {
+ this.userLocationLat = userLocationLat;
+ return this;
+ }
+
+ public Builder setUserLocationLng(double userLocationLng) {
+ this.userLocationLng = userLocationLng;
+ return this;
+ }
+
+ public Builder setUserLocationAccuracyMeters(float userLocationAccuracyMeters) {
+ this.userLocationAccuracyMeters = userLocationAccuracyMeters;
+ return this;
+ }
+
+ public Builder setUserFamiliarLanguageTags(String userFamiliarLanguageTags) {
+ this.userFamiliarLanguageTags = userFamiliarLanguageTags;
+ return this;
+ }
+
+ public Builder setUsePodNer(boolean usePodNer) {
+ this.usePodNer = usePodNer;
+ return this;
+ }
+
+ public Builder setTrigerringDictionaryOnBeginnerWords(
+ boolean triggerDictionaryOnBeginnerWords) {
+ this.triggerDictionaryOnBeginnerWords = triggerDictionaryOnBeginnerWords;
+ return this;
+ }
+
+ public Builder setUseVocabAnnotator(boolean useVocabAnnotator) {
+ this.useVocabAnnotator = useVocabAnnotator;
+ return this;
+ }
+
+ public ClassificationOptions build() {
+ return new ClassificationOptions(
+ referenceTimeMsUtc,
+ referenceTimezone,
+ locales,
+ detectedTextLanguageTags,
+ annotationUsecase,
+ userLocationLat,
+ userLocationLng,
+ userLocationAccuracyMeters,
+ userFamiliarLanguageTags,
+ usePodNer,
+ triggerDictionaryOnBeginnerWords,
+ useVocabAnnotator);
+ }
+ }
+
+ public static Builder builder() {
+ return new Builder();
}
public long getReferenceTimeMsUtc() {
@@ -577,11 +894,13 @@
return referenceTimezone;
}
+ @Nullable
public String getLocale() {
return locales;
}
/** Returns a comma separated list of BCP 47 language tags. */
+ @Nullable
public String getDetectedTextLanguageTags() {
return detectedTextLanguageTags;
}
@@ -605,15 +924,28 @@
public String getUserFamiliarLanguageTags() {
return userFamiliarLanguageTags;
}
+
+ public boolean getUsePodNer() {
+ return usePodNer;
+ }
+
+ public boolean getTriggerDictionaryOnBeginnerWords() {
+ return triggerDictionaryOnBeginnerWords;
+ }
+
+ public boolean getUseVocabAnnotator() {
+ return useVocabAnnotator;
+ }
}
/** Represents options for the annotate call. */
public static final class AnnotationOptions {
private final long referenceTimeMsUtc;
private final String referenceTimezone;
- private final String locales;
- private final String detectedTextLanguageTags;
+ @Nullable private final String locales;
+ @Nullable private final String detectedTextLanguageTags;
private final String[] entityTypes;
+ private final int annotateMode;
private final int annotationUsecase;
private final boolean hasLocationPermission;
private final boolean hasPersonalizationPermission;
@@ -621,25 +953,33 @@
private final double userLocationLat;
private final double userLocationLng;
private final float userLocationAccuracyMeters;
+ private final boolean usePodNer;
+ private final boolean triggerDictionaryOnBeginnerWords;
+ private final boolean useVocabAnnotator;
- public AnnotationOptions(
+ private AnnotationOptions(
long referenceTimeMsUtc,
String referenceTimezone,
- String locales,
- String detectedTextLanguageTags,
- Collection<String> entityTypes,
+ @Nullable String locales,
+ @Nullable String detectedTextLanguageTags,
+ @Nullable Collection<String> entityTypes,
+ int annotateMode,
int annotationUsecase,
boolean hasLocationPermission,
boolean hasPersonalizationPermission,
boolean isSerializedEntityDataEnabled,
double userLocationLat,
double userLocationLng,
- float userLocationAccuracyMeters) {
+ float userLocationAccuracyMeters,
+ boolean usePodNer,
+ boolean triggerDictionaryOnBeginnerWords,
+ boolean useVocabAnnotator) {
this.referenceTimeMsUtc = referenceTimeMsUtc;
this.referenceTimezone = referenceTimezone;
this.locales = locales;
this.detectedTextLanguageTags = detectedTextLanguageTags;
this.entityTypes = entityTypes == null ? new String[0] : entityTypes.toArray(new String[0]);
+ this.annotateMode = annotateMode;
this.annotationUsecase = annotationUsecase;
this.isSerializedEntityDataEnabled = isSerializedEntityDataEnabled;
this.userLocationLat = userLocationLat;
@@ -647,68 +987,133 @@
this.userLocationAccuracyMeters = userLocationAccuracyMeters;
this.hasLocationPermission = hasLocationPermission;
this.hasPersonalizationPermission = hasPersonalizationPermission;
+ this.usePodNer = usePodNer;
+ this.triggerDictionaryOnBeginnerWords = triggerDictionaryOnBeginnerWords;
+ this.useVocabAnnotator = useVocabAnnotator;
}
- public AnnotationOptions(
- long referenceTimeMsUtc,
- String referenceTimezone,
- String locales,
- String detectedTextLanguageTags,
- Collection<String> entityTypes,
- int annotationUsecase,
- boolean isSerializedEntityDataEnabled,
- double userLocationLat,
- double userLocationLng,
- float userLocationAccuracyMeters) {
- this(
- referenceTimeMsUtc,
- referenceTimezone,
- locales,
- detectedTextLanguageTags,
- entityTypes,
- annotationUsecase,
- /* hasLocationPermission */ true,
- /* hasPersonalizationPermission */ true,
- isSerializedEntityDataEnabled,
- userLocationLat,
- userLocationLng,
- userLocationAccuracyMeters);
+ /** Can be used to build an AnnotationOptions instance. */
+ public static class Builder {
+ private long referenceTimeMsUtc;
+ @Nullable private String referenceTimezone;
+ @Nullable private String locales;
+ @Nullable private String detectedTextLanguageTags;
+ @Nullable private Collection<String> entityTypes;
+ private int annotateMode = AnnotateMode.ENTITY_ANNOTATION.getValue();
+ private int annotationUsecase = AnnotationUsecase.SMART.getValue();
+ private boolean hasLocationPermission = true;
+ private boolean hasPersonalizationPermission = true;
+ private boolean isSerializedEntityDataEnabled = false;
+ private double userLocationLat = INVALID_LATITUDE;
+ private double userLocationLng = INVALID_LONGITUDE;
+ private float userLocationAccuracyMeters = INVALID_LOCATION_ACCURACY_METERS;
+ private boolean usePodNer = true;
+ private boolean triggerDictionaryOnBeginnerWords = false;
+ private boolean useVocabAnnotator = true;
+
+ public Builder setReferenceTimeMsUtc(long referenceTimeMsUtc) {
+ this.referenceTimeMsUtc = referenceTimeMsUtc;
+ return this;
+ }
+
+ public Builder setReferenceTimezone(String referenceTimezone) {
+ this.referenceTimezone = referenceTimezone;
+ return this;
+ }
+
+ public Builder setLocales(@Nullable String locales) {
+ this.locales = locales;
+ return this;
+ }
+
+ public Builder setDetectedTextLanguageTags(@Nullable String detectedTextLanguageTags) {
+ this.detectedTextLanguageTags = detectedTextLanguageTags;
+ return this;
+ }
+
+ public Builder setEntityTypes(Collection<String> entityTypes) {
+ this.entityTypes = entityTypes;
+ return this;
+ }
+
+ public Builder setAnnotateMode(int annotateMode) {
+ this.annotateMode = annotateMode;
+ return this;
+ }
+
+ public Builder setAnnotationUsecase(int annotationUsecase) {
+ this.annotationUsecase = annotationUsecase;
+ return this;
+ }
+
+ public Builder setHasLocationPermission(boolean hasLocationPermission) {
+ this.hasLocationPermission = hasLocationPermission;
+ return this;
+ }
+
+ public Builder setHasPersonalizationPermission(boolean hasPersonalizationPermission) {
+ this.hasPersonalizationPermission = hasPersonalizationPermission;
+ return this;
+ }
+
+ public Builder setIsSerializedEntityDataEnabled(boolean isSerializedEntityDataEnabled) {
+ this.isSerializedEntityDataEnabled = isSerializedEntityDataEnabled;
+ return this;
+ }
+
+ public Builder setUserLocationLat(double userLocationLat) {
+ this.userLocationLat = userLocationLat;
+ return this;
+ }
+
+ public Builder setUserLocationLng(double userLocationLng) {
+ this.userLocationLng = userLocationLng;
+ return this;
+ }
+
+ public Builder setUserLocationAccuracyMeters(float userLocationAccuracyMeters) {
+ this.userLocationAccuracyMeters = userLocationAccuracyMeters;
+ return this;
+ }
+
+ public Builder setUsePodNer(boolean usePodNer) {
+ this.usePodNer = usePodNer;
+ return this;
+ }
+
+ public Builder setTriggerDictionaryOnBeginnerWords(boolean triggerDictionaryOnBeginnerWords) {
+ this.triggerDictionaryOnBeginnerWords = triggerDictionaryOnBeginnerWords;
+ return this;
+ }
+
+ public Builder setUseVocabAnnotator(boolean useVocabAnnotator) {
+ this.useVocabAnnotator = useVocabAnnotator;
+ return this;
+ }
+
+ public AnnotationOptions build() {
+ return new AnnotationOptions(
+ referenceTimeMsUtc,
+ referenceTimezone,
+ locales,
+ detectedTextLanguageTags,
+ entityTypes,
+ annotateMode,
+ annotationUsecase,
+ hasLocationPermission,
+ hasPersonalizationPermission,
+ isSerializedEntityDataEnabled,
+ userLocationLat,
+ userLocationLng,
+ userLocationAccuracyMeters,
+ usePodNer,
+ triggerDictionaryOnBeginnerWords,
+ useVocabAnnotator);
+ }
}
- public AnnotationOptions(
- long referenceTimeMsUtc,
- String referenceTimezone,
- String locales,
- String detectedTextLanguageTags,
- Collection<String> entityTypes,
- int annotationUsecase,
- boolean isSerializedEntityDataEnabled) {
- this(
- referenceTimeMsUtc,
- referenceTimezone,
- locales,
- detectedTextLanguageTags,
- entityTypes,
- annotationUsecase,
- isSerializedEntityDataEnabled,
- INVALID_LATITUDE,
- INVALID_LONGITUDE,
- INVALID_LOCATION_ACCURACY_METERS);
- }
-
- public AnnotationOptions(
- long referenceTimeMsUtc,
- String referenceTimezone,
- String locales,
- String detectedTextLanguageTags) {
- this(
- referenceTimeMsUtc,
- referenceTimezone,
- locales,
- detectedTextLanguageTags,
- null,
- AnnotationUsecase.SMART.getValue(),
- /* isSerializedEntityDataEnabled */ false);
+ public static Builder builder() {
+ return new Builder();
}
public long getReferenceTimeMsUtc() {
@@ -719,11 +1124,13 @@
return referenceTimezone;
}
+ @Nullable
public String getLocale() {
return locales;
}
/** Returns a comma separated list of BCP 47 language tags. */
+ @Nullable
public String getDetectedTextLanguageTags() {
return detectedTextLanguageTags;
}
@@ -732,6 +1139,10 @@
return entityTypes;
}
+ public int getAnnotateMode() {
+ return annotateMode;
+ }
+
public int getAnnotationUsecase() {
return annotationUsecase;
}
@@ -759,6 +1170,18 @@
public boolean hasPersonalizationPermission() {
return hasPersonalizationPermission;
}
+
+ public boolean getUsePodNer() {
+ return usePodNer;
+ }
+
+ public boolean getTriggerDictionaryOnBeginnerWords() {
+ return triggerDictionaryOnBeginnerWords;
+ }
+
+ public boolean getUseVocabAnnotator() {
+ return useVocabAnnotator;
+ }
}
/**
@@ -815,7 +1238,7 @@
private native AnnotatedSpan[] nativeAnnotate(
long context, String text, AnnotationOptions options);
- private native AnnotatedSpan[][] nativeAnnotateStructuredInput(
+ private native Annotations nativeAnnotateStructuredInput(
long context, InputFragment[] inputFragments, AnnotationOptions options);
private native byte[] nativeLookUpKnowledgeEntity(long context, String id);
diff --git a/jni/com/google/android/textclassifier/LangIdModel.java b/jni/com/google/android/textclassifier/LangIdModel.java
index 0015826..890c9b0 100644
--- a/jni/com/google/android/textclassifier/LangIdModel.java
+++ b/jni/com/google/android/textclassifier/LangIdModel.java
@@ -16,6 +16,7 @@
package com.google.android.textclassifier;
+import android.content.res.AssetFileDescriptor;
import java.util.concurrent.atomic.AtomicBoolean;
/**
@@ -48,6 +49,29 @@
}
}
+ /**
+ * Creates a new instance of LangId predictor, using the provided model image, given as an {@link
+ * AssetFileDescriptor}.
+ */
+ public LangIdModel(AssetFileDescriptor assetFileDescriptor) {
+ modelPtr =
+ nativeNewWithOffset(
+ assetFileDescriptor.getParcelFileDescriptor().getFd(),
+ assetFileDescriptor.getStartOffset(),
+ assetFileDescriptor.getLength());
+ if (modelPtr == 0L) {
+ throw new IllegalArgumentException("Couldn't initialize LangId from asset file descriptor.");
+ }
+ }
+
+ /** Creates a new instance of LangId predictor, using the provided model image. */
+ public LangIdModel(int fd, long offset, long size) {
+ modelPtr = nativeNewWithOffset(fd, offset, size);
+ if (modelPtr == 0L) {
+ throw new IllegalArgumentException("Couldn't initialize LangId from given file descriptor.");
+ }
+ }
+
/** Detects the languages for given text. */
public LanguageResult[] detectLanguages(String text) {
return nativeDetectLanguages(modelPtr, text);
@@ -95,14 +119,22 @@
return nativeGetVersion(modelPtr);
}
- public float getLangIdThreshold() {
- return nativeGetLangIdThreshold(modelPtr);
- }
-
public static int getVersion(int fd) {
return nativeGetVersionFromFd(fd);
}
+ /** Returns the version of the model. */
+ public static int getVersion(AssetFileDescriptor assetFileDescriptor) {
+ return nativeGetVersionWithOffset(
+ assetFileDescriptor.getParcelFileDescriptor().getFd(),
+ assetFileDescriptor.getStartOffset(),
+ assetFileDescriptor.getLength());
+ }
+
+ public float getLangIdThreshold() {
+ return nativeGetLangIdThreshold(modelPtr);
+ }
+
/** Retrieves the pointer to the native object. */
long getNativePointer() {
return modelPtr;
@@ -130,6 +162,8 @@
private static native long nativeNewFromPath(String path);
+ private static native long nativeNewWithOffset(int fd, long offset, long size);
+
private native LanguageResult[] nativeDetectLanguages(long nativePtr, String text);
private native void nativeClose(long nativePtr);
@@ -143,4 +177,6 @@
private native float nativeGetLangIdNoiseThreshold(long nativePtr);
private native int nativeGetMinTextSizeInBytes(long nativePtr);
+
+ private static native int nativeGetVersionWithOffset(int fd, long offset, long size);
}
diff --git a/native/Android.bp b/native/Android.bp
index 352b12a..61b7ade 100644
--- a/native/Android.bp
+++ b/native/Android.bp
@@ -39,7 +39,6 @@
host_supported: true,
srcs: [
"utils/hash/farmhash.cc",
- "util/hash/hash.cc",
],
cflags: [
"-DNAMESPACE_FOR_HASH_FUNCTIONS=farmhash",
@@ -93,13 +92,16 @@
"-funsigned-char",
"-fvisibility=hidden",
+
"-DLIBTEXTCLASSIFIER_UNILIB_ICU",
"-DZLIB_CONST",
"-DSAFTM_COMPACT_LOGGING",
"-DTC3_WITH_ACTIONS_OPS",
"-DTC3_UNILIB_JAVAICU",
"-DTC3_CALENDAR_JAVAICU",
- "-DTC3_AOSP"
+ "-DTC3_AOSP",
+ "-DTC3_VOCAB_ANNOTATOR_IMPL",
+ "-DTC3_POD_NER_ANNOTATOR_IMPL",
],
product_variables: {
@@ -109,33 +111,11 @@
},
},
- generated_headers: [
- "libtextclassifier_fbgen_flatbuffers",
- "libtextclassifier_fbgen_tokenizer",
- "libtextclassifier_fbgen_codepoint_range",
- "libtextclassifier_fbgen_entity-data",
- "libtextclassifier_fbgen_zlib_buffer",
- "libtextclassifier_fbgen_resources_extra",
- "libtextclassifier_fbgen_intent_config",
- "libtextclassifier_fbgen_annotator_model",
- "libtextclassifier_fbgen_annotator_experimental_model",
- "libtextclassifier_fbgen_actions_model",
- "libtextclassifier_fbgen_tflite_text_encoder_config",
- "libtextclassifier_fbgen_lang_id_embedded_network",
- "libtextclassifier_fbgen_lang_id_model",
- "libtextclassifier_fbgen_actions-entity-data",
- "libtextclassifier_fbgen_normalization",
- "libtextclassifier_fbgen_language-tag",
- "libtextclassifier_fbgen_person_name_model",
- "libtextclassifier_fbgen_grammar_dates",
- "libtextclassifier_fbgen_timezone_code",
- "libtextclassifier_fbgen_grammar_rules"
- ],
-
header_libs: [
"jni_headers",
"tensorflow_headers",
"flatbuffer_headers",
+ "libtextclassifier_flatbuffer_headers",
],
shared_libs: [
@@ -144,18 +124,20 @@
],
static_libs: [
+ "marisa-trie",
+ "libtextclassifier_abseil",
"liblua",
"libutf",
"libtflite_static",
+ "tflite_support"
],
- min_sdk_version: "30",
}
// -----------------
// Generate headers with FlatBuffer schema compiler.
// -----------------
genrule_defaults {
- name: "fbgen",
+ name: "fbgen",
tools: ["flatc"],
// "depfile" is used here in conjunction with flatc's -M to gather the deps
cmd: "$(location flatc) --cpp --no-union-value-namespacing --gen-object-api --keep-prefix -I external/libtextclassifier/native -M $(in) >$(depfile) && " +
@@ -164,161 +146,37 @@
}
genrule {
- name: "libtextclassifier_fbgen_flatbuffers",
- srcs: ["utils/flatbuffers.fbs"],
- out: ["utils/flatbuffers_generated.h"],
+ name: "libtextclassifier_fbgen_utils_flatbuffers_flatbuffers_test",
+ srcs: ["utils/flatbuffers/flatbuffers_test.fbs"],
+ out: ["utils/flatbuffers/flatbuffers_test_generated.h"],
defaults: ["fbgen"],
}
genrule {
- name: "libtextclassifier_fbgen_tokenizer",
- srcs: ["utils/tokenizer.fbs"],
- out: ["utils/tokenizer_generated.h"],
- defaults: ["fbgen"],
-}
-
-genrule {
- name: "libtextclassifier_fbgen_codepoint_range",
- srcs: ["utils/codepoint-range.fbs"],
- out: ["utils/codepoint-range_generated.h"],
- defaults: ["fbgen"],
-}
-
-genrule {
- name: "libtextclassifier_fbgen_resources_extra",
- srcs: ["utils/resources.fbs"],
- out: ["utils/resources_generated.h"],
- defaults: ["fbgen"],
-}
-
-genrule {
- name: "libtextclassifier_fbgen_entity-data",
- srcs: ["annotator/entity-data.fbs"],
- out: ["annotator/entity-data_generated.h"],
- defaults: ["fbgen"],
-}
-
-genrule {
- name: "libtextclassifier_fbgen_zlib_buffer",
- srcs: ["utils/zlib/buffer.fbs"],
- out: ["utils/zlib/buffer_generated.h"],
- defaults: ["fbgen"],
-}
-
-genrule {
- name: "libtextclassifier_fbgen_intent_config",
- srcs: ["utils/intents/intent-config.fbs"],
- out: ["utils/intents/intent-config_generated.h"],
- defaults: ["fbgen"],
-}
-
-genrule {
- name: "libtextclassifier_fbgen_annotator_model",
- srcs: ["annotator/model.fbs"],
- out: ["annotator/model_generated.h"],
- defaults: ["fbgen"],
-}
-
-genrule {
- name: "libtextclassifier_fbgen_annotator_experimental_model",
- srcs: ["annotator/experimental/experimental.fbs"],
- out: ["annotator/experimental/experimental_generated.h"],
- defaults: ["fbgen"],
-}
-
-genrule {
- name: "libtextclassifier_fbgen_actions_model",
- srcs: ["actions/actions_model.fbs"],
- out: ["actions/actions_model_generated.h"],
- defaults: ["fbgen"],
-}
-
-genrule {
- name: "libtextclassifier_fbgen_tflite_text_encoder_config",
- srcs: ["utils/tflite/text_encoder_config.fbs"],
- out: ["utils/tflite/text_encoder_config_generated.h"],
- defaults: ["fbgen"],
-}
-
-genrule {
- name: "libtextclassifier_fbgen_lang_id_embedded_network",
- srcs: ["lang_id/common/flatbuffers/embedding-network.fbs"],
- out: ["lang_id/common/flatbuffers/embedding-network_generated.h"],
- defaults: ["fbgen"],
-}
-
-genrule {
- name: "libtextclassifier_fbgen_lang_id_model",
- srcs: ["lang_id/common/flatbuffers/model.fbs"],
- out: ["lang_id/common/flatbuffers/model_generated.h"],
- defaults: ["fbgen"],
-}
-
-genrule {
- name: "libtextclassifier_fbgen_actions-entity-data",
- srcs: ["actions/actions-entity-data.fbs"],
- out: ["actions/actions-entity-data_generated.h"],
- defaults: ["fbgen"],
-}
-
-genrule {
- name: "libtextclassifier_fbgen_normalization",
- srcs: ["utils/normalization.fbs"],
- out: ["utils/normalization_generated.h"],
- defaults: ["fbgen"],
-}
-
-genrule {
- name: "libtextclassifier_fbgen_language-tag",
- srcs: ["utils/i18n/language-tag.fbs"],
- out: ["utils/i18n/language-tag_generated.h"],
- defaults: ["fbgen"],
-}
-
-genrule {
- name: "libtextclassifier_fbgen_person_name_model",
- srcs: ["annotator/person_name/person_name_model.fbs"],
- out: ["annotator/person_name/person_name_model_generated.h"],
- defaults: ["fbgen"],
-}
-
-genrule {
- name: "libtextclassifier_fbgen_grammar_dates",
- srcs: ["annotator/grammar/dates/dates.fbs"],
- out: ["annotator/grammar/dates/dates_generated.h"],
- defaults: ["fbgen"],
-}
-
-genrule {
- name: "libtextclassifier_fbgen_timezone_code",
- srcs: ["annotator/grammar/dates/timezone-code.fbs"],
- out: ["annotator/grammar/dates/timezone-code_generated.h"],
- defaults: ["fbgen"],
-}
-
-genrule {
- name: "libtextclassifier_fbgen_grammar_rules",
- srcs: ["utils/grammar/rules.fbs"],
- out: ["utils/grammar/rules_generated.h"],
+ name: "libtextclassifier_fbgen_utils_lua_utils_tests",
+ srcs: ["utils/lua_utils_tests.fbs"],
+ out: ["utils/lua_utils_tests_generated.h"],
defaults: ["fbgen"],
}
// -----------------
// libtextclassifier
// -----------------
-cc_library_shared {
+cc_library {
name: "libtextclassifier",
defaults: ["libtextclassifier_defaults"],
-
+ min_sdk_version: "30",
srcs: ["**/*.cc"],
exclude_srcs: [
- "**/*_test.cc",
- "**/*-test-lib.cc",
- "**/testing/*.cc",
+ "**/*_test.*",
+ "**/*-test-lib.*",
+ "**/testing/*.*",
"**/*test-util.*",
"**/*test-utils.*",
+ "**/*test_util.*",
+ "**/*test_utils.*",
"**/*_test-include.*",
- "**/*unittest.cc",
+ "**/*unittest.*",
],
version_script: "jni.lds",
@@ -336,32 +194,87 @@
name: "libtextclassifier_tests",
defaults: ["libtextclassifier_defaults"],
- test_suites: ["device-tests", "mts-extservices"],
+ test_suites: ["general-tests", "mts-extservices"],
data: [
- "annotator/test_data/**/*",
- "actions/test_data/**/*",
+ "**/test_data/*",
+ "**/*.bfbs",
],
srcs: ["**/*.cc"],
+ exclude_srcs: [":libtextclassifier_java_test_sources"],
header_libs: ["jni_headers"],
static_libs: [
"libgmock_ndk",
"libgtest_ndk_c++",
+ "libbase_ndk",
],
- multilib: {
- lib32: {
- suffix: "32",
- cppflags: ["-DTC3_TEST_DATA_DIR=\"/data/nativetest/libtextclassifier_tests/test_data/\""],
- },
- lib64: {
- suffix: "64",
- cppflags: ["-DTC3_TEST_DATA_DIR=\"/data/nativetest64/libtextclassifier_tests/test_data/\""],
- },
- },
+ generated_headers: [
+ "libtextclassifier_fbgen_utils_flatbuffers_flatbuffers_test",
+ "libtextclassifier_fbgen_utils_lua_utils_tests",
+ ],
+
+ compile_multilib: "prefer32",
+
+ // A workaround for code coverage. See b/166040889#comment23
+ sdk_variant_only: true,
+}
+
+// ------------------------------------
+// Native tests require the JVM to run
+// ------------------------------------
+cc_test_library {
+ name: "libjvm_test_launcher",
+ defaults: ["libtextclassifier_defaults"],
+ srcs: [
+ ":libtextclassifier_java_test_sources",
+ "annotator/datetime/testing/*.cc",
+ "actions/test-utils.cc",
+ "utils/testing/annotator.cc",
+ "utils/testing/logging_event_listener.cc",
+ "testing/jvm_test_launcher.cc",
+ ],
+ version_script: "jni.lds",
+ static_libs: [
+ "libgmock_ndk",
+ "libgtest_ndk_c++",
+ "libbase_ndk",
+ "libtextclassifier",
+ ],
+ header_libs: [
+ "libtextclassifier_flatbuffer_testonly_headers",
+ ],
+}
+
+android_test {
+ name: "libtextclassifier_java_tests",
+ srcs: ["testing/JvmTestLauncher.java"],
+ min_sdk_version: "30",
+ test_suites: [
+ "general-tests",
+ "mts-extservices",
+ ],
+ static_libs: [
+ "androidx.test.ext.junit",
+ "androidx.test.rules",
+ "androidx.test.espresso.core",
+ "androidx.test.ext.truth",
+ "truth-prebuilt",
+ "TextClassifierCoverageLib",
+ ],
+ jni_libs: [
+ "libjvm_test_launcher",
+ ],
+ jni_uses_sdk_apis: true,
+ data: [
+ "**/*.bfbs",
+ "**/test_data/*",
+ ],
+ test_config: "JavaTest.xml",
+ compile_multilib: "both",
}
// ----------------
@@ -407,3 +320,5 @@
src: "models/lang_id.model",
sub_dir: "textclassifier",
}
+
+build = ["FlatBufferHeaders.bp", "JavaTests.bp"]
diff --git a/native/AndroidManifest.xml b/native/AndroidManifest.xml
new file mode 100644
index 0000000..ddf4c4c
--- /dev/null
+++ b/native/AndroidManifest.xml
@@ -0,0 +1,14 @@
+<?xml version="1.0" encoding="utf-8"?>
+<manifest xmlns:android="http://schemas.android.com/apk/res/android"
+ package="com.google.android.textclassifier.tests">
+
+ <uses-sdk android:minSdkVersion="30" android:targetSdkVersion="30"/>
+
+ <application>
+ <uses-library android:name="android.test.runner"/>
+ </application>
+
+ <instrumentation
+ android:name="androidx.test.runner.AndroidJUnitRunner"
+ android:targetPackage="com.google.android.textclassifier.tests"/>
+</manifest>
diff --git a/native/AndroidTest.xml b/native/AndroidTest.xml
index cee26dd..6f707e0 100644
--- a/native/AndroidTest.xml
+++ b/native/AndroidTest.xml
@@ -14,17 +14,20 @@
limitations under the License.
-->
<configuration description="Config for libtextclassifier_tests">
+ <option name="config-descriptor:metadata" key="mainline-param" value="com.google.android.extservices.apex" />
<option name="test-suite-tag" value="apct" />
<option name="test-suite-tag" value="mts" />
<target_preparer class="com.android.compatibility.common.tradefed.targetprep.FilePusher">
<option name="cleanup" value="true" />
- <option name="push" value="libtextclassifier_tests->/data/local/tmp/libtextclassifier_tests" />
- <option name="append-bitness" value="true" />
+ <option name="push" value="libtextclassifier_tests->/data/local/tests/unrestricted/libtextclassifier_tests" />
+ <option name="push" value="actions->/data/local/tests/unrestricted/actions" />
+ <option name="push" value="annotator->/data/local/tests/unrestricted/annotator" />
+ <option name="push" value="utils->/data/local/tests/unrestricted/utils" />
</target_preparer>
<test class="com.android.tradefed.testtype.GTest" >
- <option name="native-test-device-path" value="/data/local/tmp" />
+ <option name="native-test-device-path" value="/data/local/tests/unrestricted" />
<option name="module-name" value="libtextclassifier_tests" />
</test>
diff --git a/native/FlatBufferHeaders.bp b/native/FlatBufferHeaders.bp
new file mode 100644
index 0000000..950eee6
--- /dev/null
+++ b/native/FlatBufferHeaders.bp
@@ -0,0 +1,243 @@
+//
+// Copyright (C) 2018 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+genrule {
+ name: "libtextclassifier_fbgen_actions_actions_model",
+ srcs: ["actions/actions_model.fbs"],
+ out: ["actions/actions_model_generated.h"],
+ defaults: ["fbgen"],
+}
+
+genrule {
+ name: "libtextclassifier_fbgen_actions_actions-entity-data",
+ srcs: ["actions/actions-entity-data.fbs"],
+ out: ["actions/actions-entity-data_generated.h"],
+ defaults: ["fbgen"],
+}
+
+genrule {
+ name: "libtextclassifier_fbgen_lang_id_common_flatbuffers_embedding-network",
+ srcs: ["lang_id/common/flatbuffers/embedding-network.fbs"],
+ out: ["lang_id/common/flatbuffers/embedding-network_generated.h"],
+ defaults: ["fbgen"],
+}
+
+genrule {
+ name: "libtextclassifier_fbgen_lang_id_common_flatbuffers_model",
+ srcs: ["lang_id/common/flatbuffers/model.fbs"],
+ out: ["lang_id/common/flatbuffers/model_generated.h"],
+ defaults: ["fbgen"],
+}
+
+genrule {
+ name: "libtextclassifier_fbgen_annotator_person_name_person_name_model",
+ srcs: ["annotator/person_name/person_name_model.fbs"],
+ out: ["annotator/person_name/person_name_model_generated.h"],
+ defaults: ["fbgen"],
+}
+
+genrule {
+ name: "libtextclassifier_fbgen_annotator_datetime_datetime",
+ srcs: ["annotator/datetime/datetime.fbs"],
+ out: ["annotator/datetime/datetime_generated.h"],
+ defaults: ["fbgen"],
+}
+
+genrule {
+ name: "libtextclassifier_fbgen_annotator_experimental_experimental",
+ srcs: ["annotator/experimental/experimental.fbs"],
+ out: ["annotator/experimental/experimental_generated.h"],
+ defaults: ["fbgen"],
+}
+
+genrule {
+ name: "libtextclassifier_fbgen_annotator_entity-data",
+ srcs: ["annotator/entity-data.fbs"],
+ out: ["annotator/entity-data_generated.h"],
+ defaults: ["fbgen"],
+}
+
+genrule {
+ name: "libtextclassifier_fbgen_annotator_model",
+ srcs: ["annotator/model.fbs"],
+ out: ["annotator/model_generated.h"],
+ defaults: ["fbgen"],
+}
+
+genrule {
+ name: "libtextclassifier_fbgen_utils_flatbuffers_flatbuffers",
+ srcs: ["utils/flatbuffers/flatbuffers.fbs"],
+ out: ["utils/flatbuffers/flatbuffers_generated.h"],
+ defaults: ["fbgen"],
+}
+
+genrule {
+ name: "libtextclassifier_fbgen_utils_tflite_text_encoder_config",
+ srcs: ["utils/tflite/text_encoder_config.fbs"],
+ out: ["utils/tflite/text_encoder_config_generated.h"],
+ defaults: ["fbgen"],
+}
+
+genrule {
+ name: "libtextclassifier_fbgen_utils_resources",
+ srcs: ["utils/resources.fbs"],
+ out: ["utils/resources_generated.h"],
+ defaults: ["fbgen"],
+}
+
+genrule {
+ name: "libtextclassifier_fbgen_utils_zlib_buffer",
+ srcs: ["utils/zlib/buffer.fbs"],
+ out: ["utils/zlib/buffer_generated.h"],
+ defaults: ["fbgen"],
+}
+
+genrule {
+ name: "libtextclassifier_fbgen_utils_container_bit-vector",
+ srcs: ["utils/container/bit-vector.fbs"],
+ out: ["utils/container/bit-vector_generated.h"],
+ defaults: ["fbgen"],
+}
+
+genrule {
+ name: "libtextclassifier_fbgen_utils_intents_intent-config",
+ srcs: ["utils/intents/intent-config.fbs"],
+ out: ["utils/intents/intent-config_generated.h"],
+ defaults: ["fbgen"],
+}
+
+genrule {
+ name: "libtextclassifier_fbgen_utils_normalization",
+ srcs: ["utils/normalization.fbs"],
+ out: ["utils/normalization_generated.h"],
+ defaults: ["fbgen"],
+}
+
+genrule {
+ name: "libtextclassifier_fbgen_utils_grammar_semantics_expression",
+ srcs: ["utils/grammar/semantics/expression.fbs"],
+ out: ["utils/grammar/semantics/expression_generated.h"],
+ defaults: ["fbgen"],
+}
+
+genrule {
+ name: "libtextclassifier_fbgen_utils_grammar_rules",
+ srcs: ["utils/grammar/rules.fbs"],
+ out: ["utils/grammar/rules_generated.h"],
+ defaults: ["fbgen"],
+}
+
+genrule {
+ name: "libtextclassifier_fbgen_utils_grammar_testing_value",
+ srcs: ["utils/grammar/testing/value.fbs"],
+ out: ["utils/grammar/testing/value_generated.h"],
+ defaults: ["fbgen"],
+}
+
+genrule {
+ name: "libtextclassifier_fbgen_utils_codepoint-range",
+ srcs: ["utils/codepoint-range.fbs"],
+ out: ["utils/codepoint-range_generated.h"],
+ defaults: ["fbgen"],
+}
+
+genrule {
+ name: "libtextclassifier_fbgen_utils_tokenizer",
+ srcs: ["utils/tokenizer.fbs"],
+ out: ["utils/tokenizer_generated.h"],
+ defaults: ["fbgen"],
+}
+
+genrule {
+ name: "libtextclassifier_fbgen_utils_i18n_language-tag",
+ srcs: ["utils/i18n/language-tag.fbs"],
+ out: ["utils/i18n/language-tag_generated.h"],
+ defaults: ["fbgen"],
+}
+
+cc_library_headers {
+ name: "libtextclassifier_flatbuffer_headers",
+ stl: "libc++_static",
+ sdk_version: "current",
+ min_sdk_version: "30",
+ apex_available: [
+ "//apex_available:platform",
+ "com.android.extservices",
+ ],
+ generated_headers: [
+ "libtextclassifier_fbgen_actions_actions_model",
+ "libtextclassifier_fbgen_actions_actions-entity-data",
+ "libtextclassifier_fbgen_lang_id_common_flatbuffers_embedding-network",
+ "libtextclassifier_fbgen_lang_id_common_flatbuffers_model",
+ "libtextclassifier_fbgen_annotator_person_name_person_name_model",
+ "libtextclassifier_fbgen_annotator_datetime_datetime",
+ "libtextclassifier_fbgen_annotator_experimental_experimental",
+ "libtextclassifier_fbgen_annotator_entity-data",
+ "libtextclassifier_fbgen_annotator_model",
+ "libtextclassifier_fbgen_utils_flatbuffers_flatbuffers",
+ "libtextclassifier_fbgen_utils_tflite_text_encoder_config",
+ "libtextclassifier_fbgen_utils_resources",
+ "libtextclassifier_fbgen_utils_zlib_buffer",
+ "libtextclassifier_fbgen_utils_container_bit-vector",
+ "libtextclassifier_fbgen_utils_intents_intent-config",
+ "libtextclassifier_fbgen_utils_normalization",
+ "libtextclassifier_fbgen_utils_grammar_semantics_expression",
+ "libtextclassifier_fbgen_utils_grammar_rules",
+ "libtextclassifier_fbgen_utils_codepoint-range",
+ "libtextclassifier_fbgen_utils_tokenizer",
+ "libtextclassifier_fbgen_utils_i18n_language-tag",
+ ],
+ export_generated_headers: [
+ "libtextclassifier_fbgen_actions_actions_model",
+ "libtextclassifier_fbgen_actions_actions-entity-data",
+ "libtextclassifier_fbgen_lang_id_common_flatbuffers_embedding-network",
+ "libtextclassifier_fbgen_lang_id_common_flatbuffers_model",
+ "libtextclassifier_fbgen_annotator_person_name_person_name_model",
+ "libtextclassifier_fbgen_annotator_datetime_datetime",
+ "libtextclassifier_fbgen_annotator_experimental_experimental",
+ "libtextclassifier_fbgen_annotator_entity-data",
+ "libtextclassifier_fbgen_annotator_model",
+ "libtextclassifier_fbgen_utils_flatbuffers_flatbuffers",
+ "libtextclassifier_fbgen_utils_tflite_text_encoder_config",
+ "libtextclassifier_fbgen_utils_resources",
+ "libtextclassifier_fbgen_utils_zlib_buffer",
+ "libtextclassifier_fbgen_utils_container_bit-vector",
+ "libtextclassifier_fbgen_utils_intents_intent-config",
+ "libtextclassifier_fbgen_utils_normalization",
+ "libtextclassifier_fbgen_utils_grammar_semantics_expression",
+ "libtextclassifier_fbgen_utils_grammar_rules",
+ "libtextclassifier_fbgen_utils_codepoint-range",
+ "libtextclassifier_fbgen_utils_tokenizer",
+ "libtextclassifier_fbgen_utils_i18n_language-tag",
+ ],
+}
+
+cc_library_headers {
+ name: "libtextclassifier_flatbuffer_testonly_headers",
+ stl: "libc++_static",
+ sdk_version: "current",
+ min_sdk_version: "30",
+ apex_available: [
+ "//apex_available:platform",
+ "com.android.extservices",
+ ],
+ generated_headers: [
+ "libtextclassifier_fbgen_utils_grammar_testing_value",
+ ],
+ export_generated_headers: [
+ "libtextclassifier_fbgen_utils_grammar_testing_value",
+ ],
+}
diff --git a/native/JavaTest.xml b/native/JavaTest.xml
new file mode 100644
index 0000000..5393fd8
--- /dev/null
+++ b/native/JavaTest.xml
@@ -0,0 +1,40 @@
+<?xml version="1.0" encoding="utf-8"?>
+<!-- Copyright (C) 2020 The Android Open Source Project
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+-->
+
+<configuration description="Runs libtextclassifier_java_tests.">
+ <option name="test-suite-tag" value="apct" />
+ <option name="test-suite-tag" value="apct-instrumentation" />
+ <target_preparer class="com.android.tradefed.targetprep.suite.SuiteApkInstaller">
+ <option name="cleanup-apks" value="true" />
+ <option name="test-file-name" value="libtextclassifier_java_tests.apk" />
+ </target_preparer>
+ <target_preparer class="com.android.compatibility.common.tradefed.targetprep.FilePusher">
+ <option name="cleanup" value="true" />
+ <option name="push" value="actions->/data/local/tmp/actions" />
+ <option name="push" value="annotator->/data/local/tmp/annotator" />
+ <option name="push" value="utils->/data/local/tmp/utils" />
+ </target_preparer>
+
+ <test class="com.android.tradefed.testtype.AndroidJUnitTest" >
+ <option name="package" value="com.google.android.textclassifier.tests" />
+ <option name="runner" value="androidx.test.runner.AndroidJUnitRunner" />
+ <option name="device-listeners" value="com.android.textclassifier.testing.TextClassifierInstrumentationListener" />
+ </test>
+
+ <object type="module_controller" class="com.android.tradefed.testtype.suite.module.MainlineTestModuleController">
+ <option name="mainline-module-package-name" value="com.google.android.extservices" />
+ </object>
+</configuration>
diff --git a/native/JavaTests.bp b/native/JavaTests.bp
new file mode 100644
index 0000000..78d5748
--- /dev/null
+++ b/native/JavaTests.bp
@@ -0,0 +1,46 @@
+//
+// Copyright (C) 2018 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+filegroup {
+ name: "libtextclassifier_java_test_sources",
+ srcs: [
+ "actions/grammar-actions_test.cc",
+ "actions/actions-suggestions_test.cc",
+ "annotator/pod_ner/pod-ner-impl_test.cc",
+ "annotator/datetime/regex-parser_test.cc",
+ "annotator/datetime/grammar-parser_test.cc",
+ "annotator/datetime/datetime-grounder_test.cc",
+ "utils/intents/intent-generator-test-lib.cc",
+ "utils/calendar/calendar_test.cc",
+ "utils/regex-match_test.cc",
+ "utils/grammar/parsing/lexer_test.cc",
+ "annotator/number/number_test-include.cc",
+ "annotator/annotator_test-include.cc",
+ "annotator/grammar/grammar-annotator_test.cc",
+ "annotator/grammar/test-utils.cc",
+ "utils/utf8/unilib_test-include.cc",
+ "utils/grammar/analyzer_test.cc",
+ "utils/grammar/semantics/composer_test.cc",
+ "utils/grammar/semantics/evaluators/arithmetic-eval_test.cc",
+ "utils/grammar/semantics/evaluators/merge-values-eval_test.cc",
+ "utils/grammar/semantics/evaluators/const-eval_test.cc",
+ "utils/grammar/semantics/evaluators/compose-eval_test.cc",
+ "utils/grammar/semantics/evaluators/span-eval_test.cc",
+ "utils/grammar/semantics/evaluators/parse-number-eval_test.cc",
+ "utils/grammar/semantics/evaluators/constituent-eval_test.cc",
+ "utils/grammar/parsing/parser_test.cc",
+ ],
+}
diff --git a/native/actions/actions-entity-data.bfbs b/native/actions/actions-entity-data.bfbs
new file mode 100644
index 0000000..7421579
--- /dev/null
+++ b/native/actions/actions-entity-data.bfbs
Binary files differ
diff --git a/native/actions/actions-entity-data.fbs b/native/actions/actions-entity-data.fbs
old mode 100755
new mode 100644
diff --git a/native/actions/actions-suggestions.cc b/native/actions/actions-suggestions.cc
index 1fcd35c..b1a042c 100644
--- a/native/actions/actions-suggestions.cc
+++ b/native/actions/actions-suggestions.cc
@@ -17,45 +17,33 @@
#include "actions/actions-suggestions.h"
#include <memory>
+#include <vector>
+#include "utils/base/statusor.h"
+
+#if !defined(TC3_DISABLE_LUA)
#include "actions/lua-actions.h"
+#endif
+#include "actions/ngram-model.h"
+#include "actions/tflite-sensitive-model.h"
#include "actions/types.h"
#include "actions/utils.h"
#include "actions/zlib-utils.h"
#include "annotator/collections.h"
#include "utils/base/logging.h"
-#include "utils/flatbuffers.h"
+#if !defined(TC3_DISABLE_LUA)
#include "utils/lua-utils.h"
+#endif
#include "utils/normalization.h"
#include "utils/optional.h"
#include "utils/strings/split.h"
#include "utils/strings/stringpiece.h"
+#include "utils/strings/utf8.h"
#include "utils/utf8/unicodetext.h"
#include "tensorflow/lite/string_util.h"
namespace libtextclassifier3 {
-const std::string& ActionsSuggestions::kViewCalendarType =
- *[]() { return new std::string("view_calendar"); }();
-const std::string& ActionsSuggestions::kViewMapType =
- *[]() { return new std::string("view_map"); }();
-const std::string& ActionsSuggestions::kTrackFlightType =
- *[]() { return new std::string("track_flight"); }();
-const std::string& ActionsSuggestions::kOpenUrlType =
- *[]() { return new std::string("open_url"); }();
-const std::string& ActionsSuggestions::kSendSmsType =
- *[]() { return new std::string("send_sms"); }();
-const std::string& ActionsSuggestions::kCallPhoneType =
- *[]() { return new std::string("call_phone"); }();
-const std::string& ActionsSuggestions::kSendEmailType =
- *[]() { return new std::string("send_email"); }();
-const std::string& ActionsSuggestions::kShareLocation =
- *[]() { return new std::string("share_location"); }();
-
-// Name for a datetime annotation that only includes time but no date.
-const std::string& kTimeAnnotation =
- *[]() { return new std::string("time"); }();
-
constexpr float kDefaultFloat = 0.0;
constexpr bool kDefaultBool = false;
constexpr int kDefaultInt = 1;
@@ -89,6 +77,36 @@
: max_conversation_history_length);
}
+template <typename T>
+std::vector<T> PadOrTruncateToTargetLength(const std::vector<T>& inputs,
+ const int max_length,
+ const T pad_value) {
+ if (inputs.size() >= max_length) {
+ return std::vector<T>(inputs.begin(), inputs.begin() + max_length);
+ } else {
+ std::vector<T> result;
+ result.reserve(max_length);
+ result.insert(result.begin(), inputs.begin(), inputs.end());
+ result.insert(result.end(), max_length - inputs.size(), pad_value);
+ return result;
+ }
+}
+
+template <typename T>
+void SetVectorOrScalarAsModelInput(
+ const int param_index, const Variant& param_value,
+ tflite::Interpreter* interpreter,
+ const std::unique_ptr<const TfLiteModelExecutor>& model_executor) {
+ if (param_value.Has<std::vector<T>>()) {
+ model_executor->SetInput<T>(
+ param_index, param_value.ConstRefValue<std::vector<T>>(), interpreter);
+ } else if (param_value.Has<T>()) {
+ model_executor->SetInput<float>(param_index, param_value.Value<T>(),
+ interpreter);
+ } else {
+ TC3_LOG(ERROR) << "Variant type error!";
+ }
+}
} // namespace
std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromUnownedBuffer(
@@ -285,7 +303,7 @@
}
entity_data_builder_.reset(
- new ReflectiveFlatbufferBuilder(entity_data_schema_));
+ new MutableFlatbufferBuilder(entity_data_schema_));
} else {
entity_data_schema_ = nullptr;
}
@@ -321,6 +339,7 @@
}
}
+#if !defined(TC3_DISABLE_LUA)
std::string actions_script;
if (GetUncompressedString(model_->lua_actions_script(),
model_->compressed_lua_actions_script(),
@@ -331,6 +350,7 @@
return false;
}
}
+#endif // TC3_DISABLE_LUA
if (!(ranker_ = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
model_->ranking_options(), decompressor.get(),
@@ -370,15 +390,23 @@
// Create low confidence model if specified.
if (model_->low_confidence_ngram_model() != nullptr) {
- ngram_model_ = NGramModel::Create(
+ sensitive_model_ = NGramSensitiveModel::Create(
unilib_, model_->low_confidence_ngram_model(),
feature_processor_ == nullptr ? nullptr
: feature_processor_->tokenizer());
- if (ngram_model_ == nullptr) {
+ if (sensitive_model_ == nullptr) {
TC3_LOG(ERROR) << "Could not create ngram linear regression model.";
return false;
}
}
+ if (model_->low_confidence_tflite_model() != nullptr) {
+ sensitive_model_ =
+ TFLiteSensitiveModel::Create(model_->low_confidence_tflite_model());
+ if (sensitive_model_ == nullptr) {
+ TC3_LOG(ERROR) << "Could not create TFLite sensitive model.";
+ return false;
+ }
+ }
return true;
}
@@ -654,8 +682,17 @@
return false;
}
if (model_->tflite_model_spec()->input_context() >= 0) {
- model_executor_->SetInput<std::string>(
- model_->tflite_model_spec()->input_context(), context, interpreter);
+ if (model_->tflite_model_spec()->input_length_to_pad() > 0) {
+ model_executor_->SetInput<std::string>(
+ model_->tflite_model_spec()->input_context(),
+ PadOrTruncateToTargetLength(
+ context, model_->tflite_model_spec()->input_length_to_pad(),
+ std::string("")),
+ interpreter);
+ } else {
+ model_executor_->SetInput<std::string>(
+ model_->tflite_model_spec()->input_context(), context, interpreter);
+ }
}
if (model_->tflite_model_spec()->input_context_length() >= 0) {
model_executor_->SetInput<int>(
@@ -663,8 +700,16 @@
interpreter);
}
if (model_->tflite_model_spec()->input_user_id() >= 0) {
- model_executor_->SetInput<int>(model_->tflite_model_spec()->input_user_id(),
- user_ids, interpreter);
+ if (model_->tflite_model_spec()->input_length_to_pad() > 0) {
+ model_executor_->SetInput<int>(
+ model_->tflite_model_spec()->input_user_id(),
+ PadOrTruncateToTargetLength(
+ user_ids, model_->tflite_model_spec()->input_length_to_pad(), 0),
+ interpreter);
+ } else {
+ model_executor_->SetInput<int>(
+ model_->tflite_model_spec()->input_user_id(), user_ids, interpreter);
+ }
}
if (model_->tflite_model_spec()->input_num_suggestions() >= 0) {
model_executor_->SetInput<int>(
@@ -710,16 +755,24 @@
const bool has_value = param_value_it != model_parameters.end();
switch (param_type) {
case kTfLiteFloat32:
- model_executor_->SetInput<float>(
- param_index,
- has_value ? param_value_it->second.Value<float>() : kDefaultFloat,
- interpreter);
+ if (has_value) {
+ SetVectorOrScalarAsModelInput<float>(param_index,
+ param_value_it->second,
+ interpreter, model_executor_);
+ } else {
+ model_executor_->SetInput<float>(param_index, kDefaultFloat,
+ interpreter);
+ }
break;
case kTfLiteInt32:
- model_executor_->SetInput<int32_t>(
- param_index,
- has_value ? param_value_it->second.Value<int>() : kDefaultInt,
- interpreter);
+ if (has_value) {
+ SetVectorOrScalarAsModelInput<int32_t>(
+ param_index, param_value_it->second, interpreter,
+ model_executor_);
+ } else {
+ model_executor_->SetInput<int32_t>(param_index, kDefaultInt,
+ interpreter);
+ }
break;
case kTfLiteInt64:
model_executor_->SetInput<int64_t>(
@@ -777,7 +830,7 @@
void ActionsSuggestions::FillSuggestionFromSpecWithEntityData(
const ActionSuggestionSpec* spec, ActionSuggestion* suggestion) const {
- std::unique_ptr<ReflectiveFlatbuffer> entity_data =
+ std::unique_ptr<MutableFlatbuffer> entity_data =
entity_data_builder_ != nullptr ? entity_data_builder_->NewRoot()
: nullptr;
FillSuggestionFromSpec(spec, entity_data.get(), suggestion);
@@ -806,7 +859,7 @@
if (triggering) {
ActionSuggestion suggestion;
- std::unique_ptr<ReflectiveFlatbuffer> entity_data =
+ std::unique_ptr<MutableFlatbuffer> entity_data =
entity_data_builder_ != nullptr ? entity_data_builder_->NewRoot()
: nullptr;
FillSuggestionFromSpecWithEntityData(task_spec, &suggestion);
@@ -844,13 +897,12 @@
return false;
}
response->sensitivity_score = sensitive_topic_score.data()[0];
- response->output_filtered_sensitivity =
- (response->sensitivity_score >
- preconditions_.max_sensitive_topic_score);
+ response->is_sensitive = (response->sensitivity_score >
+ preconditions_.max_sensitive_topic_score);
}
// Suppress model outputs.
- if (response->output_filtered_sensitivity) {
+ if (response->is_sensitive) {
return true;
}
@@ -881,7 +933,7 @@
// Create action from model output.
ActionSuggestion suggestion;
suggestion.type = action_type->name()->str();
- std::unique_ptr<ReflectiveFlatbuffer> entity_data =
+ std::unique_ptr<MutableFlatbuffer> entity_data =
entity_data_builder_ != nullptr ? entity_data_builder_->NewRoot()
: nullptr;
FillSuggestionFromSpecWithEntityData(action_type->action(), &suggestion);
@@ -932,6 +984,12 @@
std::unique_ptr<tflite::Interpreter>* interpreter) const {
TC3_CHECK_LE(num_messages, conversation.messages.size());
+ if (sensitive_model_ != nullptr &&
+ sensitive_model_->EvalConversation(conversation, num_messages).first) {
+ response->is_sensitive = true;
+ return true;
+ }
+
if (!model_executor_) {
return true;
}
@@ -987,6 +1045,18 @@
return ReadModelOutput(interpreter->get(), options, response);
}
+Status ActionsSuggestions::SuggestActionsFromConversationIntentDetection(
+ const Conversation& conversation, const ActionSuggestionOptions& options,
+ std::vector<ActionSuggestion>* actions) const {
+ TC3_ASSIGN_OR_RETURN(
+ std::vector<ActionSuggestion> new_actions,
+ conversation_intent_detection_->SuggestActions(conversation, options));
+ for (auto& action : new_actions) {
+ actions->push_back(std::move(action));
+ }
+ return Status::OK;
+}
+
AnnotationOptions ActionsSuggestions::AnnotationOptionsForMessage(
const ConversationMessage& message) const {
AnnotationOptions options;
@@ -1036,30 +1106,7 @@
if (message->annotations.empty()) {
message->annotations = annotator->Annotate(
message->text, AnnotationOptionsForMessage(*message));
- for (int i = 0; i < message->annotations.size(); i++) {
- ClassificationResult* classification =
- &message->annotations[i].classification.front();
-
- // Specialize datetime annotation to time annotation if no date
- // component is present.
- if (classification->collection == Collections::DateTime() &&
- classification->datetime_parse_result.IsSet()) {
- bool has_only_time = true;
- for (const DatetimeComponent& component :
- classification->datetime_parse_result.datetime_components) {
- if (component.component_type !=
- DatetimeComponent::ComponentType::UNSPECIFIED &&
- component.component_type <
- DatetimeComponent::ComponentType::HOUR) {
- has_only_time = false;
- break;
- }
- }
- if (has_only_time) {
- classification->collection = kTimeAnnotation;
- }
- }
- }
+ ConvertDatetimeToTime(&message->annotations);
}
}
return annotated_conversation;
@@ -1160,7 +1207,7 @@
continue;
}
- std::unique_ptr<ReflectiveFlatbuffer> entity_data =
+ std::unique_ptr<MutableFlatbuffer> entity_data =
entity_data_builder_ != nullptr ? entity_data_builder_->NewRoot()
: nullptr;
@@ -1220,6 +1267,7 @@
return result;
}
+#if !defined(TC3_DISABLE_LUA)
bool ActionsSuggestions::SuggestActionsFromLua(
const Conversation& conversation, const TfLiteModelExecutor* model_executor,
const tflite::Interpreter* interpreter,
@@ -1238,6 +1286,15 @@
}
return lua_actions->SuggestActions(actions);
}
+#else
+bool ActionsSuggestions::SuggestActionsFromLua(
+ const Conversation& conversation, const TfLiteModelExecutor* model_executor,
+ const tflite::Interpreter* interpreter,
+ const reflection::Schema* annotation_entity_data_schema,
+ std::vector<ActionSuggestion>* actions) const {
+ return true;
+}
+#endif
bool ActionsSuggestions::GatherActionsSuggestions(
const Conversation& conversation, const Annotator* annotator,
@@ -1305,10 +1362,7 @@
std::vector<const UniLib::RegexPattern*> post_check_rules;
if (preconditions_.suppress_on_low_confidence_input) {
- if ((ngram_model_ != nullptr &&
- ngram_model_->EvalConversation(annotated_conversation,
- num_messages)) ||
- regex_actions_->IsLowConfidenceInput(annotated_conversation,
+ if (regex_actions_->IsLowConfidenceInput(annotated_conversation,
num_messages, &post_check_rules)) {
response->output_filtered_low_confidence = true;
return true;
@@ -1322,12 +1376,24 @@
return false;
}
+ // SuggestActionsFromModel also detects if the conversation is sensitive,
+ // either by using the old ngram model or the new model.
// Suppress all predictions if the conversation was deemed sensitive.
- if (preconditions_.suppress_on_sensitive_topic &&
- response->output_filtered_sensitivity) {
+ if (preconditions_.suppress_on_sensitive_topic && response->is_sensitive) {
return true;
}
+ if (conversation_intent_detection_) {
+ // TODO(zbin): Ensure the deduplication/ranking logic in ranker.cc works.
+ auto actions = SuggestActionsFromConversationIntentDetection(
+ annotated_conversation, options, &response->actions);
+ if (!actions.ok()) {
+ TC3_LOG(ERROR) << "Could not run conversation intent detection: "
+ << actions.error_message();
+ return false;
+ }
+ }
+
if (!SuggestActionsFromLua(
annotated_conversation, model_executor_.get(), interpreter.get(),
annotator != nullptr ? annotator->entity_data_schema() : nullptr,
@@ -1363,6 +1429,21 @@
if (conversation.messages[i].reference_time_ms_utc <
conversation.messages[i - 1].reference_time_ms_utc) {
TC3_LOG(ERROR) << "Messages are not sorted most recent last.";
+ return response;
+ }
+ }
+
+ // Check that messages are valid utf8.
+ for (const ConversationMessage& message : conversation.messages) {
+ if (message.text.size() > std::numeric_limits<int>::max()) {
+ TC3_LOG(ERROR) << "Rejecting too long input: " << message.text.size();
+ return {};
+ }
+
+ if (!unilib_->IsValidUtf8(UTF8ToUnicodeText(
+ message.text.data(), message.text.size(), /*do_copy=*/false))) {
+ TC3_LOG(ERROR) << "Not valid utf8 provided.";
+ return response;
}
}
@@ -1397,4 +1478,16 @@
return LoadAndVerifyModel(reinterpret_cast<const uint8_t*>(buffer), size);
}
+bool ActionsSuggestions::InitializeConversationIntentDetection(
+ const std::string& serialized_config) {
+ auto conversation_intent_detection =
+ std::make_unique<ConversationIntentDetection>();
+ if (!conversation_intent_detection->Initialize(serialized_config).ok()) {
+ TC3_LOG(ERROR) << "Failed to initialize conversation intent detection.";
+ return false;
+ }
+ conversation_intent_detection_ = std::move(conversation_intent_detection);
+ return true;
+}
+
} // namespace libtextclassifier3
diff --git a/native/actions/actions-suggestions.h b/native/actions/actions-suggestions.h
index 2a321f0..32edc78 100644
--- a/native/actions/actions-suggestions.h
+++ b/native/actions/actions-suggestions.h
@@ -25,16 +25,18 @@
#include <vector>
#include "actions/actions_model_generated.h"
+#include "actions/conversation_intent_detection/conversation-intent-detection.h"
#include "actions/feature-processor.h"
#include "actions/grammar-actions.h"
-#include "actions/ngram-model.h"
#include "actions/ranker.h"
#include "actions/regex-actions.h"
+#include "actions/sensitive-classifier-base.h"
#include "actions/types.h"
#include "annotator/annotator.h"
#include "annotator/model-executor.h"
#include "annotator/types.h"
-#include "utils/flatbuffers.h"
+#include "utils/flatbuffers/flatbuffers.h"
+#include "utils/flatbuffers/mutable.h"
#include "utils/i18n/locale.h"
#include "utils/memory/mmap.h"
#include "utils/tflite-model-executor.h"
@@ -44,12 +46,6 @@
namespace libtextclassifier3 {
-// Options for suggesting actions.
-struct ActionSuggestionOptions {
- static ActionSuggestionOptions Default() { return ActionSuggestionOptions(); }
- std::unordered_map<std::string, Variant> model_parameters;
-};
-
// Class for predicting actions following a conversation.
class ActionsSuggestions {
public:
@@ -109,22 +105,14 @@
const Conversation& conversation, const Annotator* annotator,
const ActionSuggestionOptions& options = ActionSuggestionOptions()) const;
+ bool InitializeConversationIntentDetection(
+ const std::string& serialized_config);
+
const ActionsModel* model() const;
const reflection::Schema* entity_data_schema() const;
static constexpr int kLocalUserId = 0;
- // Should be in sync with those defined in Android.
- // android/frameworks/base/core/java/android/view/textclassifier/ConversationActions.java
- static const std::string& kViewCalendarType;
- static const std::string& kViewMapType;
- static const std::string& kTrackFlightType;
- static const std::string& kOpenUrlType;
- static const std::string& kSendSmsType;
- static const std::string& kCallPhoneType;
- static const std::string& kSendEmailType;
- static const std::string& kShareLocation;
-
protected:
// Exposed for testing.
bool EmbedTokenId(const int32 token_id, std::vector<float>* embedding) const;
@@ -206,6 +194,10 @@
ActionsSuggestionsResponse* response,
std::unique_ptr<tflite::Interpreter>* interpreter) const;
+ Status SuggestActionsFromConversationIntentDetection(
+ const Conversation& conversation, const ActionSuggestionOptions& options,
+ std::vector<ActionSuggestion>* actions) const;
+
// Creates options for annotation of a message.
AnnotationOptions AnnotationOptionsForMessage(
const ConversationMessage& message) const;
@@ -262,7 +254,7 @@
// Builder for creating extra data.
const reflection::Schema* entity_data_schema_;
- std::unique_ptr<ReflectiveFlatbufferBuilder> entity_data_builder_;
+ std::unique_ptr<MutableFlatbufferBuilder> entity_data_builder_;
std::unique_ptr<ActionsSuggestionsRanker> ranker_;
std::string lua_bytecode_;
@@ -274,7 +266,11 @@
const TriggeringPreconditions* triggering_preconditions_overlay_;
// Low confidence input ngram classifier.
- std::unique_ptr<const NGramModel> ngram_model_;
+ std::unique_ptr<const SensitiveTopicModelBase> sensitive_model_;
+
+ // Conversation intent detection model for additional actions.
+ std::unique_ptr<const ConversationIntentDetection>
+ conversation_intent_detection_;
};
// Interprets the buffer as a Model flatbuffer and returns it for reading.
@@ -297,6 +293,72 @@
return function(model);
}
+class ActionsSuggestionsTypes {
+ public:
+ // Should be in sync with those defined in Android.
+ // android/frameworks/base/core/java/android/view/textclassifier/ConversationActions.java
+ static const std::string& ViewCalendar() {
+ static const std::string& value =
+ *[]() { return new std::string("view_calendar"); }();
+ return value;
+ }
+ static const std::string& ViewMap() {
+ static const std::string& value =
+ *[]() { return new std::string("view_map"); }();
+ return value;
+ }
+ static const std::string& TrackFlight() {
+ static const std::string& value =
+ *[]() { return new std::string("track_flight"); }();
+ return value;
+ }
+ static const std::string& OpenUrl() {
+ static const std::string& value =
+ *[]() { return new std::string("open_url"); }();
+ return value;
+ }
+ static const std::string& SendSms() {
+ static const std::string& value =
+ *[]() { return new std::string("send_sms"); }();
+ return value;
+ }
+ static const std::string& CallPhone() {
+ static const std::string& value =
+ *[]() { return new std::string("call_phone"); }();
+ return value;
+ }
+ static const std::string& SendEmail() {
+ static const std::string& value =
+ *[]() { return new std::string("send_email"); }();
+ return value;
+ }
+ static const std::string& ShareLocation() {
+ static const std::string& value =
+ *[]() { return new std::string("share_location"); }();
+ return value;
+ }
+ static const std::string& CreateReminder() {
+ static const std::string& value =
+ *[]() { return new std::string("create_reminder"); }();
+ return value;
+ }
+ static const std::string& TextReply() {
+ static const std::string& value =
+ *[]() { return new std::string("text_reply"); }();
+ return value;
+ }
+ static const std::string& AddContact() {
+ static const std::string& value =
+ *[]() { return new std::string("add_contact"); }();
+ return value;
+ }
+ static const std::string& Copy() {
+ static const std::string& value =
+ *[]() { return new std::string("copy"); }();
+ return value;
+ }
+};
+
} // namespace libtextclassifier3
#endif // LIBTEXTCLASSIFIER_ACTIONS_ACTIONS_SUGGESTIONS_H_
diff --git a/native/actions/actions-suggestions_test.cc b/native/actions/actions-suggestions_test.cc
new file mode 100644
index 0000000..7fe69fc
--- /dev/null
+++ b/native/actions/actions-suggestions_test.cc
@@ -0,0 +1,1829 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "actions/actions-suggestions.h"
+
+#include <fstream>
+#include <iterator>
+#include <memory>
+#include <string>
+
+#include "actions/actions_model_generated.h"
+#include "actions/test-utils.h"
+#include "actions/zlib-utils.h"
+#include "annotator/collections.h"
+#include "annotator/types.h"
+#include "utils/flatbuffers/flatbuffers.h"
+#include "utils/flatbuffers/flatbuffers_generated.h"
+#include "utils/flatbuffers/mutable.h"
+#include "utils/grammar/utils/locale-shard-map.h"
+#include "utils/grammar/utils/rules.h"
+#include "utils/hash/farmhash.h"
+#include "utils/jvm-test-utils.h"
+#include "utils/test-data-test-utils.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "flatbuffers/flatbuffers.h"
+#include "flatbuffers/reflection.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+using ::testing::ElementsAre;
+using ::testing::FloatEq;
+using ::testing::IsEmpty;
+using ::testing::NotNull;
+using ::testing::SizeIs;
+
+constexpr char kModelFileName[] = "actions_suggestions_test.model";
+constexpr char kModelGrammarFileName[] =
+ "actions_suggestions_grammar_test.model";
+constexpr char kMultiTaskTF2TestModelFileName[] =
+ "actions_suggestions_test.multi_task_tf2_test.model";
+constexpr char kMultiTaskModelFileName[] =
+ "actions_suggestions_test.multi_task_9heads.model";
+constexpr char kHashGramModelFileName[] =
+ "actions_suggestions_test.hashgram.model";
+constexpr char kMultiTaskSrP13nModelFileName[] =
+ "actions_suggestions_test.multi_task_sr_p13n.model";
+constexpr char kMultiTaskSrEmojiModelFileName[] =
+ "actions_suggestions_test.multi_task_sr_emoji.model";
+constexpr char kSensitiveTFliteModelFileName[] =
+ "actions_suggestions_test.sensitive_tflite.model";
+
+std::string ReadFile(const std::string& file_name) {
+ std::ifstream file_stream(file_name);
+ return std::string(std::istreambuf_iterator<char>(file_stream), {});
+}
+
+std::string GetModelPath() { return GetTestDataPath("actions/test_data/"); }
+
+class ActionsSuggestionsTest : public testing::Test {
+ protected:
+ explicit ActionsSuggestionsTest() : unilib_(CreateUniLibForTesting()) {}
+ std::unique_ptr<ActionsSuggestions> LoadTestModel(
+ const std::string model_file_name) {
+ return ActionsSuggestions::FromPath(GetModelPath() + model_file_name,
+ unilib_.get());
+ }
+ std::unique_ptr<ActionsSuggestions> LoadHashGramTestModel() {
+ return ActionsSuggestions::FromPath(GetModelPath() + kHashGramModelFileName,
+ unilib_.get());
+ }
+ std::unique_ptr<ActionsSuggestions> LoadMultiTaskTestModel() {
+ return ActionsSuggestions::FromPath(
+ GetModelPath() + kMultiTaskModelFileName, unilib_.get());
+ }
+
+ std::unique_ptr<ActionsSuggestions> LoadMultiTaskSrP13nTestModel() {
+ return ActionsSuggestions::FromPath(
+ GetModelPath() + kMultiTaskSrP13nModelFileName, unilib_.get());
+ }
+ std::unique_ptr<UniLib> unilib_;
+};
+
+TEST_F(ActionsSuggestionsTest, InstantiateActionSuggestions) {
+ EXPECT_THAT(LoadTestModel(kModelFileName), NotNull());
+}
+
+TEST_F(ActionsSuggestionsTest, ProducesEmptyResponseOnInvalidInput) {
+ std::unique_ptr<ActionsSuggestions> actions_suggestions =
+ LoadTestModel(kModelFileName);
+ const ActionsSuggestionsResponse response =
+ actions_suggestions->SuggestActions(
+ {{{/*user_id=*/1, "Where are you?\xf0\x9f",
+ /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/{}, /*locales=*/"en"}}});
+ EXPECT_THAT(response.actions, IsEmpty());
+}
+
+TEST_F(ActionsSuggestionsTest, ProducesEmptyResponseOnInvalidUtf8) {
+ std::unique_ptr<ActionsSuggestions> actions_suggestions =
+ LoadTestModel(kModelFileName);
+
+ const ActionsSuggestionsResponse response =
+ actions_suggestions->SuggestActions(
+ {{{/*user_id=*/1,
+ "(857) 225-3556 \xed\xa0\x80\xed\xa0\x80\xed\xa0\x80\xed\xa0\x80",
+ /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/{}, /*locales=*/"en"}}});
+ EXPECT_THAT(response.actions, IsEmpty());
+}
+
+TEST_F(ActionsSuggestionsTest, SuggestsActions) {
+ std::unique_ptr<ActionsSuggestions> actions_suggestions =
+ LoadTestModel(kModelFileName);
+ const ActionsSuggestionsResponse response =
+ actions_suggestions->SuggestActions(
+ {{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/{}, /*locales=*/"en"}}});
+ EXPECT_EQ(response.actions.size(), 3 /* share_location + 2 smart replies*/);
+}
+
+TEST_F(ActionsSuggestionsTest, SuggestsNoActionsForUnknownLocale) {
+ std::unique_ptr<ActionsSuggestions> actions_suggestions =
+ LoadTestModel(kModelFileName);
+ const ActionsSuggestionsResponse response =
+ actions_suggestions->SuggestActions(
+ {{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/{}, /*locales=*/"zz"}}});
+ EXPECT_THAT(response.actions, testing::IsEmpty());
+}
+
+TEST_F(ActionsSuggestionsTest, SuggestsActionsFromAnnotations) {
+ std::unique_ptr<ActionsSuggestions> actions_suggestions =
+ LoadTestModel(kModelFileName);
+ AnnotatedSpan annotation;
+ annotation.span = {11, 15};
+ annotation.classification = {ClassificationResult("address", 1.0)};
+ const ActionsSuggestionsResponse response =
+ actions_suggestions->SuggestActions(
+ {{{/*user_id=*/1, "are you at home?",
+ /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/{annotation},
+ /*locales=*/"en"}}});
+ ASSERT_GE(response.actions.size(), 1);
+ EXPECT_EQ(response.actions.front().type, "view_map");
+ EXPECT_EQ(response.actions.front().score, 1.0);
+}
+
+TEST_F(ActionsSuggestionsTest, SuggestsActionsFromAnnotationsWithEntityData) {
+ const std::string actions_model_string =
+ ReadFile(GetModelPath() + kModelFileName);
+ std::unique_ptr<ActionsModelT> actions_model =
+ UnPackActionsModel(actions_model_string.c_str());
+ SetTestEntityDataSchema(actions_model.get());
+
+ // Set custom actions from annotations config.
+ actions_model->annotation_actions_spec->annotation_mapping.clear();
+ actions_model->annotation_actions_spec->annotation_mapping.emplace_back(
+ new AnnotationActionsSpec_::AnnotationMappingT);
+ AnnotationActionsSpec_::AnnotationMappingT* mapping =
+ actions_model->annotation_actions_spec->annotation_mapping.back().get();
+ mapping->annotation_collection = "address";
+ mapping->action.reset(new ActionSuggestionSpecT);
+ mapping->action->type = "save_location";
+ mapping->action->score = 1.0;
+ mapping->action->priority_score = 2.0;
+ mapping->entity_field.reset(new FlatbufferFieldPathT);
+ mapping->entity_field->field.emplace_back(new FlatbufferFieldT);
+ mapping->entity_field->field.back()->field_name = "location";
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishActionsModelBuffer(builder,
+ ActionsModel::Pack(builder, actions_model.get()));
+ std::unique_ptr<ActionsSuggestions> actions_suggestions =
+ ActionsSuggestions::FromUnownedBuffer(
+ reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get());
+
+ AnnotatedSpan annotation;
+ annotation.span = {11, 15};
+ annotation.classification = {ClassificationResult("address", 1.0)};
+ const ActionsSuggestionsResponse response =
+ actions_suggestions->SuggestActions(
+ {{{/*user_id=*/1, "are you at home?",
+ /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/{annotation},
+ /*locales=*/"en"}}});
+ ASSERT_GE(response.actions.size(), 1);
+ EXPECT_EQ(response.actions.front().type, "save_location");
+ EXPECT_EQ(response.actions.front().score, 1.0);
+
+ // Check that the `location` entity field holds the text from the address
+ // annotation.
+ const flatbuffers::Table* entity =
+ flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
+ response.actions.front().serialized_entity_data.data()));
+ EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/6)->str(),
+ "home");
+}
+
+TEST_F(ActionsSuggestionsTest,
+ SuggestsActionsFromAnnotationsWithNormalization) {
+ const std::string actions_model_string =
+ ReadFile(GetModelPath() + kModelFileName);
+ std::unique_ptr<ActionsModelT> actions_model =
+ UnPackActionsModel(actions_model_string.c_str());
+ SetTestEntityDataSchema(actions_model.get());
+
+ // Set custom actions from annotations config.
+ actions_model->annotation_actions_spec->annotation_mapping.clear();
+ actions_model->annotation_actions_spec->annotation_mapping.emplace_back(
+ new AnnotationActionsSpec_::AnnotationMappingT);
+ AnnotationActionsSpec_::AnnotationMappingT* mapping =
+ actions_model->annotation_actions_spec->annotation_mapping.back().get();
+ mapping->annotation_collection = "address";
+ mapping->action.reset(new ActionSuggestionSpecT);
+ mapping->action->type = "save_location";
+ mapping->action->score = 1.0;
+ mapping->action->priority_score = 2.0;
+ mapping->entity_field.reset(new FlatbufferFieldPathT);
+ mapping->entity_field->field.emplace_back(new FlatbufferFieldT);
+ mapping->entity_field->field.back()->field_name = "location";
+ mapping->normalization_options.reset(new NormalizationOptionsT);
+ mapping->normalization_options->codepointwise_normalization =
+ NormalizationOptions_::CodepointwiseNormalizationOp_UPPERCASE;
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishActionsModelBuffer(builder,
+ ActionsModel::Pack(builder, actions_model.get()));
+ std::unique_ptr<ActionsSuggestions> actions_suggestions =
+ ActionsSuggestions::FromUnownedBuffer(
+ reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get());
+
+ AnnotatedSpan annotation;
+ annotation.span = {11, 15};
+ annotation.classification = {ClassificationResult("address", 1.0)};
+ const ActionsSuggestionsResponse response =
+ actions_suggestions->SuggestActions(
+ {{{/*user_id=*/1, "are you at home?",
+ /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/{annotation},
+ /*locales=*/"en"}}});
+ ASSERT_GE(response.actions.size(), 1);
+ EXPECT_EQ(response.actions.front().type, "save_location");
+ EXPECT_EQ(response.actions.front().score, 1.0);
+
+ // Check that the `location` entity field holds the normalized text of the
+ // annotation.
+ const flatbuffers::Table* entity =
+ flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
+ response.actions.front().serialized_entity_data.data()));
+ EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/6)->str(),
+ "HOME");
+}
+
+TEST_F(ActionsSuggestionsTest, SuggestsActionsFromDuplicatedAnnotations) {
+ std::unique_ptr<ActionsSuggestions> actions_suggestions =
+ LoadTestModel(kModelFileName);
+ AnnotatedSpan flight_annotation;
+ flight_annotation.span = {11, 15};
+ flight_annotation.classification = {ClassificationResult("flight", 2.5)};
+ AnnotatedSpan flight_annotation2;
+ flight_annotation2.span = {35, 39};
+ flight_annotation2.classification = {ClassificationResult("flight", 3.0)};
+ AnnotatedSpan email_annotation;
+ email_annotation.span = {43, 56};
+ email_annotation.classification = {ClassificationResult("email", 2.0)};
+
+ const ActionsSuggestionsResponse response =
+ actions_suggestions->SuggestActions(
+ {{{/*user_id=*/1,
+ "call me at LX38 or send message to LX38 or test@test.com.",
+ /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/
+ {flight_annotation, flight_annotation2, email_annotation},
+ /*locales=*/"en"}}});
+
+ ASSERT_GE(response.actions.size(), 2);
+ EXPECT_EQ(response.actions[0].type, "track_flight");
+ EXPECT_EQ(response.actions[0].score, 3.0);
+ EXPECT_EQ(response.actions[1].type, "send_email");
+ EXPECT_EQ(response.actions[1].score, 2.0);
+}
+
+TEST_F(ActionsSuggestionsTest, SuggestsActionsAnnotationsWithNoDeduplication) {
+ const std::string actions_model_string =
+ ReadFile(GetModelPath() + kModelFileName);
+ std::unique_ptr<ActionsModelT> actions_model =
+ UnPackActionsModel(actions_model_string.c_str());
+ // Disable deduplication.
+ actions_model->annotation_actions_spec->deduplicate_annotations = false;
+ flatbuffers::FlatBufferBuilder builder;
+ FinishActionsModelBuffer(builder,
+ ActionsModel::Pack(builder, actions_model.get()));
+ std::unique_ptr<ActionsSuggestions> actions_suggestions =
+ ActionsSuggestions::FromUnownedBuffer(
+ reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get());
+ AnnotatedSpan flight_annotation;
+ flight_annotation.span = {11, 15};
+ flight_annotation.classification = {ClassificationResult("flight", 2.5)};
+ AnnotatedSpan flight_annotation2;
+ flight_annotation2.span = {35, 39};
+ flight_annotation2.classification = {ClassificationResult("flight", 3.0)};
+ AnnotatedSpan email_annotation;
+ email_annotation.span = {43, 56};
+ email_annotation.classification = {ClassificationResult("email", 2.0)};
+
+ const ActionsSuggestionsResponse response =
+ actions_suggestions->SuggestActions(
+ {{{/*user_id=*/1,
+ "call me at LX38 or send message to LX38 or test@test.com.",
+ /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/
+ {flight_annotation, flight_annotation2, email_annotation},
+ /*locales=*/"en"}}});
+
+ ASSERT_GE(response.actions.size(), 3);
+ EXPECT_EQ(response.actions[0].type, "track_flight");
+ EXPECT_EQ(response.actions[0].score, 3.0);
+ EXPECT_EQ(response.actions[1].type, "track_flight");
+ EXPECT_EQ(response.actions[1].score, 2.5);
+ EXPECT_EQ(response.actions[2].type, "send_email");
+ EXPECT_EQ(response.actions[2].score, 2.0);
+}
+
+ActionsSuggestionsResponse TestSuggestActionsFromAnnotations(
+ const std::function<void(ActionsModelT*)>& set_config_fn,
+ const UniLib* unilib = nullptr) {
+ const std::string actions_model_string =
+ ReadFile(GetModelPath() + kModelFileName);
+ std::unique_ptr<ActionsModelT> actions_model =
+ UnPackActionsModel(actions_model_string.c_str());
+
+ // Set custom config.
+ set_config_fn(actions_model.get());
+
+ // Disable smart reply for easier testing.
+ actions_model->preconditions->min_smart_reply_triggering_score = 1.0;
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishActionsModelBuffer(builder,
+ ActionsModel::Pack(builder, actions_model.get()));
+ std::unique_ptr<ActionsSuggestions> actions_suggestions =
+ ActionsSuggestions::FromUnownedBuffer(
+ reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib);
+
+ AnnotatedSpan flight_annotation;
+ flight_annotation.span = {15, 19};
+ flight_annotation.classification = {ClassificationResult("flight", 2.0)};
+ AnnotatedSpan email_annotation;
+ email_annotation.span = {0, 16};
+ email_annotation.classification = {ClassificationResult("email", 1.0)};
+
+ return actions_suggestions->SuggestActions(
+ {{{/*user_id=*/ActionsSuggestions::kLocalUserId,
+ "hehe@android.com",
+ /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/
+ {email_annotation},
+ /*locales=*/"en"},
+ {/*user_id=*/2,
+ "yoyo@android.com",
+ /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/
+ {email_annotation},
+ /*locales=*/"en"},
+ {/*user_id=*/1,
+ "test@android.com",
+ /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/
+ {email_annotation},
+ /*locales=*/"en"},
+ {/*user_id=*/1,
+ "I am on flight LX38.",
+ /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/
+ {flight_annotation},
+ /*locales=*/"en"}}});
+}
+
+TEST_F(ActionsSuggestionsTest, SuggestsActionsWithAnnotationsOnlyLastMessage) {
+ const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations(
+ [](ActionsModelT* actions_model) {
+ actions_model->annotation_actions_spec->include_local_user_messages =
+ false;
+ actions_model->annotation_actions_spec->only_until_last_sent = true;
+ actions_model->annotation_actions_spec->max_history_from_any_person = 1;
+ actions_model->annotation_actions_spec->max_history_from_last_person =
+ 1;
+ },
+ unilib_.get());
+ EXPECT_THAT(response.actions, SizeIs(1));
+ EXPECT_EQ(response.actions[0].type, "track_flight");
+}
+
+TEST_F(ActionsSuggestionsTest, SuggestsActionsWithAnnotationsOnlyLastPerson) {
+ const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations(
+ [](ActionsModelT* actions_model) {
+ actions_model->annotation_actions_spec->include_local_user_messages =
+ false;
+ actions_model->annotation_actions_spec->only_until_last_sent = true;
+ actions_model->annotation_actions_spec->max_history_from_any_person = 1;
+ actions_model->annotation_actions_spec->max_history_from_last_person =
+ 3;
+ },
+ unilib_.get());
+ EXPECT_THAT(response.actions, SizeIs(2));
+ EXPECT_EQ(response.actions[0].type, "track_flight");
+ EXPECT_EQ(response.actions[1].type, "send_email");
+}
+
+TEST_F(ActionsSuggestionsTest, SuggestsActionsWithAnnotationsFromAny) {
+ const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations(
+ [](ActionsModelT* actions_model) {
+ actions_model->annotation_actions_spec->include_local_user_messages =
+ false;
+ actions_model->annotation_actions_spec->only_until_last_sent = true;
+ actions_model->annotation_actions_spec->max_history_from_any_person = 2;
+ actions_model->annotation_actions_spec->max_history_from_last_person =
+ 1;
+ },
+ unilib_.get());
+ EXPECT_THAT(response.actions, SizeIs(2));
+ EXPECT_EQ(response.actions[0].type, "track_flight");
+ EXPECT_EQ(response.actions[1].type, "send_email");
+}
+
+TEST_F(ActionsSuggestionsTest,
+ SuggestsActionsWithAnnotationsFromAnyManyMessages) {
+ const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations(
+ [](ActionsModelT* actions_model) {
+ actions_model->annotation_actions_spec->include_local_user_messages =
+ false;
+ actions_model->annotation_actions_spec->only_until_last_sent = true;
+ actions_model->annotation_actions_spec->max_history_from_any_person = 3;
+ actions_model->annotation_actions_spec->max_history_from_last_person =
+ 1;
+ },
+ unilib_.get());
+ EXPECT_THAT(response.actions, SizeIs(3));
+ EXPECT_EQ(response.actions[0].type, "track_flight");
+ EXPECT_EQ(response.actions[1].type, "send_email");
+ EXPECT_EQ(response.actions[2].type, "send_email");
+}
+
+TEST_F(ActionsSuggestionsTest,
+ SuggestsActionsWithAnnotationsFromAnyManyMessagesButNotLocalUser) {
+ const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations(
+ [](ActionsModelT* actions_model) {
+ actions_model->annotation_actions_spec->include_local_user_messages =
+ false;
+ actions_model->annotation_actions_spec->only_until_last_sent = true;
+ actions_model->annotation_actions_spec->max_history_from_any_person = 5;
+ actions_model->annotation_actions_spec->max_history_from_last_person =
+ 1;
+ },
+ unilib_.get());
+ EXPECT_THAT(response.actions, SizeIs(3));
+ EXPECT_EQ(response.actions[0].type, "track_flight");
+ EXPECT_EQ(response.actions[1].type, "send_email");
+ EXPECT_EQ(response.actions[2].type, "send_email");
+}
+
+TEST_F(ActionsSuggestionsTest,
+ SuggestsActionsWithAnnotationsFromAnyManyMessagesAlsoFromLocalUser) {
+ const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations(
+ [](ActionsModelT* actions_model) {
+ actions_model->annotation_actions_spec->include_local_user_messages =
+ true;
+ actions_model->annotation_actions_spec->only_until_last_sent = false;
+ actions_model->annotation_actions_spec->max_history_from_any_person = 5;
+ actions_model->annotation_actions_spec->max_history_from_last_person =
+ 1;
+ },
+ unilib_.get());
+ EXPECT_THAT(response.actions, SizeIs(4));
+ EXPECT_EQ(response.actions[0].type, "track_flight");
+ EXPECT_EQ(response.actions[1].type, "send_email");
+ EXPECT_EQ(response.actions[2].type, "send_email");
+ EXPECT_EQ(response.actions[3].type, "send_email");
+}
+
+void TestSuggestActionsWithThreshold(
+ const std::function<void(ActionsModelT*)>& set_value_fn,
+ const UniLib* unilib = nullptr, const int expected_size = 0,
+ const std::string& preconditions_overwrite = "") {
+ const std::string actions_model_string =
+ ReadFile(GetModelPath() + kModelFileName);
+ std::unique_ptr<ActionsModelT> actions_model =
+ UnPackActionsModel(actions_model_string.c_str());
+ set_value_fn(actions_model.get());
+ flatbuffers::FlatBufferBuilder builder;
+ FinishActionsModelBuffer(builder,
+ ActionsModel::Pack(builder, actions_model.get()));
+ std::unique_ptr<ActionsSuggestions> actions_suggestions =
+ ActionsSuggestions::FromUnownedBuffer(
+ reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib, preconditions_overwrite);
+ ASSERT_TRUE(actions_suggestions);
+ const ActionsSuggestionsResponse response =
+ actions_suggestions->SuggestActions(
+ {{{/*user_id=*/1, "I have the low-ground. Where are you?",
+ /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/{}, /*locales=*/"en"}}});
+ EXPECT_LE(response.actions.size(), expected_size);
+}
+
+TEST_F(ActionsSuggestionsTest, SuggestsActionsWithTriggeringScore) {
+ TestSuggestActionsWithThreshold(
+ [](ActionsModelT* actions_model) {
+ actions_model->preconditions->min_smart_reply_triggering_score = 1.0;
+ },
+ unilib_.get(),
+ /*expected_size=*/1 /*no smart reply, only actions*/
+ );
+}
+
+TEST_F(ActionsSuggestionsTest, SuggestsActionsWithMinReplyScore) {
+ TestSuggestActionsWithThreshold(
+ [](ActionsModelT* actions_model) {
+ actions_model->preconditions->min_reply_score_threshold = 1.0;
+ },
+ unilib_.get(),
+ /*expected_size=*/1 /*no smart reply, only actions*/
+ );
+}
+
+TEST_F(ActionsSuggestionsTest, SuggestsActionsWithSensitiveTopicScore) {
+ TestSuggestActionsWithThreshold(
+ [](ActionsModelT* actions_model) {
+ actions_model->preconditions->max_sensitive_topic_score = 0.0;
+ },
+ unilib_.get(),
+ /*expected_size=*/4 /* no sensitive prediction in test model*/);
+}
+
+TEST_F(ActionsSuggestionsTest, SuggestsActionsWithMaxInputLength) {
+ TestSuggestActionsWithThreshold(
+ [](ActionsModelT* actions_model) {
+ actions_model->preconditions->max_input_length = 0;
+ },
+ unilib_.get());
+}
+
+TEST_F(ActionsSuggestionsTest, SuggestsActionsWithMinInputLength) {
+ TestSuggestActionsWithThreshold(
+ [](ActionsModelT* actions_model) {
+ actions_model->preconditions->min_input_length = 100;
+ },
+ unilib_.get());
+}
+
+TEST_F(ActionsSuggestionsTest, SuggestsActionsWithPreconditionsOverwrite) {
+ TriggeringPreconditionsT preconditions_overwrite;
+ preconditions_overwrite.max_input_length = 0;
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(
+ TriggeringPreconditions::Pack(builder, &preconditions_overwrite));
+ TestSuggestActionsWithThreshold(
+ // Keep model untouched.
+ [](ActionsModelT* actions_model) {}, unilib_.get(),
+ /*expected_size=*/0,
+ std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize()));
+}
+
+#ifdef TC3_UNILIB_ICU
+TEST_F(ActionsSuggestionsTest, SuggestsActionsLowConfidence) {
+ TestSuggestActionsWithThreshold(
+ [](ActionsModelT* actions_model) {
+ actions_model->preconditions->suppress_on_low_confidence_input = true;
+ actions_model->low_confidence_rules.reset(new RulesModelT);
+ actions_model->low_confidence_rules->regex_rule.emplace_back(
+ new RulesModel_::RegexRuleT);
+ actions_model->low_confidence_rules->regex_rule.back()->pattern =
+ "low-ground";
+ },
+ unilib_.get());
+}
+
+TEST_F(ActionsSuggestionsTest, SuggestsActionsLowConfidenceInputOutput) {
+ const std::string actions_model_string =
+ ReadFile(GetModelPath() + kModelFileName);
+ std::unique_ptr<ActionsModelT> actions_model =
+ UnPackActionsModel(actions_model_string.c_str());
+ // Add custom triggering rule.
+ actions_model->rules.reset(new RulesModelT());
+ actions_model->rules->regex_rule.emplace_back(new RulesModel_::RegexRuleT);
+ RulesModel_::RegexRuleT* rule = actions_model->rules->regex_rule.back().get();
+ rule->pattern = "^(?i:hello\\s(there))$";
+ {
+ std::unique_ptr<RulesModel_::RuleActionSpecT> rule_action(
+ new RulesModel_::RuleActionSpecT);
+ rule_action->action.reset(new ActionSuggestionSpecT);
+ rule_action->action->type = "text_reply";
+ rule_action->action->response_text = "General Desaster!";
+ rule_action->action->score = 1.0f;
+ rule_action->action->priority_score = 1.0f;
+ rule->actions.push_back(std::move(rule_action));
+ }
+ {
+ std::unique_ptr<RulesModel_::RuleActionSpecT> rule_action(
+ new RulesModel_::RuleActionSpecT);
+ rule_action->action.reset(new ActionSuggestionSpecT);
+ rule_action->action->type = "text_reply";
+ rule_action->action->response_text = "General Kenobi!";
+ rule_action->action->score = 1.0f;
+ rule_action->action->priority_score = 1.0f;
+ rule->actions.push_back(std::move(rule_action));
+ }
+
+ // Add input-output low confidence rule.
+ actions_model->preconditions->suppress_on_low_confidence_input = true;
+ actions_model->low_confidence_rules.reset(new RulesModelT);
+ actions_model->low_confidence_rules->regex_rule.emplace_back(
+ new RulesModel_::RegexRuleT);
+ actions_model->low_confidence_rules->regex_rule.back()->pattern = "hello";
+ actions_model->low_confidence_rules->regex_rule.back()->output_pattern =
+ "(?i:desaster)";
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishActionsModelBuffer(builder,
+ ActionsModel::Pack(builder, actions_model.get()));
+ std::unique_ptr<ActionsSuggestions> actions_suggestions =
+ ActionsSuggestions::FromUnownedBuffer(
+ reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get());
+ ASSERT_TRUE(actions_suggestions);
+ const ActionsSuggestionsResponse response =
+ actions_suggestions->SuggestActions(
+ {{{/*user_id=*/1, "hello there",
+ /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/{}, /*locales=*/"en"}}});
+ ASSERT_GE(response.actions.size(), 1);
+ EXPECT_EQ(response.actions[0].response_text, "General Kenobi!");
+}
+
+TEST_F(ActionsSuggestionsTest,
+ SuggestsActionsLowConfidenceInputOutputOverwrite) {
+ const std::string actions_model_string =
+ ReadFile(GetModelPath() + kModelFileName);
+ std::unique_ptr<ActionsModelT> actions_model =
+ UnPackActionsModel(actions_model_string.c_str());
+ actions_model->low_confidence_rules.reset();
+
+ // Add custom triggering rule.
+ actions_model->rules.reset(new RulesModelT());
+ actions_model->rules->regex_rule.emplace_back(new RulesModel_::RegexRuleT);
+ RulesModel_::RegexRuleT* rule = actions_model->rules->regex_rule.back().get();
+ rule->pattern = "^(?i:hello\\s(there))$";
+ {
+ std::unique_ptr<RulesModel_::RuleActionSpecT> rule_action(
+ new RulesModel_::RuleActionSpecT);
+ rule_action->action.reset(new ActionSuggestionSpecT);
+ rule_action->action->type = "text_reply";
+ rule_action->action->response_text = "General Desaster!";
+ rule_action->action->score = 1.0f;
+ rule_action->action->priority_score = 1.0f;
+ rule->actions.push_back(std::move(rule_action));
+ }
+ {
+ std::unique_ptr<RulesModel_::RuleActionSpecT> rule_action(
+ new RulesModel_::RuleActionSpecT);
+ rule_action->action.reset(new ActionSuggestionSpecT);
+ rule_action->action->type = "text_reply";
+ rule_action->action->response_text = "General Kenobi!";
+ rule_action->action->score = 1.0f;
+ rule_action->action->priority_score = 1.0f;
+ rule->actions.push_back(std::move(rule_action));
+ }
+
+ // Add custom triggering rule via overwrite.
+ actions_model->preconditions->low_confidence_rules.reset();
+ TriggeringPreconditionsT preconditions;
+ preconditions.suppress_on_low_confidence_input = true;
+ preconditions.low_confidence_rules.reset(new RulesModelT);
+ preconditions.low_confidence_rules->regex_rule.emplace_back(
+ new RulesModel_::RegexRuleT);
+ preconditions.low_confidence_rules->regex_rule.back()->pattern = "hello";
+ preconditions.low_confidence_rules->regex_rule.back()->output_pattern =
+ "(?i:desaster)";
+ flatbuffers::FlatBufferBuilder preconditions_builder;
+ preconditions_builder.Finish(
+ TriggeringPreconditions::Pack(preconditions_builder, &preconditions));
+ std::string serialize_preconditions = std::string(
+ reinterpret_cast<const char*>(preconditions_builder.GetBufferPointer()),
+ preconditions_builder.GetSize());
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishActionsModelBuffer(builder,
+ ActionsModel::Pack(builder, actions_model.get()));
+ std::unique_ptr<ActionsSuggestions> actions_suggestions =
+ ActionsSuggestions::FromUnownedBuffer(
+ reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get(), serialize_preconditions);
+
+ ASSERT_TRUE(actions_suggestions);
+ const ActionsSuggestionsResponse response =
+ actions_suggestions->SuggestActions(
+ {{{/*user_id=*/1, "hello there",
+ /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/{}, /*locales=*/"en"}}});
+ ASSERT_GE(response.actions.size(), 1);
+ EXPECT_EQ(response.actions[0].response_text, "General Kenobi!");
+}
+#endif
+
+TEST_F(ActionsSuggestionsTest, SuppressActionsFromAnnotationsOnSensitiveTopic) {
+ const std::string actions_model_string =
+ ReadFile(GetModelPath() + kModelFileName);
+ std::unique_ptr<ActionsModelT> actions_model =
+ UnPackActionsModel(actions_model_string.c_str());
+
+ // Don't test if no sensitivity score is produced
+ if (actions_model->tflite_model_spec->output_sensitive_topic_score < 0) {
+ return;
+ }
+
+ actions_model->preconditions->max_sensitive_topic_score = 0.0;
+ actions_model->preconditions->suppress_on_sensitive_topic = true;
+ flatbuffers::FlatBufferBuilder builder;
+ FinishActionsModelBuffer(builder,
+ ActionsModel::Pack(builder, actions_model.get()));
+ std::unique_ptr<ActionsSuggestions> actions_suggestions =
+ ActionsSuggestions::FromUnownedBuffer(
+ reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get());
+ AnnotatedSpan annotation;
+ annotation.span = {11, 15};
+ annotation.classification = {
+ ClassificationResult(Collections::Address(), 1.0)};
+ const ActionsSuggestionsResponse response =
+ actions_suggestions->SuggestActions(
+ {{{/*user_id=*/1, "are you at home?",
+ /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/{annotation},
+ /*locales=*/"en"}}});
+ EXPECT_THAT(response.actions, testing::IsEmpty());
+}
+
+TEST_F(ActionsSuggestionsTest, SuggestsActionsWithLongerConversation) {
+ const std::string actions_model_string =
+ ReadFile(GetModelPath() + kModelFileName);
+ std::unique_ptr<ActionsModelT> actions_model =
+ UnPackActionsModel(actions_model_string.c_str());
+
+ // Allow a larger conversation context.
+ actions_model->max_conversation_history_length = 10;
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishActionsModelBuffer(builder,
+ ActionsModel::Pack(builder, actions_model.get()));
+ std::unique_ptr<ActionsSuggestions> actions_suggestions =
+ ActionsSuggestions::FromUnownedBuffer(
+ reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get());
+ AnnotatedSpan annotation;
+ annotation.span = {11, 15};
+ annotation.classification = {
+ ClassificationResult(Collections::Address(), 1.0)};
+ const ActionsSuggestionsResponse response =
+ actions_suggestions->SuggestActions(
+ {{{/*user_id=*/ActionsSuggestions::kLocalUserId, "hi, how are you?",
+ /*reference_time_ms_utc=*/10000,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/{}, /*locales=*/"en"},
+ {/*user_id=*/1, "good! are you at home?",
+ /*reference_time_ms_utc=*/15000,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/{annotation},
+ /*locales=*/"en"}}});
+ ASSERT_GE(response.actions.size(), 1);
+ EXPECT_EQ(response.actions[0].type, "view_map");
+ EXPECT_EQ(response.actions[0].score, 1.0);
+}
+
+TEST_F(ActionsSuggestionsTest, SuggestsActionsFromTF2MultiTaskModel) {
+ std::unique_ptr<ActionsSuggestions> actions_suggestions =
+ LoadTestModel(kMultiTaskTF2TestModelFileName);
+ const ActionsSuggestionsResponse response =
+ actions_suggestions->SuggestActions(
+ {{{/*user_id=*/1, "Hello how are you",
+ /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/{},
+ /*locales=*/"en"}}});
+ EXPECT_EQ(response.actions.size(), 4);
+ EXPECT_EQ(response.actions[0].response_text, "Okay");
+ EXPECT_EQ(response.actions[0].type, "REPLY_SUGGESTION");
+ EXPECT_EQ(response.actions[3].type, "TEST_CLASSIFIER_INTENT");
+}
+
+TEST_F(ActionsSuggestionsTest, SuggestsActionsFromPhoneGrammarAnnotations) {
+ std::unique_ptr<ActionsSuggestions> actions_suggestions =
+ LoadTestModel(kModelGrammarFileName);
+ AnnotatedSpan annotation;
+ annotation.span = {11, 15};
+ annotation.classification = {ClassificationResult("phone", 0.0)};
+ const ActionsSuggestionsResponse response =
+ actions_suggestions->SuggestActions(
+ {{{/*user_id=*/1, "Contact us at: *1234",
+ /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/{annotation},
+ /*locales=*/"en"}}});
+ ASSERT_GE(response.actions.size(), 1);
+ EXPECT_EQ(response.actions.front().type, "call_phone");
+ EXPECT_EQ(response.actions.front().score, 0.0);
+ EXPECT_EQ(response.actions.front().priority_score, 0.0);
+ EXPECT_EQ(response.actions.front().annotations.size(), 1);
+ EXPECT_EQ(response.actions.front().annotations.front().span.span.first, 15);
+ EXPECT_EQ(response.actions.front().annotations.front().span.span.second, 20);
+}
+
+TEST_F(ActionsSuggestionsTest, CreateActionsFromClassificationResult) {
+ std::unique_ptr<ActionsSuggestions> actions_suggestions =
+ LoadTestModel(kModelFileName);
+ AnnotatedSpan annotation;
+ annotation.span = {8, 12};
+ annotation.classification = {
+ ClassificationResult(Collections::Flight(), 1.0)};
+
+ const ActionsSuggestionsResponse response =
+ actions_suggestions->SuggestActions(
+ {{{/*user_id=*/1, "I'm on LX38?",
+ /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/{annotation},
+ /*locales=*/"en"}}});
+
+ ASSERT_GE(response.actions.size(), 2);
+ EXPECT_EQ(response.actions[0].type, "track_flight");
+ EXPECT_EQ(response.actions[0].score, 1.0);
+ EXPECT_THAT(response.actions[0].annotations, SizeIs(1));
+ EXPECT_EQ(response.actions[0].annotations[0].span.message_index, 0);
+ EXPECT_EQ(response.actions[0].annotations[0].span.span, annotation.span);
+}
+
+#ifdef TC3_UNILIB_ICU
+TEST_F(ActionsSuggestionsTest, CreateActionsFromRules) {
+ const std::string actions_model_string =
+ ReadFile(GetModelPath() + kModelFileName);
+ std::unique_ptr<ActionsModelT> actions_model =
+ UnPackActionsModel(actions_model_string.c_str());
+ ASSERT_TRUE(DecompressActionsModel(actions_model.get()));
+
+ actions_model->rules.reset(new RulesModelT());
+ actions_model->rules->regex_rule.emplace_back(new RulesModel_::RegexRuleT);
+ RulesModel_::RegexRuleT* rule = actions_model->rules->regex_rule.back().get();
+ rule->pattern = "^(?i:hello\\s(there))$";
+ rule->actions.emplace_back(new RulesModel_::RuleActionSpecT);
+ rule->actions.back()->action.reset(new ActionSuggestionSpecT);
+ ActionSuggestionSpecT* action = rule->actions.back()->action.get();
+ action->type = "text_reply";
+ action->response_text = "General Kenobi!";
+ action->score = 1.0f;
+ action->priority_score = 1.0f;
+
+ // Set capturing groups for entity data.
+ rule->actions.back()->capturing_group.emplace_back(
+ new RulesModel_::RuleActionSpec_::RuleCapturingGroupT);
+ RulesModel_::RuleActionSpec_::RuleCapturingGroupT* greeting_group =
+ rule->actions.back()->capturing_group.back().get();
+ greeting_group->group_id = 0;
+ greeting_group->entity_field.reset(new FlatbufferFieldPathT);
+ greeting_group->entity_field->field.emplace_back(new FlatbufferFieldT);
+ greeting_group->entity_field->field.back()->field_name = "greeting";
+ rule->actions.back()->capturing_group.emplace_back(
+ new RulesModel_::RuleActionSpec_::RuleCapturingGroupT);
+ RulesModel_::RuleActionSpec_::RuleCapturingGroupT* location_group =
+ rule->actions.back()->capturing_group.back().get();
+ location_group->group_id = 1;
+ location_group->entity_field.reset(new FlatbufferFieldPathT);
+ location_group->entity_field->field.emplace_back(new FlatbufferFieldT);
+ location_group->entity_field->field.back()->field_name = "location";
+
+ // Set test entity data schema.
+ SetTestEntityDataSchema(actions_model.get());
+
+ // Use meta data to generate custom serialized entity data.
+ MutableFlatbufferBuilder entity_data_builder(
+ flatbuffers::GetRoot<reflection::Schema>(
+ actions_model->actions_entity_data_schema.data()));
+ std::unique_ptr<MutableFlatbuffer> entity_data =
+ entity_data_builder.NewRoot();
+ entity_data->Set("person", "Kenobi");
+ action->serialized_entity_data = entity_data->Serialize();
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishActionsModelBuffer(builder,
+ ActionsModel::Pack(builder, actions_model.get()));
+ std::unique_ptr<ActionsSuggestions> actions_suggestions =
+ ActionsSuggestions::FromUnownedBuffer(
+ reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get());
+
+ const ActionsSuggestionsResponse response =
+ actions_suggestions->SuggestActions(
+ {{{/*user_id=*/1, "hello there", /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/{}, /*locales=*/"en"}}});
+ EXPECT_GE(response.actions.size(), 1);
+ EXPECT_EQ(response.actions[0].response_text, "General Kenobi!");
+
+ // Check entity data.
+ const flatbuffers::Table* entity =
+ flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
+ response.actions[0].serialized_entity_data.data()));
+ EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(),
+ "hello there");
+ EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/6)->str(),
+ "there");
+ EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(),
+ "Kenobi");
+}
+
+TEST_F(ActionsSuggestionsTest, CreateActionsFromRulesWithNormalization) {
+ const std::string actions_model_string =
+ ReadFile(GetModelPath() + kModelFileName);
+ std::unique_ptr<ActionsModelT> actions_model =
+ UnPackActionsModel(actions_model_string.c_str());
+ ASSERT_TRUE(DecompressActionsModel(actions_model.get()));
+
+ actions_model->rules.reset(new RulesModelT());
+ actions_model->rules->regex_rule.emplace_back(new RulesModel_::RegexRuleT);
+ RulesModel_::RegexRuleT* rule = actions_model->rules->regex_rule.back().get();
+ rule->pattern = "^(?i:hello\\sthere)$";
+ rule->actions.emplace_back(new RulesModel_::RuleActionSpecT);
+ rule->actions.back()->action.reset(new ActionSuggestionSpecT);
+ ActionSuggestionSpecT* action = rule->actions.back()->action.get();
+ action->type = "text_reply";
+ action->response_text = "General Kenobi!";
+ action->score = 1.0f;
+ action->priority_score = 1.0f;
+
+ // Set capturing groups for entity data.
+ rule->actions.back()->capturing_group.emplace_back(
+ new RulesModel_::RuleActionSpec_::RuleCapturingGroupT);
+ RulesModel_::RuleActionSpec_::RuleCapturingGroupT* greeting_group =
+ rule->actions.back()->capturing_group.back().get();
+ greeting_group->group_id = 0;
+ greeting_group->entity_field.reset(new FlatbufferFieldPathT);
+ greeting_group->entity_field->field.emplace_back(new FlatbufferFieldT);
+ greeting_group->entity_field->field.back()->field_name = "greeting";
+ greeting_group->normalization_options.reset(new NormalizationOptionsT);
+ greeting_group->normalization_options->codepointwise_normalization =
+ NormalizationOptions_::CodepointwiseNormalizationOp_DROP_WHITESPACE |
+ NormalizationOptions_::CodepointwiseNormalizationOp_UPPERCASE;
+
+ // Set test entity data schema.
+ SetTestEntityDataSchema(actions_model.get());
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishActionsModelBuffer(builder,
+ ActionsModel::Pack(builder, actions_model.get()));
+ std::unique_ptr<ActionsSuggestions> actions_suggestions =
+ ActionsSuggestions::FromUnownedBuffer(
+ reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get());
+
+ const ActionsSuggestionsResponse response =
+ actions_suggestions->SuggestActions(
+ {{{/*user_id=*/1, "hello there", /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/{}, /*locales=*/"en"}}});
+ EXPECT_GE(response.actions.size(), 1);
+ EXPECT_EQ(response.actions[0].response_text, "General Kenobi!");
+
+ // Check entity data.
+ const flatbuffers::Table* entity =
+ flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
+ response.actions[0].serialized_entity_data.data()));
+ EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(),
+ "HELLOTHERE");
+}
+
+TEST_F(ActionsSuggestionsTest, CreatesTextRepliesFromRules) {
+ const std::string actions_model_string =
+ ReadFile(GetModelPath() + kModelFileName);
+ std::unique_ptr<ActionsModelT> actions_model =
+ UnPackActionsModel(actions_model_string.c_str());
+ ASSERT_TRUE(DecompressActionsModel(actions_model.get()));
+
+ actions_model->rules.reset(new RulesModelT());
+ actions_model->rules->regex_rule.emplace_back(new RulesModel_::RegexRuleT);
+ RulesModel_::RegexRuleT* rule = actions_model->rules->regex_rule.back().get();
+ rule->pattern = "(?i:reply (stop|quit|end) (?:to|for) )";
+ rule->actions.emplace_back(new RulesModel_::RuleActionSpecT);
+
+ // Set capturing groups for entity data.
+ rule->actions.back()->capturing_group.emplace_back(
+ new RulesModel_::RuleActionSpec_::RuleCapturingGroupT);
+ RulesModel_::RuleActionSpec_::RuleCapturingGroupT* code_group =
+ rule->actions.back()->capturing_group.back().get();
+ code_group->group_id = 1;
+ code_group->text_reply.reset(new ActionSuggestionSpecT);
+ code_group->text_reply->score = 1.0f;
+ code_group->text_reply->priority_score = 1.0f;
+ code_group->normalization_options.reset(new NormalizationOptionsT);
+ code_group->normalization_options->codepointwise_normalization =
+ NormalizationOptions_::CodepointwiseNormalizationOp_LOWERCASE;
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishActionsModelBuffer(builder,
+ ActionsModel::Pack(builder, actions_model.get()));
+ std::unique_ptr<ActionsSuggestions> actions_suggestions =
+ ActionsSuggestions::FromUnownedBuffer(
+ reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get());
+
+ const ActionsSuggestionsResponse response =
+ actions_suggestions->SuggestActions(
+ {{{/*user_id=*/1,
+ "visit test.com or reply STOP to cancel your subscription",
+ /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/{}, /*locales=*/"en"}}});
+ EXPECT_GE(response.actions.size(), 1);
+ EXPECT_EQ(response.actions[0].response_text, "stop");
+}
+
+TEST_F(ActionsSuggestionsTest, CreatesActionsFromGrammarRules) {
+ const std::string actions_model_string =
+ ReadFile(GetModelPath() + kModelFileName);
+ std::unique_ptr<ActionsModelT> actions_model =
+ UnPackActionsModel(actions_model_string.c_str());
+ ASSERT_TRUE(DecompressActionsModel(actions_model.get()));
+
+ actions_model->rules->grammar_rules.reset(new RulesModel_::GrammarRulesT);
+
+ // Set tokenizer options.
+ RulesModel_::GrammarRulesT* action_grammar_rules =
+ actions_model->rules->grammar_rules.get();
+ action_grammar_rules->tokenizer_options.reset(new ActionsTokenizerOptionsT);
+ action_grammar_rules->tokenizer_options->type = TokenizationType_ICU;
+ action_grammar_rules->tokenizer_options->icu_preserve_whitespace_tokens =
+ false;
+
+ // Setup test rules.
+ action_grammar_rules->rules.reset(new grammar::RulesSetT);
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ grammar::Rules rules(locale_shard_map);
+ rules.Add(
+ "<knock>", {"<^>", "ventura", "!?", "<$>"},
+ /*callback=*/
+ static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
+ /*callback_param=*/0);
+ rules.Finalize().Serialize(/*include_debug_information=*/false,
+ action_grammar_rules->rules.get());
+ action_grammar_rules->actions.emplace_back(new RulesModel_::RuleActionSpecT);
+ RulesModel_::RuleActionSpecT* actions_spec =
+ action_grammar_rules->actions.back().get();
+ actions_spec->action.reset(new ActionSuggestionSpecT);
+ actions_spec->action->response_text = "Yes, Satan?";
+ actions_spec->action->priority_score = 1.0;
+ actions_spec->action->score = 1.0;
+ actions_spec->action->type = "text_reply";
+ action_grammar_rules->rule_match.emplace_back(
+ new RulesModel_::GrammarRules_::RuleMatchT);
+ action_grammar_rules->rule_match.back()->action_id.push_back(0);
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishActionsModelBuffer(builder,
+ ActionsModel::Pack(builder, actions_model.get()));
+ std::unique_ptr<ActionsSuggestions> actions_suggestions =
+ ActionsSuggestions::FromUnownedBuffer(
+ reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get());
+
+ const ActionsSuggestionsResponse response =
+ actions_suggestions->SuggestActions(
+ {{{/*user_id=*/1, "Ventura!",
+ /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/{}, /*locales=*/"en"}}});
+
+ EXPECT_THAT(response.actions, ElementsAre(IsSmartReply("Yes, Satan?")));
+}
+
+#if defined(TC3_UNILIB_ICU) && !defined(TEST_NO_DATETIME)
+TEST_F(ActionsSuggestionsTest, CreatesActionsWithAnnotationsFromGrammarRules) {
+ std::unique_ptr<Annotator> annotator =
+ Annotator::FromPath(GetModelPath() + "en.fb", unilib_.get());
+ const std::string actions_model_string =
+ ReadFile(GetModelPath() + kModelFileName);
+ std::unique_ptr<ActionsModelT> actions_model =
+ UnPackActionsModel(actions_model_string.c_str());
+ ASSERT_TRUE(DecompressActionsModel(actions_model.get()));
+
+ actions_model->rules->grammar_rules.reset(new RulesModel_::GrammarRulesT);
+
+ // Set tokenizer options.
+ RulesModel_::GrammarRulesT* action_grammar_rules =
+ actions_model->rules->grammar_rules.get();
+ action_grammar_rules->tokenizer_options.reset(new ActionsTokenizerOptionsT);
+ action_grammar_rules->tokenizer_options->type = TokenizationType_ICU;
+ action_grammar_rules->tokenizer_options->icu_preserve_whitespace_tokens =
+ false;
+
+ // Setup test rules.
+ action_grammar_rules->rules.reset(new grammar::RulesSetT);
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ grammar::Rules rules(locale_shard_map);
+ rules.Add(
+ "<event>", {"it", "is", "at", "<time>"},
+ /*callback=*/
+ static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
+ /*callback_param=*/0);
+ rules.BindAnnotation("<time>", "time");
+ rules.AddAnnotation("datetime");
+ rules.Finalize().Serialize(/*include_debug_information=*/false,
+ action_grammar_rules->rules.get());
+ action_grammar_rules->actions.emplace_back(new RulesModel_::RuleActionSpecT);
+ RulesModel_::RuleActionSpecT* actions_spec =
+ action_grammar_rules->actions.back().get();
+ actions_spec->action.reset(new ActionSuggestionSpecT);
+ actions_spec->action->priority_score = 1.0;
+ actions_spec->action->score = 1.0;
+ actions_spec->action->type = "create_event";
+ action_grammar_rules->rule_match.emplace_back(
+ new RulesModel_::GrammarRules_::RuleMatchT);
+ action_grammar_rules->rule_match.back()->action_id.push_back(0);
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishActionsModelBuffer(builder,
+ ActionsModel::Pack(builder, actions_model.get()));
+ std::unique_ptr<ActionsSuggestions> actions_suggestions =
+ ActionsSuggestions::FromUnownedBuffer(
+ reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get());
+
+ const ActionsSuggestionsResponse response =
+ actions_suggestions->SuggestActions(
+ {{{/*user_id=*/1, "it is at 10:30",
+ /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/{}, /*locales=*/"en"}}},
+ annotator.get());
+
+ EXPECT_THAT(response.actions, ElementsAre(IsActionOfType("create_event")));
+}
+#endif
+
+TEST_F(ActionsSuggestionsTest, DeduplicateActions) {
+ std::unique_ptr<ActionsSuggestions> actions_suggestions =
+ LoadTestModel(kModelFileName);
+ ActionsSuggestionsResponse response = actions_suggestions->SuggestActions(
+ {{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/{}, /*locales=*/"en"}}});
+
+ // Check that the location sharing model triggered.
+ bool has_location_sharing_action = false;
+ for (const ActionSuggestion& action : response.actions) {
+ if (action.type == ActionsSuggestionsTypes::ShareLocation()) {
+ has_location_sharing_action = true;
+ break;
+ }
+ }
+ EXPECT_TRUE(has_location_sharing_action);
+ const int num_actions = response.actions.size();
+
+ // Add custom rule for location sharing.
+ const std::string actions_model_string =
+ ReadFile(GetModelPath() + kModelFileName);
+ std::unique_ptr<ActionsModelT> actions_model =
+ UnPackActionsModel(actions_model_string.c_str());
+ ASSERT_TRUE(DecompressActionsModel(actions_model.get()));
+
+ actions_model->rules.reset(new RulesModelT());
+ actions_model->rules->regex_rule.emplace_back(new RulesModel_::RegexRuleT);
+ actions_model->rules->regex_rule.back()->pattern =
+ "^(?i:where are you[.?]?)$";
+ actions_model->rules->regex_rule.back()->actions.emplace_back(
+ new RulesModel_::RuleActionSpecT);
+ actions_model->rules->regex_rule.back()->actions.back()->action.reset(
+ new ActionSuggestionSpecT);
+ ActionSuggestionSpecT* action =
+ actions_model->rules->regex_rule.back()->actions.back()->action.get();
+ action->score = 1.0f;
+ action->type = ActionsSuggestionsTypes::ShareLocation();
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishActionsModelBuffer(builder,
+ ActionsModel::Pack(builder, actions_model.get()));
+ actions_suggestions = ActionsSuggestions::FromUnownedBuffer(
+ reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get());
+
+ response = actions_suggestions->SuggestActions(
+ {{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/{}, /*locales=*/"en"}}});
+ EXPECT_THAT(response.actions, SizeIs(num_actions));
+}
+
+TEST_F(ActionsSuggestionsTest, DeduplicateConflictingActions) {
+ std::unique_ptr<ActionsSuggestions> actions_suggestions =
+ LoadTestModel(kModelFileName);
+ AnnotatedSpan annotation;
+ annotation.span = {7, 11};
+ annotation.classification = {
+ ClassificationResult(Collections::Flight(), 1.0)};
+ ActionsSuggestionsResponse response = actions_suggestions->SuggestActions(
+ {{{/*user_id=*/1, "I'm on LX38",
+ /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/{annotation},
+ /*locales=*/"en"}}});
+
+ // Check that the phone actions are present.
+ EXPECT_GE(response.actions.size(), 1);
+ EXPECT_EQ(response.actions[0].type, "track_flight");
+
+ // Add custom rule.
+ const std::string actions_model_string =
+ ReadFile(GetModelPath() + kModelFileName);
+ std::unique_ptr<ActionsModelT> actions_model =
+ UnPackActionsModel(actions_model_string.c_str());
+ ASSERT_TRUE(DecompressActionsModel(actions_model.get()));
+
+ actions_model->rules.reset(new RulesModelT());
+ actions_model->rules->regex_rule.emplace_back(new RulesModel_::RegexRuleT);
+ RulesModel_::RegexRuleT* rule = actions_model->rules->regex_rule.back().get();
+ rule->pattern = "^(?i:I'm on ([a-z0-9]+))$";
+ rule->actions.emplace_back(new RulesModel_::RuleActionSpecT);
+ rule->actions.back()->action.reset(new ActionSuggestionSpecT);
+ ActionSuggestionSpecT* action = rule->actions.back()->action.get();
+ action->score = 1.0f;
+ action->priority_score = 2.0f;
+ action->type = "test_code";
+ rule->actions.back()->capturing_group.emplace_back(
+ new RulesModel_::RuleActionSpec_::RuleCapturingGroupT);
+ RulesModel_::RuleActionSpec_::RuleCapturingGroupT* code_group =
+ rule->actions.back()->capturing_group.back().get();
+ code_group->group_id = 1;
+ code_group->annotation_name = "code";
+ code_group->annotation_type = "code";
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishActionsModelBuffer(builder,
+ ActionsModel::Pack(builder, actions_model.get()));
+ actions_suggestions = ActionsSuggestions::FromUnownedBuffer(
+ reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get());
+
+ response = actions_suggestions->SuggestActions(
+ {{{/*user_id=*/1, "I'm on LX38",
+ /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/{annotation},
+ /*locales=*/"en"}}});
+ EXPECT_GE(response.actions.size(), 1);
+ EXPECT_EQ(response.actions[0].type, "test_code");
+}
+#endif
+
+TEST_F(ActionsSuggestionsTest, RanksActions) {
+ std::unique_ptr<ActionsSuggestions> actions_suggestions =
+ LoadTestModel(kModelFileName);
+ std::vector<AnnotatedSpan> annotations(2);
+ annotations[0].span = {11, 15};
+ annotations[0].classification = {ClassificationResult("address", 1.0)};
+ annotations[1].span = {19, 23};
+ annotations[1].classification = {ClassificationResult("address", 2.0)};
+ const ActionsSuggestionsResponse response =
+ actions_suggestions->SuggestActions(
+ {{{/*user_id=*/1, "are you at home or work?",
+ /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/annotations,
+ /*locales=*/"en"}}});
+ EXPECT_GE(response.actions.size(), 2);
+ EXPECT_EQ(response.actions[0].type, "view_map");
+ EXPECT_EQ(response.actions[0].score, 2.0);
+ EXPECT_EQ(response.actions[1].type, "view_map");
+ EXPECT_EQ(response.actions[1].score, 1.0);
+}
+
+TEST_F(ActionsSuggestionsTest, VisitActionsModel) {
+ EXPECT_TRUE(VisitActionsModel<bool>(GetModelPath() + kModelFileName,
+ [](const ActionsModel* model) {
+ if (model == nullptr) {
+ return false;
+ }
+ return true;
+ }));
+ EXPECT_FALSE(VisitActionsModel<bool>(GetModelPath() + "non_existing_model.fb",
+ [](const ActionsModel* model) {
+ if (model == nullptr) {
+ return false;
+ }
+ return true;
+ }));
+}
+
+TEST_F(ActionsSuggestionsTest, SuggestsActionsWithHashGramModel) {
+ std::unique_ptr<ActionsSuggestions> actions_suggestions =
+ LoadHashGramTestModel();
+ ASSERT_TRUE(actions_suggestions != nullptr);
+ {
+ const ActionsSuggestionsResponse response =
+ actions_suggestions->SuggestActions(
+ {{{/*user_id=*/1, "hello",
+ /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/{},
+ /*locales=*/"en"}}});
+ EXPECT_THAT(response.actions, testing::IsEmpty());
+ }
+ {
+ const ActionsSuggestionsResponse response =
+ actions_suggestions->SuggestActions(
+ {{{/*user_id=*/1, "where are you",
+ /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/{},
+ /*locales=*/"en"}}});
+ EXPECT_THAT(
+ response.actions,
+ ElementsAre(testing::Field(&ActionSuggestion::type, "share_location")));
+ }
+ {
+ const ActionsSuggestionsResponse response =
+ actions_suggestions->SuggestActions(
+ {{{/*user_id=*/1, "do you know johns number",
+ /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/{},
+ /*locales=*/"en"}}});
+ EXPECT_THAT(
+ response.actions,
+ ElementsAre(testing::Field(&ActionSuggestion::type, "share_contact")));
+ }
+}
+
+// Test class to expose token embedding methods for testing.
+class TestingMessageEmbedder : private ActionsSuggestions {
+ public:
+ explicit TestingMessageEmbedder(const ActionsModel* model);
+
+ using ActionsSuggestions::EmbedAndFlattenTokens;
+ using ActionsSuggestions::EmbedTokensPerMessage;
+
+ protected:
+ // EmbeddingExecutor that always returns features based on
+ // the id of the sparse features.
+ class FakeEmbeddingExecutor : public EmbeddingExecutor {
+ public:
+ bool AddEmbedding(const TensorView<int>& sparse_features, float* dest,
+ const int dest_size) const override {
+ TC3_CHECK_GE(dest_size, 1);
+ EXPECT_EQ(sparse_features.size(), 1);
+ dest[0] = sparse_features.data()[0];
+ return true;
+ }
+ };
+
+ std::unique_ptr<UniLib> unilib_;
+};
+
+TestingMessageEmbedder::TestingMessageEmbedder(const ActionsModel* model)
+ : unilib_(CreateUniLibForTesting()) {
+ model_ = model;
+ const ActionsTokenFeatureProcessorOptions* options =
+ model->feature_processor_options();
+ feature_processor_.reset(new ActionsFeatureProcessor(options, unilib_.get()));
+ embedding_executor_.reset(new FakeEmbeddingExecutor());
+ EXPECT_TRUE(
+ EmbedTokenId(options->padding_token_id(), &embedded_padding_token_));
+ EXPECT_TRUE(EmbedTokenId(options->start_token_id(), &embedded_start_token_));
+ EXPECT_TRUE(EmbedTokenId(options->end_token_id(), &embedded_end_token_));
+ token_embedding_size_ = feature_processor_->GetTokenEmbeddingSize();
+ EXPECT_EQ(token_embedding_size_, 1);
+}
+
+class EmbeddingTest : public testing::Test {
+ protected:
+ explicit EmbeddingTest() {
+ model_.feature_processor_options.reset(
+ new ActionsTokenFeatureProcessorOptionsT);
+ options_ = model_.feature_processor_options.get();
+ options_->chargram_orders = {1};
+ options_->num_buckets = 1000;
+ options_->embedding_size = 1;
+ options_->start_token_id = 0;
+ options_->end_token_id = 1;
+ options_->padding_token_id = 2;
+ options_->tokenizer_options.reset(new ActionsTokenizerOptionsT);
+ }
+
+ TestingMessageEmbedder CreateTestingMessageEmbedder() {
+ flatbuffers::FlatBufferBuilder builder;
+ FinishActionsModelBuffer(builder, ActionsModel::Pack(builder, &model_));
+ buffer_ = builder.Release();
+ return TestingMessageEmbedder(
+ flatbuffers::GetRoot<ActionsModel>(buffer_.data()));
+ }
+
+ flatbuffers::DetachedBuffer buffer_;
+ ActionsModelT model_;
+ ActionsTokenFeatureProcessorOptionsT* options_;
+};
+
+TEST_F(EmbeddingTest, EmbedsTokensPerMessageWithNoBounds) {
+ const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
+ std::vector<std::vector<Token>> tokens = {
+ {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}};
+ std::vector<float> embeddings;
+ int max_num_tokens_per_message = 0;
+
+ EXPECT_TRUE(embedder.EmbedTokensPerMessage(tokens, &embeddings,
+ &max_num_tokens_per_message));
+
+ EXPECT_EQ(max_num_tokens_per_message, 3);
+ EXPECT_EQ(embeddings.size(), 3);
+ EXPECT_THAT(embeddings[0], FloatEq(tc3farmhash::Fingerprint64("a", 1) %
+ options_->num_buckets));
+ EXPECT_THAT(embeddings[1], FloatEq(tc3farmhash::Fingerprint64("b", 1) %
+ options_->num_buckets));
+ EXPECT_THAT(embeddings[2], FloatEq(tc3farmhash::Fingerprint64("c", 1) %
+ options_->num_buckets));
+}
+
+TEST_F(EmbeddingTest, EmbedsTokensPerMessageWithPadding) {
+ options_->min_num_tokens_per_message = 5;
+ const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
+ std::vector<std::vector<Token>> tokens = {
+ {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}};
+ std::vector<float> embeddings;
+ int max_num_tokens_per_message = 0;
+
+ EXPECT_TRUE(embedder.EmbedTokensPerMessage(tokens, &embeddings,
+ &max_num_tokens_per_message));
+
+ EXPECT_EQ(max_num_tokens_per_message, 5);
+ EXPECT_EQ(embeddings.size(), 5);
+ EXPECT_THAT(embeddings[0], FloatEq(tc3farmhash::Fingerprint64("a", 1) %
+ options_->num_buckets));
+ EXPECT_THAT(embeddings[1], FloatEq(tc3farmhash::Fingerprint64("b", 1) %
+ options_->num_buckets));
+ EXPECT_THAT(embeddings[2], FloatEq(tc3farmhash::Fingerprint64("c", 1) %
+ options_->num_buckets));
+ EXPECT_THAT(embeddings[3], FloatEq(options_->padding_token_id));
+ EXPECT_THAT(embeddings[4], FloatEq(options_->padding_token_id));
+}
+
+TEST_F(EmbeddingTest, EmbedsTokensPerMessageDropsAtBeginning) {
+ options_->max_num_tokens_per_message = 2;
+ const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
+ std::vector<std::vector<Token>> tokens = {
+ {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}};
+ std::vector<float> embeddings;
+ int max_num_tokens_per_message = 0;
+
+ EXPECT_TRUE(embedder.EmbedTokensPerMessage(tokens, &embeddings,
+ &max_num_tokens_per_message));
+
+ EXPECT_EQ(max_num_tokens_per_message, 2);
+ EXPECT_EQ(embeddings.size(), 2);
+ EXPECT_THAT(embeddings[0], FloatEq(tc3farmhash::Fingerprint64("b", 1) %
+ options_->num_buckets));
+ EXPECT_THAT(embeddings[1], FloatEq(tc3farmhash::Fingerprint64("c", 1) %
+ options_->num_buckets));
+}
+
+TEST_F(EmbeddingTest, EmbedsTokensPerMessageWithMultipleMessagesNoBounds) {
+ const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
+ std::vector<std::vector<Token>> tokens = {
+ {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)},
+ {Token("d", 0, 1), Token("e", 2, 3)}};
+ std::vector<float> embeddings;
+ int max_num_tokens_per_message = 0;
+
+ EXPECT_TRUE(embedder.EmbedTokensPerMessage(tokens, &embeddings,
+ &max_num_tokens_per_message));
+
+ EXPECT_EQ(max_num_tokens_per_message, 3);
+ EXPECT_THAT(embeddings[0], FloatEq(tc3farmhash::Fingerprint64("a", 1) %
+ options_->num_buckets));
+ EXPECT_THAT(embeddings[1], FloatEq(tc3farmhash::Fingerprint64("b", 1) %
+ options_->num_buckets));
+ EXPECT_THAT(embeddings[2], FloatEq(tc3farmhash::Fingerprint64("c", 1) %
+ options_->num_buckets));
+ EXPECT_THAT(embeddings[3], FloatEq(tc3farmhash::Fingerprint64("d", 1) %
+ options_->num_buckets));
+ EXPECT_THAT(embeddings[4], FloatEq(tc3farmhash::Fingerprint64("e", 1) %
+ options_->num_buckets));
+ EXPECT_THAT(embeddings[5], FloatEq(options_->padding_token_id));
+}
+
+TEST_F(EmbeddingTest, EmbedsFlattenedTokensWithNoBounds) {
+ const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
+ std::vector<std::vector<Token>> tokens = {
+ {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}};
+ std::vector<float> embeddings;
+ int total_token_count = 0;
+
+ EXPECT_TRUE(
+ embedder.EmbedAndFlattenTokens(tokens, &embeddings, &total_token_count));
+
+ EXPECT_EQ(total_token_count, 5);
+ EXPECT_EQ(embeddings.size(), 5);
+ EXPECT_THAT(embeddings[0], FloatEq(options_->start_token_id));
+ EXPECT_THAT(embeddings[1], FloatEq(tc3farmhash::Fingerprint64("a", 1) %
+ options_->num_buckets));
+ EXPECT_THAT(embeddings[2], FloatEq(tc3farmhash::Fingerprint64("b", 1) %
+ options_->num_buckets));
+ EXPECT_THAT(embeddings[3], FloatEq(tc3farmhash::Fingerprint64("c", 1) %
+ options_->num_buckets));
+ EXPECT_THAT(embeddings[4], FloatEq(options_->end_token_id));
+}
+
+TEST_F(EmbeddingTest, EmbedsFlattenedTokensWithPadding) {
+ options_->min_num_total_tokens = 7;
+ const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
+ std::vector<std::vector<Token>> tokens = {
+ {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}};
+ std::vector<float> embeddings;
+ int total_token_count = 0;
+
+ EXPECT_TRUE(
+ embedder.EmbedAndFlattenTokens(tokens, &embeddings, &total_token_count));
+
+ EXPECT_EQ(total_token_count, 7);
+ EXPECT_EQ(embeddings.size(), 7);
+ EXPECT_THAT(embeddings[0], FloatEq(options_->start_token_id));
+ EXPECT_THAT(embeddings[1], FloatEq(tc3farmhash::Fingerprint64("a", 1) %
+ options_->num_buckets));
+ EXPECT_THAT(embeddings[2], FloatEq(tc3farmhash::Fingerprint64("b", 1) %
+ options_->num_buckets));
+ EXPECT_THAT(embeddings[3], FloatEq(tc3farmhash::Fingerprint64("c", 1) %
+ options_->num_buckets));
+ EXPECT_THAT(embeddings[4], FloatEq(options_->end_token_id));
+ EXPECT_THAT(embeddings[5], FloatEq(options_->padding_token_id));
+ EXPECT_THAT(embeddings[6], FloatEq(options_->padding_token_id));
+}
+
+TEST_F(EmbeddingTest, EmbedsFlattenedTokensDropsAtBeginning) {
+ options_->max_num_total_tokens = 3;
+ const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
+ std::vector<std::vector<Token>> tokens = {
+ {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}};
+ std::vector<float> embeddings;
+ int total_token_count = 0;
+
+ EXPECT_TRUE(
+ embedder.EmbedAndFlattenTokens(tokens, &embeddings, &total_token_count));
+
+ EXPECT_EQ(total_token_count, 3);
+ EXPECT_EQ(embeddings.size(), 3);
+ EXPECT_THAT(embeddings[0], FloatEq(tc3farmhash::Fingerprint64("b", 1) %
+ options_->num_buckets));
+ EXPECT_THAT(embeddings[1], FloatEq(tc3farmhash::Fingerprint64("c", 1) %
+ options_->num_buckets));
+ EXPECT_THAT(embeddings[2], FloatEq(options_->end_token_id));
+}
+
+TEST_F(EmbeddingTest, EmbedsFlattenedTokensWithMultipleMessagesNoBounds) {
+ const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
+ std::vector<std::vector<Token>> tokens = {
+ {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)},
+ {Token("d", 0, 1), Token("e", 2, 3)}};
+ std::vector<float> embeddings;
+ int total_token_count = 0;
+
+ EXPECT_TRUE(
+ embedder.EmbedAndFlattenTokens(tokens, &embeddings, &total_token_count));
+
+ EXPECT_EQ(total_token_count, 9);
+ EXPECT_EQ(embeddings.size(), 9);
+ EXPECT_THAT(embeddings[0], FloatEq(options_->start_token_id));
+ EXPECT_THAT(embeddings[1], FloatEq(tc3farmhash::Fingerprint64("a", 1) %
+ options_->num_buckets));
+ EXPECT_THAT(embeddings[2], FloatEq(tc3farmhash::Fingerprint64("b", 1) %
+ options_->num_buckets));
+ EXPECT_THAT(embeddings[3], FloatEq(tc3farmhash::Fingerprint64("c", 1) %
+ options_->num_buckets));
+ EXPECT_THAT(embeddings[4], FloatEq(options_->end_token_id));
+ EXPECT_THAT(embeddings[5], FloatEq(options_->start_token_id));
+ EXPECT_THAT(embeddings[6], FloatEq(tc3farmhash::Fingerprint64("d", 1) %
+ options_->num_buckets));
+ EXPECT_THAT(embeddings[7], FloatEq(tc3farmhash::Fingerprint64("e", 1) %
+ options_->num_buckets));
+ EXPECT_THAT(embeddings[8], FloatEq(options_->end_token_id));
+}
+
+TEST_F(EmbeddingTest,
+ EmbedsFlattenedTokensWithMultipleMessagesDropsAtBeginning) {
+ options_->max_num_total_tokens = 7;
+ const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
+ std::vector<std::vector<Token>> tokens = {
+ {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)},
+ {Token("d", 0, 1), Token("e", 2, 3), Token("f", 4, 5)}};
+ std::vector<float> embeddings;
+ int total_token_count = 0;
+
+ EXPECT_TRUE(
+ embedder.EmbedAndFlattenTokens(tokens, &embeddings, &total_token_count));
+
+ EXPECT_EQ(total_token_count, 7);
+ EXPECT_EQ(embeddings.size(), 7);
+ EXPECT_THAT(embeddings[0], FloatEq(tc3farmhash::Fingerprint64("c", 1) %
+ options_->num_buckets));
+ EXPECT_THAT(embeddings[1], FloatEq(options_->end_token_id));
+ EXPECT_THAT(embeddings[2], FloatEq(options_->start_token_id));
+ EXPECT_THAT(embeddings[3], FloatEq(tc3farmhash::Fingerprint64("d", 1) %
+ options_->num_buckets));
+ EXPECT_THAT(embeddings[4], FloatEq(tc3farmhash::Fingerprint64("e", 1) %
+ options_->num_buckets));
+ EXPECT_THAT(embeddings[5], FloatEq(tc3farmhash::Fingerprint64("f", 1) %
+ options_->num_buckets));
+ EXPECT_THAT(embeddings[6], FloatEq(options_->end_token_id));
+}
+
+TEST_F(ActionsSuggestionsTest, MultiTaskSuggestActionsDefault) {
+ std::unique_ptr<ActionsSuggestions> actions_suggestions =
+ LoadMultiTaskTestModel();
+ const ActionsSuggestionsResponse response =
+ actions_suggestions->SuggestActions(
+ {{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/{}, /*locales=*/"en"}}});
+ EXPECT_EQ(response.actions.size(),
+ 11 /* 8 binary classification + 3 smart replies*/);
+}
+
+const float kDisableThresholdVal = 2.0;
+
+constexpr char kSpamThreshold[] = "spam_confidence_threshold";
+constexpr char kLocationThreshold[] = "location_confidence_threshold";
+constexpr char kPhoneThreshold[] = "phone_confidence_threshold";
+constexpr char kWeatherThreshold[] = "weather_confidence_threshold";
+constexpr char kRestaurantsThreshold[] = "restaurants_confidence_threshold";
+constexpr char kMoviesThreshold[] = "movies_confidence_threshold";
+constexpr char kTtrThreshold[] = "time_to_reply_binary_threshold";
+constexpr char kReminderThreshold[] = "reminder_intent_confidence_threshold";
+constexpr char kDiversificationParm[] = "diversification_distance_threshold";
+constexpr char kEmpiricalProbFactor[] = "empirical_probability_factor";
+
+ActionSuggestionOptions GetOptionsToDisableAllClassification() {
+ ActionSuggestionOptions options;
+ // Disable all classification heads.
+ options.model_parameters.insert(
+ {kSpamThreshold, libtextclassifier3::Variant(kDisableThresholdVal)});
+ options.model_parameters.insert(
+ {kLocationThreshold, libtextclassifier3::Variant(kDisableThresholdVal)});
+ options.model_parameters.insert(
+ {kPhoneThreshold, libtextclassifier3::Variant(kDisableThresholdVal)});
+ options.model_parameters.insert(
+ {kWeatherThreshold, libtextclassifier3::Variant(kDisableThresholdVal)});
+ options.model_parameters.insert(
+ {kRestaurantsThreshold,
+ libtextclassifier3::Variant(kDisableThresholdVal)});
+ options.model_parameters.insert(
+ {kMoviesThreshold, libtextclassifier3::Variant(kDisableThresholdVal)});
+ options.model_parameters.insert(
+ {kTtrThreshold, libtextclassifier3::Variant(kDisableThresholdVal)});
+ options.model_parameters.insert(
+ {kReminderThreshold, libtextclassifier3::Variant(kDisableThresholdVal)});
+ return options;
+}
+
+TEST_F(ActionsSuggestionsTest, MultiTaskSuggestActionsSmartReplyOnly) {
+ std::unique_ptr<ActionsSuggestions> actions_suggestions =
+ LoadMultiTaskTestModel();
+ const ActionSuggestionOptions options =
+ GetOptionsToDisableAllClassification();
+ const ActionsSuggestionsResponse response =
+ actions_suggestions->SuggestActions(
+ {{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/{}, /*locales=*/"en"}}},
+ /*annotator=*/nullptr, options);
+ EXPECT_THAT(response.actions,
+ ElementsAre(IsSmartReply("Here"), IsSmartReply("I'm here"),
+ IsSmartReply("I'm home")));
+ EXPECT_EQ(response.actions.size(), 3 /*3 smart replies*/);
+}
+
+const int kUserProfileSize = 1000;
+constexpr char kUserProfileTokenIndex[] = "user_profile_token_index";
+constexpr char kUserProfileTokenWeight[] = "user_profile_token_weight";
+
+ActionSuggestionOptions GetOptionsForSmartReplyP13nModel() {
+ ActionSuggestionOptions options;
+ const std::vector<int> user_profile_token_indexes(kUserProfileSize, 1);
+ const std::vector<float> user_profile_token_weights(kUserProfileSize, 0.1f);
+ options.model_parameters.insert(
+ {kUserProfileTokenIndex,
+ libtextclassifier3::Variant(user_profile_token_indexes)});
+ options.model_parameters.insert(
+ {kUserProfileTokenWeight,
+ libtextclassifier3::Variant(user_profile_token_weights)});
+ return options;
+}
+
+TEST_F(ActionsSuggestionsTest, MultiTaskSuggestActionsSmartReplyP13n) {
+ std::unique_ptr<ActionsSuggestions> actions_suggestions =
+ LoadMultiTaskSrP13nTestModel();
+ const ActionSuggestionOptions options = GetOptionsForSmartReplyP13nModel();
+ const ActionsSuggestionsResponse response =
+ actions_suggestions->SuggestActions(
+ {{{/*user_id=*/1, "How are you?", /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/{}, /*locales=*/"en"}}},
+ /*annotator=*/nullptr, options);
+ EXPECT_EQ(response.actions.size(), 3 /*3 smart replies*/);
+}
+
+TEST_F(ActionsSuggestionsTest,
+ MultiTaskSuggestActionsDiversifiedSmartReplyAndLocation) {
+ std::unique_ptr<ActionsSuggestions> actions_suggestions =
+ LoadMultiTaskTestModel();
+ ActionSuggestionOptions options = GetOptionsToDisableAllClassification();
+ options.model_parameters[kLocationThreshold] =
+ libtextclassifier3::Variant(0.35f);
+ options.model_parameters.insert(
+ {kDiversificationParm, libtextclassifier3::Variant(0.5f)});
+ const ActionsSuggestionsResponse response =
+ actions_suggestions->SuggestActions(
+ {{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/{}, /*locales=*/"en"}}},
+ /*annotator=*/nullptr, options);
+ EXPECT_THAT(
+ response.actions,
+ ElementsAre(IsActionOfType("LOCATION_SHARE"), IsSmartReply("Here"),
+ IsSmartReply("Yes"), IsSmartReply("😟")));
+ EXPECT_EQ(response.actions.size(), 4 /*1 location share + 3 smart replies*/);
+}
+
+TEST_F(ActionsSuggestionsTest,
+ MultiTaskSuggestActionsEmProBoostedSmartReplyAndLocationAndReminder) {
+ std::unique_ptr<ActionsSuggestions> actions_suggestions =
+ LoadMultiTaskTestModel();
+ ActionSuggestionOptions options = GetOptionsToDisableAllClassification();
+ options.model_parameters[kLocationThreshold] =
+ libtextclassifier3::Variant(0.35f);
+ // reminder head always trigger since the threshold is zero.
+ options.model_parameters[kReminderThreshold] =
+ libtextclassifier3::Variant(0.0f);
+ options.model_parameters.insert(
+ {kEmpiricalProbFactor, libtextclassifier3::Variant(2.0f)});
+ const ActionsSuggestionsResponse response =
+ actions_suggestions->SuggestActions(
+ {{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/{}, /*locales=*/"en"}}},
+ /*annotator=*/nullptr, options);
+ EXPECT_THAT(
+ response.actions,
+ ElementsAre(IsSmartReply("Okay"), IsActionOfType("LOCATION_SHARE"),
+ IsSmartReply("Yes"),
+ /*Different emoji than previous test*/ IsSmartReply("😊"),
+ IsActionOfType("REMINDER_INTENT")));
+ EXPECT_EQ(response.actions.size(), 5 /*1 location share + 3 smart replies*/);
+}
+
+TEST_F(ActionsSuggestionsTest, SuggestsActionsFromMultiTaskSrEmojiModel) {
+ std::unique_ptr<ActionsSuggestions> actions_suggestions =
+ LoadTestModel(kMultiTaskSrEmojiModelFileName);
+ const ActionsSuggestionsResponse response =
+ actions_suggestions->SuggestActions(
+ {{{/*user_id=*/1, "hello?",
+ /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/{},
+ /*locales=*/"en"}}});
+ EXPECT_EQ(response.actions.size(), 5);
+ EXPECT_EQ(response.actions[0].response_text, "😁");
+ EXPECT_EQ(response.actions[0].type, "EMOJI_CONCEPT");
+ EXPECT_EQ(response.actions[1].response_text, "Yes");
+ EXPECT_EQ(response.actions[1].type, "REPLY_SUGGESTION");
+}
+
+TEST_F(ActionsSuggestionsTest, SuggestsActionsFromSensitiveTfLiteModel) {
+ std::unique_ptr<ActionsSuggestions> actions_suggestions =
+ LoadTestModel(kSensitiveTFliteModelFileName);
+ const ActionsSuggestionsResponse response =
+ actions_suggestions->SuggestActions(
+ {{{/*user_id=*/1, "I want to kill myself",
+ /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/{},
+ /*locales=*/"en"}}});
+ EXPECT_EQ(response.actions.size(), 0);
+ EXPECT_TRUE(response.is_sensitive);
+ EXPECT_FALSE(response.output_filtered_low_confidence);
+}
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/native/actions/actions_jni.cc b/native/actions/actions_jni.cc
index 7dd0169..9e15a2e 100644
--- a/native/actions/actions_jni.cc
+++ b/native/actions/actions_jni.cc
@@ -28,6 +28,7 @@
#include "annotator/annotator.h"
#include "annotator/annotator_jni_common.h"
#include "utils/base/integral_types.h"
+#include "utils/base/status_macros.h"
#include "utils/base/statusor.h"
#include "utils/intents/intent-generator.h"
#include "utils/intents/jni.h"
@@ -35,19 +36,17 @@
#include "utils/java/jni-base.h"
#include "utils/java/jni-cache.h"
#include "utils/java/jni-helper.h"
-#include "utils/java/string_utils.h"
#include "utils/memory/mmap.h"
using libtextclassifier3::ActionsSuggestions;
using libtextclassifier3::ActionsSuggestionsResponse;
-using libtextclassifier3::ActionSuggestion;
using libtextclassifier3::ActionSuggestionOptions;
using libtextclassifier3::Annotator;
using libtextclassifier3::Conversation;
using libtextclassifier3::IntentGenerator;
+using libtextclassifier3::JStringToUtf8String;
using libtextclassifier3::ScopedLocalRef;
using libtextclassifier3::StatusOr;
-using libtextclassifier3::ToStlString;
// When using the Java's ICU, UniLib needs to be instantiated with a JavaVM
// pointer from JNI. When using a standard ICU the pointer is not needed and the
@@ -74,13 +73,14 @@
std::unique_ptr<IntentGenerator> intent_generator =
IntentGenerator::Create(model->model()->android_intent_options(),
model->model()->resources(), jni_cache);
- std::unique_ptr<RemoteActionTemplatesHandler> template_handler =
- libtextclassifier3::RemoteActionTemplatesHandler::Create(jni_cache);
-
- if (intent_generator == nullptr || template_handler == nullptr) {
+ if (intent_generator == nullptr) {
return nullptr;
}
+ TC3_ASSIGN_OR_RETURN_NULL(
+ std::unique_ptr<RemoteActionTemplatesHandler> template_handler,
+ libtextclassifier3::RemoteActionTemplatesHandler::Create(jni_cache));
+
return new ActionsSuggestionsJniContext(jni_cache, std::move(model),
std::move(intent_generator),
std::move(template_handler));
@@ -121,63 +121,89 @@
return options;
}
-StatusOr<ScopedLocalRef<jobjectArray>> ActionSuggestionsToJObjectArray(
+StatusOr<ScopedLocalRef<jobject>> ActionSuggestionsToJObject(
JNIEnv* env, const ActionsSuggestionsJniContext* context,
jobject app_context,
const reflection::Schema* annotations_entity_data_schema,
- const std::vector<ActionSuggestion>& action_result,
+ const ActionsSuggestionsResponse& action_response,
const Conversation& conversation, const jstring device_locales,
const bool generate_intents) {
- auto status_or_result_class = JniHelper::FindClass(
+ // Find the class ActionSuggestion.
+ auto status_or_action_class = JniHelper::FindClass(
env, TC3_PACKAGE_PATH TC3_ACTIONS_CLASS_NAME_STR "$ActionSuggestion");
- if (!status_or_result_class.ok()) {
+ if (!status_or_action_class.ok()) {
TC3_LOG(ERROR) << "Couldn't find ActionSuggestion class.";
+ return status_or_action_class.status();
+ }
+ ScopedLocalRef<jclass> action_class =
+ std::move(status_or_action_class.ValueOrDie());
+
+ // Find the class ActionSuggestions
+ auto status_or_result_class = JniHelper::FindClass(
+ env, TC3_PACKAGE_PATH TC3_ACTIONS_CLASS_NAME_STR "$ActionSuggestions");
+ if (!status_or_result_class.ok()) {
+ TC3_LOG(ERROR) << "Couldn't find ActionSuggestions class.";
return status_or_result_class.status();
}
ScopedLocalRef<jclass> result_class =
std::move(status_or_result_class.ValueOrDie());
+ // Find the class Slot.
+ auto status_or_slot_class = JniHelper::FindClass(
+ env, TC3_PACKAGE_PATH TC3_ACTIONS_CLASS_NAME_STR "$Slot");
+ if (!status_or_slot_class.ok()) {
+ TC3_LOG(ERROR) << "Couldn't find Slot class.";
+ return status_or_slot_class.status();
+ }
+ ScopedLocalRef<jclass> slot_class =
+ std::move(status_or_slot_class.ValueOrDie());
+
TC3_ASSIGN_OR_RETURN(
- const jmethodID result_class_constructor,
+ const jmethodID action_class_constructor,
JniHelper::GetMethodID(
- env, result_class.get(), "<init>",
+ env, action_class.get(), "<init>",
"(Ljava/lang/String;Ljava/lang/String;F[L" TC3_PACKAGE_PATH
TC3_NAMED_VARIANT_CLASS_NAME_STR
";[B[L" TC3_PACKAGE_PATH TC3_REMOTE_ACTION_TEMPLATE_CLASS_NAME_STR
- ";)V"));
- TC3_ASSIGN_OR_RETURN(ScopedLocalRef<jobjectArray> results,
- JniHelper::NewObjectArray(env, action_result.size(),
- result_class.get(), nullptr));
- for (int i = 0; i < action_result.size(); i++) {
+ ";[L" TC3_PACKAGE_PATH TC3_ACTIONS_CLASS_NAME_STR "$Slot;)V"));
+ TC3_ASSIGN_OR_RETURN(const jmethodID slot_class_constructor,
+ JniHelper::GetMethodID(env, slot_class.get(), "<init>",
+ "(Ljava/lang/String;IIIF)V"));
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jobjectArray> actions,
+ JniHelper::NewObjectArray(env, action_response.actions.size(),
+ action_class.get(), nullptr));
+ for (int i = 0; i < action_response.actions.size(); i++) {
ScopedLocalRef<jobjectArray> extras;
const reflection::Schema* actions_entity_data_schema =
context->model()->entity_data_schema();
if (actions_entity_data_schema != nullptr &&
- !action_result[i].serialized_entity_data.empty()) {
+ !action_response.actions[i].serialized_entity_data.empty()) {
TC3_ASSIGN_OR_RETURN(
extras, context->template_handler()->EntityDataAsNamedVariantArray(
actions_entity_data_schema,
- action_result[i].serialized_entity_data));
+ action_response.actions[i].serialized_entity_data));
}
ScopedLocalRef<jbyteArray> serialized_entity_data;
- if (!action_result[i].serialized_entity_data.empty()) {
+ if (!action_response.actions[i].serialized_entity_data.empty()) {
TC3_ASSIGN_OR_RETURN(
serialized_entity_data,
JniHelper::NewByteArray(
- env, action_result[i].serialized_entity_data.size()));
- env->SetByteArrayRegion(
- serialized_entity_data.get(), 0,
- action_result[i].serialized_entity_data.size(),
+ env, action_response.actions[i].serialized_entity_data.size()));
+ TC3_RETURN_IF_ERROR(JniHelper::SetByteArrayRegion(
+ env, serialized_entity_data.get(), 0,
+ action_response.actions[i].serialized_entity_data.size(),
reinterpret_cast<const jbyte*>(
- action_result[i].serialized_entity_data.data()));
+ action_response.actions[i].serialized_entity_data.data())));
}
ScopedLocalRef<jobjectArray> remote_action_templates_result;
if (generate_intents) {
std::vector<RemoteActionTemplate> remote_action_templates;
if (context->intent_generator()->GenerateIntents(
- device_locales, action_result[i], conversation, app_context,
+ device_locales, action_response.actions[i], conversation,
+ app_context,
/*annotations_entity_data_schema=*/annotations_entity_data_schema,
/*actions_entity_data_schema=*/actions_entity_data_schema,
&remote_action_templates)) {
@@ -190,22 +216,58 @@
TC3_ASSIGN_OR_RETURN(ScopedLocalRef<jstring> reply,
context->jni_cache()->ConvertToJavaString(
- action_result[i].response_text));
+ action_response.actions[i].response_text));
TC3_ASSIGN_OR_RETURN(
ScopedLocalRef<jstring> action_type,
- JniHelper::NewStringUTF(env, action_result[i].type.c_str()));
+ JniHelper::NewStringUTF(env, action_response.actions[i].type.c_str()));
+
+ ScopedLocalRef<jobjectArray> slots;
+ if (!action_response.actions[i].slots.empty()) {
+ TC3_ASSIGN_OR_RETURN(slots,
+ JniHelper::NewObjectArray(
+ env, action_response.actions[i].slots.size(),
+ slot_class.get(), nullptr));
+ for (int j = 0; j < action_response.actions[i].slots.size(); j++) {
+ const Slot& slot_c = action_response.actions[i].slots[j];
+ TC3_ASSIGN_OR_RETURN(ScopedLocalRef<jstring> slot_type,
+ JniHelper::NewStringUTF(env, slot_c.type.c_str()));
+
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jobject> slot,
+ JniHelper::NewObject(
+ env, slot_class.get(), slot_class_constructor, slot_type.get(),
+ slot_c.span.message_index, slot_c.span.span.first,
+ slot_c.span.span.second, slot_c.confidence_score));
+
+ TC3_RETURN_IF_ERROR(
+ JniHelper::SetObjectArrayElement(env, slots.get(), j, slot.get()));
+ }
+ }
TC3_ASSIGN_OR_RETURN(
- ScopedLocalRef<jobject> result,
- JniHelper::NewObject(env, result_class.get(), result_class_constructor,
- reply.get(), action_type.get(),
- static_cast<jfloat>(action_result[i].score),
- extras.get(), serialized_entity_data.get(),
- remote_action_templates_result.get()));
- env->SetObjectArrayElement(results.get(), i, result.get());
+ ScopedLocalRef<jobject> action,
+ JniHelper::NewObject(
+ env, action_class.get(), action_class_constructor, reply.get(),
+ action_type.get(),
+ static_cast<jfloat>(action_response.actions[i].score), extras.get(),
+ serialized_entity_data.get(), remote_action_templates_result.get(),
+ slots.get()));
+ TC3_RETURN_IF_ERROR(
+ JniHelper::SetObjectArrayElement(env, actions.get(), i, action.get()));
}
- return results;
+
+ // Create the ActionSuggestions object.
+ TC3_ASSIGN_OR_RETURN(
+ const jmethodID result_class_constructor,
+ JniHelper::GetMethodID(env, result_class.get(), "<init>",
+ "([L" TC3_PACKAGE_PATH TC3_ACTIONS_CLASS_NAME_STR
+ "$ActionSuggestion;Z)V"));
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jobject> result,
+ JniHelper::NewObject(env, result_class.get(), result_class_constructor,
+ actions.get(), action_response.is_sensitive));
+ return result;
}
StatusOr<ConversationMessage> FromJavaConversationMessage(JNIEnv* env,
@@ -262,13 +324,14 @@
env, jmessage, get_detected_text_language_tags_method));
ConversationMessage message;
- TC3_ASSIGN_OR_RETURN(message.text, ToStlString(env, text.get()));
+ TC3_ASSIGN_OR_RETURN(message.text, JStringToUtf8String(env, text.get()));
message.user_id = user_id;
message.reference_time_ms_utc = reference_time;
TC3_ASSIGN_OR_RETURN(message.reference_timezone,
- ToStlString(env, reference_timezone.get()));
- TC3_ASSIGN_OR_RETURN(message.detected_text_language_tags,
- ToStlString(env, detected_text_language_tags.get()));
+ JStringToUtf8String(env, reference_timezone.get()));
+ TC3_ASSIGN_OR_RETURN(
+ message.detected_text_language_tags,
+ JStringToUtf8String(env, detected_text_language_tags.get()));
return message;
}
@@ -295,7 +358,8 @@
env, jconversation, get_conversation_messages_method));
std::vector<ConversationMessage> messages;
- const int size = env->GetArrayLength(jmessages.get());
+ TC3_ASSIGN_OR_RETURN(const int size,
+ JniHelper::GetArrayLength(env, jmessages.get()));
for (int i = 0; i < size; i++) {
TC3_ASSIGN_OR_RETURN(
ScopedLocalRef<jobject> jmessage,
@@ -350,84 +414,89 @@
} // namespace libtextclassifier3
using libtextclassifier3::ActionsSuggestionsJniContext;
-using libtextclassifier3::ActionSuggestionsToJObjectArray;
+using libtextclassifier3::ActionSuggestionsToJObject;
using libtextclassifier3::FromJavaActionSuggestionOptions;
using libtextclassifier3::FromJavaConversation;
+using libtextclassifier3::JByteArrayToString;
TC3_JNI_METHOD(jlong, TC3_ACTIONS_CLASS_NAME, nativeNewActionsModel)
-(JNIEnv* env, jobject thiz, jint fd, jbyteArray serialized_preconditions) {
+(JNIEnv* env, jobject clazz, jint fd, jbyteArray jserialized_preconditions) {
std::shared_ptr<libtextclassifier3::JniCache> jni_cache =
libtextclassifier3::JniCache::Create(env);
- std::string preconditions;
- if (serialized_preconditions != nullptr &&
- !libtextclassifier3::JByteArrayToString(env, serialized_preconditions,
- &preconditions)) {
- TC3_LOG(ERROR) << "Could not convert serialized preconditions.";
- return 0;
+ std::string serialized_preconditions;
+ if (jserialized_preconditions != nullptr) {
+ TC3_ASSIGN_OR_RETURN_0(
+ serialized_preconditions,
+ JByteArrayToString(env, jserialized_preconditions),
+ TC3_LOG(ERROR) << "Could not convert serialized preconditions.");
}
+
#ifdef TC3_UNILIB_JAVAICU
return reinterpret_cast<jlong>(ActionsSuggestionsJniContext::Create(
- jni_cache,
- ActionsSuggestions::FromFileDescriptor(
- fd, std::unique_ptr<UniLib>(new UniLib(jni_cache)), preconditions)));
+ jni_cache, ActionsSuggestions::FromFileDescriptor(
+ fd, std::unique_ptr<UniLib>(new UniLib(jni_cache)),
+ serialized_preconditions)));
#else
return reinterpret_cast<jlong>(ActionsSuggestionsJniContext::Create(
- jni_cache, ActionsSuggestions::FromFileDescriptor(fd, /*unilib=*/nullptr,
- preconditions)));
+ jni_cache, ActionsSuggestions::FromFileDescriptor(
+ fd, /*unilib=*/nullptr, serialized_preconditions)));
#endif // TC3_UNILIB_JAVAICU
}
TC3_JNI_METHOD(jlong, TC3_ACTIONS_CLASS_NAME, nativeNewActionsModelFromPath)
-(JNIEnv* env, jobject thiz, jstring path, jbyteArray serialized_preconditions) {
+(JNIEnv* env, jobject clazz, jstring path,
+ jbyteArray jserialized_preconditions) {
std::shared_ptr<libtextclassifier3::JniCache> jni_cache =
libtextclassifier3::JniCache::Create(env);
- TC3_ASSIGN_OR_RETURN_0(const std::string path_str, ToStlString(env, path));
- std::string preconditions;
- if (serialized_preconditions != nullptr &&
- !libtextclassifier3::JByteArrayToString(env, serialized_preconditions,
- &preconditions)) {
- TC3_LOG(ERROR) << "Could not convert serialized preconditions.";
- return 0;
+ TC3_ASSIGN_OR_RETURN_0(const std::string path_str,
+ JStringToUtf8String(env, path));
+ std::string serialized_preconditions;
+ if (jserialized_preconditions != nullptr) {
+ TC3_ASSIGN_OR_RETURN_0(
+ serialized_preconditions,
+ JByteArrayToString(env, jserialized_preconditions),
+ TC3_LOG(ERROR) << "Could not convert serialized preconditions.");
}
#ifdef TC3_UNILIB_JAVAICU
return reinterpret_cast<jlong>(ActionsSuggestionsJniContext::Create(
jni_cache, ActionsSuggestions::FromPath(
path_str, std::unique_ptr<UniLib>(new UniLib(jni_cache)),
- preconditions)));
+ serialized_preconditions)));
#else
return reinterpret_cast<jlong>(ActionsSuggestionsJniContext::Create(
jni_cache, ActionsSuggestions::FromPath(path_str, /*unilib=*/nullptr,
- preconditions)));
+ serialized_preconditions)));
#endif // TC3_UNILIB_JAVAICU
}
TC3_JNI_METHOD(jlong, TC3_ACTIONS_CLASS_NAME, nativeNewActionsModelWithOffset)
-(JNIEnv* env, jobject thiz, jint fd, jlong offset, jlong size,
- jbyteArray serialized_preconditions) {
+(JNIEnv* env, jobject clazz, jint fd, jlong offset, jlong size,
+ jbyteArray jserialized_preconditions) {
std::shared_ptr<libtextclassifier3::JniCache> jni_cache =
libtextclassifier3::JniCache::Create(env);
- std::string preconditions;
- if (serialized_preconditions != nullptr &&
- !libtextclassifier3::JByteArrayToString(env, serialized_preconditions,
- &preconditions)) {
- TC3_LOG(ERROR) << "Could not convert serialized preconditions.";
- return 0;
+ std::string serialized_preconditions;
+ if (jserialized_preconditions != nullptr) {
+ TC3_ASSIGN_OR_RETURN_0(
+ serialized_preconditions,
+ JByteArrayToString(env, jserialized_preconditions),
+ TC3_LOG(ERROR) << "Could not convert serialized preconditions.");
}
#ifdef TC3_UNILIB_JAVAICU
return reinterpret_cast<jlong>(ActionsSuggestionsJniContext::Create(
jni_cache,
ActionsSuggestions::FromFileDescriptor(
fd, offset, size, std::unique_ptr<UniLib>(new UniLib(jni_cache)),
- preconditions)));
+ serialized_preconditions)));
#else
return reinterpret_cast<jlong>(ActionsSuggestionsJniContext::Create(
- jni_cache, ActionsSuggestions::FromFileDescriptor(
- fd, offset, size, /*unilib=*/nullptr, preconditions)));
+ jni_cache,
+ ActionsSuggestions::FromFileDescriptor(
+ fd, offset, size, /*unilib=*/nullptr, serialized_preconditions)));
#endif // TC3_UNILIB_JAVAICU
}
-TC3_JNI_METHOD(jobjectArray, TC3_ACTIONS_CLASS_NAME, nativeSuggestActions)
-(JNIEnv* env, jobject clazz, jlong ptr, jobject jconversation, jobject joptions,
+TC3_JNI_METHOD(jobject, TC3_ACTIONS_CLASS_NAME, nativeSuggestActions)
+(JNIEnv* env, jobject thiz, jlong ptr, jobject jconversation, jobject joptions,
jlong annotatorPtr, jobject app_context, jstring device_locales,
jboolean generate_intents) {
if (!ptr) {
@@ -448,15 +517,15 @@
annotator ? annotator->entity_data_schema() : nullptr;
TC3_ASSIGN_OR_RETURN_NULL(
- ScopedLocalRef<jobjectArray> result,
- ActionSuggestionsToJObjectArray(
- env, context, app_context, anntotations_entity_data_schema,
- response.actions, conversation, device_locales, generate_intents));
+ ScopedLocalRef<jobject> result,
+ ActionSuggestionsToJObject(
+ env, context, app_context, anntotations_entity_data_schema, response,
+ conversation, device_locales, generate_intents));
return result.release();
}
TC3_JNI_METHOD(void, TC3_ACTIONS_CLASS_NAME, nativeCloseActionsModel)
-(JNIEnv* env, jobject clazz, jlong model_ptr) {
+(JNIEnv* env, jobject thiz, jlong model_ptr) {
const ActionsSuggestionsJniContext* context =
reinterpret_cast<ActionsSuggestionsJniContext*>(model_ptr);
delete context;
@@ -515,3 +584,30 @@
new libtextclassifier3::ScopedMmap(fd, offset, size));
return libtextclassifier3::GetVersionFromMmap(env, mmap.get());
}
+
+TC3_JNI_METHOD(jlong, TC3_ACTIONS_CLASS_NAME, nativeGetNativeModelPtr)
+(JNIEnv* env, jobject thiz, jlong ptr) {
+ if (!ptr) {
+ return 0L;
+ }
+ return reinterpret_cast<jlong>(
+ reinterpret_cast<ActionsSuggestionsJniContext*>(ptr)->model());
+}
+
+TC3_JNI_METHOD(jboolean, TC3_ACTIONS_CLASS_NAME,
+ nativeInitializeConversationIntentDetection)
+(JNIEnv* env, jobject thiz, jlong ptr, jbyteArray jserialized_config) {
+ if (!ptr) {
+ return false;
+ }
+
+ ActionsSuggestions* model =
+ reinterpret_cast<ActionsSuggestionsJniContext*>(ptr)->model();
+
+ std::string serialized_config;
+ TC3_ASSIGN_OR_RETURN_0(
+ serialized_config, JByteArrayToString(env, jserialized_config),
+ TC3_LOG(ERROR) << "Could not convert serialized conversation intent "
+ "detection config.");
+ return model->InitializeConversationIntentDetection(serialized_config);
+}
diff --git a/native/actions/actions_jni.h b/native/actions/actions_jni.h
index 276e361..2d2d103 100644
--- a/native/actions/actions_jni.h
+++ b/native/actions/actions_jni.h
@@ -32,16 +32,20 @@
#endif
TC3_JNI_METHOD(jlong, TC3_ACTIONS_CLASS_NAME, nativeNewActionsModel)
-(JNIEnv* env, jobject thiz, jint fd, jbyteArray serialized_preconditions);
+(JNIEnv* env, jobject clazz, jint fd, jbyteArray serialized_preconditions);
TC3_JNI_METHOD(jlong, TC3_ACTIONS_CLASS_NAME, nativeNewActionsModelFromPath)
-(JNIEnv* env, jobject thiz, jstring path, jbyteArray serialized_preconditions);
+(JNIEnv* env, jobject clazz, jstring path, jbyteArray serialized_preconditions);
TC3_JNI_METHOD(jlong, TC3_ACTIONS_CLASS_NAME, nativeNewActionsModelWithOffset)
-(JNIEnv* env, jobject thiz, jint fd, jlong offset, jlong size,
+(JNIEnv* env, jobject clazz, jint fd, jlong offset, jlong size,
jbyteArray serialized_preconditions);
-TC3_JNI_METHOD(jobjectArray, TC3_ACTIONS_CLASS_NAME, nativeSuggestActions)
+TC3_JNI_METHOD(jboolean, TC3_ACTIONS_CLASS_NAME,
+ nativeInitializeConversationIntentDetection)
+(JNIEnv* env, jobject thiz, jlong ptr, jbyteArray jserialized_config);
+
+TC3_JNI_METHOD(jobject, TC3_ACTIONS_CLASS_NAME, nativeSuggestActions)
(JNIEnv* env, jobject thiz, jlong ptr, jobject jconversation, jobject joptions,
jlong annotatorPtr, jobject app_context, jstring device_locales,
jboolean generate_intents);
@@ -67,6 +71,9 @@
TC3_JNI_METHOD(jint, TC3_ACTIONS_CLASS_NAME, nativeGetVersionWithOffset)
(JNIEnv* env, jobject clazz, jint fd, jlong offset, jlong size);
+TC3_JNI_METHOD(jlong, TC3_ACTIONS_CLASS_NAME, nativeGetNativeModelPtr)
+(JNIEnv* env, jobject thiz, jlong ptr);
+
#ifdef __cplusplus
}
#endif
diff --git a/native/actions/actions_model.fbs b/native/actions/actions_model.fbs
old mode 100755
new mode 100644
index 251610e..8c03eeb
--- a/native/actions/actions_model.fbs
+++ b/native/actions/actions_model.fbs
@@ -17,7 +17,7 @@
include "actions/actions-entity-data.fbs";
include "annotator/model.fbs";
include "utils/codepoint-range.fbs";
-include "utils/flatbuffers.fbs";
+include "utils/flatbuffers/flatbuffers.fbs";
include "utils/grammar/rules.fbs";
include "utils/intents/intent-config.fbs";
include "utils/normalization.fbs";
@@ -75,13 +75,10 @@
// int, the number of smart replies to produce.
input_num_suggestions:int = 4;
- // float, the output diversification distance parameter.
reserved_7:int (deprecated);
- // float, the empirical probability factor parameter.
reserved_8:int (deprecated);
- // float, the confidence threshold.
reserved_9:int (deprecated);
// Input port for hashed and embedded tokens, a (num messages, max tokens,
@@ -119,6 +116,10 @@
// Map of additional input tensor name to its index.
input_name_index:[TensorflowLiteModelSpec_.InputNameIndexEntry];
+
+ // If greater than 0, pad or truncate the input_user_id and input_context
+ // tensor to length of input_length_to_pad.
+ input_length_to_pad:int = 0;
}
// Configuration for the tokenizer.
@@ -245,6 +246,17 @@
tokenizer_options:ActionsTokenizerOptions;
}
+// TFLite based sensitive topic classifier model.
+namespace libtextclassifier3;
+table TFLiteSensitiveClassifierConfig {
+ // Specification of the model.
+ model_spec:TensorflowLiteModelSpec;
+
+ // Triggering threshold, if a sensitive topic has a score higher than this
+ // value, it triggers the classifier.
+ threshold:float;
+}
+
namespace libtextclassifier3;
table TriggeringPreconditions {
// Lower bound thresholds for the smart reply model prediction output.
@@ -280,7 +292,9 @@
low_confidence_rules:RulesModel;
reserved_11:float (deprecated);
+
reserved_12:float (deprecated);
+
reserved_13:float (deprecated);
// Smart reply thresholds.
@@ -551,6 +565,8 @@
// Feature processor options.
feature_processor_options:ActionsTokenFeatureProcessorOptions;
+
+ low_confidence_tflite_model:TFLiteSensitiveClassifierConfig;
}
root_type libtextclassifier3.ActionsModel;
diff --git a/native/actions/conversation_intent_detection/conversation-intent-detection-dummy.h b/native/actions/conversation_intent_detection/conversation-intent-detection-dummy.h
new file mode 100644
index 0000000..66255c5
--- /dev/null
+++ b/native/actions/conversation_intent_detection/conversation-intent-detection-dummy.h
@@ -0,0 +1,45 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ACTIONS_CONVERSATION_INTENT_DETECTION_CONVERSATION_INTENT_DETECTION_DUMMY_H_
+#define LIBTEXTCLASSIFIER_ACTIONS_CONVERSATION_INTENT_DETECTION_CONVERSATION_INTENT_DETECTION_DUMMY_H_
+
+#include <string>
+#include <vector>
+
+#include "actions/types.h"
+#include "utils/base/status.h"
+#include "utils/base/statusor.h"
+
+namespace libtextclassifier3 {
+
+// A dummy implementation of conversation intent detection.
+class ConversationIntentDetection {
+ public:
+ ConversationIntentDetection() {}
+
+ Status Initialize(const std::string& serialized_config) { return Status::OK; }
+
+ StatusOr<std::vector<ActionSuggestion>> SuggestActions(
+ const Conversation& conversation, const ActionSuggestionOptions& options =
+ ActionSuggestionOptions()) const {
+ return Status::OK;
+ }
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ACTIONS_CONVERSATION_INTENT_DETECTION_CONVERSATION_INTENT_DETECTION_DUMMY_H_
diff --git a/native/actions/conversation_intent_detection/conversation-intent-detection.h b/native/actions/conversation_intent_detection/conversation-intent-detection.h
new file mode 100644
index 0000000..949ceaf
--- /dev/null
+++ b/native/actions/conversation_intent_detection/conversation-intent-detection.h
@@ -0,0 +1,22 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ACTIONS_CONVERSATION_INTENT_DETECTION_CONVERSATION_INTENT_DETECTION_H_
+#define LIBTEXTCLASSIFIER_ACTIONS_CONVERSATION_INTENT_DETECTION_CONVERSATION_INTENT_DETECTION_H_
+
+#include "actions/conversation_intent_detection/conversation-intent-detection-dummy.h"
+
+#endif // LIBTEXTCLASSIFIER_ACTIONS_CONVERSATION_INTENT_DETECTION_CONVERSATION_INTENT_DETECTION_H_
diff --git a/native/actions/feature-processor_test.cc b/native/actions/feature-processor_test.cc
index 969bbf7..e36af90 100644
--- a/native/actions/feature-processor_test.cc
+++ b/native/actions/feature-processor_test.cc
@@ -47,9 +47,9 @@
std::vector<float> storage_;
};
-class FeatureProcessorTest : public ::testing::Test {
+class ActionsFeatureProcessorTest : public ::testing::Test {
protected:
- FeatureProcessorTest() : INIT_UNILIB_FOR_TESTING(unilib_) {}
+ ActionsFeatureProcessorTest() : INIT_UNILIB_FOR_TESTING(unilib_) {}
flatbuffers::DetachedBuffer PackFeatureProcessorOptions(
ActionsTokenFeatureProcessorOptionsT* options) const {
@@ -62,7 +62,7 @@
UniLib unilib_;
};
-TEST_F(FeatureProcessorTest, TokenEmbeddings) {
+TEST_F(ActionsFeatureProcessorTest, TokenEmbeddings) {
ActionsTokenFeatureProcessorOptionsT options;
options.embedding_size = 4;
options.tokenizer_options.reset(new ActionsTokenizerOptionsT);
@@ -81,7 +81,7 @@
EXPECT_THAT(token_features, SizeIs(4));
}
-TEST_F(FeatureProcessorTest, TokenEmbeddingsCaseFeature) {
+TEST_F(ActionsFeatureProcessorTest, TokenEmbeddingsCaseFeature) {
ActionsTokenFeatureProcessorOptionsT options;
options.embedding_size = 4;
options.extract_case_feature = true;
@@ -102,7 +102,7 @@
EXPECT_THAT(token_features[4], FloatEq(1.0));
}
-TEST_F(FeatureProcessorTest, MultipleTokenEmbeddingsCaseFeature) {
+TEST_F(ActionsFeatureProcessorTest, MultipleTokenEmbeddingsCaseFeature) {
ActionsTokenFeatureProcessorOptionsT options;
options.embedding_size = 4;
options.extract_case_feature = true;
diff --git a/native/actions/grammar-actions.cc b/native/actions/grammar-actions.cc
index 7f3e71f..bf99edc 100644
--- a/native/actions/grammar-actions.cc
+++ b/native/actions/grammar-actions.cc
@@ -16,204 +16,120 @@
#include "actions/grammar-actions.h"
-#include <algorithm>
-#include <unordered_map>
-
#include "actions/feature-processor.h"
#include "actions/utils.h"
#include "annotator/types.h"
-#include "utils/grammar/callback-delegate.h"
-#include "utils/grammar/match.h"
-#include "utils/grammar/matcher.h"
-#include "utils/grammar/rules-utils.h"
-#include "utils/i18n/language-tag_generated.h"
+#include "utils/base/arena.h"
+#include "utils/base/statusor.h"
#include "utils/utf8/unicodetext.h"
namespace libtextclassifier3 {
-namespace {
-
-class GrammarActionsCallbackDelegate : public grammar::CallbackDelegate {
- public:
- GrammarActionsCallbackDelegate(const UniLib* unilib,
- const RulesModel_::GrammarRules* grammar_rules)
- : unilib_(*unilib), grammar_rules_(grammar_rules) {}
-
- // Handle a grammar rule match in the actions grammar.
- void MatchFound(const grammar::Match* match, grammar::CallbackId type,
- int64 value, grammar::Matcher* matcher) override {
- switch (static_cast<GrammarActions::Callback>(type)) {
- case GrammarActions::Callback::kActionRuleMatch: {
- HandleRuleMatch(match, /*rule_id=*/value);
- return;
- }
- default:
- grammar::CallbackDelegate::MatchFound(match, type, value, matcher);
- }
- }
-
- // Deduplicate, verify and populate actions from grammar matches.
- bool GetActions(const Conversation& conversation,
- const std::string& smart_reply_action_type,
- const ReflectiveFlatbufferBuilder* entity_data_builder,
- std::vector<ActionSuggestion>* action_suggestions) const {
- std::vector<UnicodeText::const_iterator> codepoint_offsets;
- const UnicodeText message_unicode =
- UTF8ToUnicodeText(conversation.messages.back().text,
- /*do_copy=*/false);
- for (auto it = message_unicode.begin(); it != message_unicode.end(); it++) {
- codepoint_offsets.push_back(it);
- }
- codepoint_offsets.push_back(message_unicode.end());
- for (const grammar::Derivation& candidate :
- grammar::DeduplicateDerivations(candidates_)) {
- // Check that assertions are fulfilled.
- if (!VerifyAssertions(candidate.match)) {
- continue;
- }
- if (!InstantiateActionsFromMatch(
- codepoint_offsets,
- /*message_index=*/conversation.messages.size() - 1,
- smart_reply_action_type, candidate, entity_data_builder,
- action_suggestions)) {
- return false;
- }
- }
- return true;
- }
-
- private:
- // Handles action rule matches.
- void HandleRuleMatch(const grammar::Match* match, const int64 rule_id) {
- candidates_.push_back(grammar::Derivation{match, rule_id});
- }
-
- // Instantiates action suggestions from verified and deduplicated rule matches
- // and appends them to the result.
- // Expects the message as codepoints for text extraction from capturing
- // matches as well as the index of the message, for correct span production.
- bool InstantiateActionsFromMatch(
- const std::vector<UnicodeText::const_iterator>& message_codepoint_offsets,
- int message_index, const std::string& smart_reply_action_type,
- const grammar::Derivation& candidate,
- const ReflectiveFlatbufferBuilder* entity_data_builder,
- std::vector<ActionSuggestion>* result) const {
- const RulesModel_::GrammarRules_::RuleMatch* rule_match =
- grammar_rules_->rule_match()->Get(candidate.rule_id);
- if (rule_match == nullptr || rule_match->action_id() == nullptr) {
- TC3_LOG(ERROR) << "No rule action defined.";
- return false;
- }
-
- // Gather active capturing matches.
- std::unordered_map<uint16, const grammar::Match*> capturing_matches;
- for (const grammar::MappingMatch* match :
- grammar::SelectAllOfType<grammar::MappingMatch>(
- candidate.match, grammar::Match::kMappingMatch)) {
- capturing_matches[match->id] = match;
- }
-
- // Instantiate actions from the rule match.
- for (const uint16 action_id : *rule_match->action_id()) {
- const RulesModel_::RuleActionSpec* action_spec =
- grammar_rules_->actions()->Get(action_id);
- std::vector<ActionSuggestionAnnotation> annotations;
-
- std::unique_ptr<ReflectiveFlatbuffer> entity_data =
- entity_data_builder != nullptr ? entity_data_builder->NewRoot()
- : nullptr;
-
- // Set information from capturing matches.
- if (action_spec->capturing_group() != nullptr) {
- for (const RulesModel_::RuleActionSpec_::RuleCapturingGroup* group :
- *action_spec->capturing_group()) {
- auto it = capturing_matches.find(group->group_id());
- if (it == capturing_matches.end()) {
- // Capturing match is not active, skip.
- continue;
- }
-
- const grammar::Match* capturing_match = it->second;
- StringPiece match_text = StringPiece(
- message_codepoint_offsets[capturing_match->codepoint_span.first]
- .utf8_data(),
- message_codepoint_offsets[capturing_match->codepoint_span.second]
- .utf8_data() -
- message_codepoint_offsets[capturing_match->codepoint_span
- .first]
- .utf8_data());
- UnicodeText normalized_match_text =
- NormalizeMatchText(unilib_, group, match_text);
-
- if (!MergeEntityDataFromCapturingMatch(
- group, normalized_match_text.ToUTF8String(),
- entity_data.get())) {
- TC3_LOG(ERROR)
- << "Could not merge entity data from a capturing match.";
- return false;
- }
-
- // Add smart reply suggestions.
- SuggestTextRepliesFromCapturingMatch(entity_data_builder, group,
- normalized_match_text,
- smart_reply_action_type, result);
-
- // Add annotation.
- ActionSuggestionAnnotation annotation;
- if (FillAnnotationFromCapturingMatch(
- /*span=*/capturing_match->codepoint_span, group,
- /*message_index=*/message_index, match_text, &annotation)) {
- if (group->use_annotation_match()) {
- const grammar::AnnotationMatch* annotation_match =
- grammar::SelectFirstOfType<grammar::AnnotationMatch>(
- capturing_match, grammar::Match::kAnnotationMatch);
- if (!annotation_match) {
- TC3_LOG(ERROR) << "Could not get annotation for match.";
- return false;
- }
- annotation.entity = *annotation_match->annotation;
- }
- annotations.push_back(std::move(annotation));
- }
- }
- }
-
- if (action_spec->action() != nullptr) {
- ActionSuggestion suggestion;
- suggestion.annotations = annotations;
- FillSuggestionFromSpec(action_spec->action(), entity_data.get(),
- &suggestion);
- result->push_back(std::move(suggestion));
- }
- }
- return true;
- }
-
- const UniLib& unilib_;
- const RulesModel_::GrammarRules* grammar_rules_;
-
- // All action rule match candidates.
- // Grammar rule matches are recorded, deduplicated, verified and then
- // instantiated.
- std::vector<grammar::Derivation> candidates_;
-};
-} // namespace
GrammarActions::GrammarActions(
const UniLib* unilib, const RulesModel_::GrammarRules* grammar_rules,
- const ReflectiveFlatbufferBuilder* entity_data_builder,
+ const MutableFlatbufferBuilder* entity_data_builder,
const std::string& smart_reply_action_type)
: unilib_(*unilib),
grammar_rules_(grammar_rules),
tokenizer_(CreateTokenizer(grammar_rules->tokenizer_options(), unilib)),
- lexer_(unilib, grammar_rules->rules()),
entity_data_builder_(entity_data_builder),
- smart_reply_action_type_(smart_reply_action_type),
- rules_locales_(ParseRulesLocales(grammar_rules->rules())) {}
+ analyzer_(unilib, grammar_rules->rules(), tokenizer_.get()),
+ smart_reply_action_type_(smart_reply_action_type) {}
+bool GrammarActions::InstantiateActionsFromMatch(
+ const grammar::TextContext& text_context, const int message_index,
+ const grammar::Derivation& derivation,
+ std::vector<ActionSuggestion>* result) const {
+ const RulesModel_::GrammarRules_::RuleMatch* rule_match =
+ grammar_rules_->rule_match()->Get(derivation.rule_id);
+ if (rule_match == nullptr || rule_match->action_id() == nullptr) {
+ TC3_LOG(ERROR) << "No rule action defined.";
+ return false;
+ }
+
+ // Gather active capturing matches.
+ std::unordered_map<uint16, const grammar::ParseTree*> capturing_matches;
+ for (const grammar::MappingNode* mapping_node :
+ grammar::SelectAllOfType<grammar::MappingNode>(
+ derivation.parse_tree, grammar::ParseTree::Type::kMapping)) {
+ capturing_matches[mapping_node->id] = mapping_node;
+ }
+
+ // Instantiate actions from the rule match.
+ for (const uint16 action_id : *rule_match->action_id()) {
+ const RulesModel_::RuleActionSpec* action_spec =
+ grammar_rules_->actions()->Get(action_id);
+ std::vector<ActionSuggestionAnnotation> annotations;
+
+ std::unique_ptr<MutableFlatbuffer> entity_data =
+ entity_data_builder_ != nullptr ? entity_data_builder_->NewRoot()
+ : nullptr;
+
+ // Set information from capturing matches.
+ if (action_spec->capturing_group() != nullptr) {
+ for (const RulesModel_::RuleActionSpec_::RuleCapturingGroup* group :
+ *action_spec->capturing_group()) {
+ auto it = capturing_matches.find(group->group_id());
+ if (it == capturing_matches.end()) {
+ // Capturing match is not active, skip.
+ continue;
+ }
+
+ const grammar::ParseTree* capturing_match = it->second;
+ const UnicodeText match_text =
+ text_context.Span(capturing_match->codepoint_span);
+ UnicodeText normalized_match_text =
+ NormalizeMatchText(unilib_, group, match_text);
+
+ if (!MergeEntityDataFromCapturingMatch(
+ group, normalized_match_text.ToUTF8String(),
+ entity_data.get())) {
+ TC3_LOG(ERROR)
+ << "Could not merge entity data from a capturing match.";
+ return false;
+ }
+
+ // Add smart reply suggestions.
+ SuggestTextRepliesFromCapturingMatch(entity_data_builder_, group,
+ normalized_match_text,
+ smart_reply_action_type_, result);
+
+ // Add annotation.
+ ActionSuggestionAnnotation annotation;
+ if (FillAnnotationFromCapturingMatch(
+ /*span=*/capturing_match->codepoint_span, group,
+ /*message_index=*/message_index, match_text.ToUTF8String(),
+ &annotation)) {
+ if (group->use_annotation_match()) {
+ std::vector<const grammar::AnnotationNode*> annotations =
+ grammar::SelectAllOfType<grammar::AnnotationNode>(
+ capturing_match, grammar::ParseTree::Type::kAnnotation);
+ if (annotations.size() != 1) {
+ TC3_LOG(ERROR) << "Could not get annotation for match.";
+ return false;
+ }
+ annotation.entity = *annotations.front()->annotation;
+ }
+ annotations.push_back(std::move(annotation));
+ }
+ }
+ }
+
+ if (action_spec->action() != nullptr) {
+ ActionSuggestion suggestion;
+ suggestion.annotations = annotations;
+ FillSuggestionFromSpec(action_spec->action(), entity_data.get(),
+ &suggestion);
+ result->push_back(std::move(suggestion));
+ }
+ }
+ return true;
+}
bool GrammarActions::SuggestActions(
const Conversation& conversation,
std::vector<ActionSuggestion>* result) const {
- if (grammar_rules_->rules()->rules() == nullptr) {
+ if (grammar_rules_->rules()->rules() == nullptr ||
+ conversation.messages.back().text.empty()) {
// Nothing to do.
return true;
}
@@ -225,30 +141,32 @@
return false;
}
- // Select locale matching rules.
- std::vector<const grammar::RulesSet_::Rules*> locale_rules =
- SelectLocaleMatchingShards(grammar_rules_->rules(), rules_locales_,
- locales);
- if (locale_rules.empty()) {
- // Nothing to do.
- return true;
+ const int message_index = conversation.messages.size() - 1;
+ grammar::TextContext text = analyzer_.BuildTextContextForInput(
+ UTF8ToUnicodeText(conversation.messages.back().text, /*do_copy=*/false),
+ locales);
+ text.annotations = conversation.messages.back().annotations;
+
+ UnsafeArena arena(/*block_size=*/16 << 10);
+ StatusOr<std::vector<grammar::EvaluatedDerivation>> evaluated_derivations =
+ analyzer_.Parse(text, &arena);
+ // TODO(b/171294882): Return the status here and below.
+ if (!evaluated_derivations.ok()) {
+ TC3_LOG(ERROR) << "Could not run grammar analyzer: "
+ << evaluated_derivations.status().error_message();
+ return false;
}
- GrammarActionsCallbackDelegate callback_handler(&unilib_, grammar_rules_);
- grammar::Matcher matcher(&unilib_, grammar_rules_->rules(), locale_rules,
- &callback_handler);
+ for (const grammar::EvaluatedDerivation& evaluated_derivation :
+ evaluated_derivations.ValueOrDie()) {
+ if (!InstantiateActionsFromMatch(text, message_index, evaluated_derivation,
+ result)) {
+ TC3_LOG(ERROR) << "Could not instantiate actions from a grammar match.";
+ return false;
+ }
+ }
- const UnicodeText text =
- UTF8ToUnicodeText(conversation.messages.back().text, /*do_copy=*/false);
-
- // Run grammar on last message.
- lexer_.Process(text, tokenizer_->Tokenize(text),
- /*annotations=*/&conversation.messages.back().annotations,
- &matcher);
-
- // Populate results.
- return callback_handler.GetActions(conversation, smart_reply_action_type_,
- entity_data_builder_, result);
+ return true;
}
} // namespace libtextclassifier3
diff --git a/native/actions/grammar-actions.h b/native/actions/grammar-actions.h
index fc3270d..2a1725f 100644
--- a/native/actions/grammar-actions.h
+++ b/native/actions/grammar-actions.h
@@ -22,11 +22,11 @@
#include "actions/actions_model_generated.h"
#include "actions/types.h"
-#include "utils/flatbuffers.h"
-#include "utils/grammar/lexer.h"
-#include "utils/grammar/types.h"
+#include "utils/flatbuffers/mutable.h"
+#include "utils/grammar/analyzer.h"
+#include "utils/grammar/evaluated-derivation.h"
+#include "utils/grammar/text-context.h"
#include "utils/i18n/locale.h"
-#include "utils/strings/stringpiece.h"
#include "utils/tokenizer.h"
#include "utils/utf8/unilib.h"
@@ -35,27 +35,28 @@
// Grammar backed actions suggestions.
class GrammarActions {
public:
- enum class Callback : grammar::CallbackId { kActionRuleMatch = 1 };
-
- explicit GrammarActions(
- const UniLib* unilib, const RulesModel_::GrammarRules* grammar_rules,
- const ReflectiveFlatbufferBuilder* entity_data_builder,
- const std::string& smart_reply_action_type);
+ explicit GrammarActions(const UniLib* unilib,
+ const RulesModel_::GrammarRules* grammar_rules,
+ const MutableFlatbufferBuilder* entity_data_builder,
+ const std::string& smart_reply_action_type);
// Suggests actions for a conversation from a message stream.
bool SuggestActions(const Conversation& conversation,
std::vector<ActionSuggestion>* result) const;
private:
+ // Creates action suggestions from a grammar match result.
+ bool InstantiateActionsFromMatch(const grammar::TextContext& text_context,
+ int message_index,
+ const grammar::Derivation& derivation,
+ std::vector<ActionSuggestion>* result) const;
+
const UniLib& unilib_;
const RulesModel_::GrammarRules* grammar_rules_;
const std::unique_ptr<Tokenizer> tokenizer_;
- const grammar::Lexer lexer_;
- const ReflectiveFlatbufferBuilder* entity_data_builder_;
+ const MutableFlatbufferBuilder* entity_data_builder_;
+ const grammar::Analyzer analyzer_;
const std::string smart_reply_action_type_;
-
- // Pre-parsed locales of the rules.
- const std::vector<std::vector<Locale>> rules_locales_;
};
} // namespace libtextclassifier3
diff --git a/native/actions/grammar-actions_test.cc b/native/actions/grammar-actions_test.cc
new file mode 100644
index 0000000..02deea9
--- /dev/null
+++ b/native/actions/grammar-actions_test.cc
@@ -0,0 +1,708 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "actions/grammar-actions.h"
+
+#include <iostream>
+#include <memory>
+
+#include "actions/actions_model_generated.h"
+#include "actions/test-utils.h"
+#include "actions/types.h"
+#include "utils/flatbuffers/flatbuffers.h"
+#include "utils/flatbuffers/mutable.h"
+#include "utils/grammar/rules_generated.h"
+#include "utils/grammar/types.h"
+#include "utils/grammar/utils/rules.h"
+#include "utils/jvm-test-utils.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+using ::testing::ElementsAre;
+using ::testing::IsEmpty;
+
+using ::libtextclassifier3::grammar::LocaleShardMap;
+
+class TestGrammarActions : public GrammarActions {
+ public:
+ explicit TestGrammarActions(
+ const UniLib* unilib, const RulesModel_::GrammarRules* grammar_rules,
+ const MutableFlatbufferBuilder* entity_data_builder = nullptr)
+ : GrammarActions(unilib, grammar_rules, entity_data_builder,
+
+ /*smart_reply_action_type=*/"text_reply") {}
+};
+
+class GrammarActionsTest : public testing::Test {
+ protected:
+ struct AnnotationSpec {
+ int group_id = 0;
+ std::string annotation_name = "";
+ bool use_annotation_match = false;
+ };
+
+ GrammarActionsTest()
+ : unilib_(CreateUniLibForTesting()),
+ serialized_entity_data_schema_(TestEntityDataSchema()),
+ entity_data_builder_(new MutableFlatbufferBuilder(
+ flatbuffers::GetRoot<reflection::Schema>(
+ serialized_entity_data_schema_.data()))) {}
+
+ void SetTokenizerOptions(
+ RulesModel_::GrammarRulesT* action_grammar_rules) const {
+ action_grammar_rules->tokenizer_options.reset(new ActionsTokenizerOptionsT);
+ action_grammar_rules->tokenizer_options->type = TokenizationType_ICU;
+ action_grammar_rules->tokenizer_options->icu_preserve_whitespace_tokens =
+ false;
+ }
+
+ int AddActionSpec(const std::string& type, const std::string& response_text,
+ const std::vector<AnnotationSpec>& annotations,
+ RulesModel_::GrammarRulesT* action_grammar_rules) const {
+ const int action_id = action_grammar_rules->actions.size();
+ action_grammar_rules->actions.emplace_back(
+ new RulesModel_::RuleActionSpecT);
+ RulesModel_::RuleActionSpecT* actions_spec =
+ action_grammar_rules->actions.back().get();
+ actions_spec->action.reset(new ActionSuggestionSpecT);
+ actions_spec->action->response_text = response_text;
+ actions_spec->action->priority_score = 1.0;
+ actions_spec->action->score = 1.0;
+ actions_spec->action->type = type;
+ // Create annotations for specified capturing groups.
+ for (const AnnotationSpec& annotation : annotations) {
+ actions_spec->capturing_group.emplace_back(
+ new RulesModel_::RuleActionSpec_::RuleCapturingGroupT);
+ actions_spec->capturing_group.back()->group_id = annotation.group_id;
+ actions_spec->capturing_group.back()->annotation_name =
+ annotation.annotation_name;
+ actions_spec->capturing_group.back()->annotation_type =
+ annotation.annotation_name;
+ actions_spec->capturing_group.back()->use_annotation_match =
+ annotation.use_annotation_match;
+ }
+
+ return action_id;
+ }
+
+ int AddSmartReplySpec(
+ const std::string& response_text,
+ RulesModel_::GrammarRulesT* action_grammar_rules) const {
+ return AddActionSpec("text_reply", response_text, {}, action_grammar_rules);
+ }
+
+ int AddCapturingMatchSmartReplySpec(
+ const int match_id,
+ RulesModel_::GrammarRulesT* action_grammar_rules) const {
+ const int action_id = action_grammar_rules->actions.size();
+ action_grammar_rules->actions.emplace_back(
+ new RulesModel_::RuleActionSpecT);
+ RulesModel_::RuleActionSpecT* actions_spec =
+ action_grammar_rules->actions.back().get();
+ actions_spec->capturing_group.emplace_back(
+ new RulesModel_::RuleActionSpec_::RuleCapturingGroupT);
+ actions_spec->capturing_group.back()->group_id = match_id;
+ actions_spec->capturing_group.back()->text_reply.reset(
+ new ActionSuggestionSpecT);
+ actions_spec->capturing_group.back()->text_reply->priority_score = 1.0;
+ actions_spec->capturing_group.back()->text_reply->score = 1.0;
+ return action_id;
+ }
+
+ int AddRuleMatch(const std::vector<int>& action_ids,
+ RulesModel_::GrammarRulesT* action_grammar_rules) const {
+ const int rule_match_id = action_grammar_rules->rule_match.size();
+ action_grammar_rules->rule_match.emplace_back(
+ new RulesModel_::GrammarRules_::RuleMatchT);
+ action_grammar_rules->rule_match.back()->action_id.insert(
+ action_grammar_rules->rule_match.back()->action_id.end(),
+ action_ids.begin(), action_ids.end());
+ return rule_match_id;
+ }
+
+ std::unique_ptr<UniLib> unilib_;
+ const std::string serialized_entity_data_schema_;
+ std::unique_ptr<MutableFlatbufferBuilder> entity_data_builder_;
+};
+
+TEST_F(GrammarActionsTest, ProducesSmartReplies) {
+ LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap({""});
+ grammar::Rules rules(locale_shard_map);
+
+ // Create test rules.
+ // Rule: ^knock knock.?$ -> "Who's there?", "Yes?"
+ RulesModel_::GrammarRulesT action_grammar_rules;
+ SetTokenizerOptions(&action_grammar_rules);
+ action_grammar_rules.rules.reset(new grammar::RulesSetT);
+ rules.Add(
+ "<knock>", {"<^>", "knock", "knock", ".?", "<$>"},
+ /*callback=*/
+ static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
+ /*callback_param=*/
+ AddRuleMatch({AddSmartReplySpec("Who's there?", &action_grammar_rules),
+ AddSmartReplySpec("Yes?", &action_grammar_rules)},
+ &action_grammar_rules));
+ rules.Finalize().Serialize(/*include_debug_information=*/false,
+ action_grammar_rules.rules.get());
+ OwnedFlatbuffer<RulesModel_::GrammarRules, std::string> model(
+ PackFlatbuffer<RulesModel_::GrammarRules>(&action_grammar_rules));
+ TestGrammarActions grammar_actions(unilib_.get(), model.get());
+
+ std::vector<ActionSuggestion> result;
+ EXPECT_TRUE(grammar_actions.SuggestActions(
+ {/*messages=*/{{/*user_id=*/0, /*text=*/"Knock knock"}}}, &result));
+
+ EXPECT_THAT(result,
+ ElementsAre(IsSmartReply("Who's there?"), IsSmartReply("Yes?")));
+}
+
+TEST_F(GrammarActionsTest, ProducesSmartRepliesFromCapturingMatches) {
+ // Create test rules.
+ // Rule: ^Text <reply> to <command>
+ RulesModel_::GrammarRulesT action_grammar_rules;
+ SetTokenizerOptions(&action_grammar_rules);
+ action_grammar_rules.rules.reset(new grammar::RulesSetT);
+ LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap({""});
+ grammar::Rules rules(locale_shard_map);
+
+ rules.Add(
+ "<scripted_reply>",
+ {"<^>", "text", "<captured_reply>", "to", "<command>"},
+ /*callback=*/
+ static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
+ /*callback_param=*/
+ AddRuleMatch({AddCapturingMatchSmartReplySpec(
+ /*match_id=*/0, &action_grammar_rules)},
+ &action_grammar_rules));
+
+ // <command> ::= unsubscribe | cancel | confirm | receive
+ rules.Add("<command>", {"unsubscribe"});
+ rules.Add("<command>", {"cancel"});
+ rules.Add("<command>", {"confirm"});
+ rules.Add("<command>", {"receive"});
+
+ // <reply> ::= help | stop | cancel | yes
+ rules.Add("<reply>", {"help"});
+ rules.Add("<reply>", {"stop"});
+ rules.Add("<reply>", {"cancel"});
+ rules.Add("<reply>", {"yes"});
+ rules.AddValueMapping("<captured_reply>", {"<reply>"},
+ /*value=*/0);
+
+ rules.Finalize().Serialize(/*include_debug_information=*/false,
+ action_grammar_rules.rules.get());
+ OwnedFlatbuffer<RulesModel_::GrammarRules, std::string> model(
+ PackFlatbuffer<RulesModel_::GrammarRules>(&action_grammar_rules));
+ TestGrammarActions grammar_actions(unilib_.get(), model.get());
+
+ {
+ std::vector<ActionSuggestion> result;
+ EXPECT_TRUE(grammar_actions.SuggestActions(
+ {/*messages=*/{{/*user_id=*/0,
+ /*text=*/"Text YES to confirm your subscription"}}},
+ &result));
+ EXPECT_THAT(result, ElementsAre(IsSmartReply("YES")));
+ }
+
+ {
+ std::vector<ActionSuggestion> result;
+ EXPECT_TRUE(grammar_actions.SuggestActions(
+ {/*messages=*/{{/*user_id=*/0,
+ /*text=*/"text Stop to cancel your order"}}},
+ &result));
+ EXPECT_THAT(result, ElementsAre(IsSmartReply("Stop")));
+ }
+}
+
+TEST_F(GrammarActionsTest, ProducesAnnotationsForActions) {
+ // Create test rules.
+ // Rule: please dial <phone>
+ RulesModel_::GrammarRulesT action_grammar_rules;
+ SetTokenizerOptions(&action_grammar_rules);
+ action_grammar_rules.rules.reset(new grammar::RulesSetT);
+ LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap({""});
+ grammar::Rules rules(locale_shard_map);
+
+ rules.Add(
+ "<call_phone>", {"please", "dial", "<phone>"},
+ /*callback=*/
+ static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
+ /*callback_param=*/
+ AddRuleMatch({AddActionSpec("call_phone", /*response_text=*/"",
+ /*annotations=*/{{0 /*value*/, "phone"}},
+ &action_grammar_rules)},
+ &action_grammar_rules));
+ // phone ::= +00 00 000 00 00
+ rules.AddValueMapping("<phone>",
+ {"+", "<2_digits>", "<2_digits>", "<3_digits>",
+ "<2_digits>", "<2_digits>"},
+ /*value=*/0);
+ rules.Finalize().Serialize(/*include_debug_information=*/false,
+ action_grammar_rules.rules.get());
+ OwnedFlatbuffer<RulesModel_::GrammarRules, std::string> model(
+ PackFlatbuffer<RulesModel_::GrammarRules>(&action_grammar_rules));
+ TestGrammarActions grammar_actions(unilib_.get(), model.get());
+
+ std::vector<ActionSuggestion> result;
+ EXPECT_TRUE(grammar_actions.SuggestActions(
+ {/*messages=*/{{/*user_id=*/0, /*text=*/"Please dial +41 79 123 45 67"}}},
+ &result));
+
+ EXPECT_THAT(result, ElementsAre(IsActionOfType("call_phone")));
+ EXPECT_THAT(result.front().annotations,
+ ElementsAre(IsActionSuggestionAnnotation(
+ "phone", "+41 79 123 45 67", CodepointSpan{12, 28})));
+}
+
+TEST_F(GrammarActionsTest, HandlesLocales) {
+ // Create test rules.
+ // Rule: ^knock knock.?$ -> "Who's there?"
+ RulesModel_::GrammarRulesT action_grammar_rules;
+ SetTokenizerOptions(&action_grammar_rules);
+ action_grammar_rules.rules.reset(new grammar::RulesSetT);
+ LocaleShardMap locale_shard_map =
+ LocaleShardMap::CreateLocaleShardMap({"", "fr-CH"});
+ grammar::Rules rules(locale_shard_map);
+ rules.Add(
+ "<knock>", {"<^>", "knock", "knock", ".?", "<$>"},
+ /*callback=*/
+ static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
+ /*callback_param=*/
+ AddRuleMatch({AddSmartReplySpec("Who's there?", &action_grammar_rules)},
+ &action_grammar_rules));
+ rules.Add(
+ "<toc>", {"<knock>"},
+ /*callback=*/
+ static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
+ /*callback_param=*/
+ AddRuleMatch({AddSmartReplySpec("Qui est là?", &action_grammar_rules)},
+ &action_grammar_rules),
+ /*max_whitespace_gap=*/-1,
+ /*case_sensitive=*/false,
+ /*shard=*/1);
+ rules.Finalize().Serialize(/*include_debug_information=*/false,
+ action_grammar_rules.rules.get());
+ // Set locales for rules.
+ action_grammar_rules.rules->rules.back()->locale.emplace_back(
+ new LanguageTagT);
+ action_grammar_rules.rules->rules.back()->locale.back()->language = "fr";
+
+ OwnedFlatbuffer<RulesModel_::GrammarRules, std::string> model(
+ PackFlatbuffer<RulesModel_::GrammarRules>(&action_grammar_rules));
+ TestGrammarActions grammar_actions(unilib_.get(), model.get());
+
+ // Check default.
+ {
+ std::vector<ActionSuggestion> result;
+ EXPECT_TRUE(grammar_actions.SuggestActions(
+ {/*messages=*/{{/*user_id=*/0, /*text=*/"knock knock",
+ /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"UTC", /*annotations=*/{},
+ /*detected_text_language_tags=*/"en"}}},
+ &result));
+
+ EXPECT_THAT(result, ElementsAre(IsSmartReply("Who's there?")));
+ }
+
+ // Check fr.
+ {
+ std::vector<ActionSuggestion> result;
+ EXPECT_TRUE(grammar_actions.SuggestActions(
+ {/*messages=*/{{/*user_id=*/0, /*text=*/"knock knock",
+ /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"UTC", /*annotations=*/{},
+ /*detected_text_language_tags=*/"fr-CH"}}},
+ &result));
+
+ EXPECT_THAT(result, ElementsAre(IsSmartReply("Who's there?"),
+ IsSmartReply("Qui est là?")));
+ }
+}
+
+TEST_F(GrammarActionsTest, HandlesAssertions) {
+ // Create test rules.
+ // Rule: <flight> -> Track flight.
+ RulesModel_::GrammarRulesT action_grammar_rules;
+ SetTokenizerOptions(&action_grammar_rules);
+ action_grammar_rules.rules.reset(new grammar::RulesSetT);
+ LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap({""});
+ grammar::Rules rules(locale_shard_map);
+ rules.Add("<carrier>", {"lx"});
+ rules.Add("<carrier>", {"aa"});
+ rules.Add("<flight_code>", {"<2_digits>"});
+ rules.Add("<flight_code>", {"<3_digits>"});
+ rules.Add("<flight_code>", {"<4_digits>"});
+
+ // Capture flight code.
+ rules.AddValueMapping("<flight>", {"<carrier>", "<flight_code>"},
+ /*value=*/0);
+
+ // Flight: carrier + flight code and check right context.
+ rules.Add(
+ "<track_flight>", {"<flight>", "<context_assertion>?"},
+ /*callback=*/
+ static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
+ /*callback_param=*/
+ AddRuleMatch({AddActionSpec("track_flight", /*response_text=*/"",
+ /*annotations=*/{{0 /*value*/, "flight"}},
+ &action_grammar_rules)},
+ &action_grammar_rules));
+
+ // Exclude matches like: LX 38.00 etc.
+ rules.AddAssertion("<context_assertion>", {".?", "<digits>"},
+ /*negative=*/true);
+
+ rules.Finalize().Serialize(/*include_debug_information=*/false,
+ action_grammar_rules.rules.get());
+
+ OwnedFlatbuffer<RulesModel_::GrammarRules, std::string> model(
+ PackFlatbuffer<RulesModel_::GrammarRules>(&action_grammar_rules));
+ TestGrammarActions grammar_actions(unilib_.get(), model.get());
+
+ std::vector<ActionSuggestion> result;
+ EXPECT_TRUE(grammar_actions.SuggestActions(
+ {/*messages=*/{{/*user_id=*/0, /*text=*/"LX38 aa 44 LX 38.38"}}},
+ &result));
+
+ EXPECT_THAT(result, ElementsAre(IsActionOfType("track_flight"),
+ IsActionOfType("track_flight")));
+ EXPECT_THAT(result[0].annotations,
+ ElementsAre(IsActionSuggestionAnnotation("flight", "LX38",
+ CodepointSpan{0, 4})));
+ EXPECT_THAT(result[1].annotations,
+ ElementsAre(IsActionSuggestionAnnotation("flight", "aa 44",
+ CodepointSpan{5, 10})));
+}
+
+TEST_F(GrammarActionsTest, SetsFixedEntityData) {
+ // Create test rules.
+ // Rule: ^hello there$
+ RulesModel_::GrammarRulesT action_grammar_rules;
+ SetTokenizerOptions(&action_grammar_rules);
+ action_grammar_rules.rules.reset(new grammar::RulesSetT);
+ LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap({""});
+ grammar::Rules rules(locale_shard_map);
+
+ // Create smart reply and static entity data.
+ const int spec_id =
+ AddSmartReplySpec("General Kenobi!", &action_grammar_rules);
+ std::unique_ptr<MutableFlatbuffer> entity_data =
+ entity_data_builder_->NewRoot();
+ entity_data->Set("person", "Kenobi");
+ action_grammar_rules.actions[spec_id]->action->serialized_entity_data =
+ entity_data->Serialize();
+ action_grammar_rules.actions[spec_id]->action->entity_data.reset(
+ new ActionsEntityDataT);
+ action_grammar_rules.actions[spec_id]->action->entity_data->text =
+ "I have the high ground.";
+
+ rules.Add(
+ "<greeting>", {"<^>", "hello", "there", "<$>"},
+ /*callback=*/
+ static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
+ /*callback_param=*/
+ AddRuleMatch({spec_id}, &action_grammar_rules));
+ rules.Finalize().Serialize(/*include_debug_information=*/false,
+ action_grammar_rules.rules.get());
+ OwnedFlatbuffer<RulesModel_::GrammarRules, std::string> model(
+ PackFlatbuffer<RulesModel_::GrammarRules>(&action_grammar_rules));
+ TestGrammarActions grammar_actions(unilib_.get(), model.get(),
+ entity_data_builder_.get());
+
+ std::vector<ActionSuggestion> result;
+ EXPECT_TRUE(grammar_actions.SuggestActions(
+ {/*messages=*/{{/*user_id=*/0, /*text=*/"Hello there"}}}, &result));
+
+ // Check the produces smart replies.
+ EXPECT_THAT(result, ElementsAre(IsSmartReply("General Kenobi!")));
+
+ // Check entity data.
+ const flatbuffers::Table* entity =
+ flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
+ result[0].serialized_entity_data.data()));
+ EXPECT_THAT(
+ entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(),
+ "I have the high ground.");
+ EXPECT_THAT(
+ entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(),
+ "Kenobi");
+}
+
+TEST_F(GrammarActionsTest, SetsEntityDataFromCapturingMatches) {
+ // Create test rules.
+ // Rule: ^hello there$
+ RulesModel_::GrammarRulesT action_grammar_rules;
+ SetTokenizerOptions(&action_grammar_rules);
+ action_grammar_rules.rules.reset(new grammar::RulesSetT);
+ LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap({""});
+ grammar::Rules rules(locale_shard_map);
+
+ // Create smart reply and static entity data.
+ const int spec_id =
+ AddSmartReplySpec("General Kenobi!", &action_grammar_rules);
+ std::unique_ptr<MutableFlatbuffer> entity_data =
+ entity_data_builder_->NewRoot();
+ entity_data->Set("person", "Kenobi");
+ action_grammar_rules.actions[spec_id]->action->serialized_entity_data =
+ entity_data->Serialize();
+
+ // Specify results for capturing matches.
+ const int greeting_match_id = 0;
+ const int location_match_id = 1;
+ {
+ action_grammar_rules.actions[spec_id]->capturing_group.emplace_back(
+ new RulesModel_::RuleActionSpec_::RuleCapturingGroupT);
+ RulesModel_::RuleActionSpec_::RuleCapturingGroupT* group =
+ action_grammar_rules.actions[spec_id]->capturing_group.back().get();
+ group->group_id = greeting_match_id;
+ group->entity_field.reset(new FlatbufferFieldPathT);
+ group->entity_field->field.emplace_back(new FlatbufferFieldT);
+ group->entity_field->field.back()->field_name = "greeting";
+ }
+ {
+ action_grammar_rules.actions[spec_id]->capturing_group.emplace_back(
+ new RulesModel_::RuleActionSpec_::RuleCapturingGroupT);
+ RulesModel_::RuleActionSpec_::RuleCapturingGroupT* group =
+ action_grammar_rules.actions[spec_id]->capturing_group.back().get();
+ group->group_id = location_match_id;
+ group->entity_field.reset(new FlatbufferFieldPathT);
+ group->entity_field->field.emplace_back(new FlatbufferFieldT);
+ group->entity_field->field.back()->field_name = "location";
+ }
+
+ rules.Add("<location>", {"there"});
+ rules.Add("<location>", {"here"});
+ rules.AddValueMapping("<captured_location>", {"<location>"},
+ /*value=*/location_match_id);
+ rules.AddValueMapping("<greeting>", {"hello", "<captured_location>"},
+ /*value=*/greeting_match_id);
+ rules.Add(
+ "<test>", {"<^>", "<greeting>", "<$>"},
+ /*callback=*/
+ static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
+ /*callback_param=*/
+ AddRuleMatch({spec_id}, &action_grammar_rules));
+ rules.Finalize().Serialize(/*include_debug_information=*/false,
+ action_grammar_rules.rules.get());
+ OwnedFlatbuffer<RulesModel_::GrammarRules, std::string> model(
+ PackFlatbuffer<RulesModel_::GrammarRules>(&action_grammar_rules));
+ TestGrammarActions grammar_actions(unilib_.get(), model.get(),
+ entity_data_builder_.get());
+
+ std::vector<ActionSuggestion> result;
+ EXPECT_TRUE(grammar_actions.SuggestActions(
+ {/*messages=*/{{/*user_id=*/0, /*text=*/"Hello there"}}}, &result));
+
+ // Check the produces smart replies.
+ EXPECT_THAT(result, ElementsAre(IsSmartReply("General Kenobi!")));
+
+ // Check entity data.
+ const flatbuffers::Table* entity =
+ flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
+ result[0].serialized_entity_data.data()));
+ EXPECT_THAT(
+ entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(),
+ "Hello there");
+ EXPECT_THAT(
+ entity->GetPointer<const flatbuffers::String*>(/*field=*/6)->str(),
+ "there");
+ EXPECT_THAT(
+ entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(),
+ "Kenobi");
+}
+
+TEST_F(GrammarActionsTest, SetsFixedEntityDataFromCapturingGroups) {
+ // Create test rules.
+ // Rule: ^hello there$
+ RulesModel_::GrammarRulesT action_grammar_rules;
+ SetTokenizerOptions(&action_grammar_rules);
+ action_grammar_rules.rules.reset(new grammar::RulesSetT);
+ LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap({""});
+ grammar::Rules rules(locale_shard_map);
+
+ // Create smart reply.
+ const int spec_id =
+ AddSmartReplySpec("General Kenobi!", &action_grammar_rules);
+ action_grammar_rules.actions[spec_id]->capturing_group.emplace_back(
+ new RulesModel_::RuleActionSpec_::RuleCapturingGroupT);
+ RulesModel_::RuleActionSpec_::RuleCapturingGroupT* group =
+ action_grammar_rules.actions[spec_id]->capturing_group.back().get();
+ group->group_id = 0;
+ group->entity_data.reset(new ActionsEntityDataT);
+ group->entity_data->text = "You are a bold one.";
+
+ rules.AddValueMapping("<greeting>", {"<^>", "hello", "there", "<$>"},
+ /*value=*/0);
+ rules.Add(
+ "<test>", {"<greeting>"},
+ /*callback=*/
+ static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
+ /*callback_param=*/
+ AddRuleMatch({spec_id}, &action_grammar_rules));
+ rules.Finalize().Serialize(/*include_debug_information=*/false,
+ action_grammar_rules.rules.get());
+ OwnedFlatbuffer<RulesModel_::GrammarRules, std::string> model(
+ PackFlatbuffer<RulesModel_::GrammarRules>(&action_grammar_rules));
+ TestGrammarActions grammar_actions(unilib_.get(), model.get(),
+ entity_data_builder_.get());
+
+ std::vector<ActionSuggestion> result;
+ EXPECT_TRUE(grammar_actions.SuggestActions(
+ {/*messages=*/{{/*user_id=*/0, /*text=*/"Hello there"}}}, &result));
+
+ // Check the produces smart replies.
+ EXPECT_THAT(result, ElementsAre(IsSmartReply("General Kenobi!")));
+
+ // Check entity data.
+ const flatbuffers::Table* entity =
+ flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
+ result[0].serialized_entity_data.data()));
+ EXPECT_THAT(
+ entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(),
+ "You are a bold one.");
+}
+
+TEST_F(GrammarActionsTest, ProducesActionsWithAnnotations) {
+ // Create test rules.
+ // Rule: please dial <phone>
+ RulesModel_::GrammarRulesT action_grammar_rules;
+ SetTokenizerOptions(&action_grammar_rules);
+ action_grammar_rules.rules.reset(new grammar::RulesSetT);
+ LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap({""});
+ grammar::Rules rules(locale_shard_map);
+ rules.Add(
+ "<call_phone>", {"please", "dial", "<phone>"},
+ /*callback=*/
+ static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
+ /*callback_param=*/
+ AddRuleMatch({AddActionSpec("call_phone", /*response_text=*/"",
+ /*annotations=*/
+ {{0 /*value*/, "phone",
+ /*use_annotation_match=*/true}},
+ &action_grammar_rules)},
+ &action_grammar_rules));
+ rules.AddValueMapping("<phone>", {"<phone_annotation>"},
+ /*value=*/0);
+
+ grammar::Ir ir = rules.Finalize(
+ /*predefined_nonterminals=*/{"<phone_annotation>"});
+ ir.Serialize(/*include_debug_information=*/false,
+ action_grammar_rules.rules.get());
+
+ // Map "phone" annotation to "<phone_annotation>" nonterminal.
+ action_grammar_rules.rules->nonterminals->annotation_nt.emplace_back(
+ new grammar::RulesSet_::Nonterminals_::AnnotationNtEntryT);
+ action_grammar_rules.rules->nonterminals->annotation_nt.back()->key = "phone";
+ action_grammar_rules.rules->nonterminals->annotation_nt.back()->value =
+ ir.GetNonterminalForName("<phone_annotation>");
+
+ OwnedFlatbuffer<RulesModel_::GrammarRules, std::string> model(
+ PackFlatbuffer<RulesModel_::GrammarRules>(&action_grammar_rules));
+ TestGrammarActions grammar_actions(unilib_.get(), model.get());
+
+ std::vector<ActionSuggestion> result;
+
+ // Sanity check that no result are produced when no annotations are provided.
+ EXPECT_TRUE(grammar_actions.SuggestActions(
+ {/*messages=*/{{/*user_id=*/0, /*text=*/"Please dial +41 79 123 45 67"}}},
+ &result));
+ EXPECT_THAT(result, IsEmpty());
+
+ EXPECT_TRUE(grammar_actions.SuggestActions(
+ {/*messages=*/{
+ {/*user_id=*/0,
+ /*text=*/"Please dial +41 79 123 45 67",
+ /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"UTC",
+ /*annotations=*/
+ {{CodepointSpan{12, 28}, {ClassificationResult{"phone", 1.0}}}}}}},
+ &result));
+ EXPECT_THAT(result, ElementsAre(IsActionOfType("call_phone")));
+ EXPECT_THAT(result.front().annotations,
+ ElementsAre(IsActionSuggestionAnnotation(
+ "phone", "+41 79 123 45 67", CodepointSpan{12, 28})));
+}
+
+TEST_F(GrammarActionsTest, HandlesExclusions) {
+ // Create test rules.
+ RulesModel_::GrammarRulesT action_grammar_rules;
+ SetTokenizerOptions(&action_grammar_rules);
+ action_grammar_rules.rules.reset(new grammar::RulesSetT);
+
+ LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap({""});
+ grammar::Rules rules(locale_shard_map);
+ rules.Add("<excluded>", {"be", "safe"});
+ rules.AddWithExclusion("<tokens_but_not_excluded>", {"<token>", "<token>"},
+ /*excluded_nonterminal=*/"<excluded>");
+
+ rules.Add(
+ "<set_reminder>",
+ {"do", "not", "forget", "to", "<tokens_but_not_excluded>"},
+ /*callback=*/
+ static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
+ /*callback_param=*/
+ AddRuleMatch({AddActionSpec("set_reminder", /*response_text=*/"",
+ /*annotations=*/
+ {}, &action_grammar_rules)},
+ &action_grammar_rules));
+
+ rules.Finalize().Serialize(/*include_debug_information=*/false,
+ action_grammar_rules.rules.get());
+ OwnedFlatbuffer<RulesModel_::GrammarRules, std::string> model(
+ PackFlatbuffer<RulesModel_::GrammarRules>(&action_grammar_rules));
+ TestGrammarActions grammar_actions(unilib_.get(), model.get(),
+ entity_data_builder_.get());
+
+ {
+ std::vector<ActionSuggestion> result;
+ EXPECT_TRUE(grammar_actions.SuggestActions(
+ {/*messages=*/{
+ {/*user_id=*/0, /*text=*/"do not forget to bring milk"}}},
+ &result));
+ EXPECT_THAT(result, ElementsAre(IsActionOfType("set_reminder")));
+ }
+
+ {
+ std::vector<ActionSuggestion> result;
+ EXPECT_TRUE(grammar_actions.SuggestActions(
+ {/*messages=*/{{/*user_id=*/0, /*text=*/"do not forget to be there!"}}},
+ &result));
+ EXPECT_THAT(result, ElementsAre(IsActionOfType("set_reminder")));
+ }
+
+ {
+ std::vector<ActionSuggestion> result;
+ EXPECT_TRUE(grammar_actions.SuggestActions(
+ {/*messages=*/{
+ {/*user_id=*/0, /*text=*/"do not forget to buy safe or vault!"}}},
+ &result));
+ EXPECT_THAT(result, ElementsAre(IsActionOfType("set_reminder")));
+ }
+
+ {
+ std::vector<ActionSuggestion> result;
+ EXPECT_TRUE(grammar_actions.SuggestActions(
+ {/*messages=*/{{/*user_id=*/0, /*text=*/"do not forget to be safe!"}}},
+ &result));
+ EXPECT_THAT(result, IsEmpty());
+ }
+}
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/native/actions/lua-ranker_test.cc b/native/actions/lua-ranker_test.cc
index a790042..939617b 100644
--- a/native/actions/lua-ranker_test.cc
+++ b/native/actions/lua-ranker_test.cc
@@ -19,7 +19,7 @@
#include <string>
#include "actions/types.h"
-#include "utils/flatbuffers.h"
+#include "utils/flatbuffers/mutable.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
@@ -229,8 +229,8 @@
flatbuffers::GetRoot<reflection::Schema>(serialized_schema.data());
// Create test entity data.
- ReflectiveFlatbufferBuilder builder(entity_data_schema);
- std::unique_ptr<ReflectiveFlatbuffer> buffer = builder.NewRoot();
+ MutableFlatbufferBuilder builder(entity_data_schema);
+ std::unique_ptr<MutableFlatbuffer> buffer = builder.NewRoot();
buffer->Set("test", "value_a");
const std::string serialized_entity_data_a = buffer->Serialize();
buffer->Set("test", "value_b");
diff --git a/native/actions/ngram-model.cc b/native/actions/ngram-model.cc
index fb3992c..94ec8b2 100644
--- a/native/actions/ngram-model.cc
+++ b/native/actions/ngram-model.cc
@@ -60,7 +60,7 @@
} // anonymous namespace
-std::unique_ptr<NGramModel> NGramModel::Create(
+std::unique_ptr<NGramSensitiveModel> NGramSensitiveModel::Create(
const UniLib* unilib, const NGramLinearRegressionModel* model,
const Tokenizer* tokenizer) {
if (model == nullptr) {
@@ -70,12 +70,13 @@
TC3_LOG(ERROR) << "No tokenizer options specified.";
return nullptr;
}
- return std::unique_ptr<NGramModel>(new NGramModel(unilib, model, tokenizer));
+ return std::unique_ptr<NGramSensitiveModel>(
+ new NGramSensitiveModel(unilib, model, tokenizer));
}
-NGramModel::NGramModel(const UniLib* unilib,
- const NGramLinearRegressionModel* model,
- const Tokenizer* tokenizer)
+NGramSensitiveModel::NGramSensitiveModel(
+ const UniLib* unilib, const NGramLinearRegressionModel* model,
+ const Tokenizer* tokenizer)
: model_(model) {
// Create new tokenizer if options are specified, reuse feature processor
// tokenizer otherwise.
@@ -88,9 +89,10 @@
}
// Returns whether a given n-gram matches the token stream.
-bool NGramModel::IsNGramMatch(const uint32* tokens, size_t num_tokens,
- const uint32* ngram_tokens,
- size_t num_ngram_tokens, int max_skips) const {
+bool NGramSensitiveModel::IsNGramMatch(const uint32* tokens, size_t num_tokens,
+ const uint32* ngram_tokens,
+ size_t num_ngram_tokens,
+ int max_skips) const {
int token_idx = 0, ngram_token_idx = 0, skip_remain = 0;
for (; token_idx < num_tokens && ngram_token_idx < num_ngram_tokens;) {
if (tokens[token_idx] == ngram_tokens[ngram_token_idx]) {
@@ -112,8 +114,9 @@
// Calculates the total number of skip-grams that can be created for a stream
// with the given number of tokens.
-uint64 NGramModel::GetNumSkipGrams(int num_tokens, int max_ngram_length,
- int max_skips) {
+uint64 NGramSensitiveModel::GetNumSkipGrams(int num_tokens,
+ int max_ngram_length,
+ int max_skips) {
// Start with unigrams.
uint64 total = num_tokens;
for (int ngram_len = 2;
@@ -138,7 +141,8 @@
return total;
}
-std::pair<int, int> NGramModel::GetFirstTokenMatches(uint32 token_hash) const {
+std::pair<int, int> NGramSensitiveModel::GetFirstTokenMatches(
+ uint32 token_hash) const {
const int num_ngrams = model_->ngram_weights()->size();
const auto start_it = FirstTokenIterator(model_, 0);
const auto end_it = FirstTokenIterator(model_, num_ngrams);
@@ -147,15 +151,13 @@
return std::make_pair(start, end);
}
-bool NGramModel::Eval(const UnicodeText& text, float* score) const {
+std::pair<bool, float> NGramSensitiveModel::Eval(
+ const UnicodeText& text) const {
const std::vector<Token> raw_tokens = tokenizer_->Tokenize(text);
// If we have no tokens, then just bail early.
if (raw_tokens.empty()) {
- if (score != nullptr) {
- *score = model_->default_token_weight();
- }
- return false;
+ return std::make_pair(false, model_->default_token_weight());
}
// Hash the tokens.
@@ -201,25 +203,25 @@
const float internal_score =
(weight_matches + (model_->default_token_weight() * num_misses)) /
num_candidates;
- if (score != nullptr) {
- *score = internal_score;
- }
- return internal_score > model_->threshold();
+ return std::make_pair(internal_score > model_->threshold(), internal_score);
}
-bool NGramModel::EvalConversation(const Conversation& conversation,
- const int num_messages) const {
+std::pair<bool, float> NGramSensitiveModel::EvalConversation(
+ const Conversation& conversation, const int num_messages) const {
+ float score = 0.0;
for (int i = 1; i <= num_messages; i++) {
const std::string& message =
conversation.messages[conversation.messages.size() - i].text;
const UnicodeText message_unicode(
UTF8ToUnicodeText(message, /*do_copy=*/false));
// Run ngram linear regression model.
- if (Eval(message_unicode)) {
- return true;
+ const auto prediction = Eval(message_unicode);
+ if (prediction.first) {
+ return prediction;
}
+ score = std::max(score, prediction.second);
}
- return false;
+ return std::make_pair(false, score);
}
} // namespace libtextclassifier3
diff --git a/native/actions/ngram-model.h b/native/actions/ngram-model.h
index a9072cd..32fd54b 100644
--- a/native/actions/ngram-model.h
+++ b/native/actions/ngram-model.h
@@ -20,6 +20,7 @@
#include <memory>
#include "actions/actions_model_generated.h"
+#include "actions/sensitive-classifier-base.h"
#include "actions/types.h"
#include "utils/tokenizer.h"
#include "utils/utf8/unicodetext.h"
@@ -27,29 +28,30 @@
namespace libtextclassifier3 {
-class NGramModel {
+class NGramSensitiveModel : public SensitiveTopicModelBase {
public:
- static std::unique_ptr<NGramModel> Create(
+ static std::unique_ptr<NGramSensitiveModel> Create(
const UniLib* unilib, const NGramLinearRegressionModel* model,
const Tokenizer* tokenizer);
// Evaluates an n-gram linear regression model, and tests against the
// threshold. Returns true in case of a positive classification. The caller
// may also optionally query the score.
- bool Eval(const UnicodeText& text, float* score = nullptr) const;
+ std::pair<bool, float> Eval(const UnicodeText& text) const override;
// Evaluates an n-gram linear regression model against all messages in a
// conversation and returns true in case of any positive classification.
- bool EvalConversation(const Conversation& conversation,
- const int num_messages) const;
+ std::pair<bool, float> EvalConversation(const Conversation& conversation,
+ int num_messages) const override;
// Exposed for testing only.
static uint64 GetNumSkipGrams(int num_tokens, int max_ngram_length,
int max_skips);
private:
- NGramModel(const UniLib* unilib, const NGramLinearRegressionModel* model,
- const Tokenizer* tokenizer);
+ explicit NGramSensitiveModel(const UniLib* unilib,
+ const NGramLinearRegressionModel* model,
+ const Tokenizer* tokenizer);
// Returns the (begin,end] range of n-grams where the first hashed token
// matches the given value.
diff --git a/native/actions/ranker.cc b/native/actions/ranker.cc
index 5a03da5..d52ecaa 100644
--- a/native/actions/ranker.cc
+++ b/native/actions/ranker.cc
@@ -20,11 +20,15 @@
#include <set>
#include <vector>
+#if !defined(TC3_DISABLE_LUA)
#include "actions/lua-ranker.h"
+#endif
#include "actions/zlib-utils.h"
#include "annotator/types.h"
#include "utils/base/logging.h"
+#if !defined(TC3_DISABLE_LUA)
#include "utils/lua-utils.h"
+#endif
namespace libtextclassifier3 {
namespace {
@@ -215,6 +219,7 @@
return false;
}
+#if !defined(TC3_DISABLE_LUA)
std::string lua_ranking_script;
if (GetUncompressedString(options_->lua_ranking_script(),
options_->compressed_lua_ranking_script(),
@@ -225,6 +230,7 @@
return false;
}
}
+#endif
return true;
}
@@ -336,6 +342,7 @@
SortByScoreAndType(&response->actions);
}
+#if !defined(TC3_DISABLE_LUA)
// Run lua ranking snippet, if provided.
if (!lua_bytecode_.empty()) {
auto lua_ranker = ActionsSuggestionsLuaRanker::Create(
@@ -346,6 +353,7 @@
return false;
}
}
+#endif
return true;
}
diff --git a/native/actions/ranker.h b/native/actions/ranker.h
index 2ab3146..5af7c38 100644
--- a/native/actions/ranker.h
+++ b/native/actions/ranker.h
@@ -22,6 +22,7 @@
#include "actions/actions_model_generated.h"
#include "actions/types.h"
#include "utils/zlib/zlib.h"
+#include "flatbuffers/reflection.h"
namespace libtextclassifier3 {
diff --git a/native/actions/regex-actions.cc b/native/actions/regex-actions.cc
index 7d5a4b2..9d91c73 100644
--- a/native/actions/regex-actions.cc
+++ b/native/actions/regex-actions.cc
@@ -93,6 +93,9 @@
bool RegexActions::InitializeRulesModel(
const RulesModel* rules, ZlibDecompressor* decompressor,
std::vector<CompiledRule>* compiled_rules) const {
+ if (rules->regex_rule() == nullptr) {
+ return true;
+ }
for (const RulesModel_::RegexRule* rule : *rules->regex_rule()) {
std::unique_ptr<UniLib::RegexPattern> compiled_pattern =
UncompressMakeRegexPattern(
@@ -189,7 +192,7 @@
bool RegexActions::SuggestActions(
const Conversation& conversation,
- const ReflectiveFlatbufferBuilder* entity_data_builder,
+ const MutableFlatbufferBuilder* entity_data_builder,
std::vector<ActionSuggestion>* actions) const {
// Create actions based on rules checking the last message.
const int message_index = conversation.messages.size() - 1;
@@ -206,7 +209,7 @@
const ActionSuggestionSpec* action = rule_action->action();
std::vector<ActionSuggestionAnnotation> annotations;
- std::unique_ptr<ReflectiveFlatbuffer> entity_data =
+ std::unique_ptr<MutableFlatbuffer> entity_data =
entity_data_builder != nullptr ? entity_data_builder->NewRoot()
: nullptr;
diff --git a/native/actions/regex-actions.h b/native/actions/regex-actions.h
index 871f08b..ee0b186 100644
--- a/native/actions/regex-actions.h
+++ b/native/actions/regex-actions.h
@@ -23,7 +23,7 @@
#include "actions/actions_model_generated.h"
#include "actions/types.h"
-#include "utils/flatbuffers.h"
+#include "utils/flatbuffers/mutable.h"
#include "utils/utf8/unilib.h"
#include "utils/zlib/zlib.h"
@@ -55,7 +55,7 @@
// Suggests actions for a conversation from a message stream using the regex
// rules.
bool SuggestActions(const Conversation& conversation,
- const ReflectiveFlatbufferBuilder* entity_data_builder,
+ const MutableFlatbufferBuilder* entity_data_builder,
std::vector<ActionSuggestion>* actions) const;
private:
diff --git a/native/actions/sensitive-classifier-base.h b/native/actions/sensitive-classifier-base.h
new file mode 100644
index 0000000..b0ecacd
--- /dev/null
+++ b/native/actions/sensitive-classifier-base.h
@@ -0,0 +1,39 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ACTIONS_SENSITIVE_CLASSIFIER_BASE_H_
+#define LIBTEXTCLASSIFIER_ACTIONS_SENSITIVE_CLASSIFIER_BASE_H_
+
+#include <memory>
+#include <utility>
+
+#include "actions/types.h"
+#include "utils/utf8/unicodetext.h"
+
+namespace libtextclassifier3 {
+class SensitiveTopicModelBase {
+ public:
+ // Returns a pair: a boolean, which is true if the topic is sensitive, and a
+ // score.
+ virtual std::pair<bool, float> Eval(const UnicodeText& text) const = 0;
+ virtual std::pair<bool, float> EvalConversation(
+ const Conversation& conversation, int num_messages) const = 0;
+
+ virtual ~SensitiveTopicModelBase() {}
+};
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ACTIONS_SENSITIVE_CLASSIFIER_BASE_H_
diff --git a/native/actions/test-utils.cc b/native/actions/test-utils.cc
index 9b003dd..426989d 100644
--- a/native/actions/test-utils.cc
+++ b/native/actions/test-utils.cc
@@ -16,6 +16,8 @@
#include "actions/test-utils.h"
+#include "flatbuffers/reflection.h"
+
namespace libtextclassifier3 {
std::string TestEntityDataSchema() {
diff --git a/native/actions/test-utils.h b/native/actions/test-utils.h
index c05d6a9..e27f510 100644
--- a/native/actions/test-utils.h
+++ b/native/actions/test-utils.h
@@ -20,7 +20,6 @@
#include <string>
#include "actions/actions_model_generated.h"
-#include "utils/flatbuffers.h"
#include "gmock/gmock.h"
namespace libtextclassifier3 {
diff --git a/native/actions/test_data/actions_suggestions_grammar_test.model b/native/actions/test_data/actions_suggestions_grammar_test.model
new file mode 100644
index 0000000..d122687
--- /dev/null
+++ b/native/actions/test_data/actions_suggestions_grammar_test.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.hashgram.model b/native/actions/test_data/actions_suggestions_test.hashgram.model
new file mode 100644
index 0000000..cdc6bdc
--- /dev/null
+++ b/native/actions/test_data/actions_suggestions_test.hashgram.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.model b/native/actions/test_data/actions_suggestions_test.model
new file mode 100644
index 0000000..2d97bc8
--- /dev/null
+++ b/native/actions/test_data/actions_suggestions_test.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model b/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model
new file mode 100644
index 0000000..567828b
--- /dev/null
+++ b/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji.model b/native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji.model
new file mode 100644
index 0000000..99f9040
--- /dev/null
+++ b/native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.model b/native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.model
new file mode 100644
index 0000000..504d8e0
--- /dev/null
+++ b/native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_sr_p13n.model b/native/actions/test_data/actions_suggestions_test.multi_task_sr_p13n.model
new file mode 100644
index 0000000..33926c2
--- /dev/null
+++ b/native/actions/test_data/actions_suggestions_test.multi_task_sr_p13n.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_tf2_test.model b/native/actions/test_data/actions_suggestions_test.multi_task_tf2_test.model
new file mode 100644
index 0000000..730f603
--- /dev/null
+++ b/native/actions/test_data/actions_suggestions_test.multi_task_tf2_test.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.sensitive_tflite.model b/native/actions/test_data/actions_suggestions_test.sensitive_tflite.model
new file mode 100644
index 0000000..29fe077
--- /dev/null
+++ b/native/actions/test_data/actions_suggestions_test.sensitive_tflite.model
Binary files differ
diff --git a/native/actions/test_data/en_sensitive_topic_2019117.tflite b/native/actions/test_data/en_sensitive_topic_2019117.tflite
new file mode 100644
index 0000000..48edfbd
--- /dev/null
+++ b/native/actions/test_data/en_sensitive_topic_2019117.tflite
Binary files differ
diff --git a/native/actions/tflite-sensitive-model.cc b/native/actions/tflite-sensitive-model.cc
new file mode 100644
index 0000000..e68d1d5
--- /dev/null
+++ b/native/actions/tflite-sensitive-model.cc
@@ -0,0 +1,128 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "actions/tflite-sensitive-model.h"
+
+#include <utility>
+
+#include "actions/actions_model_generated.h"
+#include "actions/types.h"
+
+namespace libtextclassifier3 {
+namespace {
+const char kNotSensitive[] = "NOT_SENSITIVE";
+} // namespace
+
+std::unique_ptr<TFLiteSensitiveModel> TFLiteSensitiveModel::Create(
+ const TFLiteSensitiveClassifierConfig* model_config) {
+ auto result_model = std::unique_ptr<TFLiteSensitiveModel>(
+ new TFLiteSensitiveModel(model_config));
+ if (result_model->model_executor_ == nullptr) {
+ return nullptr;
+ }
+ return result_model;
+}
+
+std::pair<bool, float> TFLiteSensitiveModel::Eval(
+ const UnicodeText& text) const {
+ // Create a conversation with one message and classify it.
+ Conversation conversation;
+ conversation.messages.emplace_back();
+ conversation.messages.front().text = text.ToUTF8String();
+
+ return EvalConversation(conversation, 1);
+}
+
+std::pair<bool, float> TFLiteSensitiveModel::EvalConversation(
+ const Conversation& conversation, int num_messages) const {
+ if (model_executor_ == nullptr) {
+ return std::make_pair(false, 0.0f);
+ }
+ const auto interpreter = model_executor_->CreateInterpreter();
+
+ if (interpreter->AllocateTensors() != kTfLiteOk) {
+ // TODO(mgubin): report error that tensors can't be allocated.
+ return std::make_pair(false, 0.0f);
+ }
+ // The sensitive model is actually an ordinary TFLite model with Lingua API,
+ // prepare texts and user_ids similar way, it doesn't use timediffs.
+ std::vector<std::string> context;
+ std::vector<int> user_ids;
+ context.reserve(num_messages);
+ user_ids.reserve(num_messages);
+
+ // Gather last `num_messages` messages from the conversation.
+ for (int i = conversation.messages.size() - num_messages;
+ i < conversation.messages.size(); i++) {
+ const ConversationMessage& message = conversation.messages[i];
+ context.push_back(message.text);
+ user_ids.push_back(message.user_id);
+ }
+
+ // Allocate tensors.
+ //
+
+ if (model_config_->model_spec()->input_context() >= 0) {
+ if (model_config_->model_spec()->input_length_to_pad() > 0) {
+ context.resize(model_config_->model_spec()->input_length_to_pad());
+ }
+ model_executor_->SetInput<std::string>(
+ model_config_->model_spec()->input_context(), context,
+ interpreter.get());
+ }
+ if (model_config_->model_spec()->input_context_length() >= 0) {
+ model_executor_->SetInput<int>(
+ model_config_->model_spec()->input_context_length(), context.size(),
+ interpreter.get());
+ }
+
+ // Num suggestions is always locked to 3.
+ if (model_config_->model_spec()->input_num_suggestions() > 0) {
+ model_executor_->SetInput<int>(
+ model_config_->model_spec()->input_num_suggestions(), 3,
+ interpreter.get());
+ }
+
+ if (interpreter->Invoke() != kTfLiteOk) {
+ // TODO(mgubin): Report a error about invoke.
+ return std::make_pair(false, 0.0f);
+ }
+
+ // Check that the prediction is not-sensitive.
+ const std::vector<tflite::StringRef> replies =
+ model_executor_->Output<tflite::StringRef>(
+ model_config_->model_spec()->output_replies(), interpreter.get());
+ const TensorView<float> scores = model_executor_->OutputView<float>(
+ model_config_->model_spec()->output_replies_scores(), interpreter.get());
+ for (int i = 0; i < replies.size(); ++i) {
+ const auto reply = replies[i];
+ if (reply.len != sizeof(kNotSensitive) - 1 &&
+ 0 != memcmp(reply.str, kNotSensitive, sizeof(kNotSensitive))) {
+ const auto score = scores.data()[i];
+ if (score >= model_config_->threshold()) {
+ return std::make_pair(true, score);
+ }
+ }
+ }
+ return std::make_pair(false, 1.0);
+}
+
+TFLiteSensitiveModel::TFLiteSensitiveModel(
+ const TFLiteSensitiveClassifierConfig* model_config)
+ : model_config_(model_config),
+ model_executor_(TfLiteModelExecutor::FromBuffer(
+ model_config->model_spec()->tflite_model())) {}
+} // namespace libtextclassifier3
diff --git a/native/actions/tflite-sensitive-model.h b/native/actions/tflite-sensitive-model.h
new file mode 100644
index 0000000..2f161a8
--- /dev/null
+++ b/native/actions/tflite-sensitive-model.h
@@ -0,0 +1,45 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ACTIONS_TFLITE_SENSITIVE_MODEL_H_
+#define LIBTEXTCLASSIFIER_ACTIONS_TFLITE_SENSITIVE_MODEL_H_
+
+#include <memory>
+
+#include "actions/actions_model_generated.h"
+#include "actions/sensitive-classifier-base.h"
+#include "utils/tflite-model-executor.h"
+
+namespace libtextclassifier3 {
+class TFLiteSensitiveModel : public SensitiveTopicModelBase {
+ public:
+ // The object keeps but doesn't own model_config.
+ static std::unique_ptr<TFLiteSensitiveModel> Create(
+ const TFLiteSensitiveClassifierConfig* model_config);
+
+ std::pair<bool, float> Eval(const UnicodeText& text) const override;
+ std::pair<bool, float> EvalConversation(const Conversation& conversation,
+ int num_messages) const override;
+
+ private:
+ explicit TFLiteSensitiveModel(
+ const TFLiteSensitiveClassifierConfig* model_config);
+ const TFLiteSensitiveClassifierConfig* model_config_ = nullptr; // not owned.
+ std::unique_ptr<const TfLiteModelExecutor> model_executor_;
+};
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ACTIONS_TFLITE_SENSITIVE_MODEL_H_
diff --git a/native/actions/types.h b/native/actions/types.h
index e7d384f..c400bb2 100644
--- a/native/actions/types.h
+++ b/native/actions/types.h
@@ -19,11 +19,12 @@
#include <map>
#include <string>
+#include <unordered_map>
#include <vector>
#include "actions/actions-entity-data_generated.h"
#include "annotator/types.h"
-#include "utils/flatbuffers.h"
+#include "utils/flatbuffers/flatbuffers.h"
namespace libtextclassifier3 {
@@ -55,6 +56,13 @@
std::string name;
};
+// A slot associated with an action {
+struct Slot {
+ std::string type;
+ MessageTextSpan span;
+ float confidence_score;
+};
+
// Action suggestion that contains a response text and the type of the response.
struct ActionSuggestion {
// Text of the action suggestion.
@@ -75,12 +83,21 @@
// Extras information.
std::string serialized_entity_data;
- const ActionsEntityData* entity_data() {
+ // Slots corresponding to the action suggestion.
+ std::vector<Slot> slots;
+
+ const ActionsEntityData* entity_data() const {
return LoadAndVerifyFlatbuffer<ActionsEntityData>(
serialized_entity_data.data(), serialized_entity_data.size());
}
};
+// Options for suggesting actions.
+struct ActionSuggestionOptions {
+ static ActionSuggestionOptions Default() { return ActionSuggestionOptions(); }
+ std::unordered_map<std::string, Variant> model_parameters;
+};
+
// Actions suggestions result containing meta - information and the suggested
// actions.
struct ActionsSuggestionsResponse {
@@ -88,8 +105,8 @@
float sensitivity_score = -1.f;
float triggering_score = -1.f;
- // Whether the output was suppressed by the sensitivity threshold.
- bool output_filtered_sensitivity = false;
+ // Whether the input conversation is considered as sensitive.
+ bool is_sensitive = false;
// Whether the output was suppressed by the triggering score threshold.
bool output_filtered_min_triggering_score = false;
diff --git a/native/actions/utils.cc b/native/actions/utils.cc
index 96f6f1f..648f04d 100644
--- a/native/actions/utils.cc
+++ b/native/actions/utils.cc
@@ -16,14 +16,19 @@
#include "actions/utils.h"
+#include "annotator/collections.h"
#include "utils/base/logging.h"
#include "utils/normalization.h"
#include "utils/strings/stringpiece.h"
namespace libtextclassifier3 {
+// Name for a datetime annotation that only includes time but no date.
+const std::string& kTimeAnnotation =
+ *[]() { return new std::string("time"); }();
+
void FillSuggestionFromSpec(const ActionSuggestionSpec* action,
- ReflectiveFlatbuffer* entity_data,
+ MutableFlatbuffer* entity_data,
ActionSuggestion* suggestion) {
if (action != nullptr) {
suggestion->score = action->score();
@@ -52,7 +57,7 @@
}
void SuggestTextRepliesFromCapturingMatch(
- const ReflectiveFlatbufferBuilder* entity_data_builder,
+ const MutableFlatbufferBuilder* entity_data_builder,
const RulesModel_::RuleActionSpec_::RuleCapturingGroup* group,
const UnicodeText& match_text, const std::string& smart_reply_action_type,
std::vector<ActionSuggestion>* actions) {
@@ -60,7 +65,7 @@
ActionSuggestion suggestion;
suggestion.response_text = match_text.ToUTF8String();
suggestion.type = smart_reply_action_type;
- std::unique_ptr<ReflectiveFlatbuffer> entity_data =
+ std::unique_ptr<MutableFlatbuffer> entity_data =
entity_data_builder != nullptr ? entity_data_builder->NewRoot()
: nullptr;
FillSuggestionFromSpec(group->text_reply(), entity_data.get(), &suggestion);
@@ -72,13 +77,18 @@
const UniLib& unilib,
const RulesModel_::RuleActionSpec_::RuleCapturingGroup* group,
StringPiece match_text) {
- UnicodeText normalized_match_text =
- UTF8ToUnicodeText(match_text, /*do_copy=*/false);
- if (group->normalization_options() != nullptr) {
- normalized_match_text = NormalizeText(
- unilib, group->normalization_options(), normalized_match_text);
+ return NormalizeMatchText(unilib, group,
+ UTF8ToUnicodeText(match_text, /*do_copy=*/false));
+}
+
+UnicodeText NormalizeMatchText(
+ const UniLib& unilib,
+ const RulesModel_::RuleActionSpec_::RuleCapturingGroup* group,
+ const UnicodeText match_text) {
+ if (group->normalization_options() == nullptr) {
+ return match_text;
}
- return normalized_match_text;
+ return NormalizeText(unilib, group->normalization_options(), match_text);
}
bool FillAnnotationFromCapturingMatch(
@@ -104,7 +114,7 @@
bool MergeEntityDataFromCapturingMatch(
const RulesModel_::RuleActionSpec_::RuleCapturingGroup* group,
- StringPiece match_text, ReflectiveFlatbuffer* buffer) {
+ StringPiece match_text, MutableFlatbuffer* buffer) {
if (group->entity_field() != nullptr) {
if (!buffer->ParseAndSet(group->entity_field(), match_text.ToString())) {
TC3_LOG(ERROR) << "Could not set entity data from rule capturing group.";
@@ -121,4 +131,29 @@
return true;
}
+void ConvertDatetimeToTime(std::vector<AnnotatedSpan>* annotations) {
+ for (int i = 0; i < annotations->size(); i++) {
+ ClassificationResult* classification =
+ &(*annotations)[i].classification.front();
+ // Specialize datetime annotation to time annotation if no date
+ // component is present.
+ if (classification->collection == Collections::DateTime() &&
+ classification->datetime_parse_result.IsSet()) {
+ bool has_only_time = true;
+ for (const DatetimeComponent& component :
+ classification->datetime_parse_result.datetime_components) {
+ if (component.component_type !=
+ DatetimeComponent::ComponentType::UNSPECIFIED &&
+ component.component_type < DatetimeComponent::ComponentType::HOUR) {
+ has_only_time = false;
+ break;
+ }
+ }
+ if (has_only_time) {
+ classification->collection = kTimeAnnotation;
+ }
+ }
+ }
+}
+
} // namespace libtextclassifier3
diff --git a/native/actions/utils.h b/native/actions/utils.h
index 820c79d..4838464 100644
--- a/native/actions/utils.h
+++ b/native/actions/utils.h
@@ -25,7 +25,8 @@
#include "actions/actions_model_generated.h"
#include "actions/types.h"
#include "annotator/types.h"
-#include "utils/flatbuffers.h"
+#include "utils/flatbuffers/flatbuffers.h"
+#include "utils/flatbuffers/mutable.h"
#include "utils/utf8/unicodetext.h"
#include "utils/utf8/unilib.h"
@@ -33,12 +34,12 @@
// Fills an action suggestion from a template.
void FillSuggestionFromSpec(const ActionSuggestionSpec* action,
- ReflectiveFlatbuffer* entity_data,
+ MutableFlatbuffer* entity_data,
ActionSuggestion* suggestion);
// Creates text replies from capturing matches.
void SuggestTextRepliesFromCapturingMatch(
- const ReflectiveFlatbufferBuilder* entity_data_builder,
+ const MutableFlatbufferBuilder* entity_data_builder,
const RulesModel_::RuleActionSpec_::RuleCapturingGroup* group,
const UnicodeText& match_text, const std::string& smart_reply_action_type,
std::vector<ActionSuggestion>* actions);
@@ -49,6 +50,11 @@
const RulesModel_::RuleActionSpec_::RuleCapturingGroup* group,
StringPiece match_text);
+UnicodeText NormalizeMatchText(
+ const UniLib& unilib,
+ const RulesModel_::RuleActionSpec_::RuleCapturingGroup* group,
+ const UnicodeText match_text);
+
// Fills the fields in an annotation from a capturing match.
bool FillAnnotationFromCapturingMatch(
const CodepointSpan& span,
@@ -60,7 +66,11 @@
// Parses and sets values from the text and merges fixed data.
bool MergeEntityDataFromCapturingMatch(
const RulesModel_::RuleActionSpec_::RuleCapturingGroup* group,
- StringPiece match_text, ReflectiveFlatbuffer* buffer);
+ StringPiece match_text, MutableFlatbuffer* buffer);
+
+// Changes datetime classifications to a time type if no date component is
+// present. Modifies classifications in-place.
+void ConvertDatetimeToTime(std::vector<AnnotatedSpan>* annotations);
} // namespace libtextclassifier3
diff --git a/native/actions/zlib-utils.cc b/native/actions/zlib-utils.cc
index c8ad4e7..4525a21 100644
--- a/native/actions/zlib-utils.cc
+++ b/native/actions/zlib-utils.cc
@@ -20,7 +20,6 @@
#include "utils/base/logging.h"
#include "utils/intents/zlib-utils.h"
-#include "utils/resources.h"
namespace libtextclassifier3 {
@@ -76,11 +75,6 @@
model->ranking_options->compressed_lua_ranking_script.get());
}
- // Compress resources.
- if (model->resources != nullptr) {
- CompressResources(model->resources.get());
- }
-
// Compress intent generator.
if (model->android_intent_options != nullptr) {
CompressIntentModel(model->android_intent_options.get());
diff --git a/native/annotator/annotator.cc b/native/annotator/annotator.cc
index 6ee983f..e296a64 100644
--- a/native/annotator/annotator.cc
+++ b/native/annotator/annotator.cc
@@ -20,23 +20,32 @@
#include <cmath>
#include <cstddef>
#include <iterator>
+#include <limits>
#include <numeric>
#include <string>
#include <unordered_map>
#include <vector>
#include "annotator/collections.h"
+#include "annotator/datetime/grammar-parser.h"
+#include "annotator/datetime/regex-parser.h"
+#include "annotator/flatbuffer-utils.h"
+#include "annotator/knowledge/knowledge-engine-types.h"
#include "annotator/model_generated.h"
#include "annotator/types.h"
#include "utils/base/logging.h"
#include "utils/base/status.h"
#include "utils/base/statusor.h"
+#include "utils/calendar/calendar.h"
#include "utils/checksum.h"
+#include "utils/grammar/analyzer.h"
+#include "utils/i18n/locale-list.h"
#include "utils/i18n/locale.h"
#include "utils/math/softmax.h"
#include "utils/normalization.h"
#include "utils/optional.h"
#include "utils/regex-match.h"
+#include "utils/strings/append.h"
#include "utils/strings/numbers.h"
#include "utils/strings/split.h"
#include "utils/utf8/unicodetext.h"
@@ -102,12 +111,8 @@
}
// Returns whether the provided input is valid:
-// * Valid utf8 text.
// * Sane span indices.
-bool IsValidSpanInput(const UnicodeText& context, const CodepointSpan span) {
- if (!context.is_valid()) {
- return false;
- }
+bool IsValidSpanInput(const UnicodeText& context, const CodepointSpan& span) {
return (span.first >= 0 && span.first < span.second &&
span.second <= context.size_codepoints());
}
@@ -124,37 +129,6 @@
return ints_set;
}
-DateAnnotationOptions ToDateAnnotationOptions(
- const GrammarDatetimeModel_::AnnotationOptions* fb_annotation_options,
- const std::string& reference_timezone, const int64 reference_time_ms_utc) {
- DateAnnotationOptions result_annotation_options;
- result_annotation_options.base_timestamp_millis = reference_time_ms_utc;
- result_annotation_options.reference_timezone = reference_timezone;
- if (fb_annotation_options != nullptr) {
- result_annotation_options.enable_special_day_offset =
- fb_annotation_options->enable_special_day_offset();
- result_annotation_options.merge_adjacent_components =
- fb_annotation_options->merge_adjacent_components();
- result_annotation_options.enable_date_range =
- fb_annotation_options->enable_date_range();
- result_annotation_options.include_preposition =
- fb_annotation_options->include_preposition();
- if (fb_annotation_options->extra_requested_dates() != nullptr) {
- for (const auto& extra_requested_date :
- *fb_annotation_options->extra_requested_dates()) {
- result_annotation_options.extra_requested_dates.push_back(
- extra_requested_date->str());
- }
- }
- if (fb_annotation_options->ignored_spans() != nullptr) {
- for (const auto& ignored_span : *fb_annotation_options->ignored_spans()) {
- result_annotation_options.ignored_spans.push_back(ignored_span->str());
- }
- }
- }
- return result_annotation_options;
-}
-
} // namespace
tflite::Interpreter* InterpreterManager::SelectionInterpreter() {
@@ -187,8 +161,32 @@
return nullptr;
}
- auto classifier =
- std::unique_ptr<Annotator>(new Annotator(model, unilib, calendarlib));
+ auto classifier = std::unique_ptr<Annotator>(new Annotator());
+ unilib = MaybeCreateUnilib(unilib, &classifier->owned_unilib_);
+ calendarlib =
+ MaybeCreateCalendarlib(calendarlib, &classifier->owned_calendarlib_);
+ classifier->ValidateAndInitialize(model, unilib, calendarlib);
+ if (!classifier->IsInitialized()) {
+ return nullptr;
+ }
+
+ return classifier;
+}
+
+std::unique_ptr<Annotator> Annotator::FromString(
+ const std::string& buffer, const UniLib* unilib,
+ const CalendarLib* calendarlib) {
+ auto classifier = std::unique_ptr<Annotator>(new Annotator());
+ classifier->owned_buffer_ = buffer;
+ const Model* model = LoadAndVerifyModel(classifier->owned_buffer_.data(),
+ classifier->owned_buffer_.size());
+ if (model == nullptr) {
+ return nullptr;
+ }
+ unilib = MaybeCreateUnilib(unilib, &classifier->owned_unilib_);
+ calendarlib =
+ MaybeCreateCalendarlib(calendarlib, &classifier->owned_calendarlib_);
+ classifier->ValidateAndInitialize(model, unilib, calendarlib);
if (!classifier->IsInitialized()) {
return nullptr;
}
@@ -211,8 +209,12 @@
return nullptr;
}
- auto classifier = std::unique_ptr<Annotator>(
- new Annotator(mmap, model, unilib, calendarlib));
+ auto classifier = std::unique_ptr<Annotator>(new Annotator());
+ classifier->mmap_ = std::move(*mmap);
+ unilib = MaybeCreateUnilib(unilib, &classifier->owned_unilib_);
+ calendarlib =
+ MaybeCreateCalendarlib(calendarlib, &classifier->owned_calendarlib_);
+ classifier->ValidateAndInitialize(model, unilib, calendarlib);
if (!classifier->IsInitialized()) {
return nullptr;
}
@@ -235,8 +237,12 @@
return nullptr;
}
- auto classifier = std::unique_ptr<Annotator>(
- new Annotator(mmap, model, std::move(unilib), std::move(calendarlib)));
+ auto classifier = std::unique_ptr<Annotator>(new Annotator());
+ classifier->mmap_ = std::move(*mmap);
+ classifier->owned_unilib_ = std::move(unilib);
+ classifier->owned_calendarlib_ = std::move(calendarlib);
+ classifier->ValidateAndInitialize(model, classifier->owned_unilib_.get(),
+ classifier->owned_calendarlib_.get());
if (!classifier->IsInitialized()) {
return nullptr;
}
@@ -285,40 +291,12 @@
return FromScopedMmap(&mmap, std::move(unilib), std::move(calendarlib));
}
-Annotator::Annotator(std::unique_ptr<ScopedMmap>* mmap, const Model* model,
- const UniLib* unilib, const CalendarLib* calendarlib)
- : model_(model),
- mmap_(std::move(*mmap)),
- owned_unilib_(nullptr),
- unilib_(MaybeCreateUnilib(unilib, &owned_unilib_)),
- owned_calendarlib_(nullptr),
- calendarlib_(MaybeCreateCalendarlib(calendarlib, &owned_calendarlib_)) {
- ValidateAndInitialize();
-}
+void Annotator::ValidateAndInitialize(const Model* model, const UniLib* unilib,
+ const CalendarLib* calendarlib) {
+ model_ = model;
+ unilib_ = unilib;
+ calendarlib_ = calendarlib;
-Annotator::Annotator(std::unique_ptr<ScopedMmap>* mmap, const Model* model,
- std::unique_ptr<UniLib> unilib,
- std::unique_ptr<CalendarLib> calendarlib)
- : model_(model),
- mmap_(std::move(*mmap)),
- owned_unilib_(std::move(unilib)),
- unilib_(owned_unilib_.get()),
- owned_calendarlib_(std::move(calendarlib)),
- calendarlib_(owned_calendarlib_.get()) {
- ValidateAndInitialize();
-}
-
-Annotator::Annotator(const Model* model, const UniLib* unilib,
- const CalendarLib* calendarlib)
- : model_(model),
- owned_unilib_(nullptr),
- unilib_(MaybeCreateUnilib(unilib, &owned_unilib_)),
- owned_calendarlib_(nullptr),
- calendarlib_(MaybeCreateCalendarlib(calendarlib, &owned_calendarlib_)) {
- ValidateAndInitialize();
-}
-
-void Annotator::ValidateAndInitialize() {
initialized_ = false;
if (model_ == nullptr) {
@@ -439,25 +417,19 @@
return;
}
}
- if (model_->grammar_datetime_model() &&
- model_->grammar_datetime_model()->datetime_rules()) {
- cfg_datetime_parser_.reset(new dates::CfgDatetimeAnnotator(
- unilib_,
- /*tokenizer_options=*/
- model_->grammar_datetime_model()->grammar_tokenizer_options(),
- calendarlib_,
- /*datetime_rules=*/model_->grammar_datetime_model()->datetime_rules(),
- model_->grammar_datetime_model()->target_classification_score(),
- model_->grammar_datetime_model()->priority_score()));
- if (!cfg_datetime_parser_) {
- TC3_LOG(ERROR) << "Could not initialize context free grammar based "
- "datetime parser.";
- return;
- }
- }
- if (model_->datetime_model()) {
- datetime_parser_ = DatetimeParser::Instance(
+ if (model_->datetime_grammar_model()) {
+ if (model_->datetime_grammar_model()->rules()) {
+ analyzer_ = std::make_unique<grammar::Analyzer>(
+ unilib_, model_->datetime_grammar_model()->rules());
+ datetime_grounder_ = std::make_unique<DatetimeGrounder>(calendarlib_);
+ datetime_parser_ = std::make_unique<GrammarDatetimeParser>(
+ *analyzer_, *datetime_grounder_,
+ /*target_classification_score=*/1.0,
+ /*priority_score=*/1.0);
+ }
+ } else if (model_->datetime_model()) {
+ datetime_parser_ = RegexDatetimeParser::Instance(
model_->datetime_model(), unilib_, calendarlib_, decompressor.get());
if (!datetime_parser_) {
TC3_LOG(ERROR) << "Could not initialize datetime parser.";
@@ -504,6 +476,26 @@
selection_feature_processor_.get(), unilib_));
}
+ if (model_->grammar_model()) {
+ grammar_annotator_.reset(new GrammarAnnotator(
+ unilib_, model_->grammar_model(), entity_data_builder_.get()));
+ }
+
+ // The following #ifdef is here to aid quality evaluation of a situation, when
+ // a POD NER kill switch in AiAi is invoked, when a model that has POD NER in
+ // it.
+#if !defined(TC3_DISABLE_POD_NER)
+ if (model_->pod_ner_model()) {
+ pod_ner_annotator_ =
+ PodNerAnnotator::Create(model_->pod_ner_model(), *unilib_);
+ }
+#endif
+
+ if (model_->vocab_model()) {
+ vocab_annotator_ = VocabAnnotator::Create(
+ model_->vocab_model(), *selection_feature_processor_, *unilib_);
+ }
+
if (model_->entity_data_schema()) {
entity_data_schema_ = LoadAndVerifyFlatbuffer<reflection::Schema>(
model_->entity_data_schema()->Data(),
@@ -514,17 +506,12 @@
}
entity_data_builder_.reset(
- new ReflectiveFlatbufferBuilder(entity_data_schema_));
+ new MutableFlatbufferBuilder(entity_data_schema_));
} else {
entity_data_schema_ = nullptr;
entity_data_builder_ = nullptr;
}
- if (model_->grammar_model()) {
- grammar_annotator_.reset(new GrammarAnnotator(
- unilib_, model_->grammar_model(), entity_data_builder_.get()));
- }
-
if (model_->triggering_locales() &&
!ParseLocales(model_->triggering_locales()->c_str(),
&model_triggering_locales_)) {
@@ -640,7 +627,11 @@
return true;
}
-void Annotator::SetLangId(const libtextclassifier3::mobile::lang_id::LangId* lang_id) {
+bool Annotator::SetLangId(const libtextclassifier3::mobile::lang_id::LangId* lang_id) {
+ if (lang_id == nullptr) {
+ return false;
+ }
+
lang_id_ = lang_id;
if (lang_id_ != nullptr && model_->translate_annotator_options() &&
model_->translate_annotator_options()->enabled()) {
@@ -649,6 +640,7 @@
} else {
translate_annotator_.reset(nullptr);
}
+ return true;
}
bool Annotator::InitializePersonNameEngineFromUnownedBuffer(const void* buffer,
@@ -706,31 +698,14 @@
return false;
}
-namespace {
-
-int CountDigits(const std::string& str, CodepointSpan selection_indices) {
- int count = 0;
- int i = 0;
- const UnicodeText unicode_str = UTF8ToUnicodeText(str, /*do_copy=*/false);
- for (auto it = unicode_str.begin(); it != unicode_str.end(); ++it, ++i) {
- if (i >= selection_indices.first && i < selection_indices.second &&
- IsDigit(*it)) {
- ++count;
- }
- }
- return count;
-}
-
-} // namespace
-
namespace internal {
// Helper function, which if the initial 'span' contains only white-spaces,
// moves the selection to a single-codepoint selection on a left or right side
// of this space.
-CodepointSpan SnapLeftIfWhitespaceSelection(CodepointSpan span,
+CodepointSpan SnapLeftIfWhitespaceSelection(const CodepointSpan& span,
const UnicodeText& context_unicode,
const UniLib& unilib) {
- TC3_CHECK(ValidNonEmptySpan(span));
+ TC3_CHECK(span.IsValid() && !span.IsEmpty());
UnicodeText::const_iterator it;
@@ -743,10 +718,8 @@
}
}
- CodepointSpan result;
-
// Try moving left.
- result = span;
+ CodepointSpan result = span;
it = context_unicode.begin();
std::advance(it, span.first);
while (it != context_unicode.begin() && unilib.IsWhitespace(*it)) {
@@ -833,6 +806,11 @@
CodepointSpan Annotator::SuggestSelection(
const std::string& context, CodepointSpan click_indices,
const SelectionOptions& options) const {
+ if (context.size() > std::numeric_limits<int>::max()) {
+ TC3_LOG(ERROR) << "Rejecting too long input: " << context.size();
+ return {};
+ }
+
CodepointSpan original_click_indices = click_indices;
if (!initialized_) {
TC3_LOG(ERROR) << "Not initialized";
@@ -864,6 +842,11 @@
const UnicodeText context_unicode = UTF8ToUnicodeText(context,
/*do_copy=*/false);
+ if (!unilib_->IsValidUtf8(context_unicode)) {
+ TC3_LOG(ERROR) << "Rejecting input, invalid UTF8.";
+ return original_click_indices;
+ }
+
if (!IsValidSpanInput(context_unicode, click_indices)) {
TC3_VLOG(1)
<< "Trying to run SuggestSelection with invalid input, indices: "
@@ -888,60 +871,74 @@
click_indices, context_unicode, *unilib_);
}
- std::vector<AnnotatedSpan> candidates;
+ Annotations candidates;
+ // As we process a single string of context, the candidates will only
+ // contain one vector of AnnotatedSpan.
+ candidates.annotated_spans.resize(1);
InterpreterManager interpreter_manager(selection_executor_.get(),
classification_executor_.get());
std::vector<Token> tokens;
if (!ModelSuggestSelection(context_unicode, click_indices,
detected_text_language_tags, &interpreter_manager,
- &tokens, &candidates)) {
+ &tokens, &candidates.annotated_spans[0])) {
TC3_LOG(ERROR) << "Model suggest selection failed.";
return original_click_indices;
}
- if (!RegexChunk(context_unicode, selection_regex_patterns_, &candidates,
- /*is_serialized_entity_data_enabled=*/false)) {
+ const std::unordered_set<std::string> set;
+ const EnabledEntityTypes is_entity_type_enabled(set);
+ if (!RegexChunk(context_unicode, selection_regex_patterns_,
+ /*is_serialized_entity_data_enabled=*/false,
+ is_entity_type_enabled, options.annotation_usecase,
+ &candidates.annotated_spans[0])) {
TC3_LOG(ERROR) << "Regex suggest selection failed.";
return original_click_indices;
}
- if (!DatetimeChunk(
- UTF8ToUnicodeText(context, /*do_copy=*/false),
- /*reference_time_ms_utc=*/0, /*reference_timezone=*/"",
- options.locales, ModeFlag_SELECTION, options.annotation_usecase,
- /*is_serialized_entity_data_enabled=*/false, &candidates)) {
+ if (!DatetimeChunk(UTF8ToUnicodeText(context, /*do_copy=*/false),
+ /*reference_time_ms_utc=*/0, /*reference_timezone=*/"",
+ options.locales, ModeFlag_SELECTION,
+ options.annotation_usecase,
+ /*is_serialized_entity_data_enabled=*/false,
+ &candidates.annotated_spans[0])) {
TC3_LOG(ERROR) << "Datetime suggest selection failed.";
return original_click_indices;
}
if (knowledge_engine_ != nullptr &&
- !knowledge_engine_->Chunk(context, options.annotation_usecase,
- options.location_context, Permissions(),
- &candidates)) {
+ !knowledge_engine_
+ ->Chunk(context, options.annotation_usecase,
+ options.location_context, Permissions(),
+ AnnotateMode::kEntityAnnotation, &candidates)
+ .ok()) {
TC3_LOG(ERROR) << "Knowledge suggest selection failed.";
return original_click_indices;
}
if (contact_engine_ != nullptr &&
- !contact_engine_->Chunk(context_unicode, tokens, &candidates)) {
+ !contact_engine_->Chunk(context_unicode, tokens,
+ &candidates.annotated_spans[0])) {
TC3_LOG(ERROR) << "Contact suggest selection failed.";
return original_click_indices;
}
if (installed_app_engine_ != nullptr &&
- !installed_app_engine_->Chunk(context_unicode, tokens, &candidates)) {
+ !installed_app_engine_->Chunk(context_unicode, tokens,
+ &candidates.annotated_spans[0])) {
TC3_LOG(ERROR) << "Installed app suggest selection failed.";
return original_click_indices;
}
if (number_annotator_ != nullptr &&
!number_annotator_->FindAll(context_unicode, options.annotation_usecase,
- &candidates)) {
+ &candidates.annotated_spans[0])) {
TC3_LOG(ERROR) << "Number annotator failed in suggest selection.";
return original_click_indices;
}
if (duration_annotator_ != nullptr &&
!duration_annotator_->FindAll(context_unicode, tokens,
- options.annotation_usecase, &candidates)) {
+ options.annotation_usecase,
+ &candidates.annotated_spans[0])) {
TC3_LOG(ERROR) << "Duration annotator failed in suggest selection.";
return original_click_indices;
}
if (person_name_engine_ != nullptr &&
- !person_name_engine_->Chunk(context_unicode, tokens, &candidates)) {
+ !person_name_engine_->Chunk(context_unicode, tokens,
+ &candidates.annotated_spans[0])) {
TC3_LOG(ERROR) << "Person name suggest selection failed.";
return original_click_indices;
}
@@ -951,25 +948,34 @@
grammar_annotator_->SuggestSelection(detected_text_language_tags,
context_unicode, click_indices,
&grammar_suggested_span)) {
- candidates.push_back(grammar_suggested_span);
+ candidates.annotated_spans[0].push_back(grammar_suggested_span);
+ }
+
+ AnnotatedSpan pod_ner_suggested_span;
+ if (pod_ner_annotator_ != nullptr && options.use_pod_ner &&
+ pod_ner_annotator_->SuggestSelection(context_unicode, click_indices,
+ &pod_ner_suggested_span)) {
+ candidates.annotated_spans[0].push_back(pod_ner_suggested_span);
}
if (experimental_annotator_ != nullptr) {
- candidates.push_back(experimental_annotator_->SuggestSelection(
- context_unicode, click_indices));
+ candidates.annotated_spans[0].push_back(
+ experimental_annotator_->SuggestSelection(context_unicode,
+ click_indices));
}
// Sort candidates according to their position in the input, so that the next
// code can assume that any connected component of overlapping spans forms a
// contiguous block.
- std::sort(candidates.begin(), candidates.end(),
+ std::sort(candidates.annotated_spans[0].begin(),
+ candidates.annotated_spans[0].end(),
[](const AnnotatedSpan& a, const AnnotatedSpan& b) {
return a.span.first < b.span.first;
});
std::vector<int> candidate_indices;
- if (!ResolveConflicts(candidates, context, tokens,
- detected_text_language_tags, options.annotation_usecase,
+ if (!ResolveConflicts(candidates.annotated_spans[0], context, tokens,
+ detected_text_language_tags, options,
&interpreter_manager, &candidate_indices)) {
TC3_LOG(ERROR) << "Couldn't resolve conflicts.";
return original_click_indices;
@@ -977,32 +983,44 @@
std::sort(candidate_indices.begin(), candidate_indices.end(),
[this, &candidates](int a, int b) {
- return GetPriorityScore(candidates[a].classification) >
- GetPriorityScore(candidates[b].classification);
+ return GetPriorityScore(
+ candidates.annotated_spans[0][a].classification) >
+ GetPriorityScore(
+ candidates.annotated_spans[0][b].classification);
});
for (const int i : candidate_indices) {
- if (SpansOverlap(candidates[i].span, click_indices) &&
- SpansOverlap(candidates[i].span, original_click_indices)) {
+ if (SpansOverlap(candidates.annotated_spans[0][i].span, click_indices) &&
+ SpansOverlap(candidates.annotated_spans[0][i].span,
+ original_click_indices)) {
// Run model classification if not present but requested and there's a
// classification collection filter specified.
- if (candidates[i].classification.empty() &&
+ if (candidates.annotated_spans[0][i].classification.empty() &&
model_->selection_options()->always_classify_suggested_selection() &&
!filtered_collections_selection_.empty()) {
- if (!ModelClassifyText(context, detected_text_language_tags,
- candidates[i].span, &interpreter_manager,
+ if (!ModelClassifyText(context, /*cached_tokens=*/{},
+ detected_text_language_tags,
+ candidates.annotated_spans[0][i].span, options,
+ &interpreter_manager,
/*embedding_cache=*/nullptr,
- &candidates[i].classification)) {
+ &candidates.annotated_spans[0][i].classification,
+ /*tokens=*/nullptr)) {
return original_click_indices;
}
}
// Ignore if span classification is filtered.
- if (FilteredForSelection(candidates[i])) {
+ if (FilteredForSelection(candidates.annotated_spans[0][i])) {
return original_click_indices;
}
- return candidates[i].span;
+ // We return a suggested span contains the original span.
+ // This compensates for "select all" selection that may come from
+ // other apps. See http://b/179890518.
+ if (SpanContains(candidates.annotated_spans[0][i].span,
+ original_click_indices)) {
+ return candidates.annotated_spans[0][i].span;
+ }
}
}
@@ -1035,8 +1053,8 @@
const std::vector<AnnotatedSpan>& candidates, const std::string& context,
const std::vector<Token>& cached_tokens,
const std::vector<Locale>& detected_text_language_tags,
- AnnotationUsecase annotation_usecase,
- InterpreterManager* interpreter_manager, std::vector<int>* result) const {
+ const BaseOptions& options, InterpreterManager* interpreter_manager,
+ std::vector<int>* result) const {
result->clear();
result->reserve(candidates.size());
for (int i = 0; i < candidates.size();) {
@@ -1048,8 +1066,8 @@
std::vector<int> candidate_indices;
if (!ResolveConflict(context, cached_tokens, candidates,
detected_text_language_tags, i,
- first_non_overlapping, annotation_usecase,
- interpreter_manager, &candidate_indices)) {
+ first_non_overlapping, options, interpreter_manager,
+ &candidate_indices)) {
return false;
}
result->insert(result->end(), candidate_indices.begin(),
@@ -1115,7 +1133,7 @@
const std::string& context, const std::vector<Token>& cached_tokens,
const std::vector<AnnotatedSpan>& candidates,
const std::vector<Locale>& detected_text_language_tags, int start_index,
- int end_index, AnnotationUsecase annotation_usecase,
+ int end_index, const BaseOptions& options,
InterpreterManager* interpreter_manager,
std::vector<int>* chosen_indices) const {
std::vector<int> conflicting_indices;
@@ -1136,8 +1154,9 @@
// classification to determine its priority:
std::vector<ClassificationResult> classification;
if (!ModelClassifyText(context, cached_tokens, detected_text_language_tags,
- candidates[i].span, interpreter_manager,
- /*embedding_cache=*/nullptr, &classification)) {
+ candidates[i].span, options, interpreter_manager,
+ /*embedding_cache=*/nullptr, &classification,
+ /*tokens=*/nullptr)) {
return false;
}
@@ -1178,11 +1197,13 @@
}
const bool needs_conflict_resolution =
- annotation_usecase == AnnotationUsecase_ANNOTATION_USECASE_SMART ||
- (annotation_usecase == AnnotationUsecase_ANNOTATION_USECASE_RAW &&
+ options.annotation_usecase ==
+ AnnotationUsecase_ANNOTATION_USECASE_SMART ||
+ (options.annotation_usecase ==
+ AnnotationUsecase_ANNOTATION_USECASE_RAW &&
do_conflict_resolution_in_raw_mode_);
if (needs_conflict_resolution &&
- DoSourcesConflict(annotation_usecase, source_set_pair.first,
+ DoSourcesConflict(options.annotation_usecase, source_set_pair.first,
candidates[considered_candidate].source) &&
DoesCandidateConflict(considered_candidate, candidates,
source_set_pair.second)) {
@@ -1220,7 +1241,7 @@
}
bool Annotator::ModelSuggestSelection(
- const UnicodeText& context_unicode, CodepointSpan click_indices,
+ const UnicodeText& context_unicode, const CodepointSpan& click_indices,
const std::vector<Locale>& detected_text_language_tags,
InterpreterManager* interpreter_manager, std::vector<Token>* tokens,
std::vector<AnnotatedSpan>* result) const {
@@ -1237,8 +1258,10 @@
int click_pos;
*tokens = selection_feature_processor_->Tokenize(context_unicode);
+ const auto [click_begin, click_end] =
+ CodepointSpanToUnicodeTextRange(context_unicode, click_indices);
selection_feature_processor_->RetokenizeAndFindClick(
- context_unicode, click_indices,
+ context_unicode, click_begin, click_end, click_indices,
selection_feature_processor_->GetOptions()->only_use_line_with_click(),
tokens, &click_pos);
if (click_pos == kInvalidIndex) {
@@ -1254,11 +1277,11 @@
// The symmetry context span is the clicked token with symmetry_context_size
// tokens on either side.
- const TokenSpan symmetry_context_span = IntersectTokenSpans(
- ExpandTokenSpan(SingleTokenSpan(click_pos),
- /*num_tokens_left=*/symmetry_context_size,
- /*num_tokens_right=*/symmetry_context_size),
- {0, tokens->size()});
+ const TokenSpan symmetry_context_span =
+ IntersectTokenSpans(TokenSpan(click_pos).Expand(
+ /*num_tokens_left=*/symmetry_context_size,
+ /*num_tokens_right=*/symmetry_context_size),
+ AllOf(*tokens));
// Compute the extraction span based on the model type.
TokenSpan extraction_span;
@@ -1269,22 +1292,21 @@
// the bounds of the selection.
const int max_selection_span =
selection_feature_processor_->GetOptions()->max_selection_span();
- extraction_span =
- ExpandTokenSpan(symmetry_context_span,
- /*num_tokens_left=*/max_selection_span +
- bounds_sensitive_features->num_tokens_before(),
- /*num_tokens_right=*/max_selection_span +
- bounds_sensitive_features->num_tokens_after());
+ extraction_span = symmetry_context_span.Expand(
+ /*num_tokens_left=*/max_selection_span +
+ bounds_sensitive_features->num_tokens_before(),
+ /*num_tokens_right=*/max_selection_span +
+ bounds_sensitive_features->num_tokens_after());
} else {
// The extraction span is the symmetry context span expanded to include
// context_size tokens on either side.
const int context_size =
selection_feature_processor_->GetOptions()->context_size();
- extraction_span = ExpandTokenSpan(symmetry_context_span,
- /*num_tokens_left=*/context_size,
- /*num_tokens_right=*/context_size);
+ extraction_span = symmetry_context_span.Expand(
+ /*num_tokens_left=*/context_size,
+ /*num_tokens_right=*/context_size);
}
- extraction_span = IntersectTokenSpans(extraction_span, {0, tokens->size()});
+ extraction_span = IntersectTokenSpans(extraction_span, AllOf(*tokens));
if (!selection_feature_processor_->HasEnoughSupportedCodepoints(
*tokens, extraction_span)) {
@@ -1330,20 +1352,9 @@
return true;
}
-bool Annotator::ModelClassifyText(
- const std::string& context,
- const std::vector<Locale>& detected_text_language_tags,
- CodepointSpan selection_indices, InterpreterManager* interpreter_manager,
- FeatureProcessor::EmbeddingCache* embedding_cache,
- std::vector<ClassificationResult>* classification_results) const {
- return ModelClassifyText(context, {}, detected_text_language_tags,
- selection_indices, interpreter_manager,
- embedding_cache, classification_results);
-}
-
namespace internal {
std::vector<Token> CopyCachedTokens(const std::vector<Token>& cached_tokens,
- CodepointSpan selection_indices,
+ const CodepointSpan& selection_indices,
TokenSpan tokens_around_selection_to_copy) {
const auto first_selection_token = std::upper_bound(
cached_tokens.begin(), cached_tokens.end(), selection_indices.first,
@@ -1407,19 +1418,29 @@
bool Annotator::ModelClassifyText(
const std::string& context, const std::vector<Token>& cached_tokens,
const std::vector<Locale>& detected_text_language_tags,
- CodepointSpan selection_indices, InterpreterManager* interpreter_manager,
+ const CodepointSpan& selection_indices, const BaseOptions& options,
+ InterpreterManager* interpreter_manager,
FeatureProcessor::EmbeddingCache* embedding_cache,
- std::vector<ClassificationResult>* classification_results) const {
- std::vector<Token> tokens;
- return ModelClassifyText(context, cached_tokens, detected_text_language_tags,
- selection_indices, interpreter_manager,
- embedding_cache, classification_results, &tokens);
+ std::vector<ClassificationResult>* classification_results,
+ std::vector<Token>* tokens) const {
+ const UnicodeText context_unicode =
+ UTF8ToUnicodeText(context, /*do_copy=*/false);
+ const auto [span_begin, span_end] =
+ CodepointSpanToUnicodeTextRange(context_unicode, selection_indices);
+ return ModelClassifyText(context_unicode, cached_tokens,
+ detected_text_language_tags, span_begin, span_end,
+ /*line=*/nullptr, selection_indices, options,
+ interpreter_manager, embedding_cache,
+ classification_results, tokens);
}
bool Annotator::ModelClassifyText(
- const std::string& context, const std::vector<Token>& cached_tokens,
+ const UnicodeText& context_unicode, const std::vector<Token>& cached_tokens,
const std::vector<Locale>& detected_text_language_tags,
- CodepointSpan selection_indices, InterpreterManager* interpreter_manager,
+ const UnicodeText::const_iterator& span_begin,
+ const UnicodeText::const_iterator& span_end, const UnicodeTextRange* line,
+ const CodepointSpan& selection_indices, const BaseOptions& options,
+ InterpreterManager* interpreter_manager,
FeatureProcessor::EmbeddingCache* embedding_cache,
std::vector<ClassificationResult>* classification_results,
std::vector<Token>* tokens) const {
@@ -1435,8 +1456,13 @@
return true;
}
+ std::vector<Token> local_tokens;
+ if (tokens == nullptr) {
+ tokens = &local_tokens;
+ }
+
if (cached_tokens.empty()) {
- *tokens = classification_feature_processor_->Tokenize(context);
+ *tokens = classification_feature_processor_->Tokenize(context_unicode);
} else {
*tokens = internal::CopyCachedTokens(cached_tokens, selection_indices,
ClassifyTextUpperBoundNeededTokens());
@@ -1444,13 +1470,13 @@
int click_pos;
classification_feature_processor_->RetokenizeAndFindClick(
- context, selection_indices,
+ context_unicode, span_begin, span_end, selection_indices,
classification_feature_processor_->GetOptions()
->only_use_line_with_click(),
tokens, &click_pos);
const TokenSpan selection_token_span =
CodepointSpanToTokenSpan(*tokens, selection_indices);
- const int selection_num_tokens = TokenSpanSize(selection_token_span);
+ const int selection_num_tokens = selection_token_span.Size();
if (model_->classification_options()->max_num_tokens() > 0 &&
model_->classification_options()->max_num_tokens() <
selection_num_tokens) {
@@ -1473,8 +1499,7 @@
if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
// The extraction span is the selection span expanded to include a relevant
// number of tokens outside of the bounds of the selection.
- extraction_span = ExpandTokenSpan(
- selection_token_span,
+ extraction_span = selection_token_span.Expand(
/*num_tokens_left=*/bounds_sensitive_features->num_tokens_before(),
/*num_tokens_right=*/bounds_sensitive_features->num_tokens_after());
} else {
@@ -1486,11 +1511,11 @@
// either side.
const int context_size =
classification_feature_processor_->GetOptions()->context_size();
- extraction_span = ExpandTokenSpan(SingleTokenSpan(click_pos),
- /*num_tokens_left=*/context_size,
- /*num_tokens_right=*/context_size);
+ extraction_span = TokenSpan(click_pos).Expand(
+ /*num_tokens_left=*/context_size,
+ /*num_tokens_right=*/context_size);
}
- extraction_span = IntersectTokenSpans(extraction_span, {0, tokens->size()});
+ extraction_span = IntersectTokenSpans(extraction_span, AllOf(*tokens));
if (!classification_feature_processor_->HasEnoughSupportedCodepoints(
*tokens, extraction_span)) {
@@ -1548,7 +1573,7 @@
// Sanity checks.
if (top_collection == Collections::Phone()) {
- const int digit_count = CountDigits(context, selection_indices);
+ const int digit_count = std::count_if(span_begin, span_end, IsDigit);
if (digit_count <
model_->classification_options()->phone_min_num_digits() ||
digit_count >
@@ -1563,14 +1588,14 @@
return true;
}
} else if (top_collection == Collections::Dictionary()) {
- if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
+ if ((options.use_vocab_annotator && vocab_annotator_) ||
+ !Locale::IsAnyLocaleSupported(detected_text_language_tags,
dictionary_locales_,
/*default_value=*/false)) {
*classification_results = {{Collections::Other(), 1.0}};
return true;
}
}
-
*classification_results = {{top_collection, /*arg_score=*/1.0,
/*arg_priority_score=*/scores[best_score_index]}};
@@ -1588,7 +1613,7 @@
}
bool Annotator::RegexClassifyText(
- const std::string& context, CodepointSpan selection_indices,
+ const std::string& context, const CodepointSpan& selection_indices,
std::vector<ClassificationResult>* classification_result) const {
const std::string selection_text =
UTF8ToUnicodeText(context, /*do_copy=*/false)
@@ -1643,45 +1668,13 @@
}
}
-std::string CreateDatetimeSerializedEntityData(
- const DatetimeParseResult& parse_result) {
- EntityDataT entity_data;
- entity_data.datetime.reset(new EntityData_::DatetimeT());
- entity_data.datetime->time_ms_utc = parse_result.time_ms_utc;
- entity_data.datetime->granularity =
- static_cast<EntityData_::Datetime_::Granularity>(
- parse_result.granularity);
-
- for (const auto& c : parse_result.datetime_components) {
- EntityData_::Datetime_::DatetimeComponentT datetime_component;
- datetime_component.absolute_value = c.value;
- datetime_component.relative_count = c.relative_count;
- datetime_component.component_type =
- static_cast<EntityData_::Datetime_::DatetimeComponent_::ComponentType>(
- c.component_type);
- datetime_component.relation_type =
- EntityData_::Datetime_::DatetimeComponent_::RelationType_ABSOLUTE;
- if (c.relative_qualifier !=
- DatetimeComponent::RelativeQualifier::UNSPECIFIED) {
- datetime_component.relation_type =
- EntityData_::Datetime_::DatetimeComponent_::RelationType_RELATIVE;
- }
- entity_data.datetime->datetime_component.emplace_back(
- new EntityData_::Datetime_::DatetimeComponentT(datetime_component));
- }
- flatbuffers::FlatBufferBuilder builder;
- FinishEntityDataBuffer(builder, EntityData::Pack(builder, &entity_data));
- return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
- builder.GetSize());
-}
-
} // namespace
bool Annotator::DatetimeClassifyText(
- const std::string& context, CodepointSpan selection_indices,
+ const std::string& context, const CodepointSpan& selection_indices,
const ClassificationOptions& options,
std::vector<ClassificationResult>* classification_results) const {
- if (!datetime_parser_ && !cfg_datetime_parser_) {
+ if (!datetime_parser_) {
return true;
}
@@ -1689,39 +1682,24 @@
UTF8ToUnicodeText(context, /*do_copy=*/false)
.UTF8Substring(selection_indices.first, selection_indices.second);
- std::vector<DatetimeParseResultSpan> datetime_spans;
-
- if (cfg_datetime_parser_) {
- if (!(model_->grammar_datetime_model()->enabled_modes() &
- ModeFlag_CLASSIFICATION)) {
- return true;
- }
- std::vector<Locale> parsed_locales;
- ParseLocales(options.locales, &parsed_locales);
- cfg_datetime_parser_->Parse(
- selection_text,
- ToDateAnnotationOptions(
- model_->grammar_datetime_model()->annotation_options(),
- options.reference_timezone, options.reference_time_ms_utc),
- parsed_locales, &datetime_spans);
+ LocaleList locale_list = LocaleList::ParseFrom(options.locales);
+ StatusOr<std::vector<DatetimeParseResultSpan>> result_status =
+ datetime_parser_->Parse(selection_text, options.reference_time_ms_utc,
+ options.reference_timezone, locale_list,
+ ModeFlag_CLASSIFICATION,
+ options.annotation_usecase,
+ /*anchor_start_end=*/true);
+ if (!result_status.ok()) {
+ TC3_LOG(ERROR) << "Error during parsing datetime.";
+ return false;
}
- if (datetime_parser_) {
- if (!datetime_parser_->Parse(selection_text, options.reference_time_ms_utc,
- options.reference_timezone, options.locales,
- ModeFlag_CLASSIFICATION,
- options.annotation_usecase,
- /*anchor_start_end=*/true, &datetime_spans)) {
- TC3_LOG(ERROR) << "Error during parsing datetime.";
- return false;
- }
- }
-
- for (const DatetimeParseResultSpan& datetime_span : datetime_spans) {
+ for (const DatetimeParseResultSpan& datetime_span :
+ result_status.ValueOrDie()) {
// Only consider the result valid if the selection and extracted datetime
// spans exactly match.
- if (std::make_pair(datetime_span.span.first + selection_indices.first,
- datetime_span.span.second + selection_indices.first) ==
+ if (CodepointSpan(datetime_span.span.first + selection_indices.first,
+ datetime_span.span.second + selection_indices.first) ==
selection_indices) {
for (const DatetimeParseResult& parse_result : datetime_span.data) {
classification_results->emplace_back(
@@ -1740,8 +1718,12 @@
}
std::vector<ClassificationResult> Annotator::ClassifyText(
- const std::string& context, CodepointSpan selection_indices,
+ const std::string& context, const CodepointSpan& selection_indices,
const ClassificationOptions& options) const {
+ if (context.size() > std::numeric_limits<int>::max()) {
+ TC3_LOG(ERROR) << "Rejecting too long input: " << context.size();
+ return {};
+ }
if (!initialized_) {
TC3_LOG(ERROR) << "Not initialized";
return {};
@@ -1769,11 +1751,17 @@
return {};
}
- if (!IsValidSpanInput(UTF8ToUnicodeText(context, /*do_copy=*/false),
- selection_indices)) {
+ const UnicodeText context_unicode =
+ UTF8ToUnicodeText(context, /*do_copy=*/false);
+
+ if (!unilib_->IsValidUtf8(context_unicode)) {
+ TC3_LOG(ERROR) << "Rejecting input, invalid UTF8.";
+ return {};
+ }
+
+ if (!IsValidSpanInput(context_unicode, selection_indices)) {
TC3_VLOG(1) << "Trying to run ClassifyText with invalid input: "
- << std::get<0>(selection_indices) << " "
- << std::get<1>(selection_indices);
+ << selection_indices.first << " " << selection_indices.second;
return {};
}
@@ -1785,9 +1773,11 @@
// TODO(b/126579108): Propagate error status.
ClassificationResult knowledge_result;
if (knowledge_engine_ &&
- knowledge_engine_->ClassifyText(
- context, selection_indices, options.annotation_usecase,
- options.location_context, Permissions(), &knowledge_result)) {
+ knowledge_engine_
+ ->ClassifyText(context, selection_indices, options.annotation_usecase,
+ options.location_context, Permissions(),
+ &knowledge_result)
+ .ok()) {
candidates.push_back({selection_indices, {knowledge_result}});
candidates.back().source = AnnotatedSpan::Source::KNOWLEDGE;
}
@@ -1845,9 +1835,6 @@
candidates.back().source = AnnotatedSpan::Source::DATETIME;
}
- const UnicodeText context_unicode =
- UTF8ToUnicodeText(context, /*do_copy=*/false);
-
// Try the number annotator.
// TODO(b/126579108): Propagate error status.
ClassificationResult number_annotator_result;
@@ -1885,11 +1872,25 @@
candidates.push_back({selection_indices, {grammar_annotator_result}});
}
- ClassificationResult experimental_annotator_result;
- if (experimental_annotator_ &&
- experimental_annotator_->ClassifyText(context_unicode, selection_indices,
- &experimental_annotator_result)) {
- candidates.push_back({selection_indices, {experimental_annotator_result}});
+ ClassificationResult pod_ner_annotator_result;
+ if (pod_ner_annotator_ && options.use_pod_ner &&
+ pod_ner_annotator_->ClassifyText(context_unicode, selection_indices,
+ &pod_ner_annotator_result)) {
+ candidates.push_back({selection_indices, {pod_ner_annotator_result}});
+ }
+
+ ClassificationResult vocab_annotator_result;
+ if (vocab_annotator_ && options.use_vocab_annotator &&
+ vocab_annotator_->ClassifyText(
+ context_unicode, selection_indices, detected_text_language_tags,
+ options.trigger_dictionary_on_beginner_words,
+ &vocab_annotator_result)) {
+ candidates.push_back({selection_indices, {vocab_annotator_result}});
+ }
+
+ if (experimental_annotator_) {
+ experimental_annotator_->ClassifyText(context_unicode, selection_indices,
+ candidates);
}
// Try the ML model.
@@ -1903,7 +1904,7 @@
std::vector<Token> tokens;
if (!ModelClassifyText(
context, /*cached_tokens=*/{}, detected_text_language_tags,
- selection_indices, &interpreter_manager,
+ selection_indices, options, &interpreter_manager,
/*embedding_cache=*/nullptr, &model_results, &tokens)) {
return {};
}
@@ -1913,7 +1914,7 @@
std::vector<int> candidate_indices;
if (!ResolveConflicts(candidates, context, tokens,
- detected_text_language_tags, options.annotation_usecase,
+ detected_text_language_tags, options,
&interpreter_manager, &candidate_indices)) {
TC3_LOG(ERROR) << "Couldn't resolve conflicts.";
return {};
@@ -1943,8 +1944,8 @@
bool Annotator::ModelAnnotate(
const std::string& context,
const std::vector<Locale>& detected_text_language_tags,
- InterpreterManager* interpreter_manager, std::vector<Token>* tokens,
- std::vector<AnnotatedSpan>* result) const {
+ const AnnotationOptions& options, InterpreterManager* interpreter_manager,
+ std::vector<Token>* tokens, std::vector<AnnotatedSpan>* result) const {
if (model_->triggering_options() == nullptr ||
!(model_->triggering_options()->enabled_modes() & ModeFlag_ANNOTATION)) {
return true;
@@ -1977,23 +1978,26 @@
const std::string line_str =
UnicodeText::UTF8Substring(line.first, line.second);
- *tokens = selection_feature_processor_->Tokenize(line_str);
+ std::vector<Token> line_tokens;
+ line_tokens = selection_feature_processor_->Tokenize(line_str);
+
selection_feature_processor_->RetokenizeAndFindClick(
line_str, {0, std::distance(line.first, line.second)},
selection_feature_processor_->GetOptions()->only_use_line_with_click(),
- tokens,
+ &line_tokens,
/*click_pos=*/nullptr);
- const TokenSpan full_line_span = {0, tokens->size()};
+ const TokenSpan full_line_span = {
+ 0, static_cast<TokenIndex>(line_tokens.size())};
// TODO(zilka): Add support for greater granularity of this check.
if (!selection_feature_processor_->HasEnoughSupportedCodepoints(
- *tokens, full_line_span)) {
+ line_tokens, full_line_span)) {
continue;
}
std::unique_ptr<CachedFeatures> cached_features;
if (!selection_feature_processor_->ExtractFeatures(
- *tokens, full_line_span,
+ line_tokens, full_line_span,
/*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex},
embedding_executor_.get(),
/*embedding_cache=*/nullptr,
@@ -2005,7 +2009,7 @@
}
std::vector<TokenSpan> local_chunks;
- if (!ModelChunk(tokens->size(), /*span_of_interest=*/full_line_span,
+ if (!ModelChunk(line_tokens.size(), /*span_of_interest=*/full_line_span,
interpreter_manager->SelectionInterpreter(),
*cached_features, &local_chunks)) {
TC3_LOG(ERROR) << "Could not chunk.";
@@ -2013,21 +2017,68 @@
}
const int offset = std::distance(context_unicode.begin(), line.first);
+ UnicodeText line_unicode;
+ std::vector<UnicodeText::const_iterator> line_codepoints;
+ if (options.enable_optimization) {
+ if (local_chunks.empty()) {
+ continue;
+ }
+ line_unicode = UTF8ToUnicodeText(line_str, /*do_copy=*/false);
+ line_codepoints = line_unicode.Codepoints();
+ line_codepoints.push_back(line_unicode.end());
+ }
for (const TokenSpan& chunk : local_chunks) {
- const CodepointSpan codepoint_span =
- selection_feature_processor_->StripBoundaryCodepoints(
- line_str, TokenSpanToCodepointSpan(*tokens, chunk));
+ CodepointSpan codepoint_span =
+ TokenSpanToCodepointSpan(line_tokens, chunk);
+ if (options.enable_optimization) {
+ if (!codepoint_span.IsValid() ||
+ codepoint_span.second > line_codepoints.size()) {
+ continue;
+ }
+ codepoint_span = selection_feature_processor_->StripBoundaryCodepoints(
+ /*span_begin=*/line_codepoints[codepoint_span.first],
+ /*span_end=*/line_codepoints[codepoint_span.second],
+ codepoint_span);
+ if (model_->selection_options()->strip_unpaired_brackets()) {
+ codepoint_span = StripUnpairedBrackets(
+ /*span_begin=*/line_codepoints[codepoint_span.first],
+ /*span_end=*/line_codepoints[codepoint_span.second],
+ codepoint_span, *unilib_);
+ }
+ } else {
+ codepoint_span = selection_feature_processor_->StripBoundaryCodepoints(
+ line_str, codepoint_span);
+ if (model_->selection_options()->strip_unpaired_brackets()) {
+ codepoint_span =
+ StripUnpairedBrackets(context_unicode, codepoint_span, *unilib_);
+ }
+ }
// Skip empty spans.
if (codepoint_span.first != codepoint_span.second) {
std::vector<ClassificationResult> classification;
- if (!ModelClassifyText(line_str, *tokens, detected_text_language_tags,
- codepoint_span, interpreter_manager,
- &embedding_cache, &classification)) {
- TC3_LOG(ERROR) << "Could not classify text: "
- << (codepoint_span.first + offset) << " "
- << (codepoint_span.second + offset);
- return false;
+ if (options.enable_optimization) {
+ if (!ModelClassifyText(
+ line_unicode, line_tokens, detected_text_language_tags,
+ /*span_begin=*/line_codepoints[codepoint_span.first],
+ /*span_end=*/line_codepoints[codepoint_span.second], &line,
+ codepoint_span, options, interpreter_manager,
+ &embedding_cache, &classification, /*tokens=*/nullptr)) {
+ TC3_LOG(ERROR) << "Could not classify text: "
+ << (codepoint_span.first + offset) << " "
+ << (codepoint_span.second + offset);
+ return false;
+ }
+ } else {
+ if (!ModelClassifyText(line_str, line_tokens,
+ detected_text_language_tags, codepoint_span,
+ options, interpreter_manager, &embedding_cache,
+ &classification, /*tokens=*/nullptr)) {
+ TC3_LOG(ERROR) << "Could not classify text: "
+ << (codepoint_span.first + offset) << " "
+ << (codepoint_span.second + offset);
+ return false;
+ }
}
// Do not include the span if it's classified as "other".
@@ -2041,6 +2092,16 @@
}
}
}
+
+ // If we are going line-by-line, we need to insert the tokens for each line.
+ // But if not, we can optimize and just std::move the current line vector to
+ // the output.
+ if (selection_feature_processor_->GetOptions()
+ ->only_use_line_with_click()) {
+ tokens->insert(tokens->end(), line_tokens.begin(), line_tokens.end());
+ } else {
+ *tokens = std::move(line_tokens);
+ }
}
return true;
}
@@ -2103,10 +2164,6 @@
const UnicodeText context_unicode =
UTF8ToUnicodeText(context, /*do_copy=*/false);
- if (!context_unicode.is_valid()) {
- return Status(StatusCode::INVALID_ARGUMENT,
- "Context string isn't valid UTF8.");
- }
std::vector<Locale> detected_text_language_tags;
if (!ParseLocales(options.detected_text_language_tags,
@@ -2126,22 +2183,43 @@
InterpreterManager interpreter_manager(selection_executor_.get(),
classification_executor_.get());
+ const EnabledEntityTypes is_entity_type_enabled(options.entity_types);
+ const bool is_raw_usecase =
+ options.annotation_usecase == AnnotationUsecase_ANNOTATION_USECASE_RAW;
+
// Annotate with the selection model.
+ const bool model_annotations_enabled =
+ !is_raw_usecase || IsAnyModelEntityTypeEnabled(is_entity_type_enabled);
std::vector<Token> tokens;
- if (!ModelAnnotate(context, detected_text_language_tags, &interpreter_manager,
- &tokens, candidates)) {
+ if (model_annotations_enabled &&
+ !ModelAnnotate(context, detected_text_language_tags, options,
+ &interpreter_manager, &tokens, candidates)) {
return Status(StatusCode::INTERNAL, "Couldn't run ModelAnnotate.");
+ } else if (!model_annotations_enabled) {
+ // If the ML model didn't run, we need to tokenize to support the other
+ // annotators that depend on the tokens.
+ // Optimization could be made to only do this when an annotator that uses
+ // the tokens is enabled, but it's unclear if the added complexity is worth
+ // it.
+ if (selection_feature_processor_ != nullptr) {
+ tokens = selection_feature_processor_->Tokenize(context_unicode);
+ }
}
// Annotate with the regular expression models.
- if (!RegexChunk(UTF8ToUnicodeText(context, /*do_copy=*/false),
- annotation_regex_patterns_, candidates,
- options.is_serialized_entity_data_enabled)) {
+ const bool regex_annotations_enabled =
+ !is_raw_usecase || IsAnyRegexEntityTypeEnabled(is_entity_type_enabled);
+ if (regex_annotations_enabled &&
+ !RegexChunk(
+ UTF8ToUnicodeText(context, /*do_copy=*/false),
+ annotation_regex_patterns_, options.is_serialized_entity_data_enabled,
+ is_entity_type_enabled, options.annotation_usecase, candidates)) {
return Status(StatusCode::INTERNAL, "Couldn't run RegexChunk.");
}
// Annotate with the datetime model.
- const EnabledEntityTypes is_entity_type_enabled(options.entity_types);
+ // NOTE: Datetime can be disabled even in the SMART usecase, because it's been
+ // relatively slow for some clients.
if ((is_entity_type_enabled(Collections::Date()) ||
is_entity_type_enabled(Collections::DateTime())) &&
!DatetimeChunk(UTF8ToUnicodeText(context, /*do_copy=*/false),
@@ -2153,20 +2231,27 @@
}
// Annotate with the contact engine.
- if (contact_engine_ &&
+ const bool contact_annotations_enabled =
+ !is_raw_usecase || is_entity_type_enabled(Collections::Contact());
+ if (contact_annotations_enabled && contact_engine_ &&
!contact_engine_->Chunk(context_unicode, tokens, candidates)) {
return Status(StatusCode::INTERNAL, "Couldn't run contact engine Chunk.");
}
// Annotate with the installed app engine.
- if (installed_app_engine_ &&
+ const bool app_annotations_enabled =
+ !is_raw_usecase || is_entity_type_enabled(Collections::App());
+ if (app_annotations_enabled && installed_app_engine_ &&
!installed_app_engine_->Chunk(context_unicode, tokens, candidates)) {
return Status(StatusCode::INTERNAL,
"Couldn't run installed app engine Chunk.");
}
// Annotate with the number annotator.
- if (number_annotator_ != nullptr &&
+ const bool number_annotations_enabled =
+ !is_raw_usecase || (is_entity_type_enabled(Collections::Number()) ||
+ is_entity_type_enabled(Collections::Percentage()));
+ if (number_annotations_enabled && number_annotator_ != nullptr &&
!number_annotator_->FindAll(context_unicode, options.annotation_usecase,
candidates)) {
return Status(StatusCode::INTERNAL,
@@ -2174,8 +2259,9 @@
}
// Annotate with the duration annotator.
- if (is_entity_type_enabled(Collections::Duration()) &&
- duration_annotator_ != nullptr &&
+ const bool duration_annotations_enabled =
+ !is_raw_usecase || is_entity_type_enabled(Collections::Duration());
+ if (duration_annotations_enabled && duration_annotator_ != nullptr &&
!duration_annotator_->FindAll(context_unicode, tokens,
options.annotation_usecase, candidates)) {
return Status(StatusCode::INTERNAL,
@@ -2183,8 +2269,9 @@
}
// Annotate with the person name engine.
- if (is_entity_type_enabled(Collections::PersonName()) &&
- person_name_engine_ &&
+ const bool person_annotations_enabled =
+ !is_raw_usecase || is_entity_type_enabled(Collections::PersonName());
+ if (person_annotations_enabled && person_name_engine_ &&
!person_name_engine_->Chunk(context_unicode, tokens, candidates)) {
return Status(StatusCode::INTERNAL,
"Couldn't run person name engine Chunk.");
@@ -2197,6 +2284,27 @@
return Status(StatusCode::INTERNAL, "Couldn't run grammar annotators.");
}
+ // Annotate with the POD NER annotator.
+ const bool pod_ner_annotations_enabled =
+ !is_raw_usecase || IsAnyPodNerEntityTypeEnabled(is_entity_type_enabled);
+ if (pod_ner_annotations_enabled && pod_ner_annotator_ != nullptr &&
+ options.use_pod_ner &&
+ !pod_ner_annotator_->Annotate(context_unicode, candidates)) {
+ return Status(StatusCode::INTERNAL, "Couldn't run POD NER annotator.");
+ }
+
+ // Annotate with the vocab annotator.
+ const bool vocab_annotations_enabled =
+ !is_raw_usecase || is_entity_type_enabled(Collections::Dictionary());
+ if (vocab_annotations_enabled && vocab_annotator_ != nullptr &&
+ options.use_vocab_annotator &&
+ !vocab_annotator_->Annotate(context_unicode, detected_text_language_tags,
+ options.trigger_dictionary_on_beginner_words,
+ candidates)) {
+ return Status(StatusCode::INTERNAL, "Couldn't run vocab annotator.");
+ }
+
+ // Annotate with the experimental annotator.
if (experimental_annotator_ != nullptr &&
!experimental_annotator_->Annotate(context_unicode, candidates)) {
return Status(StatusCode::INTERNAL, "Couldn't run experimental annotator.");
@@ -2224,7 +2332,7 @@
std::vector<int> candidate_indices;
if (!ResolveConflicts(*candidates, context, tokens,
- detected_text_language_tags, options.annotation_usecase,
+ detected_text_language_tags, options,
&interpreter_manager, &candidate_indices)) {
return Status(StatusCode::INTERNAL, "Couldn't resolve conflicts.");
}
@@ -2267,40 +2375,54 @@
return Status::OK;
}
-StatusOr<std::vector<std::vector<AnnotatedSpan>>>
-Annotator::AnnotateStructuredInput(
+StatusOr<Annotations> Annotator::AnnotateStructuredInput(
const std::vector<InputFragment>& string_fragments,
const AnnotationOptions& options) const {
- std::vector<std::vector<AnnotatedSpan>> annotation_candidates(
- string_fragments.size());
+ Annotations annotation_candidates;
+ annotation_candidates.annotated_spans.resize(string_fragments.size());
std::vector<std::string> text_to_annotate;
text_to_annotate.reserve(string_fragments.size());
+ std::vector<FragmentMetadata> fragment_metadata;
+ fragment_metadata.reserve(string_fragments.size());
for (const auto& string_fragment : string_fragments) {
text_to_annotate.push_back(string_fragment.text);
+ fragment_metadata.push_back(
+ {.relative_bounding_box_top = string_fragment.bounding_box_top,
+ .relative_bounding_box_height = string_fragment.bounding_box_height});
}
// KnowledgeEngine is special, because it supports annotation of multiple
// fragments at once.
if (knowledge_engine_ &&
!knowledge_engine_
- ->ChunkMultipleSpans(text_to_annotate, options.annotation_usecase,
+ ->ChunkMultipleSpans(text_to_annotate, fragment_metadata,
+ options.annotation_usecase,
options.location_context, options.permissions,
- &annotation_candidates)
+ options.annotate_mode, &annotation_candidates)
.ok()) {
return Status(StatusCode::INTERNAL, "Couldn't run knowledge engine Chunk.");
}
// The annotator engines shouldn't change the number of annotation vectors.
- if (annotation_candidates.size() != text_to_annotate.size()) {
+ if (annotation_candidates.annotated_spans.size() != text_to_annotate.size()) {
TC3_LOG(ERROR) << "Received " << text_to_annotate.size()
<< " texts to annotate but generated a different number of "
"lists of annotations:"
- << annotation_candidates.size();
+ << annotation_candidates.annotated_spans.size();
return Status(StatusCode::INTERNAL,
"Number of annotation candidates differs from "
"number of texts to annotate.");
}
+ // As an optimization, if the only annotated type is Entity, we skip all the
+ // other annotators than the KnowledgeEngine. This only happens in the raw
+ // mode, to make sure it does not affect the result.
+ if (options.annotation_usecase == ANNOTATION_USECASE_RAW &&
+ options.entity_types.size() == 1 &&
+ *options.entity_types.begin() == Collections::Entity()) {
+ return annotation_candidates;
+ }
+
// Other annotators run on each fragment independently.
for (int i = 0; i < text_to_annotate.size(); ++i) {
AnnotationOptions annotation_options = options;
@@ -2314,10 +2436,11 @@
}
AddContactMetadataToKnowledgeClassificationResults(
- &annotation_candidates[i]);
+ &annotation_candidates.annotated_spans[i]);
- Status annotation_status = AnnotateSingleInput(
- text_to_annotate[i], annotation_options, &annotation_candidates[i]);
+ Status annotation_status =
+ AnnotateSingleInput(text_to_annotate[i], annotation_options,
+ &annotation_candidates.annotated_spans[i]);
if (!annotation_status.ok()) {
return annotation_status;
}
@@ -2327,16 +2450,28 @@
std::vector<AnnotatedSpan> Annotator::Annotate(
const std::string& context, const AnnotationOptions& options) const {
+ if (context.size() > std::numeric_limits<int>::max()) {
+ TC3_LOG(ERROR) << "Rejecting too long input.";
+ return {};
+ }
+
+ const UnicodeText context_unicode =
+ UTF8ToUnicodeText(context, /*do_copy=*/false);
+ if (!unilib_->IsValidUtf8(context_unicode)) {
+ TC3_LOG(ERROR) << "Rejecting input, invalid UTF8.";
+ return {};
+ }
+
std::vector<InputFragment> string_fragments;
string_fragments.push_back({.text = context});
- StatusOr<std::vector<std::vector<AnnotatedSpan>>> annotations =
+ StatusOr<Annotations> annotations =
AnnotateStructuredInput(string_fragments, options);
if (!annotations.ok()) {
TC3_LOG(ERROR) << "Returned error when calling AnnotateStructuredInput: "
<< annotations.status().error_message();
return {};
}
- return annotations.ValueOrDie()[0];
+ return annotations.ValueOrDie().annotated_spans[0];
}
CodepointSpan Annotator::ComputeSelectionBoundaries(
@@ -2408,7 +2543,7 @@
}
TC3_CHECK(entity_data_builder_ != nullptr);
- std::unique_ptr<ReflectiveFlatbuffer> entity_data =
+ std::unique_ptr<MutableFlatbuffer> entity_data =
entity_data_builder_->NewRoot();
TC3_CHECK(entity_data != nullptr);
@@ -2490,8 +2625,44 @@
return whole_amount;
}
+void Annotator::GetMoneyQuantityFromCapturingGroup(
+ const UniLib::RegexMatcher* match, const RegexModel_::Pattern* config,
+ const UnicodeText& context_unicode, std::string* quantity,
+ int* exponent) const {
+ if (config->capturing_group() == nullptr) {
+ *exponent = 0;
+ return;
+ }
+
+ const int num_groups = config->capturing_group()->size();
+ for (int i = 0; i < num_groups; i++) {
+ int status = UniLib::RegexMatcher::kNoError;
+ const int group_start = match->Start(i, &status);
+ const int group_end = match->End(i, &status);
+ if (group_start == kInvalidIndex || group_end == kInvalidIndex) {
+ continue;
+ }
+
+ *quantity =
+ unilib_
+ ->ToLowerText(UnicodeText::Substring(context_unicode, group_start,
+ group_end, /*do_copy=*/false))
+ .ToUTF8String();
+
+ if (auto entry = model_->money_parsing_options()
+ ->quantities_name_to_exponent()
+ ->LookupByKey((*quantity).c_str())) {
+ *exponent = entry->value();
+ return;
+ }
+ }
+ *exponent = 0;
+}
+
bool Annotator::ParseAndFillInMoneyAmount(
- std::string* serialized_entity_data) const {
+ std::string* serialized_entity_data, const UniLib::RegexMatcher* match,
+ const RegexModel_::Pattern* config,
+ const UnicodeText& context_unicode) const {
std::unique_ptr<EntityDataT> data =
LoadAndVerifyMutableFlatbuffer<libtextclassifier3::EntityData>(
*serialized_entity_data);
@@ -2543,20 +2714,49 @@
<< data->money->unnormalized_amount;
return false;
}
+
if (it_decimal_separator == amount.end()) {
data->money->amount_decimal_part = 0;
+ data->money->nanos = 0;
} else {
const int amount_codepoints_size = amount.size_codepoints();
- if (!unilib_->ParseInt32(
- UnicodeText::Substring(
- amount, amount_codepoints_size - separator_back_index,
- amount_codepoints_size, /*do_copy=*/false),
- &data->money->amount_decimal_part)) {
+ const UnicodeText decimal_part = UnicodeText::Substring(
+ amount, amount_codepoints_size - separator_back_index,
+ amount_codepoints_size, /*do_copy=*/false);
+ if (!unilib_->ParseInt32(decimal_part, &data->money->amount_decimal_part)) {
TC3_LOG(ERROR) << "Could not parse the money decimal part as int32 from "
"the amount: "
<< data->money->unnormalized_amount;
return false;
}
+ data->money->nanos = data->money->amount_decimal_part *
+ pow(10, 9 - decimal_part.size_codepoints());
+ }
+
+ if (model_->money_parsing_options()->quantities_name_to_exponent() !=
+ nullptr) {
+ int quantity_exponent;
+ std::string quantity;
+ GetMoneyQuantityFromCapturingGroup(match, config, context_unicode,
+ &quantity, &quantity_exponent);
+ if (quantity_exponent > 0 && quantity_exponent <= 9) {
+ const double amount_whole_part =
+ data->money->amount_whole_part * pow(10, quantity_exponent) +
+ data->money->nanos / pow(10, 9 - quantity_exponent);
+ // TODO(jacekj): Change type of `data->money->amount_whole_part` to int64
+ // (and `std::numeric_limits<int>::max()` to
+ // `std::numeric_limits<int64>::max()`).
+ if (amount_whole_part < std::numeric_limits<int>::max()) {
+ data->money->amount_whole_part = amount_whole_part;
+ data->money->nanos = data->money->nanos %
+ static_cast<int>(pow(10, 9 - quantity_exponent)) *
+ pow(10, quantity_exponent);
+ }
+ }
+ if (quantity_exponent > 0) {
+ data->money->unnormalized_amount = strings::JoinStrings(
+ " ", {data->money->unnormalized_amount, quantity});
+ }
}
*serialized_entity_data =
@@ -2564,12 +2764,71 @@
return true;
}
+bool Annotator::IsAnyModelEntityTypeEnabled(
+ const EnabledEntityTypes& is_entity_type_enabled) const {
+ if (model_->classification_feature_options() == nullptr ||
+ model_->classification_feature_options()->collections() == nullptr) {
+ return false;
+ }
+ for (int i = 0;
+ i < model_->classification_feature_options()->collections()->size();
+ i++) {
+ if (is_entity_type_enabled(model_->classification_feature_options()
+ ->collections()
+ ->Get(i)
+ ->str())) {
+ return true;
+ }
+ }
+ return false;
+}
+
+bool Annotator::IsAnyRegexEntityTypeEnabled(
+ const EnabledEntityTypes& is_entity_type_enabled) const {
+ if (model_->regex_model() == nullptr ||
+ model_->regex_model()->patterns() == nullptr) {
+ return false;
+ }
+ for (int i = 0; i < model_->regex_model()->patterns()->size(); i++) {
+ if (is_entity_type_enabled(model_->regex_model()
+ ->patterns()
+ ->Get(i)
+ ->collection_name()
+ ->str())) {
+ return true;
+ }
+ }
+ return false;
+}
+
+bool Annotator::IsAnyPodNerEntityTypeEnabled(
+ const EnabledEntityTypes& is_entity_type_enabled) const {
+ if (pod_ner_annotator_ == nullptr) {
+ return false;
+ }
+
+ for (const std::string& collection :
+ pod_ner_annotator_->GetSupportedCollections()) {
+ if (is_entity_type_enabled(collection)) {
+ return true;
+ }
+ }
+ return false;
+}
+
bool Annotator::RegexChunk(const UnicodeText& context_unicode,
const std::vector<int>& rules,
- std::vector<AnnotatedSpan>* result,
- bool is_serialized_entity_data_enabled) const {
+ bool is_serialized_entity_data_enabled,
+ const EnabledEntityTypes& enabled_entity_types,
+ const AnnotationUsecase& annotation_usecase,
+ std::vector<AnnotatedSpan>* result) const {
for (int pattern_id : rules) {
const CompiledRegexPattern& regex_pattern = regex_patterns_[pattern_id];
+ if (!enabled_entity_types(regex_pattern.config->collection_name()->str()) &&
+ annotation_usecase == AnnotationUsecase_ANNOTATION_USECASE_RAW) {
+ // No regex annotation type has been requested, skip regex annotation.
+ continue;
+ }
const auto matcher = regex_pattern.pattern->Matcher(context_unicode);
if (!matcher) {
TC3_LOG(ERROR) << "Could not get regex matcher for pattern: "
@@ -2596,12 +2855,14 @@
return false;
}
- // Further parsing unnormalized_amount for money into amount_whole_part
- // and amount_decimal_part. Can't do this with regexes because we cannot
- // have empty groups (amount_decimal_part might be an empty group).
+ // Further parsing of money amount. Need this since regexes cannot have
+ // empty groups that fill in entity data (amount_decimal_part and
+ // quantity might be empty groups).
if (regex_pattern.config->collection_name()->str() ==
Collections::Money()) {
- if (!ParseAndFillInMoneyAmount(&serialized_entity_data)) {
+ if (!ParseAndFillInMoneyAmount(&serialized_entity_data, matcher.get(),
+ regex_pattern.config,
+ context_unicode)) {
if (model_->version() >= 706) {
// This way of parsing money entity data is enabled for models
// newer than v706 => logging errors only for them (b/156634162).
@@ -2639,11 +2900,11 @@
// The inference span is the span of interest expanded to include
// max_selection_span tokens on either side, which is how far a selection can
// stretch from the click.
- const TokenSpan inference_span = IntersectTokenSpans(
- ExpandTokenSpan(span_of_interest,
- /*num_tokens_left=*/max_selection_span,
- /*num_tokens_right=*/max_selection_span),
- {0, num_tokens});
+ const TokenSpan inference_span =
+ IntersectTokenSpans(span_of_interest.Expand(
+ /*num_tokens_left=*/max_selection_span,
+ /*num_tokens_right=*/max_selection_span),
+ {0, num_tokens});
std::vector<ScoredChunk> scored_chunks;
if (selection_feature_processor_->GetOptions()->bounds_sensitive_features() &&
@@ -2670,7 +2931,7 @@
// Traverse the candidate chunks from highest-scoring to lowest-scoring. Pick
// them greedily as long as they do not overlap with any previously picked
// chunks.
- std::vector<bool> token_used(TokenSpanSize(inference_span));
+ std::vector<bool> token_used(inference_span.Size());
chunks->clear();
for (const ScoredChunk& scored_chunk : scored_chunks) {
bool feasible = true;
@@ -2766,9 +3027,8 @@
TC3_LOG(ERROR) << "Couldn't map the label to a token span.";
return false;
}
- const TokenSpan candidate_span = ExpandTokenSpan(
- SingleTokenSpan(click_pos), relative_token_span.first,
- relative_token_span.second);
+ const TokenSpan candidate_span = TokenSpan(click_pos).Expand(
+ relative_token_span.first, relative_token_span.second);
if (candidate_span.first >= 0 && candidate_span.second <= num_tokens) {
UpdateMax(&chunk_scores, candidate_span, scores[j]);
}
@@ -2803,7 +3063,7 @@
scored_chunks->clear();
if (score_single_token_spans_as_zero) {
- scored_chunks->reserve(TokenSpanSize(span_of_interest));
+ scored_chunks->reserve(span_of_interest.Size());
}
// Prepare all chunk candidates into one batch:
@@ -2819,8 +3079,7 @@
end <= inference_span.second && end - start <= max_chunk_length;
++end) {
const TokenSpan candidate_span = {start, end};
- if (score_single_token_spans_as_zero &&
- TokenSpanSize(candidate_span) == 1) {
+ if (score_single_token_spans_as_zero && candidate_span.Size() == 1) {
// Do not include the single token span in the batch, add a zero score
// for it directly to the output.
scored_chunks->push_back(ScoredChunk{candidate_span, 0.0f});
@@ -2880,31 +3139,21 @@
AnnotationUsecase annotation_usecase,
bool is_serialized_entity_data_enabled,
std::vector<AnnotatedSpan>* result) const {
- std::vector<DatetimeParseResultSpan> datetime_spans;
- if (cfg_datetime_parser_) {
- if (!(model_->grammar_datetime_model()->enabled_modes() & mode)) {
- return true;
- }
- std::vector<Locale> parsed_locales;
- ParseLocales(locales, &parsed_locales);
- cfg_datetime_parser_->Parse(
- context_unicode.ToUTF8String(),
- ToDateAnnotationOptions(
- model_->grammar_datetime_model()->annotation_options(),
- reference_timezone, reference_time_ms_utc),
- parsed_locales, &datetime_spans);
+ if (!datetime_parser_) {
+ return true;
+ }
+ LocaleList locale_list = LocaleList::ParseFrom(locales);
+ StatusOr<std::vector<DatetimeParseResultSpan>> result_status =
+ datetime_parser_->Parse(context_unicode, reference_time_ms_utc,
+ reference_timezone, locale_list, mode,
+ annotation_usecase,
+ /*anchor_start_end=*/false);
+ if (!result_status.ok()) {
+ return false;
}
- if (datetime_parser_) {
- if (!datetime_parser_->Parse(context_unicode, reference_time_ms_utc,
- reference_timezone, locales, mode,
- annotation_usecase,
- /*anchor_start_end=*/false, &datetime_spans)) {
- return false;
- }
- }
-
- for (const DatetimeParseResultSpan& datetime_span : datetime_spans) {
+ for (const DatetimeParseResultSpan& datetime_span :
+ result_status.ValueOrDie()) {
AnnotatedSpan annotated_span;
annotated_span.span = datetime_span.span;
for (const DatetimeParseResult& parse_result : datetime_span.data) {
@@ -2937,10 +3186,22 @@
return LoadAndVerifyModel(buffer, size);
}
-bool Annotator::LookUpKnowledgeEntity(
- const std::string& id, std::string* serialized_knowledge_result) const {
- return knowledge_engine_ &&
- knowledge_engine_->LookUpEntity(id, serialized_knowledge_result);
+StatusOr<std::string> Annotator::LookUpKnowledgeEntity(
+ const std::string& id) const {
+ if (!knowledge_engine_) {
+ return Status(StatusCode::FAILED_PRECONDITION,
+ "knowledge_engine_ is nullptr");
+ }
+ return knowledge_engine_->LookUpEntity(id);
+}
+
+StatusOr<std::string> Annotator::LookUpKnowledgeEntityProperty(
+ const std::string& mid_str, const std::string& property) const {
+ if (!knowledge_engine_) {
+ return Status(StatusCode::FAILED_PRECONDITION,
+ "knowledge_engine_ is nullptr");
+ }
+ return knowledge_engine_->LookUpEntityProperty(mid_str, property);
}
} // namespace libtextclassifier3
diff --git a/native/annotator/annotator.h b/native/annotator/annotator.h
index ebd762c..d69fe32 100644
--- a/native/annotator/annotator.h
+++ b/native/annotator/annotator.h
@@ -26,11 +26,11 @@
#include <vector>
#include "annotator/contact/contact-engine.h"
+#include "annotator/datetime/datetime-grounder.h"
#include "annotator/datetime/parser.h"
#include "annotator/duration/duration.h"
#include "annotator/experimental/experimental.h"
#include "annotator/feature-processor.h"
-#include "annotator/grammar/dates/cfg-datetime-annotator.h"
#include "annotator/grammar/grammar-annotator.h"
#include "annotator/installed_app/installed-app-engine.h"
#include "annotator/knowledge/knowledge-engine.h"
@@ -38,15 +38,20 @@
#include "annotator/model_generated.h"
#include "annotator/number/number.h"
#include "annotator/person_name/person-name-engine.h"
+#include "annotator/pod_ner/pod-ner.h"
#include "annotator/strip-unpaired-brackets.h"
#include "annotator/translate/translate.h"
#include "annotator/types.h"
+#include "annotator/vocab/vocab-annotator.h"
#include "annotator/zlib-utils.h"
#include "utils/base/status.h"
#include "utils/base/statusor.h"
-#include "utils/flatbuffers.h"
+#include "utils/calendar/calendar.h"
+#include "utils/flatbuffers/flatbuffers.h"
+#include "utils/flatbuffers/mutable.h"
#include "utils/i18n/locale.h"
#include "utils/memory/mmap.h"
+#include "utils/utf8/unicodetext.h"
#include "utils/utf8/unilib.h"
#include "utils/zlib/zlib.h"
#include "lang_id/lang-id.h"
@@ -105,6 +110,10 @@
static std::unique_ptr<Annotator> FromUnownedBuffer(
const char* buffer, int size, const UniLib* unilib = nullptr,
const CalendarLib* calendarlib = nullptr);
+ // Copies the underlying model buffer string.
+ static std::unique_ptr<Annotator> FromString(
+ const std::string& buffer, const UniLib* unilib = nullptr,
+ const CalendarLib* calendarlib = nullptr);
// Takes ownership of the mmap.
static std::unique_ptr<Annotator> FromScopedMmap(
std::unique_ptr<ScopedMmap>* mmap, const UniLib* unilib = nullptr,
@@ -167,7 +176,7 @@
bool InitializeExperimentalAnnotators();
// Sets up the lang-id instance that should be used.
- void SetLangId(const libtextclassifier3::mobile::lang_id::LangId* lang_id);
+ bool SetLangId(const libtextclassifier3::mobile::lang_id::LangId* lang_id);
// Runs inference for given a context and current selection (i.e. index
// of the first and one past last selected characters (utf8 codepoint
@@ -184,7 +193,7 @@
// Classifies the selected text given the context string.
// Returns an empty result if an error occurs.
std::vector<ClassificationResult> ClassifyText(
- const std::string& context, CodepointSpan selection_indices,
+ const std::string& context, const CodepointSpan& selection_indices,
const ClassificationOptions& options = ClassificationOptions()) const;
// Annotates the given structed input request. Models which handle the full
@@ -197,7 +206,7 @@
// of input fragments. The order of annotation span vectors will match the
// order of input fragments. If annotation is not possible for any of the
// annotators, no annotation is returned.
- StatusOr<std::vector<std::vector<AnnotatedSpan>>> AnnotateStructuredInput(
+ StatusOr<Annotations> AnnotateStructuredInput(
const std::vector<InputFragment>& string_fragments,
const AnnotationOptions& options = AnnotationOptions()) const;
@@ -207,10 +216,13 @@
const std::string& context,
const AnnotationOptions& options = AnnotationOptions()) const;
- // Looks up a knowledge entity by its id. If successful, populates the
- // serialized knowledge result and returns true.
- bool LookUpKnowledgeEntity(const std::string& id,
- std::string* serialized_knowledge_result) const;
+ // Looks up a knowledge entity by its id. Returns the serialized knowledge
+ // result.
+ StatusOr<std::string> LookUpKnowledgeEntity(const std::string& id) const;
+
+ // Looks up an entity's property.
+ StatusOr<std::string> LookUpKnowledgeEntityProperty(
+ const std::string& mid_str, const std::string& property) const;
const Model* model() const;
const reflection::Schema* entity_data_schema() const;
@@ -234,22 +246,14 @@
float score;
};
- // Constructs and initializes text classifier from given model.
- // Takes ownership of 'mmap', and thus owns the buffer that backs 'model'.
- Annotator(std::unique_ptr<ScopedMmap>* mmap, const Model* model,
- const UniLib* unilib, const CalendarLib* calendarlib);
- Annotator(std::unique_ptr<ScopedMmap>* mmap, const Model* model,
- std::unique_ptr<UniLib> unilib,
- std::unique_ptr<CalendarLib> calendarlib);
-
- // Constructs, validates and initializes text classifier from given model.
- // Does not own the buffer that backs 'model'.
- Annotator(const Model* model, const UniLib* unilib,
- const CalendarLib* calendarlib);
+ // NOTE: ValidateAndInitialize needs to be called before any other method.
+ Annotator() : initialized_(false) {}
// Checks that model contains all required fields, and initializes internal
// datastructures.
- void ValidateAndInitialize();
+ // Needs to be called before any other method is.
+ void ValidateAndInitialize(const Model* model, const UniLib* unilib,
+ const CalendarLib* calendarlib);
// Initializes regular expressions for the regex model.
bool InitializeRegexModel(ZlibDecompressor* decompressor);
@@ -262,7 +266,7 @@
const std::string& context,
const std::vector<Token>& cached_tokens,
const std::vector<Locale>& detected_text_language_tags,
- AnnotationUsecase annotation_usecase,
+ const BaseOptions& options,
InterpreterManager* interpreter_manager,
std::vector<int>* result) const;
@@ -274,7 +278,7 @@
const std::vector<AnnotatedSpan>& candidates,
const std::vector<Locale>& detected_text_language_tags,
int start_index, int end_index,
- AnnotationUsecase annotation_usecase,
+ const BaseOptions& options,
InterpreterManager* interpreter_manager,
std::vector<int>* chosen_indices) const;
@@ -282,37 +286,44 @@
// Provides the tokens produced during tokenization of the context string for
// reuse.
bool ModelSuggestSelection(
- const UnicodeText& context_unicode, CodepointSpan click_indices,
+ const UnicodeText& context_unicode, const CodepointSpan& click_indices,
const std::vector<Locale>& detected_text_language_tags,
InterpreterManager* interpreter_manager, std::vector<Token>* tokens,
std::vector<AnnotatedSpan>* result) const;
// Classifies the selected text given the context string with the
// classification model.
+ // The following arguments are optional:
+ // - cached_tokens - can be given as empty
+ // - embedding_cache - can be given as nullptr
+ // - tokens - can be given as nullptr
// Returns true if no error occurred.
bool ModelClassifyText(
const std::string& context, const std::vector<Token>& cached_tokens,
- const std::vector<Locale>& locales, CodepointSpan selection_indices,
+ const std::vector<Locale>& detected_text_language_tags,
+ const CodepointSpan& selection_indices, const BaseOptions& options,
InterpreterManager* interpreter_manager,
FeatureProcessor::EmbeddingCache* embedding_cache,
std::vector<ClassificationResult>* classification_results,
std::vector<Token>* tokens) const;
- // Same as above but doesn't output tokens.
+ // Same as above, but (for optimization) takes the context as UnicodeText and
+ // takes the following extra arguments:
+ // - span_begin, span_end - iterators in context_unicode corresponding to
+ // selection_indices
+ // - line - a UnicodeTextRange within context_unicode corresponding to the
+ // line containing the selection - optional, can be given as nullptr
bool ModelClassifyText(
- const std::string& context, const std::vector<Token>& cached_tokens,
+ const UnicodeText& context_unicode,
+ const std::vector<Token>& cached_tokens,
const std::vector<Locale>& detected_text_language_tags,
- CodepointSpan selection_indices, InterpreterManager* interpreter_manager,
+ const UnicodeText::const_iterator& span_begin,
+ const UnicodeText::const_iterator& span_end, const UnicodeTextRange* line,
+ const CodepointSpan& selection_indices, const BaseOptions& options,
+ InterpreterManager* interpreter_manager,
FeatureProcessor::EmbeddingCache* embedding_cache,
- std::vector<ClassificationResult>* classification_results) const;
-
- // Same as above but doesn't take cached tokens and doesn't output tokens.
- bool ModelClassifyText(
- const std::string& context,
- const std::vector<Locale>& detected_text_language_tags,
- CodepointSpan selection_indices, InterpreterManager* interpreter_manager,
- FeatureProcessor::EmbeddingCache* embedding_cache,
- std::vector<ClassificationResult>* classification_results) const;
+ std::vector<ClassificationResult>* classification_results,
+ std::vector<Token>* tokens) const;
// Returns a relative token span that represents how many tokens on the left
// from the selection and right from the selection are needed for the
@@ -322,13 +333,13 @@
// Classifies the selected text with the regular expressions models.
// Returns true if no error happened, false otherwise.
bool RegexClassifyText(
- const std::string& context, CodepointSpan selection_indices,
+ const std::string& context, const CodepointSpan& selection_indices,
std::vector<ClassificationResult>* classification_result) const;
// Classifies the selected text with the date time model.
// Returns true if no error happened, false otherwise.
bool DatetimeClassifyText(
- const std::string& context, CodepointSpan selection_indices,
+ const std::string& context, const CodepointSpan& selection_indices,
const ClassificationOptions& options,
std::vector<ClassificationResult>* classification_results) const;
@@ -340,6 +351,7 @@
// reuse.
bool ModelAnnotate(const std::string& context,
const std::vector<Locale>& detected_text_language_tags,
+ const AnnotationOptions& options,
InterpreterManager* interpreter_manager,
std::vector<Token>* tokens,
std::vector<AnnotatedSpan>* result) const;
@@ -379,8 +391,11 @@
// Produces chunks isolated by a set of regular expressions.
bool RegexChunk(const UnicodeText& context_unicode,
const std::vector<int>& rules,
- std::vector<AnnotatedSpan>* result,
- bool is_serialized_entity_data_enabled) const;
+ bool is_serialized_entity_data_enabled,
+ const EnabledEntityTypes& enabled_entity_types,
+ const AnnotationUsecase& annotation_usecase,
+
+ std::vector<AnnotatedSpan>* result) const;
// Produces chunks from the datetime parser.
bool DatetimeChunk(const UnicodeText& context_unicode,
@@ -434,11 +449,15 @@
std::unique_ptr<const FeatureProcessor> selection_feature_processor_;
std::unique_ptr<const FeatureProcessor> classification_feature_processor_;
+ std::unique_ptr<const grammar::Analyzer> analyzer_;
+ std::unique_ptr<const DatetimeGrounder> datetime_grounder_;
std::unique_ptr<const DatetimeParser> datetime_parser_;
- std::unique_ptr<const dates::CfgDatetimeAnnotator> cfg_datetime_parser_;
-
std::unique_ptr<const GrammarAnnotator> grammar_annotator_;
+ std::string owned_buffer_;
+ std::unique_ptr<UniLib> owned_unilib_;
+ std::unique_ptr<CalendarLib> owned_calendarlib_;
+
private:
struct CompiledRegexPattern {
const RegexModel_::Pattern* config;
@@ -462,7 +481,31 @@
// Parses the money amount into whole and decimal part and fills in the
// entity data information.
- bool ParseAndFillInMoneyAmount(std::string* serialized_entity_data) const;
+ bool ParseAndFillInMoneyAmount(std::string* serialized_entity_data,
+ const UniLib::RegexMatcher* match,
+ const RegexModel_::Pattern* config,
+ const UnicodeText& context_unicode) const;
+
+ // Given the regex capturing groups, extract the one representing the money
+ // quantity and fills in the actual string and the power of 10 the amount
+ // should be multiplied with.
+ void GetMoneyQuantityFromCapturingGroup(const UniLib::RegexMatcher* match,
+ const RegexModel_::Pattern* config,
+ const UnicodeText& context_unicode,
+ std::string* quantity,
+ int* exponent) const;
+
+ // Returns true if any of the ff-model entity types is enabled.
+ bool IsAnyModelEntityTypeEnabled(
+ const EnabledEntityTypes& is_entity_type_enabled) const;
+
+ // Returns true if any of the regex entity types is enabled.
+ bool IsAnyRegexEntityTypeEnabled(
+ const EnabledEntityTypes& is_entity_type_enabled) const;
+
+ // Returns true if any of the POD NER entity types is enabled.
+ bool IsAnyPodNerEntityTypeEnabled(
+ const EnabledEntityTypes& is_entity_type_enabled) const;
std::unique_ptr<ScopedMmap> mmap_;
bool initialized_ = false;
@@ -479,9 +522,7 @@
std::vector<int> annotation_regex_patterns_, classification_regex_patterns_,
selection_regex_patterns_;
- std::unique_ptr<UniLib> owned_unilib_;
const UniLib* unilib_;
- std::unique_ptr<CalendarLib> owned_calendarlib_;
const CalendarLib* calendarlib_;
std::unique_ptr<const KnowledgeEngine> knowledge_engine_;
@@ -491,11 +532,13 @@
std::unique_ptr<const DurationAnnotator> duration_annotator_;
std::unique_ptr<const PersonNameEngine> person_name_engine_;
std::unique_ptr<const TranslateAnnotator> translate_annotator_;
+ std::unique_ptr<const PodNerAnnotator> pod_ner_annotator_;
std::unique_ptr<const ExperimentalAnnotator> experimental_annotator_;
+ std::unique_ptr<const VocabAnnotator> vocab_annotator_;
// Builder for creating extra data.
const reflection::Schema* entity_data_schema_;
- std::unique_ptr<ReflectiveFlatbufferBuilder> entity_data_builder_;
+ std::unique_ptr<MutableFlatbufferBuilder> entity_data_builder_;
// Locales for which the entire model triggers.
std::vector<Locale> model_triggering_locales_;
@@ -526,7 +569,7 @@
// Helper function, which if the initial 'span' contains only white-spaces,
// moves the selection to a single-codepoint selection on the left side
// of this block of white-space.
-CodepointSpan SnapLeftIfWhitespaceSelection(CodepointSpan span,
+CodepointSpan SnapLeftIfWhitespaceSelection(const CodepointSpan& span,
const UnicodeText& context_unicode,
const UniLib& unilib);
@@ -534,7 +577,7 @@
// 'tokens_around_selection_to_copy' (on the left, and right) tokens distant
// from the tokens that correspond to 'selection_indices'.
std::vector<Token> CopyCachedTokens(const std::vector<Token>& cached_tokens,
- CodepointSpan selection_indices,
+ const CodepointSpan& selection_indices,
TokenSpan tokens_around_selection_to_copy);
} // namespace internal
diff --git a/native/annotator/annotator_jni.cc b/native/annotator/annotator_jni.cc
index 3e04f7f..6e7eeab 100644
--- a/native/annotator/annotator_jni.cc
+++ b/native/annotator/annotator_jni.cc
@@ -21,10 +21,12 @@
#include <jni.h>
#include <type_traits>
+#include <utility>
#include <vector>
#include "annotator/annotator.h"
#include "annotator/annotator_jni_common.h"
+#include "annotator/knowledge/knowledge-engine-types.h"
#include "annotator/types.h"
#include "utils/base/integral_types.h"
#include "utils/base/status_macros.h"
@@ -35,7 +37,6 @@
#include "utils/intents/remote-action-template.h"
#include "utils/java/jni-cache.h"
#include "utils/java/jni-helper.h"
-#include "utils/java/string_utils.h"
#include "utils/memory/mmap.h"
#include "utils/strings/stringpiece.h"
#include "utils/utf8/unilib.h"
@@ -49,6 +50,7 @@
#endif
using libtextclassifier3::AnnotatedSpan;
+using libtextclassifier3::Annotations;
using libtextclassifier3::Annotator;
using libtextclassifier3::ClassificationResult;
using libtextclassifier3::CodepointSpan;
@@ -81,11 +83,9 @@
std::unique_ptr<IntentGenerator> intent_generator =
IntentGenerator::Create(model->model()->intent_options(),
model->model()->resources(), jni_cache);
- std::unique_ptr<RemoteActionTemplatesHandler> template_handler =
- libtextclassifier3::RemoteActionTemplatesHandler::Create(jni_cache);
- if (template_handler == nullptr) {
- return nullptr;
- }
+ TC3_ASSIGN_OR_RETURN_NULL(
+ std::unique_ptr<RemoteActionTemplatesHandler> template_handler,
+ libtextclassifier3::RemoteActionTemplatesHandler::Create(jni_cache));
return new AnnotatorJniContext(jni_cache, std::move(model),
std::move(intent_generator),
@@ -151,10 +151,11 @@
TC3_ASSIGN_OR_RETURN(serialized_knowledge_result,
JniHelper::NewByteArray(
env, serialized_knowledge_result_string.size()));
- env->SetByteArrayRegion(serialized_knowledge_result.get(), 0,
- serialized_knowledge_result_string.size(),
- reinterpret_cast<const jbyte*>(
- serialized_knowledge_result_string.data()));
+ TC3_RETURN_IF_ERROR(JniHelper::SetByteArrayRegion(
+ env, serialized_knowledge_result.get(), 0,
+ serialized_knowledge_result_string.size(),
+ reinterpret_cast<const jbyte*>(
+ serialized_knowledge_result_string.data())));
}
ScopedLocalRef<jstring> contact_name;
@@ -204,6 +205,22 @@
env, classification_result.contact_phone_number.c_str()));
}
+ ScopedLocalRef<jstring> contact_account_type;
+ if (!classification_result.contact_account_type.empty()) {
+ TC3_ASSIGN_OR_RETURN(
+ contact_account_type,
+ JniHelper::NewStringUTF(
+ env, classification_result.contact_account_type.c_str()));
+ }
+
+ ScopedLocalRef<jstring> contact_account_name;
+ if (!classification_result.contact_account_name.empty()) {
+ TC3_ASSIGN_OR_RETURN(
+ contact_account_name,
+ JniHelper::NewStringUTF(
+ env, classification_result.contact_account_name.c_str()));
+ }
+
ScopedLocalRef<jstring> contact_id;
if (!classification_result.contact_id.empty()) {
TC3_ASSIGN_OR_RETURN(
@@ -242,11 +259,11 @@
serialized_entity_data,
JniHelper::NewByteArray(
env, classification_result.serialized_entity_data.size()));
- env->SetByteArrayRegion(
- serialized_entity_data.get(), 0,
+ TC3_RETURN_IF_ERROR(JniHelper::SetByteArrayRegion(
+ env, serialized_entity_data.get(), 0,
classification_result.serialized_entity_data.size(),
reinterpret_cast<const jbyte*>(
- classification_result.serialized_entity_data.data()));
+ classification_result.serialized_entity_data.data())));
}
ScopedLocalRef<jobjectArray> remote_action_templates_result;
@@ -274,7 +291,8 @@
row_datetime_parse.get(), serialized_knowledge_result.get(),
contact_name.get(), contact_given_name.get(), contact_family_name.get(),
contact_nickname.get(), contact_email_address.get(),
- contact_phone_number.get(), contact_id.get(), app_name.get(),
+ contact_phone_number.get(), contact_account_type.get(),
+ contact_account_name.get(), contact_id.get(), app_name.get(),
app_package_name.get(), extras.get(), serialized_entity_data.get(),
remote_action_templates_result.get(), classification_result.duration_ms,
classification_result.numeric_value,
@@ -303,13 +321,23 @@
JniHelper::GetMethodID(
env, result_class.get(), "<init>",
"(Ljava/lang/String;FL" TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR
- "$DatetimeResult;[BLjava/lang/String;Ljava/lang/String;Ljava/lang/"
- "String;"
- "Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/"
- "String;"
- "Ljava/lang/String;Ljava/lang/String;[L" TC3_PACKAGE_PATH
- "" TC3_NAMED_VARIANT_CLASS_NAME_STR ";[B[L" TC3_PACKAGE_PATH
- "" TC3_REMOTE_ACTION_TEMPLATE_CLASS_NAME_STR ";JJD)V"));
+ "$DatetimeResult;"
+ "[B"
+ "Ljava/lang/String;"
+ "Ljava/lang/String;"
+ "Ljava/lang/String;"
+ "Ljava/lang/String;"
+ "Ljava/lang/String;"
+ "Ljava/lang/String;"
+ "Ljava/lang/String;"
+ "Ljava/lang/String;"
+ "Ljava/lang/String;"
+ "Ljava/lang/String;"
+ "Ljava/lang/String;"
+ "[L" TC3_PACKAGE_PATH "" TC3_NAMED_VARIANT_CLASS_NAME_STR ";"
+ "[B"
+ "[L" TC3_PACKAGE_PATH "" TC3_REMOTE_ACTION_TEMPLATE_CLASS_NAME_STR ";"
+ "JJD)V"));
TC3_ASSIGN_OR_RETURN(const jmethodID datetime_parse_class_constructor,
JniHelper::GetMethodID(env, datetime_parse_class.get(),
"<init>", "(JI)V"));
@@ -340,7 +368,7 @@
return ClassificationResultsWithIntentsToJObjectArray(
env, model_context,
/*(unused) app_context=*/nullptr,
- /*(unused) devide_locale=*/nullptr,
+ /*(unused) device_locale=*/nullptr,
/*(unusued) options=*/nullptr,
/*(unused) selection_text=*/"",
/*(unused) selection_indices=*/{kInvalidIndex, kInvalidIndex},
@@ -348,9 +376,9 @@
/*generate_intents=*/false);
}
-CodepointSpan ConvertIndicesBMPUTF8(const std::string& utf8_str,
- CodepointSpan orig_indices,
- bool from_utf8) {
+std::pair<int, int> ConvertIndicesBMPUTF8(
+ const std::string& utf8_str, const std::pair<int, int>& orig_indices,
+ bool from_utf8) {
const libtextclassifier3::UnicodeText unicode_str =
libtextclassifier3::UTF8ToUnicodeText(utf8_str, /*do_copy=*/false);
@@ -367,7 +395,7 @@
target_index = &unicode_index;
}
- CodepointSpan result{-1, -1};
+ std::pair<int, int> result = std::make_pair(-1, -1);
std::function<void()> assign_indices_fn = [&result, &orig_indices,
&source_index, &target_index]() {
if (orig_indices.first == *source_index) {
@@ -396,13 +424,17 @@
} // namespace
CodepointSpan ConvertIndicesBMPToUTF8(const std::string& utf8_str,
- CodepointSpan bmp_indices) {
- return ConvertIndicesBMPUTF8(utf8_str, bmp_indices, /*from_utf8=*/false);
+ const std::pair<int, int>& bmp_indices) {
+ const std::pair<int, int> utf8_indices =
+ ConvertIndicesBMPUTF8(utf8_str, bmp_indices, /*from_utf8=*/false);
+ return CodepointSpan(utf8_indices.first, utf8_indices.second);
}
-CodepointSpan ConvertIndicesUTF8ToBMP(const std::string& utf8_str,
- CodepointSpan utf8_indices) {
- return ConvertIndicesBMPUTF8(utf8_str, utf8_indices, /*from_utf8=*/true);
+std::pair<int, int> ConvertIndicesUTF8ToBMP(const std::string& utf8_str,
+ const CodepointSpan& utf8_indices) {
+ return ConvertIndicesBMPUTF8(
+ utf8_str, std::make_pair(utf8_indices.first, utf8_indices.second),
+ /*from_utf8=*/true);
}
StatusOr<ScopedLocalRef<jstring>> GetLocalesFromMmap(
@@ -456,10 +488,10 @@
using libtextclassifier3::FromJavaInputFragment;
using libtextclassifier3::FromJavaSelectionOptions;
using libtextclassifier3::InputFragment;
-using libtextclassifier3::ToStlString;
+using libtextclassifier3::JStringToUtf8String;
TC3_JNI_METHOD(jlong, TC3_ANNOTATOR_CLASS_NAME, nativeNewAnnotator)
-(JNIEnv* env, jobject thiz, jint fd) {
+(JNIEnv* env, jobject clazz, jint fd) {
std::shared_ptr<libtextclassifier3::JniCache> jni_cache(
libtextclassifier3::JniCache::Create(env));
#ifdef TC3_USE_JAVAICU
@@ -475,8 +507,9 @@
}
TC3_JNI_METHOD(jlong, TC3_ANNOTATOR_CLASS_NAME, nativeNewAnnotatorFromPath)
-(JNIEnv* env, jobject thiz, jstring path) {
- TC3_ASSIGN_OR_RETURN_0(const std::string path_str, ToStlString(env, path));
+(JNIEnv* env, jobject clazz, jstring path) {
+ TC3_ASSIGN_OR_RETURN_0(const std::string path_str,
+ JStringToUtf8String(env, path));
std::shared_ptr<libtextclassifier3::JniCache> jni_cache(
libtextclassifier3::JniCache::Create(env));
#ifdef TC3_USE_JAVAICU
@@ -492,7 +525,7 @@
}
TC3_JNI_METHOD(jlong, TC3_ANNOTATOR_CLASS_NAME, nativeNewAnnotatorWithOffset)
-(JNIEnv* env, jobject thiz, jint fd, jlong offset, jlong size) {
+(JNIEnv* env, jobject clazz, jint fd, jlong offset, jlong size) {
std::shared_ptr<libtextclassifier3::JniCache> jni_cache(
libtextclassifier3::JniCache::Create(env));
#ifdef TC3_USE_JAVAICU
@@ -516,13 +549,9 @@
Annotator* model = reinterpret_cast<AnnotatorJniContext*>(ptr)->model();
- std::string serialized_config_string;
- TC3_ASSIGN_OR_RETURN_FALSE(jsize length,
- JniHelper::GetArrayLength(env, serialized_config));
- serialized_config_string.resize(length);
- env->GetByteArrayRegion(serialized_config, 0, length,
- reinterpret_cast<jbyte*>(const_cast<char*>(
- serialized_config_string.data())));
+ TC3_ASSIGN_OR_RETURN_FALSE(
+ const std::string serialized_config_string,
+ libtextclassifier3::JByteArrayToString(env, serialized_config));
return model->InitializeKnowledgeEngine(serialized_config_string);
}
@@ -536,13 +565,9 @@
Annotator* model = reinterpret_cast<AnnotatorJniContext*>(ptr)->model();
- std::string serialized_config_string;
- TC3_ASSIGN_OR_RETURN_FALSE(jsize length,
- JniHelper::GetArrayLength(env, serialized_config));
- serialized_config_string.resize(length);
- env->GetByteArrayRegion(serialized_config, 0, length,
- reinterpret_cast<jbyte*>(const_cast<char*>(
- serialized_config_string.data())));
+ TC3_ASSIGN_OR_RETURN_FALSE(
+ const std::string serialized_config_string,
+ libtextclassifier3::JByteArrayToString(env, serialized_config));
return model->InitializeContactEngine(serialized_config_string);
}
@@ -556,13 +581,9 @@
Annotator* model = reinterpret_cast<AnnotatorJniContext*>(ptr)->model();
- std::string serialized_config_string;
- TC3_ASSIGN_OR_RETURN_FALSE(jsize length,
- JniHelper::GetArrayLength(env, serialized_config));
- serialized_config_string.resize(length);
- env->GetByteArrayRegion(serialized_config, 0, length,
- reinterpret_cast<jbyte*>(const_cast<char*>(
- serialized_config_string.data())));
+ TC3_ASSIGN_OR_RETURN_FALSE(
+ const std::string serialized_config_string,
+ libtextclassifier3::JByteArrayToString(env, serialized_config));
return model->InitializeInstalledAppEngine(serialized_config_string);
}
@@ -608,20 +629,23 @@
}
const Annotator* model = reinterpret_cast<AnnotatorJniContext*>(ptr)->model();
TC3_ASSIGN_OR_RETURN_NULL(const std::string context_utf8,
- ToStlString(env, context));
- CodepointSpan input_indices =
+ JStringToUtf8String(env, context));
+ const CodepointSpan input_indices =
ConvertIndicesBMPToUTF8(context_utf8, {selection_begin, selection_end});
TC3_ASSIGN_OR_RETURN_NULL(
libtextclassifier3::SelectionOptions selection_options,
FromJavaSelectionOptions(env, options));
CodepointSpan selection =
model->SuggestSelection(context_utf8, input_indices, selection_options);
- selection = ConvertIndicesUTF8ToBMP(context_utf8, selection);
+ const std::pair<int, int> selection_bmp =
+ ConvertIndicesUTF8ToBMP(context_utf8, selection);
TC3_ASSIGN_OR_RETURN_NULL(ScopedLocalRef<jintArray> result,
JniHelper::NewIntArray(env, 2));
- env->SetIntArrayRegion(result.get(), 0, 1, &(std::get<0>(selection)));
- env->SetIntArrayRegion(result.get(), 1, 1, &(std::get<1>(selection)));
+ TC3_RETURN_NULL_IF_ERROR(JniHelper::SetIntArrayRegion(
+ env, result.get(), 0, 1, &(selection_bmp.first)));
+ TC3_RETURN_NULL_IF_ERROR(JniHelper::SetIntArrayRegion(
+ env, result.get(), 1, 1, &(selection_bmp.second)));
return result.release();
}
@@ -636,7 +660,7 @@
reinterpret_cast<AnnotatorJniContext*>(ptr);
TC3_ASSIGN_OR_RETURN_NULL(const std::string context_utf8,
- ToStlString(env, context));
+ JStringToUtf8String(env, context));
const CodepointSpan input_indices =
ConvertIndicesBMPToUTF8(context_utf8, {selection_begin, selection_end});
TC3_ASSIGN_OR_RETURN_NULL(
@@ -672,7 +696,7 @@
const AnnotatorJniContext* model_context =
reinterpret_cast<AnnotatorJniContext*>(ptr);
TC3_ASSIGN_OR_RETURN_NULL(const std::string context_utf8,
- ToStlString(env, context));
+ JStringToUtf8String(env, context));
TC3_ASSIGN_OR_RETURN_NULL(
libtextclassifier3::AnnotationOptions annotation_options,
FromJavaAnnotationOptions(env, options));
@@ -696,7 +720,7 @@
JniHelper::NewObjectArray(env, annotations.size(), result_class.get()));
for (int i = 0; i < annotations.size(); ++i) {
- CodepointSpan span_bmp =
+ const std::pair<int, int> span_bmp =
ConvertIndicesUTF8ToBMP(context_utf8, annotations[i].span);
TC3_ASSIGN_OR_RETURN_NULL(
@@ -718,8 +742,7 @@
return results.release();
}
-TC3_JNI_METHOD(jobjectArray, TC3_ANNOTATOR_CLASS_NAME,
- nativeAnnotateStructuredInput)
+TC3_JNI_METHOD(jobject, TC3_ANNOTATOR_CLASS_NAME, nativeAnnotateStructuredInput)
(JNIEnv* env, jobject thiz, jlong ptr, jobjectArray jinput_fragments,
jobject options) {
if (!ptr) {
@@ -743,7 +766,7 @@
TC3_ASSIGN_OR_RETURN_NULL(
libtextclassifier3::AnnotationOptions annotation_options,
FromJavaAnnotationOptions(env, options));
- const StatusOr<std::vector<std::vector<AnnotatedSpan>>> annotations_or =
+ const StatusOr<Annotations> annotations_or =
model_context->model()->AnnotateStructuredInput(string_fragments,
annotation_options);
if (!annotations_or.ok()) {
@@ -752,8 +775,20 @@
return nullptr;
}
- std::vector<std::vector<AnnotatedSpan>> annotations =
- std::move(annotations_or.ValueOrDie());
+ Annotations annotations = std::move(annotations_or.ValueOrDie());
+ TC3_ASSIGN_OR_RETURN_NULL(
+ ScopedLocalRef<jclass> annotations_class,
+ JniHelper::FindClass(
+ env, TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR "$Annotations"));
+
+ TC3_ASSIGN_OR_RETURN_NULL(
+ jmethodID annotations_class_constructor,
+ JniHelper::GetMethodID(
+ env, annotations_class.get(), "<init>",
+ "([[L" TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR
+ "$AnnotatedSpan;[L" TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR
+ "$ClassificationResult;)V"));
+
TC3_ASSIGN_OR_RETURN_NULL(
ScopedLocalRef<jclass> span_class,
JniHelper::FindClass(
@@ -773,26 +808,28 @@
"$AnnotatedSpan;"));
TC3_ASSIGN_OR_RETURN_NULL(
- ScopedLocalRef<jobjectArray> results,
+ ScopedLocalRef<jobjectArray> annotated_spans,
JniHelper::NewObjectArray(env, input_size, span_class_array.get()));
- for (int fragment_index = 0; fragment_index < annotations.size();
- ++fragment_index) {
+ for (int fragment_index = 0;
+ fragment_index < annotations.annotated_spans.size(); ++fragment_index) {
TC3_ASSIGN_OR_RETURN_NULL(
ScopedLocalRef<jobjectArray> jfragmentAnnotations,
- JniHelper::NewObjectArray(env, annotations[fragment_index].size(),
- span_class.get()));
+ JniHelper::NewObjectArray(
+ env, annotations.annotated_spans[fragment_index].size(),
+ span_class.get()));
for (int annotation_index = 0;
- annotation_index < annotations[fragment_index].size();
+ annotation_index < annotations.annotated_spans[fragment_index].size();
++annotation_index) {
- CodepointSpan span_bmp = ConvertIndicesUTF8ToBMP(
+ const std::pair<int, int> span_bmp = ConvertIndicesUTF8ToBMP(
string_fragments[fragment_index].text,
- annotations[fragment_index][annotation_index].span);
+ annotations.annotated_spans[fragment_index][annotation_index].span);
TC3_ASSIGN_OR_RETURN_NULL(
ScopedLocalRef<jobjectArray> classification_results,
ClassificationResultsToJObjectArray(
env, model_context,
- annotations[fragment_index][annotation_index].classification));
+ annotations.annotated_spans[fragment_index][annotation_index]
+ .classification));
TC3_ASSIGN_OR_RETURN_NULL(
ScopedLocalRef<jobject> single_annotation,
JniHelper::NewObject(env, span_class.get(), span_class_constructor,
@@ -808,14 +845,26 @@
}
}
- if (!JniHelper::SetObjectArrayElement(env, results.get(), fragment_index,
+ if (!JniHelper::SetObjectArrayElement(env, annotated_spans.get(),
+ fragment_index,
jfragmentAnnotations.get())
.ok()) {
return nullptr;
}
}
- return results.release();
+ TC3_ASSIGN_OR_RETURN_NULL(
+ ScopedLocalRef<jobjectArray> topicality_results,
+ ClassificationResultsToJObjectArray(env, model_context,
+ annotations.topicality_results));
+
+ TC3_ASSIGN_OR_RETURN_NULL(
+ ScopedLocalRef<jobject> annotations_result,
+ JniHelper::NewObject(env, annotations_class.get(),
+ annotations_class_constructor, annotated_spans.get(),
+ topicality_results.get()));
+
+ return annotations_result.release();
}
TC3_JNI_METHOD(jbyteArray, TC3_ANNOTATOR_CLASS_NAME,
@@ -825,18 +874,21 @@
return nullptr;
}
const Annotator* model = reinterpret_cast<AnnotatorJniContext*>(ptr)->model();
- TC3_ASSIGN_OR_RETURN_NULL(const std::string id_utf8, ToStlString(env, id));
- std::string serialized_knowledge_result;
- if (!model->LookUpKnowledgeEntity(id_utf8, &serialized_knowledge_result)) {
+ TC3_ASSIGN_OR_RETURN_NULL(const std::string id_utf8,
+ JStringToUtf8String(env, id));
+ auto serialized_knowledge_result_so = model->LookUpKnowledgeEntity(id_utf8);
+ if (!serialized_knowledge_result_so.ok()) {
return nullptr;
}
+ std::string serialized_knowledge_result =
+ serialized_knowledge_result_so.ValueOrDie();
TC3_ASSIGN_OR_RETURN_NULL(
ScopedLocalRef<jbyteArray> result,
JniHelper::NewByteArray(env, serialized_knowledge_result.size()));
- env->SetByteArrayRegion(
- result.get(), 0, serialized_knowledge_result.size(),
- reinterpret_cast<const jbyte*>(serialized_knowledge_result.data()));
+ TC3_RETURN_NULL_IF_ERROR(JniHelper::SetByteArrayRegion(
+ env, result.get(), 0, serialized_knowledge_result.size(),
+ reinterpret_cast<const jbyte*>(serialized_knowledge_result.data())));
return result.release();
}
@@ -864,7 +916,7 @@
}
TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME, nativeGetLocalesWithOffset)
-(JNIEnv* env, jobject thiz, jint fd, jlong offset, jlong size) {
+(JNIEnv* env, jobject clazz, jint fd, jlong offset, jlong size) {
const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
new libtextclassifier3::ScopedMmap(fd, offset, size));
TC3_ASSIGN_OR_RETURN_NULL(ScopedLocalRef<jstring> value,
@@ -880,7 +932,7 @@
}
TC3_JNI_METHOD(jint, TC3_ANNOTATOR_CLASS_NAME, nativeGetVersionWithOffset)
-(JNIEnv* env, jobject thiz, jint fd, jlong offset, jlong size) {
+(JNIEnv* env, jobject clazz, jint fd, jlong offset, jlong size) {
const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
new libtextclassifier3::ScopedMmap(fd, offset, size));
return GetVersionFromMmap(env, mmap.get());
@@ -896,7 +948,7 @@
}
TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME, nativeGetNameWithOffset)
-(JNIEnv* env, jobject thiz, jint fd, jlong offset, jlong size) {
+(JNIEnv* env, jobject clazz, jint fd, jlong offset, jlong size) {
const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
new libtextclassifier3::ScopedMmap(fd, offset, size));
TC3_ASSIGN_OR_RETURN_NULL(ScopedLocalRef<jstring> value,
diff --git a/native/annotator/annotator_jni.h b/native/annotator/annotator_jni.h
index 39a9d9a..0abaf46 100644
--- a/native/annotator/annotator_jni.h
+++ b/native/annotator/annotator_jni.h
@@ -29,13 +29,13 @@
// SmartSelection.
TC3_JNI_METHOD(jlong, TC3_ANNOTATOR_CLASS_NAME, nativeNewAnnotator)
-(JNIEnv* env, jobject thiz, jint fd);
+(JNIEnv* env, jobject clazz, jint fd);
TC3_JNI_METHOD(jlong, TC3_ANNOTATOR_CLASS_NAME, nativeNewAnnotatorFromPath)
-(JNIEnv* env, jobject thiz, jstring path);
+(JNIEnv* env, jobject clazz, jstring path);
TC3_JNI_METHOD(jlong, TC3_ANNOTATOR_CLASS_NAME, nativeNewAnnotatorWithOffset)
-(JNIEnv* env, jobject thiz, jint fd, jlong offset, jlong size);
+(JNIEnv* env, jobject clazz, jint fd, jlong offset, jlong size);
TC3_JNI_METHOD(jboolean, TC3_ANNOTATOR_CLASS_NAME,
nativeInitializeKnowledgeEngine)
@@ -68,8 +68,7 @@
jint selection_end, jobject options, jobject app_context,
jstring device_locales);
-TC3_JNI_METHOD(jobjectArray, TC3_ANNOTATOR_CLASS_NAME,
- nativeAnnotateStructuredInput)
+TC3_JNI_METHOD(jobject, TC3_ANNOTATOR_CLASS_NAME, nativeAnnotateStructuredInput)
(JNIEnv* env, jobject thiz, jlong ptr, jobjectArray jinput_fragments,
jobject options);
@@ -91,19 +90,19 @@
(JNIEnv* env, jobject clazz, jint fd);
TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME, nativeGetLocalesWithOffset)
-(JNIEnv* env, jobject thiz, jint fd, jlong offset, jlong size);
+(JNIEnv* env, jobject clazz, jint fd, jlong offset, jlong size);
TC3_JNI_METHOD(jint, TC3_ANNOTATOR_CLASS_NAME, nativeGetVersion)
(JNIEnv* env, jobject clazz, jint fd);
TC3_JNI_METHOD(jint, TC3_ANNOTATOR_CLASS_NAME, nativeGetVersionWithOffset)
-(JNIEnv* env, jobject thiz, jint fd, jlong offset, jlong size);
+(JNIEnv* env, jobject clazz, jint fd, jlong offset, jlong size);
TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME, nativeGetName)
(JNIEnv* env, jobject clazz, jint fd);
TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME, nativeGetNameWithOffset)
-(JNIEnv* env, jobject thiz, jint fd, jlong offset, jlong size);
+(JNIEnv* env, jobject clazz, jint fd, jlong offset, jlong size);
#ifdef __cplusplus
}
@@ -114,13 +113,13 @@
// Given a utf8 string and a span expressed in Java BMP (basic multilingual
// plane) codepoints, converts it to a span expressed in utf8 codepoints.
libtextclassifier3::CodepointSpan ConvertIndicesBMPToUTF8(
- const std::string& utf8_str, libtextclassifier3::CodepointSpan bmp_indices);
+ const std::string& utf8_str, const std::pair<int, int>& bmp_indices);
// Given a utf8 string and a span expressed in utf8 codepoints, converts it to a
// span expressed in Java BMP (basic multilingual plane) codepoints.
-libtextclassifier3::CodepointSpan ConvertIndicesUTF8ToBMP(
+std::pair<int, int> ConvertIndicesUTF8ToBMP(
const std::string& utf8_str,
- libtextclassifier3::CodepointSpan utf8_indices);
+ const libtextclassifier3::CodepointSpan& utf8_indices);
} // namespace libtextclassifier3
diff --git a/native/annotator/annotator_jni_common.cc b/native/annotator/annotator_jni_common.cc
index de58b70..a6f636f 100644
--- a/native/annotator/annotator_jni_common.cc
+++ b/native/annotator/annotator_jni_common.cc
@@ -16,6 +16,7 @@
#include "annotator/annotator_jni_common.h"
+#include "annotator/knowledge/knowledge-engine-types.h"
#include "utils/java/jni-base.h"
#include "utils/java/jni-helper.h"
@@ -26,13 +27,14 @@
JNIEnv* env, const jobject& jobject) {
std::unordered_set<std::string> entity_types;
jobjectArray jentity_types = reinterpret_cast<jobjectArray>(jobject);
- const int size = env->GetArrayLength(jentity_types);
+ TC3_ASSIGN_OR_RETURN(const int size,
+ JniHelper::GetArrayLength(env, jentity_types));
for (int i = 0; i < size; ++i) {
TC3_ASSIGN_OR_RETURN(
ScopedLocalRef<jstring> jentity_type,
JniHelper::GetObjectArrayElement<jstring>(env, jentity_types, i));
TC3_ASSIGN_OR_RETURN(std::string entity_type,
- ToStlString(env, jentity_type.get()));
+ JStringToUtf8String(env, jentity_type.get()));
entity_types.insert(entity_type);
}
return entity_types;
@@ -117,17 +119,35 @@
JniHelper::CallFloatMethod(
env, joptions, get_user_location_accuracy_meters));
+ // .getUsePodNer()
+ TC3_ASSIGN_OR_RETURN(
+ jmethodID get_use_pod_ner,
+ JniHelper::GetMethodID(env, options_class.get(), "getUsePodNer", "()Z"));
+ TC3_ASSIGN_OR_RETURN(bool use_pod_ner, JniHelper::CallBooleanMethod(
+ env, joptions, get_use_pod_ner));
+
+ // .getUseVocabAnnotator()
+ TC3_ASSIGN_OR_RETURN(jmethodID get_use_vocab_annotator,
+ JniHelper::GetMethodID(env, options_class.get(),
+ "getUseVocabAnnotator", "()Z"));
+ TC3_ASSIGN_OR_RETURN(
+ bool use_vocab_annotator,
+ JniHelper::CallBooleanMethod(env, joptions, get_use_vocab_annotator));
T options;
- TC3_ASSIGN_OR_RETURN(options.locales, ToStlString(env, locales.get()));
+ TC3_ASSIGN_OR_RETURN(options.locales,
+ JStringToUtf8String(env, locales.get()));
TC3_ASSIGN_OR_RETURN(options.reference_timezone,
- ToStlString(env, reference_timezone.get()));
+ JStringToUtf8String(env, reference_timezone.get()));
options.reference_time_ms_utc = reference_time;
- TC3_ASSIGN_OR_RETURN(options.detected_text_language_tags,
- ToStlString(env, detected_text_language_tags.get()));
+ TC3_ASSIGN_OR_RETURN(
+ options.detected_text_language_tags,
+ JStringToUtf8String(env, detected_text_language_tags.get()));
options.annotation_usecase =
static_cast<AnnotationUsecase>(annotation_usecase);
options.location_context = {user_location_lat, user_location_lng,
user_location_accuracy_meters};
+ options.use_pod_ner = use_pod_ner;
+ options.use_vocab_annotator = use_vocab_annotator;
return options;
}
} // namespace
@@ -154,6 +174,16 @@
ScopedLocalRef<jstring> locales,
JniHelper::CallObjectMethod<jstring>(env, joptions, get_locales));
+ // .getDetectedTextLanguageTags()
+ TC3_ASSIGN_OR_RETURN(jmethodID get_detected_text_language_tags_method,
+ JniHelper::GetMethodID(env, options_class.get(),
+ "getDetectedTextLanguageTags",
+ "()Ljava/lang/String;"));
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jstring> detected_text_language_tags,
+ JniHelper::CallObjectMethod<jstring>(
+ env, joptions, get_detected_text_language_tags_method));
+
// .getAnnotationUsecase()
TC3_ASSIGN_OR_RETURN(jmethodID get_annotation_usecase,
JniHelper::GetMethodID(env, options_class.get(),
@@ -162,11 +192,49 @@
int32 annotation_usecase,
JniHelper::CallIntMethod(env, joptions, get_annotation_usecase));
+ // .getUserLocationLat()
+ TC3_ASSIGN_OR_RETURN(jmethodID get_user_location_lat,
+ JniHelper::GetMethodID(env, options_class.get(),
+ "getUserLocationLat", "()D"));
+ TC3_ASSIGN_OR_RETURN(
+ double user_location_lat,
+ JniHelper::CallDoubleMethod(env, joptions, get_user_location_lat));
+
+ // .getUserLocationLng()
+ TC3_ASSIGN_OR_RETURN(jmethodID get_user_location_lng,
+ JniHelper::GetMethodID(env, options_class.get(),
+ "getUserLocationLng", "()D"));
+ TC3_ASSIGN_OR_RETURN(
+ double user_location_lng,
+ JniHelper::CallDoubleMethod(env, joptions, get_user_location_lng));
+
+ // .getUserLocationAccuracyMeters()
+ TC3_ASSIGN_OR_RETURN(
+ jmethodID get_user_location_accuracy_meters,
+ JniHelper::GetMethodID(env, options_class.get(),
+ "getUserLocationAccuracyMeters", "()F"));
+ TC3_ASSIGN_OR_RETURN(float user_location_accuracy_meters,
+ JniHelper::CallFloatMethod(
+ env, joptions, get_user_location_accuracy_meters));
+
+ // .getUsePodNer()
+ TC3_ASSIGN_OR_RETURN(
+ jmethodID get_use_pod_ner,
+ JniHelper::GetMethodID(env, options_class.get(), "getUsePodNer", "()Z"));
+ TC3_ASSIGN_OR_RETURN(bool use_pod_ner, JniHelper::CallBooleanMethod(
+ env, joptions, get_use_pod_ner));
+
SelectionOptions options;
- TC3_ASSIGN_OR_RETURN(options.locales, ToStlString(env, locales.get()));
+ TC3_ASSIGN_OR_RETURN(options.locales,
+ JStringToUtf8String(env, locales.get()));
options.annotation_usecase =
static_cast<AnnotationUsecase>(annotation_usecase);
-
+ TC3_ASSIGN_OR_RETURN(
+ options.detected_text_language_tags,
+ JStringToUtf8String(env, detected_text_language_tags.get()));
+ options.location_context = {user_location_lat, user_location_lng,
+ user_location_accuracy_meters};
+ options.use_pod_ner = use_pod_ner;
return options;
}
@@ -197,8 +265,19 @@
JniHelper::CallObjectMethod<jstring>(
env, joptions, get_user_familiar_language_tags));
- TC3_ASSIGN_OR_RETURN(classifier_options.user_familiar_language_tags,
- ToStlString(env, user_familiar_language_tags.get()));
+ TC3_ASSIGN_OR_RETURN(
+ classifier_options.user_familiar_language_tags,
+ JStringToUtf8String(env, user_familiar_language_tags.get()));
+
+ // .getTriggerDictionaryOnBeginnerWords()
+ TC3_ASSIGN_OR_RETURN(
+ jmethodID get_trigger_dictionary_on_beginner_words,
+ JniHelper::GetMethodID(env, options_class.get(),
+ "getTriggerDictionaryOnBeginnerWords", "()Z"));
+ TC3_ASSIGN_OR_RETURN(
+ classifier_options.trigger_dictionary_on_beginner_words,
+ JniHelper::CallBooleanMethod(env, joptions,
+ get_trigger_dictionary_on_beginner_words));
return classifier_options;
}
@@ -252,6 +331,13 @@
bool has_personalization_permission,
JniHelper::CallBooleanMethod(env, joptions,
has_personalization_permission_method));
+ // .getAnnotateMode()
+ TC3_ASSIGN_OR_RETURN(jmethodID get_annotate_mode,
+ JniHelper::GetMethodID(env, options_class.get(),
+ "getAnnotateMode", "()I"));
+ TC3_ASSIGN_OR_RETURN(
+ int32 annotate_mode,
+ JniHelper::CallIntMethod(env, joptions, get_annotate_mode));
TC3_ASSIGN_OR_RETURN(
AnnotationOptions annotation_options,
@@ -266,6 +352,17 @@
has_location_permission;
annotation_options.permissions.has_personalization_permission =
has_personalization_permission;
+ annotation_options.annotate_mode = static_cast<AnnotateMode>(annotate_mode);
+
+ // .getTriggerDictionaryOnBeginnerWords()
+ TC3_ASSIGN_OR_RETURN(
+ jmethodID get_trigger_dictionary_on_beginner_words,
+ JniHelper::GetMethodID(env, options_class.get(),
+ "getTriggerDictionaryOnBeginnerWords", "()Z"));
+ TC3_ASSIGN_OR_RETURN(
+ annotation_options.trigger_dictionary_on_beginner_words,
+ JniHelper::CallBooleanMethod(env, joptions,
+ get_trigger_dictionary_on_beginner_words));
return annotation_options;
}
@@ -290,7 +387,7 @@
ScopedLocalRef<jstring> text,
JniHelper::CallObjectMethod<jstring>(env, jfragment, get_text));
- TC3_ASSIGN_OR_RETURN(fragment.text, ToStlString(env, text.get()));
+ TC3_ASSIGN_OR_RETURN(fragment.text, JStringToUtf8String(env, text.get()));
// .hasDatetimeOptions()
TC3_ASSIGN_OR_RETURN(jmethodID has_date_time_options_method,
@@ -323,13 +420,31 @@
env, jfragment, get_reference_timezone_method));
TC3_ASSIGN_OR_RETURN(std::string reference_timezone,
- ToStlString(env, jreference_timezone.get()));
+ JStringToUtf8String(env, jreference_timezone.get()));
fragment.datetime_options =
DatetimeOptions{.reference_time_ms_utc = reference_time,
.reference_timezone = reference_timezone};
}
+ // .getBoundingBoxHeight()
+ TC3_ASSIGN_OR_RETURN(jmethodID get_bounding_box_height,
+ JniHelper::GetMethodID(env, fragment_class.get(),
+ "getBoundingBoxHeight", "()F"));
+ TC3_ASSIGN_OR_RETURN(
+ float bounding_box_height,
+ JniHelper::CallFloatMethod(env, jfragment, get_bounding_box_height));
+
+ fragment.bounding_box_height = bounding_box_height;
+
+ // .getBoundingBoxTop()
+ TC3_ASSIGN_OR_RETURN(jmethodID get_bounding_box_top,
+ JniHelper::GetMethodID(env, fragment_class.get(),
+ "getBoundingBoxTop", "()F"));
+ TC3_ASSIGN_OR_RETURN(
+ float bounding_box_top,
+ JniHelper::CallFloatMethod(env, jfragment, get_bounding_box_top));
+ fragment.bounding_box_top = bounding_box_top;
return fragment;
}
} // namespace libtextclassifier3
diff --git a/native/annotator/annotator_jni_test.cc b/native/annotator/annotator_jni_test.cc
index 929fb59..a48f173 100644
--- a/native/annotator/annotator_jni_test.cc
+++ b/native/annotator/annotator_jni_test.cc
@@ -24,52 +24,52 @@
TEST(Annotator, ConvertIndicesBMPUTF8) {
// Test boundary cases.
- EXPECT_EQ(ConvertIndicesBMPToUTF8("hello", {0, 5}), std::make_pair(0, 5));
+ EXPECT_EQ(ConvertIndicesBMPToUTF8("hello", {0, 5}), CodepointSpan(0, 5));
EXPECT_EQ(ConvertIndicesUTF8ToBMP("hello", {0, 5}), std::make_pair(0, 5));
EXPECT_EQ(ConvertIndicesBMPToUTF8("hello world", {0, 5}),
- std::make_pair(0, 5));
+ CodepointSpan(0, 5));
EXPECT_EQ(ConvertIndicesUTF8ToBMP("hello world", {0, 5}),
std::make_pair(0, 5));
EXPECT_EQ(ConvertIndicesBMPToUTF8("😁ello world", {0, 6}),
- std::make_pair(0, 5));
+ CodepointSpan(0, 5));
EXPECT_EQ(ConvertIndicesUTF8ToBMP("😁ello world", {0, 5}),
std::make_pair(0, 6));
EXPECT_EQ(ConvertIndicesBMPToUTF8("hello world", {6, 11}),
- std::make_pair(6, 11));
+ CodepointSpan(6, 11));
EXPECT_EQ(ConvertIndicesUTF8ToBMP("hello world", {6, 11}),
std::make_pair(6, 11));
EXPECT_EQ(ConvertIndicesBMPToUTF8("hello worl😁", {6, 12}),
- std::make_pair(6, 11));
+ CodepointSpan(6, 11));
EXPECT_EQ(ConvertIndicesUTF8ToBMP("hello worl😁", {6, 11}),
std::make_pair(6, 12));
// Simple example where the longer character is before the selection.
// character 😁 is 0x1f601
EXPECT_EQ(ConvertIndicesBMPToUTF8("😁 Hello World.", {3, 8}),
- std::make_pair(2, 7));
+ CodepointSpan(2, 7));
EXPECT_EQ(ConvertIndicesUTF8ToBMP("😁 Hello World.", {2, 7}),
std::make_pair(3, 8));
// Longer character is before and in selection.
EXPECT_EQ(ConvertIndicesBMPToUTF8("😁 Hell😁 World.", {3, 9}),
- std::make_pair(2, 7));
+ CodepointSpan(2, 7));
EXPECT_EQ(ConvertIndicesUTF8ToBMP("😁 Hell😁 World.", {2, 7}),
std::make_pair(3, 9));
// Longer character is before and after selection.
EXPECT_EQ(ConvertIndicesBMPToUTF8("😁 Hello😁World.", {3, 8}),
- std::make_pair(2, 7));
+ CodepointSpan(2, 7));
EXPECT_EQ(ConvertIndicesUTF8ToBMP("😁 Hello😁World.", {2, 7}),
std::make_pair(3, 8));
// Longer character is before in after selection.
EXPECT_EQ(ConvertIndicesBMPToUTF8("😁 Hell😁😁World.", {3, 9}),
- std::make_pair(2, 7));
+ CodepointSpan(2, 7));
EXPECT_EQ(ConvertIndicesUTF8ToBMP("😁 Hell😁😁World.", {2, 7}),
std::make_pair(3, 9));
diff --git a/native/annotator/annotator_test-include.cc b/native/annotator/annotator_test-include.cc
new file mode 100644
index 0000000..3ecc201
--- /dev/null
+++ b/native/annotator/annotator_test-include.cc
@@ -0,0 +1,3158 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/annotator_test-include.h"
+
+#include <iostream>
+#include <memory>
+#include <string>
+#include <type_traits>
+
+#include "annotator/annotator.h"
+#include "annotator/collections.h"
+#include "annotator/model_generated.h"
+#include "annotator/test-utils.h"
+#include "annotator/types-test-util.h"
+#include "annotator/types.h"
+#include "utils/grammar/utils/locale-shard-map.h"
+#include "utils/grammar/utils/rules.h"
+#include "utils/testing/annotator.h"
+#include "lang_id/fb_model/lang-id-from-fb.h"
+#include "lang_id/lang-id.h"
+
+namespace libtextclassifier3 {
+namespace test_internal {
+
+using ::testing::Contains;
+using ::testing::ElementsAre;
+using ::testing::ElementsAreArray;
+using ::testing::Eq;
+using ::testing::IsEmpty;
+using ::testing::UnorderedElementsAreArray;
+
+std::string GetTestModelPath() { return GetModelPath() + "test_model.fb"; }
+
+std::string GetModelWithVocabPath() {
+ return GetModelPath() + "test_vocab_model.fb";
+}
+
+std::string GetTestModelWithDatetimeRegEx() {
+ std::string model_buffer = ReadFile(GetTestModelPath());
+ model_buffer = ModifyAnnotatorModel(model_buffer, [](ModelT* model) {
+ model->datetime_grammar_model.reset(nullptr);
+ });
+ return model_buffer;
+}
+
+void ExpectFirstEntityIsMoney(const std::vector<AnnotatedSpan>& result,
+ const std::string& currency,
+ const std::string& amount, const int whole_part,
+ const int decimal_part, const int nanos) {
+ ASSERT_GT(result.size(), 0);
+ ASSERT_GT(result[0].classification.size(), 0);
+ ASSERT_EQ(result[0].classification[0].collection, "money");
+
+ const EntityData* entity_data =
+ GetEntityData(result[0].classification[0].serialized_entity_data.data());
+ ASSERT_NE(entity_data, nullptr);
+ ASSERT_NE(entity_data->money(), nullptr);
+ EXPECT_EQ(entity_data->money()->unnormalized_currency()->str(), currency);
+ EXPECT_EQ(entity_data->money()->unnormalized_amount()->str(), amount);
+ EXPECT_EQ(entity_data->money()->amount_whole_part(), whole_part);
+ EXPECT_EQ(entity_data->money()->amount_decimal_part(), decimal_part);
+ EXPECT_EQ(entity_data->money()->nanos(), nanos);
+}
+
+TEST_F(AnnotatorTest, EmbeddingExecutorLoadingFails) {
+ std::unique_ptr<Annotator> classifier =
+ Annotator::FromPath(GetModelPath() + "wrong_embeddings.fb", unilib_.get(),
+ calendarlib_.get());
+ EXPECT_FALSE(classifier);
+}
+
+void VerifyClassifyText(const Annotator* classifier) {
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ("other",
+ FirstResult(classifier->ClassifyText(
+ "this afternoon Barack Obama gave a speech at", {15, 27})));
+ EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
+ "Call me at (800) 123-456 today", {11, 24})));
+
+ // More lines.
+ EXPECT_EQ("other",
+ FirstResult(classifier->ClassifyText(
+ "this afternoon Barack Obama gave a speech at|Visit "
+ "www.google.com every today!|Call me at (800) 123-456 today.",
+ {15, 27})));
+ EXPECT_EQ("phone",
+ FirstResult(classifier->ClassifyText(
+ "this afternoon Barack Obama gave a speech at|Visit "
+ "www.google.com every today!|Call me at (800) 123-456 today.",
+ {90, 103})));
+
+ // Single word.
+ EXPECT_EQ("other", FirstResult(classifier->ClassifyText("obama", {0, 5})));
+ EXPECT_EQ("other", FirstResult(classifier->ClassifyText("asdf", {0, 4})));
+
+ // Junk. These should not crash the test.
+ classifier->ClassifyText("", {0, 0});
+ classifier->ClassifyText("asdf", {0, 0});
+ classifier->ClassifyText("asdf", {0, 27});
+ classifier->ClassifyText("asdf", {-30, 300});
+ classifier->ClassifyText("asdf", {-10, -1});
+ classifier->ClassifyText("asdf", {100, 17});
+ classifier->ClassifyText("a\n\n\n\nx x x\n\n\n\n\n\n", {1, 5});
+
+ // Test invalid utf8 input.
+ EXPECT_EQ("<INVALID RESULTS>", FirstResult(classifier->ClassifyText(
+ "\xf0\x9f\x98\x8b\x8b", {0, 0})));
+}
+
+TEST_F(AnnotatorTest, ClassifyText) {
+ std::unique_ptr<Annotator> classifier = Annotator::FromPath(
+ GetTestModelPath(), unilib_.get(), calendarlib_.get());
+ VerifyClassifyText(classifier.get());
+}
+
+TEST_F(AnnotatorTest, ClassifyTextLocalesAndDictionary) {
+ std::unique_ptr<Annotator> classifier = Annotator::FromPath(
+ GetTestModelPath(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ("other", FirstResult(classifier->ClassifyText("isotope", {0, 7})));
+
+ ClassificationOptions classification_options;
+ classification_options.detected_text_language_tags = "en";
+ EXPECT_EQ("dictionary", FirstResult(classifier->ClassifyText(
+ "isotope", {0, 7}, classification_options)));
+
+ classification_options.detected_text_language_tags = "uz";
+ EXPECT_EQ("other", FirstResult(classifier->ClassifyText(
+ "isotope", {0, 7}, classification_options)));
+}
+
+TEST_F(AnnotatorTest, ClassifyTextUseVocabAnnotatorWithoutVocabModel) {
+ std::unique_ptr<Annotator> classifier = Annotator::FromPath(
+ GetTestModelPath(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ ClassificationOptions classification_options;
+ classification_options.detected_text_language_tags = "en";
+ classification_options.use_vocab_annotator = true;
+
+ EXPECT_EQ("dictionary", FirstResult(classifier->ClassifyText(
+ "isotope", {0, 7}, classification_options)));
+}
+
+#ifdef TC3_VOCAB_ANNOTATOR_IMPL
+TEST_F(AnnotatorTest, ClassifyTextWithVocabModel) {
+ std::unique_ptr<Annotator> classifier = Annotator::FromPath(
+ GetModelWithVocabPath(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ ClassificationOptions classification_options;
+ classification_options.detected_text_language_tags = "en";
+
+ // The FFModel model does not annotate "integrity" as "dictionary", but the
+ // vocab annotator does. So we can use that to check if the vocab annotator is
+ // in use.
+ classification_options.use_vocab_annotator = true;
+ EXPECT_EQ("dictionary", FirstResult(classifier->ClassifyText(
+ "integrity", {0, 9}, classification_options)));
+ classification_options.use_vocab_annotator = false;
+ EXPECT_EQ("other", FirstResult(classifier->ClassifyText(
+ "integrity", {0, 9}, classification_options)));
+}
+#endif // TC3_VOCAB_ANNOTATOR_IMPL
+
+TEST_F(AnnotatorTest, ClassifyTextDisabledFail) {
+ const std::string test_model = ReadFile(GetTestModelPath());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+ TC3_CHECK(unpacked_model != nullptr);
+
+ unpacked_model->classification_model.clear();
+ unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
+ unpacked_model->triggering_options->enabled_modes = ModeFlag_SELECTION;
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get(), calendarlib_.get());
+
+ // The classification model is still needed for selection scores.
+ ASSERT_FALSE(classifier);
+}
+
+TEST_F(AnnotatorTest, ClassifyTextDisabled) {
+ const std::string test_model = ReadFile(GetTestModelPath());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ unpacked_model->enabled_modes = ModeFlag_ANNOTATION_AND_SELECTION;
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ EXPECT_THAT(
+ classifier->ClassifyText("Call me at (800) 123-456 today", {11, 24}),
+ IsEmpty());
+}
+
+TEST_F(AnnotatorTest, ClassifyTextFilteredCollections) {
+ const std::string test_model = ReadFile(GetTestModelPath());
+
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ test_model.c_str(), test_model.size(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
+ "Call me at (800) 123-456 today", {11, 24})));
+
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+ unpacked_model->output_options.reset(new OutputOptionsT);
+
+ // Disable phone classification
+ unpacked_model->output_options->filtered_collections_classification.push_back(
+ "phone");
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ("other", FirstResult(classifier->ClassifyText(
+ "Call me at (800) 123-456 today", {11, 24})));
+
+ // Check that the address classification still passes.
+ EXPECT_EQ("address", FirstResult(classifier->ClassifyText(
+ "350 Third Street, Cambridge", {0, 27})));
+}
+
+TEST_F(AnnotatorTest, ClassifyTextRegularExpression) {
+ const std::string test_model = ReadFile(GetTestModelPath());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ // Add test regex models.
+ unpacked_model->regex_model->patterns.push_back(MakePattern(
+ "person", "Barack Obama", /*enabled_for_classification=*/true,
+ /*enabled_for_selection=*/false, /*enabled_for_annotation=*/false, 1.0));
+ unpacked_model->regex_model->patterns.push_back(MakePattern(
+ "flight", "[a-zA-Z]{2}\\d{2,4}", /*enabled_for_classification=*/true,
+ /*enabled_for_selection=*/false, /*enabled_for_annotation=*/false, 0.5));
+ std::unique_ptr<RegexModel_::PatternT> verified_pattern =
+ MakePattern("payment_card", "\\d{4}(?: \\d{4}){3}",
+ /*enabled_for_classification=*/true,
+ /*enabled_for_selection=*/false,
+ /*enabled_for_annotation=*/false, 1.0);
+ verified_pattern->verification_options.reset(new VerificationOptionsT);
+ verified_pattern->verification_options->verify_luhn_checksum = true;
+ unpacked_model->regex_model->patterns.push_back(std::move(verified_pattern));
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ("flight",
+ FirstResult(classifier->ClassifyText(
+ "Your flight LX373 is delayed by 3 hours.", {12, 17})));
+ EXPECT_EQ("person",
+ FirstResult(classifier->ClassifyText(
+ "this afternoon Barack Obama gave a speech at", {15, 27})));
+ EXPECT_EQ("email",
+ FirstResult(classifier->ClassifyText("you@android.com", {0, 15})));
+ EXPECT_EQ("email", FirstResult(classifier->ClassifyText(
+ "Contact me at you@android.com", {14, 29})));
+
+ EXPECT_EQ("url", FirstResult(classifier->ClassifyText(
+ "Visit www.google.com every today!", {6, 20})));
+
+ EXPECT_EQ("flight", FirstResult(classifier->ClassifyText("LX 37", {0, 5})));
+ EXPECT_EQ("flight", FirstResult(classifier->ClassifyText("flight LX 37 abcd",
+ {7, 12})));
+ EXPECT_EQ("payment_card", FirstResult(classifier->ClassifyText(
+ "cc: 4012 8888 8888 1881", {4, 23})));
+ EXPECT_EQ("payment_card", FirstResult(classifier->ClassifyText(
+ "2221 0067 4735 6281", {0, 19})));
+ // Luhn check fails.
+ EXPECT_EQ("other", FirstResult(classifier->ClassifyText("2221 0067 4735 6282",
+ {0, 19})));
+
+ // More lines.
+ EXPECT_EQ("url",
+ FirstResult(classifier->ClassifyText(
+ "this afternoon Barack Obama gave a speech at|Visit "
+ "www.google.com every today!|Call me at (800) 123-456 today.",
+ {51, 65})));
+}
+
+#ifndef TC3_DISABLE_LUA
+TEST_F(AnnotatorTest, ClassifyTextRegularExpressionLuaVerification) {
+ const std::string test_model = ReadFile(GetTestModelPath());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ // Add test regex models.
+ std::unique_ptr<RegexModel_::PatternT> verified_pattern =
+ MakePattern("parcel_tracking", "((\\d{2})-00-\\d{6}-\\d{8})",
+ /*enabled_for_classification=*/true,
+ /*enabled_for_selection=*/false,
+ /*enabled_for_annotation=*/false, 1.0);
+ verified_pattern->verification_options.reset(new VerificationOptionsT);
+ verified_pattern->verification_options->lua_verifier = 0;
+ unpacked_model->regex_model->patterns.push_back(std::move(verified_pattern));
+ unpacked_model->regex_model->lua_verifier.push_back(
+ "return match[2].text==\"99\"");
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ // Custom rule triggers and is correctly verified.
+ EXPECT_EQ("parcel_tracking", FirstResult(classifier->ClassifyText(
+ "99-00-123456-12345678", {0, 21})));
+
+ // Custom verification fails.
+ EXPECT_EQ("other", FirstResult(classifier->ClassifyText(
+ "90-00-123456-12345678", {0, 21})));
+}
+#endif // TC3_DISABLE_LUA
+
+TEST_F(AnnotatorTest, ClassifyTextRegularExpressionEntityData) {
+ const std::string test_model = ReadFile(GetTestModelPath());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ // Add fake entity schema metadata.
+ AddTestEntitySchemaData(unpacked_model.get());
+
+ AddTestRegexModel(unpacked_model.get());
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ // Check with full name.
+ {
+ auto classifications =
+ classifier->ClassifyText("Barack Obama is 57 years old", {0, 28});
+ EXPECT_EQ(1, classifications.size());
+ EXPECT_EQ("person_with_age", classifications[0].collection);
+
+ // Check entity data.
+ const flatbuffers::Table* entity =
+ flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
+ classifications[0].serialized_entity_data.data()));
+ EXPECT_EQ(
+ entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(),
+ "Barack");
+ EXPECT_EQ(
+ entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(),
+ "Obama");
+ // Check `age`.
+ EXPECT_EQ(entity->GetField<int>(/*field=*/10, /*defaultval=*/0), 57);
+
+ // Check `is_alive`.
+ EXPECT_TRUE(entity->GetField<bool>(/*field=*/6, /*defaultval=*/false));
+
+ // Check `former_us_president`.
+ EXPECT_TRUE(entity->GetField<bool>(/*field=*/12, /*defaultval=*/false));
+ }
+
+ // Check only with first name.
+ {
+ auto classifications =
+ classifier->ClassifyText("Barack is 57 years old", {0, 22});
+ EXPECT_EQ(1, classifications.size());
+ EXPECT_EQ("person_with_age", classifications[0].collection);
+
+ // Check entity data.
+ const flatbuffers::Table* entity =
+ flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
+ classifications[0].serialized_entity_data.data()));
+ EXPECT_EQ(
+ entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(),
+ "Barack");
+
+ // Check `age`.
+ EXPECT_EQ(entity->GetField<int>(/*field=*/10, /*defaultval=*/0), 57);
+
+ // Check `is_alive`.
+ EXPECT_TRUE(entity->GetField<bool>(/*field=*/6, /*defaultval=*/false));
+
+ // Check `former_us_president`.
+ EXPECT_FALSE(entity->GetField<bool>(/*field=*/12, /*defaultval=*/false));
+ }
+}
+
+TEST_F(AnnotatorTest, ClassifyTextRegularExpressionEntityDataNormalization) {
+ const std::string test_model = ReadFile(GetTestModelPath());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ // Add fake entity schema metadata.
+ AddTestEntitySchemaData(unpacked_model.get());
+
+ AddTestRegexModel(unpacked_model.get());
+
+ // Upper case last name as post-processing.
+ RegexModel_::PatternT* pattern =
+ unpacked_model->regex_model->patterns.back().get();
+ pattern->capturing_group[2]->normalization_options.reset(
+ new NormalizationOptionsT);
+ pattern->capturing_group[2]
+ ->normalization_options->codepointwise_normalization =
+ NormalizationOptions_::CodepointwiseNormalizationOp_UPPERCASE;
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ auto classifications =
+ classifier->ClassifyText("Barack Obama is 57 years old", {0, 28});
+ EXPECT_EQ(1, classifications.size());
+ EXPECT_EQ("person_with_age", classifications[0].collection);
+
+ // Check entity data normalization.
+ const flatbuffers::Table* entity =
+ flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
+ classifications[0].serialized_entity_data.data()));
+ EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(),
+ "OBAMA");
+}
+
+TEST_F(AnnotatorTest, ClassifyTextPriorityResolution) {
+ const std::string test_model = ReadFile(GetTestModelPath());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+ TC3_CHECK(libtextclassifier3::DecompressModel(unpacked_model.get()));
+ // Add test regex models.
+ unpacked_model->regex_model->patterns.clear();
+ unpacked_model->regex_model->patterns.push_back(MakePattern(
+ "flight1", "[a-zA-Z]{2}\\d{2,4}", /*enabled_for_classification=*/true,
+ /*enabled_for_selection=*/false, /*enabled_for_annotation=*/false,
+ /*score=*/1.0, /*priority_score=*/1.0));
+ unpacked_model->regex_model->patterns.push_back(MakePattern(
+ "flight2", "[a-zA-Z]{2}\\d{2,4}", /*enabled_for_classification=*/true,
+ /*enabled_for_selection=*/false, /*enabled_for_annotation=*/false,
+ /*score=*/1.0, /*priority_score=*/0.0));
+
+ {
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ("flight1",
+ FirstResult(classifier->ClassifyText(
+ "Your flight LX373 is delayed by 3 hours.", {12, 17})));
+ }
+
+ unpacked_model->regex_model->patterns.back()->priority_score = 3.0;
+ {
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ("flight2",
+ FirstResult(classifier->ClassifyText(
+ "Your flight LX373 is delayed by 3 hours.", {12, 17})));
+ }
+}
+
+TEST_F(AnnotatorTest, AnnotatePriorityResolution) {
+ const std::string test_model = ReadFile(GetTestModelPath());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+ TC3_CHECK(libtextclassifier3::DecompressModel(unpacked_model.get()));
+ // Add test regex models. One of them has higher priority score than
+ // the other. We'll test that always the one with higher priority score
+ // ends up winning.
+ unpacked_model->regex_model->patterns.clear();
+ const std::string flight_regex = "([a-zA-Z]{2}\\d{2,4})";
+ unpacked_model->regex_model->patterns.push_back(MakePattern(
+ "flight", flight_regex, /*enabled_for_classification=*/true,
+ /*enabled_for_selection=*/true, /*enabled_for_annotation=*/true,
+ /*score=*/1.0, /*priority_score=*/1.0));
+ unpacked_model->regex_model->patterns.push_back(MakePattern(
+ "flight", flight_regex, /*enabled_for_classification=*/true,
+ /*enabled_for_selection=*/true, /*enabled_for_annotation=*/true,
+ /*score=*/1.0, /*priority_score=*/0.0));
+
+ // "flight" that wins should have a priority score of 1.0.
+ {
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ const std::vector<AnnotatedSpan> results =
+ classifier->Annotate("Your flight LX373 is delayed by 3 hours.");
+ ASSERT_THAT(results, Not(IsEmpty()));
+ EXPECT_THAT(results[0].classification, Not(IsEmpty()));
+ EXPECT_GE(results[0].classification[0].priority_score, 0.9);
+ }
+
+ // When we increase the priority score, the "flight" that wins should have a
+ // priority score of 3.0.
+ unpacked_model->regex_model->patterns.back()->priority_score = 3.0;
+ {
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ const std::vector<AnnotatedSpan> results =
+ classifier->Annotate("Your flight LX373 is delayed by 3 hours.");
+ ASSERT_THAT(results, Not(IsEmpty()));
+ EXPECT_THAT(results[0].classification, Not(IsEmpty()));
+ EXPECT_GE(results[0].classification[0].priority_score, 2.9);
+ }
+}
+
+TEST_F(AnnotatorTest, SuggestSelectionRegularExpression) {
+ const std::string test_model = ReadFile(GetTestModelPath());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ // Add test regex models.
+ unpacked_model->regex_model->patterns.push_back(MakePattern(
+ "person", " (Barack Obama) ", /*enabled_for_classification=*/false,
+ /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
+ unpacked_model->regex_model->patterns.push_back(MakePattern(
+ "flight", "([a-zA-Z]{2} ?\\d{2,4})", /*enabled_for_classification=*/false,
+ /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
+ unpacked_model->regex_model->patterns.back()->priority_score = 1.1;
+ std::unique_ptr<RegexModel_::PatternT> verified_pattern =
+ MakePattern("payment_card", "(\\d{4}(?: \\d{4}){3})",
+ /*enabled_for_classification=*/false,
+ /*enabled_for_selection=*/true,
+ /*enabled_for_annotation=*/false, 1.0);
+ verified_pattern->verification_options.reset(new VerificationOptionsT);
+ verified_pattern->verification_options->verify_luhn_checksum = true;
+ unpacked_model->regex_model->patterns.push_back(std::move(verified_pattern));
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ // Check regular expression selection.
+ EXPECT_EQ(classifier->SuggestSelection(
+ "Your flight MA 0123 is delayed by 3 hours.", {12, 14}),
+ CodepointSpan(12, 19));
+ EXPECT_EQ(classifier->SuggestSelection(
+ "this afternoon Barack Obama gave a speech at", {15, 21}),
+ CodepointSpan(15, 27));
+ EXPECT_EQ(classifier->SuggestSelection("cc: 4012 8888 8888 1881", {9, 14}),
+ CodepointSpan(4, 23));
+}
+
+TEST_F(AnnotatorTest, SuggestSelectionRegularExpressionCustomSelectionBounds) {
+ const std::string test_model = ReadFile(GetTestModelPath());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ // Add test regex models.
+ std::unique_ptr<RegexModel_::PatternT> custom_selection_bounds_pattern =
+ MakePattern("date_range",
+ "(?:(?:from )?(\\d{2}\\/\\d{2}\\/\\d{4}) to "
+ "(\\d{2}\\/\\d{2}\\/\\d{4}))|(for ever)",
+ /*enabled_for_classification=*/false,
+ /*enabled_for_selection=*/true,
+ /*enabled_for_annotation=*/false, 1.0);
+ custom_selection_bounds_pattern->capturing_group.emplace_back(
+ new CapturingGroupT);
+ custom_selection_bounds_pattern->capturing_group.emplace_back(
+ new CapturingGroupT);
+ custom_selection_bounds_pattern->capturing_group.emplace_back(
+ new CapturingGroupT);
+ custom_selection_bounds_pattern->capturing_group.emplace_back(
+ new CapturingGroupT);
+ custom_selection_bounds_pattern->capturing_group[0]->extend_selection = false;
+ custom_selection_bounds_pattern->capturing_group[1]->extend_selection = true;
+ custom_selection_bounds_pattern->capturing_group[2]->extend_selection = true;
+ custom_selection_bounds_pattern->capturing_group[3]->extend_selection = true;
+ unpacked_model->regex_model->patterns.push_back(
+ std::move(custom_selection_bounds_pattern));
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ // Check regular expression selection.
+ EXPECT_EQ(classifier->SuggestSelection("it's from 04/30/1789 to 03/04/1797",
+ {21, 23}),
+ CodepointSpan(10, 34));
+ EXPECT_EQ(classifier->SuggestSelection("it takes for ever", {9, 12}),
+ CodepointSpan(9, 17));
+}
+
+TEST_F(AnnotatorTest, SuggestSelectionRegularExpressionConflictsModelWins) {
+ const std::string test_model = ReadFile(GetTestModelPath());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ // Add test regex models.
+ unpacked_model->regex_model->patterns.push_back(MakePattern(
+ "person", " (Barack Obama) ", /*enabled_for_classification=*/false,
+ /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
+ unpacked_model->regex_model->patterns.push_back(MakePattern(
+ "flight", "([a-zA-Z]{2} ?\\d{2,4})", /*enabled_for_classification=*/false,
+ /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
+ unpacked_model->regex_model->patterns.back()->priority_score = 0.5;
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ // Check conflict resolution.
+ EXPECT_EQ(
+ classifier->SuggestSelection(
+ "saw Barack Obama today .. 350 Third Street, Cambridge, MA 0123",
+ {55, 57}),
+ CodepointSpan(26, 62));
+}
+
+TEST_F(AnnotatorTest, SuggestSelectionRegularExpressionConflictsRegexWins) {
+ const std::string test_model = ReadFile(GetTestModelPath());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ // Add test regex models.
+ unpacked_model->regex_model->patterns.push_back(MakePattern(
+ "person", " (Barack Obama) ", /*enabled_for_classification=*/false,
+ /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
+ unpacked_model->regex_model->patterns.push_back(MakePattern(
+ "flight", "([a-zA-Z]{2} ?\\d{2,4})", /*enabled_for_classification=*/false,
+ /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
+ unpacked_model->regex_model->patterns.back()->priority_score = 1.1;
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ // Check conflict resolution.
+ EXPECT_EQ(
+ classifier->SuggestSelection(
+ "saw Barack Obama today .. 350 Third Street, Cambridge, MA 0123",
+ {55, 57}),
+ CodepointSpan(55, 62));
+}
+
+TEST_F(AnnotatorTest, AnnotateRegex) {
+ const std::string test_model = ReadFile(GetTestModelPath());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ // Add test regex models.
+ unpacked_model->regex_model->patterns.push_back(MakePattern(
+ "person", " (Barack Obama) ", /*enabled_for_classification=*/false,
+ /*enabled_for_selection=*/false, /*enabled_for_annotation=*/true, 1.0));
+ unpacked_model->regex_model->patterns.push_back(MakePattern(
+ "flight", "([a-zA-Z]{2} ?\\d{2,4})", /*enabled_for_classification=*/false,
+ /*enabled_for_selection=*/false, /*enabled_for_annotation=*/true, 0.5));
+ std::unique_ptr<RegexModel_::PatternT> verified_pattern =
+ MakePattern("payment_card", "(\\d{4}(?: \\d{4}){3})",
+ /*enabled_for_classification=*/false,
+ /*enabled_for_selection=*/false,
+ /*enabled_for_annotation=*/true, 1.0);
+ verified_pattern->verification_options.reset(new VerificationOptionsT);
+ verified_pattern->verification_options->verify_luhn_checksum = true;
+ unpacked_model->regex_model->patterns.push_back(std::move(verified_pattern));
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ const std::string test_string =
+ "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
+ "number is 853 225 3556\nand my card is 4012 8888 8888 1881.\n";
+ EXPECT_THAT(classifier->Annotate(test_string),
+ ElementsAreArray({IsAnnotatedSpan(6, 18, "person"),
+ IsAnnotatedSpan(28, 55, "address"),
+ IsAnnotatedSpan(79, 91, "phone"),
+ IsAnnotatedSpan(107, 126, "payment_card")}));
+}
+
+TEST_F(AnnotatorTest, AnnotatesFlightNumbers) {
+ std::unique_ptr<Annotator> classifier = Annotator::FromPath(
+ GetTestModelPath(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ // ICAO is only used for selected airlines.
+ // Expected: LX373, EZY1234 and U21234.
+ const std::string test_string = "flights LX373, SWR373, EZY1234, U21234";
+ EXPECT_THAT(classifier->Annotate(test_string),
+ ElementsAreArray({IsAnnotatedSpan(8, 13, "flight"),
+ IsAnnotatedSpan(23, 30, "flight"),
+ IsAnnotatedSpan(32, 38, "flight")}));
+}
+
+#ifndef TC3_DISABLE_LUA
+TEST_F(AnnotatorTest, AnnotateRegexLuaVerification) {
+ const std::string test_model = ReadFile(GetTestModelPath());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ // Add test regex models.
+ std::unique_ptr<RegexModel_::PatternT> verified_pattern =
+ MakePattern("parcel_tracking", "((\\d{2})-00-\\d{6}-\\d{8})",
+ /*enabled_for_classification=*/true,
+ /*enabled_for_selection=*/true,
+ /*enabled_for_annotation=*/true, 1.0);
+ verified_pattern->verification_options.reset(new VerificationOptionsT);
+ verified_pattern->verification_options->lua_verifier = 0;
+ unpacked_model->regex_model->patterns.push_back(std::move(verified_pattern));
+ unpacked_model->regex_model->lua_verifier.push_back(
+ "return match[2].text==\"99\"");
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ const std::string test_string =
+ "your parcel is on the way: 99-00-123456-12345678";
+ EXPECT_THAT(classifier->Annotate(test_string),
+ ElementsAreArray({IsAnnotatedSpan(27, 48, "parcel_tracking")}));
+}
+#endif // TC3_DISABLE_LUA
+
+TEST_F(AnnotatorTest, AnnotateTextRegularExpressionEntityData) {
+ const std::string test_model = ReadFile(GetTestModelPath());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ // Add fake entity schema metadata.
+ AddTestEntitySchemaData(unpacked_model.get());
+
+ AddTestRegexModel(unpacked_model.get());
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ AnnotationOptions options;
+ options.is_serialized_entity_data_enabled = true;
+ auto annotations =
+ classifier->Annotate("Barack Obama is 57 years old", options);
+ EXPECT_EQ(1, annotations.size());
+ EXPECT_EQ(1, annotations[0].classification.size());
+ EXPECT_EQ("person_with_age", annotations[0].classification[0].collection);
+
+ // Check entity data.
+ const flatbuffers::Table* entity =
+ flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
+ annotations[0].classification[0].serialized_entity_data.data()));
+ EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(),
+ "Barack");
+ EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(),
+ "Obama");
+ // Check `age`.
+ EXPECT_EQ(entity->GetField<int>(/*field=*/10, /*defaultval=*/0), 57);
+
+ // Check `is_alive`.
+ EXPECT_TRUE(entity->GetField<bool>(/*field=*/6, /*defaultval=*/false));
+
+ // Check `former_us_president`.
+ EXPECT_TRUE(entity->GetField<bool>(/*field=*/12, /*defaultval=*/false));
+}
+
+TEST_F(AnnotatorTest, AnnotateTextRegularExpressionEntityDataNormalization) {
+ const std::string test_model = ReadFile(GetTestModelPath());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ // Add fake entity schema metadata.
+ AddTestEntitySchemaData(unpacked_model.get());
+
+ AddTestRegexModel(unpacked_model.get());
+
+ // Upper case last name as post-processing.
+ RegexModel_::PatternT* pattern =
+ unpacked_model->regex_model->patterns.back().get();
+ pattern->capturing_group[2]->normalization_options.reset(
+ new NormalizationOptionsT);
+ pattern->capturing_group[2]
+ ->normalization_options->codepointwise_normalization =
+ NormalizationOptions_::CodepointwiseNormalizationOp_UPPERCASE;
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ AnnotationOptions options;
+ options.is_serialized_entity_data_enabled = true;
+ auto annotations =
+ classifier->Annotate("Barack Obama is 57 years old", options);
+ EXPECT_EQ(1, annotations.size());
+ EXPECT_EQ(1, annotations[0].classification.size());
+ EXPECT_EQ("person_with_age", annotations[0].classification[0].collection);
+
+ // Check normalization.
+ const flatbuffers::Table* entity =
+ flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
+ annotations[0].classification[0].serialized_entity_data.data()));
+ EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(),
+ "OBAMA");
+}
+
+TEST_F(AnnotatorTest, AnnotateTextRegularExpressionEntityDataDisabled) {
+ const std::string test_model = ReadFile(GetTestModelPath());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ // Add fake entity schema metadata.
+ AddTestEntitySchemaData(unpacked_model.get());
+
+ AddTestRegexModel(unpacked_model.get());
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ AnnotationOptions options;
+ options.is_serialized_entity_data_enabled = false;
+ auto annotations =
+ classifier->Annotate("Barack Obama is 57 years old", options);
+ EXPECT_EQ(1, annotations.size());
+ EXPECT_EQ(1, annotations[0].classification.size());
+ EXPECT_EQ("person_with_age", annotations[0].classification[0].collection);
+
+ // Check entity data.
+ EXPECT_EQ("", annotations[0].classification[0].serialized_entity_data);
+}
+
+TEST_F(AnnotatorTest, PhoneFiltering) {
+ std::unique_ptr<Annotator> classifier = Annotator::FromPath(
+ GetTestModelPath(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
+ "phone: (123) 456 789", {7, 20})));
+ EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
+ "phone: (123) 456 789,0001112", {7, 25})));
+ EXPECT_EQ("other", FirstResult(classifier->ClassifyText(
+ "phone: (123) 456 789,0001112", {7, 28})));
+}
+
+TEST_F(AnnotatorTest, SuggestSelection) {
+ std::unique_ptr<Annotator> classifier = Annotator::FromPath(
+ GetTestModelPath(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ(classifier->SuggestSelection(
+ "this afternoon Barack Obama gave a speech at", {15, 21}),
+ CodepointSpan(15, 21));
+
+ // Try passing whole string.
+ // If more than 1 token is specified, we should return back what entered.
+ EXPECT_EQ(
+ classifier->SuggestSelection("350 Third Street, Cambridge", {0, 27}),
+ CodepointSpan(0, 27));
+
+ // Single letter.
+ EXPECT_EQ(classifier->SuggestSelection("a", {0, 1}), CodepointSpan(0, 1));
+
+ // Single word.
+ EXPECT_EQ(classifier->SuggestSelection("asdf", {0, 4}), CodepointSpan(0, 4));
+
+ EXPECT_EQ(
+ classifier->SuggestSelection("call me at 857 225 3556 today", {11, 14}),
+ CodepointSpan(11, 23));
+
+ // Unpaired bracket stripping.
+ EXPECT_EQ(
+ classifier->SuggestSelection("call me at (857) 225 3556 today", {12, 14}),
+ CodepointSpan(11, 25));
+ EXPECT_EQ(classifier->SuggestSelection("call me at (857 today", {12, 14}),
+ CodepointSpan(12, 15));
+ EXPECT_EQ(classifier->SuggestSelection("call me at 3556) today", {12, 14}),
+ CodepointSpan(11, 15));
+ EXPECT_EQ(classifier->SuggestSelection("call me at )857( today", {12, 14}),
+ CodepointSpan(12, 15));
+
+ // If the resulting selection would be empty, the original span is returned.
+ EXPECT_EQ(classifier->SuggestSelection("call me at )( today", {11, 13}),
+ CodepointSpan(11, 13));
+ EXPECT_EQ(classifier->SuggestSelection("call me at ( today", {11, 12}),
+ CodepointSpan(11, 12));
+ EXPECT_EQ(classifier->SuggestSelection("call me at ) today", {11, 12}),
+ CodepointSpan(11, 12));
+
+ // If the original span is larger than the found selection, the original span
+ // is returned.
+ EXPECT_EQ(
+ classifier->SuggestSelection("call me at 857 225 3556 today", {5, 24}),
+ CodepointSpan(5, 24));
+}
+
+TEST_F(AnnotatorTest, SuggestSelectionDisabledFail) {
+ const std::string test_model = ReadFile(GetTestModelPath());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ // Disable the selection model.
+ unpacked_model->selection_model.clear();
+ unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
+ unpacked_model->triggering_options->enabled_modes = ModeFlag_ANNOTATION;
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get(), calendarlib_.get());
+ // Selection model needs to be present for annotation.
+ ASSERT_FALSE(classifier);
+}
+
+TEST_F(AnnotatorTest, SuggestSelectionDisabled) {
+ const std::string test_model = ReadFile(GetTestModelPath());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ // Disable the selection model.
+ unpacked_model->selection_model.clear();
+ unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
+ unpacked_model->triggering_options->enabled_modes = ModeFlag_CLASSIFICATION;
+ unpacked_model->enabled_modes = ModeFlag_CLASSIFICATION;
+
+ // Disable the number annotator. With the selection model disabled, there is
+ // no feature processor, which is required for the number annotator.
+ unpacked_model->number_annotator_options->enabled = false;
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ(
+ classifier->SuggestSelection("call me at 857 225 3556 today", {11, 14}),
+ CodepointSpan(11, 14));
+
+ EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
+ "call me at (800) 123-456 today", {11, 24})));
+
+ EXPECT_THAT(classifier->Annotate("call me at (800) 123-456 today"),
+ IsEmpty());
+}
+
+TEST_F(AnnotatorTest, SuggestSelectionFilteredCollections) {
+ const std::string test_model = ReadFile(GetTestModelPath());
+
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ test_model.c_str(), test_model.size(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ(
+ classifier->SuggestSelection("call me at 857 225 3556 today", {11, 14}),
+ CodepointSpan(11, 23));
+
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+ unpacked_model->output_options.reset(new OutputOptionsT);
+
+ // Disable phone selection
+ unpacked_model->output_options->filtered_collections_selection.push_back(
+ "phone");
+ // We need to force this for filtering.
+ unpacked_model->selection_options->always_classify_suggested_selection = true;
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ(
+ classifier->SuggestSelection("call me at 857 225 3556 today", {11, 14}),
+ CodepointSpan(11, 14));
+
+ // Address selection should still work.
+ EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {4, 9}),
+ CodepointSpan(0, 27));
+}
+
+TEST_F(AnnotatorTest, SuggestSelectionsAreSymmetric) {
+ std::unique_ptr<Annotator> classifier = Annotator::FromPath(
+ GetTestModelPath(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {0, 3}),
+ CodepointSpan(0, 27));
+ EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {4, 9}),
+ CodepointSpan(0, 27));
+ EXPECT_EQ(
+ classifier->SuggestSelection("350 Third Street, Cambridge", {10, 16}),
+ CodepointSpan(0, 27));
+ EXPECT_EQ(classifier->SuggestSelection("a\nb\nc\n350 Third Street, Cambridge",
+ {16, 22}),
+ CodepointSpan(6, 33));
+}
+
+TEST_F(AnnotatorTest, SuggestSelectionWithNewLine) {
+ std::unique_ptr<Annotator> classifier = Annotator::FromPath(
+ GetTestModelPath(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ(classifier->SuggestSelection("abc\n857 225 3556", {4, 7}),
+ CodepointSpan(4, 16));
+ EXPECT_EQ(classifier->SuggestSelection("857 225 3556\nabc", {0, 3}),
+ CodepointSpan(0, 12));
+
+ SelectionOptions options;
+ EXPECT_EQ(classifier->SuggestSelection("857 225\n3556\nabc", {0, 3}, options),
+ CodepointSpan(0, 12));
+}
+
+TEST_F(AnnotatorTest, SuggestSelectionWithPunctuation) {
+ std::unique_ptr<Annotator> classifier = Annotator::FromPath(
+ GetTestModelPath(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ // From the right.
+ EXPECT_EQ(classifier->SuggestSelection(
+ "this afternoon BarackObama, gave a speech at", {15, 26}),
+ CodepointSpan(15, 26));
+
+ // From the right multiple.
+ EXPECT_EQ(classifier->SuggestSelection(
+ "this afternoon BarackObama,.,.,, gave a speech at", {15, 26}),
+ CodepointSpan(15, 26));
+
+ // From the left multiple.
+ EXPECT_EQ(classifier->SuggestSelection(
+ "this afternoon ,.,.,,BarackObama gave a speech at", {21, 32}),
+ CodepointSpan(21, 32));
+
+ // From both sides.
+ EXPECT_EQ(classifier->SuggestSelection(
+ "this afternoon !BarackObama,- gave a speech at", {16, 27}),
+ CodepointSpan(16, 27));
+}
+
+TEST_F(AnnotatorTest, SuggestSelectionNoCrashWithJunk) {
+ std::unique_ptr<Annotator> classifier = Annotator::FromPath(
+ GetTestModelPath(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ // Try passing in bunch of invalid selections.
+ EXPECT_EQ(classifier->SuggestSelection("", {0, 27}), CodepointSpan(0, 27));
+ EXPECT_EQ(classifier->SuggestSelection("", {-10, 27}),
+ CodepointSpan(-10, 27));
+ EXPECT_EQ(classifier->SuggestSelection("Word 1 2 3 hello!", {0, 27}),
+ CodepointSpan(0, 27));
+ EXPECT_EQ(classifier->SuggestSelection("Word 1 2 3 hello!", {-30, 300}),
+ CodepointSpan(-30, 300));
+ EXPECT_EQ(classifier->SuggestSelection("Word 1 2 3 hello!", {-10, -1}),
+ CodepointSpan(-10, -1));
+ EXPECT_EQ(classifier->SuggestSelection("Word 1 2 3 hello!", {100, 17}),
+ CodepointSpan(100, 17));
+
+ // Try passing invalid utf8.
+ EXPECT_EQ(classifier->SuggestSelection("\xf0\x9f\x98\x8b\x8b", {-1, -1}),
+ CodepointSpan(-1, -1));
+}
+
+TEST_F(AnnotatorTest, SuggestSelectionSelectSpace) {
+ std::unique_ptr<Annotator> classifier = Annotator::FromPath(
+ GetTestModelPath(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ(
+ classifier->SuggestSelection("call me at 857 225 3556 today", {14, 15}),
+ CodepointSpan(11, 23));
+ EXPECT_EQ(
+ classifier->SuggestSelection("call me at 857 225 3556 today", {10, 11}),
+ CodepointSpan(10, 11));
+ EXPECT_EQ(
+ classifier->SuggestSelection("call me at 857 225 3556 today", {23, 24}),
+ CodepointSpan(23, 24));
+ EXPECT_EQ(
+ classifier->SuggestSelection("call me at 857 225 3556, today", {23, 24}),
+ CodepointSpan(23, 24));
+ EXPECT_EQ(classifier->SuggestSelection("call me at 857 225 3556, today",
+ {14, 17}),
+ CodepointSpan(11, 25));
+ EXPECT_EQ(
+ classifier->SuggestSelection("call me at 857-225 3556, today", {14, 17}),
+ CodepointSpan(11, 23));
+ EXPECT_EQ(
+ classifier->SuggestSelection(
+ "let's meet at 350 Third Street Cambridge and go there", {30, 31}),
+ CodepointSpan(14, 40));
+ EXPECT_EQ(classifier->SuggestSelection("call me today", {4, 5}),
+ CodepointSpan(4, 5));
+ EXPECT_EQ(classifier->SuggestSelection("call me today", {7, 8}),
+ CodepointSpan(7, 8));
+
+ // With a punctuation around the selected whitespace.
+ EXPECT_EQ(
+ classifier->SuggestSelection(
+ "let's meet at 350 Third Street, Cambridge and go there", {31, 32}),
+ CodepointSpan(14, 41));
+
+ // When all's whitespace, should return the original indices.
+ EXPECT_EQ(classifier->SuggestSelection(" ", {0, 1}),
+ CodepointSpan(0, 1));
+ EXPECT_EQ(classifier->SuggestSelection(" ", {0, 3}),
+ CodepointSpan(0, 3));
+ EXPECT_EQ(classifier->SuggestSelection(" ", {2, 3}),
+ CodepointSpan(2, 3));
+ EXPECT_EQ(classifier->SuggestSelection(" ", {5, 6}),
+ CodepointSpan(5, 6));
+}
+
+TEST_F(AnnotatorTest, SnapLeftIfWhitespaceSelection) {
+ UnicodeText text;
+
+ text = UTF8ToUnicodeText("abcd efgh", /*do_copy=*/false);
+ EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, *unilib_),
+ CodepointSpan(3, 4));
+ text = UTF8ToUnicodeText("abcd ", /*do_copy=*/false);
+ EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, *unilib_),
+ CodepointSpan(3, 4));
+
+ // Nothing on the left.
+ text = UTF8ToUnicodeText(" efgh", /*do_copy=*/false);
+ EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, *unilib_),
+ CodepointSpan(4, 5));
+ text = UTF8ToUnicodeText(" efgh", /*do_copy=*/false);
+ EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({0, 1}, text, *unilib_),
+ CodepointSpan(0, 1));
+
+ // Whitespace only.
+ text = UTF8ToUnicodeText(" ", /*do_copy=*/false);
+ EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({2, 3}, text, *unilib_),
+ CodepointSpan(2, 3));
+ text = UTF8ToUnicodeText(" ", /*do_copy=*/false);
+ EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, *unilib_),
+ CodepointSpan(4, 5));
+ text = UTF8ToUnicodeText(" ", /*do_copy=*/false);
+ EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({0, 1}, text, *unilib_),
+ CodepointSpan(0, 1));
+}
+
+TEST_F(AnnotatorTest, Annotate) {
+ std::unique_ptr<Annotator> classifier = Annotator::FromPath(
+ GetTestModelPath(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ const std::string test_string =
+ "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
+ "number is 853 225 3556";
+ EXPECT_THAT(classifier->Annotate(test_string),
+ ElementsAreArray({
+ IsAnnotatedSpan(28, 55, "address"),
+ IsAnnotatedSpan(79, 91, "phone"),
+ }));
+
+ AnnotationOptions options;
+ EXPECT_THAT(classifier->Annotate("853 225 3556", options),
+ ElementsAreArray({IsAnnotatedSpan(0, 12, "phone")}));
+ EXPECT_THAT(classifier->Annotate("853 225\n3556", options),
+ ElementsAreArray({IsAnnotatedSpan(0, 12, "phone")}));
+ // Try passing invalid utf8.
+ EXPECT_TRUE(
+ classifier->Annotate("853 225 3556\n\xf0\x9f\x98\x8b\x8b", options)
+ .empty());
+}
+
+TEST_F(AnnotatorTest, AnnotatesWithBracketStripping) {
+ std::unique_ptr<Annotator> classifier = Annotator::FromPath(
+ GetTestModelPath(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ EXPECT_THAT(classifier->Annotate("call me at (0845) 100 1000 today"),
+ ElementsAreArray({
+ IsAnnotatedSpan(11, 26, "phone"),
+ }));
+
+ // Unpaired bracket stripping.
+ EXPECT_THAT(classifier->Annotate("call me at (07038201818 today"),
+ ElementsAreArray({
+ IsAnnotatedSpan(12, 23, "phone"),
+ }));
+ EXPECT_THAT(classifier->Annotate("call me at 07038201818) today"),
+ ElementsAreArray({
+ IsAnnotatedSpan(11, 22, "phone"),
+ }));
+ EXPECT_THAT(classifier->Annotate("call me at )07038201818( today"),
+ ElementsAreArray({
+ IsAnnotatedSpan(12, 23, "phone"),
+ }));
+}
+
+TEST_F(AnnotatorTest, AnnotatesWithBracketStrippingOptimized) {
+ std::unique_ptr<Annotator> classifier = Annotator::FromPath(
+ GetTestModelPath(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ AnnotationOptions options;
+ options.enable_optimization = true;
+
+ EXPECT_THAT(classifier->Annotate("call me at (0845) 100 1000 today", options),
+ ElementsAreArray({
+ IsAnnotatedSpan(11, 26, "phone"),
+ }));
+
+ // Unpaired bracket stripping.
+ EXPECT_THAT(classifier->Annotate("call me at (07038201818 today", options),
+ ElementsAreArray({
+ IsAnnotatedSpan(12, 23, "phone"),
+ }));
+ EXPECT_THAT(classifier->Annotate("call me at 07038201818) today", options),
+ ElementsAreArray({
+ IsAnnotatedSpan(11, 22, "phone"),
+ }));
+ EXPECT_THAT(classifier->Annotate("call me at )07038201818( today", options),
+ ElementsAreArray({
+ IsAnnotatedSpan(12, 23, "phone"),
+ }));
+}
+
+TEST_F(AnnotatorTest, AnnotatesOverlappingNumbers) {
+ std::unique_ptr<Annotator> classifier = Annotator::FromPath(
+ GetTestModelPath(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+ AnnotationOptions options;
+ options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_RAW;
+
+ // Number, float number and percentage annotator.
+ EXPECT_THAT(
+ classifier->Annotate("853 225 3556 and then turn it up 99%, 99 "
+ "number, 12345.12345 float number",
+ options),
+ UnorderedElementsAreArray(
+ {IsAnnotatedSpan(0, 12, "phone"), IsAnnotatedSpan(0, 3, "number"),
+ IsAnnotatedSpan(4, 7, "number"), IsAnnotatedSpan(8, 12, "number"),
+ IsAnnotatedSpan(33, 35, "number"),
+ IsAnnotatedSpan(33, 36, "percentage"),
+ IsAnnotatedSpan(38, 40, "number"), IsAnnotatedSpan(49, 60, "number"),
+ IsAnnotatedSpan(49, 60, "phone")}));
+}
+
+TEST_F(AnnotatorTest, DoesNotAnnotateNumbersInSmartUsecase) {
+ std::unique_ptr<Annotator> classifier = Annotator::FromPath(
+ GetTestModelPath(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+ AnnotationOptions options;
+ options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_SMART;
+
+ EXPECT_THAT(classifier->Annotate(
+ "853 225 3556 and then turn it up 99%, 99 number", options),
+ ElementsAreArray({IsAnnotatedSpan(0, 12, "phone"),
+ IsAnnotatedSpan(33, 36, "percentage")}));
+}
+
+void VerifyAnnotatesDurationsInRawMode(const Annotator* classifier) {
+ ASSERT_TRUE(classifier);
+ AnnotationOptions options;
+ options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_RAW;
+
+ // Duration annotator.
+ EXPECT_THAT(classifier->Annotate(
+ "it took 9 minutes and 7 seconds to get there", options),
+ Contains(IsDurationSpan(
+ /*start=*/8, /*end=*/31,
+ /*duration_ms=*/9 * 60 * 1000 + 7 * 1000)));
+}
+
+TEST_F(AnnotatorTest, AnnotatesDurationsInRawMode) {
+ std::unique_ptr<Annotator> classifier = Annotator::FromPath(
+ GetTestModelPath(), unilib_.get(), calendarlib_.get());
+ VerifyAnnotatesDurationsInRawMode(classifier.get());
+}
+
+void VerifyDurationAndRelativeTimeCanOverlapInRawMode(
+ const Annotator* classifier) {
+ ASSERT_TRUE(classifier);
+ AnnotationOptions options;
+ options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_RAW;
+ options.locales = "en";
+
+ const std::vector<AnnotatedSpan> annotations =
+ classifier->Annotate("let's meet in 3 hours", options);
+
+ EXPECT_THAT(annotations,
+ Contains(IsDatetimeSpan(/*start=*/11, /*end=*/21,
+ /*time_ms_utc=*/10800000L,
+ DatetimeGranularity::GRANULARITY_HOUR)));
+ EXPECT_THAT(annotations,
+ Contains(IsDurationSpan(/*start=*/14, /*end=*/21,
+ /*duration_ms=*/3 * 60 * 60 * 1000)));
+}
+
+TEST_F(AnnotatorTest, DurationAndRelativeTimeCanOverlapInRawMode) {
+ std::unique_ptr<Annotator> classifier = Annotator::FromPath(
+ GetTestModelPath(), unilib_.get(), calendarlib_.get());
+ VerifyDurationAndRelativeTimeCanOverlapInRawMode(classifier.get());
+}
+
+TEST_F(AnnotatorTest,
+ DurationAndRelativeTimeCanOverlapInRawModeWithDatetimeRegEx) {
+ std::string model_buffer = GetTestModelWithDatetimeRegEx();
+ std::unique_ptr<Annotator> classifier =
+ Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
+ unilib_.get(), calendarlib_.get());
+ VerifyDurationAndRelativeTimeCanOverlapInRawMode(classifier.get());
+}
+
+TEST_F(AnnotatorTest, AnnotateSplitLines) {
+ std::string model_buffer = ReadFile(GetTestModelPath());
+ model_buffer = ModifyAnnotatorModel(model_buffer, [](ModelT* model) {
+ model->selection_feature_options->only_use_line_with_click = true;
+ });
+ std::unique_ptr<Annotator> classifier =
+ Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
+ unilib_.get(), calendarlib_.get());
+
+ ASSERT_TRUE(classifier);
+
+ const std::string str1 =
+ "hey, sorry, just finished up. i didn't hear back from you in time.";
+ const std::string str2 = "2000 Main Avenue, Apt #201, San Mateo";
+
+ const int kAnnotationLength = 26;
+ EXPECT_THAT(classifier->Annotate(str1), IsEmpty());
+ EXPECT_THAT(
+ classifier->Annotate(str2),
+ ElementsAreArray({IsAnnotatedSpan(0, kAnnotationLength, "address")}));
+
+ const std::string str3 = str1 + "\n" + str2;
+ EXPECT_THAT(
+ classifier->Annotate(str3),
+ ElementsAreArray({IsAnnotatedSpan(
+ str1.size() + 1, str1.size() + 1 + kAnnotationLength, "address")}));
+}
+
+TEST_F(AnnotatorTest, UsePipeAsNewLineCharacterShouldAnnotateSplitLines) {
+ std::string model_buffer = ReadFile(GetTestModelPath());
+ model_buffer = ModifyAnnotatorModel(model_buffer, [](ModelT* model) {
+ model->selection_feature_options->only_use_line_with_click = true;
+ model->selection_feature_options->use_pipe_character_for_newline = true;
+ });
+ std::unique_ptr<Annotator> classifier =
+ Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
+ unilib_.get(), calendarlib_.get());
+
+ ASSERT_TRUE(classifier);
+
+ const std::string str1 = "hey, this is my phone number 853 225 3556";
+ const std::string str2 = "2000 Main Avenue, Apt #201, San Mateo";
+ const std::string str3 = str1 + "|" + str2;
+ const int kAnnotationLengthPhone = 12;
+ const int kAnnotationLengthAddress = 26;
+ // Splitting the lines on `str3` should have the same behavior (e.g. find the
+ // phone and address spans) as if we would annotate `str1` and `str2`
+ // individually.
+ const std::vector<AnnotatedSpan>& annotated_spans =
+ classifier->Annotate(str3);
+ EXPECT_THAT(annotated_spans,
+ ElementsAreArray(
+ {IsAnnotatedSpan(29, 29 + kAnnotationLengthPhone, "phone"),
+ IsAnnotatedSpan(static_cast<int>(str1.size()) + 1,
+ static_cast<int>(str1.size() + 1 +
+ kAnnotationLengthAddress),
+ "address")}));
+}
+
+TEST_F(AnnotatorTest,
+ NotUsingPipeAsNewLineCharacterShouldNotAnnotateSplitLines) {
+ std::string model_buffer = ReadFile(GetTestModelPath());
+ model_buffer = ModifyAnnotatorModel(model_buffer, [](ModelT* model) {
+ model->selection_feature_options->only_use_line_with_click = true;
+ model->selection_feature_options->use_pipe_character_for_newline = false;
+ });
+ std::unique_ptr<Annotator> classifier =
+ Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
+ unilib_.get(), calendarlib_.get());
+
+ ASSERT_TRUE(classifier);
+
+ const std::string str1 = "hey, this is my phone number 853 225 3556";
+ const std::string str2 = "2000 Main Avenue, Apt #201, San Mateo";
+ const std::string str3 = str1 + "|" + str2;
+ const std::vector<AnnotatedSpan>& annotated_spans =
+ classifier->Annotate(str3);
+ // Note: We only check that we get a single annotated span here when the '|'
+ // character is not used to split lines. The reason behind this is that the
+ // model is not precise for such example and the resulted annotated span might
+ // change when the model changes.
+ EXPECT_THAT(annotated_spans.size(), 1);
+}
+
+TEST_F(AnnotatorTest, AnnotateSmallBatches) {
+ const std::string test_model = ReadFile(GetTestModelPath());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ // Set the batch size.
+ unpacked_model->selection_options->batch_size = 4;
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ const std::string test_string =
+ "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
+ "number is 853 225 3556";
+ EXPECT_THAT(classifier->Annotate(test_string),
+ ElementsAreArray({
+ IsAnnotatedSpan(28, 55, "address"),
+ IsAnnotatedSpan(79, 91, "phone"),
+ }));
+
+ AnnotationOptions options;
+ EXPECT_THAT(classifier->Annotate("853 225 3556", options),
+ ElementsAreArray({IsAnnotatedSpan(0, 12, "phone")}));
+ EXPECT_THAT(classifier->Annotate("853 225\n3556", options),
+ ElementsAreArray({IsAnnotatedSpan(0, 12, "phone")}));
+}
+
+TEST_F(AnnotatorTest, AnnotateFilteringDiscardAll) {
+ const std::string test_model = ReadFile(GetTestModelPath());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
+ // Add test threshold.
+ unpacked_model->triggering_options->min_annotate_confidence =
+ 2.f; // Discards all results.
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ const std::string test_string =
+ "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
+ "number is 853 225 3556";
+
+ EXPECT_EQ(classifier->Annotate(test_string).size(), 0);
+}
+
+TEST_F(AnnotatorTest, AnnotateFilteringKeepAll) {
+ const std::string test_model = ReadFile(GetTestModelPath());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ // Add test thresholds.
+ unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
+ unpacked_model->triggering_options->min_annotate_confidence =
+ 0.f; // Keeps all results.
+ unpacked_model->triggering_options->enabled_modes = ModeFlag_ALL;
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ const std::string test_string =
+ "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
+ "number is 853 225 3556";
+ EXPECT_EQ(classifier->Annotate(test_string).size(), 2);
+}
+
+TEST_F(AnnotatorTest, AnnotateDisabled) {
+ const std::string test_model = ReadFile(GetTestModelPath());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ // Disable the model for annotation.
+ unpacked_model->enabled_modes = ModeFlag_CLASSIFICATION_AND_SELECTION;
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+ const std::string test_string =
+ "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
+ "number is 853 225 3556";
+ EXPECT_THAT(classifier->Annotate(test_string), IsEmpty());
+}
+
+TEST_F(AnnotatorTest, AnnotateFilteredCollections) {
+ const std::string test_model = ReadFile(GetTestModelPath());
+
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ test_model.c_str(), test_model.size(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ const std::string test_string =
+ "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
+ "number is 853 225 3556";
+
+ EXPECT_THAT(classifier->Annotate(test_string),
+ ElementsAreArray({
+ IsAnnotatedSpan(28, 55, "address"),
+ IsAnnotatedSpan(79, 91, "phone"),
+ }));
+
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+ unpacked_model->output_options.reset(new OutputOptionsT);
+
+ // Disable phone annotation
+ unpacked_model->output_options->filtered_collections_annotation.push_back(
+ "phone");
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ EXPECT_THAT(classifier->Annotate(test_string),
+ ElementsAreArray({
+ IsAnnotatedSpan(28, 55, "address"),
+ }));
+}
+
+TEST_F(AnnotatorTest, AnnotateFilteredCollectionsSuppress) {
+ const std::string test_model = ReadFile(GetTestModelPath());
+
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ test_model.c_str(), test_model.size(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ const std::string test_string =
+ "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
+ "number is 853 225 3556";
+
+ EXPECT_THAT(classifier->Annotate(test_string),
+ ElementsAreArray({
+ IsAnnotatedSpan(28, 55, "address"),
+ IsAnnotatedSpan(79, 91, "phone"),
+ }));
+
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+ unpacked_model->output_options.reset(new OutputOptionsT);
+
+ // We add a custom annotator that wins against the phone classification
+ // below and that we subsequently suppress.
+ unpacked_model->output_options->filtered_collections_annotation.push_back(
+ "suppress");
+
+ unpacked_model->regex_model->patterns.push_back(MakePattern(
+ "suppress", "(\\d{3} ?\\d{4})",
+ /*enabled_for_classification=*/false,
+ /*enabled_for_selection=*/false, /*enabled_for_annotation=*/true, 2.0));
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ EXPECT_THAT(classifier->Annotate(test_string),
+ ElementsAreArray({
+ IsAnnotatedSpan(28, 55, "address"),
+ }));
+}
+
+void VerifyClassifyTextDateInZurichTimezone(const Annotator* classifier) {
+ EXPECT_TRUE(classifier);
+ ClassificationOptions options;
+ options.reference_timezone = "Europe/Zurich";
+ options.locales = "en";
+
+ std::vector<ClassificationResult> result =
+ classifier->ClassifyText("january 1, 2017", {0, 15}, options);
+
+ EXPECT_THAT(result,
+ ElementsAre(IsDateResult(1483225200000,
+ DatetimeGranularity::GRANULARITY_DAY)));
+}
+
+TEST_F(AnnotatorTest, ClassifyTextDateInZurichTimezone) {
+ std::unique_ptr<Annotator> classifier = Annotator::FromPath(
+ GetTestModelPath(), unilib_.get(), calendarlib_.get());
+ VerifyClassifyTextDateInZurichTimezone(classifier.get());
+}
+
+TEST_F(AnnotatorTest, ClassifyTextDateInZurichTimezoneWithDatetimeRegEx) {
+ std::string model_buffer = GetTestModelWithDatetimeRegEx();
+ std::unique_ptr<Annotator> classifier =
+ Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
+ unilib_.get(), calendarlib_.get());
+ VerifyClassifyTextDateInZurichTimezone(classifier.get());
+}
+
+void VerifyClassifyTextDateInLATimezone(const Annotator* classifier) {
+ EXPECT_TRUE(classifier);
+ ClassificationOptions options;
+ options.reference_timezone = "America/Los_Angeles";
+ options.locales = "en";
+
+ std::vector<ClassificationResult> result =
+ classifier->ClassifyText("march 1, 2017", {0, 13}, options);
+
+ EXPECT_THAT(result,
+ ElementsAre(IsDateResult(1488355200000,
+ DatetimeGranularity::GRANULARITY_DAY)));
+}
+
+TEST_F(AnnotatorTest, ClassifyTextDateInLATimezoneWithDatetimeRegEx) {
+ std::string model_buffer = GetTestModelWithDatetimeRegEx();
+ std::unique_ptr<Annotator> classifier =
+ Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
+ unilib_.get(), calendarlib_.get());
+ VerifyClassifyTextDateInLATimezone(classifier.get());
+}
+
+TEST_F(AnnotatorTest, ClassifyTextDateInLATimezone) {
+ std::unique_ptr<Annotator> classifier = Annotator::FromPath(
+ GetTestModelPath(), unilib_.get(), calendarlib_.get());
+ VerifyClassifyTextDateInLATimezone(classifier.get());
+}
+
+void VerifyClassifyTextDateOnAotherLine(const Annotator* classifier) {
+ EXPECT_TRUE(classifier);
+ ClassificationOptions options;
+ options.reference_timezone = "Europe/Zurich";
+ options.locales = "en";
+
+ std::vector<ClassificationResult> result = classifier->ClassifyText(
+ "hello world this is the first line\n"
+ "january 1, 2017",
+ {35, 50}, options);
+
+ EXPECT_THAT(result,
+ ElementsAre(IsDateResult(1483225200000,
+ DatetimeGranularity::GRANULARITY_DAY)));
+}
+
+TEST_F(AnnotatorTest, ClassifyTextDateOnAotherLineWithDatetimeRegEx) {
+ std::string model_buffer = GetTestModelWithDatetimeRegEx();
+ std::unique_ptr<Annotator> classifier =
+ Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
+ unilib_.get(), calendarlib_.get());
+ VerifyClassifyTextDateOnAotherLine(classifier.get());
+}
+
+TEST_F(AnnotatorTest, ClassifyTextDateOnAotherLine) {
+ std::unique_ptr<Annotator> classifier = Annotator::FromPath(
+ GetTestModelPath(), unilib_.get(), calendarlib_.get());
+ VerifyClassifyTextDateOnAotherLine(classifier.get());
+}
+
+void VerifyClassifyTextWhenLocaleUSParsesDateAsMonthDay(
+ const Annotator* classifier) {
+ EXPECT_TRUE(classifier);
+ std::vector<ClassificationResult> result;
+ ClassificationOptions options;
+
+ options.reference_timezone = "Europe/Zurich";
+ options.locales = "en-US";
+ result = classifier->ClassifyText("03.05.1970 00:00am", {0, 18}, options);
+
+ // In US, the date should be interpreted as <month>.<day>.
+ EXPECT_THAT(result,
+ ElementsAre(IsDatetimeResult(
+ 5439600000, DatetimeGranularity::GRANULARITY_MINUTE)));
+}
+
+TEST_F(AnnotatorTest, ClassifyTextWhenLocaleUSParsesDateAsMonthDay) {
+ std::unique_ptr<Annotator> classifier = Annotator::FromPath(
+ GetTestModelPath(), unilib_.get(), calendarlib_.get());
+ VerifyClassifyTextWhenLocaleUSParsesDateAsMonthDay(classifier.get());
+}
+
+TEST_F(AnnotatorTest,
+ ClassifyTextWhenLocaleUSParsesDateAsMonthDayWithDatetimeRegEx) {
+ std::string model_buffer = GetTestModelWithDatetimeRegEx();
+ std::unique_ptr<Annotator> classifier =
+ Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
+ unilib_.get(), calendarlib_.get());
+ VerifyClassifyTextWhenLocaleUSParsesDateAsMonthDay(classifier.get());
+}
+
+TEST_F(AnnotatorTest, ClassifyTextWhenLocaleGermanyParsesDateAsMonthDay) {
+ std::string model_buffer = GetTestModelWithDatetimeRegEx();
+ std::unique_ptr<Annotator> classifier =
+ Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
+ unilib_.get(), calendarlib_.get());
+ EXPECT_TRUE(classifier);
+ std::vector<ClassificationResult> result;
+ ClassificationOptions options;
+
+ options.reference_timezone = "Europe/Zurich";
+ options.locales = "de";
+ result = classifier->ClassifyText("03.05.1970 00:00vorm", {0, 20}, options);
+
+ // In Germany, the date should be interpreted as <day>.<month>.
+ EXPECT_THAT(result,
+ ElementsAre(IsDatetimeResult(
+ 10537200000, DatetimeGranularity::GRANULARITY_MINUTE)));
+}
+
+TEST_F(AnnotatorTest, ClassifyTextAmbiguousDatetime) {
+ std::unique_ptr<Annotator> classifier = Annotator::FromPath(
+ GetTestModelPath(), unilib_.get(), calendarlib_.get());
+ EXPECT_TRUE(classifier);
+ ClassificationOptions options;
+ options.reference_timezone = "Europe/Zurich";
+ options.locales = "en-US";
+ const std::vector<ClassificationResult> result =
+ classifier->ClassifyText("set an alarm for 10:30", {17, 22}, options);
+
+ EXPECT_THAT(
+ result,
+ ElementsAre(
+ IsDatetimeResult(34200000, DatetimeGranularity::GRANULARITY_MINUTE),
+ IsDatetimeResult(77400000, DatetimeGranularity::GRANULARITY_MINUTE)));
+}
+
+TEST_F(AnnotatorTest, AnnotateAmbiguousDatetime) {
+ std::unique_ptr<Annotator> classifier = Annotator::FromPath(
+ GetTestModelPath(), unilib_.get(), calendarlib_.get());
+ EXPECT_TRUE(classifier);
+ AnnotationOptions options;
+ options.reference_timezone = "Europe/Zurich";
+ options.locales = "en-US";
+ const std::vector<AnnotatedSpan> spans =
+ classifier->Annotate("set an alarm for 10:30", options);
+
+ ASSERT_EQ(spans.size(), 1);
+ const std::vector<ClassificationResult> result = spans[0].classification;
+ EXPECT_THAT(
+ result,
+ ElementsAre(
+ IsDatetimeResult(34200000, DatetimeGranularity::GRANULARITY_MINUTE),
+ IsDatetimeResult(77400000, DatetimeGranularity::GRANULARITY_MINUTE)));
+}
+
+TEST_F(AnnotatorTest, SuggestTextDateDisabled) {
+ std::string test_model = GetTestModelWithDatetimeRegEx();
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ // Disable the patterns for selection.
+ for (int i = 0; i < unpacked_model->datetime_model->patterns.size(); i++) {
+ unpacked_model->datetime_model->patterns[i]->enabled_modes =
+ ModeFlag_ANNOTATION_AND_CLASSIFICATION;
+ }
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+ EXPECT_EQ("date",
+ FirstResult(classifier->ClassifyText("january 1, 2017", {0, 15})));
+ EXPECT_EQ(classifier->SuggestSelection("january 1, 2017", {0, 7}),
+ CodepointSpan(0, 7));
+ EXPECT_THAT(classifier->Annotate("january 1, 2017"),
+ ElementsAreArray({IsAnnotatedSpan(0, 15, "date")}));
+}
+
+TEST_F(AnnotatorTest, AnnotatesWithGrammarModel) {
+ const std::string test_model = ReadFile(GetTestModelPath());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ // Add test grammar model.
+ unpacked_model->grammar_model.reset(new GrammarModelT);
+ GrammarModelT* grammar_model = unpacked_model->grammar_model.get();
+ grammar_model->tokenizer_options.reset(new GrammarTokenizerOptionsT);
+ grammar_model->tokenizer_options->tokenization_type = TokenizationType_ICU;
+ grammar_model->tokenizer_options->icu_preserve_whitespace_tokens = false;
+ grammar_model->tokenizer_options->tokenize_on_script_change = true;
+
+ // Add test rules.
+ grammar_model->rules.reset(new grammar::RulesSetT);
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ grammar::Rules rules(locale_shard_map);
+ rules.Add("<tv_detective>", {"jessica", "fletcher"});
+ rules.Add("<tv_detective>", {"columbo"});
+ rules.Add("<tv_detective>", {"magnum"});
+ rules.Add(
+ "<famous_person>", {"<tv_detective>"},
+ /*callback=*/
+ static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
+ /*callback_param=*/0 /* rule classification result */);
+
+ // Set result.
+ grammar_model->rule_classification_result.emplace_back(
+ new GrammarModel_::RuleClassificationResultT);
+ GrammarModel_::RuleClassificationResultT* result =
+ grammar_model->rule_classification_result.back().get();
+ result->collection_name = "famous person";
+ result->enabled_modes = ModeFlag_ALL;
+ rules.Finalize().Serialize(/*include_debug_information=*/false,
+ grammar_model->rules.get());
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ const std::string test_string =
+ "Did you see the Novel Connection episode where Jessica Fletcher helps "
+ "Magnum solve the case? I thought that was with Columbo ...";
+
+ EXPECT_THAT(classifier->Annotate(test_string),
+ ElementsAre(IsAnnotatedSpan(47, 63, "famous person"),
+ IsAnnotatedSpan(70, 76, "famous person"),
+ IsAnnotatedSpan(117, 124, "famous person")));
+ EXPECT_THAT(FirstResult(classifier->ClassifyText("Jessica Fletcher",
+ CodepointSpan{0, 16})),
+ Eq("famous person"));
+ EXPECT_THAT(classifier->SuggestSelection("Jessica Fletcher", {0, 7}),
+ Eq(CodepointSpan{0, 16}));
+}
+
+TEST_F(AnnotatorTest, ResolveConflictsTrivial) {
+ TestingAnnotator classifier(unilib_.get(), calendarlib_.get());
+
+ std::vector<AnnotatedSpan> candidates{
+ {MakeAnnotatedSpan({0, 1}, "phone", 1.0)}};
+ std::vector<Locale> locales = {Locale::FromBCP47("en")};
+
+ BaseOptions options;
+ options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_SMART;
+ std::vector<int> chosen;
+ classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
+ locales, options,
+ /*interpreter_manager=*/nullptr, &chosen);
+ EXPECT_THAT(chosen, ElementsAreArray({0}));
+}
+
+TEST_F(AnnotatorTest, ResolveConflictsSequence) {
+ TestingAnnotator classifier(unilib_.get(), calendarlib_.get());
+
+ std::vector<AnnotatedSpan> candidates{{
+ MakeAnnotatedSpan({0, 1}, "phone", 1.0),
+ MakeAnnotatedSpan({1, 2}, "phone", 1.0),
+ MakeAnnotatedSpan({2, 3}, "phone", 1.0),
+ MakeAnnotatedSpan({3, 4}, "phone", 1.0),
+ MakeAnnotatedSpan({4, 5}, "phone", 1.0),
+ }};
+ std::vector<Locale> locales = {Locale::FromBCP47("en")};
+
+ BaseOptions options;
+ options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_SMART;
+ std::vector<int> chosen;
+ classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
+ locales, options,
+ /*interpreter_manager=*/nullptr, &chosen);
+ EXPECT_THAT(chosen, ElementsAreArray({0, 1, 2, 3, 4}));
+}
+
+TEST_F(AnnotatorTest, ResolveConflictsThreeSpans) {
+ TestingAnnotator classifier(unilib_.get(), calendarlib_.get());
+
+ std::vector<AnnotatedSpan> candidates{{
+ MakeAnnotatedSpan({0, 3}, "phone", 1.0),
+ MakeAnnotatedSpan({1, 5}, "phone", 0.5), // Looser!
+ MakeAnnotatedSpan({3, 7}, "phone", 1.0),
+ }};
+ std::vector<Locale> locales = {Locale::FromBCP47("en")};
+
+ BaseOptions options;
+ options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_SMART;
+ std::vector<int> chosen;
+ classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
+ locales, options,
+ /*interpreter_manager=*/nullptr, &chosen);
+ EXPECT_THAT(chosen, ElementsAreArray({0, 2}));
+}
+
+TEST_F(AnnotatorTest, ResolveConflictsThreeSpansReversed) {
+ TestingAnnotator classifier(unilib_.get(), calendarlib_.get());
+
+ std::vector<AnnotatedSpan> candidates{{
+ MakeAnnotatedSpan({0, 3}, "phone", 0.5), // Looser!
+ MakeAnnotatedSpan({1, 5}, "phone", 1.0),
+ MakeAnnotatedSpan({3, 7}, "phone", 0.6), // Looser!
+ }};
+ std::vector<Locale> locales = {Locale::FromBCP47("en")};
+
+ BaseOptions options;
+ options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_SMART;
+ std::vector<int> chosen;
+ classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
+ locales, options,
+ /*interpreter_manager=*/nullptr, &chosen);
+ EXPECT_THAT(chosen, ElementsAreArray({1}));
+}
+
+TEST_F(AnnotatorTest, DoesNotPrioritizeLongerSpanWhenDoingConflictResolution) {
+ TestingAnnotator classifier(unilib_.get(), calendarlib_.get());
+
+ std::vector<AnnotatedSpan> candidates{{
+ MakeAnnotatedSpan({3, 7}, "unit", 1),
+ MakeAnnotatedSpan({5, 13}, "unit", 1), // Looser!
+ MakeAnnotatedSpan({5, 30}, "url", 1), // Looser!
+ MakeAnnotatedSpan({14, 20}, "email", 1),
+ }};
+ std::vector<Locale> locales = {Locale::FromBCP47("en")};
+
+ BaseOptions options;
+ options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_SMART;
+ std::vector<int> chosen;
+ classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
+ locales, options,
+ /*interpreter_manager=*/nullptr, &chosen);
+ // Picks the first and the last annotations because they do not overlap.
+ EXPECT_THAT(chosen, ElementsAreArray({0, 3}));
+}
+
+TEST_F(AnnotatorTest, PrioritizeLongerSpanWhenDoingConflictResolution) {
+ const std::string test_model = ReadFile(GetTestModelPath());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+ TC3_CHECK(libtextclassifier3::DecompressModel(unpacked_model.get()));
+ unpacked_model->conflict_resolution_options.reset(
+ new Model_::ConflictResolutionOptionsT);
+ unpacked_model->conflict_resolution_options->prioritize_longest_annotation =
+ true;
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<TestingAnnotator> classifier =
+ TestingAnnotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get(), calendarlib_.get());
+ TC3_CHECK(classifier != nullptr);
+
+ std::vector<AnnotatedSpan> candidates{{
+ MakeAnnotatedSpan({3, 7}, "unit", 1), // Looser!
+ MakeAnnotatedSpan({5, 13}, "unit", 1), // Looser!
+ MakeAnnotatedSpan({5, 30}, "url", 1), // Pick longest match.
+ MakeAnnotatedSpan({14, 20}, "email", 1), // Looser!
+ }};
+ std::vector<Locale> locales = {Locale::FromBCP47("en")};
+
+ BaseOptions options;
+ options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_SMART;
+ std::vector<int> chosen;
+ classifier->ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
+ locales, options,
+ /*interpreter_manager=*/nullptr, &chosen);
+ EXPECT_THAT(chosen, ElementsAreArray({2}));
+}
+
+TEST_F(AnnotatorTest, ResolveConflictsFiveSpans) {
+ TestingAnnotator classifier(unilib_.get(), calendarlib_.get());
+
+ std::vector<AnnotatedSpan> candidates{{
+ MakeAnnotatedSpan({0, 3}, "phone", 0.5),
+ MakeAnnotatedSpan({1, 5}, "other", 1.0), // Looser!
+ MakeAnnotatedSpan({3, 7}, "phone", 0.6),
+ MakeAnnotatedSpan({8, 12}, "phone", 0.6), // Looser!
+ MakeAnnotatedSpan({11, 15}, "phone", 0.9),
+ }};
+ std::vector<Locale> locales = {Locale::FromBCP47("en")};
+
+ BaseOptions options;
+ options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_SMART;
+ std::vector<int> chosen;
+ classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
+ locales, options,
+ /*interpreter_manager=*/nullptr, &chosen);
+ EXPECT_THAT(chosen, ElementsAreArray({0, 2, 4}));
+}
+
+TEST_F(AnnotatorTest, ResolveConflictsRawModeOverlapsAllowedKnowledgeFirst) {
+ TestingAnnotator classifier(unilib_.get(), calendarlib_.get());
+
+ std::vector<AnnotatedSpan> candidates{{
+ MakeAnnotatedSpan({0, 15}, "entity", 0.7,
+ AnnotatedSpan::Source::KNOWLEDGE),
+ MakeAnnotatedSpan({5, 10}, "address", 0.6),
+ }};
+ std::vector<Locale> locales = {Locale::FromBCP47("en")};
+
+ BaseOptions options;
+ options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_RAW;
+ std::vector<int> chosen;
+ classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
+ locales, options,
+ /*interpreter_manager=*/nullptr, &chosen);
+ EXPECT_THAT(chosen, ElementsAreArray({0, 1}));
+}
+
+TEST_F(AnnotatorTest, ResolveConflictsRawModeOverlapsAllowedKnowledgeSecond) {
+ TestingAnnotator classifier(unilib_.get(), calendarlib_.get());
+
+ std::vector<AnnotatedSpan> candidates{{
+ MakeAnnotatedSpan({0, 15}, "address", 0.7),
+ MakeAnnotatedSpan({5, 10}, "entity", 0.6,
+ AnnotatedSpan::Source::KNOWLEDGE),
+ }};
+ std::vector<Locale> locales = {Locale::FromBCP47("en")};
+
+ BaseOptions options;
+ options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_RAW;
+ std::vector<int> chosen;
+ classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
+ locales, options,
+ /*interpreter_manager=*/nullptr, &chosen);
+ EXPECT_THAT(chosen, ElementsAreArray({0, 1}));
+}
+
+TEST_F(AnnotatorTest, ResolveConflictsRawModeOverlapsAllowedBothKnowledge) {
+ TestingAnnotator classifier(unilib_.get(), calendarlib_.get());
+
+ std::vector<AnnotatedSpan> candidates{{
+ MakeAnnotatedSpan({0, 15}, "entity", 0.7,
+ AnnotatedSpan::Source::KNOWLEDGE),
+ MakeAnnotatedSpan({5, 10}, "entity", 0.6,
+ AnnotatedSpan::Source::KNOWLEDGE),
+ }};
+ std::vector<Locale> locales = {Locale::FromBCP47("en")};
+
+ BaseOptions options;
+ options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_RAW;
+ std::vector<int> chosen;
+ classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
+ locales, options,
+ /*interpreter_manager=*/nullptr, &chosen);
+ EXPECT_THAT(chosen, ElementsAreArray({0, 1}));
+}
+
+TEST_F(AnnotatorTest, ResolveConflictsRawModeOverlapsNotAllowed) {
+ TestingAnnotator classifier(unilib_.get(), calendarlib_.get());
+
+ std::vector<AnnotatedSpan> candidates{{
+ MakeAnnotatedSpan({0, 15}, "address", 0.7),
+ MakeAnnotatedSpan({5, 10}, "date", 0.6),
+ }};
+ std::vector<Locale> locales = {Locale::FromBCP47("en")};
+
+ BaseOptions options;
+ options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_RAW;
+ std::vector<int> chosen;
+ classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
+ locales, options,
+ /*interpreter_manager=*/nullptr, &chosen);
+ EXPECT_THAT(chosen, ElementsAreArray({0}));
+}
+
+TEST_F(AnnotatorTest, ResolveConflictsRawModeGeneralOverlapsAllowed) {
+ TestingAnnotator classifier(
+ unilib_.get(), calendarlib_.get(), [](ModelT* model) {
+ model->conflict_resolution_options.reset(
+ new Model_::ConflictResolutionOptionsT);
+ model->conflict_resolution_options->do_conflict_resolution_in_raw_mode =
+ false;
+ });
+
+ std::vector<AnnotatedSpan> candidates{{
+ MakeAnnotatedSpan({0, 15}, "address", 0.7),
+ MakeAnnotatedSpan({5, 10}, "date", 0.6),
+ }};
+ std::vector<Locale> locales = {Locale::FromBCP47("en")};
+
+ BaseOptions options;
+ options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_RAW;
+ std::vector<int> chosen;
+ classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
+ locales, options,
+ /*interpreter_manager=*/nullptr, &chosen);
+ EXPECT_THAT(chosen, ElementsAreArray({0, 1}));
+}
+
+void VerifyLongInput(const Annotator* classifier) {
+ ASSERT_TRUE(classifier);
+
+ for (const auto& type_value_pair :
+ std::vector<std::pair<std::string, std::string>>{
+ {"address", "350 Third Street, Cambridge"},
+ {"phone", "123 456-7890"},
+ {"url", "www.google.com"},
+ {"email", "someone@gmail.com"},
+ {"flight", "LX 38"},
+ {"date", "September 1, 2018"}}) {
+ const std::string input_100k = std::string(50000, ' ') +
+ type_value_pair.second +
+ std::string(50000, ' ');
+ const int value_length = type_value_pair.second.size();
+
+ AnnotationOptions annotation_options;
+ annotation_options.locales = "en";
+ EXPECT_THAT(classifier->Annotate(input_100k, annotation_options),
+ ElementsAreArray({IsAnnotatedSpan(50000, 50000 + value_length,
+ type_value_pair.first)}));
+ SelectionOptions selection_options;
+ selection_options.locales = "en";
+ EXPECT_EQ(classifier->SuggestSelection(input_100k, {50000, 50001},
+ selection_options),
+ CodepointSpan(50000, 50000 + value_length));
+
+ ClassificationOptions classification_options;
+ classification_options.locales = "en";
+ EXPECT_EQ(type_value_pair.first,
+ FirstResult(classifier->ClassifyText(
+ input_100k, {50000, 50000 + value_length},
+ classification_options)));
+ }
+}
+
+TEST_F(AnnotatorTest, LongInput) {
+ std::unique_ptr<Annotator> classifier = Annotator::FromPath(
+ GetTestModelPath(), unilib_.get(), calendarlib_.get());
+ VerifyLongInput(classifier.get());
+}
+
+TEST_F(AnnotatorTest, LongInputWithRegExDatetime) {
+ std::string model_buffer = GetTestModelWithDatetimeRegEx();
+ std::unique_ptr<Annotator> classifier =
+ Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
+ unilib_.get(), calendarlib_.get());
+ VerifyLongInput(classifier.get());
+}
+
+// These coarse tests are there only to make sure the execution happens in
+// reasonable amount of time.
+TEST_F(AnnotatorTest, LongInputNoResultCheck) {
+ std::unique_ptr<Annotator> classifier = Annotator::FromPath(
+ GetTestModelPath(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ for (const std::string& value :
+ std::vector<std::string>{"http://www.aaaaaaaaaaaaaaaaaaaa.com "}) {
+ const std::string input_100k =
+ std::string(50000, ' ') + value + std::string(50000, ' ');
+ const int value_length = value.size();
+
+ classifier->Annotate(input_100k);
+ classifier->SuggestSelection(input_100k, {50000, 50001});
+ classifier->ClassifyText(input_100k, {50000, 50000 + value_length});
+ }
+}
+
+TEST_F(AnnotatorTest, MaxTokenLength) {
+ const std::string test_model = ReadFile(GetTestModelPath());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ std::unique_ptr<Annotator> classifier;
+
+ // With unrestricted number of tokens should behave normally.
+ unpacked_model->classification_options->max_num_tokens = -1;
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+ classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ(FirstResult(classifier->ClassifyText(
+ "I live at 350 Third Street, Cambridge.", {10, 37})),
+ "address");
+
+ // Raise the maximum number of tokens to suppress the classification.
+ unpacked_model->classification_options->max_num_tokens = 3;
+
+ flatbuffers::FlatBufferBuilder builder2;
+ FinishModelBuffer(builder2, Model::Pack(builder2, unpacked_model.get()));
+ classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder2.GetBufferPointer()),
+ builder2.GetSize(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ(FirstResult(classifier->ClassifyText(
+ "I live at 350 Third Street, Cambridge.", {10, 37})),
+ "other");
+}
+
+TEST_F(AnnotatorTest, MinAddressTokenLength) {
+ const std::string test_model = ReadFile(GetTestModelPath());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ std::unique_ptr<Annotator> classifier;
+
+ // With unrestricted number of address tokens should behave normally.
+ unpacked_model->classification_options->address_min_num_tokens = 0;
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+ classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ(FirstResult(classifier->ClassifyText(
+ "I live at 350 Third Street, Cambridge.", {10, 37})),
+ "address");
+
+ // Raise number of address tokens to suppress the address classification.
+ unpacked_model->classification_options->address_min_num_tokens = 5;
+
+ flatbuffers::FlatBufferBuilder builder2;
+ FinishModelBuffer(builder2, Model::Pack(builder2, unpacked_model.get()));
+ classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder2.GetBufferPointer()),
+ builder2.GetSize(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ(FirstResult(classifier->ClassifyText(
+ "I live at 350 Third Street, Cambridge.", {10, 37})),
+ "other");
+}
+
+TEST_F(AnnotatorTest, WhenOtherCollectionPriorityHighOtherIsPreferredToFlight) {
+ const std::string test_model = ReadFile(GetTestModelPath());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+ unpacked_model->triggering_options->other_collection_priority_score = 1.0;
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ(FirstResult(classifier->ClassifyText("LX37", {0, 4})), "other");
+}
+
+TEST_F(AnnotatorTest, WhenOtherCollectionPriorityHighFlightIsPreferredToOther) {
+ const std::string test_model = ReadFile(GetTestModelPath());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+ unpacked_model->triggering_options->other_collection_priority_score = -100.0;
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ(FirstResult(classifier->ClassifyText("LX37", {0, 4})), "flight");
+}
+
+TEST_F(AnnotatorTest, VisitAnnotatorModel) {
+ EXPECT_TRUE(
+ VisitAnnotatorModel<bool>(GetTestModelPath(), [](const Model* model) {
+ if (model == nullptr) {
+ return false;
+ }
+ return true;
+ }));
+ EXPECT_FALSE(VisitAnnotatorModel<bool>(
+ GetModelPath() + "non_existing_model.fb", [](const Model* model) {
+ if (model == nullptr) {
+ return false;
+ }
+ return true;
+ }));
+}
+
+TEST_F(AnnotatorTest, TriggersWhenNoLanguageDetected) {
+ std::string model_buffer = ReadFile(GetTestModelPath());
+ model_buffer = ModifyAnnotatorModel(
+ model_buffer, [](ModelT* model) { model->triggering_locales = "en,cs"; });
+ std::unique_ptr<Annotator> classifier =
+ Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
+ unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ EXPECT_THAT(classifier->Annotate("(555) 225-3556"),
+ ElementsAreArray({IsAnnotatedSpan(0, 14, "phone")}));
+ EXPECT_EQ("phone",
+ FirstResult(classifier->ClassifyText("(555) 225-3556", {0, 14})));
+ EXPECT_EQ(classifier->SuggestSelection("(555) 225-3556", {6, 9}),
+ CodepointSpan(0, 14));
+}
+
+TEST_F(AnnotatorTest, AnnotateTriggersWhenSupportedLanguageDetected) {
+ std::string model_buffer = ReadFile(GetTestModelPath());
+ model_buffer = ModifyAnnotatorModel(
+ model_buffer, [](ModelT* model) { model->triggering_locales = "en,cs"; });
+ std::unique_ptr<Annotator> classifier =
+ Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
+ unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+ AnnotationOptions options;
+ options.detected_text_language_tags = "cs";
+
+ EXPECT_THAT(classifier->Annotate("(555) 225-3556", options),
+ ElementsAreArray({IsAnnotatedSpan(0, 14, "phone")}));
+}
+
+TEST_F(AnnotatorTest, AnnotateDoesntTriggerWhenUnsupportedLanguageDetected) {
+ std::string model_buffer = ReadFile(GetTestModelPath());
+ model_buffer = ModifyAnnotatorModel(
+ model_buffer, [](ModelT* model) { model->triggering_locales = "en,cs"; });
+ std::unique_ptr<Annotator> classifier =
+ Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
+ unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+ AnnotationOptions options;
+ options.detected_text_language_tags = "de";
+
+ EXPECT_THAT(classifier->Annotate("(555) 225-3556", options), IsEmpty());
+}
+
+TEST_F(AnnotatorTest, ClassifyTextTriggersWhenSupportedLanguageDetected) {
+ std::string model_buffer = ReadFile(GetTestModelPath());
+ model_buffer = ModifyAnnotatorModel(
+ model_buffer, [](ModelT* model) { model->triggering_locales = "en,cs"; });
+ std::unique_ptr<Annotator> classifier =
+ Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
+ unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+ ClassificationOptions options;
+ options.detected_text_language_tags = "cs";
+
+ EXPECT_EQ("phone", FirstResult(classifier->ClassifyText("(555) 225-3556",
+ {0, 14}, options)));
+}
+
+TEST_F(AnnotatorTest,
+ ClassifyTextDoesntTriggerWhenUnsupportedLanguageDetected) {
+ std::string model_buffer = ReadFile(GetTestModelPath());
+ model_buffer = ModifyAnnotatorModel(
+ model_buffer, [](ModelT* model) { model->triggering_locales = "en,cs"; });
+ std::unique_ptr<Annotator> classifier =
+ Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
+ unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+ ClassificationOptions options;
+ options.detected_text_language_tags = "de";
+
+ EXPECT_THAT(classifier->ClassifyText("(555) 225-3556", {0, 14}, options),
+ IsEmpty());
+}
+
+TEST_F(AnnotatorTest, SuggestSelectionTriggersWhenSupportedLanguageDetected) {
+ std::string model_buffer = ReadFile(GetTestModelPath());
+ model_buffer = ModifyAnnotatorModel(
+ model_buffer, [](ModelT* model) { model->triggering_locales = "en,cs"; });
+ std::unique_ptr<Annotator> classifier =
+ Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
+ unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+ SelectionOptions options;
+ options.detected_text_language_tags = "cs";
+
+ EXPECT_EQ(classifier->SuggestSelection("(555) 225-3556", {6, 9}, options),
+ CodepointSpan(0, 14));
+}
+
+TEST_F(AnnotatorTest,
+ SuggestSelectionDoesntTriggerWhenUnsupportedLanguageDetected) {
+ std::string model_buffer = ReadFile(GetTestModelPath());
+ model_buffer = ModifyAnnotatorModel(
+ model_buffer, [](ModelT* model) { model->triggering_locales = "en,cs"; });
+ std::unique_ptr<Annotator> classifier =
+ Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
+ unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+ SelectionOptions options;
+ options.detected_text_language_tags = "de";
+
+ EXPECT_EQ(classifier->SuggestSelection("(555) 225-3556", {6, 9}, options),
+ CodepointSpan(6, 9));
+}
+
+TEST_F(AnnotatorTest, MlModelTriggersWhenNoLanguageDetected) {
+ std::string model_buffer = ReadFile(GetTestModelPath());
+ model_buffer = ModifyAnnotatorModel(model_buffer, [](ModelT* model) {
+ model->triggering_locales = "en,cs";
+ model->triggering_options->locales = "en,cs";
+ });
+ std::unique_ptr<Annotator> classifier =
+ Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
+ unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ EXPECT_THAT(classifier->Annotate("350 Third Street, Cambridge"),
+ ElementsAreArray({IsAnnotatedSpan(0, 27, "address")}));
+ EXPECT_EQ("address", FirstResult(classifier->ClassifyText(
+ "350 Third Street, Cambridge", {0, 27})));
+ EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {4, 9}),
+ CodepointSpan(0, 27));
+}
+
+TEST_F(AnnotatorTest, MlModelAnnotateTriggersWhenSupportedLanguageDetected) {
+ std::string model_buffer = ReadFile(GetTestModelPath());
+ model_buffer = ModifyAnnotatorModel(model_buffer, [](ModelT* model) {
+ model->triggering_locales = "en,cs";
+ model->triggering_options->locales = "en,cs";
+ });
+ std::unique_ptr<Annotator> classifier =
+ Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
+ unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+ AnnotationOptions options;
+ options.detected_text_language_tags = "cs";
+
+ EXPECT_THAT(classifier->Annotate("350 Third Street, Cambridge", options),
+ ElementsAreArray({IsAnnotatedSpan(0, 27, "address")}));
+}
+
+TEST_F(AnnotatorTest,
+ MlModelAnnotateDoesntTriggerWhenUnsupportedLanguageDetected) {
+ std::string model_buffer = ReadFile(GetTestModelPath());
+ model_buffer = ModifyAnnotatorModel(model_buffer, [](ModelT* model) {
+ model->triggering_locales = "en,cs";
+ model->triggering_options->locales = "en,cs";
+ });
+ std::unique_ptr<Annotator> classifier =
+ Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
+ unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+ AnnotationOptions options;
+ options.detected_text_language_tags = "de";
+
+ EXPECT_THAT(classifier->Annotate("350 Third Street, Cambridge", options),
+ IsEmpty());
+}
+
+TEST_F(AnnotatorTest,
+ MlModelClassifyTextTriggersWhenSupportedLanguageDetected) {
+ std::string model_buffer = ReadFile(GetTestModelPath());
+ model_buffer = ModifyAnnotatorModel(model_buffer, [](ModelT* model) {
+ model->triggering_locales = "en,cs";
+ model->triggering_options->locales = "en,cs";
+ });
+ std::unique_ptr<Annotator> classifier =
+ Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
+ unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+ ClassificationOptions options;
+ options.detected_text_language_tags = "cs";
+
+ EXPECT_EQ("address", FirstResult(classifier->ClassifyText(
+ "350 Third Street, Cambridge", {0, 27}, options)));
+}
+
+TEST_F(AnnotatorTest,
+ MlModelClassifyTextDoesntTriggerWhenUnsupportedLanguageDetected) {
+ std::string model_buffer = ReadFile(GetTestModelPath());
+ model_buffer = ModifyAnnotatorModel(model_buffer, [](ModelT* model) {
+ model->triggering_locales = "en,cs";
+ model->triggering_options->locales = "en,cs";
+ });
+ std::unique_ptr<Annotator> classifier =
+ Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
+ unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+ ClassificationOptions options;
+ options.detected_text_language_tags = "de";
+
+ EXPECT_THAT(
+ classifier->ClassifyText("350 Third Street, Cambridge", {0, 27}, options),
+ IsEmpty());
+}
+
+TEST_F(AnnotatorTest,
+ MlModelSuggestSelectionTriggersWhenSupportedLanguageDetected) {
+ std::string model_buffer = ReadFile(GetTestModelPath());
+ model_buffer = ModifyAnnotatorModel(model_buffer, [](ModelT* model) {
+ model->triggering_locales = "en,cs";
+ model->triggering_options->locales = "en,cs";
+ });
+ std::unique_ptr<Annotator> classifier =
+ Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
+ unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+ SelectionOptions options;
+ options.detected_text_language_tags = "cs";
+
+ EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {4, 9},
+ options),
+ CodepointSpan(0, 27));
+}
+
+TEST_F(AnnotatorTest,
+ MlModelSuggestSelectionDoesntTriggerWhenUnsupportedLanguageDetected) {
+ std::string model_buffer = ReadFile(GetTestModelPath());
+ model_buffer = ModifyAnnotatorModel(model_buffer, [](ModelT* model) {
+ model->triggering_locales = "en,cs";
+ model->triggering_options->locales = "en,cs";
+ });
+ std::unique_ptr<Annotator> classifier =
+ Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
+ unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+ SelectionOptions options;
+ options.detected_text_language_tags = "de";
+
+ EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {4, 9},
+ options),
+ CodepointSpan(4, 9));
+}
+
+void VerifyClassifyTextOutputsDatetimeEntityData(const Annotator* classifier) {
+ EXPECT_TRUE(classifier);
+ std::vector<ClassificationResult> result;
+ ClassificationOptions options;
+ options.locales = "en-US";
+
+ result = classifier->ClassifyText("03.05.1970 00:00am", {0, 18}, options);
+
+ ASSERT_GE(result.size(), 0);
+ const EntityData* entity_data =
+ GetEntityData(result[0].serialized_entity_data.data());
+ ASSERT_NE(entity_data, nullptr);
+ ASSERT_NE(entity_data->datetime(), nullptr);
+ EXPECT_EQ(entity_data->datetime()->time_ms_utc(), 5443200000L);
+ EXPECT_EQ(entity_data->datetime()->granularity(),
+ EntityData_::Datetime_::Granularity_GRANULARITY_MINUTE);
+ EXPECT_EQ(entity_data->datetime()->datetime_component()->size(), 6);
+
+ auto* meridiem = entity_data->datetime()->datetime_component()->Get(0);
+ EXPECT_EQ(meridiem->component_type(),
+ EntityData_::Datetime_::DatetimeComponent_::ComponentType_MERIDIEM);
+ EXPECT_EQ(meridiem->absolute_value(), 0);
+ EXPECT_EQ(meridiem->relative_count(), 0);
+ EXPECT_EQ(meridiem->relation_type(),
+ EntityData_::Datetime_::DatetimeComponent_::RelationType_ABSOLUTE);
+
+ auto* minute = entity_data->datetime()->datetime_component()->Get(1);
+ EXPECT_EQ(minute->component_type(),
+ EntityData_::Datetime_::DatetimeComponent_::ComponentType_MINUTE);
+ EXPECT_EQ(minute->absolute_value(), 0);
+ EXPECT_EQ(minute->relative_count(), 0);
+ EXPECT_EQ(minute->relation_type(),
+ EntityData_::Datetime_::DatetimeComponent_::RelationType_ABSOLUTE);
+
+ auto* hour = entity_data->datetime()->datetime_component()->Get(2);
+ EXPECT_EQ(hour->component_type(),
+ EntityData_::Datetime_::DatetimeComponent_::ComponentType_HOUR);
+ EXPECT_EQ(hour->absolute_value(), 0);
+ EXPECT_EQ(hour->relative_count(), 0);
+ EXPECT_EQ(hour->relation_type(),
+ EntityData_::Datetime_::DatetimeComponent_::RelationType_ABSOLUTE);
+
+ auto* day = entity_data->datetime()->datetime_component()->Get(3);
+ EXPECT_EQ(
+ day->component_type(),
+ EntityData_::Datetime_::DatetimeComponent_::ComponentType_DAY_OF_MONTH);
+ EXPECT_EQ(day->absolute_value(), 5);
+ EXPECT_EQ(day->relative_count(), 0);
+ EXPECT_EQ(day->relation_type(),
+ EntityData_::Datetime_::DatetimeComponent_::RelationType_ABSOLUTE);
+
+ auto* month = entity_data->datetime()->datetime_component()->Get(4);
+ EXPECT_EQ(month->component_type(),
+ EntityData_::Datetime_::DatetimeComponent_::ComponentType_MONTH);
+ EXPECT_EQ(month->absolute_value(), 3);
+ EXPECT_EQ(month->relative_count(), 0);
+ EXPECT_EQ(month->relation_type(),
+ EntityData_::Datetime_::DatetimeComponent_::RelationType_ABSOLUTE);
+
+ auto* year = entity_data->datetime()->datetime_component()->Get(5);
+ EXPECT_EQ(year->component_type(),
+ EntityData_::Datetime_::DatetimeComponent_::ComponentType_YEAR);
+ EXPECT_EQ(year->absolute_value(), 1970);
+ EXPECT_EQ(year->relative_count(), 0);
+ EXPECT_EQ(year->relation_type(),
+ EntityData_::Datetime_::DatetimeComponent_::RelationType_ABSOLUTE);
+}
+
+TEST_F(AnnotatorTest, ClassifyTextOutputsDatetimeEntityData) {
+ std::unique_ptr<Annotator> classifier = Annotator::FromPath(
+ GetTestModelPath(), unilib_.get(), calendarlib_.get());
+ VerifyClassifyTextOutputsDatetimeEntityData(classifier.get());
+}
+
+TEST_F(AnnotatorTest, ClassifyTextOutputsDatetimeEntityDataWithDatetimeRegEx) {
+ std::string model_buffer = GetTestModelWithDatetimeRegEx();
+ std::unique_ptr<Annotator> classifier =
+ Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
+ unilib_.get(), calendarlib_.get());
+ VerifyClassifyTextOutputsDatetimeEntityData(classifier.get());
+}
+
+void VerifyAnnotateOutputsDatetimeEntityData(const Annotator* classifier) {
+ EXPECT_TRUE(classifier);
+ std::vector<AnnotatedSpan> result;
+ AnnotationOptions options;
+ options.is_serialized_entity_data_enabled = true;
+ options.locales = "en";
+
+ result = classifier->Annotate("September 1, 2019", options);
+
+ ASSERT_GE(result.size(), 0);
+ ASSERT_GE(result[0].classification.size(), 0);
+ ASSERT_EQ(result[0].classification[0].collection, "date");
+ const EntityData* entity_data =
+ GetEntityData(result[0].classification[0].serialized_entity_data.data());
+ ASSERT_NE(entity_data, nullptr);
+ ASSERT_NE(entity_data->datetime(), nullptr);
+ EXPECT_EQ(entity_data->datetime()->time_ms_utc(), 1567296000000L);
+ EXPECT_EQ(entity_data->datetime()->granularity(),
+ EntityData_::Datetime_::Granularity_GRANULARITY_DAY);
+ EXPECT_EQ(entity_data->datetime()->datetime_component()->size(), 3);
+
+ auto* day = entity_data->datetime()->datetime_component()->Get(0);
+ EXPECT_EQ(
+ day->component_type(),
+ EntityData_::Datetime_::DatetimeComponent_::ComponentType_DAY_OF_MONTH);
+ EXPECT_EQ(day->absolute_value(), 1);
+ EXPECT_EQ(day->relative_count(), 0);
+ EXPECT_EQ(day->relation_type(),
+ EntityData_::Datetime_::DatetimeComponent_::RelationType_ABSOLUTE);
+
+ auto* month = entity_data->datetime()->datetime_component()->Get(1);
+ EXPECT_EQ(month->component_type(),
+ EntityData_::Datetime_::DatetimeComponent_::ComponentType_MONTH);
+ EXPECT_EQ(month->absolute_value(), 9);
+ EXPECT_EQ(month->relative_count(), 0);
+ EXPECT_EQ(month->relation_type(),
+ EntityData_::Datetime_::DatetimeComponent_::RelationType_ABSOLUTE);
+
+ auto* year = entity_data->datetime()->datetime_component()->Get(2);
+ EXPECT_EQ(year->component_type(),
+ EntityData_::Datetime_::DatetimeComponent_::ComponentType_YEAR);
+ EXPECT_EQ(year->absolute_value(), 2019);
+ EXPECT_EQ(year->relative_count(), 0);
+ EXPECT_EQ(year->relation_type(),
+ EntityData_::Datetime_::DatetimeComponent_::RelationType_ABSOLUTE);
+}
+
+TEST_F(AnnotatorTest, AnnotateOutputsDatetimeEntityData) {
+ std::unique_ptr<Annotator> classifier = Annotator::FromPath(
+ GetTestModelPath(), unilib_.get(), calendarlib_.get());
+ VerifyAnnotateOutputsDatetimeEntityData(classifier.get());
+}
+
+TEST_F(AnnotatorTest, AnnotateOutputsDatetimeEntityDataWithDatatimeRegEx) {
+ std::string model_buffer = GetTestModelWithDatetimeRegEx();
+ std::unique_ptr<Annotator> classifier =
+ Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
+ unilib_.get(), calendarlib_.get());
+ VerifyAnnotateOutputsDatetimeEntityData(classifier.get());
+}
+
+TEST_F(AnnotatorTest, AnnotateOutputsMoneyEntityData) {
+ // std::string model_buffer = GetTestModelWithDatetimeRegEx();
+ // std::unique_ptr<Annotator> classifier =
+ // Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
+ // unilib_.get(), calendarlib_.get());
+ std::unique_ptr<Annotator> classifier = Annotator::FromPath(
+ GetTestModelPath(), unilib_.get(), calendarlib_.get());
+ EXPECT_TRUE(classifier);
+ AnnotationOptions options;
+ options.is_serialized_entity_data_enabled = true;
+
+ ExpectFirstEntityIsMoney(classifier->Annotate("3.5 CHF", options), "CHF",
+ /*amount=*/"3.5", /*whole_part=*/3,
+ /*decimal_part=*/5, /*nanos=*/500000000);
+ ExpectFirstEntityIsMoney(classifier->Annotate("CHF 3.5", options), "CHF",
+ /*amount=*/"3.5", /*whole_part=*/3,
+ /*decimal_part=*/5, /*nanos=*/500000000);
+ ExpectFirstEntityIsMoney(
+ classifier->Annotate("For online purchase of CHF 23.00 enter", options),
+ "CHF", /*amount=*/"23.00", /*whole_part=*/23, /*decimal_part=*/0,
+ /*nanos=*/0);
+ ExpectFirstEntityIsMoney(
+ classifier->Annotate("For online purchase of 23.00 CHF enter", options),
+ "CHF", /*amount=*/"23.00", /*whole_part=*/23, /*decimal_part=*/0,
+ /*nanos=*/0);
+ ExpectFirstEntityIsMoney(classifier->Annotate("4.8198£", options), "£",
+ /*amount=*/"4.8198", /*whole_part=*/4,
+ /*decimal_part=*/8198, /*nanos=*/819800000);
+ ExpectFirstEntityIsMoney(classifier->Annotate("£4.8198", options), "£",
+ /*amount=*/"4.8198", /*whole_part=*/4,
+ /*decimal_part=*/8198, /*nanos=*/819800000);
+ ExpectFirstEntityIsMoney(classifier->Annotate("$0.0255", options), "$",
+ /*amount=*/"0.0255", /*whole_part=*/0,
+ /*decimal_part=*/255, /*nanos=*/25500000);
+ ExpectFirstEntityIsMoney(classifier->Annotate("$0.0255", options), "$",
+ /*amount=*/"0.0255", /*whole_part=*/0,
+ /*decimal_part=*/255, /*nanos=*/25500000);
+ ExpectFirstEntityIsMoney(
+ classifier->Annotate("for txn of INR 000.00 at RAZOR-PAY ZOMATO ONLINE "
+ "OR on card ending 0000.",
+ options),
+ "INR", /*amount=*/"000.00", /*whole_part=*/0, /*decimal_part=*/0,
+ /*nanos=*/0);
+ ExpectFirstEntityIsMoney(
+ classifier->Annotate("for txn of 000.00 INR at RAZOR-PAY ZOMATO ONLINE "
+ "OR on card ending 0000.",
+ options),
+ "INR", /*amount=*/"000.00", /*whole_part=*/0, /*decimal_part=*/0,
+ /*nanos=*/0);
+
+ ExpectFirstEntityIsMoney(classifier->Annotate("35 CHF", options), "CHF",
+ /*amount=*/"35",
+ /*whole_part=*/35, /*decimal_part=*/0, /*nanos=*/0);
+ ExpectFirstEntityIsMoney(classifier->Annotate("CHF 35", options), "CHF",
+ /*amount=*/"35", /*whole_part=*/35,
+ /*decimal_part=*/0, /*nanos=*/0);
+ ExpectFirstEntityIsMoney(
+ classifier->Annotate("and win back up to CHF 150 - with digitec",
+ options),
+ "CHF", /*amount=*/"150", /*whole_part=*/150, /*decimal_part=*/0,
+ /*nanos=*/0);
+ ExpectFirstEntityIsMoney(
+ classifier->Annotate("and win back up to 150 CHF - with digitec",
+ options),
+ "CHF", /*amount=*/"150", /*whole_part=*/150, /*decimal_part=*/0,
+ /*nanos=*/0);
+
+ ExpectFirstEntityIsMoney(classifier->Annotate("3.555.333 CHF", options),
+ "CHF", /*amount=*/"3.555.333",
+ /*whole_part=*/3555333, /*decimal_part=*/0,
+ /*nanos=*/0);
+ ExpectFirstEntityIsMoney(classifier->Annotate("CHF 3.555.333", options),
+ "CHF", /*amount=*/"3.555.333",
+ /*whole_part=*/3555333, /*decimal_part=*/0,
+ /*nanos=*/0);
+ ExpectFirstEntityIsMoney(classifier->Annotate("10,000 CHF", options), "CHF",
+ /*amount=*/"10,000", /*whole_part=*/10000,
+ /*decimal_part=*/0, /*nanos=*/0);
+ ExpectFirstEntityIsMoney(classifier->Annotate("CHF 10,000", options), "CHF",
+ /*amount=*/"10,000", /*whole_part=*/10000,
+ /*decimal_part=*/0, /*nanos=*/0);
+
+ ExpectFirstEntityIsMoney(classifier->Annotate("3,555.33 CHF", options), "CHF",
+ /*amount=*/"3,555.33", /*whole_part=*/3555,
+ /*decimal_part=*/33, /*nanos=*/330000000);
+ ExpectFirstEntityIsMoney(classifier->Annotate("CHF 3,555.33", options), "CHF",
+ /*amount=*/"3,555.33", /*whole_part=*/3555,
+ /*decimal_part=*/33, /*nanos=*/330000000);
+ ExpectFirstEntityIsMoney(classifier->Annotate("$3,000.00", options), "$",
+ /*amount=*/"3,000.00", /*whole_part=*/3000,
+ /*decimal_part=*/0, /*nanos=*/0);
+ ExpectFirstEntityIsMoney(classifier->Annotate("3,000.00$", options), "$",
+ /*amount=*/"3,000.00", /*whole_part=*/3000,
+ /*decimal_part=*/0, /*nanos=*/0);
+
+ ExpectFirstEntityIsMoney(classifier->Annotate("1.2 CHF", options), "CHF",
+ /*amount=*/"1.2", /*whole_part=*/1,
+ /*decimal_part=*/2, /*nanos=*/200000000);
+ ExpectFirstEntityIsMoney(classifier->Annotate("CHF1.2", options), "CHF",
+ /*amount=*/"1.2", /*whole_part=*/1,
+ /*decimal_part=*/2, /*nanos=*/200000000);
+
+ ExpectFirstEntityIsMoney(classifier->Annotate("$1.123456789", options), "$",
+ /*amount=*/"1.123456789", /*whole_part=*/1,
+ /*decimal_part=*/123456789, /*nanos=*/123456789);
+ ExpectFirstEntityIsMoney(classifier->Annotate("10.01 CHF", options), "CHF",
+ /*amount=*/"10.01", /*whole_part=*/10,
+ /*decimal_part=*/1, /*nanos=*/10000000);
+
+ ExpectFirstEntityIsMoney(classifier->Annotate("$59 Million", options), "$",
+ /*amount=*/"59 million", /*whole_part=*/59000000,
+ /*decimal_part=*/0, /*nanos=*/0);
+ ExpectFirstEntityIsMoney(classifier->Annotate("7.05k €", options), "€",
+ /*amount=*/"7.05 k", /*whole_part=*/7050,
+ /*decimal_part=*/5, /*nanos=*/0);
+ ExpectFirstEntityIsMoney(classifier->Annotate("7.123456789m €", options), "€",
+ /*amount=*/"7.123456789 m", /*whole_part=*/7123456,
+ /*decimal_part=*/123456789, /*nanos=*/789000000);
+ ExpectFirstEntityIsMoney(classifier->Annotate("7.000056789k €", options), "€",
+ /*amount=*/"7.000056789 k", /*whole_part=*/7000,
+ /*decimal_part=*/56789, /*nanos=*/56789000);
+
+ ExpectFirstEntityIsMoney(classifier->Annotate("$59.3 Billion", options), "$",
+ /*amount=*/"59.3 billion", /*whole_part=*/59,
+ /*decimal_part=*/3, /*nanos=*/300000000);
+ ExpectFirstEntityIsMoney(classifier->Annotate("$1.5 Billion", options), "$",
+ /*amount=*/"1.5 billion", /*whole_part=*/1500000000,
+ /*decimal_part=*/5, /*nanos=*/0);
+}
+
+TEST_F(AnnotatorTest, TranslateAction) {
+ std::unique_ptr<Annotator> classifier = Annotator::FromPath(
+ GetTestModelPath(), unilib_.get(), calendarlib_.get());
+ std::unique_ptr<libtextclassifier3::mobile::lang_id::LangId> langid_model =
+ libtextclassifier3::mobile::lang_id::GetLangIdFromFlatbufferFile(GetModelPath() +
+ "lang_id.smfb");
+ classifier->SetLangId(langid_model.get());
+
+ ClassificationOptions options;
+ options.user_familiar_language_tags = "de";
+
+ std::vector<ClassificationResult> classifications =
+ classifier->ClassifyText("hello, how are you doing?", {11, 14}, options);
+ EXPECT_EQ(classifications.size(), 1);
+ EXPECT_EQ(classifications[0].collection, "translate");
+}
+
+TEST_F(AnnotatorTest, AnnotateStructuredInputCallsMultipleAnnotators) {
+ std::unique_ptr<Annotator> classifier = Annotator::FromPath(
+ GetTestModelPath(), unilib_.get(), calendarlib_.get());
+
+ std::vector<InputFragment> string_fragments = {
+ {.text = "He owes me 3.5 CHF."},
+ {.text = "...was born on 13/12/1989."},
+ };
+
+ AnnotationOptions annotation_options;
+ annotation_options.locales = "en";
+ StatusOr<Annotations> annotations_status =
+ classifier->AnnotateStructuredInput(string_fragments, annotation_options);
+ ASSERT_TRUE(annotations_status.ok());
+ Annotations annotations = annotations_status.ValueOrDie();
+ ASSERT_EQ(annotations.annotated_spans.size(), 2);
+ EXPECT_THAT(annotations.annotated_spans[0],
+ ElementsAreArray({IsAnnotatedSpan(11, 18, "money")}));
+ EXPECT_THAT(annotations.annotated_spans[1],
+ ElementsAreArray({IsAnnotatedSpan(15, 25, "date")}));
+}
+
+void VerifyInputFragmentTimestampOverridesAnnotationOptions(
+ const Annotator* classifier) {
+ AnnotationOptions annotation_options;
+ annotation_options.locales = "en";
+ annotation_options.reference_time_ms_utc =
+ 1554465190000; // 04/05/2019 11:53 am
+ int64 fragment_reference_time = 946727580000; // 01/01/2000 11:53 am
+ std::vector<InputFragment> string_fragments = {
+ {.text = "New event at 17:20"},
+ {
+ .text = "New event at 17:20",
+ .datetime_options = Optional<DatetimeOptions>(
+ {.reference_time_ms_utc = fragment_reference_time}),
+ }};
+ StatusOr<Annotations> annotations_status =
+ classifier->AnnotateStructuredInput(string_fragments, annotation_options);
+ ASSERT_TRUE(annotations_status.ok());
+ Annotations annotations = annotations_status.ValueOrDie();
+ ASSERT_EQ(annotations.annotated_spans.size(), 2);
+ EXPECT_THAT(annotations.annotated_spans[0],
+ ElementsAreArray({IsDatetimeSpan(
+ /*start=*/13, /*end=*/18, /*time_ms_utc=*/1554484800000,
+ DatetimeGranularity::GRANULARITY_MINUTE)}));
+ EXPECT_THAT(annotations.annotated_spans[1],
+ ElementsAreArray({IsDatetimeSpan(
+ /*start=*/13, /*end=*/18, /*time_ms_utc=*/946747200000,
+ DatetimeGranularity::GRANULARITY_MINUTE)}));
+}
+
+TEST_F(AnnotatorTest,
+ InputFragmentTimestampOverridesAnnotationOptionsWithDatetimeRegEx) {
+ std::string model_buffer = GetTestModelWithDatetimeRegEx();
+ std::unique_ptr<Annotator> classifier =
+ Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
+ unilib_.get(), calendarlib_.get());
+ VerifyInputFragmentTimestampOverridesAnnotationOptions(classifier.get());
+}
+
+TEST_F(AnnotatorTest, InputFragmentTimestampOverridesAnnotationOptions) {
+ std::unique_ptr<Annotator> classifier = Annotator::FromPath(
+ GetTestModelPath(), unilib_.get(), calendarlib_.get());
+ VerifyInputFragmentTimestampOverridesAnnotationOptions(classifier.get());
+}
+
+void VerifyInputFragmentTimezoneOverridesAnnotationOptions(
+ const Annotator* classifier) {
+ std::vector<InputFragment> string_fragments = {
+ {.text = "11/12/2020 17:20"},
+ {
+ .text = "11/12/2020 17:20",
+ .datetime_options = Optional<DatetimeOptions>(
+ {.reference_timezone = "Europe/Zurich"}),
+ }};
+ AnnotationOptions annotation_options;
+ annotation_options.locales = "en-US";
+ StatusOr<Annotations> annotations_status =
+ classifier->AnnotateStructuredInput(string_fragments, annotation_options);
+ ASSERT_TRUE(annotations_status.ok());
+ Annotations annotations = annotations_status.ValueOrDie();
+ ASSERT_EQ(annotations.annotated_spans.size(), 2);
+ EXPECT_THAT(annotations.annotated_spans[0],
+ ElementsAreArray({IsDatetimeSpan(
+ /*start=*/0, /*end=*/16, /*time_ms_utc=*/1605201600000,
+ DatetimeGranularity::GRANULARITY_MINUTE)}));
+ EXPECT_THAT(annotations.annotated_spans[1],
+ ElementsAreArray({IsDatetimeSpan(
+ /*start=*/0, /*end=*/16, /*time_ms_utc=*/1605198000000,
+ DatetimeGranularity::GRANULARITY_MINUTE)}));
+}
+
+TEST_F(AnnotatorTest, InputFragmentTimezoneOverridesAnnotationOptions) {
+ std::unique_ptr<Annotator> classifier = Annotator::FromPath(
+ GetTestModelPath(), unilib_.get(), calendarlib_.get());
+ VerifyInputFragmentTimezoneOverridesAnnotationOptions(classifier.get());
+}
+
+TEST_F(AnnotatorTest,
+ InputFragmentTimezoneOverridesAnnotationOptionsWithDatetimeRegEx) {
+ std::string model_buffer = GetTestModelWithDatetimeRegEx();
+ std::unique_ptr<Annotator> classifier =
+ Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
+ unilib_.get(), calendarlib_.get());
+ VerifyInputFragmentTimezoneOverridesAnnotationOptions(classifier.get());
+}
+
+namespace {
+void AddDummyRegexDatetimeModel(ModelT* unpacked_model) {
+ unpacked_model->datetime_model.reset(new DatetimeModelT);
+ // This needs to be false otherwise we'd have to define some extractor. When
+ // this is false, the 0-th capturing group (whole match) from the pattern is
+ // used to come up with the indices.
+ unpacked_model->datetime_model->use_extractors_for_locating = false;
+ unpacked_model->datetime_model->locales.push_back("en-US");
+ unpacked_model->datetime_model->default_locales.push_back(0); // en-US
+ unpacked_model->datetime_model->patterns.push_back(
+ std::unique_ptr<DatetimeModelPatternT>(new DatetimeModelPatternT));
+ unpacked_model->datetime_model->patterns.back()->locales.push_back(
+ 0); // en-US
+ unpacked_model->datetime_model->patterns.back()->regexes.push_back(
+ std::unique_ptr<DatetimeModelPattern_::RegexT>(
+ new DatetimeModelPattern_::RegexT));
+ unpacked_model->datetime_model->patterns.back()->regexes.back()->pattern =
+ "THIS_MATCHES_IN_REGEX_MODEL";
+ unpacked_model->datetime_model->patterns.back()
+ ->regexes.back()
+ ->groups.push_back(DatetimeGroupType_GROUP_UNUSED);
+}
+} // namespace
+
+TEST_F(AnnotatorTest, AnnotateFiltersOutExactDuplicates) {
+ std::unique_ptr<Annotator> classifier = Annotator::FromPath(
+ GetTestModelPath(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ // This test assumes that both ML model and Regex model trigger on the
+ // following text and output "phone" annotation for it.
+ const std::string test_string = "1000000000";
+ AnnotationOptions options;
+ options.annotation_usecase = ANNOTATION_USECASE_RAW;
+ int num_phones = 0;
+ for (const AnnotatedSpan& span : classifier->Annotate(test_string, options)) {
+ if (span.classification[0].collection == "phone") {
+ num_phones++;
+ }
+ }
+
+ EXPECT_EQ(num_phones, 1);
+}
+
+// This test tests the optimizations in Annotator, which make some of the
+// annotators not run in the RAW mode when not requested. We test here that the
+// results indeed don't contain such annotations. However, this is a bick hacky,
+// since one could also add post-filtering, in which case these tests would
+// trivially pass.
+TEST_F(AnnotatorTest, RawModeOptimizationWorks) {
+ std::unique_ptr<Annotator> classifier = Annotator::FromPath(
+ GetTestModelPath(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ AnnotationOptions options;
+ options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_RAW;
+ // Requesting a non-existing type to avoid overlap with existing types.
+ options.entity_types.insert("some_unknown_entity_type");
+
+ // Normally, the following command would produce the following annotations:
+ // Span(19, 24, date, 1.000000),
+ // Span(53, 56, number, 1.000000),
+ // Span(53, 80, address, 1.000000),
+ // Span(128, 142, phone, 1.000000),
+ // Span(129, 132, number, 1.000000),
+ // Span(192, 200, phone, 1.000000),
+ // Span(192, 206, datetime, 1.000000),
+ // Span(246, 253, number, 1.000000),
+ // Span(246, 253, phone, 1.000000),
+ // Span(292, 293, number, 1.000000),
+ // Span(292, 301, duration, 1.000000) }
+ // But because of the optimizations, it doesn't produce anything, since
+ // we didn't request any of these entities.
+ EXPECT_THAT(classifier->Annotate(R"--(I saw Barack Obama today
+ 350 Third Street, Cambridge
+ my phone number is (853) 225-3556
+ this is when we met: 1.9.2021 13:00
+ my number: 1234567
+ duration: 3 minutes
+ )--",
+ options),
+ IsEmpty());
+}
+
+void VerifyAnnotateSupportsPointwiseCollectionFilteringInRawMode(
+ const Annotator* classifier) {
+ ASSERT_TRUE(classifier);
+ struct Example {
+ std::string collection;
+ std::string text;
+ };
+
+ // These examples contain one example per annotator, to check that each of
+ // the annotators can work in the RAW mode on its own.
+ //
+ // WARNING: This list doesn't contain yet entries for the app, contact, and
+ // person annotators. Hopefully this won't be needed once b/155214735 is
+ // fixed and the piping shared across annotators.
+ std::vector<Example> examples{
+ // ML Model.
+ {.collection = Collections::Address(),
+ .text = "... 350 Third Street, Cambridge ..."},
+ // Datetime annotator.
+ {.collection = Collections::DateTime(), .text = "... 1.9.2020 10:00 ..."},
+ // Duration annotator.
+ {.collection = Collections::Duration(),
+ .text = "... 3 hours and 9 seconds ..."},
+ // Regex annotator.
+ {.collection = Collections::Email(),
+ .text = "... platypus@theanimal.org ..."},
+ // Number annotator.
+ {.collection = Collections::Number(), .text = "... 100 ..."},
+ };
+
+ for (const Example& example : examples) {
+ AnnotationOptions options;
+ options.locales = "en";
+ options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_RAW;
+ options.entity_types.insert(example.collection);
+
+ EXPECT_THAT(classifier->Annotate(example.text, options),
+ Contains(IsAnnotationWithType(example.collection)))
+ << " text: '" << example.text
+ << "', collection: " << example.collection;
+ }
+}
+
+TEST_F(AnnotatorTest, AnnotateSupportsPointwiseCollectionFilteringInRawMode) {
+ std::unique_ptr<Annotator> classifier = Annotator::FromPath(
+ GetTestModelPath(), unilib_.get(), calendarlib_.get());
+ VerifyAnnotateSupportsPointwiseCollectionFilteringInRawMode(classifier.get());
+}
+
+TEST_F(AnnotatorTest,
+ AnnotateSupportsPointwiseCollectionFilteringInRawModeWithDatetimeRegEx) {
+ std::string model_buffer = GetTestModelWithDatetimeRegEx();
+ std::unique_ptr<Annotator> classifier =
+ Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
+ unilib_.get(), calendarlib_.get());
+ VerifyAnnotateSupportsPointwiseCollectionFilteringInRawMode(classifier.get());
+}
+
+TEST_F(AnnotatorTest, InitializeFromString) {
+ const std::string test_model = ReadFile(GetTestModelPath());
+ std::unique_ptr<Annotator> classifier =
+ Annotator::FromString(test_model, unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+ EXPECT_THAT(classifier->Annotate("(857) 225-3556"), Not(IsEmpty()));
+}
+
+// Regression test for cl/338280366. Enabling only_use_line_with_click had
+// the effect, that some annotators in the previous code releases would
+// receive only the last line of the input text. This test has the entity on the
+// first line (duration).
+TEST_F(AnnotatorTest, RegressionTestOnlyUseLineWithClickLastLine) {
+ const std::string test_model = ReadFile(GetTestModelPath());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ std::unique_ptr<Annotator> classifier;
+
+ // With unrestricted number of tokens should behave normally.
+ unpacked_model->selection_feature_options->only_use_line_with_click = true;
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+ classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ AnnotationOptions options;
+ options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_RAW;
+
+ const std::vector<AnnotatedSpan> annotations =
+ classifier->Annotate("let's meet in 3 hours\nbut not now", options);
+
+ EXPECT_THAT(annotations, Contains(IsDurationSpan(
+ /*start=*/14, /*end=*/21,
+ /*duration_ms=*/3 * 60 * 60 * 1000)));
+}
+
+TEST_F(AnnotatorTest, DoesntProcessInvalidUtf8) {
+ const std::string test_model = ReadFile(GetTestModelPath());
+ const std::string invalid_utf8_text_with_phone_number =
+ "(857) 225-3556 \xed\xa0\x80\xed\xa0\x80\xed\xa0\x80\xed\xa0\x80";
+
+ std::unique_ptr<Annotator> classifier =
+ Annotator::FromString(test_model, unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+ EXPECT_THAT(classifier->Annotate(invalid_utf8_text_with_phone_number),
+ IsEmpty());
+ EXPECT_THAT(
+ classifier->SuggestSelection(invalid_utf8_text_with_phone_number, {1, 4}),
+ Eq(CodepointSpan{1, 4}));
+ EXPECT_THAT(
+ classifier->ClassifyText(invalid_utf8_text_with_phone_number, {0, 14}),
+ IsEmpty());
+}
+
+} // namespace test_internal
+} // namespace libtextclassifier3
diff --git a/native/annotator/annotator_test-include.h b/native/annotator/annotator_test-include.h
new file mode 100644
index 0000000..bcbb9e9
--- /dev/null
+++ b/native/annotator/annotator_test-include.h
@@ -0,0 +1,77 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_ANNOTATOR_TEST_INCLUDE_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_ANNOTATOR_TEST_INCLUDE_H_
+
+#include <fstream>
+#include <string>
+
+#include "annotator/annotator.h"
+#include "utils/base/logging.h"
+#include "utils/jvm-test-utils.h"
+#include "utils/test-data-test-utils.h"
+#include "utils/testing/annotator.h"
+#include "utils/utf8/unilib.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+namespace test_internal {
+
+inline std::string GetModelPath() {
+ return GetTestDataPath("annotator/test_data/");
+}
+
+class TestingAnnotator : public Annotator {
+ public:
+ TestingAnnotator(
+ const UniLib* unilib, const CalendarLib* calendarlib,
+ const std::function<void(ModelT*)> model_update_fn = [](ModelT* model) {
+ }) {
+ owned_buffer_ = CreateEmptyModel(model_update_fn);
+ ValidateAndInitialize(libtextclassifier3::ViewModel(owned_buffer_.data(),
+ owned_buffer_.size()),
+ unilib, calendarlib);
+ }
+
+ static std::unique_ptr<TestingAnnotator> FromUnownedBuffer(
+ const char* buffer, int size, const UniLib* unilib = nullptr,
+ const CalendarLib* calendarlib = nullptr) {
+ // Safe to downcast from Annotator* to TestingAnnotator* because the
+ // subclass is not adding any new members.
+ return std::unique_ptr<TestingAnnotator>(
+ reinterpret_cast<TestingAnnotator*>(
+ Annotator::FromUnownedBuffer(buffer, size, unilib, calendarlib)
+ .release()));
+ }
+
+ using Annotator::ResolveConflicts;
+};
+
+class AnnotatorTest : public ::testing::TestWithParam<const char*> {
+ protected:
+ AnnotatorTest()
+ : unilib_(CreateUniLibForTesting()),
+ calendarlib_(CreateCalendarLibForTesting()) {}
+
+ std::unique_ptr<UniLib> unilib_;
+ std::unique_ptr<CalendarLib> calendarlib_;
+};
+
+} // namespace test_internal
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_ANNOTATOR_TEST_INCLUDE_H_
diff --git a/native/annotator/cached-features.cc b/native/annotator/cached-features.cc
index 480c044..1a14a42 100644
--- a/native/annotator/cached-features.cc
+++ b/native/annotator/cached-features.cc
@@ -88,10 +88,9 @@
click_pos -= extraction_span_.first;
AppendFeaturesInternal(
- /*intended_span=*/ExpandTokenSpan(SingleTokenSpan(click_pos),
- options_->context_size(),
- options_->context_size()),
- /*read_mask_span=*/{0, TokenSpanSize(extraction_span_)}, output_features);
+ /*intended_span=*/TokenSpan(click_pos).Expand(options_->context_size(),
+ options_->context_size()),
+ /*read_mask_span=*/{0, extraction_span_.Size()}, output_features);
}
void CachedFeatures::AppendBoundsSensitiveFeaturesForSpan(
@@ -118,16 +117,15 @@
/*intended_span=*/{selected_span.second -
config->num_tokens_inside_right(),
selected_span.second + config->num_tokens_after()},
- /*read_mask_span=*/{selected_span.first, TokenSpanSize(extraction_span_)},
- output_features);
+ /*read_mask_span=*/
+ {selected_span.first, extraction_span_.Size()}, output_features);
if (config->include_inside_bag()) {
AppendBagFeatures(selected_span, output_features);
}
if (config->include_inside_length()) {
- output_features->push_back(
- static_cast<float>(TokenSpanSize(selected_span)));
+ output_features->push_back(static_cast<float>(selected_span.Size()));
}
}
@@ -161,7 +159,7 @@
for (int i = bag_span.first; i < bag_span.second; ++i) {
for (int j = 0; j < NumFeaturesPerToken(); ++j) {
(*output_features)[offset + j] +=
- (*features_)[i * NumFeaturesPerToken() + j] / TokenSpanSize(bag_span);
+ (*features_)[i * NumFeaturesPerToken() + j] / bag_span.Size();
}
}
}
diff --git a/native/annotator/datetime/datetime-grounder.cc b/native/annotator/datetime/datetime-grounder.cc
new file mode 100644
index 0000000..7d5f440
--- /dev/null
+++ b/native/annotator/datetime/datetime-grounder.cc
@@ -0,0 +1,273 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/datetime/datetime-grounder.h"
+
+#include <limits>
+#include <unordered_map>
+#include <vector>
+
+#include "annotator/datetime/datetime_generated.h"
+#include "annotator/datetime/utils.h"
+#include "annotator/types.h"
+#include "utils/base/integral_types.h"
+#include "utils/base/status.h"
+#include "utils/base/status_macros.h"
+
+using ::libtextclassifier3::grammar::datetime::AbsoluteDateTime;
+using ::libtextclassifier3::grammar::datetime::ComponentType;
+using ::libtextclassifier3::grammar::datetime::Meridiem;
+using ::libtextclassifier3::grammar::datetime::RelativeDateTime;
+using ::libtextclassifier3::grammar::datetime::RelativeDatetimeComponent;
+using ::libtextclassifier3::grammar::datetime::UngroundedDatetime;
+using ::libtextclassifier3::grammar::datetime::RelativeDatetimeComponent_::
+ Modifier;
+
+namespace libtextclassifier3 {
+
+namespace {
+
+const std::unordered_map<int, int> kMonthDefaultLastDayMap(
+ {{/*no_month*/ 0, 31},
+ {/*January*/ 1, 31},
+ {/*Febuary*/ 2, 29},
+ {/*March*/ 3, 31},
+ {/*April*/ 4, 30},
+ {/*May*/ 5, 31},
+ {/*June*/ 6, 30},
+ {/*July*/ 7, 31},
+ {/*August*/ 8, 31},
+ {/*September*/ 9, 30},
+ {/*October*/ 10, 31},
+ {/*November*/ 11, 30},
+ {/*December*/ 12, 31}});
+
+bool IsValidDatetime(const AbsoluteDateTime* absolute_datetime) {
+ // Sanity Checks.
+ if (absolute_datetime->minute() > 59 || absolute_datetime->second() > 59 ||
+ absolute_datetime->hour() > 23 || absolute_datetime->month() > 12 ||
+ absolute_datetime->month() == 0) {
+ return false;
+ }
+ if (absolute_datetime->day() >= 0) {
+ int min_day_value = 1;
+ int max_day_value = 31;
+ if (absolute_datetime->month() >= 0 && absolute_datetime->month() <= 12) {
+ max_day_value = kMonthDefaultLastDayMap.at(absolute_datetime->month());
+ if (absolute_datetime->day() < min_day_value ||
+ absolute_datetime->day() > max_day_value) {
+ return false;
+ }
+ }
+ }
+ return true;
+}
+
+bool IsValidDatetime(const RelativeDateTime* relative_datetime) {
+ if (relative_datetime->base()) {
+ return IsValidDatetime(relative_datetime->base());
+ }
+ return true;
+}
+
+StatusOr<DatetimeComponent::RelativeQualifier> ToRelativeQualifier(
+ const Modifier& modifier) {
+ switch (modifier) {
+ case Modifier::Modifier_THIS:
+ return DatetimeComponent::RelativeQualifier::THIS;
+ case Modifier::Modifier_LAST:
+ return DatetimeComponent::RelativeQualifier::LAST;
+ case Modifier::Modifier_NEXT:
+ return DatetimeComponent::RelativeQualifier::NEXT;
+ case Modifier::Modifier_NOW:
+ return DatetimeComponent::RelativeQualifier::NOW;
+ case Modifier::Modifier_TOMORROW:
+ return DatetimeComponent::RelativeQualifier::TOMORROW;
+ case Modifier::Modifier_YESTERDAY:
+ return DatetimeComponent::RelativeQualifier::YESTERDAY;
+ case Modifier::Modifier_PAST:
+ return DatetimeComponent::RelativeQualifier::PAST;
+ case Modifier::Modifier_FUTURE:
+ return DatetimeComponent::RelativeQualifier::FUTURE;
+ case Modifier::Modifier_UNSPECIFIED:
+ return DatetimeComponent::RelativeQualifier::UNSPECIFIED;
+ default:
+ return Status(StatusCode::INTERNAL,
+ "Couldn't parse the Modifier to RelativeQualifier.");
+ }
+}
+
+StatusOr<DatetimeComponent::ComponentType> ToComponentType(
+ const grammar::datetime::ComponentType component_type) {
+ switch (component_type) {
+ case grammar::datetime::ComponentType_YEAR:
+ return DatetimeComponent::ComponentType::YEAR;
+ case grammar::datetime::ComponentType_MONTH:
+ return DatetimeComponent::ComponentType::MONTH;
+ case grammar::datetime::ComponentType_WEEK:
+ return DatetimeComponent::ComponentType::WEEK;
+ case grammar::datetime::ComponentType_DAY_OF_WEEK:
+ return DatetimeComponent::ComponentType::DAY_OF_WEEK;
+ case grammar::datetime::ComponentType_DAY_OF_MONTH:
+ return DatetimeComponent::ComponentType::DAY_OF_MONTH;
+ case grammar::datetime::ComponentType_HOUR:
+ return DatetimeComponent::ComponentType::HOUR;
+ case grammar::datetime::ComponentType_MINUTE:
+ return DatetimeComponent::ComponentType::MINUTE;
+ case grammar::datetime::ComponentType_SECOND:
+ return DatetimeComponent::ComponentType::SECOND;
+ case grammar::datetime::ComponentType_MERIDIEM:
+ return DatetimeComponent::ComponentType::MERIDIEM;
+ case grammar::datetime::ComponentType_UNSPECIFIED:
+ return DatetimeComponent::ComponentType::UNSPECIFIED;
+ default:
+ return Status(StatusCode::INTERNAL,
+ "Couldn't parse the DatetimeComponent's ComponentType from "
+ "grammar's datetime ComponentType.");
+ }
+}
+
+void FillAbsoluteDateTimeComponents(
+ const grammar::datetime::AbsoluteDateTime* absolute_datetime,
+ DatetimeParsedData* datetime_parsed_data) {
+ if (absolute_datetime->year() >= 0) {
+ datetime_parsed_data->SetAbsoluteValue(
+ DatetimeComponent::ComponentType::YEAR,
+ GetAdjustedYear(absolute_datetime->year()));
+ }
+ if (absolute_datetime->month() >= 0) {
+ datetime_parsed_data->SetAbsoluteValue(
+ DatetimeComponent::ComponentType::MONTH, absolute_datetime->month());
+ }
+ if (absolute_datetime->day() >= 0) {
+ datetime_parsed_data->SetAbsoluteValue(
+ DatetimeComponent::ComponentType::DAY_OF_MONTH,
+ absolute_datetime->day());
+ }
+ if (absolute_datetime->week_day() >= 0) {
+ datetime_parsed_data->SetAbsoluteValue(
+ DatetimeComponent::ComponentType::DAY_OF_WEEK,
+ absolute_datetime->week_day());
+ }
+ if (absolute_datetime->hour() >= 0) {
+ datetime_parsed_data->SetAbsoluteValue(
+ DatetimeComponent::ComponentType::HOUR, absolute_datetime->hour());
+ }
+ if (absolute_datetime->minute() >= 0) {
+ datetime_parsed_data->SetAbsoluteValue(
+ DatetimeComponent::ComponentType::MINUTE, absolute_datetime->minute());
+ }
+ if (absolute_datetime->second() >= 0) {
+ datetime_parsed_data->SetAbsoluteValue(
+ DatetimeComponent::ComponentType::SECOND, absolute_datetime->second());
+ }
+ if (absolute_datetime->meridiem() != grammar::datetime::Meridiem_UNKNOWN) {
+ datetime_parsed_data->SetAbsoluteValue(
+ DatetimeComponent::ComponentType::MERIDIEM,
+ absolute_datetime->meridiem() == grammar::datetime::Meridiem_AM ? 0
+ : 1);
+ }
+ if (absolute_datetime->time_zone()) {
+ datetime_parsed_data->SetAbsoluteValue(
+ DatetimeComponent::ComponentType::ZONE_OFFSET,
+ absolute_datetime->time_zone()->utc_offset_mins());
+ }
+}
+
+StatusOr<DatetimeParsedData> FillRelativeDateTimeComponents(
+ const grammar::datetime::RelativeDateTime* relative_datetime) {
+ DatetimeParsedData datetime_parsed_data;
+ for (const RelativeDatetimeComponent* relative_component :
+ *relative_datetime->relative_datetime_component()) {
+ TC3_ASSIGN_OR_RETURN(const DatetimeComponent::ComponentType component_type,
+ ToComponentType(relative_component->component_type()));
+ datetime_parsed_data.SetRelativeCount(component_type,
+ relative_component->value());
+ TC3_ASSIGN_OR_RETURN(
+ const DatetimeComponent::RelativeQualifier relative_qualifier,
+ ToRelativeQualifier(relative_component->modifier()));
+ datetime_parsed_data.SetRelativeValue(component_type, relative_qualifier);
+ }
+ if (relative_datetime->base()) {
+ FillAbsoluteDateTimeComponents(relative_datetime->base(),
+ &datetime_parsed_data);
+ }
+ return datetime_parsed_data;
+}
+
+} // namespace
+
+DatetimeGrounder::DatetimeGrounder(const CalendarLib* calendarlib)
+ : calendarlib_(*calendarlib) {}
+
+StatusOr<std::vector<DatetimeParseResult>> DatetimeGrounder::Ground(
+ const int64 reference_time_ms_utc, const std::string& reference_timezone,
+ const std::string& reference_locale,
+ const grammar::datetime::UngroundedDatetime* ungrounded_datetime) const {
+ DatetimeParsedData datetime_parsed_data;
+ if (ungrounded_datetime->absolute_datetime()) {
+ FillAbsoluteDateTimeComponents(ungrounded_datetime->absolute_datetime(),
+ &datetime_parsed_data);
+ } else if (ungrounded_datetime->relative_datetime()) {
+ TC3_ASSIGN_OR_RETURN(datetime_parsed_data,
+ FillRelativeDateTimeComponents(
+ ungrounded_datetime->relative_datetime()));
+ }
+ std::vector<DatetimeParsedData> interpretations;
+ FillInterpretations(datetime_parsed_data,
+ calendarlib_.GetGranularity(datetime_parsed_data),
+ &interpretations);
+ std::vector<DatetimeParseResult> datetime_parse_result;
+
+ for (const DatetimeParsedData& interpretation : interpretations) {
+ std::vector<DatetimeComponent> date_components;
+ interpretation.GetDatetimeComponents(&date_components);
+ DatetimeParseResult result;
+ // Text classifier only provides ambiguity limited to “AM/PM” which is
+ // encoded in the pair of DatetimeParseResult; both corresponding to the
+ // same date, but one corresponding to “AM” and the other one corresponding
+ // to “PM”.
+ if (!calendarlib_.InterpretParseData(
+ interpretation, reference_time_ms_utc, reference_timezone,
+ reference_locale, /*prefer_future_for_unspecified_date=*/true,
+ &(result.time_ms_utc), &(result.granularity))) {
+ return Status(
+ StatusCode::INTERNAL,
+ "Couldn't parse the UngroundedDatetime to DatetimeParseResult.");
+ }
+
+ // Sort the date time units by component type.
+ std::sort(date_components.begin(), date_components.end(),
+ [](DatetimeComponent a, DatetimeComponent b) {
+ return a.component_type > b.component_type;
+ });
+ result.datetime_components.swap(date_components);
+ datetime_parse_result.push_back(result);
+ }
+ return datetime_parse_result;
+}
+
+bool DatetimeGrounder::IsValidUngroundedDatetime(
+ const UngroundedDatetime* ungrounded_datetime) const {
+ if (ungrounded_datetime->absolute_datetime()) {
+ return IsValidDatetime(ungrounded_datetime->absolute_datetime());
+ } else if (ungrounded_datetime->relative_datetime()) {
+ return IsValidDatetime(ungrounded_datetime->relative_datetime());
+ }
+ return false;
+}
+
+} // namespace libtextclassifier3
diff --git a/native/annotator/datetime/datetime-grounder.h b/native/annotator/datetime/datetime-grounder.h
new file mode 100644
index 0000000..6a6f5e4
--- /dev/null
+++ b/native/annotator/datetime/datetime-grounder.h
@@ -0,0 +1,50 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_DATETIME_GROUNDER_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_DATETIME_GROUNDER_H_
+
+#include <vector>
+
+#include "annotator/datetime/datetime_generated.h"
+#include "annotator/types.h"
+#include "utils/base/statusor.h"
+#include "utils/calendar/calendar.h"
+
+namespace libtextclassifier3 {
+
+// Utility class to resolve and complete an ungrounded datetime specification.
+class DatetimeGrounder {
+ public:
+ explicit DatetimeGrounder(const CalendarLib* calendarlib);
+
+ // Resolves ambiguities and produces concrete datetime results from an
+ // ungrounded datetime specification.
+ StatusOr<std::vector<DatetimeParseResult>> Ground(
+ const int64 reference_time_ms_utc, const std::string& reference_timezone,
+ const std::string& reference_locale,
+ const grammar::datetime::UngroundedDatetime* ungrounded_datetime) const;
+
+ bool IsValidUngroundedDatetime(
+ const grammar::datetime::UngroundedDatetime* ungrounded_datetime) const;
+
+ private:
+ const CalendarLib& calendarlib_;
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_DATETIME_GROUNDER_H_
diff --git a/native/annotator/datetime/datetime-grounder_test.cc b/native/annotator/datetime/datetime-grounder_test.cc
new file mode 100644
index 0000000..121aae8
--- /dev/null
+++ b/native/annotator/datetime/datetime-grounder_test.cc
@@ -0,0 +1,292 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/datetime/datetime-grounder.h"
+
+#include "annotator/datetime/datetime_generated.h"
+#include "utils/flatbuffers/flatbuffers.h"
+#include "utils/jvm-test-utils.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+using ::libtextclassifier3::grammar::datetime::AbsoluteDateTimeT;
+using ::libtextclassifier3::grammar::datetime::ComponentType;
+using ::libtextclassifier3::grammar::datetime::Meridiem;
+using ::libtextclassifier3::grammar::datetime::RelativeDatetimeComponentT;
+using ::libtextclassifier3::grammar::datetime::RelativeDateTimeT;
+using ::libtextclassifier3::grammar::datetime::TimeZoneT;
+using ::libtextclassifier3::grammar::datetime::UngroundedDatetime;
+using ::libtextclassifier3::grammar::datetime::UngroundedDatetimeT;
+using ::libtextclassifier3::grammar::datetime::RelativeDatetimeComponent_::
+ Modifier;
+using ::testing::SizeIs;
+
+namespace libtextclassifier3 {
+
+class DatetimeGrounderTest : public testing::Test {
+ public:
+ void SetUp() override {
+ calendarlib_ = CreateCalendarLibForTesting();
+ datetime_grounder_.reset(new DatetimeGrounder(calendarlib_.get()));
+ }
+
+ protected:
+ OwnedFlatbuffer<UngroundedDatetime, std::string> BuildAbsoluteDatetime(
+ const int year, const int month, const int day, const int hour,
+ const int minute, const int second, const Meridiem meridiem) {
+ grammar::datetime::UngroundedDatetimeT ungrounded_datetime;
+ ungrounded_datetime.absolute_datetime.reset(new AbsoluteDateTimeT);
+
+ // Set absolute datetime value.
+ ungrounded_datetime.absolute_datetime->year = year;
+ ungrounded_datetime.absolute_datetime->month = month;
+ ungrounded_datetime.absolute_datetime->day = day;
+ ungrounded_datetime.absolute_datetime->hour = hour;
+ ungrounded_datetime.absolute_datetime->minute = minute;
+ ungrounded_datetime.absolute_datetime->second = second;
+ ungrounded_datetime.absolute_datetime->meridiem = meridiem;
+
+ return OwnedFlatbuffer<UngroundedDatetime, std::string>(
+ PackFlatbuffer<UngroundedDatetime>(&ungrounded_datetime));
+ }
+
+ OwnedFlatbuffer<UngroundedDatetime, std::string> BuildRelativeDatetime(
+ const ComponentType component_type, const Modifier modifier,
+ const int relative_count) {
+ UngroundedDatetimeT ungrounded_datetime;
+ ungrounded_datetime.relative_datetime.reset(new RelativeDateTimeT);
+ ungrounded_datetime.relative_datetime->relative_datetime_component
+ .emplace_back(new RelativeDatetimeComponentT);
+ ungrounded_datetime.relative_datetime->relative_datetime_component.back()
+ ->modifier = modifier;
+ ungrounded_datetime.relative_datetime->relative_datetime_component.back()
+ ->component_type = component_type;
+ ungrounded_datetime.relative_datetime->relative_datetime_component.back()
+ ->value = relative_count;
+ ungrounded_datetime.relative_datetime->base.reset(new AbsoluteDateTimeT);
+ ungrounded_datetime.relative_datetime->base->year = 2020;
+ ungrounded_datetime.relative_datetime->base->month = 6;
+ ungrounded_datetime.relative_datetime->base->day = 30;
+
+ return OwnedFlatbuffer<UngroundedDatetime, std::string>(
+ PackFlatbuffer<UngroundedDatetime>(&ungrounded_datetime));
+ }
+
+ void VerifyValidUngroundedDatetime(
+ const UngroundedDatetime* ungrounded_datetime) {
+ EXPECT_TRUE(
+ datetime_grounder_->IsValidUngroundedDatetime(ungrounded_datetime));
+ }
+
+ void VerifyInValidUngroundedDatetime(
+ const UngroundedDatetime* ungrounded_datetime) {
+ EXPECT_FALSE(
+ datetime_grounder_->IsValidUngroundedDatetime(ungrounded_datetime));
+ }
+
+ std::unique_ptr<DatetimeGrounder> datetime_grounder_;
+ std::unique_ptr<CalendarLib> calendarlib_;
+};
+
+TEST_F(DatetimeGrounderTest, AbsoluteDatetimeTest) {
+ const OwnedFlatbuffer<UngroundedDatetime, std::string> datetime =
+ BuildAbsoluteDatetime(/*year=*/2000, /*month=*/03, /*day=*/30,
+ /*hour=*/11, /*minute=*/59, /*second=*/59,
+ grammar::datetime::Meridiem_AM);
+ const std::vector<DatetimeParseResult> data =
+ datetime_grounder_
+ ->Ground(
+ /*reference_time_ms_utc=*/0, "Europe/Zurich", "en-US",
+ datetime.get())
+ .ValueOrDie();
+
+ EXPECT_THAT(data, SizeIs(1));
+ EXPECT_EQ(data[0].granularity, DatetimeGranularity::GRANULARITY_SECOND);
+
+ // Meridiem
+ EXPECT_EQ(data[0].datetime_components[0].component_type,
+ DatetimeComponent::ComponentType::MERIDIEM);
+ EXPECT_EQ(data[0].datetime_components[0].value, 0);
+
+ EXPECT_EQ(data[0].datetime_components[1].component_type,
+ DatetimeComponent::ComponentType::SECOND);
+ EXPECT_EQ(data[0].datetime_components[1].component_type,
+ DatetimeComponent::ComponentType::SECOND);
+
+ EXPECT_EQ(data[0].datetime_components[2].component_type,
+ DatetimeComponent::ComponentType::MINUTE);
+ EXPECT_EQ(data[0].datetime_components[2].value, 59);
+
+ EXPECT_EQ(data[0].datetime_components[3].component_type,
+ DatetimeComponent::ComponentType::HOUR);
+ EXPECT_EQ(data[0].datetime_components[3].value, 11);
+
+ EXPECT_EQ(data[0].datetime_components[4].component_type,
+ DatetimeComponent::ComponentType::DAY_OF_MONTH);
+ EXPECT_EQ(data[0].datetime_components[4].value, 30);
+
+ EXPECT_EQ(data[0].datetime_components[5].component_type,
+ DatetimeComponent::ComponentType::MONTH);
+ EXPECT_EQ(data[0].datetime_components[5].value, 3);
+
+ EXPECT_EQ(data[0].datetime_components[6].component_type,
+ DatetimeComponent::ComponentType::YEAR);
+ EXPECT_EQ(data[0].datetime_components[6].value, 2000);
+}
+
+TEST_F(DatetimeGrounderTest, InterpretDatetimeTest) {
+ const OwnedFlatbuffer<UngroundedDatetime, std::string> datetime =
+ BuildAbsoluteDatetime(/*year=*/2000, /*month=*/03, /*day=*/30,
+ /*hour=*/11, /*minute=*/59, /*second=*/59,
+ grammar::datetime::Meridiem_UNKNOWN);
+ const std::vector<DatetimeParseResult> data =
+ datetime_grounder_
+ ->Ground(
+ /*reference_time_ms_utc=*/0, "Europe/Zurich", "en-US",
+ datetime.get())
+ .ValueOrDie();
+
+ EXPECT_THAT(data, SizeIs(2));
+ EXPECT_EQ(data[0].granularity, DatetimeGranularity::GRANULARITY_SECOND);
+ EXPECT_EQ(data[1].granularity, DatetimeGranularity::GRANULARITY_SECOND);
+
+ // Check Meridiem's values
+ EXPECT_EQ(data[0].datetime_components[0].component_type,
+ DatetimeComponent::ComponentType::MERIDIEM);
+ EXPECT_EQ(data[0].datetime_components[0].value, 0);
+ EXPECT_EQ(data[1].datetime_components[0].component_type,
+ DatetimeComponent::ComponentType::MERIDIEM);
+ EXPECT_EQ(data[1].datetime_components[0].value, 1);
+}
+
+TEST_F(DatetimeGrounderTest, RelativeDatetimeTest) {
+ const OwnedFlatbuffer<UngroundedDatetime, std::string> datetime =
+ BuildRelativeDatetime(ComponentType::ComponentType_DAY_OF_MONTH,
+ Modifier::Modifier_NEXT, 1);
+ const std::vector<DatetimeParseResult> data =
+ datetime_grounder_
+ ->Ground(
+ /*reference_time_ms_utc=*/0, "Europe/Zurich", "en-US",
+ datetime.get())
+ .ValueOrDie();
+
+ EXPECT_THAT(data, SizeIs(1));
+ EXPECT_EQ(data[0].granularity, DatetimeGranularity::GRANULARITY_DAY);
+
+ EXPECT_EQ(data[0].datetime_components[0].component_type,
+ DatetimeComponent::ComponentType::DAY_OF_MONTH);
+ EXPECT_EQ(data[0].datetime_components[0].relative_qualifier,
+ DatetimeComponent::RelativeQualifier::NEXT);
+ EXPECT_EQ(data[0].datetime_components[0].relative_count, 1);
+ EXPECT_EQ(data[0].datetime_components[1].component_type,
+ DatetimeComponent::ComponentType::MONTH);
+ EXPECT_EQ(data[0].datetime_components[2].component_type,
+ DatetimeComponent::ComponentType::YEAR);
+}
+
+TEST_F(DatetimeGrounderTest, TimeZoneTest) {
+ grammar::datetime::UngroundedDatetimeT ungrounded_datetime;
+ ungrounded_datetime.absolute_datetime.reset(new AbsoluteDateTimeT);
+ ungrounded_datetime.absolute_datetime->time_zone.reset(new TimeZoneT);
+ ungrounded_datetime.absolute_datetime->time_zone->utc_offset_mins = 120;
+ const OwnedFlatbuffer<UngroundedDatetime, std::string> timezone(
+ PackFlatbuffer<UngroundedDatetime>(&ungrounded_datetime));
+
+ const std::vector<DatetimeParseResult> data =
+ datetime_grounder_
+ ->Ground(
+ /*reference_time_ms_utc=*/0, "Europe/Zurich", "en-US",
+ timezone.get())
+ .ValueOrDie();
+
+ EXPECT_THAT(data, SizeIs(1));
+ EXPECT_EQ(data[0].granularity, DatetimeGranularity::GRANULARITY_UNKNOWN);
+ EXPECT_EQ(data[0].datetime_components[0].component_type,
+ DatetimeComponent::ComponentType::ZONE_OFFSET);
+ EXPECT_EQ(data[0].datetime_components[0].value, 120);
+}
+
+TEST_F(DatetimeGrounderTest, InValidUngroundedDatetime) {
+ VerifyInValidUngroundedDatetime(
+ BuildAbsoluteDatetime(/*year=*/2000, /*month=*/23, /*day=*/30,
+ /*hour=*/11, /*minute=*/59, /*second=*/59,
+ grammar::datetime::Meridiem_AM)
+ .get());
+
+ VerifyInValidUngroundedDatetime(
+ BuildAbsoluteDatetime(/*year=*/2000, /*month=*/03, /*day=*/33,
+ /*hour=*/11, /*minute=*/59, /*second=*/59,
+ grammar::datetime::Meridiem_AM)
+ .get());
+
+ VerifyInValidUngroundedDatetime(
+ BuildAbsoluteDatetime(/*year=*/2000, /*month=*/02, /*day=*/30,
+ /*hour=*/11, /*minute=*/59, /*second=*/59,
+ grammar::datetime::Meridiem_AM)
+ .get());
+
+ VerifyInValidUngroundedDatetime(
+ BuildAbsoluteDatetime(/*year=*/2000, /*month=*/07, /*day=*/31,
+ /*hour=*/24, /*minute=*/59, /*second=*/59,
+ grammar::datetime::Meridiem_AM)
+ .get());
+
+ VerifyInValidUngroundedDatetime(
+ BuildAbsoluteDatetime(/*year=*/2000, /*month=*/02, /*day=*/28,
+ /*hour=*/24, /*minute=*/59, /*second=*/59,
+ grammar::datetime::Meridiem_AM)
+ .get());
+
+ VerifyInValidUngroundedDatetime(
+ BuildAbsoluteDatetime(/*year=*/2000, /*month=*/02, /*day=*/28,
+ /*hour=*/11, /*minute=*/69, /*second=*/59,
+ grammar::datetime::Meridiem_AM)
+ .get());
+
+ VerifyInValidUngroundedDatetime(
+ BuildAbsoluteDatetime(/*year=*/2000, /*month=*/02, /*day=*/28,
+ /*hour=*/11, /*minute=*/59, /*second=*/99,
+ grammar::datetime::Meridiem_AM)
+ .get());
+
+ VerifyInValidUngroundedDatetime(
+ BuildAbsoluteDatetime(/*year=*/2000, /*month=*/00, /*day=*/28,
+ /*hour=*/11, /*minute=*/59, /*second=*/99,
+ grammar::datetime::Meridiem_AM)
+ .get());
+}
+
+TEST_F(DatetimeGrounderTest, ValidUngroundedDatetime) {
+ VerifyValidUngroundedDatetime(
+ BuildAbsoluteDatetime(/*year=*/2000, /*month=*/2, /*day=*/29,
+ /*hour=*/23, /*minute=*/59, /*second=*/59,
+ grammar::datetime::Meridiem_AM)
+ .get());
+
+ VerifyValidUngroundedDatetime(
+ BuildAbsoluteDatetime(/*year=*/2000, /*month=*/7, /*day=*/31,
+ /*hour=*/23, /*minute=*/59, /*second=*/59,
+ grammar::datetime::Meridiem_AM)
+ .get());
+
+ VerifyValidUngroundedDatetime(
+ BuildAbsoluteDatetime(/*year=*/2000, /*month=*/10, /*day=*/31,
+ /*hour=*/23, /*minute=*/59, /*second=*/59,
+ grammar::datetime::Meridiem_AM)
+ .get());
+}
+
+} // namespace libtextclassifier3
diff --git a/native/annotator/datetime/datetime.fbs b/native/annotator/datetime/datetime.fbs
new file mode 100644
index 0000000..9a96bae
--- /dev/null
+++ b/native/annotator/datetime/datetime.fbs
@@ -0,0 +1,153 @@
+//
+// Copyright (C) 2018 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+// Meridiem field.
+namespace libtextclassifier3.grammar.datetime;
+enum Meridiem : int {
+ UNKNOWN = 0,
+
+ // Ante meridiem: Before noon
+ AM = 1,
+
+ // Post meridiem: After noon
+ PM = 2,
+}
+
+// Enum represents a unit of date and time in the expression.
+// Next field: 10
+namespace libtextclassifier3.grammar.datetime;
+enum ComponentType : int {
+ UNSPECIFIED = 0,
+
+ // Year of the date seen in the text match.
+ YEAR = 1,
+
+ // Month of the year starting with January = 1.
+ MONTH = 2,
+
+ // Week (7 days).
+ WEEK = 3,
+
+ // Day of week, start of the week is Sunday & its value is 1.
+ DAY_OF_WEEK = 4,
+
+ // Day of the month starting with 1.
+ DAY_OF_MONTH = 5,
+
+ // Hour of the day.
+ HOUR = 6,
+
+ // Minute of the hour with a range of 0-59.
+ MINUTE = 7,
+
+ // Seconds of the minute with a range of 0-59.
+ SECOND = 8,
+
+ // Meridiem field i.e. AM/PM.
+ MERIDIEM = 9,
+}
+
+namespace libtextclassifier3.grammar.datetime;
+table TimeZone {
+ // Offset from UTC/GTM in minutes.
+ utc_offset_mins:int;
+}
+
+namespace libtextclassifier3.grammar.datetime.RelativeDatetimeComponent_;
+enum Modifier : int {
+ UNSPECIFIED = 0,
+ NEXT = 1,
+ THIS = 2,
+ LAST = 3,
+ NOW = 4,
+ TOMORROW = 5,
+ YESTERDAY = 6,
+ PAST = 7,
+ FUTURE = 8,
+}
+
+// Message for representing the relative date-time component in date-time
+// expressions.
+// Next field: 4
+namespace libtextclassifier3.grammar.datetime;
+table RelativeDatetimeComponent {
+ component_type:ComponentType = UNSPECIFIED;
+ modifier:RelativeDatetimeComponent_.Modifier = UNSPECIFIED;
+ value:int;
+}
+
+// AbsoluteDateTime represents date-time expressions that is not ambiguous.
+// Next field: 11
+namespace libtextclassifier3.grammar.datetime;
+table AbsoluteDateTime {
+ // Year value of the date seen in the text match.
+ year:int = -1;
+
+ // Month value of the year starting with January = 1.
+ month:int = -1;
+
+ // Day value of the month starting with 1.
+ day:int = -1;
+
+ // Day of week, start of the week is Sunday and its value is 1.
+ week_day:int = -1;
+
+ // Hour value of the day.
+ hour:int = -1;
+
+ // Minute value of the hour with a range of 0-59.
+ minute:int = -1;
+
+ // Seconds value of the minute with a range of 0-59.
+ second:int = -1;
+
+ partial_second:double = -1;
+
+ // Meridiem field i.e. AM/PM.
+ meridiem:Meridiem;
+
+ time_zone:TimeZone;
+}
+
+// Message to represent relative datetime expressions.
+// It encode expressions
+// - Where modifier such as before/after shift the date e.g.[three days ago],
+// [2 days after March 1st].
+// - When prefix make the expression relative e.g. [next weekend],
+// [last Monday].
+// Next field: 3
+namespace libtextclassifier3.grammar.datetime;
+table RelativeDateTime {
+ relative_datetime_component:[RelativeDatetimeComponent];
+
+ // The base could be an absolute datetime point for example: "March 1", a
+ // relative datetime point, for example: "2 days before March 1"
+ base:AbsoluteDateTime;
+}
+
+// Datetime result.
+namespace libtextclassifier3.grammar.datetime;
+table UngroundedDatetime {
+ absolute_datetime:AbsoluteDateTime;
+ relative_datetime:RelativeDateTime;
+
+ // The annotation usecases.
+ // There are two modes.
+ // 1- SMART - Datetime results which are optimized for Smart select
+ // 2- RAW - Results are optimized for where annotates as much as possible.
+ annotation_usecases:uint = 4294967295;
+}
+
diff --git a/native/annotator/datetime/extractor.cc b/native/annotator/datetime/extractor.cc
index b8e1b7a..867c886 100644
--- a/native/annotator/datetime/extractor.cc
+++ b/native/annotator/datetime/extractor.cc
@@ -16,6 +16,9 @@
#include "annotator/datetime/extractor.h"
+#include "annotator/datetime/utils.h"
+#include "annotator/model_generated.h"
+#include "annotator/types.h"
#include "utils/base/logging.h"
namespace libtextclassifier3 {
@@ -162,6 +165,18 @@
}
break;
}
+ case DatetimeGroupType_GROUP_ABSOLUTETIME: {
+ std::unordered_map<DatetimeComponent::ComponentType, int> values;
+ if (!ParseAbsoluteDateValues(group_text, &values)) {
+ TC3_LOG(ERROR) << "Couldn't extract Component values.";
+ return false;
+ }
+ for (const std::pair<const DatetimeComponent::ComponentType, int>&
+ date_time_pair : values) {
+ result->SetAbsoluteValue(date_time_pair.first, date_time_pair.second);
+ }
+ break;
+ }
case DatetimeGroupType_GROUP_DUMMY1:
case DatetimeGroupType_GROUP_DUMMY2:
break;
@@ -376,15 +391,7 @@
if (!ParseDigits(input, parsed_year)) {
return false;
}
-
- // Logic to decide if XX will be 20XX or 19XX
- if (*parsed_year < 100) {
- if (*parsed_year < 50) {
- *parsed_year += 2000;
- } else {
- *parsed_year += 1900;
- }
- }
+ *parsed_year = GetAdjustedYear(*parsed_year);
return true;
}
@@ -417,6 +424,26 @@
return false;
}
+bool DatetimeExtractor::ParseAbsoluteDateValues(
+ const UnicodeText& input,
+ std::unordered_map<DatetimeComponent::ComponentType, int>* values) const {
+ if (MapInput(input,
+ {
+ {DatetimeExtractorType_NOON,
+ {{DatetimeComponent::ComponentType::MERIDIEM, 1},
+ {DatetimeComponent::ComponentType::MINUTE, 0},
+ {DatetimeComponent::ComponentType::HOUR, 12}}},
+ {DatetimeExtractorType_MIDNIGHT,
+ {{DatetimeComponent::ComponentType::MERIDIEM, 0},
+ {DatetimeComponent::ComponentType::MINUTE, 0},
+ {DatetimeComponent::ComponentType::HOUR, 0}}},
+ },
+ values)) {
+ return true;
+ }
+ return false;
+}
+
bool DatetimeExtractor::ParseMeridiem(const UnicodeText& input,
int* parsed_meridiem) const {
return MapInput(input,
diff --git a/native/annotator/datetime/extractor.h b/native/annotator/datetime/extractor.h
index 0f92b2a..3f2b755 100644
--- a/native/annotator/datetime/extractor.h
+++ b/native/annotator/datetime/extractor.h
@@ -96,9 +96,19 @@
const UnicodeText& input,
DatetimeComponent::ComponentType* parsed_field_type) const;
bool ParseDayOfWeek(const UnicodeText& input, int* parsed_day_of_week) const;
+
bool ParseRelationAndConvertToRelativeCount(const UnicodeText& input,
int* relative_count) const;
+ // There are some special words which represent multiple date time components
+ // e.g. if the text says “by noon” it clearly indicates that the hour is 12,
+ // minute is 0 and meridiam is PM.
+ // The method handles such tokens and translates them into multiple date time
+ // components.
+ bool ParseAbsoluteDateValues(
+ const UnicodeText& input,
+ std::unordered_map<DatetimeComponent::ComponentType, int>* values) const;
+
const CompiledRule& rule_;
const UniLib::RegexMatcher& matcher_;
int locale_id_;
diff --git a/native/annotator/datetime/grammar-parser.cc b/native/annotator/datetime/grammar-parser.cc
new file mode 100644
index 0000000..6d51c19
--- /dev/null
+++ b/native/annotator/datetime/grammar-parser.cc
@@ -0,0 +1,121 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/datetime/grammar-parser.h"
+
+#include <set>
+#include <unordered_set>
+
+#include "annotator/datetime/datetime-grounder.h"
+#include "annotator/types.h"
+#include "utils/grammar/analyzer.h"
+#include "utils/grammar/evaluated-derivation.h"
+#include "utils/grammar/parsing/derivation.h"
+
+using ::libtextclassifier3::grammar::EvaluatedDerivation;
+using ::libtextclassifier3::grammar::datetime::UngroundedDatetime;
+
+namespace libtextclassifier3 {
+
+GrammarDatetimeParser::GrammarDatetimeParser(
+ const grammar::Analyzer& analyzer,
+ const DatetimeGrounder& datetime_grounder,
+ const float target_classification_score, const float priority_score)
+ : analyzer_(analyzer),
+ datetime_grounder_(datetime_grounder),
+ target_classification_score_(target_classification_score),
+ priority_score_(priority_score) {}
+
+StatusOr<std::vector<DatetimeParseResultSpan>> GrammarDatetimeParser::Parse(
+ const std::string& input, const int64 reference_time_ms_utc,
+ const std::string& reference_timezone, const LocaleList& locale_list,
+ ModeFlag mode, AnnotationUsecase annotation_usecase,
+ bool anchor_start_end) const {
+ return Parse(UTF8ToUnicodeText(input, /*do_copy=*/false),
+ reference_time_ms_utc, reference_timezone, locale_list, mode,
+ annotation_usecase, anchor_start_end);
+}
+
+StatusOr<std::vector<DatetimeParseResultSpan>> GrammarDatetimeParser::Parse(
+ const UnicodeText& input, const int64 reference_time_ms_utc,
+ const std::string& reference_timezone, const LocaleList& locale_list,
+ ModeFlag mode, AnnotationUsecase annotation_usecase,
+ bool anchor_start_end) const {
+ std::vector<DatetimeParseResultSpan> results;
+ UnsafeArena arena(/*block_size=*/16 << 10);
+ std::vector<Locale> locales = locale_list.GetLocales();
+ // If the locale list is empty then datetime regex expression will still
+ // execute but in grammar based parser the rules are associated with local
+ // and engine will not run if the locale list is empty. In an unlikely
+ // scenario when locale is not mentioned fallback to en-*.
+ if (locales.empty()) {
+ locales.emplace_back(Locale::FromBCP47("en"));
+ }
+ TC3_ASSIGN_OR_RETURN(
+ const std::vector<EvaluatedDerivation> evaluated_derivations,
+ analyzer_.Parse(input, locales, &arena,
+ /*deduplicate_derivations=*/false));
+
+ std::vector<EvaluatedDerivation> valid_evaluated_derivations;
+ for (const EvaluatedDerivation& evaluated_derivation :
+ evaluated_derivations) {
+ if (evaluated_derivation.value) {
+ if (evaluated_derivation.value->Has<flatbuffers::Table>()) {
+ const UngroundedDatetime* ungrounded_datetime =
+ evaluated_derivation.value->Table<UngroundedDatetime>();
+ if (datetime_grounder_.IsValidUngroundedDatetime(ungrounded_datetime)) {
+ valid_evaluated_derivations.emplace_back(evaluated_derivation);
+ }
+ }
+ }
+ }
+ valid_evaluated_derivations =
+ grammar::DeduplicateDerivations(valid_evaluated_derivations);
+ for (const EvaluatedDerivation& evaluated_derivation :
+ valid_evaluated_derivations) {
+ if (evaluated_derivation.value) {
+ if (evaluated_derivation.value->Has<flatbuffers::Table>()) {
+ const UngroundedDatetime* ungrounded_datetime =
+ evaluated_derivation.value->Table<UngroundedDatetime>();
+ if ((ungrounded_datetime->annotation_usecases() &
+ (1 << annotation_usecase)) == 0) {
+ continue;
+ }
+ const StatusOr<std::vector<DatetimeParseResult>>&
+ datetime_parse_results = datetime_grounder_.Ground(
+ reference_time_ms_utc, reference_timezone,
+ locale_list.GetReferenceLocale(), ungrounded_datetime);
+ TC3_ASSIGN_OR_RETURN(
+ const std::vector<DatetimeParseResult>& parse_datetime,
+ datetime_parse_results);
+ DatetimeParseResultSpan datetime_parse_result_span;
+ datetime_parse_result_span.target_classification_score =
+ target_classification_score_;
+ datetime_parse_result_span.priority_score = priority_score_;
+ datetime_parse_result_span.data.reserve(parse_datetime.size());
+ datetime_parse_result_span.data.insert(
+ datetime_parse_result_span.data.end(), parse_datetime.begin(),
+ parse_datetime.end());
+ datetime_parse_result_span.span =
+ evaluated_derivation.parse_tree->codepoint_span;
+
+ results.emplace_back(datetime_parse_result_span);
+ }
+ }
+ }
+ return results;
+}
+} // namespace libtextclassifier3
diff --git a/native/annotator/datetime/grammar-parser.h b/native/annotator/datetime/grammar-parser.h
new file mode 100644
index 0000000..6ff4b46
--- /dev/null
+++ b/native/annotator/datetime/grammar-parser.h
@@ -0,0 +1,68 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_GRAMMAR_PARSER_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_GRAMMAR_PARSER_H_
+
+#include <string>
+#include <vector>
+
+#include "annotator/datetime/datetime-grounder.h"
+#include "annotator/datetime/parser.h"
+#include "annotator/types.h"
+#include "utils/base/statusor.h"
+#include "utils/grammar/analyzer.h"
+#include "utils/i18n/locale-list.h"
+#include "utils/utf8/unicodetext.h"
+
+namespace libtextclassifier3 {
+
+// Parses datetime expressions in the input and resolves them to actual absolute
+// time.
+class GrammarDatetimeParser : public DatetimeParser {
+ public:
+ explicit GrammarDatetimeParser(const grammar::Analyzer& analyzer,
+ const DatetimeGrounder& datetime_grounder,
+ const float target_classification_score,
+ const float priority_score);
+
+ // Parses the dates in 'input' and fills result. Makes sure that the results
+ // do not overlap.
+ // If 'anchor_start_end' is true the extracted results need to start at the
+ // beginning of 'input' and end at the end of it.
+ StatusOr<std::vector<DatetimeParseResultSpan>> Parse(
+ const std::string& input, int64 reference_time_ms_utc,
+ const std::string& reference_timezone, const LocaleList& locale_list,
+ ModeFlag mode, AnnotationUsecase annotation_usecase,
+ bool anchor_start_end) const override;
+
+ // Same as above but takes UnicodeText.
+ StatusOr<std::vector<DatetimeParseResultSpan>> Parse(
+ const UnicodeText& input, int64 reference_time_ms_utc,
+ const std::string& reference_timezone, const LocaleList& locale_list,
+ ModeFlag mode, AnnotationUsecase annotation_usecase,
+ bool anchor_start_end) const override;
+
+ private:
+ const grammar::Analyzer& analyzer_;
+ const DatetimeGrounder& datetime_grounder_;
+ const float target_classification_score_;
+ const float priority_score_;
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_GRAMMAR_PARSER_H_
diff --git a/native/annotator/datetime/grammar-parser_test.cc b/native/annotator/datetime/grammar-parser_test.cc
new file mode 100644
index 0000000..cf2dffd
--- /dev/null
+++ b/native/annotator/datetime/grammar-parser_test.cc
@@ -0,0 +1,554 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/datetime/grammar-parser.h"
+
+#include <memory>
+#include <string>
+
+#include "annotator/datetime/datetime-grounder.h"
+#include "annotator/datetime/testing/base-parser-test.h"
+#include "annotator/datetime/testing/datetime-component-builder.h"
+#include "utils/grammar/analyzer.h"
+#include "utils/jvm-test-utils.h"
+#include "utils/test-data-test-utils.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+using ::libtextclassifier3::grammar::Analyzer;
+using ::libtextclassifier3::grammar::RulesSet;
+
+namespace libtextclassifier3 {
+namespace {
+std::string GetModelPath() { return GetTestDataPath("annotator/test_data/"); }
+
+std::string ReadFile(const std::string& file_name) {
+ std::ifstream file_stream(file_name);
+ return std::string(std::istreambuf_iterator<char>(file_stream), {});
+}
+
+class GrammarDatetimeParserTest : public DateTimeParserTest {
+ public:
+ void SetUp() override {
+ grammar_buffer_ = ReadFile(GetModelPath() + "datetime.fb");
+ unilib_ = CreateUniLibForTesting();
+ calendarlib_ = CreateCalendarLibForTesting();
+ analyzer_ = std::make_unique<Analyzer>(
+ unilib_.get(), flatbuffers::GetRoot<RulesSet>(grammar_buffer_.data()));
+ datetime_grounder_ = std::make_unique<DatetimeGrounder>(calendarlib_.get());
+ parser_.reset(new GrammarDatetimeParser(*analyzer_, *datetime_grounder_,
+ /*target_classification_score=*/1.0,
+ /*priority_score=*/1.0));
+ }
+
+ // Exposes the date time parser for tests and evaluations.
+ const DatetimeParser* DatetimeParserForTests() const override {
+ return parser_.get();
+ }
+
+ private:
+ std::string grammar_buffer_;
+ std::unique_ptr<UniLib> unilib_;
+ std::unique_ptr<CalendarLib> calendarlib_;
+ std::unique_ptr<Analyzer> analyzer_;
+ std::unique_ptr<DatetimeGrounder> datetime_grounder_;
+ std::unique_ptr<DatetimeParser> parser_;
+};
+
+TEST_F(GrammarDatetimeParserTest, ParseShort) {
+ EXPECT_TRUE(ParsesCorrectly(
+ "{01/02/2020}", 1580511600000, GRANULARITY_DAY,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::MONTH, 2)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2020)
+ .Build()},
+ /*anchor_start_end=*/false,
+ /*timezone=*/"Europe/Zurich",
+ /*locales=*/"en-GB"));
+
+ EXPECT_TRUE(ParsesCorrectly(
+ "{01/02/2020}", 1577919600000, GRANULARITY_DAY,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 2)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2020)
+ .Build()}));
+
+ EXPECT_TRUE(ParsesCorrectly(
+ "{January 1, 1988}", 567990000000, GRANULARITY_DAY,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 1988)
+ .Build()}));
+}
+
+TEST_F(GrammarDatetimeParserTest, Parse) {
+ EXPECT_TRUE(ParsesCorrectly(
+ "{January 1, 1988}", 567990000000, GRANULARITY_DAY,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 1988)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{january 31 2018}", 1517353200000, GRANULARITY_DAY,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 31)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2018)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "foo {1 january 2018} bar", 1514761200000, GRANULARITY_DAY,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2018)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{09/Mar/2004 22:02:40}", 1078866160000, GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::SECOND, 40)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 02)
+ .Add(DatetimeComponent::ComponentType::HOUR, 22)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 9)
+ .Add(DatetimeComponent::ComponentType::MONTH, 3)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2004)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{Dec 2, 2010 2:39:58 AM}", 1291253998000, GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 58)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 39)
+ .Add(DatetimeComponent::ComponentType::HOUR, 2)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 2)
+ .Add(DatetimeComponent::ComponentType::MONTH, 12)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2010)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{Jun 09 2011 15:28:14}", 1307626094000, GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::SECOND, 14)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 28)
+ .Add(DatetimeComponent::ComponentType::HOUR, 15)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 9)
+ .Add(DatetimeComponent::ComponentType::MONTH, 6)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2011)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{2010-06-26 02:31:29}", {1277512289000, 1277555489000},
+ GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 29)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 31)
+ .Add(DatetimeComponent::ComponentType::HOUR, 2)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 26)
+ .Add(DatetimeComponent::ComponentType::MONTH, 6)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2010)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 29)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 31)
+ .Add(DatetimeComponent::ComponentType::HOUR, 2)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 26)
+ .Add(DatetimeComponent::ComponentType::MONTH, 6)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2010)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{2006/01/22 04:11:05}", {1137899465000, 1137942665000},
+ GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 5)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 11)
+ .Add(DatetimeComponent::ComponentType::HOUR, 4)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 22)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2006)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 5)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 11)
+ .Add(DatetimeComponent::ComponentType::HOUR, 4)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 22)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2006)
+ .Build()}));
+ EXPECT_TRUE(
+ ParsesCorrectly("{11:42:35}", {38555000, 81755000}, GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{23/Apr 11:42:35}", {9715355000, 9758555000}, GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{23/Apr/2015 11:42:35}", {1429782155000, 1429825355000},
+ GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{23-Apr-2015 11:42:35}", {1429782155000, 1429825355000},
+ GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{23 Apr 2015 11:42:35}", {1429782155000, 1429825355000},
+ GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "Are sentiments apartments decisively the especially alteration. "
+ "Thrown shy denote ten ladies though ask saw. Or by to he going "
+ "think order event music. Incommode so intention defective at "
+ "convinced. Led income months itself and houses you. After nor "
+ "you leave might share court balls. {19/apr/2010 06:36:15} Are "
+ "sentiments apartments decisively the especially alteration. "
+ "Thrown shy denote ten ladies though ask saw. Or by to he going "
+ "think order event music. Incommode so intention defective at "
+ "convinced. Led income months itself and houses you. After nor "
+ "you leave might share court balls. ",
+ {1271651775000, 1271694975000}, GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 15)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 36)
+ .Add(DatetimeComponent::ComponentType::HOUR, 6)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 19)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2010)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 15)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 36)
+ .Add(DatetimeComponent::ComponentType::HOUR, 6)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 19)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2010)
+ .Build()}));
+
+ EXPECT_TRUE(ParsesCorrectly(
+ "{january 1 2018 at 4:30}", {1514777400000, 1514820600000},
+ GRANULARITY_MINUTE,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 30)
+ .Add(DatetimeComponent::ComponentType::HOUR, 4)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2018)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 30)
+ .Add(DatetimeComponent::ComponentType::HOUR, 4)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2018)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{january 1 2018 at 4:30 am}", 1514777400000, GRANULARITY_MINUTE,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 30)
+ .Add(DatetimeComponent::ComponentType::HOUR, 4)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2018)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{january 1 2018 at 4pm}", 1514818800000, GRANULARITY_HOUR,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::HOUR, 4)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2018)
+ .Build()}));
+
+ EXPECT_TRUE(ParsesCorrectly(
+ "{today at 0:00}", {-3600000, 39600000}, GRANULARITY_MINUTE,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 0)
+ .Add(DatetimeComponent::ComponentType::HOUR, 0)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
+ DatetimeComponent::RelativeQualifier::NOW, 0)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 0)
+ .Add(DatetimeComponent::ComponentType::HOUR, 0)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
+ DatetimeComponent::RelativeQualifier::NOW, 0)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{today at 0:00}", {-57600000, -14400000}, GRANULARITY_MINUTE,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 0)
+ .Add(DatetimeComponent::ComponentType::HOUR, 0)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
+ DatetimeComponent::RelativeQualifier::NOW, 0)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 0)
+ .Add(DatetimeComponent::ComponentType::HOUR, 0)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
+ DatetimeComponent::RelativeQualifier::NOW, 0)
+ .Build()},
+ /*anchor_start_end=*/false, "America/Los_Angeles"));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{tomorrow at 4:00}", {97200000, 140400000}, GRANULARITY_MINUTE,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 0)
+ .Add(DatetimeComponent::ComponentType::HOUR, 4)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
+ DatetimeComponent::RelativeQualifier::TOMORROW, 1)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 0)
+ .Add(DatetimeComponent::ComponentType::HOUR, 4)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
+ DatetimeComponent::RelativeQualifier::TOMORROW, 1)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{tomorrow at 4am}", 97200000, GRANULARITY_HOUR,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::HOUR, 4)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
+ DatetimeComponent::RelativeQualifier::TOMORROW, 1)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "last seen {today at 9:01 PM}", 72060000, GRANULARITY_MINUTE,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 1)
+ .Add(DatetimeComponent::ComponentType::HOUR, 9)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
+ DatetimeComponent::RelativeQualifier::NOW, 0)
+ .Build()}));
+ EXPECT_TRUE(
+ ParsesCorrectly("set an alarm for {7 a.m}", 21600000, GRANULARITY_HOUR,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::HOUR, 7)
+ .Build()}));
+
+ EXPECT_TRUE(ParsesCorrectly(
+ "{04/23/15 11:42:35}", {1429782155000, 1429825355000}, GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build()}));
+
+ EXPECT_TRUE(ParsesCorrectly(
+ "{04/23/2015 11:42:35}", {1429782155000, 1429825355000},
+ GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build()}));
+
+ EXPECT_TRUE(ParsesCorrectly(
+ "{9/28/2011 2:23:15 PM}", 1317212595000, GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 15)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 23)
+ .Add(DatetimeComponent::ComponentType::HOUR, 2)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 28)
+ .Add(DatetimeComponent::ComponentType::MONTH, 9)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2011)
+ .Build()}));
+}
+
+TEST_F(GrammarDatetimeParserTest, DateValidation) {
+ EXPECT_TRUE(ParsesCorrectly(
+ "{01/02/2020}", 1577919600000, GRANULARITY_DAY,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 2)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2020)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{21/02/2020}", 1582239600000, GRANULARITY_DAY,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 21)
+ .Add(DatetimeComponent::ComponentType::MONTH, 2)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2020)
+ .Build()}));
+}
+
+TEST_F(GrammarDatetimeParserTest, OnlyRelativeDatetime) {
+ EXPECT_TRUE(
+ ParsesCorrectly("{in 3 hours}", 10800000, GRANULARITY_HOUR,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::HOUR, 0,
+ DatetimeComponent::RelativeQualifier::FUTURE, 3)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{wednesday at 4am}", 529200000, GRANULARITY_HOUR,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::HOUR, 4)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_WEEK, 4,
+ DatetimeComponent::RelativeQualifier::THIS, 0)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "set an alarm for {7am tomorrow}", 108000000, GRANULARITY_HOUR,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::HOUR, 7)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
+ DatetimeComponent::RelativeQualifier::TOMORROW, 1)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "called you {last Saturday}",
+ -432000000 /* Fri 1969-12-26 16:00:00 PST */, GRANULARITY_DAY,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::DAY_OF_WEEK, 7,
+ DatetimeComponent::RelativeQualifier::PAST, -1)
+ .Build()},
+ /*anchor_start_end=*/false,
+ /*timezone=*/"Europe/Zurich", /*locales=*/"en-US",
+ /*annotation_usecase=*/AnnotationUsecase_ANNOTATION_USECASE_RAW));
+}
+
+TEST_F(GrammarDatetimeParserTest, NamedMonthDate) {
+ EXPECT_TRUE(ParsesCorrectly(
+ "{march 1, 2017}", 1488355200000, GRANULARITY_DAY,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::MONTH, 3)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2017)
+ .Build()},
+ false, "America/Los_Angeles", "en-US",
+ AnnotationUsecase_ANNOTATION_USECASE_SMART));
+}
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/native/annotator/datetime/parser.h b/native/annotator/datetime/parser.h
index 8b58388..3b3e578 100644
--- a/native/annotator/datetime/parser.h
+++ b/native/annotator/datetime/parser.h
@@ -19,18 +19,13 @@
#include <memory>
#include <string>
-#include <unordered_map>
-#include <unordered_set>
#include <vector>
-#include "annotator/datetime/extractor.h"
-#include "annotator/model_generated.h"
#include "annotator/types.h"
-#include "utils/base/integral_types.h"
-#include "utils/calendar/calendar.h"
+#include "utils/base/statusor.h"
+#include "utils/i18n/locale-list.h"
+#include "utils/i18n/locale.h"
#include "utils/utf8/unicodetext.h"
-#include "utils/utf8/unilib.h"
-#include "utils/zlib/zlib.h"
namespace libtextclassifier3 {
@@ -38,87 +33,25 @@
// time.
class DatetimeParser {
public:
- static std::unique_ptr<DatetimeParser> Instance(
- const DatetimeModel* model, const UniLib* unilib,
- const CalendarLib* calendarlib, ZlibDecompressor* decompressor);
+ virtual ~DatetimeParser() = default;
// Parses the dates in 'input' and fills result. Makes sure that the results
// do not overlap.
// If 'anchor_start_end' is true the extracted results need to start at the
// beginning of 'input' and end at the end of it.
- bool Parse(const std::string& input, int64 reference_time_ms_utc,
- const std::string& reference_timezone, const std::string& locales,
- ModeFlag mode, AnnotationUsecase annotation_usecase,
- bool anchor_start_end,
- std::vector<DatetimeParseResultSpan>* results) const;
+ virtual StatusOr<std::vector<DatetimeParseResultSpan>> Parse(
+ const std::string& input, int64 reference_time_ms_utc,
+ const std::string& reference_timezone, const LocaleList& locale_list,
+ ModeFlag mode, AnnotationUsecase annotation_usecase,
+ bool anchor_start_end) const = 0;
// Same as above but takes UnicodeText.
- bool Parse(const UnicodeText& input, int64 reference_time_ms_utc,
- const std::string& reference_timezone, const std::string& locales,
- ModeFlag mode, AnnotationUsecase annotation_usecase,
- bool anchor_start_end,
- std::vector<DatetimeParseResultSpan>* results) const;
-
- protected:
- explicit DatetimeParser(const DatetimeModel* model, const UniLib* unilib,
- const CalendarLib* calendarlib,
- ZlibDecompressor* decompressor);
-
- // Returns a list of locale ids for given locale spec string (comma-separated
- // locale names). Assigns the first parsed locale to reference_locale.
- std::vector<int> ParseAndExpandLocales(const std::string& locales,
- std::string* reference_locale) const;
-
- // Helper function that finds datetime spans, only using the rules associated
- // with the given locales.
- bool FindSpansUsingLocales(
- const std::vector<int>& locale_ids, const UnicodeText& input,
- const int64 reference_time_ms_utc, const std::string& reference_timezone,
+ virtual StatusOr<std::vector<DatetimeParseResultSpan>> Parse(
+ const UnicodeText& input, int64 reference_time_ms_utc,
+ const std::string& reference_timezone, const LocaleList& locale_list,
ModeFlag mode, AnnotationUsecase annotation_usecase,
- bool anchor_start_end, const std::string& reference_locale,
- std::unordered_set<int>* executed_rules,
- std::vector<DatetimeParseResultSpan>* found_spans) const;
-
- bool ParseWithRule(const CompiledRule& rule, const UnicodeText& input,
- int64 reference_time_ms_utc,
- const std::string& reference_timezone,
- const std::string& reference_locale, const int locale_id,
- bool anchor_start_end,
- std::vector<DatetimeParseResultSpan>* result) const;
-
- // Converts the current match in 'matcher' into DatetimeParseResult.
- bool ExtractDatetime(const CompiledRule& rule,
- const UniLib::RegexMatcher& matcher,
- int64 reference_time_ms_utc,
- const std::string& reference_timezone,
- const std::string& reference_locale, int locale_id,
- std::vector<DatetimeParseResult>* results,
- CodepointSpan* result_span) const;
-
- // Parse and extract information from current match in 'matcher'.
- bool HandleParseMatch(const CompiledRule& rule,
- const UniLib::RegexMatcher& matcher,
- int64 reference_time_ms_utc,
- const std::string& reference_timezone,
- const std::string& reference_locale, int locale_id,
- std::vector<DatetimeParseResultSpan>* result) const;
-
- private:
- bool initialized_;
- const UniLib& unilib_;
- const CalendarLib& calendarlib_;
- std::vector<CompiledRule> rules_;
- std::unordered_map<int, std::vector<int>> locale_to_rules_;
- std::vector<std::unique_ptr<const UniLib::RegexPattern>> extractor_rules_;
- std::unordered_map<DatetimeExtractorType, std::unordered_map<int, int>>
- type_and_locale_to_extractor_rule_;
- std::unordered_map<std::string, int> locale_string_to_id_;
- std::vector<int> default_locale_ids_;
- bool use_extractors_for_locating_;
- bool generate_alternative_interpretations_when_ambiguous_;
- bool prefer_future_for_unspecified_date_;
+ bool anchor_start_end) const = 0;
};
-
} // namespace libtextclassifier3
#endif // LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_PARSER_H_
diff --git a/native/annotator/datetime/parser.cc b/native/annotator/datetime/regex-parser.cc
similarity index 69%
rename from native/annotator/datetime/parser.cc
rename to native/annotator/datetime/regex-parser.cc
index 72fd3ab..4dc9c56 100644
--- a/native/annotator/datetime/parser.cc
+++ b/native/annotator/datetime/regex-parser.cc
@@ -14,33 +14,36 @@
* limitations under the License.
*/
-#include "annotator/datetime/parser.h"
+#include "annotator/datetime/regex-parser.h"
+#include <iterator>
#include <set>
#include <unordered_set>
#include "annotator/datetime/extractor.h"
#include "annotator/datetime/utils.h"
+#include "utils/base/statusor.h"
#include "utils/calendar/calendar.h"
#include "utils/i18n/locale.h"
#include "utils/strings/split.h"
#include "utils/zlib/zlib_regex.h"
namespace libtextclassifier3 {
-std::unique_ptr<DatetimeParser> DatetimeParser::Instance(
+std::unique_ptr<DatetimeParser> RegexDatetimeParser::Instance(
const DatetimeModel* model, const UniLib* unilib,
const CalendarLib* calendarlib, ZlibDecompressor* decompressor) {
- std::unique_ptr<DatetimeParser> result(
- new DatetimeParser(model, unilib, calendarlib, decompressor));
+ std::unique_ptr<RegexDatetimeParser> result(
+ new RegexDatetimeParser(model, unilib, calendarlib, decompressor));
if (!result->initialized_) {
result.reset();
}
return result;
}
-DatetimeParser::DatetimeParser(const DatetimeModel* model, const UniLib* unilib,
- const CalendarLib* calendarlib,
- ZlibDecompressor* decompressor)
+RegexDatetimeParser::RegexDatetimeParser(const DatetimeModel* model,
+ const UniLib* unilib,
+ const CalendarLib* calendarlib,
+ ZlibDecompressor* decompressor)
: unilib_(*unilib), calendarlib_(*calendarlib) {
initialized_ = false;
@@ -113,23 +116,24 @@
initialized_ = true;
}
-bool DatetimeParser::Parse(
+StatusOr<std::vector<DatetimeParseResultSpan>> RegexDatetimeParser::Parse(
const std::string& input, const int64 reference_time_ms_utc,
- const std::string& reference_timezone, const std::string& locales,
- ModeFlag mode, AnnotationUsecase annotation_usecase, bool anchor_start_end,
- std::vector<DatetimeParseResultSpan>* results) const {
+ const std::string& reference_timezone, const LocaleList& locale_list,
+ ModeFlag mode, AnnotationUsecase annotation_usecase,
+ bool anchor_start_end) const {
return Parse(UTF8ToUnicodeText(input, /*do_copy=*/false),
- reference_time_ms_utc, reference_timezone, locales, mode,
- annotation_usecase, anchor_start_end, results);
+ reference_time_ms_utc, reference_timezone, locale_list, mode,
+ annotation_usecase, anchor_start_end);
}
-bool DatetimeParser::FindSpansUsingLocales(
+StatusOr<std::vector<DatetimeParseResultSpan>>
+RegexDatetimeParser::FindSpansUsingLocales(
const std::vector<int>& locale_ids, const UnicodeText& input,
const int64 reference_time_ms_utc, const std::string& reference_timezone,
ModeFlag mode, AnnotationUsecase annotation_usecase, bool anchor_start_end,
const std::string& reference_locale,
- std::unordered_set<int>* executed_rules,
- std::vector<DatetimeParseResultSpan>* found_spans) const {
+ std::unordered_set<int>* executed_rules) const {
+ std::vector<DatetimeParseResultSpan> found_spans;
for (const int locale_id : locale_ids) {
auto rules_it = locale_to_rules_.find(locale_id);
if (rules_it == locale_to_rules_.end()) {
@@ -152,34 +156,33 @@
}
executed_rules->insert(rule_id);
-
- if (!ParseWithRule(rules_[rule_id], input, reference_time_ms_utc,
- reference_timezone, reference_locale, locale_id,
- anchor_start_end, found_spans)) {
- return false;
- }
+ TC3_ASSIGN_OR_RETURN(
+ const std::vector<DatetimeParseResultSpan>& found_spans_per_rule,
+ ParseWithRule(rules_[rule_id], input, reference_time_ms_utc,
+ reference_timezone, reference_locale, locale_id,
+ anchor_start_end));
+ found_spans.insert(std::end(found_spans),
+ std::begin(found_spans_per_rule),
+ std::end(found_spans_per_rule));
}
}
- return true;
+ return found_spans;
}
-bool DatetimeParser::Parse(
+StatusOr<std::vector<DatetimeParseResultSpan>> RegexDatetimeParser::Parse(
const UnicodeText& input, const int64 reference_time_ms_utc,
- const std::string& reference_timezone, const std::string& locales,
- ModeFlag mode, AnnotationUsecase annotation_usecase, bool anchor_start_end,
- std::vector<DatetimeParseResultSpan>* results) const {
- std::vector<DatetimeParseResultSpan> found_spans;
+ const std::string& reference_timezone, const LocaleList& locale_list,
+ ModeFlag mode, AnnotationUsecase annotation_usecase,
+ bool anchor_start_end) const {
std::unordered_set<int> executed_rules;
- std::string reference_locale;
const std::vector<int> requested_locales =
- ParseAndExpandLocales(locales, &reference_locale);
- if (!FindSpansUsingLocales(requested_locales, input, reference_time_ms_utc,
- reference_timezone, mode, annotation_usecase,
- anchor_start_end, reference_locale,
- &executed_rules, &found_spans)) {
- return false;
- }
-
+ ParseAndExpandLocales(locale_list.GetLocaleTags());
+ TC3_ASSIGN_OR_RETURN(
+ const std::vector<DatetimeParseResultSpan>& found_spans,
+ FindSpansUsingLocales(requested_locales, input, reference_time_ms_utc,
+ reference_timezone, mode, annotation_usecase,
+ anchor_start_end, locale_list.GetReferenceLocale(),
+ &executed_rules));
std::vector<std::pair<DatetimeParseResultSpan, int>> indexed_found_spans;
indexed_found_spans.reserve(found_spans.size());
for (int i = 0; i < found_spans.size(); i++) {
@@ -200,39 +203,46 @@
}
});
- found_spans.clear();
+ std::vector<DatetimeParseResultSpan> results;
+ std::vector<DatetimeParseResultSpan> resolved_found_spans;
+ resolved_found_spans.reserve(indexed_found_spans.size());
for (auto& span_index_pair : indexed_found_spans) {
- found_spans.push_back(span_index_pair.first);
+ resolved_found_spans.push_back(span_index_pair.first);
}
std::set<int, std::function<bool(int, int)>> chosen_indices_set(
- [&found_spans](int a, int b) {
- return found_spans[a].span.first < found_spans[b].span.first;
+ [&resolved_found_spans](int a, int b) {
+ return resolved_found_spans[a].span.first <
+ resolved_found_spans[b].span.first;
});
- for (int i = 0; i < found_spans.size(); ++i) {
- if (!DoesCandidateConflict(i, found_spans, chosen_indices_set)) {
+ for (int i = 0; i < resolved_found_spans.size(); ++i) {
+ if (!DoesCandidateConflict(i, resolved_found_spans, chosen_indices_set)) {
chosen_indices_set.insert(i);
- results->push_back(found_spans[i]);
+ results.push_back(resolved_found_spans[i]);
}
}
-
- return true;
+ return results;
}
-bool DatetimeParser::HandleParseMatch(
- const CompiledRule& rule, const UniLib::RegexMatcher& matcher,
- int64 reference_time_ms_utc, const std::string& reference_timezone,
- const std::string& reference_locale, int locale_id,
- std::vector<DatetimeParseResultSpan>* result) const {
+StatusOr<std::vector<DatetimeParseResultSpan>>
+RegexDatetimeParser::HandleParseMatch(const CompiledRule& rule,
+ const UniLib::RegexMatcher& matcher,
+ int64 reference_time_ms_utc,
+ const std::string& reference_timezone,
+ const std::string& reference_locale,
+ int locale_id) const {
+ std::vector<DatetimeParseResultSpan> results;
int status = UniLib::RegexMatcher::kNoError;
const int start = matcher.Start(&status);
if (status != UniLib::RegexMatcher::kNoError) {
- return false;
+ return Status(StatusCode::INTERNAL,
+ "Failed to gets the start offset of the last match.");
}
const int end = matcher.End(&status);
if (status != UniLib::RegexMatcher::kNoError) {
- return false;
+ return Status(StatusCode::INTERNAL,
+ "Failed to gets the end offset of the last match.");
}
DatetimeParseResultSpan parse_result;
@@ -240,7 +250,7 @@
if (!ExtractDatetime(rule, matcher, reference_time_ms_utc, reference_timezone,
reference_locale, locale_id, &alternatives,
&parse_result.span)) {
- return false;
+ return Status(StatusCode::INTERNAL, "Failed to extract Datetime.");
}
if (!use_extractors_for_locating_) {
@@ -257,49 +267,44 @@
parse_result.data.push_back(alternative);
}
}
- result->push_back(parse_result);
- return true;
+ results.push_back(parse_result);
+ return results;
}
-bool DatetimeParser::ParseWithRule(
- const CompiledRule& rule, const UnicodeText& input,
- const int64 reference_time_ms_utc, const std::string& reference_timezone,
- const std::string& reference_locale, const int locale_id,
- bool anchor_start_end, std::vector<DatetimeParseResultSpan>* result) const {
+StatusOr<std::vector<DatetimeParseResultSpan>>
+RegexDatetimeParser::ParseWithRule(const CompiledRule& rule,
+ const UnicodeText& input,
+ const int64 reference_time_ms_utc,
+ const std::string& reference_timezone,
+ const std::string& reference_locale,
+ const int locale_id,
+ bool anchor_start_end) const {
+ std::vector<DatetimeParseResultSpan> results;
std::unique_ptr<UniLib::RegexMatcher> matcher =
rule.compiled_regex->Matcher(input);
int status = UniLib::RegexMatcher::kNoError;
if (anchor_start_end) {
if (matcher->Matches(&status) && status == UniLib::RegexMatcher::kNoError) {
- if (!HandleParseMatch(rule, *matcher, reference_time_ms_utc,
- reference_timezone, reference_locale, locale_id,
- result)) {
- return false;
- }
+ return HandleParseMatch(rule, *matcher, reference_time_ms_utc,
+ reference_timezone, reference_locale, locale_id);
}
} else {
while (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) {
- if (!HandleParseMatch(rule, *matcher, reference_time_ms_utc,
- reference_timezone, reference_locale, locale_id,
- result)) {
- return false;
- }
+ TC3_ASSIGN_OR_RETURN(
+ const std::vector<DatetimeParseResultSpan>& pattern_occurrence,
+ HandleParseMatch(rule, *matcher, reference_time_ms_utc,
+ reference_timezone, reference_locale, locale_id));
+ results.insert(std::end(results), std::begin(pattern_occurrence),
+ std::end(pattern_occurrence));
}
}
- return true;
+ return results;
}
-std::vector<int> DatetimeParser::ParseAndExpandLocales(
- const std::string& locales, std::string* reference_locale) const {
- std::vector<StringPiece> split_locales = strings::Split(locales, ',');
- if (!split_locales.empty()) {
- *reference_locale = split_locales[0].ToString();
- } else {
- *reference_locale = "";
- }
-
+std::vector<int> RegexDatetimeParser::ParseAndExpandLocales(
+ const std::vector<StringPiece>& locales) const {
std::vector<int> result;
- for (const StringPiece& locale_str : split_locales) {
+ for (const StringPiece& locale_str : locales) {
auto locale_it = locale_string_to_id_.find(locale_str.ToString());
if (locale_it != locale_string_to_id_.end()) {
result.push_back(locale_it->second);
@@ -348,14 +353,12 @@
return result;
}
-bool DatetimeParser::ExtractDatetime(const CompiledRule& rule,
- const UniLib::RegexMatcher& matcher,
- const int64 reference_time_ms_utc,
- const std::string& reference_timezone,
- const std::string& reference_locale,
- int locale_id,
- std::vector<DatetimeParseResult>* results,
- CodepointSpan* result_span) const {
+bool RegexDatetimeParser::ExtractDatetime(
+ const CompiledRule& rule, const UniLib::RegexMatcher& matcher,
+ const int64 reference_time_ms_utc, const std::string& reference_timezone,
+ const std::string& reference_locale, int locale_id,
+ std::vector<DatetimeParseResult>* results,
+ CodepointSpan* result_span) const {
DatetimeParsedData parse;
DatetimeExtractor extractor(rule, matcher, locale_id, &unilib_,
extractor_rules_,
diff --git a/native/annotator/datetime/regex-parser.h b/native/annotator/datetime/regex-parser.h
new file mode 100644
index 0000000..e820c21
--- /dev/null
+++ b/native/annotator/datetime/regex-parser.h
@@ -0,0 +1,123 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_REGEX_PARSER_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_REGEX_PARSER_H_
+
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+
+#include "annotator/datetime/extractor.h"
+#include "annotator/datetime/parser.h"
+#include "annotator/model_generated.h"
+#include "annotator/types.h"
+#include "utils/base/integral_types.h"
+#include "utils/base/statusor.h"
+#include "utils/calendar/calendar.h"
+#include "utils/strings/stringpiece.h"
+#include "utils/utf8/unicodetext.h"
+#include "utils/utf8/unilib.h"
+#include "utils/zlib/zlib.h"
+
+namespace libtextclassifier3 {
+
+// Parses datetime expressions in the input and resolves them to actual absolute
+// time.
+class RegexDatetimeParser : public DatetimeParser {
+ public:
+ static std::unique_ptr<DatetimeParser> Instance(
+ const DatetimeModel* model, const UniLib* unilib,
+ const CalendarLib* calendarlib, ZlibDecompressor* decompressor);
+
+ // Parses the dates in 'input' and fills result. Makes sure that the results
+ // do not overlap.
+ // If 'anchor_start_end' is true the extracted results need to start at the
+ // beginning of 'input' and end at the end of it.
+ StatusOr<std::vector<DatetimeParseResultSpan>> Parse(
+ const std::string& input, int64 reference_time_ms_utc,
+ const std::string& reference_timezone, const LocaleList& locale_list,
+ ModeFlag mode, AnnotationUsecase annotation_usecase,
+ bool anchor_start_end) const override;
+
+ // Same as above but takes UnicodeText.
+ StatusOr<std::vector<DatetimeParseResultSpan>> Parse(
+ const UnicodeText& input, int64 reference_time_ms_utc,
+ const std::string& reference_timezone, const LocaleList& locale_list,
+ ModeFlag mode, AnnotationUsecase annotation_usecase,
+ bool anchor_start_end) const override;
+
+ protected:
+ explicit RegexDatetimeParser(const DatetimeModel* model, const UniLib* unilib,
+ const CalendarLib* calendarlib,
+ ZlibDecompressor* decompressor);
+
+ // Returns a list of locale ids for given locale spec string (collection of
+ // locale names).
+ std::vector<int> ParseAndExpandLocales(
+ const std::vector<StringPiece>& locales) const;
+
+ // Helper function that finds datetime spans, only using the rules associated
+ // with the given locales.
+ StatusOr<std::vector<DatetimeParseResultSpan>> FindSpansUsingLocales(
+ const std::vector<int>& locale_ids, const UnicodeText& input,
+ const int64 reference_time_ms_utc, const std::string& reference_timezone,
+ ModeFlag mode, AnnotationUsecase annotation_usecase,
+ bool anchor_start_end, const std::string& reference_locale,
+ std::unordered_set<int>* executed_rules) const;
+
+ StatusOr<std::vector<DatetimeParseResultSpan>> ParseWithRule(
+ const CompiledRule& rule, const UnicodeText& input,
+ int64 reference_time_ms_utc, const std::string& reference_timezone,
+ const std::string& reference_locale, const int locale_id,
+ bool anchor_start_end) const;
+
+ // Converts the current match in 'matcher' into DatetimeParseResult.
+ bool ExtractDatetime(const CompiledRule& rule,
+ const UniLib::RegexMatcher& matcher,
+ int64 reference_time_ms_utc,
+ const std::string& reference_timezone,
+ const std::string& reference_locale, int locale_id,
+ std::vector<DatetimeParseResult>* results,
+ CodepointSpan* result_span) const;
+
+ // Parse and extract information from current match in 'matcher'.
+ StatusOr<std::vector<DatetimeParseResultSpan>> HandleParseMatch(
+ const CompiledRule& rule, const UniLib::RegexMatcher& matcher,
+ int64 reference_time_ms_utc, const std::string& reference_timezone,
+ const std::string& reference_locale, int locale_id) const;
+
+ private:
+ bool initialized_;
+ const UniLib& unilib_;
+ const CalendarLib& calendarlib_;
+ std::vector<CompiledRule> rules_;
+ std::unordered_map<int, std::vector<int>> locale_to_rules_;
+ std::vector<std::unique_ptr<const UniLib::RegexPattern>> extractor_rules_;
+ std::unordered_map<DatetimeExtractorType, std::unordered_map<int, int>>
+ type_and_locale_to_extractor_rule_;
+ std::unordered_map<std::string, int> locale_string_to_id_;
+ std::vector<int> default_locale_ids_;
+ bool use_extractors_for_locating_;
+ bool generate_alternative_interpretations_when_ambiguous_;
+ bool prefer_future_for_unspecified_date_;
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_REGEX_PARSER_H_
diff --git a/native/annotator/datetime/regex-parser_test.cc b/native/annotator/datetime/regex-parser_test.cc
new file mode 100644
index 0000000..33f14a4
--- /dev/null
+++ b/native/annotator/datetime/regex-parser_test.cc
@@ -0,0 +1,1378 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/datetime/regex-parser.h"
+
+#include <time.h>
+
+#include <fstream>
+#include <iostream>
+#include <memory>
+#include <string>
+
+#include "annotator/annotator.h"
+#include "annotator/datetime/testing/base-parser-test.h"
+#include "annotator/datetime/testing/datetime-component-builder.h"
+#include "annotator/model_generated.h"
+#include "annotator/types-test-util.h"
+#include "utils/i18n/locale-list.h"
+#include "utils/jvm-test-utils.h"
+#include "utils/test-data-test-utils.h"
+#include "utils/testing/annotator.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+using std::vector;
+
+namespace libtextclassifier3 {
+namespace {
+std::string GetModelPath() { return GetTestDataPath("annotator/test_data/"); }
+
+std::string ReadFile(const std::string& file_name) {
+ std::ifstream file_stream(file_name);
+ return std::string(std::istreambuf_iterator<char>(file_stream), {});
+}
+
+class RegexDatetimeParserTest : public DateTimeParserTest {
+ public:
+ void SetUp() override {
+ // Loads default unmodified model. Individual tests can call LoadModel to
+ // make changes.
+ LoadModel(
+ [](ModelT* model) { model->datetime_grammar_model.reset(nullptr); });
+ }
+
+ template <typename Fn>
+ void LoadModel(Fn model_visitor_fn) {
+ std::string model_buffer = ReadFile(GetModelPath() + "test_model.fb");
+ model_buffer_ = ModifyAnnotatorModel(model_buffer, model_visitor_fn);
+ unilib_ = CreateUniLibForTesting();
+ calendarlib_ = CreateCalendarLibForTesting();
+ classifier_ =
+ Annotator::FromUnownedBuffer(model_buffer_.data(), model_buffer_.size(),
+ unilib_.get(), calendarlib_.get());
+ TC3_CHECK(classifier_);
+ parser_ = classifier_->DatetimeParserForTests();
+ TC3_CHECK(parser_);
+ }
+
+ // Exposes the date time parser for tests and evaluations.
+ const DatetimeParser* DatetimeParserForTests() const override {
+ return classifier_->DatetimeParserForTests();
+ }
+
+ private:
+ std::string model_buffer_;
+ std::unique_ptr<Annotator> classifier_;
+ const DatetimeParser* parser_;
+ std::unique_ptr<UniLib> unilib_;
+ std::unique_ptr<CalendarLib> calendarlib_;
+};
+
+// Test with just a few cases to make debugging of general failures easier.
+TEST_F(RegexDatetimeParserTest, ParseShort) {
+ EXPECT_TRUE(ParsesCorrectly(
+ "{January 1, 1988}", 567990000000, GRANULARITY_DAY,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 1988)
+ .Build()}));
+}
+
+TEST_F(RegexDatetimeParserTest, Parse) {
+ EXPECT_TRUE(ParsesCorrectly(
+ "{January 1, 1988}", 567990000000, GRANULARITY_DAY,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 1988)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{january 31 2018}", 1517353200000, GRANULARITY_DAY,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 31)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2018)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "lorem {1 january 2018} ipsum", 1514761200000, GRANULARITY_DAY,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2018)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{09/Mar/2004 22:02:40}", 1078866160000, GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::SECOND, 40)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 02)
+ .Add(DatetimeComponent::ComponentType::HOUR, 22)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 9)
+ .Add(DatetimeComponent::ComponentType::MONTH, 3)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2004)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{Dec 2, 2010 2:39:58 AM}", 1291253998000, GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 58)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 39)
+ .Add(DatetimeComponent::ComponentType::HOUR, 2)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 2)
+ .Add(DatetimeComponent::ComponentType::MONTH, 12)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2010)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{Jun 09 2011 15:28:14}", 1307626094000, GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::SECOND, 14)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 28)
+ .Add(DatetimeComponent::ComponentType::HOUR, 15)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 9)
+ .Add(DatetimeComponent::ComponentType::MONTH, 6)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2011)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{Mar 16 08:12:04}", {6419524000, 6462724000}, GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 4)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 12)
+ .Add(DatetimeComponent::ComponentType::HOUR, 8)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 16)
+ .Add(DatetimeComponent::ComponentType::MONTH, 3)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 4)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 12)
+ .Add(DatetimeComponent::ComponentType::HOUR, 8)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 16)
+ .Add(DatetimeComponent::ComponentType::MONTH, 3)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{2010-06-26 02:31:29}", {1277512289000, 1277555489000},
+ GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 29)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 31)
+ .Add(DatetimeComponent::ComponentType::HOUR, 2)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 26)
+ .Add(DatetimeComponent::ComponentType::MONTH, 6)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2010)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 29)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 31)
+ .Add(DatetimeComponent::ComponentType::HOUR, 2)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 26)
+ .Add(DatetimeComponent::ComponentType::MONTH, 6)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2010)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{2006/01/22 04:11:05}", {1137899465000, 1137942665000},
+ GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 5)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 11)
+ .Add(DatetimeComponent::ComponentType::HOUR, 4)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 22)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2006)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 5)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 11)
+ .Add(DatetimeComponent::ComponentType::HOUR, 4)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 22)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2006)
+ .Build()}));
+ EXPECT_TRUE(
+ ParsesCorrectly("{11:42:35}", {38555000, 81755000}, GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{23/Apr 11:42:35}", {9715355000, 9758555000}, GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{23/Apr/2015 11:42:35}", {1429782155000, 1429825355000},
+ GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{23-Apr-2015 11:42:35}", {1429782155000, 1429825355000},
+ GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{23 Apr 2015 11:42:35}", {1429782155000, 1429825355000},
+ GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{04/23/15 11:42:35}", {1429782155000, 1429825355000}, GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{04/23/2015 11:42:35}", {1429782155000, 1429825355000},
+ GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{9/28/2011 2:23:15 PM}", 1317212595000, GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 15)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 23)
+ .Add(DatetimeComponent::ComponentType::HOUR, 2)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 28)
+ .Add(DatetimeComponent::ComponentType::MONTH, 9)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2011)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "Are sentiments apartments decisively the especially alteration. "
+ "Thrown shy denote ten ladies though ask saw. Or by to he going "
+ "think order event music. Incommode so intention defective at "
+ "convinced. Led income months itself and houses you. After nor "
+ "you leave might share court balls. {19/apr/2010 06:36:15} Are "
+ "sentiments apartments decisively the especially alteration. "
+ "Thrown shy denote ten ladies though ask saw. Or by to he going "
+ "think order event music. Incommode so intention defective at "
+ "convinced. Led income months itself and houses you. After nor "
+ "you leave might share court balls. ",
+ {1271651775000, 1271694975000}, GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 15)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 36)
+ .Add(DatetimeComponent::ComponentType::HOUR, 6)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 19)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2010)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 15)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 36)
+ .Add(DatetimeComponent::ComponentType::HOUR, 6)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 19)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2010)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{january 1 2018 at 4:30}", {1514777400000, 1514820600000},
+ GRANULARITY_MINUTE,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 30)
+ .Add(DatetimeComponent::ComponentType::HOUR, 4)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2018)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 30)
+ .Add(DatetimeComponent::ComponentType::HOUR, 4)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2018)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{january 1 2018 at 4:30 am}", 1514777400000, GRANULARITY_MINUTE,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 30)
+ .Add(DatetimeComponent::ComponentType::HOUR, 4)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2018)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{january 1 2018 at 4pm}", 1514818800000, GRANULARITY_HOUR,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::HOUR, 4)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2018)
+ .Build()}));
+
+ EXPECT_TRUE(ParsesCorrectly(
+ "{today at 0:00}", {-3600000, 39600000}, GRANULARITY_MINUTE,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 0)
+ .Add(DatetimeComponent::ComponentType::HOUR, 0)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
+ DatetimeComponent::RelativeQualifier::NOW, 0)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 0)
+ .Add(DatetimeComponent::ComponentType::HOUR, 0)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
+ DatetimeComponent::RelativeQualifier::NOW, 0)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{today at 0:00}", {-57600000, -14400000}, GRANULARITY_MINUTE,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 0)
+ .Add(DatetimeComponent::ComponentType::HOUR, 0)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
+ DatetimeComponent::RelativeQualifier::NOW, 0)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 0)
+ .Add(DatetimeComponent::ComponentType::HOUR, 0)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
+ DatetimeComponent::RelativeQualifier::NOW, 0)
+ .Build()},
+ /*anchor_start_end=*/false, "America/Los_Angeles"));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{tomorrow at 4:00}", {97200000, 140400000}, GRANULARITY_MINUTE,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 0)
+ .Add(DatetimeComponent::ComponentType::HOUR, 4)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
+ DatetimeComponent::RelativeQualifier::TOMORROW, 1)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 0)
+ .Add(DatetimeComponent::ComponentType::HOUR, 4)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
+ DatetimeComponent::RelativeQualifier::TOMORROW, 1)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{tomorrow at 4am}", 97200000, GRANULARITY_HOUR,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::HOUR, 4)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
+ DatetimeComponent::RelativeQualifier::TOMORROW, 1)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{wednesday at 4am}", 529200000, GRANULARITY_HOUR,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::HOUR, 4)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_WEEK, 4,
+ DatetimeComponent::RelativeQualifier::THIS, 0)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "last seen {today at 9:01 PM}", 72060000, GRANULARITY_MINUTE,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 1)
+ .Add(DatetimeComponent::ComponentType::HOUR, 9)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
+ DatetimeComponent::RelativeQualifier::NOW, 0)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "set an alarm for {7am tomorrow}", 108000000, GRANULARITY_HOUR,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::HOUR, 7)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
+ DatetimeComponent::RelativeQualifier::TOMORROW, 1)
+ .Build()}));
+ EXPECT_TRUE(
+ ParsesCorrectly("set an alarm for {7 a.m}", 21600000, GRANULARITY_HOUR,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::HOUR, 7)
+ .Build()}));
+
+ EXPECT_TRUE(ParsesCorrectly(
+ "lets meet by {noon}", 39600000, GRANULARITY_MINUTE,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 0)
+ .Add(DatetimeComponent::ComponentType::HOUR, 12)
+ .Build()},
+ /*anchor_start_end=*/false,
+ /*timezone=*/"Europe/Zurich", /*locales=*/"en-US",
+ /*annotation_usecase=*/AnnotationUsecase_ANNOTATION_USECASE_RAW));
+
+ EXPECT_TRUE(ParsesCorrectly(
+ "at {midnight}", 82800000, GRANULARITY_MINUTE,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 0)
+ .Add(DatetimeComponent::ComponentType::HOUR, 0)
+ .Build()},
+ /*anchor_start_end=*/false,
+ /*timezone=*/"Europe/Zurich", /*locales=*/"en-US",
+ /*annotation_usecase=*/AnnotationUsecase_ANNOTATION_USECASE_RAW));
+}
+
+TEST_F(RegexDatetimeParserTest, ParseWithAnchor) {
+ EXPECT_TRUE(ParsesCorrectly(
+ "{January 1, 1988}", 567990000000, GRANULARITY_DAY,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 1988)
+ .Build()},
+ /*anchor_start_end=*/false));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{January 1, 1988}", 567990000000, GRANULARITY_DAY,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 1988)
+ .Build()},
+ /*anchor_start_end=*/true));
+ EXPECT_TRUE(ParsesCorrectly(
+ "lorem {1 january 2018} ipsum", 1514761200000, GRANULARITY_DAY,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2018)
+ .Build()},
+ /*anchor_start_end=*/false));
+ EXPECT_TRUE(HasNoResult("lorem 1 january 2018 ipsum",
+ /*anchor_start_end=*/true));
+}
+
+TEST_F(RegexDatetimeParserTest, ParseWithRawUsecase) {
+ // Annotated for RAW usecase.
+ EXPECT_TRUE(ParsesCorrectly(
+ "{tomorrow}", 82800000, GRANULARITY_DAY,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
+ DatetimeComponent::RelativeQualifier::TOMORROW, 1)
+ .Build()},
+ /*anchor_start_end=*/false,
+ /*timezone=*/"Europe/Zurich", /*locales=*/"en-US",
+ /*annotation_usecase=*/AnnotationUsecase_ANNOTATION_USECASE_RAW));
+
+ EXPECT_TRUE(ParsesCorrectly(
+ "call me {in two hours}", 7200000, GRANULARITY_HOUR,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::HOUR, 0,
+ DatetimeComponent::RelativeQualifier::FUTURE, 2)
+ .Build()},
+ /*anchor_start_end=*/false,
+ /*timezone=*/"Europe/Zurich", /*locales=*/"en-US",
+ /*annotation_usecase=*/AnnotationUsecase_ANNOTATION_USECASE_RAW));
+
+ EXPECT_TRUE(ParsesCorrectly(
+ "call me {next month}", 2674800000, GRANULARITY_MONTH,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MONTH, 0,
+ DatetimeComponent::RelativeQualifier::NEXT, 1)
+ .Build()},
+ /*anchor_start_end=*/false,
+ /*timezone=*/"Europe/Zurich", /*locales=*/"en-US",
+ /*annotation_usecase=*/AnnotationUsecase_ANNOTATION_USECASE_RAW));
+ EXPECT_TRUE(ParsesCorrectly(
+ "what's the time {now}", -3600000, GRANULARITY_DAY,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
+ DatetimeComponent::RelativeQualifier::NOW, 0)
+ .Build()},
+ /*anchor_start_end=*/false,
+ /*timezone=*/"Europe/Zurich", /*locales=*/"en-US",
+ /*annotation_usecase=*/AnnotationUsecase_ANNOTATION_USECASE_RAW));
+
+ EXPECT_TRUE(ParsesCorrectly(
+ "call me on {Saturday}", 169200000, GRANULARITY_DAY,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::DAY_OF_WEEK, 7,
+ DatetimeComponent::RelativeQualifier::THIS, 0)
+ .Build()},
+ /*anchor_start_end=*/false,
+ /*timezone=*/"Europe/Zurich", /*locales=*/"en-US",
+ /*annotation_usecase=*/AnnotationUsecase_ANNOTATION_USECASE_RAW));
+
+ // Not annotated for Smart usecase.
+ EXPECT_TRUE(HasNoResult(
+ "{tomorrow}", /*anchor_start_end=*/false,
+ /*timezone=*/"Europe/Zurich",
+ /*annotation_usecase=*/AnnotationUsecase_ANNOTATION_USECASE_SMART));
+}
+
+// For details please see b/155437137
+TEST_F(RegexDatetimeParserTest, PastRelativeDatetime) {
+ EXPECT_TRUE(ParsesCorrectly(
+ "called you {last Saturday}",
+ -432000000 /* Fri 1969-12-26 16:00:00 PST */, GRANULARITY_DAY,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::DAY_OF_WEEK, 7,
+ DatetimeComponent::RelativeQualifier::PAST, -1)
+ .Build()},
+ /*anchor_start_end=*/false,
+ /*timezone=*/"Europe/Zurich", /*locales=*/"en-US",
+ /*annotation_usecase=*/AnnotationUsecase_ANNOTATION_USECASE_RAW));
+
+ EXPECT_TRUE(ParsesCorrectly(
+ "called you {last year}", -31536000000 /* Tue 1968-12-31 16:00:00 PST */,
+ GRANULARITY_YEAR,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::YEAR, 0,
+ DatetimeComponent::RelativeQualifier::PAST, -1)
+ .Build()},
+ /*anchor_start_end=*/false,
+ /*timezone=*/"Europe/Zurich", /*locales=*/"en-US",
+ /*annotation_usecase=*/AnnotationUsecase_ANNOTATION_USECASE_RAW));
+
+ EXPECT_TRUE(ParsesCorrectly(
+ "called you {last month}", -2678400000 /* Sun 1969-11-30 16:00:00 PST */,
+ GRANULARITY_MONTH,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MONTH, 0,
+ DatetimeComponent::RelativeQualifier::PAST, -1)
+ .Build()},
+ /*anchor_start_end=*/false,
+ /*timezone=*/"Europe/Zurich", /*locales=*/"en-US",
+ /*annotation_usecase=*/AnnotationUsecase_ANNOTATION_USECASE_RAW));
+
+ EXPECT_TRUE(ParsesCorrectly(
+ "called you {yesterday}", -90000000, /* Tue 1969-12-30 15:00:00 PST */
+ GRANULARITY_DAY,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
+ DatetimeComponent::RelativeQualifier::YESTERDAY, -1)
+ .Build()},
+ /*anchor_start_end=*/false,
+ /*timezone=*/"Europe/Zurich", /*locales=*/"en-US",
+ /*annotation_usecase=*/AnnotationUsecase_ANNOTATION_USECASE_RAW));
+}
+
+TEST_F(RegexDatetimeParserTest, AddsADayWhenTimeInThePastAndDayNotSpecified) {
+ // ParsesCorrectly uses 0 as the reference time, which corresponds to:
+ // "Thu Jan 01 1970 01:00:00" Zurich time. So if we pass "0:30" here, it means
+ // it is in the past, and so the parser should move this to the next day ->
+ // "Fri Jan 02 1970 00:30:00" Zurich time (b/139112907).
+ EXPECT_TRUE(ParsesCorrectly(
+ "{0:30am}", 84600000L /* 23.5 hours from reference time */,
+ GRANULARITY_MINUTE,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 30)
+ .Add(DatetimeComponent::ComponentType::HOUR, 0)
+ .Build()}));
+}
+
+TEST_F(RegexDatetimeParserTest,
+ DoesNotAddADayWhenTimeInThePastAndDayNotSpecifiedDisabled) {
+ // ParsesCorrectly uses 0 as the reference time, which corresponds to:
+ // "Thu Jan 01 1970 01:00:00" Zurich time. So if we pass "0:30" here, it means
+ // it is in the past. The parameter prefer_future_when_unspecified_day is
+ // disabled, so the parser should annotate this to the same day: "Thu Jan 01
+ // 1970 00:30:00" Zurich time.
+ LoadModel([](ModelT* model) {
+ // In the test model, the prefer_future_for_unspecified_date is true; make
+ // it false only for this test.
+ model->datetime_model->prefer_future_for_unspecified_date = false;
+ model->datetime_grammar_model.reset(nullptr);
+ });
+
+ EXPECT_TRUE(ParsesCorrectly(
+ "{0:30am}", -1800000L /* -30 minutes from reference time */,
+ GRANULARITY_MINUTE,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 30)
+ .Add(DatetimeComponent::ComponentType::HOUR, 0)
+ .Build()}));
+}
+
+TEST_F(RegexDatetimeParserTest, ParsesNoonAndMidnightCorrectly) {
+ EXPECT_TRUE(ParsesCorrectly(
+ "{January 1, 1988 12:30am}", 567991800000, GRANULARITY_MINUTE,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 30)
+ .Add(DatetimeComponent::ComponentType::HOUR, 12)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 1988)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{January 1, 1988 12:30pm}", 568035000000, GRANULARITY_MINUTE,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 30)
+ .Add(DatetimeComponent::ComponentType::HOUR, 12)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 1988)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{tomorrow at 12:00 am}", 82800000, GRANULARITY_MINUTE,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 0)
+ .Add(DatetimeComponent::ComponentType::HOUR, 12)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
+ DatetimeComponent::RelativeQualifier::TOMORROW, 1)
+ .Build()}));
+}
+
+TEST_F(RegexDatetimeParserTest, ParseGerman) {
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "{Januar 1 2018}", 1514761200000, GRANULARITY_DAY,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2018)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "{1/2/2018}", 1517439600000, GRANULARITY_DAY,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::MONTH, 2)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2018)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "lorem {1 Januar 2018} ipsum", 1514761200000, GRANULARITY_DAY,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2018)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "{19/Apr/2010:06:36:15}", {1271651775000, 1271694975000},
+ GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 15)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 36)
+ .Add(DatetimeComponent::ComponentType::HOUR, 6)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 19)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2010)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 15)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 36)
+ .Add(DatetimeComponent::ComponentType::HOUR, 6)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 19)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2010)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "{09/März/2004 22:02:40}", 1078866160000, GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::SECOND, 40)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 02)
+ .Add(DatetimeComponent::ComponentType::HOUR, 22)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 9)
+ .Add(DatetimeComponent::ComponentType::MONTH, 3)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2004)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "{Dez 2, 2010 2:39:58}", {1291253998000, 1291297198000},
+ GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 58)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 39)
+ .Add(DatetimeComponent::ComponentType::HOUR, 2)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 2)
+ .Add(DatetimeComponent::ComponentType::MONTH, 12)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2010)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 58)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 39)
+ .Add(DatetimeComponent::ComponentType::HOUR, 2)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 2)
+ .Add(DatetimeComponent::ComponentType::MONTH, 12)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2010)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "{Juni 09 2011 15:28:14}", 1307626094000, GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::SECOND, 14)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 28)
+ .Add(DatetimeComponent::ComponentType::HOUR, 15)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 9)
+ .Add(DatetimeComponent::ComponentType::MONTH, 6)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2011)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "{März 16 08:12:04}", {6419524000, 6462724000}, GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 4)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 12)
+ .Add(DatetimeComponent::ComponentType::HOUR, 8)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 16)
+ .Add(DatetimeComponent::ComponentType::MONTH, 3)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 4)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 12)
+ .Add(DatetimeComponent::ComponentType::HOUR, 8)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 16)
+ .Add(DatetimeComponent::ComponentType::MONTH, 3)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "{2010-06-26 02:31:29}", {1277512289000, 1277555489000},
+ GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 29)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 31)
+ .Add(DatetimeComponent::ComponentType::HOUR, 2)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 26)
+ .Add(DatetimeComponent::ComponentType::MONTH, 6)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2010)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 29)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 31)
+ .Add(DatetimeComponent::ComponentType::HOUR, 2)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 26)
+ .Add(DatetimeComponent::ComponentType::MONTH, 6)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2010)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "{2006/01/22 04:11:05}", {1137899465000, 1137942665000},
+ GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 5)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 11)
+ .Add(DatetimeComponent::ComponentType::HOUR, 4)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 22)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2006)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 5)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 11)
+ .Add(DatetimeComponent::ComponentType::HOUR, 4)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 22)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2006)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "{11:42:35}", {38555000, 81755000}, GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "{23/Apr 11:42:35}", {9715355000, 9758555000}, GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "{23/Apr/2015:11:42:35}", {1429782155000, 1429825355000},
+ GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "{23/Apr/2015 11:42:35}", {1429782155000, 1429825355000},
+ GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "{23-Apr-2015 11:42:35}", {1429782155000, 1429825355000},
+ GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "{23 Apr 2015 11:42:35}", {1429782155000, 1429825355000},
+ GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "{04/23/15 11:42:35}", {1429782155000, 1429825355000}, GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "{04/23/2015 11:42:35}", {1429782155000, 1429825355000},
+ GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "{19/apr/2010:06:36:15}", {1271651775000, 1271694975000},
+ GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 15)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 36)
+ .Add(DatetimeComponent::ComponentType::HOUR, 6)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 19)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2010)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 15)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 36)
+ .Add(DatetimeComponent::ComponentType::HOUR, 6)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 19)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2010)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "{januar 1 2018 um 4:30}", {1514777400000, 1514820600000},
+ GRANULARITY_MINUTE,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 30)
+ .Add(DatetimeComponent::ComponentType::HOUR, 4)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2018)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 30)
+ .Add(DatetimeComponent::ComponentType::HOUR, 4)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2018)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "{januar 1 2018 um 4:30 nachm}", 1514820600000, GRANULARITY_MINUTE,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 30)
+ .Add(DatetimeComponent::ComponentType::HOUR, 4)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2018)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "{januar 1 2018 um 4 nachm}", 1514818800000, GRANULARITY_HOUR,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::HOUR, 4)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2018)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "{14.03.2017}", 1489446000000, GRANULARITY_DAY,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 14)
+ .Add(DatetimeComponent::ComponentType::MONTH, 3)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2017)
+ .Build()}));
+
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "{morgen 0:00}", {82800000, 126000000}, GRANULARITY_MINUTE,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 0)
+ .Add(DatetimeComponent::ComponentType::HOUR, 0)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
+ DatetimeComponent::RelativeQualifier::TOMORROW, 1)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 0)
+ .Add(DatetimeComponent::ComponentType::HOUR, 0)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
+ DatetimeComponent::RelativeQualifier::TOMORROW, 1)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "{morgen um 4:00}", {97200000, 140400000}, GRANULARITY_MINUTE,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 0)
+ .Add(DatetimeComponent::ComponentType::HOUR, 4)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
+ DatetimeComponent::RelativeQualifier::TOMORROW, 1)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 0)
+ .Add(DatetimeComponent::ComponentType::HOUR, 4)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
+ DatetimeComponent::RelativeQualifier::TOMORROW, 1)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "{morgen um 4 vorm}", 97200000, GRANULARITY_HOUR,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::HOUR, 4)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
+ DatetimeComponent::RelativeQualifier::TOMORROW, 1)
+ .Build()}));
+}
+
+TEST_F(RegexDatetimeParserTest, ParseChinese) {
+ EXPECT_TRUE(ParsesCorrectlyChinese(
+ "{明天 7 上午}", 108000000, GRANULARITY_HOUR,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::HOUR, 7)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
+ DatetimeComponent::RelativeQualifier::TOMORROW, 1)
+ .Build()}));
+}
+
+TEST_F(RegexDatetimeParserTest, ParseNonUs) {
+ auto first_may_2015 =
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::MONTH, 5)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build();
+
+ EXPECT_TRUE(ParsesCorrectly("{1/5/2015}", 1430431200000, GRANULARITY_DAY,
+ {first_may_2015},
+ /*anchor_start_end=*/false,
+ /*timezone=*/"Europe/Zurich",
+ /*locales=*/"en-GB"));
+ EXPECT_TRUE(ParsesCorrectly("{1/5/2015}", 1430431200000, GRANULARITY_DAY,
+ {first_may_2015},
+ /*anchor_start_end=*/false,
+ /*timezone=*/"Europe/Zurich", /*locales=*/"en"));
+}
+
+TEST_F(RegexDatetimeParserTest, ParseUs) {
+ auto five_january_2015 =
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 5)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build();
+
+ EXPECT_TRUE(ParsesCorrectly("{1/5/2015}", 1420412400000, GRANULARITY_DAY,
+ {five_january_2015},
+ /*anchor_start_end=*/false,
+ /*timezone=*/"Europe/Zurich",
+ /*locales=*/"en-US"));
+ EXPECT_TRUE(ParsesCorrectly("{1/5/2015}", 1420412400000, GRANULARITY_DAY,
+ {five_january_2015},
+ /*anchor_start_end=*/false,
+ /*timezone=*/"Europe/Zurich",
+ /*locales=*/"es-US"));
+}
+
+TEST_F(RegexDatetimeParserTest, ParseUnknownLanguage) {
+ EXPECT_TRUE(ParsesCorrectly(
+ "bylo to {31. 12. 2015} v 6 hodin", 1451516400000, GRANULARITY_DAY,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 31)
+ .Add(DatetimeComponent::ComponentType::MONTH, 12)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build()},
+ /*anchor_start_end=*/false,
+ /*timezone=*/"Europe/Zurich", /*locales=*/"xx"));
+}
+
+TEST_F(RegexDatetimeParserTest, WhenAlternativesEnabledGeneratesAlternatives) {
+ LoadModel([](ModelT* model) {
+ model->datetime_model->generate_alternative_interpretations_when_ambiguous =
+ true;
+ });
+
+ EXPECT_TRUE(ParsesCorrectly(
+ "{january 1 2018 at 4:30}", {1514777400000, 1514820600000},
+ GRANULARITY_MINUTE,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 30)
+ .Add(DatetimeComponent::ComponentType::HOUR, 4)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2018)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 30)
+ .Add(DatetimeComponent::ComponentType::HOUR, 4)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2018)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{monday 3pm}", 396000000, GRANULARITY_HOUR,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::HOUR, 3)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_WEEK, 2,
+ DatetimeComponent::RelativeQualifier::THIS, 0)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{monday 3:00}", {352800000, 396000000}, GRANULARITY_MINUTE,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 0)
+ .Add(DatetimeComponent::ComponentType::HOUR, 3)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_WEEK, 2,
+ DatetimeComponent::RelativeQualifier::THIS, 0)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 0)
+ .Add(DatetimeComponent::ComponentType::HOUR, 3)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_WEEK, 2,
+ DatetimeComponent::RelativeQualifier::THIS, 0)
+ .Build()}));
+}
+
+TEST_F(RegexDatetimeParserTest,
+ WhenAlternativesDisabledDoesNotGenerateAlternatives) {
+ LoadModel([](ModelT* model) {
+ model->datetime_model->generate_alternative_interpretations_when_ambiguous =
+ false;
+ model->datetime_grammar_model.reset(nullptr);
+ });
+
+ EXPECT_TRUE(ParsesCorrectly(
+ "{january 1 2018 at 4:30}", 1514777400000, GRANULARITY_MINUTE,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MINUTE, 30)
+ .Add(DatetimeComponent::ComponentType::HOUR, 4)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2018)
+ .Build()}));
+}
+
+class ParserLocaleTest : public testing::Test {
+ public:
+ void SetUp() override;
+ bool HasResult(const std::string& input, const std::string& locales);
+
+ protected:
+ std::unique_ptr<UniLib> unilib_;
+ std::unique_ptr<CalendarLib> calendarlib_;
+ flatbuffers::FlatBufferBuilder builder_;
+ std::unique_ptr<DatetimeParser> parser_;
+};
+
+void AddPattern(const std::string& regex, int locale,
+ std::vector<std::unique_ptr<DatetimeModelPatternT>>* patterns) {
+ patterns->emplace_back(new DatetimeModelPatternT);
+ patterns->back()->regexes.emplace_back(new DatetimeModelPattern_::RegexT);
+ patterns->back()->regexes.back()->pattern = regex;
+ patterns->back()->regexes.back()->groups.push_back(
+ DatetimeGroupType_GROUP_UNUSED);
+ patterns->back()->locales.push_back(locale);
+}
+
+void ParserLocaleTest::SetUp() {
+ DatetimeModelT model;
+ model.use_extractors_for_locating = false;
+ model.locales.clear();
+ model.locales.push_back("en-US");
+ model.locales.push_back("en-CH");
+ model.locales.push_back("zh-Hant");
+ model.locales.push_back("en-*");
+ model.locales.push_back("zh-Hant-*");
+ model.locales.push_back("*-CH");
+ model.locales.push_back("default");
+ model.default_locales.push_back(6);
+
+ AddPattern(/*regex=*/"en-US", /*locale=*/0, &model.patterns);
+ AddPattern(/*regex=*/"en-CH", /*locale=*/1, &model.patterns);
+ AddPattern(/*regex=*/"zh-Hant", /*locale=*/2, &model.patterns);
+ AddPattern(/*regex=*/"en-all", /*locale=*/3, &model.patterns);
+ AddPattern(/*regex=*/"zh-Hant-all", /*locale=*/4, &model.patterns);
+ AddPattern(/*regex=*/"all-CH", /*locale=*/5, &model.patterns);
+ AddPattern(/*regex=*/"default", /*locale=*/6, &model.patterns);
+
+ builder_.Finish(DatetimeModel::Pack(builder_, &model));
+ const DatetimeModel* model_fb =
+ flatbuffers::GetRoot<DatetimeModel>(builder_.GetBufferPointer());
+ ASSERT_TRUE(model_fb);
+
+ unilib_ = CreateUniLibForTesting();
+ calendarlib_ = CreateCalendarLibForTesting();
+ parser_ =
+ RegexDatetimeParser::Instance(model_fb, unilib_.get(), calendarlib_.get(),
+ /*decompressor=*/nullptr);
+ ASSERT_TRUE(parser_);
+}
+
+bool ParserLocaleTest::HasResult(const std::string& input,
+ const std::string& locales) {
+ StatusOr<std::vector<DatetimeParseResultSpan>> results = parser_->Parse(
+ input, /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"", LocaleList::ParseFrom(locales),
+ ModeFlag_ANNOTATION, AnnotationUsecase_ANNOTATION_USECASE_SMART, false);
+ EXPECT_TRUE(results.ok());
+ return results.ValueOrDie().size() == 1;
+}
+
+TEST_F(ParserLocaleTest, English) {
+ EXPECT_TRUE(HasResult("en-US", /*locales=*/"en-US"));
+ EXPECT_FALSE(HasResult("en-CH", /*locales=*/"en-US"));
+ EXPECT_FALSE(HasResult("en-US", /*locales=*/"en-CH"));
+ EXPECT_TRUE(HasResult("en-CH", /*locales=*/"en-CH"));
+ EXPECT_TRUE(HasResult("default", /*locales=*/"en-CH"));
+}
+
+TEST_F(ParserLocaleTest, TraditionalChinese) {
+ EXPECT_TRUE(HasResult("zh-Hant-all", /*locales=*/"zh-Hant"));
+ EXPECT_TRUE(HasResult("zh-Hant-all", /*locales=*/"zh-Hant-TW"));
+ EXPECT_TRUE(HasResult("zh-Hant-all", /*locales=*/"zh-Hant-SG"));
+ EXPECT_FALSE(HasResult("zh-Hant-all", /*locales=*/"zh-SG"));
+ EXPECT_FALSE(HasResult("zh-Hant-all", /*locales=*/"zh"));
+ EXPECT_TRUE(HasResult("default", /*locales=*/"zh"));
+ EXPECT_TRUE(HasResult("default", /*locales=*/"zh-Hant-SG"));
+}
+
+TEST_F(ParserLocaleTest, SwissEnglish) {
+ EXPECT_TRUE(HasResult("all-CH", /*locales=*/"de-CH"));
+ EXPECT_TRUE(HasResult("all-CH", /*locales=*/"en-CH"));
+ EXPECT_TRUE(HasResult("en-all", /*locales=*/"en-CH"));
+ EXPECT_FALSE(HasResult("all-CH", /*locales=*/"de-DE"));
+ EXPECT_TRUE(HasResult("default", /*locales=*/"de-CH"));
+ EXPECT_TRUE(HasResult("default", /*locales=*/"en-CH"));
+}
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/native/annotator/datetime/testing/base-parser-test.cc b/native/annotator/datetime/testing/base-parser-test.cc
new file mode 100644
index 0000000..d8dd723
--- /dev/null
+++ b/native/annotator/datetime/testing/base-parser-test.cc
@@ -0,0 +1,162 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/datetime/testing/base-parser-test.h"
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "utils/i18n/locale-list.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+using std::vector;
+using testing::ElementsAreArray;
+
+namespace libtextclassifier3 {
+
+bool DateTimeParserTest::HasNoResult(const std::string& text,
+ bool anchor_start_end,
+ const std::string& timezone,
+ AnnotationUsecase annotation_usecase) {
+ StatusOr<std::vector<DatetimeParseResultSpan>> results_status =
+ DatetimeParserForTests()->Parse(
+ text, 0, timezone, LocaleList::ParseFrom(/*locale_tags=*/""),
+ ModeFlag_ANNOTATION, annotation_usecase, anchor_start_end);
+ if (!results_status.ok()) {
+ TC3_LOG(ERROR) << text;
+ TC3_CHECK(false);
+ }
+ return results_status.ValueOrDie().empty();
+}
+
+bool DateTimeParserTest::ParsesCorrectly(
+ const std::string& marked_text, const vector<int64>& expected_ms_utcs,
+ DatetimeGranularity expected_granularity,
+ vector<vector<DatetimeComponent>> datetime_components,
+ bool anchor_start_end, const std::string& timezone,
+ const std::string& locales, AnnotationUsecase annotation_usecase) {
+ const UnicodeText marked_text_unicode =
+ UTF8ToUnicodeText(marked_text, /*do_copy=*/false);
+ auto brace_open_it =
+ std::find(marked_text_unicode.begin(), marked_text_unicode.end(), '{');
+ auto brace_end_it =
+ std::find(marked_text_unicode.begin(), marked_text_unicode.end(), '}');
+ TC3_CHECK(brace_open_it != marked_text_unicode.end());
+ TC3_CHECK(brace_end_it != marked_text_unicode.end());
+
+ std::string text;
+ text +=
+ UnicodeText::UTF8Substring(marked_text_unicode.begin(), brace_open_it);
+ text += UnicodeText::UTF8Substring(std::next(brace_open_it), brace_end_it);
+ text += UnicodeText::UTF8Substring(std::next(brace_end_it),
+ marked_text_unicode.end());
+
+ StatusOr<std::vector<DatetimeParseResultSpan>> results_status =
+ DatetimeParserForTests()->Parse(
+ text, 0, timezone, LocaleList::ParseFrom(locales),
+ ModeFlag_ANNOTATION, annotation_usecase, anchor_start_end);
+ if (!results_status.ok()) {
+ TC3_LOG(ERROR) << text;
+ TC3_CHECK(false);
+ }
+ // const std::vector<DatetimeParseResultSpan>& results =
+ // results_status.ValueOrDie();
+ if (results_status.ValueOrDie().empty()) {
+ TC3_LOG(ERROR) << "No results.";
+ return false;
+ }
+
+ const int expected_start_index =
+ std::distance(marked_text_unicode.begin(), brace_open_it);
+ // The -1 below is to account for the opening bracket character.
+ const int expected_end_index =
+ std::distance(marked_text_unicode.begin(), brace_end_it) - 1;
+
+ std::vector<DatetimeParseResultSpan> filtered_results;
+ for (const DatetimeParseResultSpan& result : results_status.ValueOrDie()) {
+ if (SpansOverlap(result.span, {expected_start_index, expected_end_index})) {
+ filtered_results.push_back(result);
+ }
+ }
+ std::vector<DatetimeParseResultSpan> expected{
+ {{expected_start_index, expected_end_index},
+ {},
+ /*target_classification_score=*/1.0,
+ /*priority_score=*/1.0}};
+ expected[0].data.resize(expected_ms_utcs.size());
+ for (int i = 0; i < expected_ms_utcs.size(); i++) {
+ expected[0].data[i] = {expected_ms_utcs[i], expected_granularity,
+ datetime_components[i]};
+ }
+
+ const bool matches =
+ testing::Matches(ElementsAreArray(expected))(filtered_results);
+ if (!matches) {
+ TC3_LOG(ERROR) << "Expected: " << expected[0];
+ if (filtered_results.empty()) {
+ TC3_LOG(ERROR) << "But got no results.";
+ }
+ TC3_LOG(ERROR) << "Actual: " << filtered_results[0];
+ }
+
+ return matches;
+}
+
+bool DateTimeParserTest::ParsesCorrectly(
+ const std::string& marked_text, const int64 expected_ms_utc,
+ DatetimeGranularity expected_granularity,
+ vector<vector<DatetimeComponent>> datetime_components,
+ bool anchor_start_end, const std::string& timezone,
+ const std::string& locales, AnnotationUsecase annotation_usecase) {
+ return ParsesCorrectly(marked_text, vector<int64>{expected_ms_utc},
+ expected_granularity, datetime_components,
+ anchor_start_end, timezone, locales,
+ annotation_usecase);
+}
+
+bool DateTimeParserTest::ParsesCorrectlyGerman(
+ const std::string& marked_text, const vector<int64>& expected_ms_utcs,
+ DatetimeGranularity expected_granularity,
+ vector<vector<DatetimeComponent>> datetime_components) {
+ return ParsesCorrectly(marked_text, expected_ms_utcs, expected_granularity,
+ datetime_components,
+ /*anchor_start_end=*/false,
+ /*timezone=*/"Europe/Zurich", /*locales=*/"de");
+}
+
+bool DateTimeParserTest::ParsesCorrectlyGerman(
+ const std::string& marked_text, const int64 expected_ms_utc,
+ DatetimeGranularity expected_granularity,
+ vector<vector<DatetimeComponent>> datetime_components) {
+ return ParsesCorrectly(marked_text, expected_ms_utc, expected_granularity,
+ datetime_components,
+ /*anchor_start_end=*/false,
+ /*timezone=*/"Europe/Zurich", /*locales=*/"de");
+}
+
+bool DateTimeParserTest::ParsesCorrectlyChinese(
+ const std::string& marked_text, const int64 expected_ms_utc,
+ DatetimeGranularity expected_granularity,
+ vector<vector<DatetimeComponent>> datetime_components) {
+ return ParsesCorrectly(marked_text, expected_ms_utc, expected_granularity,
+ datetime_components,
+ /*anchor_start_end=*/false,
+ /*timezone=*/"Europe/Zurich", /*locales=*/"zh");
+}
+
+} // namespace libtextclassifier3
diff --git a/native/annotator/datetime/testing/base-parser-test.h b/native/annotator/datetime/testing/base-parser-test.h
new file mode 100644
index 0000000..3465a04
--- /dev/null
+++ b/native/annotator/datetime/testing/base-parser-test.h
@@ -0,0 +1,81 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_TESTING_BASE_PARSER_TEST_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_TESTING_BASE_PARSER_TEST_H_
+
+#include <string>
+#include <vector>
+
+#include "annotator/datetime/parser.h"
+#include "annotator/datetime/testing/base-parser-test.h"
+#include "annotator/types.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+
+class DateTimeParserTest : public testing::Test {
+ public:
+ bool HasNoResult(const std::string& text, bool anchor_start_end = false,
+ const std::string& timezone = "Europe/Zurich",
+ AnnotationUsecase annotation_usecase =
+ AnnotationUsecase_ANNOTATION_USECASE_SMART);
+
+ bool ParsesCorrectly(
+ const std::string& marked_text,
+ const std::vector<int64>& expected_ms_utcs,
+ DatetimeGranularity expected_granularity,
+ std::vector<std::vector<DatetimeComponent>> datetime_components,
+ bool anchor_start_end = false,
+ const std::string& timezone = "Europe/Zurich",
+ const std::string& locales = "en-US",
+ AnnotationUsecase annotation_usecase =
+ AnnotationUsecase_ANNOTATION_USECASE_SMART);
+
+ bool ParsesCorrectly(
+ const std::string& marked_text, const int64 expected_ms_utc,
+ DatetimeGranularity expected_granularity,
+ std::vector<std::vector<DatetimeComponent>> datetime_components,
+ bool anchor_start_end = false,
+ const std::string& timezone = "Europe/Zurich",
+ const std::string& locales = "en-US",
+ AnnotationUsecase annotation_usecase =
+ AnnotationUsecase_ANNOTATION_USECASE_SMART);
+
+ bool ParsesCorrectlyGerman(
+ const std::string& marked_text,
+ const std::vector<int64>& expected_ms_utcs,
+ DatetimeGranularity expected_granularity,
+ std::vector<std::vector<DatetimeComponent>> datetime_components);
+
+ bool ParsesCorrectlyGerman(
+ const std::string& marked_text, const int64 expected_ms_utc,
+ DatetimeGranularity expected_granularity,
+ std::vector<std::vector<DatetimeComponent>> datetime_components);
+
+ bool ParsesCorrectlyChinese(
+ const std::string& marked_text, const int64 expected_ms_utc,
+ DatetimeGranularity expected_granularity,
+ std::vector<std::vector<DatetimeComponent>> datetime_components);
+
+ // Exposes the date time parser for tests and evaluations.
+ virtual const DatetimeParser* DatetimeParserForTests() const = 0;
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_TESTING_BASE_PARSER_TEST_H_
diff --git a/native/annotator/datetime/testing/datetime-component-builder.cc b/native/annotator/datetime/testing/datetime-component-builder.cc
new file mode 100644
index 0000000..f0764da
--- /dev/null
+++ b/native/annotator/datetime/testing/datetime-component-builder.cc
@@ -0,0 +1,50 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/datetime/testing/datetime-component-builder.h"
+
+namespace libtextclassifier3 {
+
+DatetimeComponentsBuilder DatetimeComponentsBuilder::Add(
+ DatetimeComponent::ComponentType type, int value) {
+ DatetimeComponent component;
+ component.component_type = type;
+ component.value = value;
+ return AddComponent(component);
+}
+
+DatetimeComponentsBuilder DatetimeComponentsBuilder::Add(
+ DatetimeComponent::ComponentType type, int value,
+ DatetimeComponent::RelativeQualifier relative_qualifier,
+ int relative_count) {
+ DatetimeComponent component;
+ component.component_type = type;
+ component.value = value;
+ component.relative_qualifier = relative_qualifier;
+ component.relative_count = relative_count;
+ return AddComponent(component);
+}
+
+std::vector<DatetimeComponent> DatetimeComponentsBuilder::Build() {
+ return std::move(datetime_components_);
+}
+
+DatetimeComponentsBuilder DatetimeComponentsBuilder::AddComponent(
+ const DatetimeComponent& datetime_component) {
+ datetime_components_.push_back(datetime_component);
+ return *this;
+}
+} // namespace libtextclassifier3
diff --git a/native/annotator/datetime/testing/datetime-component-builder.h b/native/annotator/datetime/testing/datetime-component-builder.h
new file mode 100644
index 0000000..a6a9f36
--- /dev/null
+++ b/native/annotator/datetime/testing/datetime-component-builder.h
@@ -0,0 +1,44 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_TESTING_DATETIME_COMPONENT_BUILDER_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_TESTING_DATETIME_COMPONENT_BUILDER_H_
+
+#include "annotator/types.h"
+
+namespace libtextclassifier3 {
+
+// Builder class to construct the DatetimeComponents and make the test readable.
+class DatetimeComponentsBuilder {
+ public:
+ DatetimeComponentsBuilder Add(DatetimeComponent::ComponentType type,
+ int value);
+
+ DatetimeComponentsBuilder Add(
+ DatetimeComponent::ComponentType type, int value,
+ DatetimeComponent::RelativeQualifier relative_qualifier,
+ int relative_count);
+
+ std::vector<DatetimeComponent> Build();
+
+ private:
+ DatetimeComponentsBuilder AddComponent(
+ const DatetimeComponent& datetime_component);
+ std::vector<DatetimeComponent> datetime_components_;
+};
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_TESTING_DATETIME_COMPONENT_BUILDER_H_
diff --git a/native/annotator/datetime/utils.cc b/native/annotator/datetime/utils.cc
index 30a99a1..d772809 100644
--- a/native/annotator/datetime/utils.cc
+++ b/native/annotator/datetime/utils.cc
@@ -64,4 +64,15 @@
}
}
+int GetAdjustedYear(const int parsed_year) {
+ if (parsed_year < 100) {
+ if (parsed_year < 50) {
+ return parsed_year + 2000;
+ } else {
+ return parsed_year + 1900;
+ }
+ }
+ return parsed_year;
+}
+
} // namespace libtextclassifier3
diff --git a/native/annotator/datetime/utils.h b/native/annotator/datetime/utils.h
index cdf1c8b..297ed1d 100644
--- a/native/annotator/datetime/utils.h
+++ b/native/annotator/datetime/utils.h
@@ -30,6 +30,8 @@
const DatetimeGranularity& granularity,
std::vector<DatetimeParsedData>* interpretations);
+// Logic to decide if XX will be 20XX or 19XX
+int GetAdjustedYear(const int parsed_year);
} // namespace libtextclassifier3
#endif // LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_UTILS_H_
diff --git a/native/annotator/duration/duration.cc b/native/annotator/duration/duration.cc
index 07b9885..c59b8e0 100644
--- a/native/annotator/duration/duration.cc
+++ b/native/annotator/duration/duration.cc
@@ -22,6 +22,7 @@
#include "annotator/collections.h"
#include "annotator/types.h"
#include "utils/base/logging.h"
+#include "utils/base/macros.h"
#include "utils/strings/numbers.h"
#include "utils/utf8/unicodetext.h"
@@ -100,6 +101,24 @@
return result;
}
+// Get the dangling quantity unit e.g. for 2 hours 10, 10 would have the unit
+// "minute".
+DurationUnit GetDanglingQuantityUnit(const DurationUnit main_unit) {
+ switch (main_unit) {
+ case DurationUnit::HOUR:
+ return DurationUnit::MINUTE;
+ case DurationUnit::MINUTE:
+ return DurationUnit::SECOND;
+ case DurationUnit::UNKNOWN:
+ TC3_LOG(ERROR) << "Requesting parse of UNKNOWN duration duration_unit.";
+ TC3_FALLTHROUGH_INTENDED;
+ case DurationUnit::WEEK:
+ case DurationUnit::DAY:
+ case DurationUnit::SECOND:
+ // We only support dangling units for hours and minutes.
+ return DurationUnit::UNKNOWN;
+ }
+}
} // namespace internal
bool DurationAnnotator::ClassifyText(
@@ -201,26 +220,32 @@
const bool parse_ended_without_unit_for_last_mentioned_quantity =
has_quantity;
+ if (parse_ended_without_unit_for_last_mentioned_quantity) {
+ const DurationUnit main_unit = parsed_duration_atoms.rbegin()->unit;
+ if (parsed_duration.plus_half) {
+ // Process "and half" suffix.
+ end_index = quantity_end_index;
+ ParsedDurationAtom atom = ParsedDurationAtom::Half();
+ atom.unit = main_unit;
+ parsed_duration_atoms.push_back(atom);
+ } else if (options_->enable_dangling_quantity_interpretation()) {
+ // Process dangling quantity.
+ ParsedDurationAtom atom;
+ atom.value = parsed_duration.value;
+ atom.unit = GetDanglingQuantityUnit(main_unit);
+ if (atom.unit != DurationUnit::UNKNOWN) {
+ end_index = quantity_end_index;
+ parsed_duration_atoms.push_back(atom);
+ }
+ }
+ }
+
ClassificationResult classification{Collections::Duration(),
options_->score()};
classification.priority_score = options_->priority_score();
classification.duration_ms =
ParsedDurationAtomsToMillis(parsed_duration_atoms);
- // Process suffix expressions like "and half" that don't have the
- // duration_unit explicitly mentioned.
- if (parse_ended_without_unit_for_last_mentioned_quantity) {
- if (parsed_duration.plus_half) {
- end_index = quantity_end_index;
- ParsedDurationAtom atom = ParsedDurationAtom::Half();
- atom.unit = parsed_duration_atoms.rbegin()->unit;
- classification.duration_ms += ParsedDurationAtomsToMillis({atom});
- } else if (options_->enable_dangling_quantity_interpretation()) {
- end_index = quantity_end_index;
- // TODO(b/144752747) Add dangling quantity to duration_ms.
- }
- }
-
result->span = feature_processor_->StripBoundaryCodepoints(
context, {start_index, end_index});
result->classification.push_back(classification);
@@ -256,7 +281,7 @@
break;
}
- int64 value = atom.value;
+ double value = atom.value;
// This condition handles expressions like "an hour", where the quantity is
// not specified. In this case we assume quantity 1. Except for cases like
// "half hour".
@@ -287,8 +312,8 @@
return true;
}
- int32 parsed_value;
- if (ParseInt32(lowercase_token_value.c_str(), &parsed_value)) {
+ double parsed_value;
+ if (ParseDouble(lowercase_token_value.c_str(), &parsed_value)) {
value->value = parsed_value;
return true;
}
diff --git a/native/annotator/duration/duration.h b/native/annotator/duration/duration.h
index db4bdae..1a42ac3 100644
--- a/native/annotator/duration/duration.h
+++ b/native/annotator/duration/duration.h
@@ -98,7 +98,7 @@
internal::DurationUnit unit = internal::DurationUnit::UNKNOWN;
// Quantity of the duration unit.
- int value = 0;
+ double value = 0;
// True, if half an unit was specified (either in addition, or exclusively).
// E.g. "hour and a half".
diff --git a/native/annotator/duration/duration_test.cc b/native/annotator/duration/duration_test.cc
index a0985a2..7c07a72 100644
--- a/native/annotator/duration/duration_test.cc
+++ b/native/annotator/duration/duration_test.cc
@@ -23,7 +23,7 @@
#include "annotator/model_generated.h"
#include "annotator/types-test-util.h"
#include "annotator/types.h"
-#include "utils/test-utils.h"
+#include "utils/tokenizer-utils.h"
#include "utils/utf8/unicodetext.h"
#include "utils/utf8/unilib.h"
#include "gmock/gmock.h"
@@ -435,21 +435,57 @@
3.5 * 60 * 1000)))))));
}
-TEST_F(DurationAnnotatorTest, CorrectlyAnnotatesSpanWithDanglingQuantity) {
+TEST_F(DurationAnnotatorTest, FindsDurationWithDanglingQuantity) {
const UnicodeText text = UTF8ToUnicodeText("20 minutes 10");
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(duration_annotator_.FindAll(
text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
- // TODO(b/144752747) Include test for duration_ms.
EXPECT_THAT(
result,
ElementsAre(
AllOf(Field(&AnnotatedSpan::span, CodepointSpan(0, 13)),
Field(&AnnotatedSpan::classification,
- ElementsAre(AllOf(Field(&ClassificationResult::collection,
- "duration")))))));
+ ElementsAre(AllOf(
+ Field(&ClassificationResult::collection, "duration"),
+ Field(&ClassificationResult::duration_ms,
+ 20 * 60 * 1000 + 10 * 1000)))))));
+}
+
+TEST_F(DurationAnnotatorTest, FindsDurationWithDanglingQuantityNotSupported) {
+ const UnicodeText text = UTF8ToUnicodeText("20 seconds 10");
+ std::vector<Token> tokens = Tokenize(text);
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(duration_annotator_.FindAll(
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+
+ EXPECT_THAT(
+ result,
+ ElementsAre(AllOf(
+ Field(&AnnotatedSpan::span, CodepointSpan(0, 10)),
+ Field(&AnnotatedSpan::classification,
+ ElementsAre(AllOf(
+ Field(&ClassificationResult::collection, "duration"),
+ Field(&ClassificationResult::duration_ms, 20 * 1000)))))));
+}
+
+TEST_F(DurationAnnotatorTest, FindsDurationWithDecimalQuantity) {
+ const UnicodeText text = UTF8ToUnicodeText("in 10.2 hours");
+ std::vector<Token> tokens = Tokenize(text);
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(duration_annotator_.FindAll(
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+
+ EXPECT_THAT(
+ result,
+ ElementsAre(
+ AllOf(Field(&AnnotatedSpan::span, CodepointSpan(3, 13)),
+ Field(&AnnotatedSpan::classification,
+ ElementsAre(AllOf(
+ Field(&ClassificationResult::collection, "duration"),
+ Field(&ClassificationResult::duration_ms,
+ 10 * 60 * 60 * 1000 + 12 * 60 * 1000)))))));
}
const DurationAnnotatorOptions* TestingJapaneseDurationAnnotatorOptions() {
@@ -472,7 +508,7 @@
options.half_expressions.push_back("半");
options.require_quantity = true;
- options.enable_dangling_quantity_interpretation = false;
+ options.enable_dangling_quantity_interpretation = true;
flatbuffers::FlatBufferBuilder builder;
builder.Finish(DurationAnnotatorOptions::Pack(builder, &options));
@@ -545,7 +581,7 @@
EXPECT_THAT(result, IsEmpty());
}
-TEST_F(JapaneseDurationAnnotatorTest, IgnoresDanglingQuantity) {
+TEST_F(JapaneseDurationAnnotatorTest, FindsDurationWithDanglingQuantity) {
const UnicodeText text = UTF8ToUnicodeText("2 分 10 の アラーム");
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
@@ -555,12 +591,12 @@
EXPECT_THAT(
result,
ElementsAre(
- AllOf(Field(&AnnotatedSpan::span, CodepointSpan(0, 3)),
+ AllOf(Field(&AnnotatedSpan::span, CodepointSpan(0, 6)),
Field(&AnnotatedSpan::classification,
ElementsAre(AllOf(
Field(&ClassificationResult::collection, "duration"),
Field(&ClassificationResult::duration_ms,
- 2 * 60 * 1000)))))));
+ 2 * 60 * 1000 + 10 * 1000)))))));
}
} // namespace
diff --git a/native/annotator/entity-data.fbs b/native/annotator/entity-data.fbs
old mode 100755
new mode 100644
index 4c02f6d..f82eb44
--- a/native/annotator/entity-data.fbs
+++ b/native/annotator/entity-data.fbs
@@ -175,11 +175,24 @@
// Whole part of the amount (e.g. 123 from "CHF 123.45").
amount_whole_part:int;
- // Decimal part of the amount (e.g. 45 from "CHF 123.45").
+ // Decimal part of the amount (e.g. 45 from "CHF 123.45"). Will be
+ // deprecated, use nanos instead.
amount_decimal_part:int;
// Money amount (e.g. 123.45 from "CHF 123.45").
unnormalized_amount:string (shared);
+
+ // Number of nano (10^-9) units of the amount fractional part.
+ // The value must be between -999,999,999 and +999,999,999 inclusive.
+ // If `units` is positive, `nanos` must be positive or zero.
+ // If `units` is zero, `nanos` can be positive, zero, or negative.
+ // If `units` is negative, `nanos` must be negative or zero.
+ // For example $-1.75 is represented as `amount_whole_part`=-1 and
+ // `nanos`=-750,000,000.
+ nanos:int;
+
+ // Money quantity (e.g. k from "CHF 123.45k").
+ quantity:string (shared);
}
namespace libtextclassifier3.EntityData_.Translate_;
diff --git a/native/annotator/experimental/experimental-dummy.h b/native/annotator/experimental/experimental-dummy.h
index 389aae1..28eec5f 100644
--- a/native/annotator/experimental/experimental-dummy.h
+++ b/native/annotator/experimental/experimental-dummy.h
@@ -39,7 +39,7 @@
bool Annotate(const UnicodeText& context,
std::vector<AnnotatedSpan>* candidates) const {
- return false;
+ return true;
}
AnnotatedSpan SuggestSelection(const UnicodeText& context,
@@ -47,8 +47,8 @@
return {click, {}};
}
- bool ClassifyText(const UnicodeText& context, CodepointSpan click,
- ClassificationResult* result) const {
+ bool ClassifyText(const UnicodeText& context, CodepointSpan selection_indices,
+ std::vector<AnnotatedSpan>& candidates) const {
return false;
}
};
diff --git a/native/annotator/experimental/experimental.fbs b/native/annotator/experimental/experimental.fbs
old mode 100755
new mode 100644
diff --git a/native/annotator/feature-processor.cc b/native/annotator/feature-processor.cc
index 8d08574..93c3636 100644
--- a/native/annotator/feature-processor.cc
+++ b/native/annotator/feature-processor.cc
@@ -67,8 +67,8 @@
extractor_options.extract_selection_mask_feature =
options->extract_selection_mask_feature();
if (options->regexp_feature() != nullptr) {
- for (const auto& regexp_feauture : *options->regexp_feature()) {
- extractor_options.regexp_features.push_back(regexp_feauture->str());
+ for (const auto& regexp_feature : *options->regexp_feature()) {
+ extractor_options.regexp_features.push_back(regexp_feature->str());
}
}
extractor_options.remap_digits = options->remap_digits();
@@ -82,7 +82,7 @@
return extractor_options;
}
-void SplitTokensOnSelectionBoundaries(CodepointSpan selection,
+void SplitTokensOnSelectionBoundaries(const CodepointSpan& selection,
std::vector<Token>* tokens) {
for (auto it = tokens->begin(); it != tokens->end(); ++it) {
const UnicodeText token_word =
@@ -137,30 +137,27 @@
} // namespace internal
void FeatureProcessor::StripTokensFromOtherLines(
- const std::string& context, CodepointSpan span,
+ const std::string& context, const CodepointSpan& span,
std::vector<Token>* tokens) const {
const UnicodeText context_unicode = UTF8ToUnicodeText(context,
/*do_copy=*/false);
- StripTokensFromOtherLines(context_unicode, span, tokens);
+ const auto [span_begin, span_end] =
+ CodepointSpanToUnicodeTextRange(context_unicode, span);
+ StripTokensFromOtherLines(context_unicode, span_begin, span_end, span,
+ tokens);
}
void FeatureProcessor::StripTokensFromOtherLines(
- const UnicodeText& context_unicode, CodepointSpan span,
+ const UnicodeText& context_unicode,
+ const UnicodeText::const_iterator& span_begin,
+ const UnicodeText::const_iterator& span_end, const CodepointSpan& span,
std::vector<Token>* tokens) const {
std::vector<UnicodeTextRange> lines =
SplitContext(context_unicode, options_->use_pipe_character_for_newline());
- auto span_start = context_unicode.begin();
- if (span.first > 0) {
- std::advance(span_start, span.first);
- }
- auto span_end = context_unicode.begin();
- if (span.second > 0) {
- std::advance(span_end, span.second);
- }
for (const UnicodeTextRange& line : lines) {
// Find the line that completely contains the span.
- if (line.first <= span_start && line.second >= span_end) {
+ if (line.first <= span_begin && line.second >= span_end) {
const CodepointIndex last_line_begin_index =
std::distance(context_unicode.begin(), line.first);
const CodepointIndex last_line_end_index =
@@ -198,9 +195,9 @@
return tokenizer_.Tokenize(text_unicode);
}
-bool FeatureProcessor::LabelToSpan(
- const int label, const VectorSpan<Token>& tokens,
- std::pair<CodepointIndex, CodepointIndex>* span) const {
+bool FeatureProcessor::LabelToSpan(const int label,
+ const VectorSpan<Token>& tokens,
+ CodepointSpan* span) const {
if (tokens.size() != GetNumContextTokens()) {
return false;
}
@@ -221,7 +218,7 @@
if (result_begin_codepoint == kInvalidIndex ||
result_end_codepoint == kInvalidIndex) {
- *span = CodepointSpan({kInvalidIndex, kInvalidIndex});
+ *span = CodepointSpan::kInvalid;
} else {
const UnicodeText token_begin_unicode =
UTF8ToUnicodeText(result_begin_token.value, /*do_copy=*/false);
@@ -241,8 +238,8 @@
if (begin_ignored == (result_end_codepoint - result_begin_codepoint)) {
*span = {result_begin_codepoint, result_begin_codepoint};
} else {
- *span = CodepointSpan({result_begin_codepoint + begin_ignored,
- result_end_codepoint - end_ignored});
+ *span = CodepointSpan(result_begin_codepoint + begin_ignored,
+ result_end_codepoint - end_ignored);
}
}
return true;
@@ -258,9 +255,9 @@
}
}
-bool FeatureProcessor::SpanToLabel(
- const std::pair<CodepointIndex, CodepointIndex>& span,
- const std::vector<Token>& tokens, int* label) const {
+bool FeatureProcessor::SpanToLabel(const CodepointSpan& span,
+ const std::vector<Token>& tokens,
+ int* label) const {
if (tokens.size() != GetNumContextTokens()) {
return false;
}
@@ -323,8 +320,8 @@
return true;
}
-int FeatureProcessor::TokenSpanToLabel(const TokenSpan& span) const {
- auto it = selection_to_label_.find(span);
+int FeatureProcessor::TokenSpanToLabel(const TokenSpan& token_span) const {
+ auto it = selection_to_label_.find(token_span);
if (it != selection_to_label_.end()) {
return it->second;
} else {
@@ -333,10 +330,10 @@
}
TokenSpan CodepointSpanToTokenSpan(const std::vector<Token>& selectable_tokens,
- CodepointSpan codepoint_span,
+ const CodepointSpan& codepoint_span,
bool snap_boundaries_to_containing_tokens) {
- const int codepoint_start = std::get<0>(codepoint_span);
- const int codepoint_end = std::get<1>(codepoint_span);
+ const int codepoint_start = codepoint_span.first;
+ const int codepoint_end = codepoint_span.second;
TokenIndex start_token = kInvalidIndex;
TokenIndex end_token = kInvalidIndex;
@@ -360,18 +357,31 @@
}
CodepointSpan TokenSpanToCodepointSpan(
- const std::vector<Token>& selectable_tokens, TokenSpan token_span) {
+ const std::vector<Token>& selectable_tokens, const TokenSpan& token_span) {
return {selectable_tokens[token_span.first].start,
selectable_tokens[token_span.second - 1].end};
}
+UnicodeTextRange CodepointSpanToUnicodeTextRange(
+ const UnicodeText& unicode_text, const CodepointSpan& span) {
+ auto begin = unicode_text.begin();
+ if (span.first > 0) {
+ std::advance(begin, span.first);
+ }
+ auto end = unicode_text.begin();
+ if (span.second > 0) {
+ std::advance(end, span.second);
+ }
+ return {begin, end};
+}
+
namespace {
// Finds a single token that completely contains the given span.
int FindTokenThatContainsSpan(const std::vector<Token>& selectable_tokens,
- CodepointSpan codepoint_span) {
- const int codepoint_start = std::get<0>(codepoint_span);
- const int codepoint_end = std::get<1>(codepoint_span);
+ const CodepointSpan& codepoint_span) {
+ const int codepoint_start = codepoint_span.first;
+ const int codepoint_end = codepoint_span.second;
for (int i = 0; i < selectable_tokens.size(); ++i) {
if (codepoint_start >= selectable_tokens[i].start &&
@@ -386,12 +396,12 @@
namespace internal {
-int CenterTokenFromClick(CodepointSpan span,
+int CenterTokenFromClick(const CodepointSpan& span,
const std::vector<Token>& selectable_tokens) {
- int range_begin;
- int range_end;
- std::tie(range_begin, range_end) =
+ const TokenSpan token_span =
CodepointSpanToTokenSpan(selectable_tokens, span);
+ int range_begin = token_span.first;
+ int range_end = token_span.second;
// If no exact match was found, try finding a token that completely contains
// the click span. This is useful e.g. when Android builds the selection
@@ -414,11 +424,11 @@
}
int CenterTokenFromMiddleOfSelection(
- CodepointSpan span, const std::vector<Token>& selectable_tokens) {
- int range_begin;
- int range_end;
- std::tie(range_begin, range_end) =
+ const CodepointSpan& span, const std::vector<Token>& selectable_tokens) {
+ const TokenSpan token_span =
CodepointSpanToTokenSpan(selectable_tokens, span);
+ const int range_begin = token_span.first;
+ const int range_end = token_span.second;
// Center the clicked token in the selection range.
if (range_begin != kInvalidIndex && range_end != kInvalidIndex) {
@@ -430,7 +440,7 @@
} // namespace internal
-int FeatureProcessor::FindCenterToken(CodepointSpan span,
+int FeatureProcessor::FindCenterToken(const CodepointSpan& span,
const std::vector<Token>& tokens) const {
if (options_->center_token_selection_method() ==
FeatureProcessorOptions_::
@@ -464,7 +474,7 @@
const VectorSpan<Token> tokens,
std::vector<CodepointSpan>* selection_label_spans) const {
for (int i = 0; i < label_to_selection_.size(); ++i) {
- CodepointSpan span;
+ CodepointSpan span = CodepointSpan::kInvalid;
if (!LabelToSpan(i, tokens, &span)) {
TC3_LOG(ERROR) << "Could not convert label to span: " << i;
return false;
@@ -474,6 +484,13 @@
return true;
}
+bool FeatureProcessor::SelectionLabelRelativeTokenSpans(
+ std::vector<TokenSpan>* selection_label_relative_token_spans) const {
+ selection_label_relative_token_spans->assign(label_to_selection_.begin(),
+ label_to_selection_.end());
+ return true;
+}
+
void FeatureProcessor::PrepareIgnoredSpanBoundaryCodepoints() {
if (options_->ignored_span_boundary_codepoints() != nullptr) {
for (const int codepoint : *options_->ignored_span_boundary_codepoints()) {
@@ -486,15 +503,6 @@
const UnicodeText::const_iterator& span_start,
const UnicodeText::const_iterator& span_end,
bool count_from_beginning) const {
- return CountIgnoredSpanBoundaryCodepoints(span_start, span_end,
- count_from_beginning,
- ignored_span_boundary_codepoints_);
-}
-
-int FeatureProcessor::CountIgnoredSpanBoundaryCodepoints(
- const UnicodeText::const_iterator& span_start,
- const UnicodeText::const_iterator& span_end, bool count_from_beginning,
- const std::unordered_set<int>& ignored_span_boundary_codepoints) const {
if (span_start == span_end) {
return 0;
}
@@ -517,8 +525,8 @@
// Move until we encounter a non-ignored character.
int num_ignored = 0;
- while (ignored_span_boundary_codepoints.find(*it) !=
- ignored_span_boundary_codepoints.end()) {
+ while (ignored_span_boundary_codepoints_.find(*it) !=
+ ignored_span_boundary_codepoints_.end()) {
++num_ignored;
if (it == it_last) {
@@ -571,74 +579,36 @@
}
CodepointSpan FeatureProcessor::StripBoundaryCodepoints(
- const std::string& context, CodepointSpan span) const {
- return StripBoundaryCodepoints(context, span,
- ignored_span_boundary_codepoints_,
- ignored_span_boundary_codepoints_);
-}
-
-CodepointSpan FeatureProcessor::StripBoundaryCodepoints(
- const std::string& context, CodepointSpan span,
- const std::unordered_set<int>& ignored_prefix_span_boundary_codepoints,
- const std::unordered_set<int>& ignored_suffix_span_boundary_codepoints)
- const {
+ const std::string& context, const CodepointSpan& span) const {
const UnicodeText context_unicode =
UTF8ToUnicodeText(context, /*do_copy=*/false);
- return StripBoundaryCodepoints(context_unicode, span,
- ignored_prefix_span_boundary_codepoints,
- ignored_suffix_span_boundary_codepoints);
+ return StripBoundaryCodepoints(context_unicode, span);
}
CodepointSpan FeatureProcessor::StripBoundaryCodepoints(
- const UnicodeText& context_unicode, CodepointSpan span) const {
- return StripBoundaryCodepoints(context_unicode, span,
- ignored_span_boundary_codepoints_,
- ignored_span_boundary_codepoints_);
-}
-
-CodepointSpan FeatureProcessor::StripBoundaryCodepoints(
- const UnicodeText& context_unicode, CodepointSpan span,
- const std::unordered_set<int>& ignored_prefix_span_boundary_codepoints,
- const std::unordered_set<int>& ignored_suffix_span_boundary_codepoints)
- const {
- if (context_unicode.empty() || !ValidNonEmptySpan(span)) {
+ const UnicodeText& context_unicode, const CodepointSpan& span) const {
+ if (context_unicode.empty() || !span.IsValid() || span.IsEmpty()) {
return span;
}
- UnicodeText::const_iterator span_begin = context_unicode.begin();
- std::advance(span_begin, span.first);
- UnicodeText::const_iterator span_end = context_unicode.begin();
- std::advance(span_end, span.second);
+ const auto [span_begin, span_end] =
+ CodepointSpanToUnicodeTextRange(context_unicode, span);
- return StripBoundaryCodepoints(span_begin, span_end, span,
- ignored_prefix_span_boundary_codepoints,
- ignored_suffix_span_boundary_codepoints);
+ return StripBoundaryCodepoints(span_begin, span_end, span);
}
CodepointSpan FeatureProcessor::StripBoundaryCodepoints(
const UnicodeText::const_iterator& span_begin,
- const UnicodeText::const_iterator& span_end, CodepointSpan span) const {
- return StripBoundaryCodepoints(span_begin, span_end, span,
- ignored_span_boundary_codepoints_,
- ignored_span_boundary_codepoints_);
-}
-
-CodepointSpan FeatureProcessor::StripBoundaryCodepoints(
- const UnicodeText::const_iterator& span_begin,
- const UnicodeText::const_iterator& span_end, CodepointSpan span,
- const std::unordered_set<int>& ignored_prefix_span_boundary_codepoints,
- const std::unordered_set<int>& ignored_suffix_span_boundary_codepoints)
- const {
- if (!ValidNonEmptySpan(span) || span_begin == span_end) {
+ const UnicodeText::const_iterator& span_end,
+ const CodepointSpan& span) const {
+ if (!span.IsValid() || span.IsEmpty() || span_begin == span_end) {
return span;
}
const int start_offset = CountIgnoredSpanBoundaryCodepoints(
- span_begin, span_end, /*count_from_beginning=*/true,
- ignored_prefix_span_boundary_codepoints);
+ span_begin, span_end, /*count_from_beginning=*/true);
const int end_offset = CountIgnoredSpanBoundaryCodepoints(
- span_begin, span_end, /*count_from_beginning=*/false,
- ignored_suffix_span_boundary_codepoints);
+ span_begin, span_end, /*count_from_beginning=*/false);
if (span.first + start_offset < span.second - end_offset) {
return {span.first + start_offset, span.second - end_offset};
@@ -670,21 +640,10 @@
const std::string& FeatureProcessor::StripBoundaryCodepoints(
const std::string& value, std::string* buffer) const {
- return StripBoundaryCodepoints(value, buffer,
- ignored_span_boundary_codepoints_,
- ignored_span_boundary_codepoints_);
-}
-
-const std::string& FeatureProcessor::StripBoundaryCodepoints(
- const std::string& value, std::string* buffer,
- const std::unordered_set<int>& ignored_prefix_span_boundary_codepoints,
- const std::unordered_set<int>& ignored_suffix_span_boundary_codepoints)
- const {
const UnicodeText value_unicode = UTF8ToUnicodeText(value, /*do_copy=*/false);
const CodepointSpan initial_span{0, value_unicode.size_codepoints()};
- const CodepointSpan stripped_span = StripBoundaryCodepoints(
- value_unicode, initial_span, ignored_prefix_span_boundary_codepoints,
- ignored_suffix_span_boundary_codepoints);
+ const CodepointSpan stripped_span =
+ StripBoundaryCodepoints(value_unicode, initial_span);
if (initial_span != stripped_span) {
const UnicodeText stripped_token_value =
@@ -735,20 +694,24 @@
}
void FeatureProcessor::RetokenizeAndFindClick(const std::string& context,
- CodepointSpan input_span,
+ const CodepointSpan& input_span,
bool only_use_line_with_click,
std::vector<Token>* tokens,
int* click_pos) const {
const UnicodeText context_unicode =
UTF8ToUnicodeText(context, /*do_copy=*/false);
- RetokenizeAndFindClick(context_unicode, input_span, only_use_line_with_click,
- tokens, click_pos);
+ const auto [span_begin, span_end] =
+ CodepointSpanToUnicodeTextRange(context_unicode, input_span);
+ RetokenizeAndFindClick(context_unicode, span_begin, span_end, input_span,
+ only_use_line_with_click, tokens, click_pos);
}
void FeatureProcessor::RetokenizeAndFindClick(
- const UnicodeText& context_unicode, CodepointSpan input_span,
- bool only_use_line_with_click, std::vector<Token>* tokens,
- int* click_pos) const {
+ const UnicodeText& context_unicode,
+ const UnicodeText::const_iterator& span_begin,
+ const UnicodeText::const_iterator& span_end,
+ const CodepointSpan& input_span, bool only_use_line_with_click,
+ std::vector<Token>* tokens, int* click_pos) const {
TC3_CHECK(tokens != nullptr);
if (options_->split_tokens_on_selection_boundaries()) {
@@ -756,7 +719,8 @@
}
if (only_use_line_with_click) {
- StripTokensFromOtherLines(context_unicode, input_span, tokens);
+ StripTokensFromOtherLines(context_unicode, span_begin, span_end, input_span,
+ tokens);
}
int local_click_pos;
@@ -773,7 +737,7 @@
namespace internal {
-void StripOrPadTokens(TokenSpan relative_click_span, int context_size,
+void StripOrPadTokens(const TokenSpan& relative_click_span, int context_size,
std::vector<Token>* tokens, int* click_pos) {
int right_context_needed = relative_click_span.second + context_size;
if (*click_pos + right_context_needed + 1 >= tokens->size()) {
@@ -810,7 +774,7 @@
} // namespace internal
bool FeatureProcessor::HasEnoughSupportedCodepoints(
- const std::vector<Token>& tokens, TokenSpan token_span) const {
+ const std::vector<Token>& tokens, const TokenSpan& token_span) const {
if (options_->min_supported_codepoint_ratio() > 0) {
const float supported_codepoint_ratio =
SupportedCodepointsRatio(token_span, tokens);
@@ -824,13 +788,13 @@
}
bool FeatureProcessor::ExtractFeatures(
- const std::vector<Token>& tokens, TokenSpan token_span,
- CodepointSpan selection_span_for_feature,
+ const std::vector<Token>& tokens, const TokenSpan& token_span,
+ const CodepointSpan& selection_span_for_feature,
const EmbeddingExecutor* embedding_executor,
EmbeddingCache* embedding_cache, int feature_vector_size,
std::unique_ptr<CachedFeatures>* cached_features) const {
std::unique_ptr<std::vector<float>> features(new std::vector<float>());
- features->reserve(feature_vector_size * TokenSpanSize(token_span));
+ features->reserve(feature_vector_size * token_span.Size());
for (int i = token_span.first; i < token_span.second; ++i) {
if (!AppendTokenFeaturesWithCache(tokens[i], selection_span_for_feature,
embedding_executor, embedding_cache,
@@ -862,7 +826,7 @@
}
bool FeatureProcessor::AppendTokenFeaturesWithCache(
- const Token& token, CodepointSpan selection_span_for_feature,
+ const Token& token, const CodepointSpan& selection_span_for_feature,
const EmbeddingExecutor* embedding_executor,
EmbeddingCache* embedding_cache,
std::vector<float>* output_features) const {
diff --git a/native/annotator/feature-processor.h b/native/annotator/feature-processor.h
index 78dbbce..554727a 100644
--- a/native/annotator/feature-processor.h
+++ b/native/annotator/feature-processor.h
@@ -49,22 +49,23 @@
// Splits tokens that contain the selection boundary inside them.
// E.g. "foo{bar}@google.com" -> "foo", "bar", "@google.com"
-void SplitTokensOnSelectionBoundaries(CodepointSpan selection,
+void SplitTokensOnSelectionBoundaries(const CodepointSpan& selection,
std::vector<Token>* tokens);
// Returns the index of token that corresponds to the codepoint span.
-int CenterTokenFromClick(CodepointSpan span, const std::vector<Token>& tokens);
+int CenterTokenFromClick(const CodepointSpan& span,
+ const std::vector<Token>& tokens);
// Returns the index of token that corresponds to the middle of the codepoint
// span.
int CenterTokenFromMiddleOfSelection(
- CodepointSpan span, const std::vector<Token>& selectable_tokens);
+ const CodepointSpan& span, const std::vector<Token>& selectable_tokens);
// Strips the tokens from the tokens vector that are not used for feature
// extraction because they are out of scope, or pads them so that there is
// enough tokens in the required context_size for all inferences with a click
// in relative_click_span.
-void StripOrPadTokens(TokenSpan relative_click_span, int context_size,
+void StripOrPadTokens(const TokenSpan& relative_click_span, int context_size,
std::vector<Token>* tokens, int* click_pos);
} // namespace internal
@@ -74,12 +75,25 @@
// token to overlap with the codepoint range to be considered part of it.
// Otherwise it must be fully included in the range.
TokenSpan CodepointSpanToTokenSpan(
- const std::vector<Token>& selectable_tokens, CodepointSpan codepoint_span,
+ const std::vector<Token>& selectable_tokens,
+ const CodepointSpan& codepoint_span,
bool snap_boundaries_to_containing_tokens = false);
// Converts a token span to a codepoint span in the given list of tokens.
CodepointSpan TokenSpanToCodepointSpan(
- const std::vector<Token>& selectable_tokens, TokenSpan token_span);
+ const std::vector<Token>& selectable_tokens, const TokenSpan& token_span);
+
+// Converts a codepoint span to a unicode text range, within the given unicode
+// text.
+// For an invalid span (with a negative index), returns (begin, begin). This
+// means that it is safe to call this function before checking the validity of
+// the span.
+// The indices must fit within the unicode text.
+// Note that the execution time is linear with respect to the codepoint indices.
+// Calling this function repeatedly for spans on the same text might lead to
+// inefficient code.
+UnicodeTextRange CodepointSpanToUnicodeTextRange(
+ const UnicodeText& unicode_text, const CodepointSpan& span);
// Takes care of preparing features for the span prediction model.
class FeatureProcessor {
@@ -132,25 +146,29 @@
// Retokenizes the context and input span, and finds the click position.
// Depending on the options, might modify tokens (split them or remove them).
void RetokenizeAndFindClick(const std::string& context,
- CodepointSpan input_span,
+ const CodepointSpan& input_span,
bool only_use_line_with_click,
std::vector<Token>* tokens, int* click_pos) const;
- // Same as above but takes UnicodeText.
+ // Same as above, but takes UnicodeText and iterators within it corresponding
+ // to input_span.
void RetokenizeAndFindClick(const UnicodeText& context_unicode,
- CodepointSpan input_span,
+ const UnicodeText::const_iterator& span_begin,
+ const UnicodeText::const_iterator& span_end,
+ const CodepointSpan& input_span,
bool only_use_line_with_click,
std::vector<Token>* tokens, int* click_pos) const;
// Returns true if the token span has enough supported codepoints (as defined
// in the model config) or not and model should not run.
bool HasEnoughSupportedCodepoints(const std::vector<Token>& tokens,
- TokenSpan token_span) const;
+ const TokenSpan& token_span) const;
// Extracts features as a CachedFeatures object that can be used for repeated
// inference over token spans in the given context.
- bool ExtractFeatures(const std::vector<Token>& tokens, TokenSpan token_span,
- CodepointSpan selection_span_for_feature,
+ bool ExtractFeatures(const std::vector<Token>& tokens,
+ const TokenSpan& token_span,
+ const CodepointSpan& selection_span_for_feature,
const EmbeddingExecutor* embedding_executor,
EmbeddingCache* embedding_cache, int feature_vector_size,
std::unique_ptr<CachedFeatures>* cached_features) const;
@@ -162,6 +180,11 @@
VectorSpan<Token> tokens,
std::vector<CodepointSpan>* selection_label_spans) const;
+ // Fills selection_label_relative_token_spans with number of tokens left and
+ // right from the click.
+ bool SelectionLabelRelativeTokenSpans(
+ std::vector<TokenSpan>* selection_label_relative_token_spans) const;
+
int DenseFeaturesCount() const {
return feature_extractor_.DenseFeaturesCount();
}
@@ -177,38 +200,17 @@
// start and end indices. If the span comprises entirely of boundary
// codepoints, the first index of span is returned for both indices.
CodepointSpan StripBoundaryCodepoints(const std::string& context,
- CodepointSpan span) const;
-
- // Same as previous, but also takes the ignored span boundary codepoints.
- CodepointSpan StripBoundaryCodepoints(
- const std::string& context, CodepointSpan span,
- const std::unordered_set<int>& ignored_prefix_span_boundary_codepoints,
- const std::unordered_set<int>& ignored_suffix_span_boundary_codepoints)
- const;
+ const CodepointSpan& span) const;
// Same as above but takes UnicodeText.
CodepointSpan StripBoundaryCodepoints(const UnicodeText& context_unicode,
- CodepointSpan span) const;
-
- // Same as the previous, but also takes the ignored span boundary codepoints.
- CodepointSpan StripBoundaryCodepoints(
- const UnicodeText& context_unicode, CodepointSpan span,
- const std::unordered_set<int>& ignored_prefix_span_boundary_codepoints,
- const std::unordered_set<int>& ignored_suffix_span_boundary_codepoints)
- const;
+ const CodepointSpan& span) const;
// Same as above but takes a pair of iterators for the span, for efficiency.
CodepointSpan StripBoundaryCodepoints(
const UnicodeText::const_iterator& span_begin,
- const UnicodeText::const_iterator& span_end, CodepointSpan span) const;
-
- // Same as previous, but also takes the ignored span boundary codepoints.
- CodepointSpan StripBoundaryCodepoints(
- const UnicodeText::const_iterator& span_begin,
- const UnicodeText::const_iterator& span_end, CodepointSpan span,
- const std::unordered_set<int>& ignored_prefix_span_boundary_codepoints,
- const std::unordered_set<int>& ignored_suffix_span_boundary_codepoints)
- const;
+ const UnicodeText::const_iterator& span_end,
+ const CodepointSpan& span) const;
// Same as above, but takes an optional buffer for saving the modified value.
// As an optimization, returns pointer to 'value' if nothing was stripped, or
@@ -216,13 +218,6 @@
const std::string& StripBoundaryCodepoints(const std::string& value,
std::string* buffer) const;
- // Same as previous, but also takes the ignored span boundary codepoints.
- const std::string& StripBoundaryCodepoints(
- const std::string& value, std::string* buffer,
- const std::unordered_set<int>& ignored_prefix_span_boundary_codepoints,
- const std::unordered_set<int>& ignored_suffix_span_boundary_codepoints)
- const;
-
protected:
// Returns the class id corresponding to the given string collection
// identifier. There is a catch-all class id that the function returns for
@@ -245,11 +240,11 @@
CodepointSpan* span) const;
// Converts a span to the corresponding label given output_tokens.
- bool SpanToLabel(const std::pair<CodepointIndex, CodepointIndex>& span,
+ bool SpanToLabel(const CodepointSpan& span,
const std::vector<Token>& output_tokens, int* label) const;
// Converts a token span to the corresponding label.
- int TokenSpanToLabel(const std::pair<TokenIndex, TokenIndex>& span) const;
+ int TokenSpanToLabel(const TokenSpan& token_span) const;
// Returns the ratio of supported codepoints to total number of codepoints in
// the given token span.
@@ -268,35 +263,32 @@
const UnicodeText::const_iterator& span_end,
bool count_from_beginning) const;
- // Same as previous, but also takes the ignored span boundary codepoints.
- int CountIgnoredSpanBoundaryCodepoints(
- const UnicodeText::const_iterator& span_start,
- const UnicodeText::const_iterator& span_end, bool count_from_beginning,
- const std::unordered_set<int>& ignored_span_boundary_codepoints) const;
-
// Finds the center token index in tokens vector, using the method defined
// in options_.
- int FindCenterToken(CodepointSpan span,
+ int FindCenterToken(const CodepointSpan& span,
const std::vector<Token>& tokens) const;
// Removes all tokens from tokens that are not on a line (defined by calling
// SplitContext on the context) to which span points.
- void StripTokensFromOtherLines(const std::string& context, CodepointSpan span,
+ void StripTokensFromOtherLines(const std::string& context,
+ const CodepointSpan& span,
std::vector<Token>* tokens) const;
// Same as above but takes UnicodeText.
void StripTokensFromOtherLines(const UnicodeText& context_unicode,
- CodepointSpan span,
+ const UnicodeText::const_iterator& span_begin,
+ const UnicodeText::const_iterator& span_end,
+ const CodepointSpan& span,
std::vector<Token>* tokens) const;
// Extracts the features of a token and appends them to the output vector.
// Uses the embedding cache to to avoid re-extracting the re-embedding the
// sparse features for the same token.
- bool AppendTokenFeaturesWithCache(const Token& token,
- CodepointSpan selection_span_for_feature,
- const EmbeddingExecutor* embedding_executor,
- EmbeddingCache* embedding_cache,
- std::vector<float>* output_features) const;
+ bool AppendTokenFeaturesWithCache(
+ const Token& token, const CodepointSpan& selection_span_for_feature,
+ const EmbeddingExecutor* embedding_executor,
+ EmbeddingCache* embedding_cache,
+ std::vector<float>* output_features) const;
protected:
const TokenFeatureExtractor feature_extractor_;
diff --git a/native/annotator/feature-processor_test.cc b/native/annotator/feature-processor_test.cc
new file mode 100644
index 0000000..86f25e4
--- /dev/null
+++ b/native/annotator/feature-processor_test.cc
@@ -0,0 +1,1050 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/feature-processor.h"
+
+#include "annotator/model-executor.h"
+#include "utils/tensor-view.h"
+#include "utils/utf8/unicodetext.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+using testing::ElementsAreArray;
+using testing::FloatEq;
+using testing::Matcher;
+
+flatbuffers::DetachedBuffer PackFeatureProcessorOptions(
+ const FeatureProcessorOptionsT& options) {
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(CreateFeatureProcessorOptions(builder, &options));
+ return builder.Release();
+}
+
+template <typename T>
+std::vector<T> Subvector(const std::vector<T>& vector, int start, int end) {
+ return std::vector<T>(vector.begin() + start, vector.begin() + end);
+}
+
+Matcher<std::vector<float>> ElementsAreFloat(const std::vector<float>& values) {
+ std::vector<Matcher<float>> matchers;
+ for (const float value : values) {
+ matchers.push_back(FloatEq(value));
+ }
+ return ElementsAreArray(matchers);
+}
+
+class TestingFeatureProcessor : public FeatureProcessor {
+ public:
+ using FeatureProcessor::CountIgnoredSpanBoundaryCodepoints;
+ using FeatureProcessor::FeatureProcessor;
+ using FeatureProcessor::SpanToLabel;
+ using FeatureProcessor::StripTokensFromOtherLines;
+ using FeatureProcessor::supported_codepoint_ranges_;
+ using FeatureProcessor::SupportedCodepointsRatio;
+};
+
+// EmbeddingExecutor that always returns features based on
+class FakeEmbeddingExecutor : public EmbeddingExecutor {
+ public:
+ bool AddEmbedding(const TensorView<int>& sparse_features, float* dest,
+ int dest_size) const override {
+ TC3_CHECK_GE(dest_size, 4);
+ EXPECT_EQ(sparse_features.size(), 1);
+ dest[0] = sparse_features.data()[0];
+ dest[1] = sparse_features.data()[0];
+ dest[2] = -sparse_features.data()[0];
+ dest[3] = -sparse_features.data()[0];
+ return true;
+ }
+
+ private:
+ std::vector<float> storage_;
+};
+
+class AnnotatorFeatureProcessorTest : public ::testing::Test {
+ protected:
+ AnnotatorFeatureProcessorTest() : INIT_UNILIB_FOR_TESTING(unilib_) {}
+ UniLib unilib_;
+};
+
+TEST_F(AnnotatorFeatureProcessorTest, SplitTokensOnSelectionBoundariesMiddle) {
+ std::vector<Token> tokens{Token("Hělló", 0, 5),
+ Token("fěěbař@google.com", 6, 23),
+ Token("heře!", 24, 29)};
+
+ internal::SplitTokensOnSelectionBoundaries({9, 12}, &tokens);
+
+ // clang-format off
+ EXPECT_THAT(tokens, ElementsAreArray(
+ {Token("Hělló", 0, 5),
+ Token("fěě", 6, 9),
+ Token("bař", 9, 12),
+ Token("@google.com", 12, 23),
+ Token("heře!", 24, 29)}));
+ // clang-format on
+}
+
+TEST_F(AnnotatorFeatureProcessorTest, SplitTokensOnSelectionBoundariesBegin) {
+ std::vector<Token> tokens{Token("Hělló", 0, 5),
+ Token("fěěbař@google.com", 6, 23),
+ Token("heře!", 24, 29)};
+
+ internal::SplitTokensOnSelectionBoundaries({6, 12}, &tokens);
+
+ // clang-format off
+ EXPECT_THAT(tokens, ElementsAreArray(
+ {Token("Hělló", 0, 5),
+ Token("fěěbař", 6, 12),
+ Token("@google.com", 12, 23),
+ Token("heře!", 24, 29)}));
+ // clang-format on
+}
+
+TEST_F(AnnotatorFeatureProcessorTest, SplitTokensOnSelectionBoundariesEnd) {
+ std::vector<Token> tokens{Token("Hělló", 0, 5),
+ Token("fěěbař@google.com", 6, 23),
+ Token("heře!", 24, 29)};
+
+ internal::SplitTokensOnSelectionBoundaries({9, 23}, &tokens);
+
+ // clang-format off
+ EXPECT_THAT(tokens, ElementsAreArray(
+ {Token("Hělló", 0, 5),
+ Token("fěě", 6, 9),
+ Token("bař@google.com", 9, 23),
+ Token("heře!", 24, 29)}));
+ // clang-format on
+}
+
+TEST_F(AnnotatorFeatureProcessorTest, SplitTokensOnSelectionBoundariesWhole) {
+ std::vector<Token> tokens{Token("Hělló", 0, 5),
+ Token("fěěbař@google.com", 6, 23),
+ Token("heře!", 24, 29)};
+
+ internal::SplitTokensOnSelectionBoundaries({6, 23}, &tokens);
+
+ // clang-format off
+ EXPECT_THAT(tokens, ElementsAreArray(
+ {Token("Hělló", 0, 5),
+ Token("fěěbař@google.com", 6, 23),
+ Token("heře!", 24, 29)}));
+ // clang-format on
+}
+
+TEST_F(AnnotatorFeatureProcessorTest,
+ SplitTokensOnSelectionBoundariesCrossToken) {
+ std::vector<Token> tokens{Token("Hělló", 0, 5),
+ Token("fěěbař@google.com", 6, 23),
+ Token("heře!", 24, 29)};
+
+ internal::SplitTokensOnSelectionBoundaries({2, 9}, &tokens);
+
+ // clang-format off
+ EXPECT_THAT(tokens, ElementsAreArray(
+ {Token("Hě", 0, 2),
+ Token("lló", 2, 5),
+ Token("fěě", 6, 9),
+ Token("bař@google.com", 9, 23),
+ Token("heře!", 24, 29)}));
+ // clang-format on
+}
+
+TEST_F(AnnotatorFeatureProcessorTest, KeepLineWithClickFirst) {
+ FeatureProcessorOptionsT options;
+ options.only_use_line_with_click = true;
+ flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
+ &unilib_);
+
+ const std::string context = "Fiřst Lině\nSěcond Lině\nThiřd Lině";
+ const CodepointSpan span = {0, 5};
+ // clang-format off
+ std::vector<Token> tokens = {Token("Fiřst", 0, 5),
+ Token("Lině", 6, 10),
+ Token("Sěcond", 11, 17),
+ Token("Lině", 18, 22),
+ Token("Thiřd", 23, 28),
+ Token("Lině", 29, 33)};
+ // clang-format on
+
+ // Keeps the first line.
+ feature_processor.StripTokensFromOtherLines(context, span, &tokens);
+ EXPECT_THAT(tokens,
+ ElementsAreArray({Token("Fiřst", 0, 5), Token("Lině", 6, 10)}));
+}
+
+TEST_F(AnnotatorFeatureProcessorTest, KeepLineWithClickSecond) {
+ FeatureProcessorOptionsT options;
+ options.only_use_line_with_click = true;
+ flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
+ &unilib_);
+
+ const std::string context = "Fiřst Lině\nSěcond Lině\nThiřd Lině";
+ const CodepointSpan span = {18, 22};
+ // clang-format off
+ std::vector<Token> tokens = {Token("Fiřst", 0, 5),
+ Token("Lině", 6, 10),
+ Token("Sěcond", 11, 17),
+ Token("Lině", 18, 22),
+ Token("Thiřd", 23, 28),
+ Token("Lině", 29, 33)};
+ // clang-format on
+
+ // Keeps the first line.
+ feature_processor.StripTokensFromOtherLines(context, span, &tokens);
+ EXPECT_THAT(tokens, ElementsAreArray(
+ {Token("Sěcond", 11, 17), Token("Lině", 18, 22)}));
+}
+
+TEST_F(AnnotatorFeatureProcessorTest, KeepLineWithClickThird) {
+ FeatureProcessorOptionsT options;
+ options.only_use_line_with_click = true;
+ flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
+ &unilib_);
+
+ const std::string context = "Fiřst Lině\nSěcond Lině\nThiřd Lině";
+ const CodepointSpan span = {24, 33};
+ // clang-format off
+ std::vector<Token> tokens = {Token("Fiřst", 0, 5),
+ Token("Lině", 6, 10),
+ Token("Sěcond", 11, 17),
+ Token("Lině", 18, 22),
+ Token("Thiřd", 23, 28),
+ Token("Lině", 29, 33)};
+ // clang-format on
+
+ // Keeps the first line.
+ feature_processor.StripTokensFromOtherLines(context, span, &tokens);
+ EXPECT_THAT(tokens, ElementsAreArray(
+ {Token("Thiřd", 23, 28), Token("Lině", 29, 33)}));
+}
+
+TEST_F(AnnotatorFeatureProcessorTest, KeepLineWithClickSecondWithPipe) {
+ FeatureProcessorOptionsT options;
+ options.only_use_line_with_click = true;
+ flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
+ &unilib_);
+
+ const std::string context = "Fiřst Lině|Sěcond Lině\nThiřd Lině";
+ const CodepointSpan span = {18, 22};
+ // clang-format off
+ std::vector<Token> tokens = {Token("Fiřst", 0, 5),
+ Token("Lině", 6, 10),
+ Token("Sěcond", 11, 17),
+ Token("Lině", 18, 22),
+ Token("Thiřd", 23, 28),
+ Token("Lině", 29, 33)};
+ // clang-format on
+
+ // Keeps the first line.
+ feature_processor.StripTokensFromOtherLines(context, span, &tokens);
+ EXPECT_THAT(tokens, ElementsAreArray(
+ {Token("Sěcond", 11, 17), Token("Lině", 18, 22)}));
+}
+
+TEST_F(AnnotatorFeatureProcessorTest,
+ KeepLineWithClickAndDoNotUsePipeAsNewLineCharacter) {
+ FeatureProcessorOptionsT options;
+ options.only_use_line_with_click = true;
+ options.use_pipe_character_for_newline = false;
+ flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
+ &unilib_);
+
+ const std::string context = "Fiřst Lině|Sěcond Lině\nThiřd Lině";
+ const CodepointSpan span = {18, 22};
+ // clang-format off
+ std::vector<Token> tokens = {Token("Fiřst", 0, 5),
+ Token("Lině|Sěcond", 6, 17),
+ Token("Lině", 18, 22),
+ Token("Thiřd", 23, 28),
+ Token("Lině", 29, 33)};
+ // clang-format on
+
+ // Keeps the first line.
+ feature_processor.StripTokensFromOtherLines(context, span, &tokens);
+ EXPECT_THAT(tokens, ElementsAreArray({Token("Fiřst", 0, 5),
+ Token("Lině|Sěcond", 6, 17),
+ Token("Lině", 18, 22)}));
+}
+
+TEST_F(AnnotatorFeatureProcessorTest, ShouldSplitLinesOnPipe) {
+ FeatureProcessorOptionsT options;
+ options.use_pipe_character_for_newline = true;
+ flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
+ &unilib_);
+
+ const std::string context = "Fiřst Lině|Sěcond Lině\nThiřd Lině";
+ const UnicodeText context_unicode = UTF8ToUnicodeText(context,
+ /*do_copy=*/false);
+
+ const std::vector<UnicodeTextRange>& lines = feature_processor.SplitContext(
+ context_unicode, options.use_pipe_character_for_newline);
+ EXPECT_EQ(lines.size(), 3);
+ EXPECT_EQ(UnicodeText::UTF8Substring(lines[0].first, lines[0].second),
+ "Fiřst Lině");
+ EXPECT_EQ(UnicodeText::UTF8Substring(lines[1].first, lines[1].second),
+ "Sěcond Lině");
+ EXPECT_EQ(UnicodeText::UTF8Substring(lines[2].first, lines[2].second),
+ "Thiřd Lině");
+}
+
+TEST_F(AnnotatorFeatureProcessorTest, ShouldNotSplitLinesOnPipe) {
+ FeatureProcessorOptionsT options;
+ options.use_pipe_character_for_newline = false;
+ flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
+ &unilib_);
+
+ const std::string context = "Fiřst Lině|Sěcond Lině\nThiřd Lině";
+ const UnicodeText context_unicode = UTF8ToUnicodeText(context,
+ /*do_copy=*/false);
+
+ const std::vector<UnicodeTextRange>& lines = feature_processor.SplitContext(
+ context_unicode, options.use_pipe_character_for_newline);
+ EXPECT_EQ(lines.size(), 2);
+ EXPECT_EQ(UnicodeText::UTF8Substring(lines[0].first, lines[0].second),
+ "Fiřst Lině|Sěcond Lině");
+ EXPECT_EQ(UnicodeText::UTF8Substring(lines[1].first, lines[1].second),
+ "Thiřd Lině");
+}
+
+TEST_F(AnnotatorFeatureProcessorTest, KeepLineWithCrosslineClick) {
+ FeatureProcessorOptionsT options;
+ options.only_use_line_with_click = true;
+ flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
+ &unilib_);
+
+ const std::string context = "Fiřst Lině\nSěcond Lině\nThiřd Lině";
+ const CodepointSpan span = {5, 23};
+ // clang-format off
+ std::vector<Token> tokens = {Token("Fiřst", 0, 5),
+ Token("Lině", 6, 10),
+ Token("Sěcond", 18, 23),
+ Token("Lině", 19, 23),
+ Token("Thiřd", 23, 28),
+ Token("Lině", 29, 33)};
+ // clang-format on
+
+ // Keeps the first line.
+ feature_processor.StripTokensFromOtherLines(context, span, &tokens);
+ EXPECT_THAT(tokens, ElementsAreArray(
+ {Token("Fiřst", 0, 5), Token("Lině", 6, 10),
+ Token("Sěcond", 18, 23), Token("Lině", 19, 23),
+ Token("Thiřd", 23, 28), Token("Lině", 29, 33)}));
+}
+
+TEST_F(AnnotatorFeatureProcessorTest, SpanToLabel) {
+ FeatureProcessorOptionsT options;
+ options.context_size = 1;
+ options.max_selection_span = 1;
+ options.snap_label_span_boundaries_to_containing_tokens = false;
+
+ options.tokenization_codepoint_config.emplace_back(
+ new TokenizationCodepointRangeT());
+ auto& config = options.tokenization_codepoint_config.back();
+ config->start = 32;
+ config->end = 33;
+ config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR;
+
+ flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
+ &unilib_);
+ std::vector<Token> tokens = feature_processor.Tokenize("one, two, three");
+ ASSERT_EQ(3, tokens.size());
+ int label;
+ ASSERT_TRUE(feature_processor.SpanToLabel({5, 8}, tokens, &label));
+ EXPECT_EQ(kInvalidLabel, label);
+ ASSERT_TRUE(feature_processor.SpanToLabel({5, 9}, tokens, &label));
+ EXPECT_NE(kInvalidLabel, label);
+ TokenSpan token_span;
+ feature_processor.LabelToTokenSpan(label, &token_span);
+ EXPECT_EQ(0, token_span.first);
+ EXPECT_EQ(0, token_span.second);
+
+ // Reconfigure with snapping enabled.
+ options.snap_label_span_boundaries_to_containing_tokens = true;
+ flatbuffers::DetachedBuffer options2_fb =
+ PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor2(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options2_fb.data()),
+ &unilib_);
+ int label2;
+ ASSERT_TRUE(feature_processor2.SpanToLabel({5, 8}, tokens, &label2));
+ EXPECT_EQ(label, label2);
+ ASSERT_TRUE(feature_processor2.SpanToLabel({6, 9}, tokens, &label2));
+ EXPECT_EQ(label, label2);
+ ASSERT_TRUE(feature_processor2.SpanToLabel({5, 9}, tokens, &label2));
+ EXPECT_EQ(label, label2);
+
+ // Cross a token boundary.
+ ASSERT_TRUE(feature_processor2.SpanToLabel({4, 9}, tokens, &label2));
+ EXPECT_EQ(kInvalidLabel, label2);
+ ASSERT_TRUE(feature_processor2.SpanToLabel({5, 10}, tokens, &label2));
+ EXPECT_EQ(kInvalidLabel, label2);
+
+ // Multiple tokens.
+ options.context_size = 2;
+ options.max_selection_span = 2;
+ flatbuffers::DetachedBuffer options3_fb =
+ PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor3(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options3_fb.data()),
+ &unilib_);
+ tokens = feature_processor3.Tokenize("zero, one, two, three, four");
+ ASSERT_TRUE(feature_processor3.SpanToLabel({6, 15}, tokens, &label2));
+ EXPECT_NE(kInvalidLabel, label2);
+ feature_processor3.LabelToTokenSpan(label2, &token_span);
+ EXPECT_EQ(1, token_span.first);
+ EXPECT_EQ(0, token_span.second);
+
+ int label3;
+ ASSERT_TRUE(feature_processor3.SpanToLabel({6, 14}, tokens, &label3));
+ EXPECT_EQ(label2, label3);
+ ASSERT_TRUE(feature_processor3.SpanToLabel({6, 13}, tokens, &label3));
+ EXPECT_EQ(label2, label3);
+ ASSERT_TRUE(feature_processor3.SpanToLabel({7, 13}, tokens, &label3));
+ EXPECT_EQ(label2, label3);
+}
+
+TEST_F(AnnotatorFeatureProcessorTest, SpanToLabelIgnoresPunctuation) {
+ FeatureProcessorOptionsT options;
+ options.context_size = 1;
+ options.max_selection_span = 1;
+ options.snap_label_span_boundaries_to_containing_tokens = false;
+
+ options.tokenization_codepoint_config.emplace_back(
+ new TokenizationCodepointRangeT());
+ auto& config = options.tokenization_codepoint_config.back();
+ config->start = 32;
+ config->end = 33;
+ config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR;
+
+ flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
+ &unilib_);
+ std::vector<Token> tokens = feature_processor.Tokenize("one, two, three");
+ ASSERT_EQ(3, tokens.size());
+ int label;
+ ASSERT_TRUE(feature_processor.SpanToLabel({5, 8}, tokens, &label));
+ EXPECT_EQ(kInvalidLabel, label);
+ ASSERT_TRUE(feature_processor.SpanToLabel({5, 9}, tokens, &label));
+ EXPECT_NE(kInvalidLabel, label);
+ TokenSpan token_span;
+ feature_processor.LabelToTokenSpan(label, &token_span);
+ EXPECT_EQ(0, token_span.first);
+ EXPECT_EQ(0, token_span.second);
+
+ // Reconfigure with snapping enabled.
+ options.snap_label_span_boundaries_to_containing_tokens = true;
+ flatbuffers::DetachedBuffer options2_fb =
+ PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor2(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options2_fb.data()),
+ &unilib_);
+ int label2;
+ ASSERT_TRUE(feature_processor2.SpanToLabel({5, 8}, tokens, &label2));
+ EXPECT_EQ(label, label2);
+ ASSERT_TRUE(feature_processor2.SpanToLabel({6, 9}, tokens, &label2));
+ EXPECT_EQ(label, label2);
+ ASSERT_TRUE(feature_processor2.SpanToLabel({5, 9}, tokens, &label2));
+ EXPECT_EQ(label, label2);
+
+ // Cross a token boundary.
+ ASSERT_TRUE(feature_processor2.SpanToLabel({4, 9}, tokens, &label2));
+ EXPECT_EQ(kInvalidLabel, label2);
+ ASSERT_TRUE(feature_processor2.SpanToLabel({5, 10}, tokens, &label2));
+ EXPECT_EQ(kInvalidLabel, label2);
+
+ // Multiple tokens.
+ options.context_size = 2;
+ options.max_selection_span = 2;
+ flatbuffers::DetachedBuffer options3_fb =
+ PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor3(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options3_fb.data()),
+ &unilib_);
+ tokens = feature_processor3.Tokenize("zero, one, two, three, four");
+ ASSERT_TRUE(feature_processor3.SpanToLabel({6, 15}, tokens, &label2));
+ EXPECT_NE(kInvalidLabel, label2);
+ feature_processor3.LabelToTokenSpan(label2, &token_span);
+ EXPECT_EQ(1, token_span.first);
+ EXPECT_EQ(0, token_span.second);
+
+ int label3;
+ ASSERT_TRUE(feature_processor3.SpanToLabel({6, 14}, tokens, &label3));
+ EXPECT_EQ(label2, label3);
+ ASSERT_TRUE(feature_processor3.SpanToLabel({6, 13}, tokens, &label3));
+ EXPECT_EQ(label2, label3);
+ ASSERT_TRUE(feature_processor3.SpanToLabel({7, 13}, tokens, &label3));
+ EXPECT_EQ(label2, label3);
+}
+
+TEST_F(AnnotatorFeatureProcessorTest, CenterTokenFromClick) {
+ int token_index;
+
+ // Exactly aligned indices.
+ token_index = internal::CenterTokenFromClick(
+ {6, 11},
+ {Token("Hělló", 0, 5), Token("world", 6, 11), Token("heře!", 12, 17)});
+ EXPECT_EQ(token_index, 1);
+
+ // Click is contained in a token.
+ token_index = internal::CenterTokenFromClick(
+ {13, 17},
+ {Token("Hělló", 0, 5), Token("world", 6, 11), Token("heře!", 12, 17)});
+ EXPECT_EQ(token_index, 2);
+
+ // Click spans two tokens.
+ token_index = internal::CenterTokenFromClick(
+ {6, 17},
+ {Token("Hělló", 0, 5), Token("world", 6, 11), Token("heře!", 12, 17)});
+ EXPECT_EQ(token_index, kInvalidIndex);
+}
+
+TEST_F(AnnotatorFeatureProcessorTest, CenterTokenFromMiddleOfSelection) {
+ int token_index;
+
+ // Selection of length 3. Exactly aligned indices.
+ token_index = internal::CenterTokenFromMiddleOfSelection(
+ {7, 27},
+ {Token("Token1", 0, 6), Token("Token2", 7, 13), Token("Token3", 14, 20),
+ Token("Token4", 21, 27), Token("Token5", 28, 34)});
+ EXPECT_EQ(token_index, 2);
+
+ // Selection of length 1 token. Exactly aligned indices.
+ token_index = internal::CenterTokenFromMiddleOfSelection(
+ {21, 27},
+ {Token("Token1", 0, 6), Token("Token2", 7, 13), Token("Token3", 14, 20),
+ Token("Token4", 21, 27), Token("Token5", 28, 34)});
+ EXPECT_EQ(token_index, 3);
+
+ // Selection marks sub-token range, with no tokens in it.
+ token_index = internal::CenterTokenFromMiddleOfSelection(
+ {29, 33},
+ {Token("Token1", 0, 6), Token("Token2", 7, 13), Token("Token3", 14, 20),
+ Token("Token4", 21, 27), Token("Token5", 28, 34)});
+ EXPECT_EQ(token_index, kInvalidIndex);
+
+ // Selection of length 2. Sub-token indices.
+ token_index = internal::CenterTokenFromMiddleOfSelection(
+ {3, 25},
+ {Token("Token1", 0, 6), Token("Token2", 7, 13), Token("Token3", 14, 20),
+ Token("Token4", 21, 27), Token("Token5", 28, 34)});
+ EXPECT_EQ(token_index, 1);
+
+ // Selection of length 1. Sub-token indices.
+ token_index = internal::CenterTokenFromMiddleOfSelection(
+ {22, 34},
+ {Token("Token1", 0, 6), Token("Token2", 7, 13), Token("Token3", 14, 20),
+ Token("Token4", 21, 27), Token("Token5", 28, 34)});
+ EXPECT_EQ(token_index, 4);
+
+ // Some invalid ones.
+ token_index = internal::CenterTokenFromMiddleOfSelection({7, 27}, {});
+ EXPECT_EQ(token_index, -1);
+}
+
+TEST_F(AnnotatorFeatureProcessorTest, SupportedCodepointsRatio) {
+ FeatureProcessorOptionsT options;
+ options.context_size = 2;
+ options.max_selection_span = 2;
+ options.snap_label_span_boundaries_to_containing_tokens = false;
+ options.feature_version = 2;
+ options.embedding_size = 4;
+ options.bounds_sensitive_features.reset(
+ new FeatureProcessorOptions_::BoundsSensitiveFeaturesT());
+ options.bounds_sensitive_features->enabled = true;
+ options.bounds_sensitive_features->num_tokens_before = 5;
+ options.bounds_sensitive_features->num_tokens_inside_left = 3;
+ options.bounds_sensitive_features->num_tokens_inside_right = 3;
+ options.bounds_sensitive_features->num_tokens_after = 5;
+ options.bounds_sensitive_features->include_inside_bag = true;
+ options.bounds_sensitive_features->include_inside_length = true;
+
+ options.tokenization_codepoint_config.emplace_back(
+ new TokenizationCodepointRangeT());
+ auto& config = options.tokenization_codepoint_config.back();
+ config->start = 32;
+ config->end = 33;
+ config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR;
+
+ {
+ options.supported_codepoint_ranges.emplace_back(new CodepointRangeT());
+ auto& range = options.supported_codepoint_ranges.back();
+ range->start = 0;
+ range->end = 128;
+ }
+
+ {
+ options.supported_codepoint_ranges.emplace_back(new CodepointRangeT());
+ auto& range = options.supported_codepoint_ranges.back();
+ range->start = 10000;
+ range->end = 10001;
+ }
+
+ {
+ options.supported_codepoint_ranges.emplace_back(new CodepointRangeT());
+ auto& range = options.supported_codepoint_ranges.back();
+ range->start = 20000;
+ range->end = 30000;
+ }
+
+ flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
+ &unilib_);
+ EXPECT_THAT(feature_processor.SupportedCodepointsRatio(
+ {0, 3}, feature_processor.Tokenize("aaa bbb ccc")),
+ FloatEq(1.0));
+ EXPECT_THAT(feature_processor.SupportedCodepointsRatio(
+ {0, 3}, feature_processor.Tokenize("aaa bbb ěěě")),
+ FloatEq(2.0 / 3));
+ EXPECT_THAT(feature_processor.SupportedCodepointsRatio(
+ {0, 3}, feature_processor.Tokenize("ěěě řřř ěěě")),
+ FloatEq(0.0));
+ EXPECT_THAT(feature_processor.SupportedCodepointsRatio(
+ {0, 0}, feature_processor.Tokenize("")),
+ FloatEq(0.0));
+ EXPECT_FALSE(
+ IsCodepointInRanges(-1, feature_processor.supported_codepoint_ranges_));
+ EXPECT_TRUE(
+ IsCodepointInRanges(0, feature_processor.supported_codepoint_ranges_));
+ EXPECT_TRUE(
+ IsCodepointInRanges(10, feature_processor.supported_codepoint_ranges_));
+ EXPECT_TRUE(
+ IsCodepointInRanges(127, feature_processor.supported_codepoint_ranges_));
+ EXPECT_FALSE(
+ IsCodepointInRanges(128, feature_processor.supported_codepoint_ranges_));
+ EXPECT_FALSE(
+ IsCodepointInRanges(9999, feature_processor.supported_codepoint_ranges_));
+ EXPECT_TRUE(IsCodepointInRanges(
+ 10000, feature_processor.supported_codepoint_ranges_));
+ EXPECT_FALSE(IsCodepointInRanges(
+ 10001, feature_processor.supported_codepoint_ranges_));
+ EXPECT_TRUE(IsCodepointInRanges(
+ 25000, feature_processor.supported_codepoint_ranges_));
+
+ const std::vector<Token> tokens = {Token("ěěě", 0, 3), Token("řřř", 4, 7),
+ Token("eee", 8, 11)};
+
+ options.min_supported_codepoint_ratio = 0.0;
+ flatbuffers::DetachedBuffer options2_fb =
+ PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor2(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options2_fb.data()),
+ &unilib_);
+ EXPECT_TRUE(feature_processor2.HasEnoughSupportedCodepoints(
+ tokens, /*token_span=*/{0, 3}));
+
+ options.min_supported_codepoint_ratio = 0.2;
+ flatbuffers::DetachedBuffer options3_fb =
+ PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor3(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options3_fb.data()),
+ &unilib_);
+ EXPECT_TRUE(feature_processor3.HasEnoughSupportedCodepoints(
+ tokens, /*token_span=*/{0, 3}));
+
+ options.min_supported_codepoint_ratio = 0.5;
+ flatbuffers::DetachedBuffer options4_fb =
+ PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor4(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options4_fb.data()),
+ &unilib_);
+ EXPECT_FALSE(feature_processor4.HasEnoughSupportedCodepoints(
+ tokens, /*token_span=*/{0, 3}));
+}
+
+TEST_F(AnnotatorFeatureProcessorTest, InSpanFeature) {
+ FeatureProcessorOptionsT options;
+ options.context_size = 2;
+ options.max_selection_span = 2;
+ options.snap_label_span_boundaries_to_containing_tokens = false;
+ options.feature_version = 2;
+ options.embedding_size = 4;
+ options.extract_selection_mask_feature = true;
+
+ flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
+ &unilib_);
+
+ std::unique_ptr<CachedFeatures> cached_features;
+
+ FakeEmbeddingExecutor embedding_executor;
+
+ const std::vector<Token> tokens = {Token("aaa", 0, 3), Token("bbb", 4, 7),
+ Token("ccc", 8, 11), Token("ddd", 12, 15)};
+
+ EXPECT_TRUE(feature_processor.ExtractFeatures(
+ tokens, /*token_span=*/{0, 4},
+ /*selection_span_for_feature=*/{4, 11}, &embedding_executor,
+ /*embedding_cache=*/nullptr, /*feature_vector_size=*/5,
+ &cached_features));
+ std::vector<float> features;
+ cached_features->AppendClickContextFeaturesForClick(1, &features);
+ ASSERT_EQ(features.size(), 25);
+ EXPECT_THAT(features[4], FloatEq(0.0));
+ EXPECT_THAT(features[9], FloatEq(0.0));
+ EXPECT_THAT(features[14], FloatEq(1.0));
+ EXPECT_THAT(features[19], FloatEq(1.0));
+ EXPECT_THAT(features[24], FloatEq(0.0));
+}
+
+TEST_F(AnnotatorFeatureProcessorTest, EmbeddingCache) {
+ FeatureProcessorOptionsT options;
+ options.context_size = 2;
+ options.max_selection_span = 2;
+ options.snap_label_span_boundaries_to_containing_tokens = false;
+ options.feature_version = 2;
+ options.embedding_size = 4;
+ options.bounds_sensitive_features.reset(
+ new FeatureProcessorOptions_::BoundsSensitiveFeaturesT());
+ options.bounds_sensitive_features->enabled = true;
+ options.bounds_sensitive_features->num_tokens_before = 3;
+ options.bounds_sensitive_features->num_tokens_inside_left = 2;
+ options.bounds_sensitive_features->num_tokens_inside_right = 2;
+ options.bounds_sensitive_features->num_tokens_after = 3;
+
+ flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
+ &unilib_);
+
+ std::unique_ptr<CachedFeatures> cached_features;
+
+ FakeEmbeddingExecutor embedding_executor;
+
+ const std::vector<Token> tokens = {
+ Token("aaa", 0, 3), Token("bbb", 4, 7), Token("ccc", 8, 11),
+ Token("ddd", 12, 15), Token("eee", 16, 19), Token("fff", 20, 23)};
+
+ // We pre-populate the cache with dummy embeddings, to make sure they are
+ // used when populating the features vector.
+ const std::vector<float> cached_padding_features = {10.0, -10.0, 10.0, -10.0};
+ const std::vector<float> cached_features1 = {1.0, 2.0, 3.0, 4.0};
+ const std::vector<float> cached_features2 = {5.0, 6.0, 7.0, 8.0};
+ FeatureProcessor::EmbeddingCache embedding_cache = {
+ {{kInvalidIndex, kInvalidIndex}, cached_padding_features},
+ {{4, 7}, cached_features1},
+ {{12, 15}, cached_features2},
+ };
+
+ EXPECT_TRUE(feature_processor.ExtractFeatures(
+ tokens, /*token_span=*/{0, 6},
+ /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex},
+ &embedding_executor, &embedding_cache, /*feature_vector_size=*/4,
+ &cached_features));
+ std::vector<float> features;
+ cached_features->AppendBoundsSensitiveFeaturesForSpan({2, 4}, &features);
+ ASSERT_EQ(features.size(), 40);
+ // Check that the dummy embeddings were used.
+ EXPECT_THAT(Subvector(features, 0, 4),
+ ElementsAreFloat(cached_padding_features));
+ EXPECT_THAT(Subvector(features, 8, 12), ElementsAreFloat(cached_features1));
+ EXPECT_THAT(Subvector(features, 16, 20), ElementsAreFloat(cached_features2));
+ EXPECT_THAT(Subvector(features, 24, 28), ElementsAreFloat(cached_features2));
+ EXPECT_THAT(Subvector(features, 36, 40),
+ ElementsAreFloat(cached_padding_features));
+ // Check that the real embeddings were cached.
+ EXPECT_EQ(embedding_cache.size(), 7);
+ EXPECT_THAT(Subvector(features, 4, 8),
+ ElementsAreFloat(embedding_cache.at({0, 3})));
+ EXPECT_THAT(Subvector(features, 12, 16),
+ ElementsAreFloat(embedding_cache.at({8, 11})));
+ EXPECT_THAT(Subvector(features, 20, 24),
+ ElementsAreFloat(embedding_cache.at({8, 11})));
+ EXPECT_THAT(Subvector(features, 28, 32),
+ ElementsAreFloat(embedding_cache.at({16, 19})));
+ EXPECT_THAT(Subvector(features, 32, 36),
+ ElementsAreFloat(embedding_cache.at({20, 23})));
+}
+
+TEST_F(AnnotatorFeatureProcessorTest, StripUnusedTokensWithNoRelativeClick) {
+ std::vector<Token> tokens_orig{
+ Token("0", 0, 0), Token("1", 0, 0), Token("2", 0, 0), Token("3", 0, 0),
+ Token("4", 0, 0), Token("5", 0, 0), Token("6", 0, 0), Token("7", 0, 0),
+ Token("8", 0, 0), Token("9", 0, 0), Token("10", 0, 0), Token("11", 0, 0),
+ Token("12", 0, 0)};
+
+ std::vector<Token> tokens;
+ int click_index;
+
+ // Try to click first token and see if it gets padded from left.
+ tokens = tokens_orig;
+ click_index = 0;
+ internal::StripOrPadTokens({0, 0}, 2, &tokens, &click_index);
+ // clang-format off
+ EXPECT_EQ(tokens, std::vector<Token>({Token(),
+ Token(),
+ Token("0", 0, 0),
+ Token("1", 0, 0),
+ Token("2", 0, 0)}));
+ // clang-format on
+ EXPECT_EQ(click_index, 2);
+
+ // When we click the second token nothing should get padded.
+ tokens = tokens_orig;
+ click_index = 2;
+ internal::StripOrPadTokens({0, 0}, 2, &tokens, &click_index);
+ // clang-format off
+ EXPECT_EQ(tokens, std::vector<Token>({Token("0", 0, 0),
+ Token("1", 0, 0),
+ Token("2", 0, 0),
+ Token("3", 0, 0),
+ Token("4", 0, 0)}));
+ // clang-format on
+ EXPECT_EQ(click_index, 2);
+
+ // When we click the last token tokens should get padded from the right.
+ tokens = tokens_orig;
+ click_index = 12;
+ internal::StripOrPadTokens({0, 0}, 2, &tokens, &click_index);
+ // clang-format off
+ EXPECT_EQ(tokens, std::vector<Token>({Token("10", 0, 0),
+ Token("11", 0, 0),
+ Token("12", 0, 0),
+ Token(),
+ Token()}));
+ // clang-format on
+ EXPECT_EQ(click_index, 2);
+}
+
+TEST_F(AnnotatorFeatureProcessorTest, StripUnusedTokensWithRelativeClick) {
+ std::vector<Token> tokens_orig{
+ Token("0", 0, 0), Token("1", 0, 0), Token("2", 0, 0), Token("3", 0, 0),
+ Token("4", 0, 0), Token("5", 0, 0), Token("6", 0, 0), Token("7", 0, 0),
+ Token("8", 0, 0), Token("9", 0, 0), Token("10", 0, 0), Token("11", 0, 0),
+ Token("12", 0, 0)};
+
+ std::vector<Token> tokens;
+ int click_index;
+
+ // Try to click first token and see if it gets padded from left to maximum
+ // context_size.
+ tokens = tokens_orig;
+ click_index = 0;
+ internal::StripOrPadTokens({2, 3}, 2, &tokens, &click_index);
+ // clang-format off
+ EXPECT_EQ(tokens, std::vector<Token>({Token(),
+ Token(),
+ Token("0", 0, 0),
+ Token("1", 0, 0),
+ Token("2", 0, 0),
+ Token("3", 0, 0),
+ Token("4", 0, 0),
+ Token("5", 0, 0)}));
+ // clang-format on
+ EXPECT_EQ(click_index, 2);
+
+ // Clicking to the middle with enough context should not produce any padding.
+ tokens = tokens_orig;
+ click_index = 6;
+ internal::StripOrPadTokens({3, 1}, 2, &tokens, &click_index);
+ // clang-format off
+ EXPECT_EQ(tokens, std::vector<Token>({Token("1", 0, 0),
+ Token("2", 0, 0),
+ Token("3", 0, 0),
+ Token("4", 0, 0),
+ Token("5", 0, 0),
+ Token("6", 0, 0),
+ Token("7", 0, 0),
+ Token("8", 0, 0),
+ Token("9", 0, 0)}));
+ // clang-format on
+ EXPECT_EQ(click_index, 5);
+
+ // Clicking at the end should pad right to maximum context_size.
+ tokens = tokens_orig;
+ click_index = 11;
+ internal::StripOrPadTokens({3, 1}, 2, &tokens, &click_index);
+ // clang-format off
+ EXPECT_EQ(tokens, std::vector<Token>({Token("6", 0, 0),
+ Token("7", 0, 0),
+ Token("8", 0, 0),
+ Token("9", 0, 0),
+ Token("10", 0, 0),
+ Token("11", 0, 0),
+ Token("12", 0, 0),
+ Token(),
+ Token()}));
+ // clang-format on
+ EXPECT_EQ(click_index, 5);
+}
+
+TEST_F(AnnotatorFeatureProcessorTest, IgnoredSpanBoundaryCodepoints) {
+ FeatureProcessorOptionsT options;
+ options.ignored_span_boundary_codepoints.push_back('.');
+ options.ignored_span_boundary_codepoints.push_back(',');
+ options.ignored_span_boundary_codepoints.push_back('[');
+ options.ignored_span_boundary_codepoints.push_back(']');
+
+ flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
+ &unilib_);
+
+ const std::string text1_utf8 = "ěščř";
+ const UnicodeText text1 = UTF8ToUnicodeText(text1_utf8, /*do_copy=*/false);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text1.begin(), text1.end(),
+ /*count_from_beginning=*/true),
+ 0);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text1.begin(), text1.end(),
+ /*count_from_beginning=*/false),
+ 0);
+
+ const std::string text2_utf8 = ".,abčd";
+ const UnicodeText text2 = UTF8ToUnicodeText(text2_utf8, /*do_copy=*/false);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text2.begin(), text2.end(),
+ /*count_from_beginning=*/true),
+ 2);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text2.begin(), text2.end(),
+ /*count_from_beginning=*/false),
+ 0);
+
+ const std::string text3_utf8 = ".,abčd[]";
+ const UnicodeText text3 = UTF8ToUnicodeText(text3_utf8, /*do_copy=*/false);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text3.begin(), text3.end(),
+ /*count_from_beginning=*/true),
+ 2);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text3.begin(), text3.end(),
+ /*count_from_beginning=*/false),
+ 2);
+
+ const std::string text4_utf8 = "[abčd]";
+ const UnicodeText text4 = UTF8ToUnicodeText(text4_utf8, /*do_copy=*/false);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text4.begin(), text4.end(),
+ /*count_from_beginning=*/true),
+ 1);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text4.begin(), text4.end(),
+ /*count_from_beginning=*/false),
+ 1);
+
+ const std::string text5_utf8 = "";
+ const UnicodeText text5 = UTF8ToUnicodeText(text5_utf8, /*do_copy=*/false);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text5.begin(), text5.end(),
+ /*count_from_beginning=*/true),
+ 0);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text5.begin(), text5.end(),
+ /*count_from_beginning=*/false),
+ 0);
+
+ const std::string text6_utf8 = "012345ěščř";
+ const UnicodeText text6 = UTF8ToUnicodeText(text6_utf8, /*do_copy=*/false);
+ UnicodeText::const_iterator text6_begin = text6.begin();
+ std::advance(text6_begin, 6);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text6_begin, text6.end(),
+ /*count_from_beginning=*/true),
+ 0);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text6_begin, text6.end(),
+ /*count_from_beginning=*/false),
+ 0);
+
+ const std::string text7_utf8 = "012345.,ěščř";
+ const UnicodeText text7 = UTF8ToUnicodeText(text7_utf8, /*do_copy=*/false);
+ UnicodeText::const_iterator text7_begin = text7.begin();
+ std::advance(text7_begin, 6);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text7_begin, text7.end(),
+ /*count_from_beginning=*/true),
+ 2);
+ UnicodeText::const_iterator text7_end = text7.begin();
+ std::advance(text7_end, 8);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text7.begin(), text7_end,
+ /*count_from_beginning=*/false),
+ 2);
+
+ // Test not stripping.
+ EXPECT_EQ(feature_processor.StripBoundaryCodepoints(
+ "Hello [[[Wořld]] or not?", {0, 24}),
+ CodepointSpan(0, 24));
+ // Test basic stripping.
+ EXPECT_EQ(feature_processor.StripBoundaryCodepoints(
+ "Hello [[[Wořld]] or not?", {6, 16}),
+ CodepointSpan(9, 14));
+ // Test stripping when everything is stripped.
+ EXPECT_EQ(
+ feature_processor.StripBoundaryCodepoints("Hello [[[]] or not?", {6, 11}),
+ CodepointSpan(6, 6));
+ // Test stripping empty string.
+ EXPECT_EQ(feature_processor.StripBoundaryCodepoints("", {0, 0}),
+ CodepointSpan(0, 0));
+}
+
+TEST_F(AnnotatorFeatureProcessorTest, CodepointSpanToTokenSpan) {
+ const std::vector<Token> tokens{Token("Hělló", 0, 5),
+ Token("fěěbař@google.com", 6, 23),
+ Token("heře!", 24, 29)};
+
+ // Spans matching the tokens exactly.
+ EXPECT_EQ(TokenSpan(0, 1), CodepointSpanToTokenSpan(tokens, {0, 5}));
+ EXPECT_EQ(TokenSpan(1, 2), CodepointSpanToTokenSpan(tokens, {6, 23}));
+ EXPECT_EQ(TokenSpan(2, 3), CodepointSpanToTokenSpan(tokens, {24, 29}));
+ EXPECT_EQ(TokenSpan(0, 2), CodepointSpanToTokenSpan(tokens, {0, 23}));
+ EXPECT_EQ(TokenSpan(1, 3), CodepointSpanToTokenSpan(tokens, {6, 29}));
+ EXPECT_EQ(TokenSpan(0, 3), CodepointSpanToTokenSpan(tokens, {0, 29}));
+
+ // Snapping to containing tokens has no effect.
+ EXPECT_EQ(TokenSpan(0, 1), CodepointSpanToTokenSpan(tokens, {0, 5}, true));
+ EXPECT_EQ(TokenSpan(1, 2), CodepointSpanToTokenSpan(tokens, {6, 23}, true));
+ EXPECT_EQ(TokenSpan(2, 3), CodepointSpanToTokenSpan(tokens, {24, 29}, true));
+ EXPECT_EQ(TokenSpan(0, 2), CodepointSpanToTokenSpan(tokens, {0, 23}, true));
+ EXPECT_EQ(TokenSpan(1, 3), CodepointSpanToTokenSpan(tokens, {6, 29}, true));
+ EXPECT_EQ(TokenSpan(0, 3), CodepointSpanToTokenSpan(tokens, {0, 29}, true));
+
+ // Span boundaries inside tokens.
+ EXPECT_EQ(TokenSpan(1, 2), CodepointSpanToTokenSpan(tokens, {1, 28}));
+ EXPECT_EQ(TokenSpan(0, 3), CodepointSpanToTokenSpan(tokens, {1, 28}, true));
+
+ // Tokens adjacent to the span, but not overlapping.
+ EXPECT_EQ(TokenSpan(1, 2), CodepointSpanToTokenSpan(tokens, {5, 24}));
+ EXPECT_EQ(TokenSpan(1, 2), CodepointSpanToTokenSpan(tokens, {5, 24}, true));
+}
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/native/annotator/flatbuffer-utils.cc b/native/annotator/flatbuffer-utils.cc
new file mode 100644
index 0000000..d4cbe4a
--- /dev/null
+++ b/native/annotator/flatbuffer-utils.cc
@@ -0,0 +1,94 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/flatbuffer-utils.h"
+
+#include <memory>
+
+#include "utils/base/logging.h"
+#include "utils/flatbuffers/flatbuffers.h"
+#include "utils/flatbuffers/reflection.h"
+#include "flatbuffers/reflection.h"
+
+namespace libtextclassifier3 {
+
+bool SwapFieldNamesForOffsetsInPath(ModelT* model) {
+ if (model->regex_model == nullptr || model->entity_data_schema.empty()) {
+ // Nothing to do.
+ return true;
+ }
+ const reflection::Schema* schema =
+ LoadAndVerifyFlatbuffer<reflection::Schema>(
+ model->entity_data_schema.data(), model->entity_data_schema.size());
+
+ for (std::unique_ptr<RegexModel_::PatternT>& pattern :
+ model->regex_model->patterns) {
+ for (std::unique_ptr<CapturingGroupT>& group : pattern->capturing_group) {
+ if (group->entity_field_path == nullptr) {
+ continue;
+ }
+
+ if (!SwapFieldNamesForOffsetsInPath(schema,
+ group->entity_field_path.get())) {
+ return false;
+ }
+ }
+ }
+
+ return true;
+}
+
+std::string SwapFieldNamesForOffsetsInPathInSerializedModel(
+ const std::string& model) {
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(model.c_str());
+ TC3_CHECK(unpacked_model != nullptr);
+ TC3_CHECK(SwapFieldNamesForOffsetsInPath(unpacked_model.get()));
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+ return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize());
+}
+
+std::string CreateDatetimeSerializedEntityData(
+ const DatetimeParseResult& parse_result) {
+ EntityDataT entity_data;
+ entity_data.datetime.reset(new EntityData_::DatetimeT());
+ entity_data.datetime->time_ms_utc = parse_result.time_ms_utc;
+ entity_data.datetime->granularity =
+ static_cast<EntityData_::Datetime_::Granularity>(
+ parse_result.granularity);
+
+ for (const auto& c : parse_result.datetime_components) {
+ EntityData_::Datetime_::DatetimeComponentT datetime_component;
+ datetime_component.absolute_value = c.value;
+ datetime_component.relative_count = c.relative_count;
+ datetime_component.component_type =
+ static_cast<EntityData_::Datetime_::DatetimeComponent_::ComponentType>(
+ c.component_type);
+ datetime_component.relation_type =
+ EntityData_::Datetime_::DatetimeComponent_::RelationType_ABSOLUTE;
+ if (c.relative_qualifier !=
+ DatetimeComponent::RelativeQualifier::UNSPECIFIED) {
+ datetime_component.relation_type =
+ EntityData_::Datetime_::DatetimeComponent_::RelationType_RELATIVE;
+ }
+ entity_data.datetime->datetime_component.emplace_back(
+ new EntityData_::Datetime_::DatetimeComponentT(datetime_component));
+ }
+ return PackFlatbuffer<EntityData>(&entity_data);
+}
+
+} // namespace libtextclassifier3
diff --git a/native/annotator/flatbuffer-utils.h b/native/annotator/flatbuffer-utils.h
new file mode 100644
index 0000000..cd7d653
--- /dev/null
+++ b/native/annotator/flatbuffer-utils.h
@@ -0,0 +1,42 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Utility functions for working with FlatBuffers in the annotator model.
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_FLATBUFFER_UTILS_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_FLATBUFFER_UTILS_H_
+
+#include <string>
+
+#include "annotator/model_generated.h"
+#include "annotator/types.h"
+
+namespace libtextclassifier3 {
+
+// Resolves field lookups by name to the concrete field offsets in the regex
+// rules of the model.
+bool SwapFieldNamesForOffsetsInPath(ModelT* model);
+
+// Same as above but for a serialized model.
+std::string SwapFieldNamesForOffsetsInPathInSerializedModel(
+ const std::string& model);
+
+std::string CreateDatetimeSerializedEntityData(
+ const DatetimeParseResult& parse_result);
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_FLATBUFFER_UTILS_H_
diff --git a/native/annotator/grammar/dates/annotations/annotation-options.h b/native/annotator/grammar/dates/annotations/annotation-options.h
deleted file mode 100755
index 29e9939..0000000
--- a/native/annotator/grammar/dates/annotations/annotation-options.h
+++ /dev/null
@@ -1,96 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_ANNOTATIONS_ANNOTATION_OPTIONS_H_
-#define LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_ANNOTATIONS_ANNOTATION_OPTIONS_H_
-
-#include <string>
-#include <vector>
-
-#include "utils/base/integral_types.h"
-
-namespace libtextclassifier3 {
-
-// Options for date/datetime/date range annotations.
-struct DateAnnotationOptions {
- // If enabled, extract special day offset like today, yesterday, etc.
- bool enable_special_day_offset;
-
- // If true, merge the adjacent day of week, time and date. e.g.
- // "20/2/2016 at 8pm" is extracted as a single instance instead of two
- // instance: "20/2/2016" and "8pm".
- bool merge_adjacent_components;
-
- // List the extra id of requested dates.
- std::vector<std::string> extra_requested_dates;
-
- // If true, try to include preposition to the extracted annotation. e.g.
- // "at 6pm". if it's false, only 6pm is included. offline-actions has special
- // requirements to include preposition.
- bool include_preposition;
-
- // The base timestamp (milliseconds) which used to convert relative time to
- // absolute time.
- // e.g.:
- // base timestamp is 2016/4/25, then tomorrow will be converted to
- // 2016/4/26.
- // base timestamp is 2016/4/25 10:30:20am, then 1 days, 2 hours, 10 minutes
- // and 5 seconds ago will be converted to 2016/4/24 08:20:15am
- int64 base_timestamp_millis;
-
- // If enabled, extract range in date annotator.
- // input: Monday, 5-6pm
- // If the flag is true, The extracted annotation only contains 1 range
- // instance which is from Monday 5pm to 6pm.
- // If the flag is false, The extracted annotation contains two date
- // instance: "Monday" and "6pm".
- bool enable_date_range;
-
- // Timezone in which the input text was written
- std::string reference_timezone;
- // Localization params.
- // The format of the locale lists should be "<lang_code-<county_code>"
- // comma-separated list of two-character language/country pairs.
- std::string locales;
-
- // If enabled, the annotation/rule_match priority score is used to set the and
- // priority score of the annotation.
- // In case of false the annotation priority score are set from
- // GrammarDatetimeModel's priority_score
- bool use_rule_priority_score;
-
- // If enabled, annotator will try to resolve the ambiguity by generating
- // possible alternative interpretations of the input text
- // e.g. '9:45' will be resolved to '9:45 AM' and '9:45 PM'.
- bool generate_alternative_interpretations_when_ambiguous;
-
- // List the ignored span in the date string e.g. 12 March @12PM, here '@'
- // can be ignored tokens.
- std::vector<std::string> ignored_spans;
-
- // Default Constructor
- DateAnnotationOptions()
- : enable_special_day_offset(true),
- merge_adjacent_components(true),
- include_preposition(false),
- base_timestamp_millis(0),
- enable_date_range(false),
- use_rule_priority_score(false),
- generate_alternative_interpretations_when_ambiguous(false) {}
-};
-
-} // namespace libtextclassifier3
-#endif // LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_ANNOTATIONS_ANNOTATION_OPTIONS_H_
diff --git a/native/annotator/grammar/dates/annotations/annotation-util.cc b/native/annotator/grammar/dates/annotations/annotation-util.cc
deleted file mode 100644
index 438206f..0000000
--- a/native/annotator/grammar/dates/annotations/annotation-util.cc
+++ /dev/null
@@ -1,101 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "annotator/grammar/dates/annotations/annotation-util.h"
-
-#include <algorithm>
-
-#include "utils/base/logging.h"
-
-namespace libtextclassifier3 {
-
-int GetPropertyIndex(StringPiece name, const AnnotationData& annotation_data) {
- for (int i = 0; i < annotation_data.properties.size(); ++i) {
- if (annotation_data.properties[i].name == name.ToString()) {
- return i;
- }
- }
- return -1;
-}
-
-int GetPropertyIndex(StringPiece name, const Annotation& annotation) {
- return GetPropertyIndex(name, annotation.data);
-}
-
-int GetIntProperty(StringPiece name, const Annotation& annotation) {
- return GetIntProperty(name, annotation.data);
-}
-
-int GetIntProperty(StringPiece name, const AnnotationData& annotation_data) {
- const int index = GetPropertyIndex(name, annotation_data);
- if (index < 0) {
- TC3_DCHECK_GE(index, 0)
- << "No property with name " << name.ToString() << ".";
- return 0;
- }
-
- if (annotation_data.properties.at(index).int_values.size() != 1) {
- TC3_DCHECK_EQ(annotation_data.properties[index].int_values.size(), 1);
- return 0;
- }
-
- return annotation_data.properties.at(index).int_values.at(0);
-}
-
-int AddIntProperty(StringPiece name, int value, Annotation* annotation) {
- return AddRepeatedIntProperty(name, &value, 1, annotation);
-}
-
-int AddIntProperty(StringPiece name, int value,
- AnnotationData* annotation_data) {
- return AddRepeatedIntProperty(name, &value, 1, annotation_data);
-}
-
-int AddRepeatedIntProperty(StringPiece name, const int* start, int size,
- Annotation* annotation) {
- return AddRepeatedIntProperty(name, start, size, &annotation->data);
-}
-
-int AddRepeatedIntProperty(StringPiece name, const int* start, int size,
- AnnotationData* annotation_data) {
- Property property;
- property.name = name.ToString();
- auto first = start;
- auto last = start + size;
- while (first != last) {
- property.int_values.push_back(*first);
- first++;
- }
- annotation_data->properties.push_back(property);
- return annotation_data->properties.size() - 1;
-}
-
-int AddAnnotationDataProperty(const std::string& key,
- const AnnotationData& value,
- AnnotationData* annotation_data) {
- Property property;
- property.name = key;
- property.annotation_data_values.push_back(value);
- annotation_data->properties.push_back(property);
- return annotation_data->properties.size() - 1;
-}
-
-int AddAnnotationDataProperty(const std::string& key,
- const AnnotationData& value,
- Annotation* annotation) {
- return AddAnnotationDataProperty(key, value, &annotation->data);
-}
-} // namespace libtextclassifier3
diff --git a/native/annotator/grammar/dates/annotations/annotation-util.h b/native/annotator/grammar/dates/annotations/annotation-util.h
deleted file mode 100644
index e4afbfe..0000000
--- a/native/annotator/grammar/dates/annotations/annotation-util.h
+++ /dev/null
@@ -1,75 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_ANNOTATIONS_ANNOTATION_UTIL_H_
-#define LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_ANNOTATIONS_ANNOTATION_UTIL_H_
-
-#include "annotator/grammar/dates/annotations/annotation.h"
-#include "utils/strings/stringpiece.h"
-
-namespace libtextclassifier3 {
-
-// Return the index of property in annotation.data().properties().
-// Return -1 if the property does not exist.
-int GetPropertyIndex(StringPiece name, const Annotation& annotation);
-
-// Return the index of property in thing.properties().
-// Return -1 if the property does not exist.
-int GetPropertyIndex(StringPiece name, const AnnotationData& annotation_data);
-
-// Return the single int value for property 'name' of the annotation.
-// Returns 0 if the property does not exist or does not contain a single int
-// value.
-int GetIntProperty(StringPiece name, const Annotation& annotation);
-
-// Return the single float value for property 'name' of the annotation.
-// Returns 0 if the property does not exist or does not contain a single int
-// value.
-int GetIntProperty(StringPiece name, const AnnotationData& annotation_data);
-
-// Add a new property with a single int value to an Annotation instance.
-// Return the index of the property.
-int AddIntProperty(StringPiece name, const int value, Annotation* annotation);
-
-// Add a new property with a single int value to a Thing instance.
-// Return the index of the property.
-int AddIntProperty(StringPiece name, const int value,
- AnnotationData* annotation_data);
-
-// Add a new property with repeated int values to an Annotation instance.
-// Return the index of the property.
-int AddRepeatedIntProperty(StringPiece name, const int* start, int size,
- Annotation* annotation);
-
-// Add a new property with repeated int values to a Thing instance.
-// Return the index of the property.
-int AddRepeatedIntProperty(StringPiece name, const int* start, int size,
- AnnotationData* annotation_data);
-
-// Add a new property with Thing value.
-// Return the index of the property.
-int AddAnnotationDataProperty(const std::string& key,
- const AnnotationData& value,
- Annotation* annotation);
-
-// Add a new property with Thing value.
-// Return the index of the property.
-int AddAnnotationDataProperty(const std::string& key,
- const AnnotationData& value,
- AnnotationData* annotation_data);
-
-} // namespace libtextclassifier3
-#endif // LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_ANNOTATIONS_ANNOTATION_UTIL_H_
diff --git a/native/annotator/grammar/dates/annotations/annotation-util_test.cc b/native/annotator/grammar/dates/annotations/annotation-util_test.cc
deleted file mode 100644
index 6d25d64..0000000
--- a/native/annotator/grammar/dates/annotations/annotation-util_test.cc
+++ /dev/null
@@ -1,67 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "annotator/grammar/dates/annotations/annotation-util.h"
-
-#include "annotator/grammar/dates/annotations/annotation.h"
-#include "gtest/gtest.h"
-
-namespace libtextclassifier3 {
-namespace {
-
-TEST(AnnotationUtilTest, VerifyIntFunctions) {
- Annotation annotation;
-
- int index_key1 = AddIntProperty("key1", 1, &annotation);
- int index_key2 = AddIntProperty("key2", 2, &annotation);
-
- static const int kValuesKey3[] = {3, 4, 5};
- int index_key3 =
- AddRepeatedIntProperty("key3", kValuesKey3, /*size=*/3, &annotation);
-
- EXPECT_EQ(2, GetIntProperty("key2", annotation));
- EXPECT_EQ(1, GetIntProperty("key1", annotation));
-
- EXPECT_EQ(index_key1, GetPropertyIndex("key1", annotation));
- EXPECT_EQ(index_key2, GetPropertyIndex("key2", annotation));
- EXPECT_EQ(index_key3, GetPropertyIndex("key3", annotation));
- EXPECT_EQ(-1, GetPropertyIndex("invalid_key", annotation));
-}
-
-TEST(AnnotationUtilTest, VerifyAnnotationDataFunctions) {
- Annotation annotation;
-
- AnnotationData true_annotation_data;
- Property true_property;
- true_property.bool_values.push_back(true);
- true_annotation_data.properties.push_back(true_property);
- int index_key1 =
- AddAnnotationDataProperty("key1", true_annotation_data, &annotation);
-
- AnnotationData false_annotation_data;
- Property false_property;
- false_property.bool_values.push_back(false);
- true_annotation_data.properties.push_back(false_property);
- int index_key2 =
- AddAnnotationDataProperty("key2", false_annotation_data, &annotation);
-
- EXPECT_EQ(index_key1, GetPropertyIndex("key1", annotation));
- EXPECT_EQ(index_key2, GetPropertyIndex("key2", annotation));
- EXPECT_EQ(-1, GetPropertyIndex("invalid_key", annotation));
-}
-
-} // namespace
-} // namespace libtextclassifier3
diff --git a/native/annotator/grammar/dates/annotations/annotation.h b/native/annotator/grammar/dates/annotations/annotation.h
deleted file mode 100644
index e6ddb09..0000000
--- a/native/annotator/grammar/dates/annotations/annotation.h
+++ /dev/null
@@ -1,71 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_ANNOTATIONS_ANNOTATION_H_
-#define LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_ANNOTATIONS_ANNOTATION_H_
-
-#include <string>
-#include <vector>
-
-#include "utils/base/integral_types.h"
-
-namespace libtextclassifier3 {
-
-struct AnnotationData;
-
-// Define enum for each annotation.
-enum GrammarAnnotationType {
- // Date&time like "May 1", "12:20pm", etc.
- DATETIME = 0,
- // Datetime range like "2pm - 3pm".
- DATETIME_RANGE = 1,
-};
-
-struct Property {
- // TODO(hassan): Replace the name with enum e.g. PropertyType.
- std::string name;
- // At most one of these will have any values.
- std::vector<bool> bool_values;
- std::vector<int64> int_values;
- std::vector<double> double_values;
- std::vector<std::string> string_values;
- std::vector<AnnotationData> annotation_data_values;
-};
-
-struct AnnotationData {
- // TODO(hassan): Replace it type with GrammarAnnotationType
- std::string type;
- std::vector<Property> properties;
-};
-
-// Represents an annotation instance.
-// lets call it either AnnotationDetails
-struct Annotation {
- // Codepoint offsets into the original text specifying the substring of the
- // text that was annotated.
- int32 begin;
- int32 end;
-
- // Annotation priority score which can be used to resolve conflict between
- // annotators.
- float annotator_priority_score;
-
- // Represents the details of the annotation instance, including the type of
- // the annotation instance and its properties.
- AnnotationData data;
-};
-} // namespace libtextclassifier3
-#endif // LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_ANNOTATIONS_ANNOTATION_H_
diff --git a/native/annotator/grammar/dates/cfg-datetime-annotator.cc b/native/annotator/grammar/dates/cfg-datetime-annotator.cc
deleted file mode 100644
index 99d3be0..0000000
--- a/native/annotator/grammar/dates/cfg-datetime-annotator.cc
+++ /dev/null
@@ -1,139 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "annotator/grammar/dates/cfg-datetime-annotator.h"
-
-#include "annotator/datetime/utils.h"
-#include "annotator/grammar/dates/annotations/annotation-options.h"
-#include "annotator/grammar/utils.h"
-#include "utils/strings/split.h"
-#include "utils/tokenizer.h"
-#include "utils/utf8/unicodetext.h"
-
-namespace libtextclassifier3::dates {
-namespace {
-
-static std::string GetReferenceLocale(const std::string& locales) {
- std::vector<StringPiece> split_locales = strings::Split(locales, ',');
- if (!split_locales.empty()) {
- return split_locales[0].ToString();
- }
- return "";
-}
-
-static void InterpretParseData(const DatetimeParsedData& datetime_parsed_data,
- const DateAnnotationOptions& options,
- const CalendarLib& calendarlib,
- int64* interpreted_time_ms_utc,
- DatetimeGranularity* granularity) {
- DatetimeGranularity local_granularity =
- calendarlib.GetGranularity(datetime_parsed_data);
- if (!calendarlib.InterpretParseData(
- datetime_parsed_data, options.base_timestamp_millis,
- options.reference_timezone, GetReferenceLocale(options.locales),
- /*prefer_future_for_unspecified_date=*/true, interpreted_time_ms_utc,
- granularity)) {
- TC3_LOG(WARNING) << "Failed to extract time in millis and Granularity.";
- // Fallingback to DatetimeParsedData's finest granularity
- *granularity = local_granularity;
- }
-}
-
-} // namespace
-
-CfgDatetimeAnnotator::CfgDatetimeAnnotator(
- const UniLib* unilib, const GrammarTokenizerOptions* tokenizer_options,
- const CalendarLib* calendar_lib, const DatetimeRules* datetime_rules,
- const float annotator_target_classification_score,
- const float annotator_priority_score)
- : calendar_lib_(*calendar_lib),
- tokenizer_(BuildTokenizer(unilib, tokenizer_options)),
- parser_(unilib, datetime_rules),
- annotator_target_classification_score_(
- annotator_target_classification_score),
- annotator_priority_score_(annotator_priority_score) {}
-
-void CfgDatetimeAnnotator::Parse(
- const std::string& input, const DateAnnotationOptions& annotation_options,
- const std::vector<Locale>& locales,
- std::vector<DatetimeParseResultSpan>* results) const {
- Parse(UTF8ToUnicodeText(input, /*do_copy=*/false), annotation_options,
- locales, results);
-}
-
-void CfgDatetimeAnnotator::ProcessDatetimeParseResult(
- const DateAnnotationOptions& annotation_options,
- const DatetimeParseResult& datetime_parse_result,
- std::vector<DatetimeParseResult>* results) const {
- DatetimeParsedData datetime_parsed_data;
- datetime_parsed_data.AddDatetimeComponents(
- datetime_parse_result.datetime_components);
-
- std::vector<DatetimeParsedData> interpretations;
- if (annotation_options.generate_alternative_interpretations_when_ambiguous) {
- FillInterpretations(datetime_parsed_data,
- calendar_lib_.GetGranularity(datetime_parsed_data),
- &interpretations);
- } else {
- interpretations.emplace_back(datetime_parsed_data);
- }
- for (const DatetimeParsedData& interpretation : interpretations) {
- results->emplace_back();
- interpretation.GetDatetimeComponents(&results->back().datetime_components);
- InterpretParseData(interpretation, annotation_options, calendar_lib_,
- &(results->back().time_ms_utc),
- &(results->back().granularity));
- std::sort(results->back().datetime_components.begin(),
- results->back().datetime_components.end(),
- [](const DatetimeComponent& a, const DatetimeComponent& b) {
- return a.component_type > b.component_type;
- });
- }
-}
-
-void CfgDatetimeAnnotator::Parse(
- const UnicodeText& input, const DateAnnotationOptions& annotation_options,
- const std::vector<Locale>& locales,
- std::vector<DatetimeParseResultSpan>* results) const {
- std::vector<DatetimeParseResultSpan> grammar_datetime_parse_result_spans =
- parser_.Parse(input.data(), tokenizer_.Tokenize(input), locales,
- annotation_options);
-
- for (const DatetimeParseResultSpan& grammar_datetime_parse_result_span :
- grammar_datetime_parse_result_spans) {
- DatetimeParseResultSpan datetime_parse_result_span;
- datetime_parse_result_span.span.first =
- grammar_datetime_parse_result_span.span.first;
- datetime_parse_result_span.span.second =
- grammar_datetime_parse_result_span.span.second;
- datetime_parse_result_span.priority_score = annotator_priority_score_;
- if (annotation_options.use_rule_priority_score) {
- datetime_parse_result_span.priority_score =
- grammar_datetime_parse_result_span.priority_score;
- }
- datetime_parse_result_span.target_classification_score =
- annotator_target_classification_score_;
- for (const DatetimeParseResult& grammar_datetime_parse_result :
- grammar_datetime_parse_result_span.data) {
- ProcessDatetimeParseResult(annotation_options,
- grammar_datetime_parse_result,
- &datetime_parse_result_span.data);
- }
- results->emplace_back(datetime_parse_result_span);
- }
-}
-
-} // namespace libtextclassifier3::dates
diff --git a/native/annotator/grammar/dates/cfg-datetime-annotator.h b/native/annotator/grammar/dates/cfg-datetime-annotator.h
deleted file mode 100644
index 73c9b7b..0000000
--- a/native/annotator/grammar/dates/cfg-datetime-annotator.h
+++ /dev/null
@@ -1,74 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_CFG_DATETIME_ANNOTATOR_H_
-#define LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_CFG_DATETIME_ANNOTATOR_H_
-
-#include "annotator/grammar/dates/annotations/annotation.h"
-#include "annotator/grammar/dates/dates_generated.h"
-#include "annotator/grammar/dates/parser.h"
-#include "annotator/grammar/dates/utils/annotation-keys.h"
-#include "annotator/model_generated.h"
-#include "utils/calendar/calendar.h"
-#include "utils/i18n/locale.h"
-#include "utils/tokenizer.h"
-#include "utils/utf8/unilib.h"
-
-namespace libtextclassifier3::dates {
-
-// Helper class to convert the parsed datetime expression from AnnotationList
-// (List of annotation generated from Grammar rules) to DatetimeParseResultSpan.
-class CfgDatetimeAnnotator {
- public:
- explicit CfgDatetimeAnnotator(
- const UniLib* unilib, const GrammarTokenizerOptions* tokenizer_options,
- const CalendarLib* calendar_lib, const DatetimeRules* datetime_rules,
- const float annotator_target_classification_score,
- const float annotator_priority_score);
-
- // CfgDatetimeAnnotator is neither copyable nor movable.
- CfgDatetimeAnnotator(const CfgDatetimeAnnotator&) = delete;
- CfgDatetimeAnnotator& operator=(const CfgDatetimeAnnotator&) = delete;
-
- // Parses the dates in 'input' and fills result. Makes sure that the results
- // do not overlap.
- // Method will return false if input does not contain any datetime span.
- void Parse(const std::string& input,
- const DateAnnotationOptions& annotation_options,
- const std::vector<Locale>& locales,
- std::vector<DatetimeParseResultSpan>* results) const;
-
- // UnicodeText version of parse.
- void Parse(const UnicodeText& input,
- const DateAnnotationOptions& annotation_options,
- const std::vector<Locale>& locales,
- std::vector<DatetimeParseResultSpan>* results) const;
-
- private:
- void ProcessDatetimeParseResult(
- const DateAnnotationOptions& annotation_options,
- const DatetimeParseResult& datetime_parse_result,
- std::vector<DatetimeParseResult>* results) const;
-
- const CalendarLib& calendar_lib_;
- const Tokenizer tokenizer_;
- DateParser parser_;
- const float annotator_target_classification_score_;
- const float annotator_priority_score_;
-};
-
-} // namespace libtextclassifier3::dates
-#endif // LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_CFG_DATETIME_ANNOTATOR_H_
diff --git a/native/annotator/grammar/dates/dates.fbs b/native/annotator/grammar/dates/dates.fbs
deleted file mode 100755
index 6d535bc..0000000
--- a/native/annotator/grammar/dates/dates.fbs
+++ /dev/null
@@ -1,351 +0,0 @@
-//
-// Copyright (C) 2018 The Android Open Source Project
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-//
-
-include "annotator/grammar/dates/timezone-code.fbs";
-include "utils/grammar/rules.fbs";
-
-// Type identifiers of all non-trivial matches.
-namespace libtextclassifier3.dates;
-enum MatchType : int {
- UNKNOWN = 0,
-
- // Match of a date extraction rule.
- DATETIME_RULE = 1,
-
- // Match of a date range extraction rule.
- DATETIME_RANGE_RULE = 2,
-
- // Match defined by an ExtractionRule (e.g., a single time-result that is
- // matched by a time-rule, which is ready to be output individually, with
- // this kind of match, we can retrieve it in range rules).
- DATETIME = 3,
-
- // Match defined by TermValue.
- TERM_VALUE = 4,
-
- // Matches defined by Nonterminal.
- NONTERMINAL = 5,
-
- DIGITS = 6,
- YEAR = 7,
- MONTH = 8,
- DAY = 9,
- HOUR = 10,
- MINUTE = 11,
- SECOND = 12,
- FRACTION_SECOND = 13,
- DAY_OF_WEEK = 14,
- TIME_VALUE = 15,
- TIME_SPAN = 16,
- TIME_ZONE_NAME = 17,
- TIME_ZONE_OFFSET = 18,
- TIME_PERIOD = 19,
- RELATIVE_DATE = 20,
- COMBINED_DIGITS = 21,
-}
-
-namespace libtextclassifier3.dates;
-enum BCAD : int {
- BCAD_NONE = -1,
- BC = 0,
- AD = 1,
-}
-
-namespace libtextclassifier3.dates;
-enum DayOfWeek : int {
- DOW_NONE = -1,
- SUNDAY = 1,
- MONDAY = 2,
- TUESDAY = 3,
- WEDNESDAY = 4,
- THURSDAY = 5,
- FRIDAY = 6,
- SATURDAY = 7,
-}
-
-namespace libtextclassifier3.dates;
-enum TimespanCode : int {
- TIMESPAN_CODE_NONE = -1,
- AM = 0,
- PM = 1,
- NOON = 2,
- MIDNIGHT = 3,
-
- // English "tonight".
- TONIGHT = 11,
-}
-
-// The datetime grammar rules.
-namespace libtextclassifier3.dates;
-table DatetimeRules {
- // The context free grammar rules.
- rules:grammar.RulesSet;
-
- // Values associated with grammar rule matches.
- extraction_rule:[ExtractionRuleParameter];
-
- term_value:[TermValue];
- nonterminal_value:[NonterminalValue];
-}
-
-namespace libtextclassifier3.dates;
-table TermValue {
- value:int;
-
- // A time segment e.g. 10AM - 12AM
- time_span_spec:TimeSpanSpec;
-
- // Time zone information representation
- time_zone_name_spec:TimeZoneNameSpec;
-}
-
-// Define nonterms from terms or other nonterms.
-namespace libtextclassifier3.dates;
-table NonterminalValue {
- // Mapping value.
- value:TermValue;
-
- // Parameter describing formatting choices for nonterminal messages
- nonterminal_parameter:NonterminalParameter;
-
- // Parameter interpreting past/future dates (e.g. "last year")
- relative_parameter:RelativeParameter;
-
- // Format info for nonterminals representing times.
- time_value_parameter:TimeValueParameter;
-
- // Parameter describing the format of time-zone info - e.g. "UTC-8"
- time_zone_offset_parameter:TimeZoneOffsetParameter;
-}
-
-namespace libtextclassifier3.dates.RelativeParameter_;
-enum RelativeType : int {
- NONE = 0,
- YEAR = 1,
- MONTH = 2,
- DAY = 3,
- WEEK = 4,
- HOUR = 5,
- MINUTE = 6,
- SECOND = 7,
-}
-
-namespace libtextclassifier3.dates.RelativeParameter_;
-enum Period : int {
- PERIOD_UNKNOWN = 0,
- PERIOD_PAST = 1,
- PERIOD_FUTURE = 2,
-}
-
-// Relative interpretation.
-// Indicates which day the day of week could be, for example "next Friday"
-// could means the Friday which is the closest Friday or the Friday in the
-// next week.
-namespace libtextclassifier3.dates.RelativeParameter_;
-enum Interpretation : int {
- UNKNOWN = 0,
-
- // The closest X in the past.
- NEAREST_LAST = 1,
-
- // The X before the closest X in the past.
- SECOND_LAST = 2,
-
- // The closest X in the future.
- NEAREST_NEXT = 3,
-
- // The X after the closest X in the future.
- SECOND_NEXT = 4,
-
- // X in the previous one.
- PREVIOUS = 5,
-
- // X in the coming one.
- COMING = 6,
-
- // X in current one, it can be both past and future.
- CURRENT = 7,
-
- // Some X.
- SOME = 8,
-
- // The closest X, it can be both past and future.
- NEAREST = 9,
-}
-
-namespace libtextclassifier3.dates;
-table RelativeParameter {
- type:RelativeParameter_.RelativeType = NONE;
- period:RelativeParameter_.Period = PERIOD_UNKNOWN;
- day_of_week_interpretation:[RelativeParameter_.Interpretation];
-}
-
-namespace libtextclassifier3.dates.NonterminalParameter_;
-enum Flag : int {
- IS_SPELLED = 1,
-}
-
-namespace libtextclassifier3.dates;
-table NonterminalParameter {
- // Bit-wise OR Flag.
- flag:uint = 0;
-
- combined_digits_format:string (shared);
-}
-
-namespace libtextclassifier3.dates.TimeValueParameter_;
-enum TimeValueValidation : int {
- // Allow extra spaces between sub-components in time-value.
- ALLOW_EXTRA_SPACE = 1,
- // 1 << 0
-
- // Disallow colon- or dot-context with digits for time-value.
- DISALLOW_COLON_DOT_CONTEXT = 2,
- // 1 << 1
-}
-
-namespace libtextclassifier3.dates;
-table TimeValueParameter {
- validation:uint = 0;
- // Bitwise-OR
-
- flag:uint = 0;
- // Bitwise-OR
-}
-
-namespace libtextclassifier3.dates.TimeZoneOffsetParameter_;
-enum Format : int {
- // Offset is in an uncategorized format.
- FORMAT_UNKNOWN = 0,
-
- // Offset contains 1-digit hour only, e.g. "UTC-8".
- FORMAT_H = 1,
-
- // Offset contains 2-digit hour only, e.g. "UTC-08".
- FORMAT_HH = 2,
-
- // Offset contains 1-digit hour and minute, e.g. "UTC-8:00".
- FORMAT_H_MM = 3,
-
- // Offset contains 2-digit hour and minute, e.g. "UTC-08:00".
- FORMAT_HH_MM = 4,
-
- // Offset contains 3-digit hour-and-minute, e.g. "UTC-800".
- FORMAT_HMM = 5,
-
- // Offset contains 4-digit hour-and-minute, e.g. "UTC-0800".
- FORMAT_HHMM = 6,
-}
-
-namespace libtextclassifier3.dates;
-table TimeZoneOffsetParameter {
- format:TimeZoneOffsetParameter_.Format = FORMAT_UNKNOWN;
-}
-
-namespace libtextclassifier3.dates.ExtractionRuleParameter_;
-enum ExtractionValidation : int {
- // Boundary checking for final match.
- LEFT_BOUND = 1,
-
- RIGHT_BOUND = 2,
- SPELLED_YEAR = 4,
- SPELLED_MONTH = 8,
- SPELLED_DAY = 16,
-
- // Without this validation-flag set, unconfident time-zone expression
- // are discarded in the output-callback, e.g. "-08:00, +8".
- ALLOW_UNCONFIDENT_TIME_ZONE = 32,
-}
-
-// Parameter info for extraction rule, help rule explanation.
-namespace libtextclassifier3.dates;
-table ExtractionRuleParameter {
- // Bit-wise OR Validation.
- validation:uint = 0;
-
- priority_delta:int;
- id:string (shared);
-
- // The score reflects the confidence score of the date/time match, which is
- // set while creating grammar rules.
- // e.g. given we have the rule which detect "22.33" as a HH.MM then because
- // of ambiguity the confidence of this match maybe relatively less.
- annotator_priority_score:float;
-}
-
-// Internal structure used to describe an hour-mapping segment.
-namespace libtextclassifier3.dates.TimeSpanSpec_;
-table Segment {
- // From 0 to 24, the beginning hour of the segment, always included.
- begin:int;
-
- // From 0 to 24, the ending hour of the segment, not included if the
- // segment is not closed. The value 0 means the beginning of the next
- // day, the same value as "begin" means a time-point.
- end:int;
-
- // From -24 to 24, the mapping offset in hours from spanned expressions
- // to 24-hour expressions. The value 0 means identical mapping.
- offset:int;
-
- // True if the segment is a closed one instead of a half-open one.
- // Always set it to true when describing time-points.
- is_closed:bool = false;
-
- // True if a strict check should be performed onto the segment which
- // disallows already-offset hours to be used in spanned expressions,
- // e.g. 15:30PM.
- is_strict:bool = false;
-
- // True if the time-span can be used without an explicitly specified
- // hour value, then it can generate an exact time point (the "begin"
- // o'clock sharp, like "noon") or a time range, like "Tonight".
- is_stand_alone:bool = false;
-}
-
-namespace libtextclassifier3.dates;
-table TimeSpanSpec {
- code:TimespanCode;
- segment:[TimeSpanSpec_.Segment];
-}
-
-namespace libtextclassifier3.dates.TimeZoneNameSpec_;
-enum TimeZoneType : int {
- // The corresponding name might represent a standard or daylight-saving
- // time-zone, depending on some external information, e.g. the date.
- AMBIGUOUS = 0,
-
- // The corresponding name represents a standard time-zone.
- STANDARD = 1,
-
- // The corresponding name represents a daylight-saving time-zone.
- DAYLIGHT = 2,
-}
-
-namespace libtextclassifier3.dates;
-table TimeZoneNameSpec {
- code:TimezoneCode;
- type:TimeZoneNameSpec_.TimeZoneType = AMBIGUOUS;
-
- // Set to true if the corresponding name is internationally used as an
- // abbreviation (or expression) of UTC. For example, "GMT" and "Z".
- is_utc:bool = false;
-
- // Set to false if the corresponding name is not an abbreviation. For example,
- // "Pacific Time" and "China Standard Time".
- is_abbreviation:bool = true;
-}
-
diff --git a/native/annotator/grammar/dates/extractor.cc b/native/annotator/grammar/dates/extractor.cc
deleted file mode 100644
index d2db23e..0000000
--- a/native/annotator/grammar/dates/extractor.cc
+++ /dev/null
@@ -1,913 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "annotator/grammar/dates/extractor.h"
-
-#include <initializer_list>
-#include <map>
-
-#include "annotator/grammar/dates/utils/date-match.h"
-#include "annotator/grammar/dates/utils/date-utils.h"
-#include "utils/base/casts.h"
-#include "utils/base/logging.h"
-#include "utils/strings/numbers.h"
-
-namespace libtextclassifier3::dates {
-namespace {
-
-// Helper struct for time-related components.
-// Extracts all subnodes of a specified type.
-struct MatchComponents {
- MatchComponents(const grammar::Match* root,
- std::initializer_list<int16> types)
- : root(root),
- components(grammar::SelectAll(
- root, [root, &types](const grammar::Match* node) {
- if (node == root || node->type == grammar::Match::kUnknownType) {
- return false;
- }
- for (const int64 type : types) {
- if (node->type == type) {
- return true;
- }
- }
- return false;
- })) {}
-
- // Returns the index of the first submatch of the specified type or -1 if not
- // found.
- int IndexOf(const int16 type, const int start_index = 0) const {
- for (int i = start_index; i < components.size(); i++) {
- if (components[i]->type == type) {
- return i;
- }
- }
- return -1;
- }
-
- // Returns the first submatch of the specified type, or nullptr if not found.
- template <typename T>
- const T* SubmatchOf(const int16 type, const int start_index = 0) const {
- return SubmatchAt<T>(IndexOf(type, start_index));
- }
-
- template <typename T>
- const T* SubmatchAt(const int index) const {
- if (index < 0) {
- return nullptr;
- }
- return static_cast<const T*>(components[index]);
- }
-
- const grammar::Match* root;
- std::vector<const grammar::Match*> components;
-};
-
-// Helper method to check whether a time value has valid components.
-bool IsValidTimeValue(const TimeValueMatch& time_value) {
- // Can only specify seconds if minutes are present.
- if (time_value.minute == NO_VAL && time_value.second != NO_VAL) {
- return false;
- }
- // Can only specify fraction of seconds if seconds are present.
- if (time_value.second == NO_VAL && time_value.fraction_second >= 0.0) {
- return false;
- }
-
- const int8 h = time_value.hour;
- const int8 m = (time_value.minute < 0 ? 0 : time_value.minute);
- const int8 s = (time_value.second < 0 ? 0 : time_value.second);
- const double f =
- (time_value.fraction_second < 0.0 ? 0.0 : time_value.fraction_second);
-
- // Check value bounds.
- if (h == NO_VAL || h > 24 || m > 59 || s > 60) {
- return false;
- }
- if (h == 24 && (m != 0 || s != 0 || f > 0.0)) {
- return false;
- }
- if (s == 60 && m != 59) {
- return false;
- }
- return true;
-}
-
-int ParseLeadingDec32Value(const char* c_str) {
- int value;
- if (ParseInt32(c_str, &value)) {
- return value;
- }
- return NO_VAL;
-}
-
-double ParseLeadingDoubleValue(const char* c_str) {
- double value;
- if (ParseDouble(c_str, &value)) {
- return value;
- }
- return NO_VAL;
-}
-
-// Extracts digits as an integer and adds a typed match accordingly.
-template <typename T>
-void CheckDigits(const grammar::Match* match,
- const NonterminalValue* nonterminal, StringPiece match_text,
- grammar::Matcher* matcher) {
- TC3_CHECK(match->IsUnaryRule());
- const int value = ParseLeadingDec32Value(match_text.ToString().c_str());
- if (!T::IsValid(value)) {
- return;
- }
- const int num_digits = match_text.size();
- T* result = matcher->AllocateAndInitMatch<T>(
- match->lhs, match->codepoint_span, match->match_offset);
- result->Reset();
- result->nonterminal = nonterminal;
- result->value = value;
- result->count_of_digits = num_digits;
- result->is_zero_prefixed = (num_digits >= 2 && match_text[0] == '0');
- matcher->AddMatch(result);
-}
-
-// Extracts digits as a decimal (as fraction, as if a "0." is prefixed) and
-// adds a typed match to the `er accordingly.
-template <typename T>
-void CheckDigitsAsFraction(const grammar::Match* match,
- const NonterminalValue* nonterminal,
- StringPiece match_text, grammar::Matcher* matcher) {
- TC3_CHECK(match->IsUnaryRule());
- // TODO(smillius): Should should be achievable in a more straight-forward way.
- const double value =
- ParseLeadingDoubleValue(("0." + match_text.ToString()).data());
- if (!T::IsValid(value)) {
- return;
- }
- T* result = matcher->AllocateAndInitMatch<T>(
- match->lhs, match->codepoint_span, match->match_offset);
- result->Reset();
- result->nonterminal = nonterminal;
- result->value = value;
- result->count_of_digits = match_text.size();
- matcher->AddMatch(result);
-}
-
-// Extracts consecutive digits as multiple integers according to a format and
-// adds a type match to the matcher accordingly.
-template <typename T>
-void CheckCombinedDigits(const grammar::Match* match,
- const NonterminalValue* nonterminal,
- StringPiece match_text, grammar::Matcher* matcher) {
- TC3_CHECK(match->IsUnaryRule());
- const std::string& format =
- nonterminal->nonterminal_parameter()->combined_digits_format()->str();
- if (match_text.size() != format.size()) {
- return;
- }
-
- static std::map<char, CombinedDigitsMatch::Index>& kCombinedDigitsMatchIndex =
- *[]() {
- return new std::map<char, CombinedDigitsMatch::Index>{
- {'Y', CombinedDigitsMatch::INDEX_YEAR},
- {'M', CombinedDigitsMatch::INDEX_MONTH},
- {'D', CombinedDigitsMatch::INDEX_DAY},
- {'h', CombinedDigitsMatch::INDEX_HOUR},
- {'m', CombinedDigitsMatch::INDEX_MINUTE},
- {'s', CombinedDigitsMatch::INDEX_SECOND}};
- }();
-
- struct Segment {
- const int index;
- const int length;
- const int value;
- };
- std::vector<Segment> segments;
- int slice_start = 0;
- while (slice_start < format.size()) {
- int slice_end = slice_start + 1;
- // Advace right as long as we have the same format character.
- while (slice_end < format.size() &&
- format[slice_start] == format[slice_end]) {
- slice_end++;
- }
-
- const int slice_length = slice_end - slice_start;
- const int value = ParseLeadingDec32Value(
- std::string(match_text.data() + slice_start, slice_length).c_str());
-
- auto index = kCombinedDigitsMatchIndex.find(format[slice_start]);
- if (index == kCombinedDigitsMatchIndex.end()) {
- return;
- }
- if (!T::IsValid(index->second, value)) {
- return;
- }
- segments.push_back(Segment{index->second, slice_length, value});
- slice_start = slice_end;
- }
- T* result = matcher->AllocateAndInitMatch<T>(
- match->lhs, match->codepoint_span, match->match_offset);
- result->Reset();
- result->nonterminal = nonterminal;
- for (const Segment& segment : segments) {
- result->values[segment.index] = segment.value;
- }
- result->count_of_digits = match_text.size();
- result->is_zero_prefixed =
- (match_text[0] == '0' && segments.front().length >= 2);
- matcher->AddMatch(result);
-}
-
-// Retrieves the corresponding value from an associated term-value mapping for
-// the nonterminal and adds a typed match to the matcher accordingly.
-template <typename T>
-void CheckMappedValue(const grammar::Match* match,
- const NonterminalValue* nonterminal,
- grammar::Matcher* matcher) {
- const TermValueMatch* term =
- grammar::SelectFirstOfType<TermValueMatch>(match, MatchType_TERM_VALUE);
- if (term == nullptr) {
- return;
- }
- const int value = term->term_value->value();
- if (!T::IsValid(value)) {
- return;
- }
- T* result = matcher->AllocateAndInitMatch<T>(
- match->lhs, match->codepoint_span, match->match_offset);
- result->Reset();
- result->nonterminal = nonterminal;
- result->value = value;
- matcher->AddMatch(result);
-}
-
-// Checks if there is an associated value in the corresponding nonterminal and
-// adds a typed match to the matcher accordingly.
-template <typename T>
-void CheckDirectValue(const grammar::Match* match,
- const NonterminalValue* nonterminal,
- grammar::Matcher* matcher) {
- const int value = nonterminal->value()->value();
- if (!T::IsValid(value)) {
- return;
- }
- T* result = matcher->AllocateAndInitMatch<T>(
- match->lhs, match->codepoint_span, match->match_offset);
- result->Reset();
- result->nonterminal = nonterminal;
- result->value = value;
- matcher->AddMatch(result);
-}
-
-template <typename T>
-void CheckAndAddDirectOrMappedValue(const grammar::Match* match,
- const NonterminalValue* nonterminal,
- grammar::Matcher* matcher) {
- if (nonterminal->value() != nullptr) {
- CheckDirectValue<T>(match, nonterminal, matcher);
- } else {
- CheckMappedValue<T>(match, nonterminal, matcher);
- }
-}
-
-template <typename T>
-void CheckAndAddNumericValue(const grammar::Match* match,
- const NonterminalValue* nonterminal,
- StringPiece match_text,
- grammar::Matcher* matcher) {
- if (nonterminal->nonterminal_parameter() != nullptr &&
- nonterminal->nonterminal_parameter()->flag() &
- NonterminalParameter_::Flag_IS_SPELLED) {
- CheckMappedValue<T>(match, nonterminal, matcher);
- } else {
- CheckDigits<T>(match, nonterminal, match_text, matcher);
- }
-}
-
-// Tries to parse as digital time value.
-bool ParseDigitalTimeValue(const std::vector<UnicodeText::const_iterator>& text,
- const MatchComponents& components,
- const NonterminalValue* nonterminal,
- grammar::Matcher* matcher) {
- // Required fields.
- const HourMatch* hour = components.SubmatchOf<HourMatch>(MatchType_HOUR);
- if (hour == nullptr || hour->count_of_digits == 0) {
- return false;
- }
-
- // Optional fields.
- const MinuteMatch* minute =
- components.SubmatchOf<MinuteMatch>(MatchType_MINUTE);
- if (minute != nullptr && minute->count_of_digits == 0) {
- return false;
- }
- const SecondMatch* second =
- components.SubmatchOf<SecondMatch>(MatchType_SECOND);
- if (second != nullptr && second->count_of_digits == 0) {
- return false;
- }
- const FractionSecondMatch* fraction_second =
- components.SubmatchOf<FractionSecondMatch>(MatchType_FRACTION_SECOND);
- if (fraction_second != nullptr && fraction_second->count_of_digits == 0) {
- return false;
- }
-
- // Validation.
- uint32 validation = nonterminal->time_value_parameter()->validation();
- const grammar::Match* end = hour;
- if (minute != nullptr) {
- if (second != nullptr) {
- if (fraction_second != nullptr) {
- end = fraction_second;
- } else {
- end = second;
- }
- } else {
- end = minute;
- }
- }
-
- // Check if there is any extra space between h m s f.
- if ((validation &
- TimeValueParameter_::TimeValueValidation_ALLOW_EXTRA_SPACE) == 0) {
- // Check whether there is whitespace between token.
- if (minute != nullptr && minute->HasLeadingWhitespace()) {
- return false;
- }
- if (second != nullptr && second->HasLeadingWhitespace()) {
- return false;
- }
- if (fraction_second != nullptr && fraction_second->HasLeadingWhitespace()) {
- return false;
- }
- }
-
- // Check if there is any ':' or '.' as a prefix or suffix.
- if (validation &
- TimeValueParameter_::TimeValueValidation_DISALLOW_COLON_DOT_CONTEXT) {
- const int begin_pos = hour->codepoint_span.first;
- const int end_pos = end->codepoint_span.second;
- if (begin_pos > 1 &&
- (*text[begin_pos - 1] == ':' || *text[begin_pos - 1] == '.') &&
- isdigit(*text[begin_pos - 2])) {
- return false;
- }
- // Last valid codepoint is at text.size() - 2 as we added the end position
- // of text for easier span extraction.
- if (end_pos < text.size() - 2 &&
- (*text[end_pos] == ':' || *text[end_pos] == '.') &&
- isdigit(*text[end_pos + 1])) {
- return false;
- }
- }
-
- TimeValueMatch time_value;
- time_value.Init(components.root->lhs, components.root->codepoint_span,
- components.root->match_offset);
- time_value.Reset();
- time_value.hour_match = hour;
- time_value.minute_match = minute;
- time_value.second_match = second;
- time_value.fraction_second_match = fraction_second;
- time_value.is_hour_zero_prefixed = hour->is_zero_prefixed;
- time_value.is_minute_one_digit =
- (minute != nullptr && minute->count_of_digits == 1);
- time_value.is_second_one_digit =
- (second != nullptr && second->count_of_digits == 1);
- time_value.hour = hour->value;
- time_value.minute = (minute != nullptr ? minute->value : NO_VAL);
- time_value.second = (second != nullptr ? second->value : NO_VAL);
- time_value.fraction_second =
- (fraction_second != nullptr ? fraction_second->value : NO_VAL);
-
- if (!IsValidTimeValue(time_value)) {
- return false;
- }
-
- TimeValueMatch* result = matcher->AllocateMatch<TimeValueMatch>();
- *result = time_value;
- matcher->AddMatch(result);
- return true;
-}
-
-// Tries to parsing a time from spelled out time components.
-bool ParseSpelledTimeValue(const MatchComponents& components,
- const NonterminalValue* nonterminal,
- grammar::Matcher* matcher) {
- // Required fields.
- const HourMatch* hour = components.SubmatchOf<HourMatch>(MatchType_HOUR);
- if (hour == nullptr || hour->count_of_digits != 0) {
- return false;
- }
- // Optional fields.
- const MinuteMatch* minute =
- components.SubmatchOf<MinuteMatch>(MatchType_MINUTE);
- if (minute != nullptr && minute->count_of_digits != 0) {
- return false;
- }
- const SecondMatch* second =
- components.SubmatchOf<SecondMatch>(MatchType_SECOND);
- if (second != nullptr && second->count_of_digits != 0) {
- return false;
- }
-
- uint32 validation = nonterminal->time_value_parameter()->validation();
- // Check if there is any extra space between h m s.
- if ((validation &
- TimeValueParameter_::TimeValueValidation_ALLOW_EXTRA_SPACE) == 0) {
- // Check whether there is whitespace between token.
- if (minute != nullptr && minute->HasLeadingWhitespace()) {
- return false;
- }
- if (second != nullptr && second->HasLeadingWhitespace()) {
- return false;
- }
- }
-
- TimeValueMatch time_value;
- time_value.Init(components.root->lhs, components.root->codepoint_span,
- components.root->match_offset);
- time_value.Reset();
- time_value.hour_match = hour;
- time_value.minute_match = minute;
- time_value.second_match = second;
- time_value.is_hour_zero_prefixed = hour->is_zero_prefixed;
- time_value.is_minute_one_digit =
- (minute != nullptr && minute->count_of_digits == 1);
- time_value.is_second_one_digit =
- (second != nullptr && second->count_of_digits == 1);
- time_value.hour = hour->value;
- time_value.minute = (minute != nullptr ? minute->value : NO_VAL);
- time_value.second = (second != nullptr ? second->value : NO_VAL);
-
- if (!IsValidTimeValue(time_value)) {
- return false;
- }
-
- TimeValueMatch* result = matcher->AllocateMatch<TimeValueMatch>();
- *result = time_value;
- matcher->AddMatch(result);
- return true;
-}
-
-// Reconstructs and validates a time value from a match.
-void CheckTimeValue(const std::vector<UnicodeText::const_iterator>& text,
- const grammar::Match* match,
- const NonterminalValue* nonterminal,
- grammar::Matcher* matcher) {
- MatchComponents components(
- match, {MatchType_HOUR, MatchType_MINUTE, MatchType_SECOND,
- MatchType_FRACTION_SECOND});
- if (ParseDigitalTimeValue(text, components, nonterminal, matcher)) {
- return;
- }
- if (ParseSpelledTimeValue(components, nonterminal, matcher)) {
- return;
- }
-}
-
-// Validates a time span match.
-void CheckTimeSpan(const grammar::Match* match,
- const NonterminalValue* nonterminal,
- grammar::Matcher* matcher) {
- const TermValueMatch* ts_name =
- grammar::SelectFirstOfType<TermValueMatch>(match, MatchType_TERM_VALUE);
- const TermValue* term_value = ts_name->term_value;
- TC3_CHECK(term_value != nullptr);
- TC3_CHECK(term_value->time_span_spec() != nullptr);
- const TimeSpanSpec* ts_spec = term_value->time_span_spec();
- TimeSpanMatch* time_span = matcher->AllocateAndInitMatch<TimeSpanMatch>(
- match->lhs, match->codepoint_span, match->match_offset);
- time_span->Reset();
- time_span->nonterminal = nonterminal;
- time_span->time_span_spec = ts_spec;
- time_span->time_span_code = ts_spec->code();
- matcher->AddMatch(time_span);
-}
-
-// Validates a time period match.
-void CheckTimePeriod(const std::vector<UnicodeText::const_iterator>& text,
- const grammar::Match* match,
- const NonterminalValue* nonterminal,
- grammar::Matcher* matcher) {
- int period_value = NO_VAL;
-
- // If a value mapping exists, use it.
- if (nonterminal->value() != nullptr) {
- period_value = nonterminal->value()->value();
- } else if (const TermValueMatch* term =
- grammar::SelectFirstOfType<TermValueMatch>(
- match, MatchType_TERM_VALUE)) {
- period_value = term->term_value->value();
- } else if (const grammar::Match* digits =
- grammar::SelectFirstOfType<grammar::Match>(
- match, grammar::Match::kDigitsType)) {
- period_value = ParseLeadingDec32Value(
- std::string(text[digits->codepoint_span.first].utf8_data(),
- text[digits->codepoint_span.second].utf8_data() -
- text[digits->codepoint_span.first].utf8_data())
- .c_str());
- }
-
- if (period_value <= NO_VAL) {
- return;
- }
-
- TimePeriodMatch* result = matcher->AllocateAndInitMatch<TimePeriodMatch>(
- match->lhs, match->codepoint_span, match->match_offset);
- result->Reset();
- result->nonterminal = nonterminal;
- result->value = period_value;
- matcher->AddMatch(result);
-}
-
-// Reconstructs a date from a relative date rule match.
-void CheckRelativeDate(const DateAnnotationOptions& options,
- const grammar::Match* match,
- const NonterminalValue* nonterminal,
- grammar::Matcher* matcher) {
- if (!options.enable_special_day_offset &&
- grammar::SelectFirstOfType<TermValueMatch>(match, MatchType_TERM_VALUE) !=
- nullptr) {
- // Special day offsets, like "Today", "Tomorrow" etc. are not enabled.
- return;
- }
-
- RelativeMatch* relative_match = matcher->AllocateAndInitMatch<RelativeMatch>(
- match->lhs, match->codepoint_span, match->match_offset);
- relative_match->Reset();
- relative_match->nonterminal = nonterminal;
-
- // Fill relative date information from individual components.
- grammar::Traverse(match, [match, relative_match](const grammar::Match* node) {
- // Ignore the current match.
- if (node == match || node->type == grammar::Match::kUnknownType) {
- return true;
- }
-
- if (node->type == MatchType_TERM_VALUE) {
- const int value =
- static_cast<const TermValueMatch*>(node)->term_value->value();
- relative_match->day = abs(value);
- if (value >= 0) {
- // Marks "today" as in the future.
- relative_match->is_future_date = true;
- }
- relative_match->existing |=
- (RelativeMatch::HAS_DAY | RelativeMatch::HAS_IS_FUTURE);
- return false;
- }
-
- // Parse info from nonterminal.
- const NonterminalValue* nonterminal =
- static_cast<const NonterminalMatch*>(node)->nonterminal;
- if (nonterminal != nullptr &&
- nonterminal->relative_parameter() != nullptr) {
- const RelativeParameter* relative_parameter =
- nonterminal->relative_parameter();
- if (relative_parameter->period() !=
- RelativeParameter_::Period_PERIOD_UNKNOWN) {
- relative_match->is_future_date =
- (relative_parameter->period() ==
- RelativeParameter_::Period_PERIOD_FUTURE);
- relative_match->existing |= RelativeMatch::HAS_IS_FUTURE;
- }
- if (relative_parameter->day_of_week_interpretation() != nullptr) {
- relative_match->day_of_week_nonterminal = nonterminal;
- relative_match->existing |= RelativeMatch::HAS_DAY_OF_WEEK;
- }
- }
-
- // Relative day of week.
- if (node->type == MatchType_DAY_OF_WEEK) {
- relative_match->day_of_week =
- static_cast<const DayOfWeekMatch*>(node)->value;
- return false;
- }
-
- if (node->type != MatchType_TIME_PERIOD) {
- return true;
- }
-
- const TimePeriodMatch* period = static_cast<const TimePeriodMatch*>(node);
- switch (nonterminal->relative_parameter()->type()) {
- case RelativeParameter_::RelativeType_YEAR: {
- relative_match->year = period->value;
- relative_match->existing |= RelativeMatch::HAS_YEAR;
- break;
- }
- case RelativeParameter_::RelativeType_MONTH: {
- relative_match->month = period->value;
- relative_match->existing |= RelativeMatch::HAS_MONTH;
- break;
- }
- case RelativeParameter_::RelativeType_WEEK: {
- relative_match->week = period->value;
- relative_match->existing |= RelativeMatch::HAS_WEEK;
- break;
- }
- case RelativeParameter_::RelativeType_DAY: {
- relative_match->day = period->value;
- relative_match->existing |= RelativeMatch::HAS_DAY;
- break;
- }
- case RelativeParameter_::RelativeType_HOUR: {
- relative_match->hour = period->value;
- relative_match->existing |= RelativeMatch::HAS_HOUR;
- break;
- }
- case RelativeParameter_::RelativeType_MINUTE: {
- relative_match->minute = period->value;
- relative_match->existing |= RelativeMatch::HAS_MINUTE;
- break;
- }
- case RelativeParameter_::RelativeType_SECOND: {
- relative_match->second = period->value;
- relative_match->existing |= RelativeMatch::HAS_SECOND;
- break;
- }
- default:
- break;
- }
-
- return true;
- });
- matcher->AddMatch(relative_match);
-}
-
-bool IsValidTimeZoneOffset(const int time_zone_offset) {
- return (time_zone_offset >= -720 && time_zone_offset <= 840 &&
- time_zone_offset % 15 == 0);
-}
-
-// Parses, validates and adds a time zone offset match.
-void CheckTimeZoneOffset(const grammar::Match* match,
- const NonterminalValue* nonterminal,
- grammar::Matcher* matcher) {
- MatchComponents components(
- match, {MatchType_DIGITS, MatchType_TERM_VALUE, MatchType_NONTERMINAL});
- const TermValueMatch* tz_sign =
- components.SubmatchOf<TermValueMatch>(MatchType_TERM_VALUE);
- if (tz_sign == nullptr) {
- return;
- }
- const int sign = tz_sign->term_value->value();
- TC3_CHECK(sign == -1 || sign == 1);
-
- const int tz_digits_index = components.IndexOf(MatchType_DIGITS);
- if (tz_digits_index < 0) {
- return;
- }
- const DigitsMatch* tz_digits =
- components.SubmatchAt<DigitsMatch>(tz_digits_index);
- if (tz_digits == nullptr) {
- return;
- }
-
- int offset;
- if (tz_digits->count_of_digits >= 3) {
- offset = (tz_digits->value / 100) * 60 + (tz_digits->value % 100);
- } else {
- offset = tz_digits->value * 60;
- if (const DigitsMatch* tz_digits_extra = components.SubmatchOf<DigitsMatch>(
- MatchType_DIGITS, /*start_index=*/tz_digits_index + 1)) {
- offset += tz_digits_extra->value;
- }
- }
-
- const NonterminalMatch* tz_offset =
- components.SubmatchOf<NonterminalMatch>(MatchType_NONTERMINAL);
- if (tz_offset == nullptr) {
- return;
- }
-
- const int time_zone_offset = sign * offset;
- if (!IsValidTimeZoneOffset(time_zone_offset)) {
- return;
- }
-
- TimeZoneOffsetMatch* result =
- matcher->AllocateAndInitMatch<TimeZoneOffsetMatch>(
- match->lhs, match->codepoint_span, match->match_offset);
- result->Reset();
- result->nonterminal = nonterminal;
- result->time_zone_offset_param =
- tz_offset->nonterminal->time_zone_offset_parameter();
- result->time_zone_offset = time_zone_offset;
- matcher->AddMatch(result);
-}
-
-// Validates and adds a time zone name match.
-void CheckTimeZoneName(const grammar::Match* match,
- const NonterminalValue* nonterminal,
- grammar::Matcher* matcher) {
- TC3_CHECK(match->IsUnaryRule());
- const TermValueMatch* tz_name =
- static_cast<const TermValueMatch*>(match->unary_rule_rhs());
- if (tz_name == nullptr) {
- return;
- }
- const TimeZoneNameSpec* tz_name_spec =
- tz_name->term_value->time_zone_name_spec();
- TimeZoneNameMatch* result = matcher->AllocateAndInitMatch<TimeZoneNameMatch>(
- match->lhs, match->codepoint_span, match->match_offset);
- result->Reset();
- result->nonterminal = nonterminal;
- result->time_zone_name_spec = tz_name_spec;
- result->time_zone_code = tz_name_spec->code();
- matcher->AddMatch(result);
-}
-
-// Adds a mapped term value match containing its value.
-void AddTermValue(const grammar::Match* match, const TermValue* term_value,
- grammar::Matcher* matcher) {
- TermValueMatch* term_match = matcher->AllocateAndInitMatch<TermValueMatch>(
- match->lhs, match->codepoint_span, match->match_offset);
- term_match->Reset();
- term_match->term_value = term_value;
- matcher->AddMatch(term_match);
-}
-
-// Adds a match for a nonterminal.
-void AddNonterminal(const grammar::Match* match,
- const NonterminalValue* nonterminal,
- grammar::Matcher* matcher) {
- NonterminalMatch* result =
- matcher->AllocateAndInitMatch<NonterminalMatch>(*match);
- result->Reset();
- result->nonterminal = nonterminal;
- matcher->AddMatch(result);
-}
-
-// Adds a match for an extraction rule that is potentially used in a date range
-// rule.
-void AddExtractionRuleMatch(const grammar::Match* match,
- const ExtractionRuleParameter* rule,
- grammar::Matcher* matcher) {
- ExtractionMatch* result =
- matcher->AllocateAndInitMatch<ExtractionMatch>(*match);
- result->Reset();
- result->extraction_rule = rule;
- matcher->AddMatch(result);
-}
-
-} // namespace
-
-void DateExtractor::HandleExtractionRuleMatch(
- const ExtractionRuleParameter* rule, const grammar::Match* match,
- grammar::Matcher* matcher) {
- if (rule->id() != nullptr) {
- const std::string rule_id = rule->id()->str();
- bool keep = false;
- for (const std::string& extra_requested_dates_id :
- options_.extra_requested_dates) {
- if (extra_requested_dates_id == rule_id) {
- keep = true;
- break;
- }
- }
- if (!keep) {
- return;
- }
- }
- output_.push_back(
- Output{rule, matcher->AllocateAndInitMatch<grammar::Match>(*match)});
-}
-
-void DateExtractor::HandleRangeExtractionRuleMatch(const grammar::Match* match,
- grammar::Matcher* matcher) {
- // Collect the two datetime roots that make up the range.
- std::vector<const grammar::Match*> parts;
- grammar::Traverse(match, [match, &parts](const grammar::Match* node) {
- if (node == match || node->type == grammar::Match::kUnknownType) {
- // Just continue traversing the match.
- return true;
- }
-
- // Collect, but don't expand the individual datetime nodes.
- parts.push_back(node);
- return false;
- });
- TC3_CHECK_EQ(parts.size(), 2);
- range_output_.push_back(
- RangeOutput{matcher->AllocateAndInitMatch<grammar::Match>(*match),
- /*from=*/parts[0], /*to=*/parts[1]});
-}
-
-void DateExtractor::MatchFound(const grammar::Match* match,
- const grammar::CallbackId type,
- const int64 value, grammar::Matcher* matcher) {
- switch (type) {
- case MatchType_DATETIME_RULE: {
- HandleExtractionRuleMatch(
- /*rule=*/
- datetime_rules_->extraction_rule()->Get(value), match, matcher);
- return;
- }
- case MatchType_DATETIME_RANGE_RULE: {
- HandleRangeExtractionRuleMatch(match, matcher);
- return;
- }
- case MatchType_DATETIME: {
- // If an extraction rule is also part of a range extraction rule, then the
- // extraction rule is treated as a rule match and nonterminal match.
- // This type is used to match the rule as non terminal.
- AddExtractionRuleMatch(
- match, datetime_rules_->extraction_rule()->Get(value), matcher);
- return;
- }
- case MatchType_TERM_VALUE: {
- // Handle mapped terms.
- AddTermValue(match, datetime_rules_->term_value()->Get(value), matcher);
- return;
- }
- default:
- break;
- }
-
- // Handle non-terminals.
- const NonterminalValue* nonterminal =
- datetime_rules_->nonterminal_value()->Get(value);
- StringPiece match_text =
- StringPiece(text_[match->codepoint_span.first].utf8_data(),
- text_[match->codepoint_span.second].utf8_data() -
- text_[match->codepoint_span.first].utf8_data());
- switch (type) {
- case MatchType_NONTERMINAL:
- AddNonterminal(match, nonterminal, matcher);
- break;
- case MatchType_DIGITS:
- CheckDigits<DigitsMatch>(match, nonterminal, match_text, matcher);
- break;
- case MatchType_YEAR:
- CheckDigits<YearMatch>(match, nonterminal, match_text, matcher);
- break;
- case MatchType_MONTH:
- CheckAndAddNumericValue<MonthMatch>(match, nonterminal, match_text,
- matcher);
- break;
- case MatchType_DAY:
- CheckAndAddNumericValue<DayMatch>(match, nonterminal, match_text,
- matcher);
- break;
- case MatchType_DAY_OF_WEEK:
- CheckAndAddDirectOrMappedValue<DayOfWeekMatch>(match, nonterminal,
- matcher);
- break;
- case MatchType_HOUR:
- CheckAndAddNumericValue<HourMatch>(match, nonterminal, match_text,
- matcher);
- break;
- case MatchType_MINUTE:
- CheckAndAddNumericValue<MinuteMatch>(match, nonterminal, match_text,
- matcher);
- break;
- case MatchType_SECOND:
- CheckAndAddNumericValue<SecondMatch>(match, nonterminal, match_text,
- matcher);
- break;
- case MatchType_FRACTION_SECOND:
- CheckDigitsAsFraction<FractionSecondMatch>(match, nonterminal, match_text,
- matcher);
- break;
- case MatchType_TIME_VALUE:
- CheckTimeValue(text_, match, nonterminal, matcher);
- break;
- case MatchType_TIME_SPAN:
- CheckTimeSpan(match, nonterminal, matcher);
- break;
- case MatchType_TIME_ZONE_NAME:
- CheckTimeZoneName(match, nonterminal, matcher);
- break;
- case MatchType_TIME_ZONE_OFFSET:
- CheckTimeZoneOffset(match, nonterminal, matcher);
- break;
- case MatchType_TIME_PERIOD:
- CheckTimePeriod(text_, match, nonterminal, matcher);
- break;
- case MatchType_RELATIVE_DATE:
- CheckRelativeDate(options_, match, nonterminal, matcher);
- break;
- case MatchType_COMBINED_DIGITS:
- CheckCombinedDigits<CombinedDigitsMatch>(match, nonterminal, match_text,
- matcher);
- break;
- default:
- TC3_VLOG(ERROR) << "Unhandled match type: " << type;
- }
-}
-
-} // namespace libtextclassifier3::dates
diff --git a/native/annotator/grammar/dates/extractor.h b/native/annotator/grammar/dates/extractor.h
deleted file mode 100644
index 58c8880..0000000
--- a/native/annotator/grammar/dates/extractor.h
+++ /dev/null
@@ -1,87 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_EXTRACTOR_H_
-#define LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_EXTRACTOR_H_
-
-#include <vector>
-
-#include "annotator/grammar/dates/annotations/annotation-options.h"
-#include "annotator/grammar/dates/dates_generated.h"
-#include "utils/base/integral_types.h"
-#include "utils/grammar/callback-delegate.h"
-#include "utils/grammar/match.h"
-#include "utils/grammar/matcher.h"
-#include "utils/grammar/types.h"
-#include "utils/strings/stringpiece.h"
-#include "utils/utf8/unicodetext.h"
-
-namespace libtextclassifier3::dates {
-
-// A helper class for the datetime parser that extracts structured data from
-// the datetime grammar matches.
-// It handles simple sanity checking of the rule matches and interacts with the
-// grammar matcher to extract all datetime occurrences in a text.
-class DateExtractor : public grammar::CallbackDelegate {
- public:
- // Represents a date match for an extraction rule.
- struct Output {
- const ExtractionRuleParameter* rule = nullptr;
- const grammar::Match* match = nullptr;
- };
-
- // Represents a date match from a range extraction rule.
- struct RangeOutput {
- const grammar::Match* match = nullptr;
- const grammar::Match* from = nullptr;
- const grammar::Match* to = nullptr;
- };
-
- DateExtractor(const std::vector<UnicodeText::const_iterator>& text,
- const DateAnnotationOptions& options,
- const DatetimeRules* datetime_rules)
- : text_(text), options_(options), datetime_rules_(datetime_rules) {}
-
- // Handle a rule match in the date time grammar.
- // This checks the type of the match and does type dependent checks.
- void MatchFound(const grammar::Match* match, grammar::CallbackId type,
- int64 value, grammar::Matcher* matcher) override;
-
- const std::vector<Output>& output() const { return output_; }
- const std::vector<RangeOutput>& range_output() const { return range_output_; }
-
- private:
- // Extracts a date from a root rule match.
- void HandleExtractionRuleMatch(const ExtractionRuleParameter* rule,
- const grammar::Match* match,
- grammar::Matcher* matcher);
-
- // Extracts a date range from a root rule match.
- void HandleRangeExtractionRuleMatch(const grammar::Match* match,
- grammar::Matcher* matcher);
-
- const std::vector<UnicodeText::const_iterator>& text_;
- const DateAnnotationOptions& options_;
- const DatetimeRules* datetime_rules_;
-
- // Extraction results.
- std::vector<Output> output_;
- std::vector<RangeOutput> range_output_;
-};
-
-} // namespace libtextclassifier3::dates
-
-#endif // LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_EXTRACTOR_H_
diff --git a/native/annotator/grammar/dates/parser.cc b/native/annotator/grammar/dates/parser.cc
deleted file mode 100644
index 37e65fc..0000000
--- a/native/annotator/grammar/dates/parser.cc
+++ /dev/null
@@ -1,794 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "annotator/grammar/dates/parser.h"
-
-#include "annotator/grammar/dates/extractor.h"
-#include "annotator/grammar/dates/utils/date-match.h"
-#include "annotator/grammar/dates/utils/date-utils.h"
-#include "utils/base/integral_types.h"
-#include "utils/base/logging.h"
-#include "utils/base/macros.h"
-#include "utils/grammar/lexer.h"
-#include "utils/grammar/matcher.h"
-#include "utils/grammar/rules_generated.h"
-#include "utils/grammar/types.h"
-#include "utils/strings/split.h"
-#include "utils/strings/stringpiece.h"
-
-namespace libtextclassifier3::dates {
-namespace {
-
-// Helper methods to validate individual components from a date match.
-
-// Checks the validation requirement of a rule against a match.
-// For example if the rule asks for `SPELLED_MONTH`, then we check that the
-// match has the right flag.
-bool CheckMatchValidationAndFlag(
- const grammar::Match* match, const ExtractionRuleParameter* rule,
- const ExtractionRuleParameter_::ExtractionValidation validation,
- const NonterminalParameter_::Flag flag) {
- if (rule == nullptr || (rule->validation() & validation) == 0) {
- // No validation requirement.
- return true;
- }
- const NonterminalParameter* nonterminal_parameter =
- static_cast<const NonterminalMatch*>(match)
- ->nonterminal->nonterminal_parameter();
- return (nonterminal_parameter != nullptr &&
- (nonterminal_parameter->flag() & flag) != 0);
-}
-
-bool GenerateDate(const ExtractionRuleParameter* rule,
- const grammar::Match* match, DateMatch* date) {
- bool is_valid = true;
-
- // Post check and assign date components.
- grammar::Traverse(match, [rule, date, &is_valid](const grammar::Match* node) {
- switch (node->type) {
- case MatchType_YEAR: {
- if (CheckMatchValidationAndFlag(
- node, rule,
- ExtractionRuleParameter_::ExtractionValidation_SPELLED_YEAR,
- NonterminalParameter_::Flag_IS_SPELLED)) {
- date->year_match = static_cast<const YearMatch*>(node);
- date->year = date->year_match->value;
- } else {
- is_valid = false;
- }
- break;
- }
- case MatchType_MONTH: {
- if (CheckMatchValidationAndFlag(
- node, rule,
- ExtractionRuleParameter_::ExtractionValidation_SPELLED_MONTH,
- NonterminalParameter_::Flag_IS_SPELLED)) {
- date->month_match = static_cast<const MonthMatch*>(node);
- date->month = date->month_match->value;
- } else {
- is_valid = false;
- }
- break;
- }
- case MatchType_DAY: {
- if (CheckMatchValidationAndFlag(
- node, rule,
- ExtractionRuleParameter_::ExtractionValidation_SPELLED_DAY,
- NonterminalParameter_::Flag_IS_SPELLED)) {
- date->day_match = static_cast<const DayMatch*>(node);
- date->day = date->day_match->value;
- } else {
- is_valid = false;
- }
- break;
- }
- case MatchType_DAY_OF_WEEK: {
- date->day_of_week_match = static_cast<const DayOfWeekMatch*>(node);
- date->day_of_week =
- static_cast<DayOfWeek>(date->day_of_week_match->value);
- break;
- }
- case MatchType_TIME_VALUE: {
- date->time_value_match = static_cast<const TimeValueMatch*>(node);
- date->hour = date->time_value_match->hour;
- date->minute = date->time_value_match->minute;
- date->second = date->time_value_match->second;
- date->fraction_second = date->time_value_match->fraction_second;
- return false;
- }
- case MatchType_TIME_SPAN: {
- date->time_span_match = static_cast<const TimeSpanMatch*>(node);
- date->time_span_code = date->time_span_match->time_span_code;
- return false;
- }
- case MatchType_TIME_ZONE_NAME: {
- date->time_zone_name_match =
- static_cast<const TimeZoneNameMatch*>(node);
- date->time_zone_code = date->time_zone_name_match->time_zone_code;
- return false;
- }
- case MatchType_TIME_ZONE_OFFSET: {
- date->time_zone_offset_match =
- static_cast<const TimeZoneOffsetMatch*>(node);
- date->time_zone_offset = date->time_zone_offset_match->time_zone_offset;
- return false;
- }
- case MatchType_RELATIVE_DATE: {
- date->relative_match = static_cast<const RelativeMatch*>(node);
- return false;
- }
- case MatchType_COMBINED_DIGITS: {
- date->combined_digits_match =
- static_cast<const CombinedDigitsMatch*>(node);
- if (date->combined_digits_match->HasYear()) {
- date->year = date->combined_digits_match->GetYear();
- }
- if (date->combined_digits_match->HasMonth()) {
- date->month = date->combined_digits_match->GetMonth();
- }
- if (date->combined_digits_match->HasDay()) {
- date->day = date->combined_digits_match->GetDay();
- }
- if (date->combined_digits_match->HasHour()) {
- date->hour = date->combined_digits_match->GetHour();
- }
- if (date->combined_digits_match->HasMinute()) {
- date->minute = date->combined_digits_match->GetMinute();
- }
- if (date->combined_digits_match->HasSecond()) {
- date->second = date->combined_digits_match->GetSecond();
- }
- return false;
- }
- default:
- // Expand node further.
- return true;
- }
-
- return false;
- });
-
- if (is_valid) {
- date->begin = match->codepoint_span.first;
- date->end = match->codepoint_span.second;
- date->priority = rule ? rule->priority_delta() : 0;
- date->annotator_priority_score =
- rule ? rule->annotator_priority_score() : 0.0;
- }
- return is_valid;
-}
-
-bool GenerateFromOrToDateRange(const grammar::Match* match, DateMatch* date) {
- return GenerateDate(
- /*rule=*/(
- match->type == MatchType_DATETIME
- ? static_cast<const ExtractionMatch*>(match)->extraction_rule
- : nullptr),
- match, date);
-}
-
-bool GenerateDateRange(const grammar::Match* match, const grammar::Match* from,
- const grammar::Match* to, DateRangeMatch* date_range) {
- if (!GenerateFromOrToDateRange(from, &date_range->from)) {
- TC3_LOG(WARNING) << "Failed to generate date for `from`.";
- return false;
- }
- if (!GenerateFromOrToDateRange(to, &date_range->to)) {
- TC3_LOG(WARNING) << "Failed to generate date for `to`.";
- return false;
- }
- date_range->begin = match->codepoint_span.first;
- date_range->end = match->codepoint_span.second;
- return true;
-}
-
-bool NormalizeHour(DateMatch* date) {
- if (date->time_span_match == nullptr) {
- // Nothing to do.
- return true;
- }
- return NormalizeHourByTimeSpan(date->time_span_match->time_span_spec, date);
-}
-
-void CheckAndSetAmbiguousHour(DateMatch* date) {
- if (date->HasHour()) {
- // Use am-pm ambiguity as default.
- if (!date->HasTimeSpanCode() && date->hour >= 1 && date->hour <= 12 &&
- !(date->time_value_match != nullptr &&
- date->time_value_match->hour_match != nullptr &&
- date->time_value_match->hour_match->is_zero_prefixed)) {
- date->SetAmbiguousHourProperties(2, 12);
- }
- }
-}
-
-// Normalizes a date candidate.
-// Returns whether the candidate was successfully normalized.
-bool NormalizeDate(DateMatch* date) {
- // Normalize hour.
- if (!NormalizeHour(date)) {
- TC3_VLOG(ERROR) << "Hour normalization (according to time-span) failed."
- << date->DebugString();
- return false;
- }
- CheckAndSetAmbiguousHour(date);
- if (!date->IsValid()) {
- TC3_VLOG(ERROR) << "Fields inside date instance are ill-formed "
- << date->DebugString();
- }
- return true;
-}
-
-// Copies the field from one DateMatch to another whose field is null. for
-// example: if the from is "May 1, 8pm", and the to is "9pm", "May 1" will be
-// copied to "to". Now we only copy fields for date range requirement.fv
-void CopyFieldsForDateMatch(const DateMatch& from, DateMatch* to) {
- if (from.time_span_match != nullptr && to->time_span_match == nullptr) {
- to->time_span_match = from.time_span_match;
- to->time_span_code = from.time_span_code;
- }
- if (from.month_match != nullptr && to->month_match == nullptr) {
- to->month_match = from.month_match;
- to->month = from.month;
- }
-}
-
-// Normalizes a date range candidate.
-// Returns whether the date range was successfully normalized.
-bool NormalizeDateRange(DateRangeMatch* date_range) {
- CopyFieldsForDateMatch(date_range->from, &date_range->to);
- CopyFieldsForDateMatch(date_range->to, &date_range->from);
- return (NormalizeDate(&date_range->from) && NormalizeDate(&date_range->to));
-}
-
-bool CheckDate(const DateMatch& date, const ExtractionRuleParameter* rule) {
- // It's possible that "time_zone_name_match == NULL" when
- // "HasTimeZoneCode() == true", or "time_zone_offset_match == NULL" when
- // "HasTimeZoneOffset() == true" due to inference between endpoints, so we
- // must check if they really exist before using them.
- if (date.HasTimeZoneOffset()) {
- if (date.HasTimeZoneCode()) {
- if (date.time_zone_name_match != nullptr) {
- TC3_CHECK(date.time_zone_name_match->time_zone_name_spec != nullptr);
- const TimeZoneNameSpec* spec =
- date.time_zone_name_match->time_zone_name_spec;
- if (!spec->is_utc()) {
- return false;
- }
- if (!spec->is_abbreviation()) {
- return false;
- }
- }
- } else if (date.time_zone_offset_match != nullptr) {
- TC3_CHECK(date.time_zone_offset_match->time_zone_offset_param != nullptr);
- const TimeZoneOffsetParameter* param =
- date.time_zone_offset_match->time_zone_offset_param;
- if (param->format() == TimeZoneOffsetParameter_::Format_FORMAT_H ||
- param->format() == TimeZoneOffsetParameter_::Format_FORMAT_HH) {
- return false;
- }
- if (!(rule->validation() &
- ExtractionRuleParameter_::
- ExtractionValidation_ALLOW_UNCONFIDENT_TIME_ZONE)) {
- if (param->format() == TimeZoneOffsetParameter_::Format_FORMAT_H_MM ||
- param->format() == TimeZoneOffsetParameter_::Format_FORMAT_HH_MM ||
- param->format() == TimeZoneOffsetParameter_::Format_FORMAT_HMM) {
- return false;
- }
- }
- }
- }
-
- // Case: 1 April could be extracted as year 1, month april.
- // We simply remove this case.
- if (!date.HasBcAd() && date.year_match != nullptr && date.year < 1000) {
- // We allow case like 11/5/01
- if (date.HasMonth() && date.HasDay() &&
- date.year_match->count_of_digits == 2) {
- } else {
- return false;
- }
- }
-
- // Ignore the date if the year is larger than 9999 (The maximum number of 4
- // digits).
- if (date.year_match != nullptr && date.year > 9999) {
- TC3_VLOG(ERROR) << "Year is greater than 9999.";
- return false;
- }
-
- // Case: spelled may could be month 5, it also used very common as modal
- // verbs. We ignore spelled may as month.
- if ((rule->validation() &
- ExtractionRuleParameter_::ExtractionValidation_SPELLED_MONTH) &&
- date.month == 5 && !date.HasYear() && !date.HasDay()) {
- return false;
- }
-
- return true;
-}
-
-bool CheckContext(const std::vector<UnicodeText::const_iterator>& text,
- const DateExtractor::Output& output) {
- const uint32 validation = output.rule->validation();
-
- // Nothing to check if we don't have any validation requirements for the
- // span boundaries.
- if ((validation &
- (ExtractionRuleParameter_::ExtractionValidation_LEFT_BOUND |
- ExtractionRuleParameter_::ExtractionValidation_RIGHT_BOUND)) == 0) {
- return true;
- }
-
- const int begin = output.match->codepoint_span.first;
- const int end = output.match->codepoint_span.second;
-
- // So far, we only check that the adjacent character cannot be a separator,
- // like /, - or .
- if ((validation &
- ExtractionRuleParameter_::ExtractionValidation_LEFT_BOUND) != 0) {
- if (begin > 0 && (*text[begin - 1] == '/' || *text[begin - 1] == '-' ||
- *text[begin - 1] == ':')) {
- return false;
- }
- }
- if ((validation &
- ExtractionRuleParameter_::ExtractionValidation_RIGHT_BOUND) != 0) {
- // Last valid codepoint is at text.size() - 2 as we added the end position
- // of text for easier span extraction.
- if (end < text.size() - 1 &&
- (*text[end] == '/' || *text[end] == '-' || *text[end] == ':')) {
- return false;
- }
- }
-
- return true;
-}
-
-// Validates a date match. Returns true if the candidate is valid.
-bool ValidateDate(const std::vector<UnicodeText::const_iterator>& text,
- const DateExtractor::Output& output, const DateMatch& date) {
- if (!CheckDate(date, output.rule)) {
- return false;
- }
- if (!CheckContext(text, output)) {
- return false;
- }
- return true;
-}
-
-// Builds matched date instances from the grammar output.
-std::vector<DateMatch> BuildDateMatches(
- const std::vector<UnicodeText::const_iterator>& text,
- const std::vector<DateExtractor::Output>& outputs) {
- std::vector<DateMatch> result;
- for (const DateExtractor::Output& output : outputs) {
- DateMatch date;
- if (GenerateDate(output.rule, output.match, &date)) {
- if (!NormalizeDate(&date)) {
- continue;
- }
- if (!ValidateDate(text, output, date)) {
- continue;
- }
- result.push_back(date);
- }
- }
- return result;
-}
-
-// Builds matched date range instances from the grammar output.
-std::vector<DateRangeMatch> BuildDateRangeMatches(
- const std::vector<UnicodeText::const_iterator>& text,
- const std::vector<DateExtractor::RangeOutput>& range_outputs) {
- std::vector<DateRangeMatch> result;
- for (const DateExtractor::RangeOutput& range_output : range_outputs) {
- DateRangeMatch date_range;
- if (GenerateDateRange(range_output.match, range_output.from,
- range_output.to, &date_range)) {
- if (!NormalizeDateRange(&date_range)) {
- continue;
- }
- result.push_back(date_range);
- }
- }
- return result;
-}
-
-template <typename T>
-void RemoveDeletedMatches(const std::vector<bool>& removed,
- std::vector<T>* matches) {
- int input = 0;
- for (int next = 0; next < matches->size(); ++next) {
- if (removed[next]) {
- continue;
- }
- if (input != next) {
- (*matches)[input] = (*matches)[next];
- }
- input++;
- }
- matches->resize(input);
-}
-
-// Removes duplicated date or date range instances.
-// Overlapping date and date ranges are not considered here.
-template <typename T>
-void RemoveDuplicatedDates(std::vector<T>* matches) {
- // Assumption: matches are sorted ascending by (begin, end).
- std::vector<bool> removed(matches->size(), false);
- for (int i = 0; i < matches->size(); i++) {
- if (removed[i]) {
- continue;
- }
- const T& candidate = matches->at(i);
- for (int j = i + 1; j < matches->size(); j++) {
- if (removed[j]) {
- continue;
- }
- const T& next = matches->at(j);
-
- // Not overlapping.
- if (next.begin >= candidate.end) {
- break;
- }
-
- // If matching the same span of text, then check the priority.
- if (candidate.begin == next.begin && candidate.end == next.end) {
- if (candidate.GetPriority() < next.GetPriority()) {
- removed[i] = true;
- break;
- } else {
- removed[j] = true;
- continue;
- }
- }
-
- // Checks if `next` is fully covered by fields of `candidate`.
- if (next.end <= candidate.end) {
- removed[j] = true;
- continue;
- }
-
- // Checks whether `candidate`/`next` is a refinement.
- if (IsRefinement(candidate, next)) {
- removed[j] = true;
- continue;
- } else if (IsRefinement(next, candidate)) {
- removed[i] = true;
- break;
- }
- }
- }
- RemoveDeletedMatches(removed, matches);
-}
-
-// Filters out simple overtriggering simple matches.
-bool IsBlacklistedDate(const UniLib& unilib,
- const std::vector<UnicodeText::const_iterator>& text,
- const DateMatch& match) {
- const int begin = match.begin;
- const int end = match.end;
- if (end - begin != 3) {
- return false;
- }
-
- std::string text_lower =
- unilib
- .ToLowerText(
- UTF8ToUnicodeText(text[begin].utf8_data(),
- text[end].utf8_data() - text[begin].utf8_data(),
- /*do_copy=*/false))
- .ToUTF8String();
-
- // "sun" is not a good abbreviation for a standalone day of the week.
- if (match.IsStandaloneRelativeDayOfWeek() &&
- (text_lower == "sun" || text_lower == "mon")) {
- return true;
- }
-
- // "mar" is not a good abbreviation for single month.
- if (match.HasMonth() && text_lower == "mar") {
- return true;
- }
-
- return false;
-}
-
-// Checks if two date matches are adjacent and mergeable.
-bool AreDateMatchesAdjacentAndMergeable(
- const UniLib& unilib, const std::vector<UnicodeText::const_iterator>& text,
- const std::vector<std::string>& ignored_spans, const DateMatch& prev,
- const DateMatch& next) {
- // Check the context between the two matches.
- if (next.begin <= prev.end) {
- // The two matches are not adjacent.
- return false;
- }
- UnicodeText span;
- for (int i = prev.end; i < next.begin; i++) {
- const char32 codepoint = *text[i];
- if (unilib.IsWhitespace(codepoint)) {
- continue;
- }
- span.push_back(unilib.ToLower(codepoint));
- }
- if (span.empty()) {
- return true;
- }
- const std::string span_text = span.ToUTF8String();
- bool matched = false;
- for (const std::string& ignored_span : ignored_spans) {
- if (span_text == ignored_span) {
- matched = true;
- break;
- }
- }
- if (!matched) {
- return false;
- }
- return IsDateMatchMergeable(prev, next);
-}
-
-// Merges adjacent date and date range.
-// For e.g. Monday, 5-10pm, the date "Monday" and the time range "5-10pm" will
-// be merged
-void MergeDateRangeAndDate(const UniLib& unilib,
- const std::vector<UnicodeText::const_iterator>& text,
- const std::vector<std::string>& ignored_spans,
- const std::vector<DateMatch>& dates,
- std::vector<DateRangeMatch>* date_ranges) {
- // For each range, check the date before or after the it to see if they could
- // be merged. Both the range and date array are sorted, so we only need to
- // scan the date array once.
- int next_date = 0;
- for (int i = 0; i < date_ranges->size(); i++) {
- DateRangeMatch* date_range = &date_ranges->at(i);
- // So far we only merge time range with a date.
- if (!date_range->from.HasHour()) {
- continue;
- }
-
- for (; next_date < dates.size(); next_date++) {
- const DateMatch& date = dates[next_date];
-
- // If the range is before the date, we check whether `date_range->to` can
- // be merged with the date.
- if (date_range->end <= date.begin) {
- DateMatch merged_date = date;
- if (AreDateMatchesAdjacentAndMergeable(unilib, text, ignored_spans,
- date_range->to, date)) {
- MergeDateMatch(date_range->to, &merged_date, /*update_span=*/true);
- date_range->to = merged_date;
- date_range->end = date_range->to.end;
- MergeDateMatch(date, &date_range->from, /*update_span=*/false);
- next_date++;
-
- // Check the second date after the range to see if it could be merged
- // further. For example: 10-11pm, Monday, May 15. 10-11pm is merged
- // with Monday and then we check that it could be merged with May 15
- // as well.
- if (next_date < dates.size()) {
- DateMatch next_match = dates[next_date];
- if (AreDateMatchesAdjacentAndMergeable(
- unilib, text, ignored_spans, date_range->to, next_match)) {
- MergeDateMatch(date_range->to, &next_match, /*update_span=*/true);
- date_range->to = next_match;
- date_range->end = date_range->to.end;
- MergeDateMatch(dates[next_date], &date_range->from,
- /*update_span=*/false);
- next_date++;
- }
- }
- }
- // Since the range is before the date, we try to check if the next range
- // could be merged with the current date.
- break;
- } else if (date_range->end > date.end && date_range->begin > date.begin) {
- // If the range is after the date, we check if `date_range.from` can be
- // merged with the date. Here is a special case, the date before range
- // could be partially overlapped. This is because the range.from could
- // be extracted as year in date. For example: March 3, 10-11pm is
- // extracted as date March 3, 2010 and the range 10-11pm. In this
- // case, we simply clear the year from date.
- DateMatch merged_date = date;
- if (date.HasYear() &&
- date.year_match->codepoint_span.second > date_range->begin) {
- merged_date.year_match = nullptr;
- merged_date.year = NO_VAL;
- merged_date.end = date.year_match->match_offset;
- }
- // Check and merge the range and the date before the range.
- if (AreDateMatchesAdjacentAndMergeable(unilib, text, ignored_spans,
- merged_date, date_range->from)) {
- MergeDateMatch(merged_date, &date_range->from, /*update_span=*/true);
- date_range->begin = date_range->from.begin;
- MergeDateMatch(merged_date, &date_range->to, /*update_span=*/false);
-
- // Check if the second date before the range can be merged as well.
- if (next_date > 0) {
- DateMatch prev_match = dates[next_date - 1];
- if (prev_match.end <= date_range->from.begin) {
- if (AreDateMatchesAdjacentAndMergeable(unilib, text,
- ignored_spans, prev_match,
- date_range->from)) {
- MergeDateMatch(prev_match, &date_range->from,
- /*update_span=*/true);
- date_range->begin = date_range->from.begin;
- MergeDateMatch(prev_match, &date_range->to,
- /*update_span=*/false);
- }
- }
- }
- next_date++;
- break;
- } else {
- // Since the date is before the date range, we move to the next date
- // to check if it could be merged with the current range.
- continue;
- }
- } else {
- // The date is either fully overlapped by the date range or the date
- // span end is after the date range. Move to the next date in both
- // cases.
- }
- }
- }
-}
-
-// Removes the dates which are part of a range. e.g. in "May 1 - 3", the date
-// "May 1" is fully contained in the range.
-void RemoveOverlappedDateByRange(const std::vector<DateRangeMatch>& ranges,
- std::vector<DateMatch>* dates) {
- int next_date = 0;
- std::vector<bool> removed(dates->size(), false);
- for (int i = 0; i < ranges.size(); ++i) {
- const auto& range = ranges[i];
- for (; next_date < dates->size(); ++next_date) {
- const auto& date = dates->at(next_date);
- // So far we don't touch the partially overlapped case.
- if (date.begin >= range.begin && date.end <= range.end) {
- // Fully contained.
- removed[next_date] = true;
- } else if (date.end <= range.begin) {
- continue; // date is behind range, go to next date
- } else if (date.begin >= range.end) {
- break; // range is behind date, go to next range
- }
- }
- }
- RemoveDeletedMatches(removed, dates);
-}
-
-// Converts candidate dates and date ranges.
-void FillDateInstances(
- const UniLib& unilib, const std::vector<UnicodeText::const_iterator>& text,
- const DateAnnotationOptions& options, std::vector<DateMatch>* date_matches,
- std::vector<DatetimeParseResultSpan>* datetime_parse_result_spans) {
- int i = 0;
- for (int j = 1; j < date_matches->size(); j++) {
- if (options.merge_adjacent_components &&
- AreDateMatchesAdjacentAndMergeable(unilib, text, options.ignored_spans,
- date_matches->at(i),
- date_matches->at(j))) {
- MergeDateMatch(date_matches->at(i), &date_matches->at(j), true);
- } else {
- if (!IsBlacklistedDate(unilib, text, date_matches->at(i))) {
- DatetimeParseResultSpan datetime_parse_result_span;
- FillDateInstance(date_matches->at(i), &datetime_parse_result_span);
- datetime_parse_result_spans->push_back(datetime_parse_result_span);
- }
- }
- i = j;
- }
- if (!IsBlacklistedDate(unilib, text, date_matches->at(i))) {
- DatetimeParseResultSpan datetime_parse_result_span;
- FillDateInstance(date_matches->at(i), &datetime_parse_result_span);
- datetime_parse_result_spans->push_back(datetime_parse_result_span);
- }
-}
-
-void FillDateRangeInstances(
- const std::vector<DateRangeMatch>& date_range_matches,
- std::vector<DatetimeParseResultSpan>* datetime_parse_result_spans) {
- for (const DateRangeMatch& date_range_match : date_range_matches) {
- DatetimeParseResultSpan datetime_parse_result_span;
- FillDateRangeInstance(date_range_match, &datetime_parse_result_span);
- datetime_parse_result_spans->push_back(datetime_parse_result_span);
- }
-}
-
-// Fills `DatetimeParseResultSpan` from `DateMatch` and `DateRangeMatch`
-// instances.
-std::vector<DatetimeParseResultSpan> GetOutputAsAnnotationList(
- const UniLib& unilib, const DateExtractor& extractor,
- const std::vector<UnicodeText::const_iterator>& text,
- const DateAnnotationOptions& options) {
- std::vector<DatetimeParseResultSpan> datetime_parse_result_spans;
- std::vector<DateMatch> date_matches =
- BuildDateMatches(text, extractor.output());
-
- std::sort(
- date_matches.begin(), date_matches.end(),
- // Order by increasing begin, and decreasing end (decreasing length).
- [](const DateMatch& a, const DateMatch& b) {
- return (a.begin < b.begin || (a.begin == b.begin && a.end > b.end));
- });
-
- if (!date_matches.empty()) {
- RemoveDuplicatedDates(&date_matches);
- }
-
- if (options.enable_date_range) {
- std::vector<DateRangeMatch> date_range_matches =
- BuildDateRangeMatches(text, extractor.range_output());
-
- if (!date_range_matches.empty()) {
- std::sort(
- date_range_matches.begin(), date_range_matches.end(),
- // Order by increasing begin, and decreasing end (decreasing length).
- [](const DateRangeMatch& a, const DateRangeMatch& b) {
- return (a.begin < b.begin || (a.begin == b.begin && a.end > b.end));
- });
- RemoveDuplicatedDates(&date_range_matches);
- }
-
- if (!date_matches.empty()) {
- MergeDateRangeAndDate(unilib, text, options.ignored_spans, date_matches,
- &date_range_matches);
- RemoveOverlappedDateByRange(date_range_matches, &date_matches);
- }
- FillDateRangeInstances(date_range_matches, &datetime_parse_result_spans);
- }
-
- if (!date_matches.empty()) {
- FillDateInstances(unilib, text, options, &date_matches,
- &datetime_parse_result_spans);
- }
- return datetime_parse_result_spans;
-}
-
-} // namespace
-
-std::vector<DatetimeParseResultSpan> DateParser::Parse(
- StringPiece text, const std::vector<Token>& tokens,
- const std::vector<Locale>& locales,
- const DateAnnotationOptions& options) const {
- std::vector<UnicodeText::const_iterator> codepoint_offsets;
- const UnicodeText text_unicode = UTF8ToUnicodeText(text,
- /*do_copy=*/false);
- for (auto it = text_unicode.begin(); it != text_unicode.end(); it++) {
- codepoint_offsets.push_back(it);
- }
- codepoint_offsets.push_back(text_unicode.end());
- DateExtractor extractor(codepoint_offsets, options, datetime_rules_);
- // Select locale matching rules.
- // Only use a shard if locales match or the shard doesn't specify a locale
- // restriction.
- std::vector<const grammar::RulesSet_::Rules*> locale_rules =
- SelectLocaleMatchingShards(datetime_rules_->rules(), rules_locales_,
- locales);
- if (locale_rules.empty()) {
- return {};
- }
- grammar::Matcher matcher(&unilib_, datetime_rules_->rules(), locale_rules,
- &extractor);
- lexer_.Process(text_unicode, tokens, /*annotations=*/nullptr, &matcher);
- return GetOutputAsAnnotationList(unilib_, extractor, codepoint_offsets,
- options);
-}
-
-} // namespace libtextclassifier3::dates
diff --git a/native/annotator/grammar/dates/parser.h b/native/annotator/grammar/dates/parser.h
deleted file mode 100644
index be919df..0000000
--- a/native/annotator/grammar/dates/parser.h
+++ /dev/null
@@ -1,64 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_PARSER_H_
-#define LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_PARSER_H_
-
-#include <vector>
-
-#include "annotator/grammar/dates/annotations/annotation-options.h"
-#include "annotator/grammar/dates/annotations/annotation.h"
-#include "annotator/grammar/dates/dates_generated.h"
-#include "annotator/grammar/dates/utils/date-match.h"
-#include "utils/grammar/lexer.h"
-#include "utils/grammar/rules-utils.h"
-#include "utils/i18n/locale.h"
-#include "utils/strings/stringpiece.h"
-#include "utils/utf8/unilib.h"
-
-namespace libtextclassifier3::dates {
-
-// Parses datetime expressions in the input with the datetime grammar and
-// constructs, validates, deduplicates and normalizes date time annotations.
-class DateParser {
- public:
- explicit DateParser(const UniLib* unilib, const DatetimeRules* datetime_rules)
- : unilib_(*unilib),
- lexer_(unilib, datetime_rules->rules()),
- datetime_rules_(datetime_rules),
- rules_locales_(ParseRulesLocales(datetime_rules->rules())) {}
-
- // Parses the dates in the input. Makes sure that the results do not
- // overlap.
- std::vector<DatetimeParseResultSpan> Parse(
- StringPiece text, const std::vector<Token>& tokens,
- const std::vector<Locale>& locales,
- const DateAnnotationOptions& options) const;
-
- private:
- const UniLib& unilib_;
- const grammar::Lexer lexer_;
-
- // The datetime grammar.
- const DatetimeRules* datetime_rules_;
-
- // Pre-parsed locales of the rules.
- const std::vector<std::vector<Locale>> rules_locales_;
-};
-
-} // namespace libtextclassifier3::dates
-
-#endif // LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_PARSER_H_
diff --git a/native/annotator/grammar/dates/timezone-code.fbs b/native/annotator/grammar/dates/timezone-code.fbs
deleted file mode 100755
index ff615ee..0000000
--- a/native/annotator/grammar/dates/timezone-code.fbs
+++ /dev/null
@@ -1,593 +0,0 @@
-//
-// Copyright (C) 2018 The Android Open Source Project
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-//
-
-namespace libtextclassifier3.dates;
-enum TimezoneCode : int {
- TIMEZONE_CODE_NONE = -1,
- ETC_UNKNOWN = 0,
- PST8PDT = 1,
- // Delegate.
-
- AFRICA_ABIDJAN = 2,
- AFRICA_ACCRA = 3,
- AFRICA_ADDIS_ABABA = 4,
- AFRICA_ALGIERS = 5,
- AFRICA_ASMARA = 6,
- AFRICA_BAMAKO = 7,
- // Delegate.
-
- AFRICA_BANGUI = 8,
- AFRICA_BANJUL = 9,
- AFRICA_BISSAU = 10,
- AFRICA_BLANTYRE = 11,
- AFRICA_BRAZZAVILLE = 12,
- AFRICA_BUJUMBURA = 13,
- EGYPT = 14,
- // Delegate.
-
- AFRICA_CASABLANCA = 15,
- AFRICA_CEUTA = 16,
- AFRICA_CONAKRY = 17,
- AFRICA_DAKAR = 18,
- AFRICA_DAR_ES_SALAAM = 19,
- AFRICA_DJIBOUTI = 20,
- AFRICA_DOUALA = 21,
- AFRICA_EL_AAIUN = 22,
- AFRICA_FREETOWN = 23,
- AFRICA_GABORONE = 24,
- AFRICA_HARARE = 25,
- AFRICA_JOHANNESBURG = 26,
- AFRICA_KAMPALA = 27,
- AFRICA_KHARTOUM = 28,
- AFRICA_KIGALI = 29,
- AFRICA_KINSHASA = 30,
- AFRICA_LAGOS = 31,
- AFRICA_LIBREVILLE = 32,
- AFRICA_LOME = 33,
- AFRICA_LUANDA = 34,
- AFRICA_LUBUMBASHI = 35,
- AFRICA_LUSAKA = 36,
- AFRICA_MALABO = 37,
- AFRICA_MAPUTO = 38,
- AFRICA_MASERU = 39,
- AFRICA_MBABANE = 40,
- AFRICA_MOGADISHU = 41,
- AFRICA_MONROVIA = 42,
- AFRICA_NAIROBI = 43,
- AFRICA_NDJAMENA = 44,
- AFRICA_NIAMEY = 45,
- AFRICA_NOUAKCHOTT = 46,
- AFRICA_OUAGADOUGOU = 47,
- AFRICA_PORTO_NOVO = 48,
- AFRICA_SAO_TOME = 49,
- LIBYA = 51,
- // Delegate.
-
- AFRICA_TUNIS = 52,
- AFRICA_WINDHOEK = 53,
- US_ALEUTIAN = 54,
- // Delegate.
-
- US_ALASKA = 55,
- // Delegate.
-
- AMERICA_ANGUILLA = 56,
- AMERICA_ANTIGUA = 57,
- AMERICA_ARAGUAINA = 58,
- AMERICA_BUENOS_AIRES = 59,
- AMERICA_CATAMARCA = 60,
- AMERICA_CORDOBA = 62,
- AMERICA_JUJUY = 63,
- AMERICA_ARGENTINA_LA_RIOJA = 64,
- AMERICA_MENDOZA = 65,
- AMERICA_ARGENTINA_RIO_GALLEGOS = 66,
- AMERICA_ARGENTINA_SAN_JUAN = 67,
- AMERICA_ARGENTINA_TUCUMAN = 68,
- AMERICA_ARGENTINA_USHUAIA = 69,
- AMERICA_ARUBA = 70,
- AMERICA_ASUNCION = 71,
- AMERICA_BAHIA = 72,
- AMERICA_BARBADOS = 73,
- AMERICA_BELEM = 74,
- AMERICA_BELIZE = 75,
- AMERICA_BOA_VISTA = 76,
- AMERICA_BOGOTA = 77,
- AMERICA_BOISE = 78,
- AMERICA_CAMBRIDGE_BAY = 79,
- AMERICA_CAMPO_GRANDE = 80,
- AMERICA_CANCUN = 81,
- AMERICA_CARACAS = 82,
- AMERICA_CAYENNE = 83,
- AMERICA_CAYMAN = 84,
- CST6CDT = 85,
- // Delegate.
-
- AMERICA_CHIHUAHUA = 86,
- AMERICA_COSTA_RICA = 87,
- AMERICA_CUIABA = 88,
- AMERICA_CURACAO = 89,
- AMERICA_DANMARKSHAVN = 90,
- AMERICA_DAWSON = 91,
- AMERICA_DAWSON_CREEK = 92,
- NAVAJO = 93,
- // Delegate.
-
- US_MICHIGAN = 94,
- // Delegate.
-
- AMERICA_DOMINICA = 95,
- CANADA_MOUNTAIN = 96,
- // Delegate.
-
- AMERICA_EIRUNEPE = 97,
- AMERICA_EL_SALVADOR = 98,
- AMERICA_FORTALEZA = 99,
- AMERICA_GLACE_BAY = 100,
- AMERICA_GODTHAB = 101,
- AMERICA_GOOSE_BAY = 102,
- AMERICA_GRAND_TURK = 103,
- AMERICA_GRENADA = 104,
- AMERICA_GUADELOUPE = 105,
- AMERICA_GUATEMALA = 106,
- AMERICA_GUAYAQUIL = 107,
- AMERICA_GUYANA = 108,
- AMERICA_HALIFAX = 109,
- // Delegate.
-
- CUBA = 110,
- // Delegate.
-
- AMERICA_HERMOSILLO = 111,
- AMERICA_KNOX_IN = 113,
- // Delegate.
-
- AMERICA_INDIANA_MARENGO = 114,
- US_EAST_INDIANA = 115,
- AMERICA_INDIANA_VEVAY = 116,
- AMERICA_INUVIK = 117,
- AMERICA_IQALUIT = 118,
- JAMAICA = 119,
- // Delegate.
-
- AMERICA_JUNEAU = 120,
- AMERICA_KENTUCKY_MONTICELLO = 122,
- AMERICA_LA_PAZ = 123,
- AMERICA_LIMA = 124,
- AMERICA_LOUISVILLE = 125,
- AMERICA_MACEIO = 126,
- AMERICA_MANAGUA = 127,
- BRAZIL_WEST = 128,
- // Delegate.
-
- AMERICA_MARTINIQUE = 129,
- MEXICO_BAJASUR = 130,
- // Delegate.
-
- AMERICA_MENOMINEE = 131,
- AMERICA_MERIDA = 132,
- MEXICO_GENERAL = 133,
- // Delegate.
-
- AMERICA_MIQUELON = 134,
- AMERICA_MONTERREY = 135,
- AMERICA_MONTEVIDEO = 136,
- AMERICA_MONTREAL = 137,
- AMERICA_MONTSERRAT = 138,
- AMERICA_NASSAU = 139,
- EST5EDT = 140,
- // Delegate.
-
- AMERICA_NIPIGON = 141,
- AMERICA_NOME = 142,
- AMERICA_NORONHA = 143,
- // Delegate.
-
- AMERICA_NORTH_DAKOTA_CENTER = 144,
- AMERICA_PANAMA = 145,
- AMERICA_PANGNIRTUNG = 146,
- AMERICA_PARAMARIBO = 147,
- US_ARIZONA = 148,
- // Delegate.
-
- AMERICA_PORT_AU_PRINCE = 149,
- AMERICA_PORT_OF_SPAIN = 150,
- AMERICA_PORTO_VELHO = 151,
- AMERICA_PUERTO_RICO = 152,
- AMERICA_RAINY_RIVER = 153,
- AMERICA_RANKIN_INLET = 154,
- AMERICA_RECIFE = 155,
- AMERICA_REGINA = 156,
- // Delegate.
-
- BRAZIL_ACRE = 157,
- AMERICA_SANTIAGO = 158,
- // Delegate.
-
- AMERICA_SANTO_DOMINGO = 159,
- BRAZIL_EAST = 160,
- // Delegate.
-
- AMERICA_SCORESBYSUND = 161,
- AMERICA_ST_JOHNS = 163,
- // Delegate.
-
- AMERICA_ST_KITTS = 164,
- AMERICA_ST_LUCIA = 165,
- AMERICA_VIRGIN = 166,
- // Delegate.
-
- AMERICA_ST_VINCENT = 167,
- AMERICA_SWIFT_CURRENT = 168,
- AMERICA_TEGUCIGALPA = 169,
- AMERICA_THULE = 170,
- AMERICA_THUNDER_BAY = 171,
- AMERICA_TIJUANA = 172,
- CANADA_EASTERN = 173,
- // Delegate.
-
- AMERICA_TORTOLA = 174,
- CANADA_PACIFIC = 175,
- // Delegate.
-
- CANADA_YUKON = 176,
- // Delegate.
-
- CANADA_CENTRAL = 177,
- // Delegate.
-
- AMERICA_YAKUTAT = 178,
- AMERICA_YELLOWKNIFE = 179,
- ANTARCTICA_CASEY = 180,
- ANTARCTICA_DAVIS = 181,
- ANTARCTICA_DUMONTDURVILLE = 182,
- ANTARCTICA_MAWSON = 183,
- ANTARCTICA_MCMURDO = 184,
- ANTARCTICA_PALMER = 185,
- ANTARCTICA_ROTHERA = 186,
- ANTARCTICA_SYOWA = 188,
- ANTARCTICA_VOSTOK = 189,
- ATLANTIC_JAN_MAYEN = 190,
- // Delegate.
-
- ASIA_ADEN = 191,
- ASIA_ALMATY = 192,
- ASIA_AMMAN = 193,
- ASIA_ANADYR = 194,
- ASIA_AQTAU = 195,
- ASIA_AQTOBE = 196,
- ASIA_ASHGABAT = 197,
- // Delegate.
-
- ASIA_BAGHDAD = 198,
- ASIA_BAHRAIN = 199,
- ASIA_BAKU = 200,
- ASIA_BANGKOK = 201,
- ASIA_BEIRUT = 202,
- ASIA_BISHKEK = 203,
- ASIA_BRUNEI = 204,
- ASIA_KOLKATA = 205,
- // Delegate.
-
- ASIA_CHOIBALSAN = 206,
- ASIA_COLOMBO = 208,
- ASIA_DAMASCUS = 209,
- ASIA_DACCA = 210,
- ASIA_DILI = 211,
- ASIA_DUBAI = 212,
- ASIA_DUSHANBE = 213,
- ASIA_GAZA = 214,
- HONGKONG = 216,
- // Delegate.
-
- ASIA_HOVD = 217,
- ASIA_IRKUTSK = 218,
- ASIA_JAKARTA = 220,
- ASIA_JAYAPURA = 221,
- ISRAEL = 222,
- // Delegate.
-
- ASIA_KABUL = 223,
- ASIA_KAMCHATKA = 224,
- ASIA_KARACHI = 225,
- ASIA_KATMANDU = 227,
- ASIA_KRASNOYARSK = 228,
- ASIA_KUALA_LUMPUR = 229,
- ASIA_KUCHING = 230,
- ASIA_KUWAIT = 231,
- ASIA_MACAO = 232,
- ASIA_MAGADAN = 233,
- ASIA_MAKASSAR = 234,
- // Delegate.
-
- ASIA_MANILA = 235,
- ASIA_MUSCAT = 236,
- ASIA_NICOSIA = 237,
- // Delegate.
-
- ASIA_NOVOSIBIRSK = 238,
- ASIA_OMSK = 239,
- ASIA_ORAL = 240,
- ASIA_PHNOM_PENH = 241,
- ASIA_PONTIANAK = 242,
- ASIA_PYONGYANG = 243,
- ASIA_QATAR = 244,
- ASIA_QYZYLORDA = 245,
- ASIA_RANGOON = 246,
- ASIA_RIYADH = 247,
- ASIA_SAIGON = 248,
- ASIA_SAKHALIN = 249,
- ASIA_SAMARKAND = 250,
- ROK = 251,
- // Delegate.
-
- PRC = 252,
- SINGAPORE = 253,
- // Delegate.
-
- ROC = 254,
- // Delegate.
-
- ASIA_TASHKENT = 255,
- ASIA_TBILISI = 256,
- IRAN = 257,
- // Delegate.
-
- ASIA_THIMBU = 258,
- JAPAN = 259,
- // Delegate.
-
- ASIA_ULAN_BATOR = 260,
- // Delegate.
-
- ASIA_URUMQI = 261,
- ASIA_VIENTIANE = 262,
- ASIA_VLADIVOSTOK = 263,
- ASIA_YAKUTSK = 264,
- ASIA_YEKATERINBURG = 265,
- ASIA_YEREVAN = 266,
- ATLANTIC_AZORES = 267,
- ATLANTIC_BERMUDA = 268,
- ATLANTIC_CANARY = 269,
- ATLANTIC_CAPE_VERDE = 270,
- ATLANTIC_FAROE = 271,
- // Delegate.
-
- ATLANTIC_MADEIRA = 273,
- ICELAND = 274,
- // Delegate.
-
- ATLANTIC_SOUTH_GEORGIA = 275,
- ATLANTIC_STANLEY = 276,
- ATLANTIC_ST_HELENA = 277,
- AUSTRALIA_SOUTH = 278,
- // Delegate.
-
- AUSTRALIA_BRISBANE = 279,
- // Delegate.
-
- AUSTRALIA_YANCOWINNA = 280,
- // Delegate.
-
- AUSTRALIA_NORTH = 281,
- // Delegate.
-
- AUSTRALIA_HOBART = 282,
- // Delegate.
-
- AUSTRALIA_LINDEMAN = 283,
- AUSTRALIA_LHI = 284,
- AUSTRALIA_VICTORIA = 285,
- // Delegate.
-
- AUSTRALIA_WEST = 286,
- // Delegate.
-
- AUSTRALIA_ACT = 287,
- EUROPE_AMSTERDAM = 288,
- EUROPE_ANDORRA = 289,
- EUROPE_ATHENS = 290,
- EUROPE_BELGRADE = 292,
- EUROPE_BERLIN = 293,
- EUROPE_BRATISLAVA = 294,
- EUROPE_BRUSSELS = 295,
- EUROPE_BUCHAREST = 296,
- EUROPE_BUDAPEST = 297,
- EUROPE_CHISINAU = 298,
- // Delegate.
-
- EUROPE_COPENHAGEN = 299,
- EIRE = 300,
- EUROPE_GIBRALTAR = 301,
- EUROPE_HELSINKI = 302,
- TURKEY = 303,
- EUROPE_KALININGRAD = 304,
- EUROPE_KIEV = 305,
- PORTUGAL = 306,
- // Delegate.
-
- EUROPE_LJUBLJANA = 307,
- GB = 308,
- EUROPE_LUXEMBOURG = 309,
- EUROPE_MADRID = 310,
- EUROPE_MALTA = 311,
- EUROPE_MARIEHAMN = 312,
- EUROPE_MINSK = 313,
- EUROPE_MONACO = 314,
- W_SU = 315,
- // Delegate.
-
- EUROPE_OSLO = 317,
- EUROPE_PARIS = 318,
- EUROPE_PRAGUE = 319,
- EUROPE_RIGA = 320,
- EUROPE_ROME = 321,
- EUROPE_SAMARA = 322,
- EUROPE_SAN_MARINO = 323,
- EUROPE_SARAJEVO = 324,
- EUROPE_SIMFEROPOL = 325,
- EUROPE_SKOPJE = 326,
- EUROPE_SOFIA = 327,
- EUROPE_STOCKHOLM = 328,
- EUROPE_TALLINN = 329,
- EUROPE_TIRANE = 330,
- EUROPE_UZHGOROD = 331,
- EUROPE_VADUZ = 332,
- EUROPE_VATICAN = 333,
- EUROPE_VIENNA = 334,
- EUROPE_VILNIUS = 335,
- POLAND = 336,
- // Delegate.
-
- EUROPE_ZAGREB = 337,
- EUROPE_ZAPOROZHYE = 338,
- EUROPE_ZURICH = 339,
- INDIAN_ANTANANARIVO = 340,
- INDIAN_CHAGOS = 341,
- INDIAN_CHRISTMAS = 342,
- INDIAN_COCOS = 343,
- INDIAN_COMORO = 344,
- INDIAN_KERGUELEN = 345,
- INDIAN_MAHE = 346,
- INDIAN_MALDIVES = 347,
- INDIAN_MAURITIUS = 348,
- INDIAN_MAYOTTE = 349,
- INDIAN_REUNION = 350,
- PACIFIC_APIA = 351,
- NZ = 352,
- NZ_CHAT = 353,
- PACIFIC_EASTER = 354,
- PACIFIC_EFATE = 355,
- PACIFIC_ENDERBURY = 356,
- PACIFIC_FAKAOFO = 357,
- PACIFIC_FIJI = 358,
- PACIFIC_FUNAFUTI = 359,
- PACIFIC_GALAPAGOS = 360,
- PACIFIC_GAMBIER = 361,
- PACIFIC_GUADALCANAL = 362,
- PACIFIC_GUAM = 363,
- US_HAWAII = 364,
- // Delegate.
-
- PACIFIC_JOHNSTON = 365,
- PACIFIC_KIRITIMATI = 366,
- PACIFIC_KOSRAE = 367,
- KWAJALEIN = 368,
- PACIFIC_MAJURO = 369,
- PACIFIC_MARQUESAS = 370,
- PACIFIC_MIDWAY = 371,
- PACIFIC_NAURU = 372,
- PACIFIC_NIUE = 373,
- PACIFIC_NORFOLK = 374,
- PACIFIC_NOUMEA = 375,
- US_SAMOA = 376,
- // Delegate.
-
- PACIFIC_PALAU = 377,
- PACIFIC_PITCAIRN = 378,
- PACIFIC_PONAPE = 379,
- PACIFIC_PORT_MORESBY = 380,
- PACIFIC_RAROTONGA = 381,
- PACIFIC_SAIPAN = 382,
- PACIFIC_TAHITI = 383,
- PACIFIC_TARAWA = 384,
- PACIFIC_TONGATAPU = 385,
- PACIFIC_YAP = 386,
- PACIFIC_WAKE = 387,
- PACIFIC_WALLIS = 388,
- AMERICA_ATIKOKAN = 390,
- AUSTRALIA_CURRIE = 391,
- ETC_GMT_EAST_14 = 392,
- ETC_GMT_EAST_13 = 393,
- ETC_GMT_EAST_12 = 394,
- ETC_GMT_EAST_11 = 395,
- ETC_GMT_EAST_10 = 396,
- ETC_GMT_EAST_9 = 397,
- ETC_GMT_EAST_8 = 398,
- ETC_GMT_EAST_7 = 399,
- ETC_GMT_EAST_6 = 400,
- ETC_GMT_EAST_5 = 401,
- ETC_GMT_EAST_4 = 402,
- ETC_GMT_EAST_3 = 403,
- ETC_GMT_EAST_2 = 404,
- ETC_GMT_EAST_1 = 405,
- GMT = 406,
- // Delegate.
-
- ETC_GMT_WEST_1 = 407,
- ETC_GMT_WEST_2 = 408,
- ETC_GMT_WEST_3 = 409,
- SYSTEMV_AST4 = 410,
- // Delegate.
-
- EST = 411,
- SYSTEMV_CST6 = 412,
- // Delegate.
-
- MST = 413,
- // Delegate.
-
- SYSTEMV_PST8 = 414,
- // Delegate.
-
- SYSTEMV_YST9 = 415,
- // Delegate.
-
- HST = 416,
- // Delegate.
-
- ETC_GMT_WEST_11 = 417,
- ETC_GMT_WEST_12 = 418,
- AMERICA_NORTH_DAKOTA_NEW_SALEM = 419,
- AMERICA_INDIANA_PETERSBURG = 420,
- AMERICA_INDIANA_VINCENNES = 421,
- AMERICA_MONCTON = 422,
- AMERICA_BLANC_SABLON = 423,
- EUROPE_GUERNSEY = 424,
- EUROPE_ISLE_OF_MAN = 425,
- EUROPE_JERSEY = 426,
- EUROPE_PODGORICA = 427,
- EUROPE_VOLGOGRAD = 428,
- AMERICA_INDIANA_WINAMAC = 429,
- AUSTRALIA_EUCLA = 430,
- AMERICA_INDIANA_TELL_CITY = 431,
- AMERICA_RESOLUTE = 432,
- AMERICA_ARGENTINA_SAN_LUIS = 433,
- AMERICA_SANTAREM = 434,
- AMERICA_ARGENTINA_SALTA = 435,
- AMERICA_BAHIA_BANDERAS = 436,
- AMERICA_MARIGOT = 437,
- AMERICA_MATAMOROS = 438,
- AMERICA_OJINAGA = 439,
- AMERICA_SANTA_ISABEL = 440,
- AMERICA_ST_BARTHELEMY = 441,
- ANTARCTICA_MACQUARIE = 442,
- ASIA_NOVOKUZNETSK = 443,
- AFRICA_JUBA = 444,
- AMERICA_METLAKATLA = 445,
- AMERICA_NORTH_DAKOTA_BEULAH = 446,
- AMERICA_SITKA = 447,
- ASIA_HEBRON = 448,
- AMERICA_CRESTON = 449,
- AMERICA_KRALENDIJK = 450,
- AMERICA_LOWER_PRINCES = 451,
- ANTARCTICA_TROLL = 452,
- ASIA_KHANDYGA = 453,
- ASIA_UST_NERA = 454,
- EUROPE_BUSINGEN = 455,
- ASIA_CHITA = 456,
- ASIA_SREDNEKOLYMSK = 457,
-}
-
diff --git a/native/annotator/grammar/dates/utils/annotation-keys.cc b/native/annotator/grammar/dates/utils/annotation-keys.cc
deleted file mode 100644
index 3438c6d..0000000
--- a/native/annotator/grammar/dates/utils/annotation-keys.cc
+++ /dev/null
@@ -1,29 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "annotator/grammar/dates/utils/annotation-keys.h"
-
-namespace libtextclassifier3 {
-namespace dates {
-const char* const kDateTimeType = "dateTime";
-const char* const kDateTimeRangeType = "dateTimeRange";
-const char* const kDateTime = "dateTime";
-const char* const kDateTimeSupplementary = "dateTimeSupplementary";
-const char* const kDateTimeRelative = "dateTimeRelative";
-const char* const kDateTimeRangeFrom = "dateTimeRangeFrom";
-const char* const kDateTimeRangeTo = "dateTimeRangeTo";
-} // namespace dates
-} // namespace libtextclassifier3
diff --git a/native/annotator/grammar/dates/utils/annotation-keys.h b/native/annotator/grammar/dates/utils/annotation-keys.h
deleted file mode 100644
index f970a51..0000000
--- a/native/annotator/grammar/dates/utils/annotation-keys.h
+++ /dev/null
@@ -1,59 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_UTILS_ANNOTATION_KEYS_H_
-#define LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_UTILS_ANNOTATION_KEYS_H_
-
-namespace libtextclassifier3 {
-namespace dates {
-
-// Date time specific constants not defined in standard schemas.
-//
-// Date annotator output two type of annotation. One is date&time like "May 1",
-// "12:20pm", etc. Another is range like "2pm - 3pm". The two string identify
-// the type of annotation and are used as type in Thing proto.
-extern const char* const kDateTimeType;
-extern const char* const kDateTimeRangeType;
-
-// kDateTime contains most common field for date time. It's integer array and
-// the format is (year, month, day, hour, minute, second, fraction_sec,
-// day_of_week). All eight fields must be provided. If the field is not
-// extracted, the value is -1 in the array.
-extern const char* const kDateTime;
-
-// kDateTimeSupplementary contains uncommon field like timespan, timezone. It's
-// integer array and the format is (bc_ad, timespan_code, timezone_code,
-// timezone_offset). Al four fields must be provided. If the field is not
-// extracted, the value is -1 in the array.
-extern const char* const kDateTimeSupplementary;
-
-// kDateTimeRelative contains fields for relative date time. It's integer
-// array and the format is (is_future, year, month, day, week, hour, minute,
-// second, day_of_week, dow_interpretation*). The first nine fields must be
-// provided and dow_interpretation could have zero or multiple values.
-// If the field is not extracted, the value is -1 in the array.
-extern const char* const kDateTimeRelative;
-
-// Date time range specific constants not defined in standard schemas.
-// kDateTimeRangeFrom and kDateTimeRangeTo define the from/to of a date/time
-// range. The value is thing object which contains a date time.
-extern const char* const kDateTimeRangeFrom;
-extern const char* const kDateTimeRangeTo;
-
-} // namespace dates
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_UTILS_ANNOTATION_KEYS_H_
diff --git a/native/annotator/grammar/dates/utils/date-match.cc b/native/annotator/grammar/dates/utils/date-match.cc
deleted file mode 100644
index d9fca52..0000000
--- a/native/annotator/grammar/dates/utils/date-match.cc
+++ /dev/null
@@ -1,440 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "annotator/grammar/dates/utils/date-match.h"
-
-#include <algorithm>
-
-#include "annotator/grammar/dates/utils/date-utils.h"
-#include "annotator/types.h"
-#include "utils/strings/append.h"
-
-static const int kAM = 0;
-static const int kPM = 1;
-
-namespace libtextclassifier3 {
-namespace dates {
-
-namespace {
-static int GetMeridiemValue(const TimespanCode& timespan_code) {
- switch (timespan_code) {
- case TimespanCode_AM:
- case TimespanCode_MIDNIGHT:
- // MIDNIGHT [3] -> AM
- return kAM;
- case TimespanCode_TONIGHT:
- // TONIGHT [11] -> PM
- case TimespanCode_NOON:
- // NOON [2] -> PM
- case TimespanCode_PM:
- return kPM;
- case TimespanCode_TIMESPAN_CODE_NONE:
- default:
- TC3_LOG(WARNING) << "Failed to extract time span code.";
- }
- return NO_VAL;
-}
-
-static int GetRelativeCount(const RelativeParameter* relative_parameter) {
- for (const int interpretation :
- *relative_parameter->day_of_week_interpretation()) {
- switch (interpretation) {
- case RelativeParameter_::Interpretation_NEAREST_LAST:
- case RelativeParameter_::Interpretation_PREVIOUS:
- return -1;
- case RelativeParameter_::Interpretation_SECOND_LAST:
- return -2;
- case RelativeParameter_::Interpretation_SECOND_NEXT:
- return 2;
- case RelativeParameter_::Interpretation_COMING:
- case RelativeParameter_::Interpretation_SOME:
- case RelativeParameter_::Interpretation_NEAREST:
- case RelativeParameter_::Interpretation_NEAREST_NEXT:
- return 1;
- case RelativeParameter_::Interpretation_CURRENT:
- return 0;
- }
- }
- return 0;
-}
-} // namespace
-
-using strings::JoinStrings;
-using strings::SStringAppendF;
-
-std::string DateMatch::DebugString() const {
- std::string res;
-#if !defined(NDEBUG)
- if (begin >= 0 && end >= 0) {
- SStringAppendF(&res, 0, "[%u,%u)", begin, end);
- }
-
- if (HasDayOfWeek()) {
- SStringAppendF(&res, 0, "%u", day_of_week);
- }
-
- if (HasYear()) {
- int year_output = year;
- if (HasBcAd() && bc_ad == BCAD_BC) {
- year_output = -year;
- }
- SStringAppendF(&res, 0, "%u/", year_output);
- } else {
- SStringAppendF(&res, 0, "____/");
- }
-
- if (HasMonth()) {
- SStringAppendF(&res, 0, "%u/", month);
- } else {
- SStringAppendF(&res, 0, "__/");
- }
-
- if (HasDay()) {
- SStringAppendF(&res, 0, "%u ", day);
- } else {
- SStringAppendF(&res, 0, "__ ");
- }
-
- if (HasHour()) {
- SStringAppendF(&res, 0, "%u:", hour);
- } else {
- SStringAppendF(&res, 0, "__:");
- }
-
- if (HasMinute()) {
- SStringAppendF(&res, 0, "%u:", minute);
- } else {
- SStringAppendF(&res, 0, "__:");
- }
-
- if (HasSecond()) {
- if (HasFractionSecond()) {
- SStringAppendF(&res, 0, "%u.%lf ", second, fraction_second);
- } else {
- SStringAppendF(&res, 0, "%u ", second);
- }
- } else {
- SStringAppendF(&res, 0, "__ ");
- }
-
- if (HasTimeSpanCode() && TimespanCode_TIMESPAN_CODE_NONE < time_span_code &&
- time_span_code <= TimespanCode_MAX) {
- SStringAppendF(&res, 0, "TS=%u ", time_span_code);
- }
-
- if (HasTimeZoneCode() && time_zone_code != -1) {
- SStringAppendF(&res, 0, "TZ= %u ", time_zone_code);
- }
-
- if (HasTimeZoneOffset()) {
- SStringAppendF(&res, 0, "TZO=%u ", time_zone_offset);
- }
-
- if (HasRelativeDate()) {
- const RelativeMatch* rm = relative_match;
- SStringAppendF(&res, 0, (rm->is_future_date ? "future " : "past "));
- if (rm->day_of_week != NO_VAL) {
- SStringAppendF(&res, 0, "DOW:%d ", rm->day_of_week);
- }
- if (rm->year != NO_VAL) {
- SStringAppendF(&res, 0, "Y:%d ", rm->year);
- }
- if (rm->month != NO_VAL) {
- SStringAppendF(&res, 0, "M:%d ", rm->month);
- }
- if (rm->day != NO_VAL) {
- SStringAppendF(&res, 0, "D:%d ", rm->day);
- }
- if (rm->week != NO_VAL) {
- SStringAppendF(&res, 0, "W:%d ", rm->week);
- }
- if (rm->hour != NO_VAL) {
- SStringAppendF(&res, 0, "H:%d ", rm->hour);
- }
- if (rm->minute != NO_VAL) {
- SStringAppendF(&res, 0, "M:%d ", rm->minute);
- }
- if (rm->second != NO_VAL) {
- SStringAppendF(&res, 0, "S:%d ", rm->second);
- }
- }
-
- SStringAppendF(&res, 0, "prio=%d ", priority);
- SStringAppendF(&res, 0, "conf-score=%lf ", annotator_priority_score);
-
- if (IsHourAmbiguous()) {
- std::vector<int8> values;
- GetPossibleHourValues(&values);
- std::string str_values;
-
- for (unsigned int i = 0; i < values.size(); ++i) {
- SStringAppendF(&str_values, 0, "%u,", values[i]);
- }
- SStringAppendF(&res, 0, "amb=%s ", str_values.c_str());
- }
-
- std::vector<std::string> tags;
- if (is_inferred) {
- tags.push_back("inferred");
- }
- if (!tags.empty()) {
- SStringAppendF(&res, 0, "tag=%s ", JoinStrings(",", tags).c_str());
- }
-#endif // !defined(NDEBUG)
- return res;
-}
-
-void DateMatch::GetPossibleHourValues(std::vector<int8>* values) const {
- TC3_CHECK(values != nullptr);
- values->clear();
- if (HasHour()) {
- int8 possible_hour = hour;
- values->push_back(possible_hour);
- for (int count = 1; count < ambiguous_hour_count; ++count) {
- possible_hour += ambiguous_hour_interval;
- if (possible_hour >= 24) {
- possible_hour -= 24;
- }
- values->push_back(possible_hour);
- }
- }
-}
-
-DatetimeComponent::RelativeQualifier DateMatch::GetRelativeQualifier() const {
- if (HasRelativeDate()) {
- if (relative_match->existing & RelativeMatch::HAS_IS_FUTURE) {
- if (!relative_match->is_future_date) {
- return DatetimeComponent::RelativeQualifier::PAST;
- }
- }
- return DatetimeComponent::RelativeQualifier::FUTURE;
- }
- return DatetimeComponent::RelativeQualifier::UNSPECIFIED;
-}
-
-// Embed RelativeQualifier information of DatetimeComponent as a sign of
-// relative counter field of datetime component i.e. relative counter is
-// negative when relative qualifier RelativeQualifier::PAST.
-int GetAdjustedRelativeCounter(
- const DatetimeComponent::RelativeQualifier& relative_qualifier,
- const int relative_counter) {
- if (DatetimeComponent::RelativeQualifier::PAST == relative_qualifier) {
- return -relative_counter;
- }
- return relative_counter;
-}
-
-Optional<DatetimeComponent> CreateDatetimeComponent(
- const DatetimeComponent::ComponentType& component_type,
- const DatetimeComponent::RelativeQualifier& relative_qualifier,
- const int absolute_value, const int relative_value) {
- if (absolute_value == NO_VAL && relative_value == NO_VAL) {
- return Optional<DatetimeComponent>();
- }
- return Optional<DatetimeComponent>(DatetimeComponent(
- component_type,
- (relative_value != NO_VAL)
- ? relative_qualifier
- : DatetimeComponent::RelativeQualifier::UNSPECIFIED,
- (absolute_value != NO_VAL) ? absolute_value : 0,
- (relative_value != NO_VAL)
- ? GetAdjustedRelativeCounter(relative_qualifier, relative_value)
- : 0));
-}
-
-Optional<DatetimeComponent> CreateDayOfWeekComponent(
- const RelativeMatch* relative_match,
- const DatetimeComponent::RelativeQualifier& relative_qualifier,
- const DayOfWeek& absolute_day_of_week) {
- DatetimeComponent::RelativeQualifier updated_relative_qualifier =
- relative_qualifier;
- int absolute_value = absolute_day_of_week;
- int relative_value = NO_VAL;
- if (relative_match) {
- relative_value = relative_match->day_of_week;
- if (relative_match->existing & RelativeMatch::HAS_DAY_OF_WEEK) {
- if (relative_match->IsStandaloneRelativeDayOfWeek() &&
- absolute_day_of_week == DayOfWeek_DOW_NONE) {
- absolute_value = relative_match->day_of_week;
- }
- // Check if the relative date has day of week with week period.
- if (relative_match->existing & RelativeMatch::HAS_WEEK) {
- relative_value = 1;
- } else {
- const NonterminalValue* nonterminal =
- relative_match->day_of_week_nonterminal;
- TC3_CHECK(nonterminal != nullptr);
- TC3_CHECK(nonterminal->relative_parameter());
- const RelativeParameter* rp = nonterminal->relative_parameter();
- if (rp->day_of_week_interpretation()) {
- relative_value = GetRelativeCount(rp);
- if (relative_value < 0) {
- relative_value = abs(relative_value);
- updated_relative_qualifier =
- DatetimeComponent::RelativeQualifier::PAST;
- } else if (relative_value > 0) {
- updated_relative_qualifier =
- DatetimeComponent::RelativeQualifier::FUTURE;
- }
- }
- }
- }
- }
- return CreateDatetimeComponent(DatetimeComponent::ComponentType::DAY_OF_WEEK,
- updated_relative_qualifier, absolute_value,
- relative_value);
-}
-
-// Resolve the year’s ambiguity.
-// If the year in the date has 4 digits i.e. DD/MM/YYYY then there is no
-// ambiguity, the year value is YYYY but certain format i.e. MM/DD/YY is
-// ambiguous e.g. in {April/23/15} year value can be 15 or 1915 or 2015.
-// Following heuristic is used to resolve the ambiguity.
-// - For YYYY there is nothing to resolve.
-// - For all YY years
-// - Value less than 50 will be resolved to 20YY
-// - Value greater or equal 50 will be resolved to 19YY
-static int InterpretYear(int parsed_year) {
- if (parsed_year == NO_VAL) {
- return parsed_year;
- }
- if (parsed_year < 100) {
- if (parsed_year < 50) {
- return parsed_year + 2000;
- }
- return parsed_year + 1900;
- }
- return parsed_year;
-}
-
-Optional<DatetimeComponent> DateMatch::GetDatetimeComponent(
- const DatetimeComponent::ComponentType& component_type) const {
- switch (component_type) {
- case DatetimeComponent::ComponentType::YEAR:
- return CreateDatetimeComponent(
- component_type, GetRelativeQualifier(), InterpretYear(year),
- (relative_match != nullptr) ? relative_match->year : NO_VAL);
- case DatetimeComponent::ComponentType::MONTH:
- return CreateDatetimeComponent(
- component_type, GetRelativeQualifier(), month,
- (relative_match != nullptr) ? relative_match->month : NO_VAL);
- case DatetimeComponent::ComponentType::DAY_OF_MONTH:
- return CreateDatetimeComponent(
- component_type, GetRelativeQualifier(), day,
- (relative_match != nullptr) ? relative_match->day : NO_VAL);
- case DatetimeComponent::ComponentType::HOUR:
- return CreateDatetimeComponent(
- component_type, GetRelativeQualifier(), hour,
- (relative_match != nullptr) ? relative_match->hour : NO_VAL);
- case DatetimeComponent::ComponentType::MINUTE:
- return CreateDatetimeComponent(
- component_type, GetRelativeQualifier(), minute,
- (relative_match != nullptr) ? relative_match->minute : NO_VAL);
- case DatetimeComponent::ComponentType::SECOND:
- return CreateDatetimeComponent(
- component_type, GetRelativeQualifier(), second,
- (relative_match != nullptr) ? relative_match->second : NO_VAL);
- case DatetimeComponent::ComponentType::DAY_OF_WEEK:
- return CreateDayOfWeekComponent(relative_match, GetRelativeQualifier(),
- day_of_week);
- case DatetimeComponent::ComponentType::MERIDIEM:
- return CreateDatetimeComponent(component_type, GetRelativeQualifier(),
- GetMeridiemValue(time_span_code), NO_VAL);
- case DatetimeComponent::ComponentType::ZONE_OFFSET:
- if (HasTimeZoneOffset()) {
- return Optional<DatetimeComponent>(DatetimeComponent(
- component_type, DatetimeComponent::RelativeQualifier::UNSPECIFIED,
- time_zone_offset, /*arg_relative_count=*/0));
- }
- return Optional<DatetimeComponent>();
- case DatetimeComponent::ComponentType::WEEK:
- return CreateDatetimeComponent(
- component_type, GetRelativeQualifier(), NO_VAL,
- HasRelativeDate() ? relative_match->week : NO_VAL);
- default:
- return Optional<DatetimeComponent>();
- }
-}
-
-bool DateMatch::IsValid() const {
- if (!HasYear() && HasBcAd()) {
- return false;
- }
- if (!HasMonth() && HasYear() && (HasDay() || HasDayOfWeek())) {
- return false;
- }
- if (!HasDay() && HasDayOfWeek() && (HasYear() || HasMonth())) {
- return false;
- }
- if (!HasDay() && !HasDayOfWeek() && HasHour() && (HasYear() || HasMonth())) {
- return false;
- }
- if (!HasHour() && (HasMinute() || HasSecond() || HasFractionSecond())) {
- return false;
- }
- if (!HasMinute() && (HasSecond() || HasFractionSecond())) {
- return false;
- }
- if (!HasSecond() && HasFractionSecond()) {
- return false;
- }
- // Check whether day exists in a month, to exclude cases like "April 31".
- if (HasDay() && HasMonth() && day > GetLastDayOfMonth(year, month)) {
- return false;
- }
- return (HasDateFields() || HasTimeFields() || HasRelativeDate());
-}
-
-void DateMatch::FillDatetimeComponents(
- std::vector<DatetimeComponent>* datetime_component) const {
- static const std::vector<DatetimeComponent::ComponentType>*
- kDatetimeComponents = new std::vector<DatetimeComponent::ComponentType>{
- DatetimeComponent::ComponentType::ZONE_OFFSET,
- DatetimeComponent::ComponentType::MERIDIEM,
- DatetimeComponent::ComponentType::SECOND,
- DatetimeComponent::ComponentType::MINUTE,
- DatetimeComponent::ComponentType::HOUR,
- DatetimeComponent::ComponentType::DAY_OF_MONTH,
- DatetimeComponent::ComponentType::DAY_OF_WEEK,
- DatetimeComponent::ComponentType::WEEK,
- DatetimeComponent::ComponentType::MONTH,
- DatetimeComponent::ComponentType::YEAR};
-
- for (const DatetimeComponent::ComponentType& component_type :
- *kDatetimeComponents) {
- Optional<DatetimeComponent> date_time =
- GetDatetimeComponent(component_type);
- if (date_time.has_value()) {
- datetime_component->emplace_back(date_time.value());
- }
- }
-}
-
-std::string DateRangeMatch::DebugString() const {
- std::string res;
- // The method is only called for debugging purposes.
-#if !defined(NDEBUG)
- if (begin >= 0 && end >= 0) {
- SStringAppendF(&res, 0, "[%u,%u)\n", begin, end);
- }
- SStringAppendF(&res, 0, "from: %s \n", from.DebugString().c_str());
- SStringAppendF(&res, 0, "to: %s\n", to.DebugString().c_str());
-#endif // !defined(NDEBUG)
- return res;
-}
-
-} // namespace dates
-} // namespace libtextclassifier3
diff --git a/native/annotator/grammar/dates/utils/date-match.h b/native/annotator/grammar/dates/utils/date-match.h
deleted file mode 100644
index 285e9b3..0000000
--- a/native/annotator/grammar/dates/utils/date-match.h
+++ /dev/null
@@ -1,537 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_UTILS_DATE_MATCH_H_
-#define LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_UTILS_DATE_MATCH_H_
-
-#include <stddef.h>
-#include <stdint.h>
-
-#include <algorithm>
-#include <vector>
-
-#include "annotator/grammar/dates/dates_generated.h"
-#include "annotator/grammar/dates/timezone-code_generated.h"
-#include "utils/grammar/match.h"
-
-namespace libtextclassifier3 {
-namespace dates {
-
-static constexpr int NO_VAL = -1;
-
-// POD match data structure.
-struct MatchBase : public grammar::Match {
- void Reset() { type = MatchType::MatchType_UNKNOWN; }
-};
-
-struct ExtractionMatch : public MatchBase {
- const ExtractionRuleParameter* extraction_rule;
-
- void Reset() {
- MatchBase::Reset();
- type = MatchType::MatchType_DATETIME_RULE;
- extraction_rule = nullptr;
- }
-};
-
-struct TermValueMatch : public MatchBase {
- const TermValue* term_value;
-
- void Reset() {
- MatchBase::Reset();
- type = MatchType::MatchType_TERM_VALUE;
- term_value = nullptr;
- }
-};
-
-struct NonterminalMatch : public MatchBase {
- const NonterminalValue* nonterminal;
-
- void Reset() {
- MatchBase::Reset();
- type = MatchType::MatchType_NONTERMINAL;
- nonterminal = nullptr;
- }
-};
-
-struct IntegerMatch : public NonterminalMatch {
- int value;
- int8 count_of_digits; // When expression is in digits format.
- bool is_zero_prefixed; // When expression is in digits format.
-
- void Reset() {
- NonterminalMatch::Reset();
- value = NO_VAL;
- count_of_digits = 0;
- is_zero_prefixed = false;
- }
-};
-
-struct DigitsMatch : public IntegerMatch {
- void Reset() {
- IntegerMatch::Reset();
- type = MatchType::MatchType_DIGITS;
- }
-
- static bool IsValid(int x) { return true; }
-};
-
-struct YearMatch : public IntegerMatch {
- void Reset() {
- IntegerMatch::Reset();
- type = MatchType::MatchType_YEAR;
- }
-
- static bool IsValid(int x) { return x >= 1; }
-};
-
-struct MonthMatch : public IntegerMatch {
- void Reset() {
- IntegerMatch::Reset();
- type = MatchType::MatchType_MONTH;
- }
-
- static bool IsValid(int x) { return (x >= 1 && x <= 12); }
-};
-
-struct DayMatch : public IntegerMatch {
- void Reset() {
- IntegerMatch::Reset();
- type = MatchType::MatchType_DAY;
- }
-
- static bool IsValid(int x) { return (x >= 1 && x <= 31); }
-};
-
-struct HourMatch : public IntegerMatch {
- void Reset() {
- IntegerMatch::Reset();
- type = MatchType::MatchType_HOUR;
- }
-
- static bool IsValid(int x) { return (x >= 0 && x <= 24); }
-};
-
-struct MinuteMatch : public IntegerMatch {
- void Reset() {
- IntegerMatch::Reset();
- type = MatchType::MatchType_MINUTE;
- }
-
- static bool IsValid(int x) { return (x >= 0 && x <= 59); }
-};
-
-struct SecondMatch : public IntegerMatch {
- void Reset() {
- IntegerMatch::Reset();
- type = MatchType::MatchType_SECOND;
- }
-
- static bool IsValid(int x) { return (x >= 0 && x <= 60); }
-};
-
-struct DecimalMatch : public NonterminalMatch {
- double value;
- int8 count_of_digits; // When expression is in digits format.
-
- void Reset() {
- NonterminalMatch::Reset();
- value = NO_VAL;
- count_of_digits = 0;
- }
-};
-
-struct FractionSecondMatch : public DecimalMatch {
- void Reset() {
- DecimalMatch::Reset();
- type = MatchType::MatchType_FRACTION_SECOND;
- }
-
- static bool IsValid(double x) { return (x >= 0.0 && x < 1.0); }
-};
-
-// CombinedIntegersMatch<N> is used for expressions containing multiple (up
-// to N) matches of integers without delimeters between them (because
-// CFG-grammar is based on tokenizer, it could not split a token into several
-// pieces like using regular-expression). For example, "1130" contains "11"
-// and "30" meaning November 30.
-template <int N>
-struct CombinedIntegersMatch : public NonterminalMatch {
- enum {
- SIZE = N,
- };
-
- int values[SIZE];
- int8 count_of_digits; // When expression is in digits format.
- bool is_zero_prefixed; // When expression is in digits format.
-
- void Reset() {
- NonterminalMatch::Reset();
- for (int i = 0; i < SIZE; ++i) {
- values[i] = NO_VAL;
- }
- count_of_digits = 0;
- is_zero_prefixed = false;
- }
-};
-
-struct CombinedDigitsMatch : public CombinedIntegersMatch<6> {
- enum Index {
- INDEX_YEAR = 0,
- INDEX_MONTH = 1,
- INDEX_DAY = 2,
- INDEX_HOUR = 3,
- INDEX_MINUTE = 4,
- INDEX_SECOND = 5,
- };
-
- bool HasYear() const { return values[INDEX_YEAR] != NO_VAL; }
- bool HasMonth() const { return values[INDEX_MONTH] != NO_VAL; }
- bool HasDay() const { return values[INDEX_DAY] != NO_VAL; }
- bool HasHour() const { return values[INDEX_HOUR] != NO_VAL; }
- bool HasMinute() const { return values[INDEX_MINUTE] != NO_VAL; }
- bool HasSecond() const { return values[INDEX_SECOND] != NO_VAL; }
-
- int GetYear() const { return values[INDEX_YEAR]; }
- int GetMonth() const { return values[INDEX_MONTH]; }
- int GetDay() const { return values[INDEX_DAY]; }
- int GetHour() const { return values[INDEX_HOUR]; }
- int GetMinute() const { return values[INDEX_MINUTE]; }
- int GetSecond() const { return values[INDEX_SECOND]; }
-
- void Reset() {
- CombinedIntegersMatch<SIZE>::Reset();
- type = MatchType::MatchType_COMBINED_DIGITS;
- }
-
- static bool IsValid(int i, int x) {
- switch (i) {
- case INDEX_YEAR:
- return YearMatch::IsValid(x);
- case INDEX_MONTH:
- return MonthMatch::IsValid(x);
- case INDEX_DAY:
- return DayMatch::IsValid(x);
- case INDEX_HOUR:
- return HourMatch::IsValid(x);
- case INDEX_MINUTE:
- return MinuteMatch::IsValid(x);
- case INDEX_SECOND:
- return SecondMatch::IsValid(x);
- default:
- return false;
- }
- }
-};
-
-struct TimeValueMatch : public NonterminalMatch {
- const HourMatch* hour_match;
- const MinuteMatch* minute_match;
- const SecondMatch* second_match;
- const FractionSecondMatch* fraction_second_match;
-
- bool is_hour_zero_prefixed : 1;
- bool is_minute_one_digit : 1;
- bool is_second_one_digit : 1;
-
- int8 hour;
- int8 minute;
- int8 second;
- double fraction_second;
-
- void Reset() {
- NonterminalMatch::Reset();
- type = MatchType::MatchType_TIME_VALUE;
- hour_match = nullptr;
- minute_match = nullptr;
- second_match = nullptr;
- fraction_second_match = nullptr;
- is_hour_zero_prefixed = false;
- is_minute_one_digit = false;
- is_second_one_digit = false;
- hour = NO_VAL;
- minute = NO_VAL;
- second = NO_VAL;
- fraction_second = NO_VAL;
- }
-};
-
-struct TimeSpanMatch : public NonterminalMatch {
- const TimeSpanSpec* time_span_spec;
- TimespanCode time_span_code;
-
- void Reset() {
- NonterminalMatch::Reset();
- type = MatchType::MatchType_TIME_SPAN;
- time_span_spec = nullptr;
- time_span_code = TimespanCode_TIMESPAN_CODE_NONE;
- }
-};
-
-struct TimeZoneNameMatch : public NonterminalMatch {
- const TimeZoneNameSpec* time_zone_name_spec;
- TimezoneCode time_zone_code;
-
- void Reset() {
- NonterminalMatch::Reset();
- type = MatchType::MatchType_TIME_ZONE_NAME;
- time_zone_name_spec = nullptr;
- time_zone_code = TimezoneCode_TIMEZONE_CODE_NONE;
- }
-};
-
-struct TimeZoneOffsetMatch : public NonterminalMatch {
- const TimeZoneOffsetParameter* time_zone_offset_param;
- int16 time_zone_offset;
-
- void Reset() {
- NonterminalMatch::Reset();
- type = MatchType::MatchType_TIME_ZONE_OFFSET;
- time_zone_offset_param = nullptr;
- time_zone_offset = 0;
- }
-};
-
-struct DayOfWeekMatch : public IntegerMatch {
- void Reset() {
- IntegerMatch::Reset();
- type = MatchType::MatchType_DAY_OF_WEEK;
- }
-
- static bool IsValid(int x) {
- return (x > DayOfWeek_DOW_NONE && x <= DayOfWeek_MAX);
- }
-};
-
-struct TimePeriodMatch : public NonterminalMatch {
- int value;
-
- void Reset() {
- NonterminalMatch::Reset();
- type = MatchType::MatchType_TIME_PERIOD;
- value = NO_VAL;
- }
-};
-
-struct RelativeMatch : public NonterminalMatch {
- enum {
- HAS_NONE = 0,
- HAS_YEAR = 1 << 0,
- HAS_MONTH = 1 << 1,
- HAS_DAY = 1 << 2,
- HAS_WEEK = 1 << 3,
- HAS_HOUR = 1 << 4,
- HAS_MINUTE = 1 << 5,
- HAS_SECOND = 1 << 6,
- HAS_DAY_OF_WEEK = 1 << 7,
- HAS_IS_FUTURE = 1 << 31,
- };
- uint32 existing;
-
- int year;
- int month;
- int day;
- int week;
- int hour;
- int minute;
- int second;
- const NonterminalValue* day_of_week_nonterminal;
- int8 day_of_week;
- bool is_future_date;
-
- bool HasDay() const { return existing & HAS_DAY; }
-
- bool HasDayFields() const { return existing & (HAS_DAY | HAS_DAY_OF_WEEK); }
-
- bool HasTimeValueFields() const {
- return existing & (HAS_HOUR | HAS_MINUTE | HAS_SECOND);
- }
-
- bool IsStandaloneRelativeDayOfWeek() const {
- return (existing & HAS_DAY_OF_WEEK) && (existing & ~HAS_DAY_OF_WEEK) == 0;
- }
-
- void Reset() {
- NonterminalMatch::Reset();
- type = MatchType::MatchType_RELATIVE_DATE;
- existing = HAS_NONE;
- year = NO_VAL;
- month = NO_VAL;
- day = NO_VAL;
- week = NO_VAL;
- hour = NO_VAL;
- minute = NO_VAL;
- second = NO_VAL;
- day_of_week = NO_VAL;
- is_future_date = false;
- }
-};
-
-// This is not necessarily POD, it is used to keep the final matched result.
-struct DateMatch {
- // Sub-matches in the date match.
- const YearMatch* year_match = nullptr;
- const MonthMatch* month_match = nullptr;
- const DayMatch* day_match = nullptr;
- const DayOfWeekMatch* day_of_week_match = nullptr;
- const TimeValueMatch* time_value_match = nullptr;
- const TimeSpanMatch* time_span_match = nullptr;
- const TimeZoneNameMatch* time_zone_name_match = nullptr;
- const TimeZoneOffsetMatch* time_zone_offset_match = nullptr;
- const RelativeMatch* relative_match = nullptr;
- const CombinedDigitsMatch* combined_digits_match = nullptr;
-
- // [begin, end) indicates the Document position where the date or date range
- // was found.
- int begin = -1;
- int end = -1;
- int priority = 0;
- float annotator_priority_score = 0.0;
-
- int year = NO_VAL;
- int8 month = NO_VAL;
- int8 day = NO_VAL;
- DayOfWeek day_of_week = DayOfWeek_DOW_NONE;
- BCAD bc_ad = BCAD_BCAD_NONE;
- int8 hour = NO_VAL;
- int8 minute = NO_VAL;
- int8 second = NO_VAL;
- double fraction_second = NO_VAL;
- TimespanCode time_span_code = TimespanCode_TIMESPAN_CODE_NONE;
- int time_zone_code = TimezoneCode_TIMEZONE_CODE_NONE;
- int16 time_zone_offset = std::numeric_limits<int16>::min();
-
- // Fields about ambiguous hours. These fields are used to interpret the
- // possible values of ambiguous hours. Since all kinds of known ambiguities
- // are in the form of arithmetic progression (starting from .hour field),
- // we can use "ambiguous_hour_count" to denote the count of ambiguous hours,
- // and use "ambiguous_hour_interval" to denote the distance between a pair
- // of adjacent possible hours. Values in the arithmetic progression are
- // shrunk into [0, 23] (MOD 24). One can use the GetPossibleHourValues()
- // method for the complete list of possible hours.
- uint8 ambiguous_hour_count = 0;
- uint8 ambiguous_hour_interval = 0;
-
- bool is_inferred = false;
-
- // This field is set in function PerformRefinements to remove some DateMatch
- // like overlapped, duplicated, etc.
- bool is_removed = false;
-
- std::string DebugString() const;
-
- bool HasYear() const { return year != NO_VAL; }
- bool HasMonth() const { return month != NO_VAL; }
- bool HasDay() const { return day != NO_VAL; }
- bool HasDayOfWeek() const { return day_of_week != DayOfWeek_DOW_NONE; }
- bool HasBcAd() const { return bc_ad != BCAD_BCAD_NONE; }
- bool HasHour() const { return hour != NO_VAL; }
- bool HasMinute() const { return minute != NO_VAL; }
- bool HasSecond() const { return second != NO_VAL; }
- bool HasFractionSecond() const { return fraction_second != NO_VAL; }
- bool HasTimeSpanCode() const {
- return time_span_code != TimespanCode_TIMESPAN_CODE_NONE;
- }
- bool HasTimeZoneCode() const {
- return time_zone_code != TimezoneCode_TIMEZONE_CODE_NONE;
- }
- bool HasTimeZoneOffset() const {
- return time_zone_offset != std::numeric_limits<int16>::min();
- }
-
- bool HasRelativeDate() const { return relative_match != nullptr; }
-
- bool IsHourAmbiguous() const { return ambiguous_hour_count >= 2; }
-
- bool IsStandaloneTime() const {
- return (HasHour() || HasMinute()) && !HasDayOfWeek() && !HasDay() &&
- !HasMonth() && !HasYear();
- }
-
- void SetAmbiguousHourProperties(uint8 count, uint8 interval) {
- ambiguous_hour_count = count;
- ambiguous_hour_interval = interval;
- }
-
- // Outputs all the possible hour values. If current DateMatch does not
- // contain an hour, nothing will be output. If the hour is not ambiguous,
- // only one value (= .hour) will be output. This method clears the vector
- // "values" first, and it is not guaranteed that the values in the vector
- // are in a sorted order.
- void GetPossibleHourValues(std::vector<int8>* values) const;
-
- int GetPriority() const { return priority; }
-
- float GetAnnotatorPriorityScore() const { return annotator_priority_score; }
-
- bool IsStandaloneRelativeDayOfWeek() const {
- return (HasRelativeDate() &&
- relative_match->IsStandaloneRelativeDayOfWeek() &&
- !HasDateFields() && !HasTimeFields() && !HasTimeSpanCode());
- }
-
- bool HasDateFields() const {
- return (HasYear() || HasMonth() || HasDay() || HasDayOfWeek() || HasBcAd());
- }
- bool HasTimeValueFields() const {
- return (HasHour() || HasMinute() || HasSecond() || HasFractionSecond());
- }
- bool HasTimeSpanFields() const { return HasTimeSpanCode(); }
- bool HasTimeZoneFields() const {
- return (HasTimeZoneCode() || HasTimeZoneOffset());
- }
- bool HasTimeFields() const {
- return (HasTimeValueFields() || HasTimeSpanFields() || HasTimeZoneFields());
- }
-
- bool IsValid() const;
-
- // Overall relative qualifier of the DateMatch e.g. 2 year ago is 'PAST' and
- // next week is 'FUTURE'.
- DatetimeComponent::RelativeQualifier GetRelativeQualifier() const;
-
- // Getter method to get the 'DatetimeComponent' of given 'ComponentType'.
- Optional<DatetimeComponent> GetDatetimeComponent(
- const DatetimeComponent::ComponentType& component_type) const;
-
- void FillDatetimeComponents(
- std::vector<DatetimeComponent>* datetime_component) const;
-};
-
-// Represent a matched date range which includes the from and to matched date.
-struct DateRangeMatch {
- int begin = -1;
- int end = -1;
-
- DateMatch from;
- DateMatch to;
-
- std::string DebugString() const;
-
- int GetPriority() const {
- return std::max(from.GetPriority(), to.GetPriority());
- }
-
- float GetAnnotatorPriorityScore() const {
- return std::max(from.GetAnnotatorPriorityScore(),
- to.GetAnnotatorPriorityScore());
- }
-};
-
-} // namespace dates
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_UTILS_DATE_MATCH_H_
diff --git a/native/annotator/grammar/dates/utils/date-match_test.cc b/native/annotator/grammar/dates/utils/date-match_test.cc
deleted file mode 100644
index f10f32a..0000000
--- a/native/annotator/grammar/dates/utils/date-match_test.cc
+++ /dev/null
@@ -1,397 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "annotator/grammar/dates/utils/date-match.h"
-
-#include <stdint.h>
-
-#include <string>
-
-#include "annotator/grammar/dates/dates_generated.h"
-#include "annotator/grammar/dates/timezone-code_generated.h"
-#include "annotator/grammar/dates/utils/date-utils.h"
-#include "utils/strings/append.h"
-#include "gtest/gtest.h"
-
-namespace libtextclassifier3 {
-namespace dates {
-namespace {
-
-class DateMatchTest : public ::testing::Test {
- protected:
- enum {
- X = NO_VAL,
- };
-
- static DayOfWeek DOW_X() { return DayOfWeek_DOW_NONE; }
- static DayOfWeek SUN() { return DayOfWeek_SUNDAY; }
-
- static BCAD BCAD_X() { return BCAD_BCAD_NONE; }
- static BCAD BC() { return BCAD_BC; }
-
- DateMatch& SetDate(DateMatch* date, int year, int8 month, int8 day,
- DayOfWeek day_of_week = DOW_X(), BCAD bc_ad = BCAD_X()) {
- date->year = year;
- date->month = month;
- date->day = day;
- date->day_of_week = day_of_week;
- date->bc_ad = bc_ad;
- return *date;
- }
-
- DateMatch& SetTimeValue(DateMatch* date, int8 hour, int8 minute = X,
- int8 second = X, double fraction_second = X) {
- date->hour = hour;
- date->minute = minute;
- date->second = second;
- date->fraction_second = fraction_second;
- return *date;
- }
-
- DateMatch& SetTimeSpan(DateMatch* date, TimespanCode time_span_code) {
- date->time_span_code = time_span_code;
- return *date;
- }
-
- DateMatch& SetTimeZone(DateMatch* date, TimezoneCode time_zone_code,
- int16 time_zone_offset = INT16_MIN) {
- date->time_zone_code = time_zone_code;
- date->time_zone_offset = time_zone_offset;
- return *date;
- }
-
- bool SameDate(const DateMatch& a, const DateMatch& b) {
- return (a.day == b.day && a.month == b.month && a.year == b.year &&
- a.day_of_week == b.day_of_week);
- }
-
- DateMatch& SetDayOfWeek(DateMatch* date, DayOfWeek dow) {
- date->day_of_week = dow;
- return *date;
- }
-};
-
-TEST_F(DateMatchTest, BitFieldWidth) {
- // For DateMatch::day_of_week (:8).
- EXPECT_GE(DayOfWeek_MIN, INT8_MIN);
- EXPECT_LE(DayOfWeek_MAX, INT8_MAX);
-
- // For DateMatch::bc_ad (:8).
- EXPECT_GE(BCAD_MIN, INT8_MIN);
- EXPECT_LE(BCAD_MAX, INT8_MAX);
-
- // For DateMatch::time_span_code (:16).
- EXPECT_GE(TimespanCode_MIN, INT16_MIN);
- EXPECT_LE(TimespanCode_MAX, INT16_MAX);
-}
-
-TEST_F(DateMatchTest, IsValid) {
- // Valid: dates.
- {
- DateMatch d;
- SetDate(&d, 2014, 1, 26);
- EXPECT_TRUE(d.IsValid()) << d.DebugString();
- }
- {
- DateMatch d;
- SetDate(&d, 2014, 1, X);
- EXPECT_TRUE(d.IsValid()) << d.DebugString();
- }
- {
- DateMatch d;
- SetDate(&d, 2014, X, X);
- EXPECT_TRUE(d.IsValid()) << d.DebugString();
- }
- {
- DateMatch d;
- SetDate(&d, X, 1, 26);
- EXPECT_TRUE(d.IsValid()) << d.DebugString();
- }
- {
- DateMatch d;
- SetDate(&d, X, 1, X);
- EXPECT_TRUE(d.IsValid()) << d.DebugString();
- }
- {
- DateMatch d;
- SetDate(&d, X, X, 26);
- EXPECT_TRUE(d.IsValid()) << d.DebugString();
- }
- {
- DateMatch d;
- SetDate(&d, 2014, 1, 26, SUN());
- EXPECT_TRUE(d.IsValid()) << d.DebugString();
- }
- {
- DateMatch d;
- SetDate(&d, X, 1, 26, SUN());
- EXPECT_TRUE(d.IsValid()) << d.DebugString();
- }
- {
- DateMatch d;
- SetDate(&d, X, X, 26, SUN());
- EXPECT_TRUE(d.IsValid()) << d.DebugString();
- }
- {
- DateMatch d;
- SetDate(&d, 2014, 1, 26, DOW_X(), BC());
- EXPECT_TRUE(d.IsValid()) << d.DebugString();
- }
- // Valid: times.
- {
- DateMatch d;
- SetTimeValue(&d, 12, 30, 59, 0.99);
- EXPECT_TRUE(d.IsValid()) << d.DebugString();
- }
- {
- DateMatch d;
- SetTimeValue(&d, 12, 30, 59);
- EXPECT_TRUE(d.IsValid()) << d.DebugString();
- }
- {
- DateMatch d;
- SetTimeValue(&d, 12, 30);
- EXPECT_TRUE(d.IsValid()) << d.DebugString();
- }
- {
- DateMatch d;
- SetTimeValue(&d, 12);
- EXPECT_TRUE(d.IsValid()) << d.DebugString();
- }
- // Valid: mixed.
- {
- DateMatch d;
- SetDate(&d, 2014, 1, 26);
- SetTimeValue(&d, 12, 30, 59, 0.99);
- EXPECT_TRUE(d.IsValid()) << d.DebugString();
- }
- {
- DateMatch d;
- SetDate(&d, X, 1, 26);
- SetTimeValue(&d, 12, 30, 59);
- EXPECT_TRUE(d.IsValid()) << d.DebugString();
- }
- {
- DateMatch d;
- SetDate(&d, X, X, X, SUN());
- SetTimeValue(&d, 12, 30);
- EXPECT_TRUE(d.IsValid()) << d.DebugString();
- }
- // Invalid: dates.
- {
- DateMatch d;
- SetDate(&d, X, 1, 26, DOW_X(), BC());
- EXPECT_FALSE(d.IsValid()) << d.DebugString();
- }
- {
- DateMatch d;
- SetDate(&d, 2014, X, 26);
- EXPECT_FALSE(d.IsValid()) << d.DebugString();
- }
- {
- DateMatch d;
- SetDate(&d, 2014, X, X, SUN());
- EXPECT_FALSE(d.IsValid()) << d.DebugString();
- }
- {
- DateMatch d;
- SetDate(&d, X, 1, X, SUN());
- EXPECT_FALSE(d.IsValid()) << d.DebugString();
- }
- // Invalid: times.
- {
- DateMatch d;
- SetTimeValue(&d, 12, X, 59);
- EXPECT_FALSE(d.IsValid()) << d.DebugString();
- }
- {
- DateMatch d;
- SetTimeValue(&d, 12, X, X, 0.99);
- EXPECT_FALSE(d.IsValid()) << d.DebugString();
- }
- {
- DateMatch d;
- SetTimeValue(&d, 12, 30, X, 0.99);
- EXPECT_FALSE(d.IsValid()) << d.DebugString();
- }
- {
- DateMatch d;
- SetTimeValue(&d, X, 30);
- EXPECT_FALSE(d.IsValid()) << d.DebugString();
- }
- // Invalid: mixed.
- {
- DateMatch d;
- SetDate(&d, 2014, 1, X);
- SetTimeValue(&d, 12);
- EXPECT_FALSE(d.IsValid()) << d.DebugString();
- }
- // Invalid: empty.
- {
- DateMatch d;
- EXPECT_FALSE(d.IsValid()) << d.DebugString();
- }
-}
-
-std::string DebugStrings(const std::vector<DateMatch>& instances) {
- std::string res;
- for (int i = 0; i < instances.size(); ++i) {
- ::libtextclassifier3::strings::SStringAppendF(
- &res, 0, "[%d] == %s\n", i, instances[i].DebugString().c_str());
- }
- return res;
-}
-
-TEST_F(DateMatchTest, IsRefinement) {
- {
- DateMatch a;
- SetDate(&a, 2014, 2, X);
- DateMatch b;
- SetDate(&b, 2014, X, X);
- EXPECT_TRUE(IsRefinement(a, b)) << DebugStrings({a, b});
- }
- {
- DateMatch a;
- SetDate(&a, 2014, 2, 24);
- DateMatch b;
- SetDate(&b, 2014, 2, X);
- EXPECT_TRUE(IsRefinement(a, b)) << DebugStrings({a, b});
- }
- {
- DateMatch a;
- SetDate(&a, 2014, 2, 24);
- DateMatch b;
- SetDate(&b, X, 2, 24);
- EXPECT_TRUE(IsRefinement(a, b)) << DebugStrings({a, b});
- }
- {
- DateMatch a;
- SetDate(&a, 2014, 2, 24);
- SetTimeValue(&a, 9, X, X);
- DateMatch b;
- SetDate(&b, 2014, 2, 24);
- EXPECT_TRUE(IsRefinement(a, b)) << DebugStrings({a, b});
- }
- {
- DateMatch a;
- SetDate(&a, 2014, 2, 24);
- SetTimeValue(&a, 9, 0, X);
- DateMatch b;
- SetDate(&b, 2014, 2, 24);
- SetTimeValue(&b, 9, X, X);
- EXPECT_TRUE(IsRefinement(a, b)) << DebugStrings({a, b});
- }
- {
- DateMatch a;
- SetDate(&a, 2014, 2, 24);
- SetTimeValue(&a, 9, 0, 0);
- DateMatch b;
- SetDate(&b, 2014, 2, 24);
- SetTimeValue(&b, 9, 0, X);
- EXPECT_TRUE(IsRefinement(a, b)) << DebugStrings({a, b});
- }
- {
- DateMatch a;
- SetDate(&a, 2014, 2, 24);
- SetTimeValue(&a, 9, X, X);
- SetTimeSpan(&a, TimespanCode_AM);
- DateMatch b;
- SetDate(&b, 2014, 2, 24);
- SetTimeValue(&b, 9, X, X);
- EXPECT_TRUE(IsRefinement(a, b)) << DebugStrings({a, b});
- }
- {
- DateMatch a;
- SetDate(&a, 2014, 2, 24);
- SetTimeValue(&a, 9, X, X);
- SetTimeZone(&a, TimezoneCode_PST8PDT);
- DateMatch b;
- SetDate(&b, 2014, 2, 24);
- SetTimeValue(&b, 9, X, X);
- EXPECT_TRUE(IsRefinement(a, b)) << DebugStrings({a, b});
- }
- {
- DateMatch a;
- SetDate(&a, 2014, 2, 24);
- SetTimeValue(&a, 9, X, X);
- a.priority += 10;
- DateMatch b;
- SetDate(&b, 2014, 2, 24);
- SetTimeValue(&b, 9, X, X);
- EXPECT_TRUE(IsRefinement(a, b)) << DebugStrings({a, b});
- }
- {
- DateMatch a;
- SetDate(&a, 2014, 2, 24);
- SetTimeValue(&a, 9, X, X);
- DateMatch b;
- SetDate(&b, 2014, 2, 24);
- SetTimeValue(&b, 9, X, X);
- EXPECT_TRUE(IsRefinement(a, b)) << DebugStrings({a, b});
- }
- {
- DateMatch a;
- SetDate(&a, 2014, 2, 24);
- SetTimeValue(&a, 9, X, X);
- DateMatch b;
- SetDate(&b, X, 2, 24);
- SetTimeValue(&b, 9, 0, X);
- EXPECT_FALSE(IsRefinement(a, b)) << DebugStrings({a, b});
- }
- {
- DateMatch a;
- SetDate(&a, X, 2, 24);
- SetTimeValue(&a, 9, X, X);
- DateMatch b;
- SetDate(&b, 2014, 2, 24);
- EXPECT_FALSE(IsRefinement(a, b)) << DebugStrings({a, b});
- }
- {
- DateMatch a;
- SetTimeValue(&a, 9, 0, 0);
- DateMatch b;
- SetTimeValue(&b, 9, X, X);
- SetTimeSpan(&b, TimespanCode_AM);
- EXPECT_FALSE(IsRefinement(a, b)) << DebugStrings({a, b});
- }
-}
-
-TEST_F(DateMatchTest, FillDateInstance_AnnotatorPriorityScore) {
- DateMatch date_match;
- SetDate(&date_match, 2014, 2, X);
- date_match.annotator_priority_score = 0.5;
- DatetimeParseResultSpan datetime_parse_result_span;
- FillDateInstance(date_match, &datetime_parse_result_span);
- EXPECT_FLOAT_EQ(datetime_parse_result_span.priority_score, 0.5)
- << DebugStrings({date_match});
-}
-
-TEST_F(DateMatchTest, MergeDateMatch_AnnotatorPriorityScore) {
- DateMatch a;
- SetDate(&a, 2014, 2, 4);
- a.annotator_priority_score = 0.5;
-
- DateMatch b;
- SetTimeValue(&b, 10, 45, 23);
- b.annotator_priority_score = 1.0;
-
- MergeDateMatch(b, &a, false);
- EXPECT_FLOAT_EQ(a.annotator_priority_score, 1.0);
-}
-
-} // namespace
-} // namespace dates
-} // namespace libtextclassifier3
diff --git a/native/annotator/grammar/dates/utils/date-utils.cc b/native/annotator/grammar/dates/utils/date-utils.cc
deleted file mode 100644
index ea8015d..0000000
--- a/native/annotator/grammar/dates/utils/date-utils.cc
+++ /dev/null
@@ -1,399 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "annotator/grammar/dates/utils/date-utils.h"
-
-#include <algorithm>
-#include <ctime>
-
-#include "annotator/grammar/dates/annotations/annotation-util.h"
-#include "annotator/grammar/dates/dates_generated.h"
-#include "annotator/grammar/dates/utils/annotation-keys.h"
-#include "annotator/grammar/dates/utils/date-match.h"
-#include "annotator/types.h"
-#include "utils/base/macros.h"
-
-namespace libtextclassifier3 {
-namespace dates {
-
-bool IsLeapYear(int year) {
- // For the sake of completeness, we want to be able to decide
- // whether a year is a leap year all the way back to 0 Julian, or
- // 4714 BCE. But we don't want to take the modulus of a negative
- // number, because this may not be very well-defined or portable. So
- // we increment the year by some large multiple of 400, which is the
- // periodicity of this leap-year calculation.
- if (year < 0) {
- year += 8000;
- }
- return ((year) % 4 == 0 && ((year) % 100 != 0 || (year) % 400 == 0));
-}
-
-namespace {
-#define SECSPERMIN (60)
-#define MINSPERHOUR (60)
-#define HOURSPERDAY (24)
-#define DAYSPERWEEK (7)
-#define DAYSPERNYEAR (365)
-#define DAYSPERLYEAR (366)
-#define MONSPERYEAR (12)
-
-const int8 kDaysPerMonth[2][1 + MONSPERYEAR] = {
- {-1, 31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31},
- {-1, 31, 29, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31},
-};
-} // namespace
-
-int8 GetLastDayOfMonth(int year, int month) {
- if (year == 0) { // No year specified
- return kDaysPerMonth[1][month];
- }
- return kDaysPerMonth[IsLeapYear(year)][month];
-}
-
-namespace {
-inline bool IsHourInSegment(const TimeSpanSpec_::Segment* segment, int8 hour,
- bool is_exact) {
- return (hour >= segment->begin() &&
- (hour < segment->end() ||
- (hour == segment->end() && is_exact && segment->is_closed())));
-}
-
-Property* FindOrCreateDefaultDateTime(AnnotationData* inst) {
- // Refer comments for kDateTime in annotation-keys.h to see the format.
- static constexpr int kDefault[] = {-1, -1, -1, -1, -1, -1, -1, -1};
-
- int idx = GetPropertyIndex(kDateTime, *inst);
- if (idx < 0) {
- idx = AddRepeatedIntProperty(kDateTime, kDefault, TC3_ARRAYSIZE(kDefault),
- inst);
- }
- return &inst->properties[idx];
-}
-
-void IncrementDayOfWeek(DayOfWeek* dow) {
- static const DayOfWeek dow_ring[] = {DayOfWeek_MONDAY, DayOfWeek_TUESDAY,
- DayOfWeek_WEDNESDAY, DayOfWeek_THURSDAY,
- DayOfWeek_FRIDAY, DayOfWeek_SATURDAY,
- DayOfWeek_SUNDAY, DayOfWeek_MONDAY};
- const auto& cur_dow =
- std::find(std::begin(dow_ring), std::end(dow_ring), *dow);
- if (cur_dow != std::end(dow_ring)) {
- *dow = *std::next(cur_dow);
- }
-}
-} // namespace
-
-bool NormalizeHourByTimeSpan(const TimeSpanSpec* ts_spec, DateMatch* date) {
- if (ts_spec->segment() == nullptr) {
- return false;
- }
- if (date->HasHour()) {
- const bool is_exact =
- (!date->HasMinute() ||
- (date->minute == 0 &&
- (!date->HasSecond() ||
- (date->second == 0 &&
- (!date->HasFractionSecond() || date->fraction_second == 0.0)))));
- for (const TimeSpanSpec_::Segment* segment : *ts_spec->segment()) {
- if (IsHourInSegment(segment, date->hour + segment->offset(), is_exact)) {
- date->hour += segment->offset();
- return true;
- }
- if (!segment->is_strict() &&
- IsHourInSegment(segment, date->hour, is_exact)) {
- return true;
- }
- }
- } else {
- for (const TimeSpanSpec_::Segment* segment : *ts_spec->segment()) {
- if (segment->is_stand_alone()) {
- if (segment->begin() == segment->end()) {
- date->hour = segment->begin();
- }
- // Allow stand-alone time-span points and ranges.
- return true;
- }
- }
- }
- return false;
-}
-
-bool IsRefinement(const DateMatch& a, const DateMatch& b) {
- int count = 0;
- if (b.HasBcAd()) {
- if (!a.HasBcAd() || a.bc_ad != b.bc_ad) return false;
- } else if (a.HasBcAd()) {
- if (a.bc_ad == BCAD_BC) return false;
- ++count;
- }
- if (b.HasYear()) {
- if (!a.HasYear() || a.year != b.year) return false;
- } else if (a.HasYear()) {
- ++count;
- }
- if (b.HasMonth()) {
- if (!a.HasMonth() || a.month != b.month) return false;
- } else if (a.HasMonth()) {
- ++count;
- }
- if (b.HasDay()) {
- if (!a.HasDay() || a.day != b.day) return false;
- } else if (a.HasDay()) {
- ++count;
- }
- if (b.HasDayOfWeek()) {
- if (!a.HasDayOfWeek() || a.day_of_week != b.day_of_week) return false;
- } else if (a.HasDayOfWeek()) {
- ++count;
- }
- if (b.HasHour()) {
- if (!a.HasHour()) return false;
- std::vector<int8> possible_hours;
- b.GetPossibleHourValues(&possible_hours);
- if (std::find(possible_hours.begin(), possible_hours.end(), a.hour) ==
- possible_hours.end()) {
- return false;
- }
- } else if (a.HasHour()) {
- ++count;
- }
- if (b.HasMinute()) {
- if (!a.HasMinute() || a.minute != b.minute) return false;
- } else if (a.HasMinute()) {
- ++count;
- }
- if (b.HasSecond()) {
- if (!a.HasSecond() || a.second != b.second) return false;
- } else if (a.HasSecond()) {
- ++count;
- }
- if (b.HasFractionSecond()) {
- if (!a.HasFractionSecond() || a.fraction_second != b.fraction_second)
- return false;
- } else if (a.HasFractionSecond()) {
- ++count;
- }
- if (b.HasTimeSpanCode()) {
- if (!a.HasTimeSpanCode() || a.time_span_code != b.time_span_code)
- return false;
- } else if (a.HasTimeSpanCode()) {
- ++count;
- }
- if (b.HasTimeZoneCode()) {
- if (!a.HasTimeZoneCode() || a.time_zone_code != b.time_zone_code)
- return false;
- } else if (a.HasTimeZoneCode()) {
- ++count;
- }
- if (b.HasTimeZoneOffset()) {
- if (!a.HasTimeZoneOffset() || a.time_zone_offset != b.time_zone_offset)
- return false;
- } else if (a.HasTimeZoneOffset()) {
- ++count;
- }
- return (count > 0 || a.priority >= b.priority);
-}
-
-bool IsRefinement(const DateRangeMatch& a, const DateRangeMatch& b) {
- return false;
-}
-
-bool IsPrecedent(const DateMatch& a, const DateMatch& b) {
- if (a.HasYear() && b.HasYear()) {
- if (a.year < b.year) return true;
- if (a.year > b.year) return false;
- }
-
- if (a.HasMonth() && b.HasMonth()) {
- if (a.month < b.month) return true;
- if (a.month > b.month) return false;
- }
-
- if (a.HasDay() && b.HasDay()) {
- if (a.day < b.day) return true;
- if (a.day > b.day) return false;
- }
-
- if (a.HasHour() && b.HasHour()) {
- if (a.hour < b.hour) return true;
- if (a.hour > b.hour) return false;
- }
-
- if (a.HasMinute() && b.HasHour()) {
- if (a.minute < b.hour) return true;
- if (a.minute > b.hour) return false;
- }
-
- if (a.HasSecond() && b.HasSecond()) {
- if (a.second < b.hour) return true;
- if (a.second > b.hour) return false;
- }
-
- return false;
-}
-
-void FillDateInstance(const DateMatch& date,
- DatetimeParseResultSpan* instance) {
- instance->span.first = date.begin;
- instance->span.second = date.end;
- instance->priority_score = date.GetAnnotatorPriorityScore();
- DatetimeParseResult datetime_parse_result;
- date.FillDatetimeComponents(&datetime_parse_result.datetime_components);
- instance->data.emplace_back(datetime_parse_result);
-}
-
-void FillDateRangeInstance(const DateRangeMatch& range,
- DatetimeParseResultSpan* instance) {
- instance->span.first = range.begin;
- instance->span.second = range.end;
- instance->priority_score = range.GetAnnotatorPriorityScore();
-
- // Filling from DatetimeParseResult.
- instance->data.emplace_back();
- range.from.FillDatetimeComponents(&instance->data.back().datetime_components);
-
- // Filling to DatetimeParseResult.
- instance->data.emplace_back();
- range.to.FillDatetimeComponents(&instance->data.back().datetime_components);
-}
-
-namespace {
-bool AnyOverlappedField(const DateMatch& prev, const DateMatch& next) {
-#define Field(f) \
- if (prev.f && next.f) return true
- Field(year_match);
- Field(month_match);
- Field(day_match);
- Field(day_of_week_match);
- Field(time_value_match);
- Field(time_span_match);
- Field(time_zone_name_match);
- Field(time_zone_offset_match);
- Field(relative_match);
- Field(combined_digits_match);
-#undef Field
- return false;
-}
-
-void MergeDateMatchImpl(const DateMatch& prev, DateMatch* next,
- bool update_span) {
-#define RM(f) \
- if (!next->f) next->f = prev.f
- RM(year_match);
- RM(month_match);
- RM(day_match);
- RM(day_of_week_match);
- RM(time_value_match);
- RM(time_span_match);
- RM(time_zone_name_match);
- RM(time_zone_offset_match);
- RM(relative_match);
- RM(combined_digits_match);
-#undef RM
-
-#define RV(f) \
- if (next->f == NO_VAL) next->f = prev.f
- RV(year);
- RV(month);
- RV(day);
- RV(hour);
- RV(minute);
- RV(second);
- RV(fraction_second);
-#undef RV
-
-#define RE(f, v) \
- if (next->f == v) next->f = prev.f
- RE(day_of_week, DayOfWeek_DOW_NONE);
- RE(bc_ad, BCAD_BCAD_NONE);
- RE(time_span_code, TimespanCode_TIMESPAN_CODE_NONE);
- RE(time_zone_code, TimezoneCode_TIMEZONE_CODE_NONE);
-#undef RE
-
- if (next->time_zone_offset == std::numeric_limits<int16>::min()) {
- next->time_zone_offset = prev.time_zone_offset;
- }
-
- next->priority = std::max(next->priority, prev.priority);
- next->annotator_priority_score =
- std::max(next->annotator_priority_score, prev.annotator_priority_score);
- if (update_span) {
- next->begin = std::min(next->begin, prev.begin);
- next->end = std::max(next->end, prev.end);
- }
-}
-} // namespace
-
-bool IsDateMatchMergeable(const DateMatch& prev, const DateMatch& next) {
- // Do not merge if they share the same field.
- if (AnyOverlappedField(prev, next)) {
- return false;
- }
-
- // It's impossible that both prev and next have relative date since it's
- // excluded by overlapping check before.
- if (prev.HasRelativeDate() || next.HasRelativeDate()) {
- // If one of them is relative date, then we merge:
- // - if relative match shouldn't have time, and always has DOW or day.
- // - if not both relative match and non relative match has day.
- // - if non relative match has time or day.
- const DateMatch* rm = &prev;
- const DateMatch* non_rm = &prev;
- if (prev.HasRelativeDate()) {
- non_rm = &next;
- } else {
- rm = &next;
- }
-
- const RelativeMatch* relative_match = rm->relative_match;
- // Relative Match should have day or DOW but no time.
- if (!relative_match->HasDayFields() ||
- relative_match->HasTimeValueFields()) {
- return false;
- }
- // Check if both relative match and non relative match has day.
- if (non_rm->HasDateFields() && relative_match->HasDay()) {
- return false;
- }
- // Non relative match should have either hour (time) or day (date).
- if (!non_rm->HasHour() && !non_rm->HasDay()) {
- return false;
- }
- } else {
- // Only one match has date and another has time.
- if ((prev.HasDateFields() && next.HasDateFields()) ||
- (prev.HasTimeFields() && next.HasTimeFields())) {
- return false;
- }
- // DOW never be extracted as a single DateMatch except in RelativeMatch. So
- // here, we always merge one with day and another one with hour.
- if (!(prev.HasDay() || next.HasDay()) ||
- !(prev.HasHour() || next.HasHour())) {
- return false;
- }
- }
- return true;
-}
-
-void MergeDateMatch(const DateMatch& prev, DateMatch* next, bool update_span) {
- if (IsDateMatchMergeable(prev, *next)) {
- MergeDateMatchImpl(prev, next, update_span);
- }
-}
-
-} // namespace dates
-} // namespace libtextclassifier3
diff --git a/native/annotator/grammar/dates/utils/date-utils.h b/native/annotator/grammar/dates/utils/date-utils.h
deleted file mode 100644
index 2fcda92..0000000
--- a/native/annotator/grammar/dates/utils/date-utils.h
+++ /dev/null
@@ -1,82 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_UTILS_DATE_UTILS_H_
-#define LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_UTILS_DATE_UTILS_H_
-
-#include <stddef.h>
-#include <stdint.h>
-
-#include <ctime>
-#include <vector>
-
-#include "annotator/grammar/dates/annotations/annotation.h"
-#include "annotator/grammar/dates/utils/date-match.h"
-#include "utils/base/casts.h"
-
-namespace libtextclassifier3 {
-namespace dates {
-
-bool IsLeapYear(int year);
-
-int8 GetLastDayOfMonth(int year, int month);
-
-// Normalizes hour value of the specified date using the specified time-span
-// specification. Returns true if the original hour value (can be no-value)
-// is compatible with the time-span and gets normalized successfully, or
-// false otherwise.
-bool NormalizeHourByTimeSpan(const TimeSpanSpec* ts_spec, DateMatch* date);
-
-// Returns true iff "a" is considered as a refinement of "b". For example,
-// besides fully compatible fields, having more fields or higher priority.
-bool IsRefinement(const DateMatch& a, const DateMatch& b);
-bool IsRefinement(const DateRangeMatch& a, const DateRangeMatch& b);
-
-// Returns true iff "a" occurs strictly before "b"
-bool IsPrecedent(const DateMatch& a, const DateMatch& b);
-
-// Fill DatetimeParseResult based on DateMatch object which is created from
-// matched rule. The matched string is extracted from tokenizer which provides
-// an interface to access the clean text based on the matched range.
-void FillDateInstance(const DateMatch& date, DatetimeParseResult* instance);
-
-// Fill DatetimeParseResultSpan based on DateMatch object which is created from
-// matched rule. The matched string is extracted from tokenizer which provides
-// an interface to access the clean text based on the matched range.
-void FillDateInstance(const DateMatch& date, DatetimeParseResultSpan* instance);
-
-// Fill DatetimeParseResultSpan based on DateRangeMatch object which i screated
-// from matched rule.
-void FillDateRangeInstance(const DateRangeMatch& range,
- DatetimeParseResultSpan* instance);
-
-// Merge the fields in DateMatch prev to next if there is no overlapped field.
-// If update_span is true, the span of next is also updated.
-// e.g.: prev is 11am, next is: May 1, then the merged next is May 1, 11am
-void MergeDateMatch(const DateMatch& prev, DateMatch* next, bool update_span);
-
-// If DateMatches have no overlapped field, then they could be merged as the
-// following rules:
-// -- If both don't have relative match and one DateMatch has day but another
-// DateMatch has hour.
-// -- If one have relative match then follow the rules in code.
-// It's impossible to get DateMatch which only has DOW and not in relative
-// match according to current rules.
-bool IsDateMatchMergeable(const DateMatch& prev, const DateMatch& next);
-} // namespace dates
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_UTILS_DATE_UTILS_H_
diff --git a/native/annotator/grammar/grammar-annotator.cc b/native/annotator/grammar/grammar-annotator.cc
index baa3fac..cf36454 100644
--- a/native/annotator/grammar/grammar-annotator.cc
+++ b/native/annotator/grammar/grammar-annotator.cc
@@ -19,12 +19,8 @@
#include "annotator/feature-processor.h"
#include "annotator/grammar/utils.h"
#include "annotator/types.h"
+#include "utils/base/arena.h"
#include "utils/base/logging.h"
-#include "utils/grammar/callback-delegate.h"
-#include "utils/grammar/match.h"
-#include "utils/grammar/matcher.h"
-#include "utils/grammar/rules-utils.h"
-#include "utils/grammar/types.h"
#include "utils/normalization.h"
#include "utils/optional.h"
#include "utils/utf8/unicodetext.h"
@@ -32,448 +28,296 @@
namespace libtextclassifier3 {
namespace {
-// Returns the unicode codepoint offsets in a utf8 encoded text.
-std::vector<UnicodeText::const_iterator> UnicodeCodepointOffsets(
- const UnicodeText& text) {
- std::vector<UnicodeText::const_iterator> offsets;
- for (auto it = text.begin(); it != text.end(); it++) {
- offsets.push_back(it);
+// Retrieves all capturing nodes from a parse tree.
+std::unordered_map<uint16, const grammar::ParseTree*> GetCapturingNodes(
+ const grammar::ParseTree* parse_tree) {
+ std::unordered_map<uint16, const grammar::ParseTree*> capturing_nodes;
+ for (const grammar::MappingNode* mapping_node :
+ grammar::SelectAllOfType<grammar::MappingNode>(
+ parse_tree, grammar::ParseTree::Type::kMapping)) {
+ capturing_nodes[mapping_node->id] = mapping_node;
}
- offsets.push_back(text.end());
- return offsets;
+ return capturing_nodes;
+}
+
+// Computes the selection boundaries from a parse tree.
+CodepointSpan MatchSelectionBoundaries(
+ const grammar::ParseTree* parse_tree,
+ const GrammarModel_::RuleClassificationResult* classification) {
+ if (classification->capturing_group() == nullptr) {
+ // Use full match as selection span.
+ return parse_tree->codepoint_span;
+ }
+
+ // Set information from capturing matches.
+ CodepointSpan span{kInvalidIndex, kInvalidIndex};
+ std::unordered_map<uint16, const grammar::ParseTree*> capturing_nodes =
+ GetCapturingNodes(parse_tree);
+
+ // Compute span boundaries.
+ for (int i = 0; i < classification->capturing_group()->size(); i++) {
+ auto it = capturing_nodes.find(i);
+ if (it == capturing_nodes.end()) {
+ // Capturing group is not active, skip.
+ continue;
+ }
+ const CapturingGroup* group = classification->capturing_group()->Get(i);
+ if (group->extend_selection()) {
+ if (span.first == kInvalidIndex) {
+ span = it->second->codepoint_span;
+ } else {
+ span.first = std::min(span.first, it->second->codepoint_span.first);
+ span.second = std::max(span.second, it->second->codepoint_span.second);
+ }
+ }
+ }
+ return span;
}
} // namespace
-class GrammarAnnotatorCallbackDelegate : public grammar::CallbackDelegate {
- public:
- explicit GrammarAnnotatorCallbackDelegate(
- const UniLib* unilib, const GrammarModel* model,
- const ReflectiveFlatbufferBuilder* entity_data_builder,
- const ModeFlag mode)
- : unilib_(*unilib),
- model_(model),
- entity_data_builder_(entity_data_builder),
- mode_(mode) {}
+GrammarAnnotator::GrammarAnnotator(
+ const UniLib* unilib, const GrammarModel* model,
+ const MutableFlatbufferBuilder* entity_data_builder)
+ : unilib_(*unilib),
+ model_(model),
+ tokenizer_(BuildTokenizer(unilib, model->tokenizer_options())),
+ entity_data_builder_(entity_data_builder),
+ analyzer_(unilib, model->rules(), &tokenizer_) {}
- // Handles a grammar rule match in the annotator grammar.
- void MatchFound(const grammar::Match* match, grammar::CallbackId type,
- int64 value, grammar::Matcher* matcher) override {
- switch (static_cast<GrammarAnnotator::Callback>(type)) {
- case GrammarAnnotator::Callback::kRuleMatch: {
- HandleRuleMatch(match, /*rule_id=*/value);
- return;
- }
- default:
- grammar::CallbackDelegate::MatchFound(match, type, value, matcher);
+// Filters out results that do not overlap with a reference span.
+std::vector<grammar::Derivation> GrammarAnnotator::OverlappingDerivations(
+ const CodepointSpan& selection,
+ const std::vector<grammar::Derivation>& derivations,
+ const bool only_exact_overlap) const {
+ std::vector<grammar::Derivation> result;
+ for (const grammar::Derivation& derivation : derivations) {
+ // Discard matches that do not match the selection.
+ // Simple check.
+ if (!SpansOverlap(selection, derivation.parse_tree->codepoint_span)) {
+ continue;
}
+
+ // Compute exact selection boundaries (without assertions and
+ // non-capturing parts).
+ const CodepointSpan span = MatchSelectionBoundaries(
+ derivation.parse_tree,
+ model_->rule_classification_result()->Get(derivation.rule_id));
+ if (!SpansOverlap(selection, span) ||
+ (only_exact_overlap && span != selection)) {
+ continue;
+ }
+ result.push_back(derivation);
}
+ return result;
+}
- // Deduplicate and populate annotations from grammar matches.
- bool GetAnnotations(const std::vector<UnicodeText::const_iterator>& text,
- std::vector<AnnotatedSpan>* annotations) const {
- for (const grammar::Derivation& candidate :
- grammar::DeduplicateDerivations(candidates_)) {
- // Check that assertions are fulfilled.
- if (!grammar::VerifyAssertions(candidate.match)) {
- continue;
- }
- if (!AddAnnotatedSpanFromMatch(text, candidate, annotations)) {
- return false;
- }
- }
+bool GrammarAnnotator::InstantiateAnnotatedSpanFromDerivation(
+ const grammar::TextContext& input_context,
+ const grammar::ParseTree* parse_tree,
+ const GrammarModel_::RuleClassificationResult* interpretation,
+ AnnotatedSpan* result) const {
+ result->span = MatchSelectionBoundaries(parse_tree, interpretation);
+ ClassificationResult classification;
+ if (!InstantiateClassificationFromDerivation(
+ input_context, parse_tree, interpretation, &classification)) {
+ return false;
+ }
+ result->classification.push_back(classification);
+ return true;
+}
+
+// Instantiates a classification result from a rule match.
+bool GrammarAnnotator::InstantiateClassificationFromDerivation(
+ const grammar::TextContext& input_context,
+ const grammar::ParseTree* parse_tree,
+ const GrammarModel_::RuleClassificationResult* interpretation,
+ ClassificationResult* classification) const {
+ classification->collection = interpretation->collection_name()->str();
+ classification->score = interpretation->target_classification_score();
+ classification->priority_score = interpretation->priority_score();
+
+ // Assemble entity data.
+ if (entity_data_builder_ == nullptr) {
return true;
}
-
- bool GetTextSelection(const std::vector<UnicodeText::const_iterator>& text,
- const CodepointSpan& selection, AnnotatedSpan* result) {
- std::vector<grammar::Derivation> selection_candidates;
- // Deduplicate and verify matches.
- auto maybe_interpretation = GetBestValidInterpretation(
- grammar::DeduplicateDerivations(GetOverlappingRuleMatches(
- selection, candidates_, /*only_exact_overlap=*/false)));
- if (!maybe_interpretation.has_value()) {
- return false;
- }
- const GrammarModel_::RuleClassificationResult* interpretation;
- const grammar::Match* match;
- std::tie(interpretation, match) = maybe_interpretation.value();
- return InstantiateAnnotatedSpanFromInterpretation(text, interpretation,
- match, result);
+ std::unique_ptr<MutableFlatbuffer> entity_data =
+ entity_data_builder_->NewRoot();
+ if (interpretation->serialized_entity_data() != nullptr) {
+ entity_data->MergeFromSerializedFlatbuffer(
+ StringPiece(interpretation->serialized_entity_data()->data(),
+ interpretation->serialized_entity_data()->size()));
+ }
+ if (interpretation->entity_data() != nullptr) {
+ entity_data->MergeFrom(reinterpret_cast<const flatbuffers::Table*>(
+ interpretation->entity_data()));
}
- // Provides a classification results from the grammar matches.
- bool GetClassification(const std::vector<UnicodeText::const_iterator>& text,
- const CodepointSpan& selection,
- ClassificationResult* classification) const {
- // Deduplicate and verify matches.
- auto maybe_interpretation = GetBestValidInterpretation(
- grammar::DeduplicateDerivations(GetOverlappingRuleMatches(
- selection, candidates_, /*only_exact_overlap=*/true)));
- if (!maybe_interpretation.has_value()) {
- return false;
- }
-
- // Instantiate result.
- const GrammarModel_::RuleClassificationResult* interpretation;
- const grammar::Match* match;
- std::tie(interpretation, match) = maybe_interpretation.value();
- return InstantiateClassificationInterpretation(text, interpretation, match,
- classification);
- }
-
- private:
- // Handles annotation/selection/classification rule matches.
- void HandleRuleMatch(const grammar::Match* match, const int64 rule_id) {
- if ((model_->rule_classification_result()->Get(rule_id)->enabled_modes() &
- mode_) != 0) {
- candidates_.push_back(grammar::Derivation{match, rule_id});
- }
- }
-
- // Computes the selection boundaries from a grammar match.
- CodepointSpan MatchSelectionBoundaries(
- const grammar::Match* match,
- const GrammarModel_::RuleClassificationResult* classification) const {
- if (classification->capturing_group() == nullptr) {
- // Use full match as selection span.
- return match->codepoint_span;
- }
-
- // Set information from capturing matches.
- CodepointSpan span{kInvalidIndex, kInvalidIndex};
+ // Populate entity data from the capturing matches.
+ if (interpretation->capturing_group() != nullptr) {
// Gather active capturing matches.
- std::unordered_map<uint16, const grammar::Match*> capturing_matches;
- for (const grammar::MappingMatch* match :
- grammar::SelectAllOfType<grammar::MappingMatch>(
- match, grammar::Match::kMappingMatch)) {
- capturing_matches[match->id] = match;
- }
+ std::unordered_map<uint16, const grammar::ParseTree*> capturing_nodes =
+ GetCapturingNodes(parse_tree);
- // Compute span boundaries.
- for (int i = 0; i < classification->capturing_group()->size(); i++) {
- auto it = capturing_matches.find(i);
- if (it == capturing_matches.end()) {
+ for (int i = 0; i < interpretation->capturing_group()->size(); i++) {
+ auto it = capturing_nodes.find(i);
+ if (it == capturing_nodes.end()) {
// Capturing group is not active, skip.
continue;
}
- const CapturingGroup* group = classification->capturing_group()->Get(i);
- if (group->extend_selection()) {
- if (span.first == kInvalidIndex) {
- span = it->second->codepoint_span;
- } else {
- span.first = std::min(span.first, it->second->codepoint_span.first);
- span.second =
- std::max(span.second, it->second->codepoint_span.second);
+ const CapturingGroup* group = interpretation->capturing_group()->Get(i);
+
+ // Add static entity data.
+ if (group->serialized_entity_data() != nullptr) {
+ entity_data->MergeFromSerializedFlatbuffer(
+ StringPiece(interpretation->serialized_entity_data()->data(),
+ interpretation->serialized_entity_data()->size()));
+ }
+
+ // Set entity field from captured text.
+ if (group->entity_field_path() != nullptr) {
+ const grammar::ParseTree* capturing_match = it->second;
+ UnicodeText match_text =
+ input_context.Span(capturing_match->codepoint_span);
+ if (group->normalization_options() != nullptr) {
+ match_text = NormalizeText(unilib_, group->normalization_options(),
+ match_text);
+ }
+ if (!entity_data->ParseAndSet(group->entity_field_path(),
+ match_text.ToUTF8String())) {
+ TC3_LOG(ERROR) << "Could not set entity data from capturing match.";
+ return false;
}
}
}
- return span;
}
- // Filters out results that do not overlap with a reference span.
- std::vector<grammar::Derivation> GetOverlappingRuleMatches(
- const CodepointSpan& selection,
- const std::vector<grammar::Derivation>& candidates,
- const bool only_exact_overlap) const {
- std::vector<grammar::Derivation> result;
- for (const grammar::Derivation& candidate : candidates) {
- // Discard matches that do not match the selection.
- // Simple check.
- if (!SpansOverlap(selection, candidate.match->codepoint_span)) {
- continue;
- }
-
- // Compute exact selection boundaries (without assertions and
- // non-capturing parts).
- const CodepointSpan span = MatchSelectionBoundaries(
- candidate.match,
- model_->rule_classification_result()->Get(candidate.rule_id));
- if (!SpansOverlap(selection, span) ||
- (only_exact_overlap && span != selection)) {
- continue;
- }
- result.push_back(candidate);
- }
- return result;
+ if (entity_data && entity_data->HasExplicitlySetFields()) {
+ classification->serialized_entity_data = entity_data->Serialize();
}
-
- // Returns the best valid interpretation of a set of candidate matches.
- Optional<std::pair<const GrammarModel_::RuleClassificationResult*,
- const grammar::Match*>>
- GetBestValidInterpretation(
- const std::vector<grammar::Derivation>& candidates) const {
- const GrammarModel_::RuleClassificationResult* best_interpretation =
- nullptr;
- const grammar::Match* best_match = nullptr;
- for (const grammar::Derivation& candidate : candidates) {
- if (!grammar::VerifyAssertions(candidate.match)) {
- continue;
- }
- const GrammarModel_::RuleClassificationResult*
- rule_classification_result =
- model_->rule_classification_result()->Get(candidate.rule_id);
- if (best_interpretation == nullptr ||
- best_interpretation->priority_score() <
- rule_classification_result->priority_score()) {
- best_interpretation = rule_classification_result;
- best_match = candidate.match;
- }
- }
-
- // No valid interpretation found.
- Optional<std::pair<const GrammarModel_::RuleClassificationResult*,
- const grammar::Match*>>
- result;
- if (best_interpretation != nullptr) {
- result = {best_interpretation, best_match};
- }
- return result;
- }
-
- // Instantiates an annotated span from a rule match and appends it to the
- // result.
- bool AddAnnotatedSpanFromMatch(
- const std::vector<UnicodeText::const_iterator>& text,
- const grammar::Derivation& candidate,
- std::vector<AnnotatedSpan>* result) const {
- if (candidate.rule_id < 0 ||
- candidate.rule_id >= model_->rule_classification_result()->size()) {
- TC3_LOG(INFO) << "Invalid rule id.";
- return false;
- }
- const GrammarModel_::RuleClassificationResult* interpretation =
- model_->rule_classification_result()->Get(candidate.rule_id);
- result->emplace_back();
- return InstantiateAnnotatedSpanFromInterpretation(
- text, interpretation, candidate.match, &result->back());
- }
-
- bool InstantiateAnnotatedSpanFromInterpretation(
- const std::vector<UnicodeText::const_iterator>& text,
- const GrammarModel_::RuleClassificationResult* interpretation,
- const grammar::Match* match, AnnotatedSpan* result) const {
- result->span = MatchSelectionBoundaries(match, interpretation);
- ClassificationResult classification;
- if (!InstantiateClassificationInterpretation(text, interpretation, match,
- &classification)) {
- return false;
- }
- result->classification.push_back(classification);
- return true;
- }
-
- // Instantiates a classification result from a rule match.
- bool InstantiateClassificationInterpretation(
- const std::vector<UnicodeText::const_iterator>& text,
- const GrammarModel_::RuleClassificationResult* interpretation,
- const grammar::Match* match, ClassificationResult* classification) const {
- classification->collection = interpretation->collection_name()->str();
- classification->score = interpretation->target_classification_score();
- classification->priority_score = interpretation->priority_score();
-
- // Assemble entity data.
- if (entity_data_builder_ == nullptr) {
- return true;
- }
- std::unique_ptr<ReflectiveFlatbuffer> entity_data =
- entity_data_builder_->NewRoot();
- if (interpretation->serialized_entity_data() != nullptr) {
- entity_data->MergeFromSerializedFlatbuffer(
- StringPiece(interpretation->serialized_entity_data()->data(),
- interpretation->serialized_entity_data()->size()));
- }
- if (interpretation->entity_data() != nullptr) {
- entity_data->MergeFrom(reinterpret_cast<const flatbuffers::Table*>(
- interpretation->entity_data()));
- }
-
- // Populate entity data from the capturing matches.
- if (interpretation->capturing_group() != nullptr) {
- // Gather active capturing matches.
- std::unordered_map<uint16, const grammar::Match*> capturing_matches;
- for (const grammar::MappingMatch* match :
- grammar::SelectAllOfType<grammar::MappingMatch>(
- match, grammar::Match::kMappingMatch)) {
- capturing_matches[match->id] = match;
- }
- for (int i = 0; i < interpretation->capturing_group()->size(); i++) {
- auto it = capturing_matches.find(i);
- if (it == capturing_matches.end()) {
- // Capturing group is not active, skip.
- continue;
- }
- const CapturingGroup* group = interpretation->capturing_group()->Get(i);
-
- // Add static entity data.
- if (group->serialized_entity_data() != nullptr) {
- entity_data->MergeFromSerializedFlatbuffer(
- StringPiece(interpretation->serialized_entity_data()->data(),
- interpretation->serialized_entity_data()->size()));
- }
-
- // Set entity field from captured text.
- if (group->entity_field_path() != nullptr) {
- const grammar::Match* capturing_match = it->second;
- StringPiece group_text = StringPiece(
- text[capturing_match->codepoint_span.first].utf8_data(),
- text[capturing_match->codepoint_span.second].utf8_data() -
- text[capturing_match->codepoint_span.first].utf8_data());
- UnicodeText normalized_group_text =
- UTF8ToUnicodeText(group_text, /*do_copy=*/false);
- if (group->normalization_options() != nullptr) {
- normalized_group_text = NormalizeText(
- unilib_, group->normalization_options(), normalized_group_text);
- }
- if (!entity_data->ParseAndSet(group->entity_field_path(),
- normalized_group_text.ToUTF8String())) {
- TC3_LOG(ERROR) << "Could not set entity data from capturing match.";
- return false;
- }
- }
- }
- }
-
- if (entity_data && entity_data->HasExplicitlySetFields()) {
- classification->serialized_entity_data = entity_data->Serialize();
- }
- return true;
- }
-
- const UniLib& unilib_;
- const GrammarModel* model_;
- const ReflectiveFlatbufferBuilder* entity_data_builder_;
- const ModeFlag mode_;
-
- // All annotation/selection/classification rule match candidates.
- // Grammar rule matches are recorded, deduplicated and then instantiated.
- std::vector<grammar::Derivation> candidates_;
-};
-
-GrammarAnnotator::GrammarAnnotator(
- const UniLib* unilib, const GrammarModel* model,
- const ReflectiveFlatbufferBuilder* entity_data_builder)
- : unilib_(*unilib),
- model_(model),
- lexer_(unilib, model->rules()),
- tokenizer_(BuildTokenizer(unilib, model->tokenizer_options())),
- entity_data_builder_(entity_data_builder),
- rules_locales_(grammar::ParseRulesLocales(model->rules())) {}
+ return true;
+}
bool GrammarAnnotator::Annotate(const std::vector<Locale>& locales,
const UnicodeText& text,
std::vector<AnnotatedSpan>* result) const {
- if (model_ == nullptr || model_->rules() == nullptr) {
- // Nothing to do.
- return true;
+ grammar::TextContext input_context =
+ analyzer_.BuildTextContextForInput(text, locales);
+
+ UnsafeArena arena(/*block_size=*/16 << 10);
+
+ for (const grammar::Derivation& derivation : ValidDeduplicatedDerivations(
+ analyzer_.parser().Parse(input_context, &arena))) {
+ const GrammarModel_::RuleClassificationResult* interpretation =
+ model_->rule_classification_result()->Get(derivation.rule_id);
+ if ((interpretation->enabled_modes() & ModeFlag_ANNOTATION) == 0) {
+ continue;
+ }
+ result->emplace_back();
+ if (!InstantiateAnnotatedSpanFromDerivation(
+ input_context, derivation.parse_tree, interpretation,
+ &result->back())) {
+ return false;
+ }
}
- // Select locale matching rules.
- std::vector<const grammar::RulesSet_::Rules*> locale_rules =
- SelectLocaleMatchingShards(model_->rules(), rules_locales_, locales);
- if (locale_rules.empty()) {
- // Nothing to do.
- return true;
- }
-
- // Run the grammar.
- GrammarAnnotatorCallbackDelegate callback_handler(
- &unilib_, model_, entity_data_builder_,
- /*mode=*/ModeFlag_ANNOTATION);
- grammar::Matcher matcher(&unilib_, model_->rules(), locale_rules,
- &callback_handler);
- lexer_.Process(text, tokenizer_.Tokenize(text), /*annotations=*/nullptr,
- &matcher);
-
- // Populate results.
- return callback_handler.GetAnnotations(UnicodeCodepointOffsets(text), result);
+ return true;
}
bool GrammarAnnotator::SuggestSelection(const std::vector<Locale>& locales,
const UnicodeText& text,
const CodepointSpan& selection,
AnnotatedSpan* result) const {
- if (model_ == nullptr || model_->rules() == nullptr ||
- selection == CodepointSpan{kInvalidIndex, kInvalidIndex}) {
- // Nothing to do.
+ if (!selection.IsValid() || selection.IsEmpty()) {
return false;
}
- // Select locale matching rules.
- std::vector<const grammar::RulesSet_::Rules*> locale_rules =
- SelectLocaleMatchingShards(model_->rules(), rules_locales_, locales);
- if (locale_rules.empty()) {
- // Nothing to do.
- return true;
+ grammar::TextContext input_context =
+ analyzer_.BuildTextContextForInput(text, locales);
+
+ UnsafeArena arena(/*block_size=*/16 << 10);
+
+ const GrammarModel_::RuleClassificationResult* best_interpretation = nullptr;
+ const grammar::ParseTree* best_match = nullptr;
+ for (const grammar::Derivation& derivation :
+ ValidDeduplicatedDerivations(OverlappingDerivations(
+ selection, analyzer_.parser().Parse(input_context, &arena),
+ /*only_exact_overlap=*/false))) {
+ const GrammarModel_::RuleClassificationResult* interpretation =
+ model_->rule_classification_result()->Get(derivation.rule_id);
+ if ((interpretation->enabled_modes() & ModeFlag_SELECTION) == 0) {
+ continue;
+ }
+ if (best_interpretation == nullptr ||
+ interpretation->priority_score() >
+ best_interpretation->priority_score()) {
+ best_interpretation = interpretation;
+ best_match = derivation.parse_tree;
+ }
}
- // Run the grammar.
- GrammarAnnotatorCallbackDelegate callback_handler(
- &unilib_, model_, entity_data_builder_,
- /*mode=*/ModeFlag_SELECTION);
- grammar::Matcher matcher(&unilib_, model_->rules(), locale_rules,
- &callback_handler);
- lexer_.Process(text, tokenizer_.Tokenize(text), /*annotations=*/nullptr,
- &matcher);
+ if (best_interpretation == nullptr) {
+ return false;
+ }
- // Populate the result.
- return callback_handler.GetTextSelection(UnicodeCodepointOffsets(text),
- selection, result);
+ return InstantiateAnnotatedSpanFromDerivation(input_context, best_match,
+ best_interpretation, result);
}
bool GrammarAnnotator::ClassifyText(
const std::vector<Locale>& locales, const UnicodeText& text,
const CodepointSpan& selection,
ClassificationResult* classification_result) const {
- if (model_ == nullptr || model_->rules() == nullptr ||
- selection == CodepointSpan{kInvalidIndex, kInvalidIndex}) {
+ if (!selection.IsValid() || selection.IsEmpty()) {
// Nothing to do.
return false;
}
- // Select locale matching rules.
- std::vector<const grammar::RulesSet_::Rules*> locale_rules =
- SelectLocaleMatchingShards(model_->rules(), rules_locales_, locales);
- if (locale_rules.empty()) {
- // Nothing to do.
+ grammar::TextContext input_context =
+ analyzer_.BuildTextContextForInput(text, locales);
+
+ if (const TokenSpan context_span = CodepointSpanToTokenSpan(
+ input_context.tokens, selection,
+ /*snap_boundaries_to_containing_tokens=*/true);
+ context_span.IsValid()) {
+ if (model_->context_left_num_tokens() != kInvalidIndex) {
+ input_context.context_span.first =
+ std::max(0, context_span.first - model_->context_left_num_tokens());
+ }
+ if (model_->context_right_num_tokens() != kInvalidIndex) {
+ input_context.context_span.second =
+ std::min(static_cast<int>(input_context.tokens.size()),
+ context_span.second + model_->context_right_num_tokens());
+ }
+ }
+
+ UnsafeArena arena(/*block_size=*/16 << 10);
+
+ const GrammarModel_::RuleClassificationResult* best_interpretation = nullptr;
+ const grammar::ParseTree* best_match = nullptr;
+ for (const grammar::Derivation& derivation :
+ ValidDeduplicatedDerivations(OverlappingDerivations(
+ selection, analyzer_.parser().Parse(input_context, &arena),
+ /*only_exact_overlap=*/true))) {
+ const GrammarModel_::RuleClassificationResult* interpretation =
+ model_->rule_classification_result()->Get(derivation.rule_id);
+ if ((interpretation->enabled_modes() & ModeFlag_CLASSIFICATION) == 0) {
+ continue;
+ }
+ if (best_interpretation == nullptr ||
+ interpretation->priority_score() >
+ best_interpretation->priority_score()) {
+ best_interpretation = interpretation;
+ best_match = derivation.parse_tree;
+ }
+ }
+
+ if (best_interpretation == nullptr) {
return false;
}
- // Run the grammar.
- GrammarAnnotatorCallbackDelegate callback_handler(
- &unilib_, model_, entity_data_builder_,
- /*mode=*/ModeFlag_CLASSIFICATION);
- grammar::Matcher matcher(&unilib_, model_->rules(), locale_rules,
- &callback_handler);
-
- const std::vector<Token> tokens = tokenizer_.Tokenize(text);
- if (model_->context_left_num_tokens() == -1 &&
- model_->context_right_num_tokens() == -1) {
- // Use all tokens.
- lexer_.Process(text, tokens, /*annotations=*/{}, &matcher);
- } else {
- TokenSpan context_span = CodepointSpanToTokenSpan(
- tokens, selection, /*snap_boundaries_to_containing_tokens=*/true);
- std::vector<Token>::const_iterator begin = tokens.begin();
- std::vector<Token>::const_iterator end = tokens.begin();
- if (model_->context_left_num_tokens() != -1) {
- std::advance(begin, std::max(0, context_span.first -
- model_->context_left_num_tokens()));
- }
- if (model_->context_right_num_tokens() == -1) {
- end = tokens.end();
- } else {
- std::advance(end, std::min(static_cast<int>(tokens.size()),
- context_span.second +
- model_->context_right_num_tokens()));
- }
- lexer_.Process(text, begin, end,
- /*annotations=*/nullptr, &matcher);
- }
-
- // Populate result.
- return callback_handler.GetClassification(UnicodeCodepointOffsets(text),
- selection, classification_result);
+ return InstantiateClassificationFromDerivation(
+ input_context, best_match, best_interpretation, classification_result);
}
} // namespace libtextclassifier3
diff --git a/native/annotator/grammar/grammar-annotator.h b/native/annotator/grammar/grammar-annotator.h
index 365bb44..251b557 100644
--- a/native/annotator/grammar/grammar-annotator.h
+++ b/native/annotator/grammar/grammar-annotator.h
@@ -21,8 +21,10 @@
#include "annotator/model_generated.h"
#include "annotator/types.h"
-#include "utils/flatbuffers.h"
-#include "utils/grammar/lexer.h"
+#include "utils/flatbuffers/mutable.h"
+#include "utils/grammar/analyzer.h"
+#include "utils/grammar/evaluated-derivation.h"
+#include "utils/grammar/text-context.h"
#include "utils/i18n/locale.h"
#include "utils/tokenizer.h"
#include "utils/utf8/unicodetext.h"
@@ -33,13 +35,9 @@
// Grammar backed annotator.
class GrammarAnnotator {
public:
- enum class Callback : grammar::CallbackId {
- kRuleMatch = 1,
- };
-
explicit GrammarAnnotator(
const UniLib* unilib, const GrammarModel* model,
- const ReflectiveFlatbufferBuilder* entity_data_builder);
+ const MutableFlatbufferBuilder* entity_data_builder);
// Annotates a given text.
// Returns true if the text was successfully annotated.
@@ -59,14 +57,31 @@
AnnotatedSpan* result) const;
private:
+ // Filters out derivations that do not overlap with a reference span.
+ std::vector<grammar::Derivation> OverlappingDerivations(
+ const CodepointSpan& selection,
+ const std::vector<grammar::Derivation>& derivations,
+ const bool only_exact_overlap) const;
+
+ // Fills out an annotated span from a grammar match result.
+ bool InstantiateAnnotatedSpanFromDerivation(
+ const grammar::TextContext& input_context,
+ const grammar::ParseTree* parse_tree,
+ const GrammarModel_::RuleClassificationResult* interpretation,
+ AnnotatedSpan* result) const;
+
+ // Instantiates a classification result from a rule match.
+ bool InstantiateClassificationFromDerivation(
+ const grammar::TextContext& input_context,
+ const grammar::ParseTree* parse_tree,
+ const GrammarModel_::RuleClassificationResult* interpretation,
+ ClassificationResult* classification) const;
+
const UniLib& unilib_;
const GrammarModel* model_;
- const grammar::Lexer lexer_;
const Tokenizer tokenizer_;
- const ReflectiveFlatbufferBuilder* entity_data_builder_;
-
- // Pre-parsed locales of the rules.
- const std::vector<std::vector<Locale>> rules_locales_;
+ const MutableFlatbufferBuilder* entity_data_builder_;
+ const grammar::Analyzer analyzer_;
};
} // namespace libtextclassifier3
diff --git a/native/annotator/grammar/grammar-annotator_test.cc b/native/annotator/grammar/grammar-annotator_test.cc
new file mode 100644
index 0000000..6fcd1f5
--- /dev/null
+++ b/native/annotator/grammar/grammar-annotator_test.cc
@@ -0,0 +1,554 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/grammar/grammar-annotator.h"
+
+#include <memory>
+
+#include "annotator/grammar/test-utils.h"
+#include "annotator/grammar/utils.h"
+#include "annotator/model_generated.h"
+#include "utils/flatbuffers/flatbuffers.h"
+#include "utils/flatbuffers/mutable.h"
+#include "utils/grammar/utils/locale-shard-map.h"
+#include "utils/grammar/utils/rules.h"
+#include "utils/tokenizer.h"
+#include "utils/utf8/unicodetext.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+using testing::ElementsAre;
+
+flatbuffers::DetachedBuffer PackModel(const GrammarModelT& model) {
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(GrammarModel::Pack(builder, &model));
+ return builder.Release();
+}
+
+TEST_F(GrammarAnnotatorTest, AnnotesWithGrammarRules) {
+ // Create test rules.
+ GrammarModelT grammar_model;
+ SetTestTokenizerOptions(&grammar_model);
+ grammar_model.rules.reset(new grammar::RulesSetT);
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ grammar::Rules rules(locale_shard_map);
+ rules.Add("<carrier>", {"lx"});
+ rules.Add("<carrier>", {"aa"});
+ rules.Add("<flight_code>", {"<2_digits>"});
+ rules.Add("<flight_code>", {"<3_digits>"});
+ rules.Add("<flight_code>", {"<4_digits>"});
+ rules.Add(
+ "<flight>", {"<carrier>", "<flight_code>"},
+ /*callback=*/
+ static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
+ /*callback_param=*/
+ AddRuleClassificationResult("flight", ModeFlag_ALL, 1.0, &grammar_model));
+ rules.Finalize().Serialize(/*include_debug_information=*/false,
+ grammar_model.rules.get());
+ flatbuffers::DetachedBuffer serialized_model = PackModel(grammar_model);
+ GrammarAnnotator annotator(CreateGrammarAnnotator(serialized_model));
+
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(annotator.Annotate(
+ {Locale::FromBCP47("en")},
+ UTF8ToUnicodeText(
+ "My flight: LX 38 arriving at 4pm, I'll fly back on AA2014",
+ /*do_copy=*/false),
+ &result));
+
+ EXPECT_THAT(result, ElementsAre(IsAnnotatedSpan(11, 16, "flight"),
+ IsAnnotatedSpan(51, 57, "flight")));
+}
+
+TEST_F(GrammarAnnotatorTest, HandlesAssertions) {
+ // Create test rules.
+ GrammarModelT grammar_model;
+ SetTestTokenizerOptions(&grammar_model);
+ grammar_model.rules.reset(new grammar::RulesSetT);
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ grammar::Rules rules(locale_shard_map);
+ rules.Add("<carrier>", {"lx"});
+ rules.Add("<carrier>", {"aa"});
+ rules.Add("<flight_code>", {"<2_digits>"});
+ rules.Add("<flight_code>", {"<3_digits>"});
+ rules.Add("<flight_code>", {"<4_digits>"});
+
+ // Flight: carrier + flight code and check right context.
+ rules.Add(
+ "<flight>", {"<carrier>", "<flight_code>", "<context_assertion>?"},
+ /*callback=*/
+ static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
+ /*callback_param=*/
+ AddRuleClassificationResult("flight", ModeFlag_ALL, 1.0, &grammar_model));
+
+ // Exclude matches like: LX 38.00 etc.
+ rules.AddAssertion("<context_assertion>", {".?", "<digits>"},
+ /*negative=*/true);
+
+ rules.Finalize().Serialize(/*include_debug_information=*/false,
+ grammar_model.rules.get());
+ flatbuffers::DetachedBuffer serialized_model = PackModel(grammar_model);
+ GrammarAnnotator annotator(CreateGrammarAnnotator(serialized_model));
+
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(annotator.Annotate(
+ {Locale::FromBCP47("en")},
+ UTF8ToUnicodeText("My flight: LX 38 arriving at 4pm, I'll fly back on "
+ "AA2014 on LX 38.00",
+ /*do_copy=*/false),
+ &result));
+
+ EXPECT_THAT(result, ElementsAre(IsAnnotatedSpan(11, 16, "flight"),
+ IsAnnotatedSpan(51, 57, "flight")));
+}
+
+TEST_F(GrammarAnnotatorTest, HandlesCapturingGroups) {
+ // Create test rules.
+ GrammarModelT grammar_model;
+ SetTestTokenizerOptions(&grammar_model);
+ grammar_model.rules.reset(new grammar::RulesSetT);
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ grammar::Rules rules(locale_shard_map);
+ rules.AddValueMapping("<low_confidence_phone>", {"<digits>"},
+ /*value=*/0);
+
+ // Create rule result.
+ const int classification_result_id =
+ AddRuleClassificationResult("phone", ModeFlag_ALL, 1.0, &grammar_model);
+ grammar_model.rule_classification_result[classification_result_id]
+ ->capturing_group.emplace_back(new CapturingGroupT);
+ grammar_model.rule_classification_result[classification_result_id]
+ ->capturing_group.back()
+ ->extend_selection = true;
+
+ rules.Add(
+ "<phone>", {"please", "call", "<low_confidence_phone>"},
+ /*callback=*/
+ static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
+ /*callback_param=*/classification_result_id);
+
+ rules.Finalize().Serialize(/*include_debug_information=*/false,
+ grammar_model.rules.get());
+ flatbuffers::DetachedBuffer serialized_model = PackModel(grammar_model);
+ GrammarAnnotator annotator(CreateGrammarAnnotator(serialized_model));
+
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(annotator.Annotate(
+ {Locale::FromBCP47("en")},
+ UTF8ToUnicodeText("Please call 911 before 10 am!", /*do_copy=*/false),
+ &result));
+ EXPECT_THAT(result, ElementsAre(IsAnnotatedSpan(12, 15, "phone")));
+}
+
+TEST_F(GrammarAnnotatorTest, ClassifiesTextWithGrammarRules) {
+ // Create test rules.
+ GrammarModelT grammar_model;
+ SetTestTokenizerOptions(&grammar_model);
+ grammar_model.rules.reset(new grammar::RulesSetT);
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ grammar::Rules rules(locale_shard_map);
+ rules.Add("<carrier>", {"lx"});
+ rules.Add("<carrier>", {"aa"});
+ rules.Add("<flight_code>", {"<2_digits>"});
+ rules.Add("<flight_code>", {"<3_digits>"});
+ rules.Add("<flight_code>", {"<4_digits>"});
+ rules.Add(
+ "<flight>", {"<carrier>", "<flight_code>"},
+ /*callback=*/
+ static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
+ /*callback_param=*/
+ AddRuleClassificationResult("flight", ModeFlag_ALL, 1.0, &grammar_model));
+ rules.Finalize().Serialize(/*include_debug_information=*/false,
+ grammar_model.rules.get());
+ flatbuffers::DetachedBuffer serialized_model = PackModel(grammar_model);
+ GrammarAnnotator annotator(CreateGrammarAnnotator(serialized_model));
+
+ ClassificationResult result;
+ EXPECT_TRUE(annotator.ClassifyText(
+ {Locale::FromBCP47("en")},
+ UTF8ToUnicodeText(
+ "My flight: LX 38 arriving at 4pm, I'll fly back on AA2014",
+ /*do_copy=*/false),
+ CodepointSpan{11, 16}, &result));
+ EXPECT_THAT(result, IsClassificationResult("flight"));
+}
+
+TEST_F(GrammarAnnotatorTest, ClassifiesTextWithAssertions) {
+ // Create test rules.
+ GrammarModelT grammar_model;
+ SetTestTokenizerOptions(&grammar_model);
+ grammar_model.rules.reset(new grammar::RulesSetT);
+
+ // Use unbounded context.
+ grammar_model.context_left_num_tokens = -1;
+ grammar_model.context_right_num_tokens = -1;
+
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ grammar::Rules rules(locale_shard_map);
+ rules.Add("<carrier>", {"lx"});
+ rules.Add("<carrier>", {"aa"});
+ rules.Add("<flight_code>", {"<2_digits>"});
+ rules.Add("<flight_code>", {"<3_digits>"});
+ rules.Add("<flight_code>", {"<4_digits>"});
+ rules.AddValueMapping("<flight_selection>", {"<carrier>", "<flight_code>"},
+ /*value=*/0);
+
+ // Flight: carrier + flight code and check right context.
+ const int classification_result_id =
+ AddRuleClassificationResult("flight", ModeFlag_ALL, 1.0, &grammar_model);
+ rules.Add(
+ "<flight>", {"<flight_selection>", "<context_assertion>?"},
+ /*callback=*/
+ static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
+ /*callback_param=*/
+ classification_result_id);
+
+ grammar_model.rule_classification_result[classification_result_id]
+ ->capturing_group.emplace_back(new CapturingGroupT);
+ grammar_model.rule_classification_result[classification_result_id]
+ ->capturing_group.back()
+ ->extend_selection = true;
+
+ // Exclude matches like: LX 38.00 etc.
+ rules.AddAssertion("<context_assertion>", {".?", "<digits>"},
+ /*negative=*/true);
+
+ rules.Finalize().Serialize(/*include_debug_information=*/false,
+ grammar_model.rules.get());
+ flatbuffers::DetachedBuffer serialized_model = PackModel(grammar_model);
+ GrammarAnnotator annotator(CreateGrammarAnnotator(serialized_model));
+
+ EXPECT_FALSE(annotator.ClassifyText(
+ {Locale::FromBCP47("en")},
+ UTF8ToUnicodeText("See LX 38.00", /*do_copy=*/false), CodepointSpan{4, 9},
+ nullptr));
+ EXPECT_FALSE(annotator.ClassifyText(
+ {Locale::FromBCP47("en")},
+ UTF8ToUnicodeText("See LX 38 00", /*do_copy=*/false), CodepointSpan{4, 9},
+ nullptr));
+ ClassificationResult result;
+ EXPECT_TRUE(annotator.ClassifyText(
+ {Locale::FromBCP47("en")},
+ UTF8ToUnicodeText("See LX 38, seat 5", /*do_copy=*/false),
+ CodepointSpan{4, 9}, &result));
+ EXPECT_THAT(result, IsClassificationResult("flight"));
+}
+
+TEST_F(GrammarAnnotatorTest, ClassifiesTextWithContext) {
+ // Create test rules.
+ GrammarModelT grammar_model;
+ SetTestTokenizerOptions(&grammar_model);
+ grammar_model.rules.reset(new grammar::RulesSetT);
+
+ // Max three tokens to the left ("tracking number: ...").
+ grammar_model.context_left_num_tokens = 3;
+ grammar_model.context_right_num_tokens = 0;
+
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ grammar::Rules rules(locale_shard_map);
+ rules.Add("<tracking_number>", {"<5_digits>"});
+ rules.Add("<tracking_number>", {"<6_digits>"});
+ rules.Add("<tracking_number>", {"<7_digits>"});
+ rules.Add("<tracking_number>", {"<8_digits>"});
+ rules.Add("<tracking_number>", {"<9_digits>"});
+ rules.Add("<tracking_number>", {"<10_digits>"});
+ rules.AddValueMapping("<captured_tracking_number>", {"<tracking_number>"},
+ /*value=*/0);
+ rules.Add("<parcel_tracking_trigger>", {"tracking", "number?", ":?"});
+
+ const int classification_result_id = AddRuleClassificationResult(
+ "parcel_tracking", ModeFlag_ALL, 1.0, &grammar_model);
+ rules.Add(
+ "<parcel_tracking>",
+ {"<parcel_tracking_trigger>", "<captured_tracking_number>"},
+ /*callback=*/
+ static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
+ /*callback_param=*/
+ classification_result_id);
+
+ grammar_model.rule_classification_result[classification_result_id]
+ ->capturing_group.emplace_back(new CapturingGroupT);
+ grammar_model.rule_classification_result[classification_result_id]
+ ->capturing_group.back()
+ ->extend_selection = true;
+
+ rules.Finalize().Serialize(/*include_debug_information=*/false,
+ grammar_model.rules.get());
+ flatbuffers::DetachedBuffer serialized_model = PackModel(grammar_model);
+ GrammarAnnotator annotator(CreateGrammarAnnotator(serialized_model));
+
+ ClassificationResult result;
+ EXPECT_TRUE(annotator.ClassifyText(
+ {Locale::FromBCP47("en")},
+ UTF8ToUnicodeText("Use tracking number 012345 for live parcel tracking.",
+ /*do_copy=*/false),
+ CodepointSpan{20, 26}, &result));
+ EXPECT_THAT(result, IsClassificationResult("parcel_tracking"));
+
+ EXPECT_FALSE(annotator.ClassifyText(
+ {Locale::FromBCP47("en")},
+ UTF8ToUnicodeText("Call phone 012345 for live parcel tracking.",
+ /*do_copy=*/false),
+ CodepointSpan{11, 17}, &result));
+}
+
+TEST_F(GrammarAnnotatorTest, SuggestsTextSelection) {
+ // Create test rules.
+ GrammarModelT grammar_model;
+ SetTestTokenizerOptions(&grammar_model);
+ grammar_model.rules.reset(new grammar::RulesSetT);
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ grammar::Rules rules(locale_shard_map);
+ rules.Add("<carrier>", {"lx"});
+ rules.Add("<carrier>", {"aa"});
+ rules.Add("<flight_code>", {"<2_digits>"});
+ rules.Add("<flight_code>", {"<3_digits>"});
+ rules.Add("<flight_code>", {"<4_digits>"});
+ rules.Add(
+ "<flight>", {"<carrier>", "<flight_code>"},
+ /*callback=*/
+ static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
+ /*callback_param=*/
+ AddRuleClassificationResult("flight", ModeFlag_ALL, 1.0, &grammar_model));
+ rules.Finalize().Serialize(/*include_debug_information=*/false,
+ grammar_model.rules.get());
+ flatbuffers::DetachedBuffer serialized_model = PackModel(grammar_model);
+ GrammarAnnotator annotator(CreateGrammarAnnotator(serialized_model));
+
+ AnnotatedSpan selection;
+ EXPECT_TRUE(annotator.SuggestSelection(
+ {Locale::FromBCP47("en")},
+ UTF8ToUnicodeText(
+ "My flight: LX 38 arriving at 4pm, I'll fly back on AA2014",
+ /*do_copy=*/false),
+ /*selection=*/CodepointSpan{14, 15}, &selection));
+ EXPECT_THAT(selection, IsAnnotatedSpan(11, 16, "flight"));
+}
+
+TEST_F(GrammarAnnotatorTest, SetsFixedEntityData) {
+ // Create test rules.
+ GrammarModelT grammar_model;
+ SetTestTokenizerOptions(&grammar_model);
+ grammar_model.rules.reset(new grammar::RulesSetT);
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ grammar::Rules rules(locale_shard_map);
+ const int person_result =
+ AddRuleClassificationResult("person", ModeFlag_ALL, 1.0, &grammar_model);
+ rules.Add(
+ "<person>", {"barack", "obama"},
+ /*callback=*/
+ static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
+ /*callback_param=*/person_result);
+
+ // Add test entity data.
+ std::unique_ptr<MutableFlatbuffer> entity_data =
+ entity_data_builder_->NewRoot();
+ entity_data->Set("person", "Former President Barack Obama");
+ grammar_model.rule_classification_result[person_result]
+ ->serialized_entity_data = entity_data->Serialize();
+
+ rules.Finalize().Serialize(/*include_debug_information=*/false,
+ grammar_model.rules.get());
+ flatbuffers::DetachedBuffer serialized_model = PackModel(grammar_model);
+ GrammarAnnotator annotator(CreateGrammarAnnotator(serialized_model));
+
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(annotator.Annotate(
+ {Locale::FromBCP47("en")},
+ UTF8ToUnicodeText("I saw Barack Obama today", /*do_copy=*/false),
+ &result));
+ EXPECT_THAT(result, ElementsAre(IsAnnotatedSpan(6, 18, "person")));
+
+ // Check entity data.
+ // As we don't have generated code for the ad-hoc generated entity data
+ // schema, we have to check manually using field offsets.
+ const flatbuffers::Table* entity =
+ flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
+ result.front().classification.front().serialized_entity_data.data()));
+ EXPECT_THAT(
+ entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(),
+ "Former President Barack Obama");
+}
+
+TEST_F(GrammarAnnotatorTest, SetsEntityDataFromCapturingMatches) {
+ // Create test rules.
+ GrammarModelT grammar_model;
+ SetTestTokenizerOptions(&grammar_model);
+ grammar_model.rules.reset(new grammar::RulesSetT);
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ grammar::Rules rules(locale_shard_map);
+ const int person_result =
+ AddRuleClassificationResult("person", ModeFlag_ALL, 1.0, &grammar_model);
+
+ rules.Add("<person>", {"barack?", "obama"});
+ rules.Add("<person>", {"zapp?", "brannigan"});
+ rules.AddValueMapping("<captured_person>", {"<person>"},
+ /*value=*/0);
+ rules.Add(
+ "<test>", {"<captured_person>"},
+ /*callback=*/
+ static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
+ /*callback_param=*/person_result);
+
+ // Set capturing group entity data information.
+ grammar_model.rule_classification_result[person_result]
+ ->capturing_group.emplace_back(new CapturingGroupT);
+ CapturingGroupT* group =
+ grammar_model.rule_classification_result[person_result]
+ ->capturing_group.back()
+ .get();
+ group->entity_field_path.reset(new FlatbufferFieldPathT);
+ group->entity_field_path->field.emplace_back(new FlatbufferFieldT);
+ group->entity_field_path->field.back()->field_name = "person";
+ group->normalization_options.reset(new NormalizationOptionsT);
+ group->normalization_options->codepointwise_normalization =
+ NormalizationOptions_::CodepointwiseNormalizationOp_UPPERCASE;
+
+ rules.Finalize().Serialize(/*include_debug_information=*/false,
+ grammar_model.rules.get());
+ flatbuffers::DetachedBuffer serialized_model = PackModel(grammar_model);
+ GrammarAnnotator annotator(CreateGrammarAnnotator(serialized_model));
+
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(annotator.Annotate(
+ {Locale::FromBCP47("en")},
+ UTF8ToUnicodeText("I saw Zapp Brannigan today", /*do_copy=*/false),
+ &result));
+ EXPECT_THAT(result, ElementsAre(IsAnnotatedSpan(6, 20, "person")));
+
+ // Check entity data.
+ // As we don't have generated code for the ad-hoc generated entity data
+ // schema, we have to check manually using field offsets.
+ const flatbuffers::Table* entity =
+ flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
+ result.front().classification.front().serialized_entity_data.data()));
+ EXPECT_THAT(
+ entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(),
+ "ZAPP BRANNIGAN");
+}
+
+TEST_F(GrammarAnnotatorTest, RespectsRuleModes) {
+ // Create test rules.
+ GrammarModelT grammar_model;
+ SetTestTokenizerOptions(&grammar_model);
+ grammar_model.rules.reset(new grammar::RulesSetT);
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ grammar::Rules rules(locale_shard_map);
+ rules.Add("<classification_carrier>", {"ei"});
+ rules.Add("<classification_carrier>", {"en"});
+ rules.Add("<selection_carrier>", {"ai"});
+ rules.Add("<selection_carrier>", {"bx"});
+ rules.Add("<annotation_carrier>", {"aa"});
+ rules.Add("<annotation_carrier>", {"lx"});
+ rules.Add("<flight_code>", {"<2_digits>"});
+ rules.Add("<flight_code>", {"<3_digits>"});
+ rules.Add("<flight_code>", {"<4_digits>"});
+ rules.Add(
+ "<flight>", {"<annotation_carrier>", "<flight_code>"},
+ /*callback=*/
+ static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
+ /*callback_param=*/
+ AddRuleClassificationResult("flight", ModeFlag_ALL, 1.0, &grammar_model));
+ rules.Add(
+ "<flight>", {"<selection_carrier>", "<flight_code>"},
+ /*callback=*/
+ static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
+ /*callback_param=*/
+ AddRuleClassificationResult("flight",
+ ModeFlag_CLASSIFICATION_AND_SELECTION, 1.0,
+ &grammar_model));
+ rules.Add(
+ "<flight>", {"<classification_carrier>", "<flight_code>"},
+ /*callback=*/
+ static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
+ /*callback_param=*/
+ AddRuleClassificationResult("flight", ModeFlag_CLASSIFICATION, 1.0,
+ &grammar_model));
+ rules.Finalize().Serialize(/*include_debug_information=*/false,
+ grammar_model.rules.get());
+ flatbuffers::DetachedBuffer serialized_model = PackModel(grammar_model);
+ GrammarAnnotator annotator(CreateGrammarAnnotator(serialized_model));
+
+ const UnicodeText text = UTF8ToUnicodeText(
+ "My flight: LX 38 arriving at 4pm, I'll fly back on EI2014 but maybe "
+ "also on bx 222");
+ const std::vector<Locale> locales = {Locale::FromBCP47("en")};
+
+ // Annotation, only high confidence pattern.
+ {
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(annotator.Annotate(locales, text, &result));
+ EXPECT_THAT(result, ElementsAre(IsAnnotatedSpan(11, 16, "flight")));
+ }
+
+ // Selection, annotation patterns + selection.
+ {
+ AnnotatedSpan selection;
+
+ // Selects 'LX 38'.
+ EXPECT_TRUE(annotator.SuggestSelection(locales, text,
+ /*selection=*/CodepointSpan{14, 15},
+ &selection));
+ EXPECT_THAT(selection, IsAnnotatedSpan(11, 16, "flight"));
+
+ // Selects 'bx 222'.
+ EXPECT_TRUE(annotator.SuggestSelection(locales, text,
+ /*selection=*/CodepointSpan{76, 77},
+ &selection));
+ EXPECT_THAT(selection, IsAnnotatedSpan(76, 82, "flight"));
+
+ // Doesn't select 'EI2014'.
+ EXPECT_FALSE(annotator.SuggestSelection(locales, text,
+ /*selection=*/CodepointSpan{51, 51},
+ &selection));
+ }
+
+ // Classification, all patterns.
+ {
+ ClassificationResult result;
+
+ // Classifies 'LX 38'.
+ EXPECT_TRUE(
+ annotator.ClassifyText(locales, text, CodepointSpan{11, 16}, &result));
+ EXPECT_THAT(result, IsClassificationResult("flight"));
+
+ // Classifies 'EI2014'.
+ EXPECT_TRUE(
+ annotator.ClassifyText(locales, text, CodepointSpan{51, 57}, &result));
+ EXPECT_THAT(result, IsClassificationResult("flight"));
+
+ // Classifies 'bx 222'.
+ EXPECT_TRUE(
+ annotator.ClassifyText(locales, text, CodepointSpan{76, 82}, &result));
+ EXPECT_THAT(result, IsClassificationResult("flight"));
+ }
+}
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/native/annotator/grammar/test-utils.cc b/native/annotator/grammar/test-utils.cc
new file mode 100644
index 0000000..45bc301
--- /dev/null
+++ b/native/annotator/grammar/test-utils.cc
@@ -0,0 +1,38 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/grammar/test-utils.h"
+
+#include "utils/tokenizer.h"
+
+namespace libtextclassifier3 {
+
+GrammarAnnotator GrammarAnnotatorTest::CreateGrammarAnnotator(
+ const ::flatbuffers::DetachedBuffer& serialized_model) {
+ return GrammarAnnotator(
+ unilib_.get(),
+ flatbuffers::GetRoot<GrammarModel>(serialized_model.data()),
+ entity_data_builder_.get());
+}
+
+void SetTestTokenizerOptions(GrammarModelT* model) {
+ model->tokenizer_options.reset(new GrammarTokenizerOptionsT);
+ model->tokenizer_options->tokenization_type = TokenizationType_ICU;
+ model->tokenizer_options->icu_preserve_whitespace_tokens = false;
+ model->tokenizer_options->tokenize_on_script_change = true;
+}
+
+} // namespace libtextclassifier3
diff --git a/native/annotator/grammar/test-utils.h b/native/annotator/grammar/test-utils.h
new file mode 100644
index 0000000..e6b2071
--- /dev/null
+++ b/native/annotator/grammar/test-utils.h
@@ -0,0 +1,74 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_TEST_UTILS_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_TEST_UTILS_H_
+
+#include <memory>
+
+#include "actions/test-utils.h"
+#include "annotator/grammar/grammar-annotator.h"
+#include "utils/flatbuffers/mutable.h"
+#include "utils/jvm-test-utils.h"
+#include "utils/utf8/unilib.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+
+// TODO(sofian): Move this matchers to a level up library, useable for more
+// tests in text_classifier.
+MATCHER_P3(IsAnnotatedSpan, start, end, collection,
+ "is annotated span with begin that " +
+ ::testing::DescribeMatcher<int>(start, negation) +
+ ", end that " + ::testing::DescribeMatcher<int>(end, negation) +
+ ", collection that " +
+ ::testing::DescribeMatcher<std::string>(collection, negation)) {
+ return ::testing::ExplainMatchResult(CodepointSpan(start, end), arg.span,
+ result_listener) &&
+ ::testing::ExplainMatchResult(::testing::StrEq(collection),
+ arg.classification.front().collection,
+ result_listener);
+}
+
+MATCHER_P(IsClassificationResult, collection,
+ "is classification result with collection that " +
+ ::testing::DescribeMatcher<std::string>(collection, negation)) {
+ return ::testing::ExplainMatchResult(::testing::StrEq(collection),
+ arg.collection, result_listener);
+}
+
+class GrammarAnnotatorTest : public ::testing::Test {
+ protected:
+ GrammarAnnotatorTest()
+ : unilib_(CreateUniLibForTesting()),
+ serialized_entity_data_schema_(TestEntityDataSchema()),
+ entity_data_builder_(new MutableFlatbufferBuilder(
+ flatbuffers::GetRoot<reflection::Schema>(
+ serialized_entity_data_schema_.data()))) {}
+
+ GrammarAnnotator CreateGrammarAnnotator(
+ const ::flatbuffers::DetachedBuffer& serialized_model);
+
+ std::unique_ptr<UniLib> unilib_;
+ const std::string serialized_entity_data_schema_;
+ std::unique_ptr<MutableFlatbufferBuilder> entity_data_builder_;
+};
+
+void SetTestTokenizerOptions(GrammarModelT* model);
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_TEST_UTILS_H_
diff --git a/native/annotator/grammar/utils.cc b/native/annotator/grammar/utils.cc
index 8b9363d..bb58190 100644
--- a/native/annotator/grammar/utils.cc
+++ b/native/annotator/grammar/utils.cc
@@ -53,13 +53,14 @@
int AddRuleClassificationResult(const std::string& collection,
const ModeFlag& enabled_modes,
- GrammarModelT* model) {
+ float priority_score, GrammarModelT* model) {
const int result_id = model->rule_classification_result.size();
model->rule_classification_result.emplace_back(new RuleClassificationResultT);
RuleClassificationResultT* result =
model->rule_classification_result.back().get();
result->collection_name = collection;
result->enabled_modes = enabled_modes;
+ result->priority_score = priority_score;
return result_id;
}
diff --git a/native/annotator/grammar/utils.h b/native/annotator/grammar/utils.h
index 4d870fd..21d383f 100644
--- a/native/annotator/grammar/utils.h
+++ b/native/annotator/grammar/utils.h
@@ -35,7 +35,7 @@
// Returns the ID associated with the created classification rule.
int AddRuleClassificationResult(const std::string& collection,
const ModeFlag& enabled_modes,
- GrammarModelT* model);
+ float priority_score, GrammarModelT* model);
} // namespace libtextclassifier3
diff --git a/native/annotator/knowledge/knowledge-engine-dummy.h b/native/annotator/knowledge/knowledge-engine-dummy.h
index e9f688a..34fa490 100644
--- a/native/annotator/knowledge/knowledge-engine-dummy.h
+++ b/native/annotator/knowledge/knowledge-engine-dummy.h
@@ -19,6 +19,7 @@
#include <string>
+#include "annotator/knowledge/knowledge-engine-types.h"
#include "annotator/model_generated.h"
#include "annotator/types.h"
#include "utils/base/status.h"
@@ -36,33 +37,38 @@
void SetPriorityScore(float priority_score) {}
- bool ClassifyText(const std::string& text, CodepointSpan selection_indices,
- AnnotationUsecase annotation_usecase,
- const Optional<LocationContext>& location_context,
- const Permissions& permissions,
- ClassificationResult* classification_result) const {
- return false;
+ Status ClassifyText(const std::string& text, CodepointSpan selection_indices,
+ AnnotationUsecase annotation_usecase,
+ const Optional<LocationContext>& location_context,
+ const Permissions& permissions,
+ ClassificationResult* classification_result) const {
+ return Status(StatusCode::UNIMPLEMENTED, "Not implemented.");
}
- bool Chunk(const std::string& text, AnnotationUsecase annotation_usecase,
- const Optional<LocationContext>& location_context,
- const Permissions& permissions,
- std::vector<AnnotatedSpan>* result) const {
- return true;
+ Status Chunk(const std::string& text, AnnotationUsecase annotation_usecase,
+ const Optional<LocationContext>& location_context,
+ const Permissions& permissions, const AnnotateMode annotate_mode,
+ Annotations* result) const {
+ return Status::OK;
}
Status ChunkMultipleSpans(
const std::vector<std::string>& text_fragments,
+ const std::vector<FragmentMetadata>& fragment_metadata,
AnnotationUsecase annotation_usecase,
const Optional<LocationContext>& location_context,
- const Permissions& permissions,
- std::vector<std::vector<AnnotatedSpan>>* results) const {
+ const Permissions& permissions, const AnnotateMode annotate_mode,
+ Annotations* results) const {
return Status::OK;
}
- bool LookUpEntity(const std::string& id,
- std::string* serialized_knowledge_result) const {
- return false;
+ StatusOr<std::string> LookUpEntity(const std::string& id) const {
+ return Status(StatusCode::UNIMPLEMENTED, "Not implemented.");
+ }
+
+ StatusOr<std::string> LookUpEntityProperty(
+ const std::string& mid_str, const std::string& property) const {
+ return Status(StatusCode::UNIMPLEMENTED, "Not implemented");
}
};
diff --git a/native/annotator/knowledge/knowledge-engine-types.h b/native/annotator/knowledge/knowledge-engine-types.h
new file mode 100644
index 0000000..04b71cb
--- /dev/null
+++ b/native/annotator/knowledge/knowledge-engine-types.h
@@ -0,0 +1,31 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_KNOWLEDGE_KNOWLEDGE_ENGINE_TYPES_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_KNOWLEDGE_KNOWLEDGE_ENGINE_TYPES_H_
+
+namespace libtextclassifier3 {
+
+enum AnnotateMode { kEntityAnnotation, kEntityAndTopicalityAnnotation };
+
+struct FragmentMetadata {
+ float relative_bounding_box_top;
+ float relative_bounding_box_height;
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_KNOWLEDGE_KNOWLEDGE_ENGINE_TYPES_H_
diff --git a/native/annotator/model.fbs b/native/annotator/model.fbs
old mode 100755
new mode 100644
index bdb7a17..57187f5
--- a/native/annotator/model.fbs
+++ b/native/annotator/model.fbs
@@ -16,9 +16,9 @@
include "annotator/entity-data.fbs";
include "annotator/experimental/experimental.fbs";
-include "annotator/grammar/dates/dates.fbs";
include "utils/codepoint-range.fbs";
-include "utils/flatbuffers.fbs";
+include "utils/container/bit-vector.fbs";
+include "utils/flatbuffers/flatbuffers.fbs";
include "utils/grammar/rules.fbs";
include "utils/intents/intent-config.fbs";
include "utils/normalization.fbs";
@@ -131,6 +131,8 @@
NINETY = 70,
HUNDRED = 71,
THOUSAND = 72,
+ NOON = 73,
+ MIDNIGHT = 74,
}
namespace libtextclassifier3;
@@ -154,6 +156,7 @@
GROUP_DUMMY1 = 12,
GROUP_DUMMY2 = 13,
+ GROUP_ABSOLUTETIME = 14,
}
// Options for the model that predicts text selection.
@@ -370,85 +373,6 @@
tokenize_on_script_change:bool = false;
}
-// Options for grammar date/datetime/date range annotations.
-namespace libtextclassifier3.GrammarDatetimeModel_;
-table AnnotationOptions {
- // If enabled, extract special day offset like today, yesterday, etc.
- enable_special_day_offset:bool = true;
-
- // If true, merge the adjacent day of week, time and date. e.g.
- // "20/2/2016 at 8pm" is extracted as a single instance instead of two
- // instance: "20/2/2016" and "8pm".
- merge_adjacent_components:bool = true;
-
- // List the extra id of requested dates.
- extra_requested_dates:[string];
-
- // If true, try to include preposition to the extracted annotation. e.g.
- // "at 6pm". if it's false, only 6pm is included. offline-actions has
- // special requirements to include preposition.
- include_preposition:bool = true;
-
- // If enabled, extract range in date annotator.
- // input: Monday, 5-6pm
- // If the flag is true, The extracted annotation only contains 1 range
- // instance which is from Monday 5pm to 6pm.
- // If the flag is false, The extracted annotation contains two date
- // instance: "Monday" and "6pm".
- enable_date_range:bool = true;
- reserved_6:int16 (deprecated);
-
- // If enabled, the rule priority score is used to set the priority score of
- // the annotation.
- // In case of false the annotation priority score is set from
- // GrammarDatetimeModel's priority_score
- use_rule_priority_score:bool = false;
-
- // If enabled, annotator will try to resolve the ambiguity by generating
- // possible alternative interpretations of the input text
- // e.g. '9:45' will be resolved to '9:45 AM' and '9:45 PM'.
- generate_alternative_interpretations_when_ambiguous:bool;
-
- // List of spans which grammar will ignore during the match e.g. if
- // “@” is in the allowed span list and input is “12 March @ 12PM” then “@”
- // will be ignored and 12 March @ 12PM will be translate to
- // {Day:12 Month: March Hour: 12 MERIDIAN: PM}.
- // This can also be achieved by adding additional rules e.g.
- // <Digit_Day> <Month> <Time>
- // <Digit_Day> <Month> @ <Time>
- // Though this is doable in the grammar but requires multiple rules, this
- // list enables the rule to represent multiple rules.
- ignored_spans:[string];
-}
-
-namespace libtextclassifier3;
-table GrammarDatetimeModel {
- // List of BCP 47 locale strings representing all locales supported by the
- // model.
- locales:[string];
-
- // If true, will give only future dates (when the day is not specified).
- prefer_future_for_unspecified_date:bool = false;
-
- // Grammar specific tokenizer options.
- grammar_tokenizer_options:GrammarTokenizerOptions;
-
- // The modes for which to apply the grammars.
- enabled_modes:ModeFlag = ALL;
-
- // The datetime grammar rules.
- datetime_rules:dates.DatetimeRules;
-
- // The final score to assign to the results of grammar model
- target_classification_score:float = 1;
-
- // The priority score used for conflict resolution with the other models.
- priority_score:float = 0;
-
- // Options for grammar annotations.
- annotation_options:GrammarDatetimeModel_.AnnotationOptions;
-}
-
namespace libtextclassifier3.DatetimeModelLibrary_;
table Item {
key:string (shared);
@@ -502,12 +426,29 @@
// Grammar specific tokenizer options.
tokenizer_options:GrammarTokenizerOptions;
+
+ // The score.
+ target_classification_score:float = 1;
+
+ // The priority score used for conflict resolution with the other models.
+ priority_score:float = 1;
+}
+
+namespace libtextclassifier3.MoneyParsingOptions_;
+table QuantitiesNameToExponentEntry {
+ key:string (key, shared);
+ value:int;
}
namespace libtextclassifier3;
table MoneyParsingOptions {
// Separators (codepoints) marking decimal or thousand in the money amount.
separators:[int];
+
+ // Mapping between a quantity string (e.g. "million") and the power of 10
+ // it multiplies the amount with (e.g. 6 in case of "million").
+ // NOTE: The entries need to be sorted by key since we use LookupByKey.
+ quantities_name_to_exponent:[MoneyParsingOptions_.QuantitiesNameToExponentEntry];
}
namespace libtextclassifier3.ModelTriggeringOptions_;
@@ -652,13 +593,16 @@
triggering_locales:string (shared);
embedding_pruning_mask:Model_.EmbeddingPruningMask;
- grammar_datetime_model:GrammarDatetimeModel;
+ reserved_25:int16 (deprecated);
contact_annotator_options:ContactAnnotatorOptions;
money_parsing_options:MoneyParsingOptions;
translate_annotator_options:TranslateAnnotatorOptions;
grammar_model:GrammarModel;
conflict_resolution_options:Model_.ConflictResolutionOptions;
experimental_model:ExperimentalModel;
+ pod_ner_model:PodNerModel;
+ vocab_model:VocabModel;
+ datetime_grammar_model:GrammarModel;
}
// Method for selecting the center token.
@@ -985,4 +929,120 @@
backoff_options:TranslateAnnotatorOptions_.BackoffOptions;
}
+namespace libtextclassifier3.PodNerModel_;
+table Collection {
+ // Collection's name (e.g., "location", "person").
+ name:string (shared);
+
+ // Priority scores used for conflict resolution with the other annotators
+ // when the annotation is made over a single/multi token text.
+ single_token_priority_score:float;
+
+ multi_token_priority_score:float;
+}
+
+namespace libtextclassifier3.PodNerModel_.Label_;
+enum BoiseType : int {
+ NONE = 0,
+ BEGIN = 1,
+ O = 2,
+ // No label.
+
+ INTERMEDIATE = 3,
+ SINGLE = 4,
+ END = 5,
+}
+
+namespace libtextclassifier3.PodNerModel_.Label_;
+enum MentionType : int {
+ UNDEFINED = 0,
+ NAM = 1,
+ NOM = 2,
+}
+
+namespace libtextclassifier3.PodNerModel_;
+table Label {
+ boise_type:Label_.BoiseType;
+ mention_type:Label_.MentionType;
+ collection_id:int;
+ // points to the collections array above.
+}
+
+namespace libtextclassifier3;
+table PodNerModel {
+ tflite_model:[ubyte];
+ word_piece_vocab:[ubyte];
+ lowercase_input:bool = true;
+
+ // Index of mention_logits tensor in the output of the tflite model. Can
+ // be found in the textproto output after model is converted to tflite.
+ logits_index_in_output_tensor:int = 0;
+
+ // Whether to append a period at the end of an input that doesn't already
+ // end in punctuation.
+ append_final_period:bool = false;
+
+ // Priority score used for conflict resolution with the other models. Used
+ // only if collections_array is empty.
+ priority_score:float = 0;
+
+ // Maximum number of wordpieces supported by the model.
+ max_num_wordpieces:int = 128;
+
+ // In case of long text (number of wordpieces greater than the max) we use
+ // sliding window approach, this determines the number of overlapping
+ // wordpieces between two consecutive windows. This overlap enables context
+ // for each word NER annotates.
+ sliding_window_num_wordpieces_overlap:int = 20;
+ reserved_9:int16 (deprecated);
+
+ // The possible labels the ner model can output. If empty the default labels
+ // will be used.
+ labels:[PodNerModel_.Label];
+
+ // If the ratio of unknown wordpieces in the input text is greater than this
+ // maximum, the text won't be annotated.
+ max_ratio_unknown_wordpieces:float = 0.1;
+
+ // Possible collections for labeled entities.
+ collections:[PodNerModel_.Collection];
+
+ // Minimum word-length and wordpieces-length required for the text to be
+ // annotated.
+ min_number_of_tokens:int = 1;
+
+ min_number_of_wordpieces:int = 1;
+}
+
+namespace libtextclassifier3;
+table VocabModel {
+ // A trie that stores a list of vocabs that triggers "Define". A id is
+ // returned when looking up a vocab from the trie and the id can be used
+ // to access more information about that vocab. The marisa trie library
+ // requires 8-byte alignment because the first thing in a marisa trie is a
+ // 64-bit integer.
+ vocab_trie:[ubyte] (force_align: 8);
+
+ // A bit vector that tells if the vocab should trigger "Define" for users of
+ // beginner proficiency only. To look up the bit vector, use the id returned
+ // by the trie.
+ beginner_level:BitVectorData;
+
+ // A sorted list of indices of vocabs that should not trigger "Define" if
+ // its leading character is in upper case. The indices are those returned by
+ // trie. You may perform binary search to look up an index.
+ do_not_trigger_in_upper_case:BitVectorData;
+
+ // Comma-separated list of locales (BCP 47 tags) that the model supports, that
+ // are used to prevent triggering on input in unsupported languages. If
+ // empty, the model will trigger on all inputs.
+ triggering_locales:string (shared);
+
+ // The final score to assign to the results of the vocab model
+ target_classification_score:float = 1;
+
+ // Priority score used for conflict resolution with the other models.
+ priority_score:float = 0;
+}
+
root_type libtextclassifier3.Model;
diff --git a/native/annotator/number/number_test-include.cc b/native/annotator/number/number_test-include.cc
new file mode 100644
index 0000000..f47933f
--- /dev/null
+++ b/native/annotator/number/number_test-include.cc
@@ -0,0 +1,1111 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/number/number_test-include.h"
+
+#include <string>
+#include <vector>
+
+#include "annotator/collections.h"
+#include "annotator/model_generated.h"
+#include "annotator/types-test-util.h"
+#include "annotator/types.h"
+#include "utils/tokenizer-utils.h"
+#include "utils/utf8/unicodetext.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+namespace test_internal {
+
+using ::testing::AllOf;
+using ::testing::ElementsAre;
+using ::testing::Field;
+using ::testing::Matcher;
+using ::testing::UnorderedElementsAre;
+
+const NumberAnnotatorOptions*
+NumberAnnotatorTest::TestingNumberAnnotatorOptions() {
+ static const flatbuffers::DetachedBuffer* options_data = []() {
+ NumberAnnotatorOptionsT options;
+ options.enabled = true;
+ options.priority_score = -10.0;
+ options.float_number_priority_score = 1.0;
+ options.enabled_annotation_usecases =
+ 1 << AnnotationUsecase_ANNOTATION_USECASE_RAW;
+ options.max_number_of_digits = 20;
+
+ options.percentage_priority_score = 1.0;
+ options.percentage_annotation_usecases =
+ (1 << AnnotationUsecase_ANNOTATION_USECASE_RAW) +
+ (1 << AnnotationUsecase_ANNOTATION_USECASE_SMART);
+ std::set<std::string> percent_suffixes({"パーセント", "percent", "pércént",
+ "pc", "pct", "%", "٪", "﹪", "%"});
+ for (const std::string& string_value : percent_suffixes) {
+ options.percentage_pieces_string.append(string_value);
+ options.percentage_pieces_string.push_back('\0');
+ }
+
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(NumberAnnotatorOptions::Pack(builder, &options));
+ return new flatbuffers::DetachedBuffer(builder.Release());
+ }();
+
+ return flatbuffers::GetRoot<NumberAnnotatorOptions>(options_data->data());
+}
+
+MATCHER_P(IsCorrectCollection, collection, "collection is " + collection) {
+ return arg.collection == collection;
+}
+
+MATCHER_P(IsCorrectNumericValue, numeric_value,
+ "numeric value is " + std::to_string(numeric_value)) {
+ return arg.numeric_value == numeric_value;
+}
+
+MATCHER_P(IsCorrectNumericDoubleValue, numeric_double_value,
+ "numeric double value is " + std::to_string(numeric_double_value)) {
+ return arg.numeric_double_value == numeric_double_value;
+}
+
+MATCHER_P(IsCorrectScore, score, "score is " + std::to_string(score)) {
+ return arg.score == score;
+}
+
+MATCHER_P(IsCorrectPriortyScore, priority_score,
+ "priority score is " + std::to_string(priority_score)) {
+ return arg.priority_score == priority_score;
+}
+
+MATCHER_P(IsCorrectSpan, span,
+ "span is (" + std::to_string(span.first) + "," +
+ std::to_string(span.second) + ")") {
+ return arg.span == span;
+}
+
+MATCHER_P(Classification, inner, "") {
+ return testing::ExplainMatchResult(inner, arg.classification,
+ result_listener);
+}
+
+static Matcher<AnnotatedSpan> IsAnnotatedSpan(
+ const CodepointSpan& codepoint_span, const std::string& collection,
+ const int int_value, const double double_value,
+ const float priority_score = -10, const float score = 1) {
+ return AllOf(
+ IsCorrectSpan(codepoint_span),
+ Classification(ElementsAre(AllOf(
+ IsCorrectCollection(collection), IsCorrectNumericValue(int_value),
+ IsCorrectNumericDoubleValue(double_value), IsCorrectScore(score),
+ IsCorrectPriortyScore(priority_score)))));
+}
+
+TEST_F(NumberAnnotatorTest, ClassifiesAndParsesNumberCorrectly) {
+ ClassificationResult classification_result;
+ EXPECT_TRUE(number_annotator_.ClassifyText(
+ UTF8ToUnicodeText("... 12345 ..."), {4, 9},
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
+
+ EXPECT_EQ(classification_result.collection, "number");
+ EXPECT_EQ(classification_result.numeric_value, 12345);
+ EXPECT_FLOAT_EQ(classification_result.numeric_double_value, 12345);
+}
+
+TEST_F(NumberAnnotatorTest, ClassifiesAndParsesNumberAsFloatCorrectly) {
+ ClassificationResult classification_result;
+ EXPECT_TRUE(number_annotator_.ClassifyText(
+ UTF8ToUnicodeText("... 12345.12345 ..."), {4, 15},
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
+
+ EXPECT_EQ(classification_result.collection, "number");
+ EXPECT_EQ(classification_result.numeric_value, 12345);
+ EXPECT_FLOAT_EQ(classification_result.numeric_double_value, 12345.12345);
+}
+
+TEST_F(NumberAnnotatorTest,
+ ClassifiesAndParsesNumberAsFloatCorrectlyWithoutDecimals) {
+ ClassificationResult classification_result;
+ // The dot after a number is considered punctuation, not part of a floating
+ // number.
+ EXPECT_TRUE(number_annotator_.ClassifyText(
+ UTF8ToUnicodeText("... 12345. ..."), {4, 9},
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
+ EXPECT_FALSE(number_annotator_.ClassifyText(
+ UTF8ToUnicodeText("... 12345. ..."), {4, 10},
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
+
+ EXPECT_EQ(classification_result.collection, "number");
+ EXPECT_EQ(classification_result.numeric_value, 12345);
+ EXPECT_FLOAT_EQ(classification_result.numeric_double_value, 12345);
+
+ EXPECT_TRUE(number_annotator_.ClassifyText(
+ UTF8ToUnicodeText("... 12345. ..."), {4, 9},
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
+ EXPECT_EQ(classification_result.collection, "number");
+ EXPECT_EQ(classification_result.numeric_value, 12345);
+ EXPECT_FLOAT_EQ(classification_result.numeric_double_value, 12345);
+}
+
+TEST_F(NumberAnnotatorTest, FindsAllIntegerAndFloatNumbersInText) {
+ std::vector<AnnotatedSpan> result;
+ // In the context "68.9#" -> 68.9 is a number because # is punctuation.
+ // In the context "68.9#?" -> 68.9 is not a number because is followed by two
+ // punctuation signs.
+ EXPECT_TRUE(number_annotator_.FindAll(
+ UTF8ToUnicodeText("how much is 2 plus 5 divided by 7% minus 3.14 "
+ "what about 68.9# or 68.9#?"),
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+
+ EXPECT_THAT(result,
+ UnorderedElementsAre(
+ IsAnnotatedSpan(CodepointSpan(12, 13), "number",
+ /*int_value=*/2, /*double_value=*/2.0),
+ IsAnnotatedSpan(CodepointSpan(19, 20), "number",
+ /*int_value=*/5, /*double_value=*/5.0),
+ IsAnnotatedSpan(CodepointSpan(32, 33), "number",
+ /*int_value=*/7, /*double_value=*/7.0),
+ IsAnnotatedSpan(CodepointSpan(32, 34), "percentage",
+ /*int_value=*/7, /*double_value=*/7.0,
+ /*priority_score=*/1),
+ IsAnnotatedSpan(CodepointSpan(41, 45), "number",
+ /*int_value=*/3, /*double_value=*/3.14,
+ /*priority_score=*/1),
+ IsAnnotatedSpan(CodepointSpan(57, 61), "number",
+ /*int_value=*/68, /*double_value=*/68.9,
+ /*priority_score=*/1)));
+}
+
+TEST_F(NumberAnnotatorTest, ClassifiesNonNumberCorrectly) {
+ ClassificationResult classification_result;
+ EXPECT_FALSE(number_annotator_.ClassifyText(
+ UTF8ToUnicodeText("... 123a45 ..."), {4, 10},
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
+ EXPECT_FALSE(number_annotator_.ClassifyText(
+ UTF8ToUnicodeText("... 12345..12345 ..."), {4, 16},
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
+ EXPECT_FALSE(number_annotator_.ClassifyText(
+ UTF8ToUnicodeText("... 12345a ..."), {4, 11},
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
+}
+
+TEST_F(NumberAnnotatorTest, ClassifiesNumberSelectionCorrectly) {
+ ClassificationResult classification_result;
+ // Punctuation after a number is not part of the number.
+ EXPECT_TRUE(number_annotator_.ClassifyText(
+ UTF8ToUnicodeText("... 14, ..."), {4, 6},
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
+ EXPECT_EQ(classification_result.collection, "number");
+ EXPECT_EQ(classification_result.numeric_value, 14);
+ EXPECT_EQ(classification_result.numeric_double_value, 14);
+
+ EXPECT_FALSE(number_annotator_.ClassifyText(
+ UTF8ToUnicodeText("... 14, ..."), {4, 7},
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
+}
+
+TEST_F(NumberAnnotatorTest, ClassifiesPercentageSignCorrectly) {
+ ClassificationResult classification_result;
+ EXPECT_TRUE(number_annotator_.ClassifyText(
+ UTF8ToUnicodeText("... 99% ..."), {4, 7},
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
+
+ EXPECT_EQ(classification_result.collection, "percentage");
+ EXPECT_EQ(classification_result.numeric_value, 99);
+ EXPECT_EQ(classification_result.numeric_double_value, 99);
+}
+
+TEST_F(NumberAnnotatorTest, ClassifiesPercentageWordCorrectly) {
+ ClassificationResult classification_result;
+ EXPECT_TRUE(number_annotator_.ClassifyText(
+ UTF8ToUnicodeText("... 15 percent ..."), {4, 14},
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
+
+ EXPECT_EQ(classification_result.collection, "percentage");
+ EXPECT_EQ(classification_result.numeric_value, 15);
+ EXPECT_EQ(classification_result.numeric_double_value, 15);
+}
+
+TEST_F(NumberAnnotatorTest, ClassifiesNonAsciiPercentageIncorrectSuffix) {
+ ClassificationResult classification_result;
+ EXPECT_FALSE(number_annotator_.ClassifyText(
+ UTF8ToUnicodeText("15 café"), {0, 7},
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
+}
+
+TEST_F(NumberAnnotatorTest, ClassifiesNonAsciiFrPercentageCorrectSuffix) {
+ ClassificationResult classification_result;
+ EXPECT_TRUE(number_annotator_.ClassifyText(
+ UTF8ToUnicodeText("25 pércént"), {0, 10},
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
+
+ EXPECT_EQ(classification_result.collection, "percentage");
+ EXPECT_EQ(classification_result.numeric_value, 25);
+ EXPECT_EQ(classification_result.numeric_double_value, 25);
+}
+
+TEST_F(NumberAnnotatorTest, ClassifiesNonAsciiJaPercentageCorrectSuffix) {
+ ClassificationResult classification_result;
+ EXPECT_TRUE(number_annotator_.ClassifyText(
+ UTF8ToUnicodeText("10パーセント"), {0, 7},
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
+ EXPECT_EQ(classification_result.collection, "percentage");
+ EXPECT_EQ(classification_result.numeric_value, 10);
+ EXPECT_EQ(classification_result.numeric_double_value, 10);
+
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(number_annotator_.FindAll(
+ UTF8ToUnicodeText("明日の降水確率は10パーセント 音量を12にセット"),
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ EXPECT_THAT(result,
+ UnorderedElementsAre(
+ IsAnnotatedSpan(CodepointSpan(8, 10), "number",
+ /*int_value=*/10, /*double_value=*/10.0),
+ IsAnnotatedSpan(CodepointSpan(8, 15), "percentage",
+ /*int_value=*/10, /*double_value=*/10.0,
+ /*priority_score=*/1),
+ IsAnnotatedSpan(CodepointSpan(20, 22), "number",
+ /*int_value=*/12, /*double_value=*/12.0)));
+}
+
+TEST_F(NumberAnnotatorTest, FindsAllNumbersInText) {
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(number_annotator_.FindAll(
+ UTF8ToUnicodeText("... 12345 ... 9 is my number and 27% or 68# #38 #39 "
+ "but not $99."),
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+
+ EXPECT_THAT(
+ result,
+ UnorderedElementsAre(
+ IsAnnotatedSpan(CodepointSpan(4, 9), "number",
+ /*int_value=*/12345, /*double_value=*/12345.0),
+ IsAnnotatedSpan(CodepointSpan(14, 15), "number",
+ /*int_value=*/9, /*double_value=*/9.0),
+ IsAnnotatedSpan(CodepointSpan(33, 35), "number",
+ /*int_value=*/27, /*double_value=*/27.0),
+ IsAnnotatedSpan(CodepointSpan(33, 36), "percentage",
+ /*int_value=*/27, /*double_value=*/27.0,
+ /*priority_score=*/1),
+ IsAnnotatedSpan(CodepointSpan(40, 42), "number",
+ /*int_value=*/68, /*double_value=*/68.0),
+ IsAnnotatedSpan(CodepointSpan(45, 47), "number",
+ /*int_value=*/38, /*double_value=*/38.0),
+ IsAnnotatedSpan(CodepointSpan(49, 51), "number",
+ /*int_value=*/39, /*double_value=*/39.0)));
+}
+
+TEST_F(NumberAnnotatorTest, FindsNoNumberInText) {
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(number_annotator_.FindAll(
+ UTF8ToUnicodeText("... 12345a ... 12345..12345 and 123a45 are not valid. "
+ "And -#5% is also bad."),
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ ASSERT_EQ(result.size(), 0);
+}
+
+TEST_F(NumberAnnotatorTest, FindsNumberWithPunctuation) {
+ std::vector<AnnotatedSpan> result;
+ // A number should be followed by only one punctuation signs => 15 is not a
+ // number.
+ EXPECT_TRUE(number_annotator_.FindAll(
+ UTF8ToUnicodeText(
+ "It's 12, 13, 14! Or 15??? For sure 16: 17; 18. and -19"),
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+
+ EXPECT_THAT(result,
+ UnorderedElementsAre(
+ IsAnnotatedSpan(CodepointSpan(5, 7), "number",
+ /*int_value=*/12, /*double_value=*/12.0),
+ IsAnnotatedSpan(CodepointSpan(9, 11), "number",
+ /*int_value=*/13, /*double_value=*/13.0),
+ IsAnnotatedSpan(CodepointSpan(13, 15), "number",
+ /*int_value=*/14, /*double_value=*/14.0),
+ IsAnnotatedSpan(CodepointSpan(35, 37), "number",
+ /*int_value=*/16, /*double_value=*/16.0),
+ IsAnnotatedSpan(CodepointSpan(39, 41), "number",
+ /*int_value=*/17, /*double_value=*/17.0),
+ IsAnnotatedSpan(CodepointSpan(43, 45), "number",
+ /*int_value=*/18, /*double_value=*/18.0),
+ IsAnnotatedSpan(CodepointSpan(51, 54), "number",
+ /*int_value=*/-19, /*double_value=*/-19.0)));
+}
+
+TEST_F(NumberAnnotatorTest, FindsFloatNumberWithPunctuation) {
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(number_annotator_.FindAll(
+ UTF8ToUnicodeText("It's 12.123, 13.45, 14.54321! Or 15.1? Maybe 16.33: "
+ "17.21; but for sure 18.90."),
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+
+ EXPECT_THAT(result,
+ UnorderedElementsAre(
+ IsAnnotatedSpan(CodepointSpan(5, 11), "number",
+ /*int_value=*/12, /*double_value=*/12.123,
+ /*priority_score=*/1),
+ IsAnnotatedSpan(CodepointSpan(13, 18), "number",
+ /*int_value=*/13, /*double_value=*/13.45,
+ /*priority_score=*/1),
+ IsAnnotatedSpan(CodepointSpan(20, 28), "number",
+ /*int_value=*/14, /*double_value=*/14.54321,
+ /*priority_score=*/1),
+ IsAnnotatedSpan(CodepointSpan(33, 37), "number",
+ /*int_value=*/15, /*double_value=*/15.1,
+ /*priority_score=*/1),
+ IsAnnotatedSpan(CodepointSpan(45, 50), "number",
+ /*int_value=*/16, /*double_value=*/16.33,
+ /*priority_score=*/1),
+ IsAnnotatedSpan(CodepointSpan(52, 57), "number",
+ /*int_value=*/17, /*double_value=*/17.21,
+ /*priority_score=*/1),
+ IsAnnotatedSpan(CodepointSpan(72, 77), "number",
+ /*int_value=*/18, /*double_value=*/18.9,
+ /*priority_score=*/1)));
+}
+
+TEST_F(NumberAnnotatorTest, HandlesNumbersAtBeginning) {
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(number_annotator_.FindAll(
+ UTF8ToUnicodeText("-5"), AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ &result));
+
+ EXPECT_THAT(result, UnorderedElementsAre(IsAnnotatedSpan(
+ CodepointSpan(0, 2), "number",
+ /*int_value=*/-5, /*double_value=*/-5)));
+}
+
+TEST_F(NumberAnnotatorTest, HandlesNegativeNumbers) {
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(number_annotator_.FindAll(
+ UTF8ToUnicodeText("Number -5 and -5% and not number --5%"),
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+
+ EXPECT_THAT(result,
+ UnorderedElementsAre(
+ IsAnnotatedSpan(CodepointSpan(7, 9), "number",
+ /*int_value=*/-5, /*double_value=*/-5),
+ IsAnnotatedSpan(CodepointSpan(14, 16), "number",
+ /*int_value=*/-5, /*double_value=*/-5),
+ IsAnnotatedSpan(CodepointSpan(14, 17), "percentage",
+ /*int_value=*/-5, /*double_value=*/-5,
+ /*priority_score=*/1)));
+}
+
+TEST_F(NumberAnnotatorTest, FindGoodPercentageContexts) {
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(number_annotator_.FindAll(
+ UTF8ToUnicodeText(
+ "5 percent, 10 pct, 25 pc and 17%, -5 percent, 10% are percentages"),
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+
+ EXPECT_THAT(result,
+ UnorderedElementsAre(
+ IsAnnotatedSpan(CodepointSpan(0, 1), "number",
+ /*int_value=*/5, /*double_value=*/5),
+ IsAnnotatedSpan(CodepointSpan(0, 9), "percentage",
+ /*int_value=*/5, /*double_value=*/5,
+ /*priority_score=*/1),
+ IsAnnotatedSpan(CodepointSpan(11, 13), "number",
+ /*int_value=*/10, /*double_value=*/10),
+ IsAnnotatedSpan(CodepointSpan(11, 17), "percentage",
+ /*int_value=*/10, /*double_value=*/10,
+ /*priority_score=*/1),
+ IsAnnotatedSpan(CodepointSpan(19, 21), "number",
+ /*int_value=*/25, /*double_value=*/25),
+ IsAnnotatedSpan(CodepointSpan(19, 24), "percentage",
+ /*int_value=*/25, /*double_value=*/25,
+ /*priority_score=*/1),
+ IsAnnotatedSpan(CodepointSpan(29, 31), "number",
+ /*int_value=*/17, /*double_value=*/17),
+ IsAnnotatedSpan(CodepointSpan(29, 32), "percentage",
+ /*int_value=*/17, /*double_value=*/17,
+ /*priority_score=*/1),
+ IsAnnotatedSpan(CodepointSpan(34, 36), "number",
+ /*int_value=*/-5, /*double_value=*/-5),
+ IsAnnotatedSpan(CodepointSpan(34, 44), "percentage",
+ /*int_value=*/-5, /*double_value=*/-5,
+ /*priority_score=*/1),
+ IsAnnotatedSpan(CodepointSpan(46, 48), "number",
+ /*int_value=*/10, /*double_value=*/10),
+ IsAnnotatedSpan(CodepointSpan(46, 49), "percentage",
+ /*int_value=*/10, /*double_value=*/10,
+ /*priority_score=*/1)));
+}
+
+TEST_F(NumberAnnotatorTest, FindSinglePercentageInContext) {
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(number_annotator_.FindAll(
+ UTF8ToUnicodeText("5%"), AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ &result));
+
+ EXPECT_THAT(result, UnorderedElementsAre(
+ IsAnnotatedSpan(CodepointSpan(0, 1), "number",
+ /*int_value=*/5, /*double_value=*/5),
+ IsAnnotatedSpan(CodepointSpan(0, 2), "percentage",
+ /*int_value=*/5, /*double_value=*/5,
+ /*priority_score=*/1)));
+}
+
+TEST_F(NumberAnnotatorTest, IgnoreBadPercentageContexts) {
+ std::vector<AnnotatedSpan> result;
+ // A valid number is followed by only one punctuation element.
+ EXPECT_TRUE(number_annotator_.FindAll(
+ UTF8ToUnicodeText("10, pct, 25 prc, 5#: percentage are not percentages"),
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+
+ EXPECT_THAT(result,
+ UnorderedElementsAre(
+ IsAnnotatedSpan(CodepointSpan(0, 2), "number",
+ /*int_value=*/10, /*double_value=*/10),
+ IsAnnotatedSpan(CodepointSpan(9, 11), "number",
+ /*int_value=*/25, /*double_value=*/25)));
+}
+
+TEST_F(NumberAnnotatorTest, IgnoreBadPercentagePunctuationContexts) {
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(number_annotator_.FindAll(
+ UTF8ToUnicodeText(
+ "#!24% or :?33 percent are not valid percentages, nor numbers."),
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+
+ EXPECT_TRUE(result.empty());
+}
+
+TEST_F(NumberAnnotatorTest, FindPercentageInNonAsciiContext) {
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(number_annotator_.FindAll(
+ UTF8ToUnicodeText(
+ "At the café 10% or 25 percent of people are nice. Only 10%!"),
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+
+ EXPECT_THAT(result,
+ UnorderedElementsAre(
+ IsAnnotatedSpan(CodepointSpan(12, 14), "number",
+ /*int_value=*/10, /*double_value=*/10),
+ IsAnnotatedSpan(CodepointSpan(12, 15), "percentage",
+ /*int_value=*/10, /*double_value=*/10,
+ /*priority_score=*/1),
+ IsAnnotatedSpan(CodepointSpan(19, 21), "number",
+ /*int_value=*/25, /*double_value=*/25),
+ IsAnnotatedSpan(CodepointSpan(19, 29), "percentage",
+ /*int_value=*/25, /*double_value=*/25,
+ /*priority_score=*/1),
+ IsAnnotatedSpan(CodepointSpan(55, 57), "number",
+ /*int_value=*/10, /*double_value=*/10),
+ IsAnnotatedSpan(CodepointSpan(55, 58), "percentage",
+ /*int_value=*/10, /*double_value=*/10,
+ /*priority_score=*/1)));
+}
+
+TEST_F(NumberAnnotatorTest,
+ WhenPercentSuffixWithAdditionalIgnoredCharactersDoesNotParseIt) {
+ ClassificationResult classification_result;
+ EXPECT_FALSE(number_annotator_.ClassifyText(
+ UTF8ToUnicodeText("23#!? percent"), {0, 13},
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
+}
+
+TEST_F(NumberAnnotatorTest,
+ WhenPercentSuffixWithAdditionalRandomTokensDoesNotParseIt) {
+ ClassificationResult classification_result;
+ EXPECT_FALSE(number_annotator_.ClassifyText(
+ UTF8ToUnicodeText("23 asdf 3.14 pct asdf"), {0, 21},
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
+}
+
+TEST_F(NumberAnnotatorTest,
+ WhenPercentSuffixWithAdditionalRandomPrefixSuffixDoesNotParseIt) {
+ ClassificationResult classification_result;
+ EXPECT_FALSE(number_annotator_.ClassifyText(
+ UTF8ToUnicodeText("abdf23 percentabdf"), {0, 18},
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
+}
+
+TEST_F(NumberAnnotatorTest,
+ WhenPercentSuffixWithAdditionalRandomStringsDoesNotParsesIt) {
+ ClassificationResult classification_result;
+ EXPECT_FALSE(number_annotator_.ClassifyText(
+ UTF8ToUnicodeText("#?!23 percent#!?"), {0, 16},
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
+}
+
+TEST_F(NumberAnnotatorTest, WhenBothPercentSymbolAndSuffixDoesNotParseIt) {
+ ClassificationResult classification_result;
+ EXPECT_FALSE(number_annotator_.ClassifyText(
+ UTF8ToUnicodeText("23% percent"), {0, 11},
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
+}
+
+TEST_F(NumberAnnotatorTest,
+ WhenPercentSymbolWithAdditionalPrefixCharactersDoesNotParsesIt) {
+ ClassificationResult classification_result;
+ EXPECT_FALSE(number_annotator_.ClassifyText(
+ UTF8ToUnicodeText("#?23%"), {0, 5},
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
+}
+
+TEST_F(NumberAnnotatorTest, WhenNumberWithAdditionalCharactersDoesNotParsesIt) {
+ ClassificationResult classification_result;
+ EXPECT_FALSE(number_annotator_.ClassifyText(
+ UTF8ToUnicodeText("23#!?"), {0, 5},
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
+}
+
+TEST_F(NumberAnnotatorTest,
+ WhenPercentSymbolWithAdditionalCharactersDoesNotParsesIt) {
+ ClassificationResult classification_result;
+ // ! does not belong to the percentage annotation
+ EXPECT_TRUE(number_annotator_.ClassifyText(
+ UTF8ToUnicodeText("23%!"), {0, 3},
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
+ EXPECT_EQ(classification_result.collection, "percentage");
+ EXPECT_EQ(classification_result.numeric_value, 23);
+ EXPECT_EQ(classification_result.numeric_double_value, 23);
+
+ EXPECT_FALSE(number_annotator_.ClassifyText(
+ UTF8ToUnicodeText("23%!"), {0, 4},
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
+}
+
+TEST_F(NumberAnnotatorTest,
+ WhenAdditionalCharactersWithMisplacedPercentSymbolDoesNotParsesIt) {
+ ClassificationResult classification_result;
+ EXPECT_FALSE(number_annotator_.ClassifyText(
+ UTF8ToUnicodeText("23.:;%"), {0, 6},
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
+}
+
+TEST_F(NumberAnnotatorTest, WhenMultipleMinusSignsDoesNotParsesIt) {
+ ClassificationResult classification_result;
+ EXPECT_TRUE(number_annotator_.ClassifyText(
+ UTF8ToUnicodeText("--11"), {1, 4},
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
+ EXPECT_THAT(classification_result,
+ AllOf(Field(&ClassificationResult::collection, "number"),
+ Field(&ClassificationResult::numeric_value, -11),
+ Field(&ClassificationResult::numeric_double_value, -11)));
+
+ EXPECT_FALSE(number_annotator_.ClassifyText(
+ UTF8ToUnicodeText("--11"), {0, 4},
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
+}
+
+TEST_F(NumberAnnotatorTest, WhenMultipleMinusSignsPercentSignDoesNotParsesIt) {
+ ClassificationResult classification_result;
+ EXPECT_TRUE(number_annotator_.ClassifyText(
+ UTF8ToUnicodeText("--11%"), {1, 5},
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
+ EXPECT_THAT(classification_result,
+ AllOf(Field(&ClassificationResult::collection, "percentage"),
+ Field(&ClassificationResult::numeric_value, -11),
+ Field(&ClassificationResult::numeric_double_value, -11)));
+
+ EXPECT_FALSE(number_annotator_.ClassifyText(
+ UTF8ToUnicodeText("--11%"), {0, 5},
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
+}
+
+TEST_F(NumberAnnotatorTest, WhenPlusMinusSignsDoesNotParsesIt) {
+ ClassificationResult classification_result;
+ EXPECT_TRUE(number_annotator_.ClassifyText(
+ UTF8ToUnicodeText("+-11"), {1, 4},
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
+ EXPECT_THAT(classification_result,
+ AllOf(Field(&ClassificationResult::collection, "number"),
+ Field(&ClassificationResult::numeric_value, -11),
+ Field(&ClassificationResult::numeric_double_value, -11)));
+
+ EXPECT_FALSE(number_annotator_.ClassifyText(
+ UTF8ToUnicodeText("+-11"), {0, 4},
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
+}
+
+TEST_F(NumberAnnotatorTest, WhenMinusPlusSignsDoesNotParsesIt) {
+ ClassificationResult classification_result;
+ // + right before a number is not included in the number annotation
+ EXPECT_FALSE(number_annotator_.ClassifyText(
+ UTF8ToUnicodeText("-+11"), {1, 4},
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
+ EXPECT_FALSE(number_annotator_.ClassifyText(
+ UTF8ToUnicodeText("-+11"), {0, 4},
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
+}
+
+TEST_F(NumberAnnotatorTest, WhenMinusSignSuffixDoesNotParsesIt) {
+ ClassificationResult classification_result;
+ EXPECT_FALSE(number_annotator_.ClassifyText(
+ UTF8ToUnicodeText("10-"), {0, 3},
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
+}
+
+TEST_F(NumberAnnotatorTest, WhenMultipleCharSuffixDoesNotParsesIt) {
+ ClassificationResult classification_result;
+ EXPECT_TRUE(number_annotator_.ClassifyText(
+ UTF8ToUnicodeText("10**"), {0, 2},
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
+ EXPECT_THAT(classification_result,
+ AllOf(Field(&ClassificationResult::collection, "number"),
+ Field(&ClassificationResult::numeric_value, 10),
+ Field(&ClassificationResult::numeric_double_value, 10)));
+
+ EXPECT_FALSE(number_annotator_.ClassifyText(
+ UTF8ToUnicodeText("10**"), {0, 3},
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
+ EXPECT_FALSE(number_annotator_.ClassifyText(
+ UTF8ToUnicodeText("10**"), {0, 4},
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
+}
+
+TEST_F(NumberAnnotatorTest, WhenMultipleCharPrefixDoesNotParsesIt) {
+ ClassificationResult classification_result;
+ EXPECT_FALSE(number_annotator_.ClassifyText(
+ UTF8ToUnicodeText("**10"), {1, 4},
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
+ EXPECT_FALSE(number_annotator_.ClassifyText(
+ UTF8ToUnicodeText("**10"), {0, 4},
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
+}
+
+TEST_F(NumberAnnotatorTest, WhenLowestSupportedNumberParsesIt) {
+ ClassificationResult classification_result;
+ EXPECT_TRUE(number_annotator_.ClassifyText(
+ UTF8ToUnicodeText("-1000000000"), {0, 11},
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
+
+ EXPECT_THAT(
+ classification_result,
+ AllOf(Field(&ClassificationResult::collection, "number"),
+ Field(&ClassificationResult::numeric_value, -1000000000),
+ Field(&ClassificationResult::numeric_double_value, -1000000000)));
+}
+
+TEST_F(NumberAnnotatorTest, WhenLargestSupportedNumberParsesIt) {
+ ClassificationResult classification_result;
+ EXPECT_TRUE(number_annotator_.ClassifyText(
+ UTF8ToUnicodeText("1000000000"), {0, 10},
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
+
+ EXPECT_THAT(
+ classification_result,
+ AllOf(Field(&ClassificationResult::collection, "number"),
+ Field(&ClassificationResult::numeric_value, 1000000000),
+ Field(&ClassificationResult::numeric_double_value, 1000000000)));
+}
+
+TEST_F(NumberAnnotatorTest, WhenLowestSupportedFloatNumberParsesIt) {
+ ClassificationResult classification_result;
+ EXPECT_TRUE(number_annotator_.ClassifyText(
+ UTF8ToUnicodeText("-999999999.999999999"), {0, 20},
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
+
+ EXPECT_THAT(classification_result,
+ AllOf(Field(&ClassificationResult::collection, "number"),
+ Field(&ClassificationResult::numeric_value, -1000000000),
+ Field(&ClassificationResult::numeric_double_value,
+ -999999999.999999999)));
+}
+
+TEST_F(NumberAnnotatorTest, WhenLargestFloatSupportedNumberParsesIt) {
+ ClassificationResult classification_result;
+ EXPECT_TRUE(number_annotator_.ClassifyText(
+ UTF8ToUnicodeText("999999999.999999999"), {0, 19},
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
+
+ EXPECT_THAT(classification_result,
+ AllOf(Field(&ClassificationResult::collection, "number"),
+ Field(&ClassificationResult::numeric_value, 1000000000),
+ Field(&ClassificationResult::numeric_double_value,
+ 999999999.999999999)));
+}
+
+TEST_F(NumberAnnotatorTest, WhenLargeNumberDoesNotParseIt) {
+ ClassificationResult classification_result;
+ EXPECT_FALSE(number_annotator_.ClassifyText(
+ UTF8ToUnicodeText("1234567890123456789012345678901234567890"), {0, 40},
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
+}
+
+TEST_F(NumberAnnotatorTest, WhenMinusInTheMiddleDoesNotParseIt) {
+ ClassificationResult classification_result;
+ EXPECT_FALSE(number_annotator_.ClassifyText(
+ UTF8ToUnicodeText("2016-2017"), {0, 9},
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
+}
+
+TEST_F(NumberAnnotatorTest, WhenSuffixWithoutNumberDoesNotParseIt) {
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(number_annotator_.FindAll(
+ UTF8ToUnicodeText("... % ..."), AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ &result));
+
+ ASSERT_EQ(result.size(), 0);
+}
+
+TEST_F(NumberAnnotatorTest, WhenPrefixWithoutNumberDoesNotParseIt) {
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(number_annotator_.FindAll(
+ UTF8ToUnicodeText("... $ ..."), AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ &result));
+
+ ASSERT_EQ(result.size(), 0);
+}
+
+TEST_F(NumberAnnotatorTest, WhenPrefixAndSuffixWithoutNumberDoesNotParseIt) {
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(number_annotator_.FindAll(
+ UTF8ToUnicodeText("... $% ..."), AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ &result));
+
+ ASSERT_EQ(result.size(), 0);
+}
+
+TEST_F(NumberAnnotatorTest, ForNumberAnnotationsSetsScoreAndPriorityScore) {
+ ClassificationResult classification_result;
+ EXPECT_TRUE(number_annotator_.ClassifyText(
+ UTF8ToUnicodeText("... 12345 ..."), {4, 9},
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
+
+ EXPECT_EQ(classification_result.collection, "number");
+ EXPECT_EQ(classification_result.numeric_value, 12345);
+ EXPECT_EQ(classification_result.numeric_double_value, 12345);
+ EXPECT_EQ(classification_result.score, 1);
+ EXPECT_EQ(classification_result.priority_score, -10);
+
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(number_annotator_.FindAll(
+ UTF8ToUnicodeText("Come at 9 or 10 ok?"),
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+
+ EXPECT_THAT(result,
+ UnorderedElementsAre(
+ IsAnnotatedSpan(CodepointSpan(8, 9), "number",
+ /*int_value=*/9, /*double_value=*/9),
+ IsAnnotatedSpan(CodepointSpan(13, 15), "number",
+ /*int_value=*/10, /*double_value=*/10)));
+}
+
+TEST_F(NumberAnnotatorTest,
+ ForFloatNumberAnnotationsSetsScoreAndPriorityScore) {
+ ClassificationResult classification_result;
+ EXPECT_TRUE(number_annotator_.ClassifyText(
+ UTF8ToUnicodeText("... 12345.12345 ..."), {4, 15},
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
+ EXPECT_EQ(classification_result.collection, "number");
+ EXPECT_EQ(classification_result.numeric_value, 12345);
+ EXPECT_EQ(classification_result.numeric_double_value, 12345.12345);
+ EXPECT_EQ(classification_result.score, 1);
+ EXPECT_EQ(classification_result.priority_score, 1);
+
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(number_annotator_.FindAll(
+ UTF8ToUnicodeText("Results are between 12.5 and 13.5, right?"),
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ EXPECT_THAT(result,
+ UnorderedElementsAre(
+ IsAnnotatedSpan(CodepointSpan(20, 24), "number",
+ /*int_value=*/12, /*double_value=*/12.5,
+ /*priority_score=*/1),
+ IsAnnotatedSpan(CodepointSpan(29, 33), "number",
+ /*int_value=*/13, /*double_value=*/13.5,
+ /*priority_score=*/1)));
+}
+
+TEST_F(NumberAnnotatorTest, ForPercentageAnnotationsSetsScoreAndPriorityScore) {
+ ClassificationResult classification_result;
+ EXPECT_TRUE(number_annotator_.ClassifyText(
+ UTF8ToUnicodeText("... 12345% ..."), {4, 10},
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
+ EXPECT_EQ(classification_result.collection, "percentage");
+ EXPECT_EQ(classification_result.numeric_value, 12345);
+ EXPECT_EQ(classification_result.numeric_double_value, 12345);
+ EXPECT_EQ(classification_result.score, 1);
+ EXPECT_EQ(classification_result.priority_score, 1);
+
+ EXPECT_TRUE(number_annotator_.ClassifyText(
+ UTF8ToUnicodeText("... 12345 percent ..."), {4, 17},
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
+ EXPECT_EQ(classification_result.collection, "percentage");
+ EXPECT_EQ(classification_result.numeric_value, 12345);
+ EXPECT_EQ(classification_result.numeric_double_value, 12345);
+ EXPECT_EQ(classification_result.score, 1);
+ EXPECT_EQ(classification_result.priority_score, 1);
+
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(number_annotator_.FindAll(
+ UTF8ToUnicodeText("Results are between 9% and 10 percent."),
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ EXPECT_THAT(result,
+ UnorderedElementsAre(
+ IsAnnotatedSpan(CodepointSpan(20, 21), "number",
+ /*int_value=*/9, /*double_value=*/9),
+ IsAnnotatedSpan(CodepointSpan(20, 22), "percentage",
+ /*int_value=*/9, /*double_value=*/9,
+ /*priority_score=*/1),
+ IsAnnotatedSpan(CodepointSpan(27, 29), "number",
+ /*int_value=*/10, /*double_value=*/10),
+ IsAnnotatedSpan(CodepointSpan(27, 37), "percentage",
+ /*int_value=*/10, /*double_value=*/10,
+ /*priority_score=*/1)));
+}
+
+TEST_F(NumberAnnotatorTest, NumberDisabledPercentageEnabledForSmartUsecase) {
+ ClassificationResult classification_result;
+ EXPECT_FALSE(number_annotator_.ClassifyText(
+ UTF8ToUnicodeText("... 12345 ..."), {4, 9},
+ AnnotationUsecase_ANNOTATION_USECASE_SMART, &classification_result));
+
+ EXPECT_TRUE(number_annotator_.ClassifyText(
+ UTF8ToUnicodeText("... 12345% ..."), {4, 10},
+ AnnotationUsecase_ANNOTATION_USECASE_SMART, &classification_result));
+ EXPECT_EQ(classification_result.collection, "percentage");
+ EXPECT_EQ(classification_result.numeric_value, 12345);
+ EXPECT_EQ(classification_result.numeric_double_value, 12345.0);
+ EXPECT_EQ(classification_result.score, 1);
+ EXPECT_EQ(classification_result.priority_score, 1);
+
+ EXPECT_TRUE(number_annotator_.ClassifyText(
+ UTF8ToUnicodeText("... 12345percent ..."), {4, 16},
+ AnnotationUsecase_ANNOTATION_USECASE_SMART, &classification_result));
+ EXPECT_EQ(classification_result.collection, "percentage");
+ EXPECT_EQ(classification_result.numeric_value, 12345);
+ EXPECT_EQ(classification_result.numeric_double_value, 12345);
+ EXPECT_EQ(classification_result.score, 1);
+ EXPECT_EQ(classification_result.priority_score, 1);
+
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(number_annotator_.FindAll(
+ UTF8ToUnicodeText("Accuracy for experiment 3 is 9%."),
+ AnnotationUsecase_ANNOTATION_USECASE_SMART, &result));
+ EXPECT_THAT(result, UnorderedElementsAre(
+ IsAnnotatedSpan(CodepointSpan(29, 31), "percentage",
+ /*int_value=*/9, /*double_value=*/9.0,
+ /*priority_score=*/1)));
+}
+
+TEST_F(NumberAnnotatorTest, MathOperatorsNotAnnotatedAsNumbersFindAll) {
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(number_annotator_.FindAll(
+ UTF8ToUnicodeText("how much is 2 + 2 or 5 - 96 * 89"),
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+
+ EXPECT_THAT(result,
+ UnorderedElementsAre(
+ IsAnnotatedSpan(CodepointSpan(12, 13), "number",
+ /*int_value=*/2, /*double_value=*/2),
+ IsAnnotatedSpan(CodepointSpan(16, 17), "number",
+ /*int_value=*/2, /*double_value=*/2),
+ IsAnnotatedSpan(CodepointSpan(21, 22), "number",
+ /*int_value=*/5, /*double_value=*/5),
+ IsAnnotatedSpan(CodepointSpan(25, 27), "number",
+ /*int_value=*/96, /*double_value=*/96),
+ IsAnnotatedSpan(CodepointSpan(30, 32), "number",
+ /*int_value=*/89, /*double_value=*/89)));
+}
+
+TEST_F(NumberAnnotatorTest, MathOperatorsNotAnnotatedAsNumbersClassifyText) {
+ ClassificationResult classification_result;
+ EXPECT_FALSE(number_annotator_.ClassifyText(
+ UTF8ToUnicodeText("2 + 2"), {2, 3},
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
+ EXPECT_FALSE(number_annotator_.ClassifyText(
+ UTF8ToUnicodeText("2 - 96 * 89"), {2, 3},
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
+}
+
+TEST_F(NumberAnnotatorTest, SlashSeparatesTwoNumbersFindAll) {
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(number_annotator_.FindAll(
+ UTF8ToUnicodeText("what's 1 + 2/3 * 4/5 * 6 / 7"),
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+
+ EXPECT_THAT(result,
+ UnorderedElementsAre(
+ IsAnnotatedSpan(CodepointSpan(7, 8), "number",
+ /*int_value=*/1, /*double_value=*/1),
+ IsAnnotatedSpan(CodepointSpan(11, 12), "number",
+ /*int_value=*/2, /*double_value=*/2),
+ IsAnnotatedSpan(CodepointSpan(13, 14), "number",
+ /*int_value=*/3, /*double_value=*/3),
+ IsAnnotatedSpan(CodepointSpan(17, 18), "number",
+ /*int_value=*/4, /*double_value=*/4),
+ IsAnnotatedSpan(CodepointSpan(19, 20), "number",
+ /*int_value=*/5, /*double_value=*/5),
+ IsAnnotatedSpan(CodepointSpan(23, 24), "number",
+ /*int_value=*/6, /*double_value=*/6),
+ IsAnnotatedSpan(CodepointSpan(27, 28), "number",
+ /*int_value=*/7, /*double_value=*/7)));
+}
+
+TEST_F(NumberAnnotatorTest, SlashSeparatesTwoNumbersClassifyText) {
+ ClassificationResult classification_result;
+ EXPECT_TRUE(number_annotator_.ClassifyText(
+ UTF8ToUnicodeText("what's 1 + 2/3 * 4"), {11, 12},
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
+ EXPECT_EQ(classification_result.collection, "number");
+ EXPECT_EQ(classification_result.numeric_value, 2);
+ EXPECT_EQ(classification_result.numeric_double_value, 2);
+ EXPECT_EQ(classification_result.score, 1);
+
+ EXPECT_TRUE(number_annotator_.ClassifyText(
+ UTF8ToUnicodeText("what's 1 + 2/3 * 4"), {13, 14},
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
+ EXPECT_EQ(classification_result.collection, "number");
+ EXPECT_EQ(classification_result.numeric_value, 3);
+ EXPECT_EQ(classification_result.numeric_double_value, 3);
+ EXPECT_EQ(classification_result.score, 1);
+}
+
+TEST_F(NumberAnnotatorTest, SlashDoesNotSeparatesTwoNumbersFindAll) {
+ std::vector<AnnotatedSpan> result;
+ // 2 in the "2/" context is a number because / is punctuation
+ EXPECT_TRUE(number_annotator_.FindAll(
+ UTF8ToUnicodeText("what's 2a2/3 or 2/s4 or 2/ or /3 or //3 or 2//"),
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+
+ EXPECT_THAT(result, UnorderedElementsAre(IsAnnotatedSpan(
+ CodepointSpan(24, 25), "number",
+ /*int_value=*/2, /*double_value=*/2)));
+}
+
+TEST_F(NumberAnnotatorTest, BracketsContextAnnotatedFindAll) {
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(number_annotator_.FindAll(
+ UTF8ToUnicodeText("The interval is: (12, 13) or [-12, -4.5)"),
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+
+ EXPECT_THAT(result,
+ UnorderedElementsAre(
+ IsAnnotatedSpan(CodepointSpan(18, 20), "number",
+ /*int_value=*/12, /*double_value=*/12),
+ IsAnnotatedSpan(CodepointSpan(22, 24), "number",
+ /*int_value=*/13, /*double_value=*/13),
+ IsAnnotatedSpan(CodepointSpan(30, 33), "number",
+ /*int_value=*/-12, /*double_value=*/-12),
+ IsAnnotatedSpan(CodepointSpan(35, 39), "number",
+ /*int_value=*/-4, /*double_value=*/-4.5,
+ /*priority_score=*/1)));
+}
+
+TEST_F(NumberAnnotatorTest, BracketsContextNotAnnotatedFindAll) {
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(number_annotator_.FindAll(
+ UTF8ToUnicodeText("The interval is: -(12, 138*)"),
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+
+ EXPECT_TRUE(result.empty());
+}
+
+TEST_F(NumberAnnotatorTest, FractionalNumberDotsFindAll) {
+ std::vector<AnnotatedSpan> result;
+ // Dots source: https://unicode-search.net/unicode-namesearch.pl?term=period
+ EXPECT_TRUE(number_annotator_.FindAll(
+ UTF8ToUnicodeText("3.1 3﹒2 3.3"),
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+
+ EXPECT_THAT(result, UnorderedElementsAre(
+ IsAnnotatedSpan(CodepointSpan(0, 3), "number",
+ /*int_value=*/3, /*double_value=*/3.1,
+ /*priority_score=*/1),
+ IsAnnotatedSpan(CodepointSpan(4, 7), "number",
+ /*int_value=*/3, /*double_value=*/3.2,
+ /*priority_score=*/1),
+ IsAnnotatedSpan(CodepointSpan(8, 11), "number",
+ /*int_value=*/3, /*double_value=*/3.3,
+ /*priority_score=*/1)));
+}
+
+TEST_F(NumberAnnotatorTest, NonAsciiDigitsFindAll) {
+ std::vector<AnnotatedSpan> result;
+ // Dots source: https://unicode-search.net/unicode-namesearch.pl?term=period
+ // Digits source: https://unicode-search.net/unicode-namesearch.pl?term=digit
+ EXPECT_TRUE(number_annotator_.FindAll(
+ UTF8ToUnicodeText("3 3﹒2 3.3%"),
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+
+ EXPECT_THAT(result, UnorderedElementsAre(
+ IsAnnotatedSpan(CodepointSpan(0, 1), "number",
+ /*int_value=*/3, /*double_value=*/3),
+ IsAnnotatedSpan(CodepointSpan(2, 5), "number",
+ /*int_value=*/3, /*double_value=*/3.2,
+ /*priority_score=*/1),
+ IsAnnotatedSpan(CodepointSpan(6, 9), "number",
+ /*int_value=*/3, /*double_value=*/3.3,
+ /*priority_score=*/1),
+ IsAnnotatedSpan(CodepointSpan(6, 10), "percentage",
+ /*int_value=*/3, /*double_value=*/3.3,
+ /*priority_score=*/1)));
+}
+
+TEST_F(NumberAnnotatorTest, AnnotatedZeroPrecededNumbersFindAll) {
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(number_annotator_.FindAll(
+ UTF8ToUnicodeText("Numbers: 0.9 or 09 or 09.9 or 032310"),
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+
+ EXPECT_THAT(result, UnorderedElementsAre(
+ IsAnnotatedSpan(CodepointSpan(9, 12), "number",
+ /*int_value=*/0, /*double_value=*/0.9,
+ /*priority_score=*/1),
+ IsAnnotatedSpan(CodepointSpan(16, 18), "number",
+ /*int_value=*/9, /*double_value=*/9),
+ IsAnnotatedSpan(CodepointSpan(22, 26), "number",
+ /*int_value=*/9, /*double_value=*/9.9,
+ /*priority_score=*/1),
+ IsAnnotatedSpan(CodepointSpan(30, 36), "number",
+ /*int_value=*/32310,
+ /*double_value=*/32310)));
+}
+
+TEST_F(NumberAnnotatorTest, ZeroAfterDotFindAll) {
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(number_annotator_.FindAll(
+ UTF8ToUnicodeText("15.0 16.00"), AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ &result));
+
+ EXPECT_THAT(result,
+ UnorderedElementsAre(
+ IsAnnotatedSpan(CodepointSpan(0, 4), "number",
+ /*int_value=*/15, /*double_value=*/15),
+ IsAnnotatedSpan(CodepointSpan(5, 10), "number",
+ /*int_value=*/16, /*double_value=*/16)));
+}
+
+TEST_F(NumberAnnotatorTest, NineDotNineFindAll) {
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(number_annotator_.FindAll(
+ UTF8ToUnicodeText("9.9 9.99 99.99 99.999 99.9999"),
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+
+ EXPECT_THAT(result,
+ UnorderedElementsAre(
+ IsAnnotatedSpan(CodepointSpan(0, 3), "number",
+ /*int_value=*/9, /*double_value=*/9.9,
+ /*priority_score=*/1),
+ IsAnnotatedSpan(CodepointSpan(4, 8), "number",
+ /*int_value=*/9, /*double_value=*/9.99,
+ /*priority_score=*/1),
+ IsAnnotatedSpan(CodepointSpan(9, 14), "number",
+ /*int_value=*/99, /*double_value=*/99.99,
+ /*priority_score=*/1),
+ IsAnnotatedSpan(CodepointSpan(15, 21), "number",
+ /*int_value=*/99, /*double_value=*/99.999,
+ /*priority_score=*/1),
+ IsAnnotatedSpan(CodepointSpan(22, 29), "number",
+ /*int_value=*/99, /*double_value=*/99.9999,
+ /*priority_score=*/1)));
+}
+
+} // namespace test_internal
+} // namespace libtextclassifier3
diff --git a/native/annotator/number/number_test-include.h b/native/annotator/number/number_test-include.h
new file mode 100644
index 0000000..9de7c86
--- /dev/null
+++ b/native/annotator/number/number_test-include.h
@@ -0,0 +1,42 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_NUMBER_NUMBER_TEST_INCLUDE_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_NUMBER_NUMBER_TEST_INCLUDE_H_
+
+#include "annotator/number/number.h"
+#include "utils/jvm-test-utils.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+namespace test_internal {
+
+class NumberAnnotatorTest : public ::testing::Test {
+ protected:
+ NumberAnnotatorTest()
+ : unilib_(CreateUniLibForTesting()),
+ number_annotator_(TestingNumberAnnotatorOptions(), unilib_.get()) {}
+
+ const NumberAnnotatorOptions* TestingNumberAnnotatorOptions();
+
+ std::unique_ptr<UniLib> unilib_;
+ NumberAnnotator number_annotator_;
+};
+
+} // namespace test_internal
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_NUMBER_NUMBER_TEST_INCLUDE_H_
diff --git a/native/annotator/person_name/person_name_model.fbs b/native/annotator/person_name/person_name_model.fbs
old mode 100755
new mode 100644
diff --git a/native/annotator/pod_ner/pod-ner-impl.cc b/native/annotator/pod_ner/pod-ner-impl.cc
new file mode 100644
index 0000000..666b7c7
--- /dev/null
+++ b/native/annotator/pod_ner/pod-ner-impl.cc
@@ -0,0 +1,520 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/pod_ner/pod-ner-impl.h"
+
+#include <algorithm>
+#include <cstdint>
+#include <ctime>
+#include <iostream>
+#include <memory>
+#include <ostream>
+#include <unordered_set>
+#include <vector>
+
+#include "annotator/model_generated.h"
+#include "annotator/pod_ner/utils.h"
+#include "annotator/types.h"
+#include "utils/base/logging.h"
+#include "utils/bert_tokenizer.h"
+#include "utils/tflite-model-executor.h"
+#include "utils/tokenizer-utils.h"
+#include "utils/utf8/unicodetext.h"
+#include "absl/strings/ascii.h"
+#include "tensorflow/lite/kernels/builtin_op_kernels.h"
+#include "tensorflow/lite/mutable_op_resolver.h"
+#include "tensorflow_lite_support/cc/text/tokenizers/tokenizer.h"
+#include "tensorflow_models/seq_flow_lite/tflite_ops/layer_norm.h"
+#include "tensorflow_models/seq_flow_lite/tflite_ops/quantization_util.h"
+
+namespace libtextclassifier3 {
+
+using PodNerModel_::CollectionT;
+using PodNerModel_::LabelT;
+using ::tflite::support::text::tokenizer::TokenizerResult;
+
+namespace {
+
+using PodNerModel_::Label_::BoiseType;
+using PodNerModel_::Label_::BoiseType_BEGIN;
+using PodNerModel_::Label_::BoiseType_END;
+using PodNerModel_::Label_::BoiseType_INTERMEDIATE;
+using PodNerModel_::Label_::BoiseType_O;
+using PodNerModel_::Label_::BoiseType_SINGLE;
+using PodNerModel_::Label_::MentionType;
+using PodNerModel_::Label_::MentionType_NAM;
+using PodNerModel_::Label_::MentionType_NOM;
+using PodNerModel_::Label_::MentionType_UNDEFINED;
+
+void EmplaceToLabelVector(BoiseType boise_type, MentionType mention_type,
+ int collection_id, std::vector<LabelT> *labels) {
+ labels->emplace_back();
+ labels->back().boise_type = boise_type;
+ labels->back().mention_type = mention_type;
+ labels->back().collection_id = collection_id;
+}
+
+void FillDefaultLabelsAndCollections(float default_priority,
+ std::vector<LabelT> *labels,
+ std::vector<CollectionT> *collections) {
+ std::vector<std::string> collection_names = {
+ "art", "consumer_good", "event", "location",
+ "organization", "ner_entity", "person", "undefined"};
+ collections->clear();
+ for (const std::string &collection_name : collection_names) {
+ collections->emplace_back();
+ collections->back().name = collection_name;
+ collections->back().single_token_priority_score = default_priority;
+ collections->back().multi_token_priority_score = default_priority;
+ }
+
+ labels->clear();
+ for (auto boise_type :
+ {BoiseType_BEGIN, BoiseType_END, BoiseType_INTERMEDIATE}) {
+ for (auto mention_type : {MentionType_NAM, MentionType_NOM}) {
+ for (int i = 0; i < collections->size() - 1; ++i) { // skip undefined
+ EmplaceToLabelVector(boise_type, mention_type, i, labels);
+ }
+ }
+ }
+ EmplaceToLabelVector(BoiseType_O, MentionType_UNDEFINED, 7, labels);
+ for (auto mention_type : {MentionType_NAM, MentionType_NOM}) {
+ for (int i = 0; i < collections->size() - 1; ++i) { // skip undefined
+ EmplaceToLabelVector(BoiseType_SINGLE, mention_type, i, labels);
+ }
+ }
+}
+
+std::unique_ptr<tflite::Interpreter> CreateInterpreter(
+ const PodNerModel *model) {
+ TC3_CHECK(model != nullptr);
+ if (model->tflite_model() == nullptr) {
+ TC3_LOG(ERROR) << "Unable to create tf.lite interpreter, model is null.";
+ return nullptr;
+ }
+
+ const tflite::Model *tflite_model =
+ tflite::GetModel(model->tflite_model()->Data());
+ if (tflite_model == nullptr) {
+ TC3_LOG(ERROR) << "Unable to create tf.lite interpreter, model is null.";
+ return nullptr;
+ }
+
+ std::unique_ptr<tflite::OpResolver> resolver =
+ BuildOpResolver([](tflite::MutableOpResolver *mutable_resolver) {
+ mutable_resolver->AddBuiltin(::tflite::BuiltinOperator_SHAPE,
+ ::tflite::ops::builtin::Register_SHAPE());
+ mutable_resolver->AddBuiltin(::tflite::BuiltinOperator_RANGE,
+ ::tflite::ops::builtin::Register_RANGE());
+ mutable_resolver->AddBuiltin(
+ ::tflite::BuiltinOperator_ARG_MAX,
+ ::tflite::ops::builtin::Register_ARG_MAX());
+ mutable_resolver->AddBuiltin(
+ ::tflite::BuiltinOperator_EXPAND_DIMS,
+ ::tflite::ops::builtin::Register_EXPAND_DIMS());
+ mutable_resolver->AddCustom(
+ "LayerNorm", ::seq_flow_lite::ops::custom::Register_LAYER_NORM());
+ });
+
+ std::unique_ptr<tflite::Interpreter> tflite_interpreter;
+ tflite::InterpreterBuilder(tflite_model, *resolver,
+ nullptr)(&tflite_interpreter);
+ if (tflite_interpreter == nullptr) {
+ TC3_LOG(ERROR) << "Unable to create tf.lite interpreter.";
+ return nullptr;
+ }
+ return tflite_interpreter;
+}
+
+bool FindSpecialWordpieceIds(const std::unique_ptr<BertTokenizer> &tokenizer,
+ int *cls_id, int *sep_id, int *period_id,
+ int *unknown_id) {
+ if (!tokenizer->LookupId("[CLS]", cls_id)) {
+ TC3_LOG(ERROR) << "Couldn't find [CLS] wordpiece.";
+ return false;
+ }
+ if (!tokenizer->LookupId("[SEP]", sep_id)) {
+ TC3_LOG(ERROR) << "Couldn't find [SEP] wordpiece.";
+ return false;
+ }
+ if (!tokenizer->LookupId(".", period_id)) {
+ TC3_LOG(ERROR) << "Couldn't find [.] wordpiece.";
+ return false;
+ }
+ if (!tokenizer->LookupId("[UNK]", unknown_id)) {
+ TC3_LOG(ERROR) << "Couldn't find [UNK] wordpiece.";
+ return false;
+ }
+ return true;
+}
+// WARNING: This tokenizer is not exactly the one the model was trained with
+// so there might be nuances.
+std::unique_ptr<BertTokenizer> CreateTokenizer(const PodNerModel *model) {
+ TC3_CHECK(model != nullptr);
+ if (model->word_piece_vocab() == nullptr) {
+ TC3_LOG(ERROR)
+ << "Unable to create tokenizer, model or word_pieces is null.";
+ return nullptr;
+ }
+
+ return std::unique_ptr<BertTokenizer>(new BertTokenizer(
+ reinterpret_cast<const char *>(model->word_piece_vocab()->Data()),
+ model->word_piece_vocab()->size()));
+}
+
+} // namespace
+
+std::unique_ptr<PodNerAnnotator> PodNerAnnotator::Create(
+ const PodNerModel *model, const UniLib &unilib) {
+ if (model == nullptr) {
+ TC3_LOG(ERROR) << "Create received null model.";
+ return nullptr;
+ }
+
+ std::unique_ptr<BertTokenizer> tokenizer = CreateTokenizer(model);
+ if (tokenizer == nullptr) {
+ return nullptr;
+ }
+
+ int cls_id, sep_id, period_id, unknown_wordpiece_id;
+ if (!FindSpecialWordpieceIds(tokenizer, &cls_id, &sep_id, &period_id,
+ &unknown_wordpiece_id)) {
+ return nullptr;
+ }
+
+ std::unique_ptr<PodNerAnnotator> annotator(new PodNerAnnotator(unilib));
+ annotator->tokenizer_ = std::move(tokenizer);
+ annotator->lowercase_input_ = model->lowercase_input();
+ annotator->logits_index_in_output_tensor_ =
+ model->logits_index_in_output_tensor();
+ annotator->append_final_period_ = model->append_final_period();
+ if (model->labels() && model->labels()->size() > 0 && model->collections() &&
+ model->collections()->size() > 0) {
+ annotator->labels_.clear();
+ for (const PodNerModel_::Label *label : *model->labels()) {
+ annotator->labels_.emplace_back();
+ annotator->labels_.back().boise_type = label->boise_type();
+ annotator->labels_.back().mention_type = label->mention_type();
+ annotator->labels_.back().collection_id = label->collection_id();
+ }
+ for (const PodNerModel_::Collection *collection : *model->collections()) {
+ annotator->collections_.emplace_back();
+ annotator->collections_.back().name = collection->name()->str();
+ annotator->collections_.back().single_token_priority_score =
+ collection->single_token_priority_score();
+ annotator->collections_.back().multi_token_priority_score =
+ collection->multi_token_priority_score();
+ }
+ } else {
+ FillDefaultLabelsAndCollections(
+ model->priority_score(), &annotator->labels_, &annotator->collections_);
+ }
+ int max_num_surrounding_wordpieces = model->append_final_period() ? 3 : 2;
+ annotator->max_num_effective_wordpieces_ =
+ model->max_num_wordpieces() - max_num_surrounding_wordpieces;
+ annotator->sliding_window_num_wordpieces_overlap_ =
+ model->sliding_window_num_wordpieces_overlap();
+ annotator->max_ratio_unknown_wordpieces_ =
+ model->max_ratio_unknown_wordpieces();
+ annotator->min_number_of_tokens_ = model->min_number_of_tokens();
+ annotator->min_number_of_wordpieces_ = model->min_number_of_wordpieces();
+ annotator->cls_wordpiece_id_ = cls_id;
+ annotator->sep_wordpiece_id_ = sep_id;
+ annotator->period_wordpiece_id_ = period_id;
+ annotator->unknown_wordpiece_id_ = unknown_wordpiece_id;
+ annotator->model_ = model;
+
+ return annotator;
+}
+
+std::vector<LabelT> PodNerAnnotator::ReadResultsFromInterpreter(
+ tflite::Interpreter &interpreter) const {
+ TfLiteTensor *output =
+ interpreter.tensor(interpreter.outputs()[logits_index_in_output_tensor_]);
+ TC3_CHECK_EQ(output->dims->size, 3);
+ TC3_CHECK_EQ(output->dims->data[0], 1);
+ TC3_CHECK_EQ(output->dims->data[2], labels_.size());
+ std::vector<LabelT> return_value(output->dims->data[1]);
+ std::vector<float> probs(output->dims->data[1]);
+ for (int step = 0, index = 0; step < output->dims->data[1]; ++step) {
+ float max_prob = 0.0f;
+ int max_index = 0;
+ for (int cindex = 0; cindex < output->dims->data[2]; ++cindex) {
+ const float probability =
+ ::seq_flow_lite::PodDequantize(*output, index++);
+ if (probability > max_prob) {
+ max_prob = probability;
+ max_index = cindex;
+ }
+ }
+ return_value[step] = labels_[max_index];
+ probs[step] = max_prob;
+ }
+ return return_value;
+}
+
+std::vector<LabelT> PodNerAnnotator::ExecuteModel(
+ const VectorSpan<int> &wordpiece_indices,
+ const VectorSpan<int32_t> &token_starts,
+ const VectorSpan<Token> &tokens) const {
+ // Check that there are not more input indices than supported.
+ if (wordpiece_indices.size() > max_num_effective_wordpieces_) {
+ TC3_LOG(ERROR) << "More than " << max_num_effective_wordpieces_
+ << " indices passed to POD NER model.";
+ return {};
+ }
+ if (wordpiece_indices.size() <= 0 || token_starts.size() <= 0 ||
+ tokens.size() <= 0) {
+ TC3_LOG(ERROR) << "ExecuteModel received illegal input, #wordpiece_indices="
+ << wordpiece_indices.size()
+ << " #token_starts=" << token_starts.size()
+ << " #tokens=" << tokens.size();
+ return {};
+ }
+
+ // For the CLS (at the beginning) and SEP (at the end) wordpieces.
+ int num_additional_wordpieces = 2;
+ bool should_append_final_period = false;
+ // Optionally add a final period wordpiece if the final token is not
+ // already punctuation. This can improve performance for models trained on
+ // data mostly ending in sentence-final punctuation.
+ const std::string &last_token = (tokens.end() - 1)->value;
+ if (append_final_period_ &&
+ (last_token.size() != 1 || !unilib_.IsPunctuation(last_token.at(0)))) {
+ should_append_final_period = true;
+ num_additional_wordpieces++;
+ }
+
+ // Interpreter needs to be created for each inference call separately,
+ // otherwise the class is not thread-safe.
+ std::unique_ptr<tflite::Interpreter> interpreter = CreateInterpreter(model_);
+ if (interpreter == nullptr) {
+ TC3_LOG(ERROR) << "Couldn't create Interpreter.";
+ return {};
+ }
+
+ TfLiteStatus status;
+ status = interpreter->ResizeInputTensor(
+ interpreter->inputs()[0],
+ {1, wordpiece_indices.size() + num_additional_wordpieces});
+ TC3_CHECK_EQ(status, kTfLiteOk);
+ status = interpreter->ResizeInputTensor(interpreter->inputs()[1],
+ {1, token_starts.size()});
+ TC3_CHECK_EQ(status, kTfLiteOk);
+
+ status = interpreter->AllocateTensors();
+ TC3_CHECK_EQ(status, kTfLiteOk);
+
+ TfLiteTensor *tensor = interpreter->tensor(interpreter->inputs()[0]);
+ int wordpiece_tensor_index = 0;
+ tensor->data.i32[wordpiece_tensor_index++] = cls_wordpiece_id_;
+ for (int wordpiece_index : wordpiece_indices) {
+ tensor->data.i32[wordpiece_tensor_index++] = wordpiece_index;
+ }
+
+ if (should_append_final_period) {
+ tensor->data.i32[wordpiece_tensor_index++] = period_wordpiece_id_;
+ }
+ tensor->data.i32[wordpiece_tensor_index++] = sep_wordpiece_id_;
+
+ tensor = interpreter->tensor(interpreter->inputs()[1]);
+ for (int i = 0; i < token_starts.size(); ++i) {
+ // Need to add one because of the starting CLS wordpiece and reduce the
+ // offset from the first wordpiece.
+ tensor->data.i32[i] = token_starts[i] + 1 - token_starts[0];
+ }
+
+ status = interpreter->Invoke();
+ TC3_CHECK_EQ(status, kTfLiteOk);
+
+ return ReadResultsFromInterpreter(*interpreter);
+}
+
+bool PodNerAnnotator::PrepareText(const UnicodeText &text_unicode,
+ std::vector<int32_t> *wordpiece_indices,
+ std::vector<int32_t> *token_starts,
+ std::vector<Token> *tokens) const {
+ *tokens = TokenizeOnWhiteSpacePunctuationAndChineseLetter(
+ text_unicode.ToUTF8String());
+ tokens->erase(std::remove_if(tokens->begin(), tokens->end(),
+ [](const Token &token) {
+ return token.start == token.end;
+ }),
+ tokens->end());
+
+ for (const Token &token : *tokens) {
+ const std::string token_text =
+ lowercase_input_ ? unilib_
+ .ToLowerText(UTF8ToUnicodeText(
+ token.value, /*do_copy=*/false))
+ .ToUTF8String()
+ : token.value;
+
+ const TokenizerResult wordpiece_tokenization =
+ tokenizer_->TokenizeSingleToken(token_text);
+
+ std::vector<int> wordpiece_ids;
+ for (const std::string &wordpiece : wordpiece_tokenization.subwords) {
+ if (!tokenizer_->LookupId(wordpiece, &(wordpiece_ids.emplace_back()))) {
+ TC3_LOG(ERROR) << "Couldn't find wordpiece " << wordpiece;
+ return false;
+ }
+ }
+
+ if (wordpiece_ids.empty()) {
+ TC3_LOG(ERROR) << "wordpiece_ids.empty()";
+ return false;
+ }
+ token_starts->push_back(wordpiece_indices->size());
+ for (const int64 wordpiece_id : wordpiece_ids) {
+ wordpiece_indices->push_back(wordpiece_id);
+ }
+ }
+
+ return true;
+}
+
+bool PodNerAnnotator::Annotate(const UnicodeText &context,
+ std::vector<AnnotatedSpan> *results) const {
+ return AnnotateAroundSpanOfInterest(context, {0, context.size_codepoints()},
+ results);
+}
+
+bool PodNerAnnotator::AnnotateAroundSpanOfInterest(
+ const UnicodeText &context, const CodepointSpan &span_of_interest,
+ std::vector<AnnotatedSpan> *results) const {
+ TC3_CHECK(results != nullptr);
+
+ std::vector<int32_t> wordpiece_indices;
+ std::vector<int32_t> token_starts;
+ std::vector<Token> tokens;
+ if (!PrepareText(context, &wordpiece_indices, &token_starts, &tokens)) {
+ TC3_LOG(ERROR) << "PodNerAnnotator PrepareText(...) failed.";
+ return false;
+ }
+ const int unknown_wordpieces_count =
+ std::count(wordpiece_indices.begin(), wordpiece_indices.end(),
+ unknown_wordpiece_id_);
+ if (tokens.empty() || tokens.size() < min_number_of_tokens_ ||
+ wordpiece_indices.size() < min_number_of_wordpieces_ ||
+ (static_cast<float>(unknown_wordpieces_count) /
+ wordpiece_indices.size()) > max_ratio_unknown_wordpieces_) {
+ return true;
+ }
+
+ std::vector<LabelT> labels;
+ int first_token_index_entire_window = 0;
+
+ WindowGenerator window_generator(
+ wordpiece_indices, token_starts, tokens, max_num_effective_wordpieces_,
+ sliding_window_num_wordpieces_overlap_, span_of_interest);
+ while (!window_generator.Done()) {
+ VectorSpan<int32_t> cur_wordpiece_indices;
+ VectorSpan<int32_t> cur_token_starts;
+ VectorSpan<Token> cur_tokens;
+ if (!window_generator.Next(&cur_wordpiece_indices, &cur_token_starts,
+ &cur_tokens) ||
+ cur_tokens.size() <= 0 || cur_token_starts.size() <= 0 ||
+ cur_wordpiece_indices.size() <= 0) {
+ return false;
+ }
+ std::vector<LabelT> new_labels =
+ ExecuteModel(cur_wordpiece_indices, cur_token_starts, cur_tokens);
+ if (labels.empty()) { // First loop.
+ first_token_index_entire_window = cur_tokens.begin() - tokens.begin();
+ }
+ if (!MergeLabelsIntoLeftSequence(
+ /*labels_right=*/new_labels,
+ /*index_first_right_tag_in_left=*/cur_tokens.begin() -
+ tokens.begin() - first_token_index_entire_window,
+ /*labels_left=*/&labels)) {
+ return false;
+ }
+ }
+
+ if (labels.empty()) {
+ return false;
+ }
+ ConvertTagsToAnnotatedSpans(
+ VectorSpan<Token>(tokens.begin() + first_token_index_entire_window,
+ tokens.end()),
+ labels, collections_, {PodNerModel_::Label_::MentionType_NAM},
+ /*relaxed_inside_label_matching=*/false,
+ /*relaxed_mention_type_matching=*/false, results);
+
+ return true;
+}
+
+bool PodNerAnnotator::SuggestSelection(const UnicodeText &context,
+ CodepointSpan click,
+ AnnotatedSpan *result) const {
+ TC3_VLOG(INFO) << "POD NER SuggestSelection " << click;
+ std::vector<AnnotatedSpan> annotations;
+ if (!AnnotateAroundSpanOfInterest(context, click, &annotations)) {
+ TC3_VLOG(INFO) << "POD NER SuggestSelection: Annotate error. Returning: "
+ << click;
+ *result = {};
+ return false;
+ }
+
+ for (const AnnotatedSpan &annotation : annotations) {
+ TC3_VLOG(INFO) << "POD NER SuggestSelection: " << annotation;
+ if (annotation.span.first <= click.first &&
+ annotation.span.second >= click.second) {
+ TC3_VLOG(INFO) << "POD NER SuggestSelection: Accepted.";
+ *result = annotation;
+ return true;
+ }
+ }
+
+ TC3_VLOG(INFO)
+ << "POD NER SuggestSelection: No annotation matched click. Returning: "
+ << click;
+ *result = {};
+ return false;
+}
+
+bool PodNerAnnotator::ClassifyText(const UnicodeText &context,
+ CodepointSpan click,
+ ClassificationResult *result) const {
+ TC3_VLOG(INFO) << "POD NER ClassifyText " << click;
+ std::vector<AnnotatedSpan> annotations;
+ if (!AnnotateAroundSpanOfInterest(context, click, &annotations)) {
+ return false;
+ }
+
+ for (const AnnotatedSpan &annotation : annotations) {
+ if (annotation.span.first <= click.first &&
+ annotation.span.second >= click.second) {
+ if (annotation.classification.empty()) {
+ return false;
+ }
+ *result = annotation.classification[0];
+ return true;
+ }
+ }
+ return false;
+}
+
+std::vector<std::string> PodNerAnnotator::GetSupportedCollections() const {
+ std::vector<std::string> result;
+ for (const PodNerModel_::CollectionT &collection : collections_) {
+ result.push_back(collection.name);
+ }
+ return result;
+}
+
+} // namespace libtextclassifier3
diff --git a/native/annotator/pod_ner/pod-ner-impl.h b/native/annotator/pod_ner/pod-ner-impl.h
new file mode 100644
index 0000000..2dd2a33
--- /dev/null
+++ b/native/annotator/pod_ner/pod-ner-impl.h
@@ -0,0 +1,95 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_POD_NER_POD_NER_IMPL_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_POD_NER_POD_NER_IMPL_H_
+
+#include <memory>
+
+#include "annotator/model_generated.h"
+#include "annotator/types.h"
+#include "utils/bert_tokenizer.h"
+#include "utils/utf8/unicodetext.h"
+#include "utils/utf8/unilib.h"
+#include "tensorflow/lite/context.h"
+#include "tensorflow/lite/interpreter.h"
+#include "tensorflow/lite/kernels/register.h"
+#include "tensorflow/lite/string_util.h"
+
+namespace libtextclassifier3 {
+
+// Uses POD NER BERT-based model for annotating various types of entities.
+class PodNerAnnotator {
+ public:
+ static std::unique_ptr<PodNerAnnotator> Create(const PodNerModel *model,
+ const UniLib &unilib);
+
+ bool Annotate(const UnicodeText &context,
+ std::vector<AnnotatedSpan> *results) const;
+
+ // Returns true if an entity was detected under 'click', and the selection
+ // indices expanded and assigned to 'result'. Otherwise returns false, and
+ // resets 'result'.
+ bool SuggestSelection(const UnicodeText &context, CodepointSpan click,
+ AnnotatedSpan *result) const;
+
+ bool ClassifyText(const UnicodeText &context, CodepointSpan click,
+ ClassificationResult *result) const;
+
+ std::vector<std::string> GetSupportedCollections() const;
+
+ private:
+ explicit PodNerAnnotator(const UniLib &unilib) : unilib_(unilib) {}
+
+ std::vector<PodNerModel_::LabelT> ReadResultsFromInterpreter(
+ tflite::Interpreter &interpreter) const;
+
+ std::vector<PodNerModel_::LabelT> ExecuteModel(
+ const VectorSpan<int> &wordpiece_indices,
+ const VectorSpan<int32_t> &token_starts,
+ const VectorSpan<Token> &tokens) const;
+
+ bool PrepareText(const UnicodeText &text_unicode,
+ std::vector<int32_t> *wordpiece_indices,
+ std::vector<int32_t> *token_starts,
+ std::vector<Token> *tokens) const;
+
+ bool AnnotateAroundSpanOfInterest(const UnicodeText &context,
+ const CodepointSpan &span_of_interest,
+ std::vector<AnnotatedSpan> *results) const;
+
+ const UniLib &unilib_;
+ bool lowercase_input_;
+ int logits_index_in_output_tensor_;
+ bool append_final_period_;
+ int max_num_effective_wordpieces_;
+ int sliding_window_num_wordpieces_overlap_;
+ float max_ratio_unknown_wordpieces_;
+ int min_number_of_tokens_;
+ int min_number_of_wordpieces_;
+ int cls_wordpiece_id_;
+ int sep_wordpiece_id_;
+ int period_wordpiece_id_;
+ int unknown_wordpiece_id_;
+ std::vector<PodNerModel_::CollectionT> collections_;
+ std::vector<PodNerModel_::LabelT> labels_;
+ std::unique_ptr<BertTokenizer> tokenizer_;
+ const PodNerModel *model_;
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_POD_NER_POD_NER_IMPL_H_
diff --git a/native/annotator/pod_ner/pod-ner-impl_test.cc b/native/annotator/pod_ner/pod-ner-impl_test.cc
new file mode 100644
index 0000000..c7d0bee
--- /dev/null
+++ b/native/annotator/pod_ner/pod-ner-impl_test.cc
@@ -0,0 +1,562 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/pod_ner/pod-ner-impl.h"
+
+#include <iostream>
+#include <memory>
+#include <thread> // NOLINT(build/c++11)
+
+#include "annotator/model_generated.h"
+#include "annotator/types.h"
+#include "utils/jvm-test-utils.h"
+#include "utils/test-data-test-utils.h"
+#include "utils/tokenizer-utils.h"
+#include "utils/utf8/unicodetext.h"
+#include "utils/utf8/unilib.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+using ::testing::IsEmpty;
+using ::testing::Not;
+
+using PodNerModel_::Label_::BoiseType;
+using PodNerModel_::Label_::BoiseType_BEGIN;
+using PodNerModel_::Label_::BoiseType_END;
+using PodNerModel_::Label_::BoiseType_INTERMEDIATE;
+using PodNerModel_::Label_::BoiseType_O;
+using PodNerModel_::Label_::BoiseType_SINGLE;
+using PodNerModel_::Label_::MentionType;
+using PodNerModel_::Label_::MentionType_NAM;
+using PodNerModel_::Label_::MentionType_NOM;
+using PodNerModel_::Label_::MentionType_UNDEFINED;
+
+constexpr int kMinNumberOfTokens = 1;
+constexpr int kMinNumberOfWordpieces = 1;
+constexpr float kDefaultPriorityScore = 0.5;
+
+class PodNerTest : public testing::Test {
+ protected:
+ PodNerTest() {
+ PodNerModelT model;
+
+ model.min_number_of_tokens = kMinNumberOfTokens;
+ model.min_number_of_wordpieces = kMinNumberOfWordpieces;
+ model.priority_score = kDefaultPriorityScore;
+
+ const std::string tflite_model_buffer =
+ GetTestFileContent("annotator/pod_ner/test_data/tflite_model.tflite");
+ model.tflite_model = std::vector<uint8_t>(tflite_model_buffer.begin(),
+ tflite_model_buffer.end());
+ const std::string word_piece_vocab_buffer =
+ GetTestFileContent("annotator/pod_ner/test_data/vocab.txt");
+ model.word_piece_vocab = std::vector<uint8_t>(
+ word_piece_vocab_buffer.begin(), word_piece_vocab_buffer.end());
+
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(PodNerModel::Pack(builder, &model));
+
+ model_buffer_ =
+ std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize());
+ model_ = static_cast<const PodNerModel*>(
+ flatbuffers::GetRoot<PodNerModel>(model_buffer_.data()));
+
+ model.append_final_period = true;
+ flatbuffers::FlatBufferBuilder builder_append_final_period;
+ builder_append_final_period.Finish(
+ PodNerModel::Pack(builder_append_final_period, &model));
+
+ model_buffer_append_final_period_ =
+ std::string(reinterpret_cast<const char*>(
+ builder_append_final_period.GetBufferPointer()),
+ builder_append_final_period.GetSize());
+ model_append_final_period_ =
+ static_cast<const PodNerModel*>(flatbuffers::GetRoot<PodNerModel>(
+ model_buffer_append_final_period_.data()));
+
+ unilib_ = CreateUniLibForTesting();
+ }
+
+ std::string model_buffer_;
+ const PodNerModel* model_;
+ std::string model_buffer_append_final_period_;
+ const PodNerModel* model_append_final_period_;
+ std::unique_ptr<UniLib> unilib_;
+};
+
+TEST_F(PodNerTest, AnnotateSmokeTest) {
+ std::unique_ptr<PodNerAnnotator> annotator =
+ PodNerAnnotator::Create(model_, *unilib_);
+ ASSERT_TRUE(annotator != nullptr);
+
+ {
+ std::vector<AnnotatedSpan> annotations;
+ ASSERT_TRUE(annotator->Annotate(
+ UTF8ToUnicodeText("Google New York , in New York"), &annotations));
+ EXPECT_THAT(annotations, Not(IsEmpty()));
+ }
+
+ {
+ std::vector<AnnotatedSpan> annotations;
+ ASSERT_TRUE(annotator->Annotate(
+ UTF8ToUnicodeText("Jamie I'm in the first picture and Cameron and Zach "
+ "are in the second "
+ "picture."),
+ &annotations));
+ EXPECT_THAT(annotations, Not(IsEmpty()));
+ }
+}
+
+TEST_F(PodNerTest, AnnotateEmptyInput) {
+ std::unique_ptr<PodNerAnnotator> annotator =
+ PodNerAnnotator::Create(model_, *unilib_);
+ ASSERT_TRUE(annotator != nullptr);
+
+ {
+ std::vector<AnnotatedSpan> annotations;
+ ASSERT_TRUE(annotator->Annotate(UTF8ToUnicodeText(""), &annotations));
+ EXPECT_THAT(annotations, IsEmpty());
+ }
+}
+
+void FillCollections(
+ const std::vector<std::string>& collection_names,
+ const std::vector<float>& single_token_priority_scores,
+ const std::vector<float>& multi_token_priority_scores,
+ std::vector<std::unique_ptr<PodNerModel_::CollectionT>>* collections) {
+ ASSERT_TRUE(collection_names.size() == single_token_priority_scores.size() &&
+ collection_names.size() == multi_token_priority_scores.size());
+ collections->clear();
+ for (int i = 0; i < collection_names.size(); ++i) {
+ collections->push_back(std::make_unique<PodNerModel_::CollectionT>());
+ collections->back()->name = collection_names[i];
+ collections->back()->single_token_priority_score =
+ single_token_priority_scores[i];
+ collections->back()->multi_token_priority_score =
+ multi_token_priority_scores[i];
+ }
+}
+
+void EmplaceToLabelVector(
+ BoiseType boise_type, MentionType mention_type, int collection_id,
+ std::vector<std::unique_ptr<PodNerModel_::LabelT>>* labels) {
+ labels->push_back(std::make_unique<PodNerModel_::LabelT>());
+ labels->back()->boise_type = boise_type;
+ labels->back()->mention_type = mention_type;
+ labels->back()->collection_id = collection_id;
+}
+
+void FillLabels(int num_collections,
+ std::vector<std::unique_ptr<PodNerModel_::LabelT>>* labels) {
+ labels->clear();
+ for (auto boise_type :
+ {BoiseType_BEGIN, BoiseType_END, BoiseType_INTERMEDIATE}) {
+ for (auto mention_type : {MentionType_NAM, MentionType_NOM}) {
+ for (int i = 0; i < num_collections - 1; ++i) { // skip undefined
+ EmplaceToLabelVector(boise_type, mention_type, i, labels);
+ }
+ }
+ }
+ EmplaceToLabelVector(BoiseType_O, MentionType_UNDEFINED, num_collections - 1,
+ labels);
+ for (auto mention_type : {MentionType_NAM, MentionType_NOM}) {
+ for (int i = 0; i < num_collections - 1; ++i) { // skip undefined
+ EmplaceToLabelVector(BoiseType_SINGLE, mention_type, i, labels);
+ }
+ }
+}
+
+TEST_F(PodNerTest, AnnotateDefaultCollections) {
+ std::unique_ptr<PodNerAnnotator> annotator =
+ PodNerAnnotator::Create(model_, *unilib_);
+ ASSERT_TRUE(annotator != nullptr);
+
+ std::string multi_word_location = "I live in New York";
+ std::string single_word_location = "I live in Zurich";
+ {
+ std::vector<AnnotatedSpan> annotations;
+ ASSERT_TRUE(annotator->Annotate(UTF8ToUnicodeText(multi_word_location),
+ &annotations));
+ EXPECT_THAT(annotations, Not(IsEmpty()));
+ EXPECT_EQ(annotations[0].classification[0].collection, "location");
+ EXPECT_EQ(annotations[0].classification[0].priority_score,
+ kDefaultPriorityScore);
+
+ annotations.clear();
+ ASSERT_TRUE(annotator->Annotate(UTF8ToUnicodeText(single_word_location),
+ &annotations));
+ EXPECT_THAT(annotations, Not(IsEmpty()));
+ EXPECT_EQ(annotations[0].classification[0].collection, "location");
+ EXPECT_EQ(annotations[0].classification[0].priority_score,
+ kDefaultPriorityScore);
+ }
+}
+
+TEST_F(PodNerTest, AnnotateConfigurableCollections) {
+ std::unique_ptr<PodNerModelT> unpacked_model(model_->UnPack());
+ ASSERT_TRUE(unpacked_model != nullptr);
+
+ float xxx_single_token_priority = 0.9;
+ float xxx_multi_token_priority = 1.7;
+ const std::vector<std::string> collection_names = {
+ "art", "consumer_good", "event", "xxx",
+ "organization", "ner_entity", "person", "undefined"};
+ FillCollections(collection_names,
+ /*single_token_priority_scores=*/
+ {0., 0., 0., xxx_single_token_priority, 0., 0., 0., 0.},
+ /*multi_token_priority_scores=*/
+ {0., 0., 0., xxx_multi_token_priority, 0., 0., 0., 0.},
+ &(unpacked_model->collections));
+ FillLabels(collection_names.size(), &(unpacked_model->labels));
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(PodNerModel::Pack(builder, unpacked_model.get()));
+ std::string model_buffer =
+ std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize());
+ std::unique_ptr<PodNerAnnotator> annotator = PodNerAnnotator::Create(
+ static_cast<const PodNerModel*>(
+ flatbuffers::GetRoot<PodNerModel>(model_buffer.data())),
+ *unilib_);
+ ASSERT_TRUE(annotator != nullptr);
+
+ std::string multi_word_location = "I live in New York";
+ std::string single_word_location = "I live in Zurich";
+ {
+ std::vector<AnnotatedSpan> annotations;
+ ASSERT_TRUE(annotator->Annotate(UTF8ToUnicodeText(multi_word_location),
+ &annotations));
+ EXPECT_THAT(annotations, Not(IsEmpty()));
+ EXPECT_EQ(annotations[0].classification[0].collection, "xxx");
+ EXPECT_EQ(annotations[0].classification[0].priority_score,
+ xxx_multi_token_priority);
+
+ annotations.clear();
+ ASSERT_TRUE(annotator->Annotate(UTF8ToUnicodeText(single_word_location),
+ &annotations));
+ EXPECT_THAT(annotations, Not(IsEmpty()));
+ EXPECT_EQ(annotations[0].classification[0].collection, "xxx");
+ EXPECT_EQ(annotations[0].classification[0].priority_score,
+ xxx_single_token_priority);
+ }
+}
+
+TEST_F(PodNerTest, AnnotateMinNumTokens) {
+ std::unique_ptr<PodNerAnnotator> annotator =
+ PodNerAnnotator::Create(model_, *unilib_);
+ ASSERT_TRUE(annotator != nullptr);
+
+ std::string text = "in New York";
+ {
+ std::vector<AnnotatedSpan> annotations;
+ ASSERT_TRUE(annotator->Annotate(UTF8ToUnicodeText(text), &annotations));
+ EXPECT_THAT(annotations, Not(IsEmpty()));
+ }
+
+ std::unique_ptr<PodNerModelT> unpacked_model(model_->UnPack());
+ ASSERT_TRUE(unpacked_model != nullptr);
+
+ unpacked_model->min_number_of_tokens = 4;
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(PodNerModel::Pack(builder, unpacked_model.get()));
+
+ std::string model_buffer =
+ std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize());
+ annotator = PodNerAnnotator::Create(
+ static_cast<const PodNerModel*>(
+ flatbuffers::GetRoot<PodNerModel>(model_buffer.data())),
+ *unilib_);
+ ASSERT_TRUE(annotator != nullptr);
+ {
+ std::vector<AnnotatedSpan> annotations;
+ ASSERT_TRUE(annotator->Annotate(UTF8ToUnicodeText(text), &annotations));
+ EXPECT_THAT(annotations, IsEmpty());
+ }
+}
+
+TEST_F(PodNerTest, AnnotateMinNumWordpieces) {
+ std::unique_ptr<PodNerAnnotator> annotator =
+ PodNerAnnotator::Create(model_, *unilib_);
+ ASSERT_TRUE(annotator != nullptr);
+
+ std::string text = "in New York";
+ {
+ std::vector<AnnotatedSpan> annotations;
+ ASSERT_TRUE(annotator->Annotate(UTF8ToUnicodeText(text), &annotations));
+ EXPECT_THAT(annotations, Not(IsEmpty()));
+ }
+
+ std::unique_ptr<PodNerModelT> unpacked_model(model_->UnPack());
+ ASSERT_TRUE(unpacked_model != nullptr);
+
+ unpacked_model->min_number_of_wordpieces = 10;
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(PodNerModel::Pack(builder, unpacked_model.get()));
+
+ std::string model_buffer =
+ std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize());
+ annotator = PodNerAnnotator::Create(
+ static_cast<const PodNerModel*>(
+ flatbuffers::GetRoot<PodNerModel>(model_buffer.data())),
+ *unilib_);
+ ASSERT_TRUE(annotator != nullptr);
+ {
+ std::vector<AnnotatedSpan> annotations;
+ ASSERT_TRUE(annotator->Annotate(UTF8ToUnicodeText(text), &annotations));
+ EXPECT_THAT(annotations, IsEmpty());
+ }
+}
+
+TEST_F(PodNerTest, AnnotateNonstandardText) {
+ std::unique_ptr<PodNerAnnotator> annotator =
+ PodNerAnnotator::Create(model_, *unilib_);
+ ASSERT_TRUE(annotator != nullptr);
+
+ const std::string nonstandard_text =
+ "abcNxCDU1RWNvbXByLXI4NS8xNzcwLzE3NzA4NDY2L3J1Ymluby1raWRzLXJlY2xpbmVyLXd"
+ "pdGgtY3VwLWhvbGRlci5qcGc=/"
+ "UnViaW5vIEtpZHMgUmVjbGluZXIgd2l0aCBDdXAgSG9sZGVyIGJ5IEhhcnJpZXQgQmVl."
+ "html>";
+ std::vector<AnnotatedSpan> annotations;
+ ASSERT_TRUE(
+ annotator->Annotate(UTF8ToUnicodeText(nonstandard_text), &annotations));
+ EXPECT_THAT(annotations, IsEmpty());
+}
+
+TEST_F(PodNerTest, AnnotateTextWithLinefeed) {
+ std::unique_ptr<PodNerAnnotator> annotator =
+ PodNerAnnotator::Create(model_, *unilib_);
+ ASSERT_TRUE(annotator != nullptr);
+
+ std::string nonstandard_text = "My name is Kuba\x09";
+ nonstandard_text += "and this is a test.";
+ std::vector<AnnotatedSpan> annotations;
+ ASSERT_TRUE(
+ annotator->Annotate(UTF8ToUnicodeText(nonstandard_text), &annotations));
+ EXPECT_THAT(annotations, Not(IsEmpty()));
+ EXPECT_EQ(annotations[0].span, CodepointSpan(11, 15));
+
+ nonstandard_text = "My name is Kuba\x09 and this is a test.";
+ ASSERT_TRUE(
+ annotator->Annotate(UTF8ToUnicodeText(nonstandard_text), &annotations));
+ EXPECT_THAT(annotations, Not(IsEmpty()));
+ EXPECT_EQ(annotations[0].span, CodepointSpan(11, 15));
+}
+
+TEST_F(PodNerTest, AnnotateWithUnknownWordpieces) {
+ std::unique_ptr<PodNerAnnotator> annotator =
+ PodNerAnnotator::Create(model_, *unilib_);
+ ASSERT_TRUE(annotator != nullptr);
+
+ const std::string long_text =
+ "It is easy to spend a fun and exciting day in Seattle without a car. "
+ "There are lots of ways to modify this itinerary. Add a ferry ride "
+ "from the waterfront. Spending the day at the Seattle Center or at the "
+ "aquarium could easily extend this from one to several days. Take the "
+ "Underground Tour in Pioneer Square. Visit the Klondike Gold Rush "
+ "Museum which is fun and free. In the summer months you can ride the "
+ "passenger-only Water Taxi from the waterfront to West Seattle and "
+ "Alki Beach. Here's a sample one day itinerary: Start at the Space "
+ "Needle by taking the Seattle Monorail from downtown. Look around the "
+ "Seattle Center or go to the Space Needle.";
+ const std::string text_with_unknown_wordpieces = "před chvílí";
+
+ std::vector<AnnotatedSpan> annotations;
+ ASSERT_TRUE(
+ annotator->Annotate(UTF8ToUnicodeText("Google New York , in New York. " +
+ text_with_unknown_wordpieces),
+ &annotations));
+ EXPECT_THAT(annotations, IsEmpty());
+ ASSERT_TRUE(annotator->Annotate(
+ UTF8ToUnicodeText(long_text + " " + text_with_unknown_wordpieces),
+ &annotations));
+ EXPECT_THAT(annotations, Not(IsEmpty()));
+}
+
+class PodNerTestWithOrWithoutFinalPeriod
+ : public PodNerTest,
+ public testing::WithParamInterface<bool> {};
+
+INSTANTIATE_TEST_SUITE_P(TestAnnotateLongText,
+ PodNerTestWithOrWithoutFinalPeriod,
+ testing::Values(true, false));
+
+TEST_P(PodNerTestWithOrWithoutFinalPeriod, AnnotateLongText) {
+ std::unique_ptr<PodNerAnnotator> annotator = PodNerAnnotator::Create(
+ GetParam() ? model_append_final_period_ : model_, *unilib_);
+ ASSERT_TRUE(annotator != nullptr);
+
+ const std::string long_text =
+ "It is easy to spend a fun and exciting day in Seattle without a car. "
+ "There are lots of ways to modify this itinerary. Add a ferry ride "
+ "from the waterfront. Spending the day at the Seattle Center or at the "
+ "aquarium could easily extend this from one to several days. Take the "
+ "Underground Tour in Pioneer Square. Visit the Klondike Gold Rush "
+ "Museum which is fun and free. In the summer months you can ride the "
+ "passenger-only Water Taxi from the waterfront to West Seattle and "
+ "Alki Beach. Here's a sample one day itinerary: Start at the Space "
+ "Needle by taking the Seattle Monorail from downtown. Look around the "
+ "Seattle Center or go to the Space Needle. If you're interested in "
+ "music the EMP-SFM (Experience Music Project - Science Fiction Musuem) "
+ "is located at the foot of the Space Needle. It has a lot of rock'n "
+ "roll memorabilia that you may find interesting. The Chihuly Garden "
+ "and Glass musuem is near the Space Needle and you can get a "
+ "combination ticket for both. It gets really good reviews. If you're "
+ "interested, then the Bill & Melinda Gates Foundation is across from "
+ "the EMP and has a visitors center that is free. Come see how Bill "
+ "Gates is giving away his millions. Take the Monorail back downtown. "
+ "You will be at 5th and Pine (Westlake Center). Head west to the Pike "
+ "Place Market. Look around then head for the Pike Place hill climb "
+ "which is a series of steps that walk down to the waterfront. You will "
+ "end up across the street from the Seattle Aquarium. Plenty of things "
+ "to do on the waterfront, boat cruises, seafood restaurants, the "
+ "Aquarium, or your typical tourist activities. You can walk or take "
+ "the waterfront trolley bus. Note that waterfront construction has "
+ "relocated the trolley Metro bus route 99 that will take you from "
+ "Pioneer Square all the way to the end of the waterfront where you can "
+ "visit the Seattle Art Musuem's XXX Sculpture Garden just north of "
+ "Pier 70. The route goes thru Chinatown/International District, "
+ "through Pioneer Square, up 1st ave past the Pike Place Market and to "
+ "1st and Cedar which is walking distance to the Space Needle. It then "
+ "goes down Broad Street toward the Olympic Sculpture Garden. It runs "
+ "approximately every 30 minutes during the day and early evening.";
+ std::vector<AnnotatedSpan> annotations;
+ ASSERT_TRUE(annotator->Annotate(UTF8ToUnicodeText(long_text), &annotations));
+ EXPECT_THAT(annotations, Not(IsEmpty()));
+
+ const std::string location_from_beginning = "Seattle";
+ int start_span_location_from_beginning =
+ long_text.find(location_from_beginning);
+ EXPECT_EQ(annotations[0].span,
+ CodepointSpan(start_span_location_from_beginning,
+ start_span_location_from_beginning +
+ location_from_beginning.length()));
+
+ const std::string location_from_end = "Olympic Sculpture Garden";
+ int start_span_location_from_end = long_text.find(location_from_end);
+ const AnnotatedSpan& last_annotation = *annotations.rbegin();
+ EXPECT_EQ(
+ last_annotation.span,
+ CodepointSpan(start_span_location_from_end,
+ start_span_location_from_end + location_from_end.length()));
+}
+
+TEST_F(PodNerTest, SuggestSelectionLongText) {
+ std::unique_ptr<PodNerAnnotator> annotator =
+ PodNerAnnotator::Create(model_, *unilib_);
+ ASSERT_TRUE(annotator != nullptr);
+
+ const std::string long_text =
+ "It is easy to spend a fun and exciting day in Seattle without a car. "
+ "There are lots of ways to modify this itinerary. Add a ferry ride "
+ "from the waterfront. Spending the day at the Seattle Center or at the "
+ "aquarium could easily extend this from one to several days. Take the "
+ "Underground Tour in Pioneer Square. Visit the Klondike Gold Rush "
+ "Museum which is fun and free. In the summer months you can ride the "
+ "passenger-only Water Taxi from the waterfront to West Seattle and "
+ "Alki Beach. Here's a sample one day itinerary: Start at the Space "
+ "Needle by taking the Seattle Monorail from downtown. Look around the "
+ "Seattle Center or go to the Space Needle. If you're interested in "
+ "music the EMP-SFM (Experience Music Project - Science Fiction Musuem) "
+ "is located at the foot of the Space Needle. It has a lot of rock'n "
+ "roll memorabilia that you may find interesting. The Chihuly Garden "
+ "and Glass musuem is near the Space Needle and you can get a "
+ "combination ticket for both. It gets really good reviews. If you're "
+ "interested, then the Bill & Melinda Gates Foundation is across from "
+ "the EMP and has a visitors center that is free. Come see how Bill "
+ "Gates is giving away his millions. Take the Monorail back downtown. "
+ "You will be at 5th and Pine (Westlake Center). Head west to the Pike "
+ "Place Market. Look around then head for the Pike Place hill climb "
+ "which is a series of steps that walk down to the waterfront. You will "
+ "end up across the street from the Seattle Aquarium. Plenty of things "
+ "to do on the waterfront, boat cruises, seafood restaurants, the "
+ "Aquarium, or your typical tourist activities. You can walk or take "
+ "the waterfront trolley bus. Note that waterfront construction has "
+ "relocated the trolley Metro bus route 99 that will take you from "
+ "Pioneer Square all the way to the end of the waterfront where you can "
+ "visit the Seattle Art Musuem's XXX Sculpture Garden just north of "
+ "Pier 70. The route goes thru Chinatown/International District, "
+ "through Pioneer Square, up 1st ave past the Pike Place Market and to "
+ "1st and Cedar which is walking distance to the Space Needle. It then "
+ "goes down Broad Street toward the Olympic Sculpture Garden. It runs "
+ "approximately every 30 minutes during the day and early evening.";
+ const std::string klondike = "Klondike Gold Rush Museum";
+ int klondike_start = long_text.find(klondike);
+
+ AnnotatedSpan suggested_span;
+ EXPECT_TRUE(annotator->SuggestSelection(UTF8ToUnicodeText(long_text),
+ {klondike_start, klondike_start + 8},
+ &suggested_span));
+ EXPECT_EQ(suggested_span.span,
+ CodepointSpan(klondike_start, klondike_start + klondike.length()));
+}
+
+TEST_F(PodNerTest, SuggestSelectionTest) {
+ std::unique_ptr<PodNerAnnotator> annotator =
+ PodNerAnnotator::Create(model_, *unilib_);
+ ASSERT_TRUE(annotator != nullptr);
+
+ AnnotatedSpan suggested_span;
+ EXPECT_TRUE(annotator->SuggestSelection(
+ UTF8ToUnicodeText("Google New York, in New York"), {7, 10},
+ &suggested_span));
+ EXPECT_EQ(suggested_span.span, CodepointSpan(7, 15));
+ EXPECT_FALSE(annotator->SuggestSelection(
+ UTF8ToUnicodeText("Google New York, in New York"), {17, 19},
+ &suggested_span));
+ EXPECT_EQ(suggested_span.span, CodepointSpan(kInvalidIndex, kInvalidIndex));
+}
+
+TEST_F(PodNerTest, ClassifyTextTest) {
+ std::unique_ptr<PodNerAnnotator> annotator =
+ PodNerAnnotator::Create(model_, *unilib_);
+ ASSERT_TRUE(annotator != nullptr);
+
+ ClassificationResult result;
+ ASSERT_TRUE(annotator->ClassifyText(UTF8ToUnicodeText("We met in New York"),
+ {10, 18}, &result));
+ EXPECT_EQ(result.collection, "location");
+}
+
+TEST_F(PodNerTest, ThreadSafety) {
+ std::unique_ptr<PodNerAnnotator> annotator =
+ PodNerAnnotator::Create(model_, *unilib_);
+ ASSERT_TRUE(annotator != nullptr);
+
+ // Do inference in 20 threads. When run with --config=tsan, this should fire
+ // if there's a problem.
+ std::vector<std::thread> thread_pool(20);
+ for (std::thread& thread : thread_pool) {
+ thread = std::thread([&annotator]() {
+ AnnotatedSpan suggested_span;
+ EXPECT_TRUE(annotator->SuggestSelection(
+ UTF8ToUnicodeText("Google New York, in New York"), {7, 10},
+ &suggested_span));
+ EXPECT_EQ(suggested_span.span, CodepointSpan(7, 15));
+ });
+ }
+ for (std::thread& thread : thread_pool) {
+ thread.join();
+ }
+}
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/native/annotator/pod_ner/pod-ner.h b/native/annotator/pod_ner/pod-ner.h
new file mode 100644
index 0000000..812e94e
--- /dev/null
+++ b/native/annotator/pod_ner/pod-ner.h
@@ -0,0 +1,32 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_POD_NER_POD_NER_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_POD_NER_POD_NER_H_
+
+#if defined TC3_POD_NER_ANNOTATOR_FLAG_DEFINED
+#include "annotator/pod_ner/pod-ner-flag-defined.h"
+#else
+#if defined TC3_POD_NER_ANNOTATOR_IMPL
+#include "annotator/pod_ner/pod-ner-impl.h"
+#elif defined TC3_POD_NER_ANNOTATOR_DUMMY
+#include "annotator/pod_ner/pod-ner-dummy.h"
+#else
+#error No POD NER implementation specified.
+#endif
+#endif // TC3_POD_NER_ANNOTATOR_FLAG_DEFINED
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_POD_NER_POD_NER_H_
diff --git a/native/annotator/pod_ner/test_data/tflite_model.tflite b/native/annotator/pod_ner/test_data/tflite_model.tflite
new file mode 100644
index 0000000..d1286a7
--- /dev/null
+++ b/native/annotator/pod_ner/test_data/tflite_model.tflite
Binary files differ
diff --git a/native/annotator/pod_ner/test_data/vocab.txt b/native/annotator/pod_ner/test_data/vocab.txt
new file mode 100644
index 0000000..fb14027
--- /dev/null
+++ b/native/annotator/pod_ner/test_data/vocab.txt
@@ -0,0 +1,30522 @@
+[PAD]
+[unused0]
+[unused1]
+[unused2]
+[unused3]
+[unused4]
+[unused5]
+[unused6]
+[unused7]
+[unused8]
+[unused9]
+[unused10]
+[unused11]
+[unused12]
+[unused13]
+[unused14]
+[unused15]
+[unused16]
+[unused17]
+[unused18]
+[unused19]
+[unused20]
+[unused21]
+[unused22]
+[unused23]
+[unused24]
+[unused25]
+[unused26]
+[unused27]
+[unused28]
+[unused29]
+[unused30]
+[unused31]
+[unused32]
+[unused33]
+[unused34]
+[unused35]
+[unused36]
+[unused37]
+[unused38]
+[unused39]
+[unused40]
+[unused41]
+[unused42]
+[unused43]
+[unused44]
+[unused45]
+[unused46]
+[unused47]
+[unused48]
+[unused49]
+[unused50]
+[unused51]
+[unused52]
+[unused53]
+[unused54]
+[unused55]
+[unused56]
+[unused57]
+[unused58]
+[unused59]
+[unused60]
+[unused61]
+[unused62]
+[unused63]
+[unused64]
+[unused65]
+[unused66]
+[unused67]
+[unused68]
+[unused69]
+[unused70]
+[unused71]
+[unused72]
+[unused73]
+[unused74]
+[unused75]
+[unused76]
+[unused77]
+[unused78]
+[unused79]
+[unused80]
+[unused81]
+[unused82]
+[unused83]
+[unused84]
+[unused85]
+[unused86]
+[unused87]
+[unused88]
+[unused89]
+[unused90]
+[unused91]
+[unused92]
+[unused93]
+[unused94]
+[unused95]
+[unused96]
+[unused97]
+[unused98]
+[UNK]
+[CLS]
+[SEP]
+[MASK]
+[unused99]
+[unused100]
+[unused101]
+[unused102]
+[unused103]
+[unused104]
+[unused105]
+[unused106]
+[unused107]
+[unused108]
+[unused109]
+[unused110]
+[unused111]
+[unused112]
+[unused113]
+[unused114]
+[unused115]
+[unused116]
+[unused117]
+[unused118]
+[unused119]
+[unused120]
+[unused121]
+[unused122]
+[unused123]
+[unused124]
+[unused125]
+[unused126]
+[unused127]
+[unused128]
+[unused129]
+[unused130]
+[unused131]
+[unused132]
+[unused133]
+[unused134]
+[unused135]
+[unused136]
+[unused137]
+[unused138]
+[unused139]
+[unused140]
+[unused141]
+[unused142]
+[unused143]
+[unused144]
+[unused145]
+[unused146]
+[unused147]
+[unused148]
+[unused149]
+[unused150]
+[unused151]
+[unused152]
+[unused153]
+[unused154]
+[unused155]
+[unused156]
+[unused157]
+[unused158]
+[unused159]
+[unused160]
+[unused161]
+[unused162]
+[unused163]
+[unused164]
+[unused165]
+[unused166]
+[unused167]
+[unused168]
+[unused169]
+[unused170]
+[unused171]
+[unused172]
+[unused173]
+[unused174]
+[unused175]
+[unused176]
+[unused177]
+[unused178]
+[unused179]
+[unused180]
+[unused181]
+[unused182]
+[unused183]
+[unused184]
+[unused185]
+[unused186]
+[unused187]
+[unused188]
+[unused189]
+[unused190]
+[unused191]
+[unused192]
+[unused193]
+[unused194]
+[unused195]
+[unused196]
+[unused197]
+[unused198]
+[unused199]
+[unused200]
+[unused201]
+[unused202]
+[unused203]
+[unused204]
+[unused205]
+[unused206]
+[unused207]
+[unused208]
+[unused209]
+[unused210]
+[unused211]
+[unused212]
+[unused213]
+[unused214]
+[unused215]
+[unused216]
+[unused217]
+[unused218]
+[unused219]
+[unused220]
+[unused221]
+[unused222]
+[unused223]
+[unused224]
+[unused225]
+[unused226]
+[unused227]
+[unused228]
+[unused229]
+[unused230]
+[unused231]
+[unused232]
+[unused233]
+[unused234]
+[unused235]
+[unused236]
+[unused237]
+[unused238]
+[unused239]
+[unused240]
+[unused241]
+[unused242]
+[unused243]
+[unused244]
+[unused245]
+[unused246]
+[unused247]
+[unused248]
+[unused249]
+[unused250]
+[unused251]
+[unused252]
+[unused253]
+[unused254]
+[unused255]
+[unused256]
+[unused257]
+[unused258]
+[unused259]
+[unused260]
+[unused261]
+[unused262]
+[unused263]
+[unused264]
+[unused265]
+[unused266]
+[unused267]
+[unused268]
+[unused269]
+[unused270]
+[unused271]
+[unused272]
+[unused273]
+[unused274]
+[unused275]
+[unused276]
+[unused277]
+[unused278]
+[unused279]
+[unused280]
+[unused281]
+[unused282]
+[unused283]
+[unused284]
+[unused285]
+[unused286]
+[unused287]
+[unused288]
+[unused289]
+[unused290]
+[unused291]
+[unused292]
+[unused293]
+[unused294]
+[unused295]
+[unused296]
+[unused297]
+[unused298]
+[unused299]
+[unused300]
+[unused301]
+[unused302]
+[unused303]
+[unused304]
+[unused305]
+[unused306]
+[unused307]
+[unused308]
+[unused309]
+[unused310]
+[unused311]
+[unused312]
+[unused313]
+[unused314]
+[unused315]
+[unused316]
+[unused317]
+[unused318]
+[unused319]
+[unused320]
+[unused321]
+[unused322]
+[unused323]
+[unused324]
+[unused325]
+[unused326]
+[unused327]
+[unused328]
+[unused329]
+[unused330]
+[unused331]
+[unused332]
+[unused333]
+[unused334]
+[unused335]
+[unused336]
+[unused337]
+[unused338]
+[unused339]
+[unused340]
+[unused341]
+[unused342]
+[unused343]
+[unused344]
+[unused345]
+[unused346]
+[unused347]
+[unused348]
+[unused349]
+[unused350]
+[unused351]
+[unused352]
+[unused353]
+[unused354]
+[unused355]
+[unused356]
+[unused357]
+[unused358]
+[unused359]
+[unused360]
+[unused361]
+[unused362]
+[unused363]
+[unused364]
+[unused365]
+[unused366]
+[unused367]
+[unused368]
+[unused369]
+[unused370]
+[unused371]
+[unused372]
+[unused373]
+[unused374]
+[unused375]
+[unused376]
+[unused377]
+[unused378]
+[unused379]
+[unused380]
+[unused381]
+[unused382]
+[unused383]
+[unused384]
+[unused385]
+[unused386]
+[unused387]
+[unused388]
+[unused389]
+[unused390]
+[unused391]
+[unused392]
+[unused393]
+[unused394]
+[unused395]
+[unused396]
+[unused397]
+[unused398]
+[unused399]
+[unused400]
+[unused401]
+[unused402]
+[unused403]
+[unused404]
+[unused405]
+[unused406]
+[unused407]
+[unused408]
+[unused409]
+[unused410]
+[unused411]
+[unused412]
+[unused413]
+[unused414]
+[unused415]
+[unused416]
+[unused417]
+[unused418]
+[unused419]
+[unused420]
+[unused421]
+[unused422]
+[unused423]
+[unused424]
+[unused425]
+[unused426]
+[unused427]
+[unused428]
+[unused429]
+[unused430]
+[unused431]
+[unused432]
+[unused433]
+[unused434]
+[unused435]
+[unused436]
+[unused437]
+[unused438]
+[unused439]
+[unused440]
+[unused441]
+[unused442]
+[unused443]
+[unused444]
+[unused445]
+[unused446]
+[unused447]
+[unused448]
+[unused449]
+[unused450]
+[unused451]
+[unused452]
+[unused453]
+[unused454]
+[unused455]
+[unused456]
+[unused457]
+[unused458]
+[unused459]
+[unused460]
+[unused461]
+[unused462]
+[unused463]
+[unused464]
+[unused465]
+[unused466]
+[unused467]
+[unused468]
+[unused469]
+[unused470]
+[unused471]
+[unused472]
+[unused473]
+[unused474]
+[unused475]
+[unused476]
+[unused477]
+[unused478]
+[unused479]
+[unused480]
+[unused481]
+[unused482]
+[unused483]
+[unused484]
+[unused485]
+[unused486]
+[unused487]
+[unused488]
+[unused489]
+[unused490]
+[unused491]
+[unused492]
+[unused493]
+[unused494]
+[unused495]
+[unused496]
+[unused497]
+[unused498]
+[unused499]
+[unused500]
+[unused501]
+[unused502]
+[unused503]
+[unused504]
+[unused505]
+[unused506]
+[unused507]
+[unused508]
+[unused509]
+[unused510]
+[unused511]
+[unused512]
+[unused513]
+[unused514]
+[unused515]
+[unused516]
+[unused517]
+[unused518]
+[unused519]
+[unused520]
+[unused521]
+[unused522]
+[unused523]
+[unused524]
+[unused525]
+[unused526]
+[unused527]
+[unused528]
+[unused529]
+[unused530]
+[unused531]
+[unused532]
+[unused533]
+[unused534]
+[unused535]
+[unused536]
+[unused537]
+[unused538]
+[unused539]
+[unused540]
+[unused541]
+[unused542]
+[unused543]
+[unused544]
+[unused545]
+[unused546]
+[unused547]
+[unused548]
+[unused549]
+[unused550]
+[unused551]
+[unused552]
+[unused553]
+[unused554]
+[unused555]
+[unused556]
+[unused557]
+[unused558]
+[unused559]
+[unused560]
+[unused561]
+[unused562]
+[unused563]
+[unused564]
+[unused565]
+[unused566]
+[unused567]
+[unused568]
+[unused569]
+[unused570]
+[unused571]
+[unused572]
+[unused573]
+[unused574]
+[unused575]
+[unused576]
+[unused577]
+[unused578]
+[unused579]
+[unused580]
+[unused581]
+[unused582]
+[unused583]
+[unused584]
+[unused585]
+[unused586]
+[unused587]
+[unused588]
+[unused589]
+[unused590]
+[unused591]
+[unused592]
+[unused593]
+[unused594]
+[unused595]
+[unused596]
+[unused597]
+[unused598]
+[unused599]
+[unused600]
+[unused601]
+[unused602]
+[unused603]
+[unused604]
+[unused605]
+[unused606]
+[unused607]
+[unused608]
+[unused609]
+[unused610]
+[unused611]
+[unused612]
+[unused613]
+[unused614]
+[unused615]
+[unused616]
+[unused617]
+[unused618]
+[unused619]
+[unused620]
+[unused621]
+[unused622]
+[unused623]
+[unused624]
+[unused625]
+[unused626]
+[unused627]
+[unused628]
+[unused629]
+[unused630]
+[unused631]
+[unused632]
+[unused633]
+[unused634]
+[unused635]
+[unused636]
+[unused637]
+[unused638]
+[unused639]
+[unused640]
+[unused641]
+[unused642]
+[unused643]
+[unused644]
+[unused645]
+[unused646]
+[unused647]
+[unused648]
+[unused649]
+[unused650]
+[unused651]
+[unused652]
+[unused653]
+[unused654]
+[unused655]
+[unused656]
+[unused657]
+[unused658]
+[unused659]
+[unused660]
+[unused661]
+[unused662]
+[unused663]
+[unused664]
+[unused665]
+[unused666]
+[unused667]
+[unused668]
+[unused669]
+[unused670]
+[unused671]
+[unused672]
+[unused673]
+[unused674]
+[unused675]
+[unused676]
+[unused677]
+[unused678]
+[unused679]
+[unused680]
+[unused681]
+[unused682]
+[unused683]
+[unused684]
+[unused685]
+[unused686]
+[unused687]
+[unused688]
+[unused689]
+[unused690]
+[unused691]
+[unused692]
+[unused693]
+[unused694]
+[unused695]
+[unused696]
+[unused697]
+[unused698]
+[unused699]
+[unused700]
+[unused701]
+[unused702]
+[unused703]
+[unused704]
+[unused705]
+[unused706]
+[unused707]
+[unused708]
+[unused709]
+[unused710]
+[unused711]
+[unused712]
+[unused713]
+[unused714]
+[unused715]
+[unused716]
+[unused717]
+[unused718]
+[unused719]
+[unused720]
+[unused721]
+[unused722]
+[unused723]
+[unused724]
+[unused725]
+[unused726]
+[unused727]
+[unused728]
+[unused729]
+[unused730]
+[unused731]
+[unused732]
+[unused733]
+[unused734]
+[unused735]
+[unused736]
+[unused737]
+[unused738]
+[unused739]
+[unused740]
+[unused741]
+[unused742]
+[unused743]
+[unused744]
+[unused745]
+[unused746]
+[unused747]
+[unused748]
+[unused749]
+[unused750]
+[unused751]
+[unused752]
+[unused753]
+[unused754]
+[unused755]
+[unused756]
+[unused757]
+[unused758]
+[unused759]
+[unused760]
+[unused761]
+[unused762]
+[unused763]
+[unused764]
+[unused765]
+[unused766]
+[unused767]
+[unused768]
+[unused769]
+[unused770]
+[unused771]
+[unused772]
+[unused773]
+[unused774]
+[unused775]
+[unused776]
+[unused777]
+[unused778]
+[unused779]
+[unused780]
+[unused781]
+[unused782]
+[unused783]
+[unused784]
+[unused785]
+[unused786]
+[unused787]
+[unused788]
+[unused789]
+[unused790]
+[unused791]
+[unused792]
+[unused793]
+[unused794]
+[unused795]
+[unused796]
+[unused797]
+[unused798]
+[unused799]
+[unused800]
+[unused801]
+[unused802]
+[unused803]
+[unused804]
+[unused805]
+[unused806]
+[unused807]
+[unused808]
+[unused809]
+[unused810]
+[unused811]
+[unused812]
+[unused813]
+[unused814]
+[unused815]
+[unused816]
+[unused817]
+[unused818]
+[unused819]
+[unused820]
+[unused821]
+[unused822]
+[unused823]
+[unused824]
+[unused825]
+[unused826]
+[unused827]
+[unused828]
+[unused829]
+[unused830]
+[unused831]
+[unused832]
+[unused833]
+[unused834]
+[unused835]
+[unused836]
+[unused837]
+[unused838]
+[unused839]
+[unused840]
+[unused841]
+[unused842]
+[unused843]
+[unused844]
+[unused845]
+[unused846]
+[unused847]
+[unused848]
+[unused849]
+[unused850]
+[unused851]
+[unused852]
+[unused853]
+[unused854]
+[unused855]
+[unused856]
+[unused857]
+[unused858]
+[unused859]
+[unused860]
+[unused861]
+[unused862]
+[unused863]
+[unused864]
+[unused865]
+[unused866]
+[unused867]
+[unused868]
+[unused869]
+[unused870]
+[unused871]
+[unused872]
+[unused873]
+[unused874]
+[unused875]
+[unused876]
+[unused877]
+[unused878]
+[unused879]
+[unused880]
+[unused881]
+[unused882]
+[unused883]
+[unused884]
+[unused885]
+[unused886]
+[unused887]
+[unused888]
+[unused889]
+[unused890]
+[unused891]
+[unused892]
+[unused893]
+[unused894]
+[unused895]
+[unused896]
+[unused897]
+[unused898]
+[unused899]
+[unused900]
+[unused901]
+[unused902]
+[unused903]
+[unused904]
+[unused905]
+[unused906]
+[unused907]
+[unused908]
+[unused909]
+[unused910]
+[unused911]
+[unused912]
+[unused913]
+[unused914]
+[unused915]
+[unused916]
+[unused917]
+[unused918]
+[unused919]
+[unused920]
+[unused921]
+[unused922]
+[unused923]
+[unused924]
+[unused925]
+[unused926]
+[unused927]
+[unused928]
+[unused929]
+[unused930]
+[unused931]
+[unused932]
+[unused933]
+[unused934]
+[unused935]
+[unused936]
+[unused937]
+[unused938]
+[unused939]
+[unused940]
+[unused941]
+[unused942]
+[unused943]
+[unused944]
+[unused945]
+[unused946]
+[unused947]
+[unused948]
+[unused949]
+[unused950]
+[unused951]
+[unused952]
+[unused953]
+[unused954]
+[unused955]
+[unused956]
+[unused957]
+[unused958]
+[unused959]
+[unused960]
+[unused961]
+[unused962]
+[unused963]
+[unused964]
+[unused965]
+[unused966]
+[unused967]
+[unused968]
+[unused969]
+[unused970]
+[unused971]
+[unused972]
+[unused973]
+[unused974]
+[unused975]
+[unused976]
+[unused977]
+[unused978]
+[unused979]
+[unused980]
+[unused981]
+[unused982]
+[unused983]
+[unused984]
+[unused985]
+[unused986]
+[unused987]
+[unused988]
+[unused989]
+[unused990]
+[unused991]
+[unused992]
+[unused993]
+!
+"
+#
+$
+%
+&
+'
+(
+)
+*
++
+,
+-
+.
+/
+0
+1
+2
+3
+4
+5
+6
+7
+8
+9
+:
+;
+<
+=
+>
+?
+@
+[
+\
+]
+^
+_
+`
+a
+b
+c
+d
+e
+f
+g
+h
+i
+j
+k
+l
+m
+n
+o
+p
+q
+r
+s
+t
+u
+v
+w
+x
+y
+z
+{
+|
+}
+~
+¡
+¢
+£
+¤
+¥
+¦
+§
+¨
+©
+ª
+«
+¬
+®
+°
+±
+²
+³
+´
+µ
+¶
+·
+¹
+º
+»
+¼
+½
+¾
+¿
+×
+ß
+æ
+ð
+÷
+ø
+þ
+đ
+ħ
+ı
+ł
+ŋ
+œ
+ƒ
+ɐ
+ɑ
+ɒ
+ɔ
+ɕ
+ə
+ɛ
+ɡ
+ɣ
+ɨ
+ɪ
+ɫ
+ɬ
+ɯ
+ɲ
+ɴ
+ɹ
+ɾ
+ʀ
+ʁ
+ʂ
+ʃ
+ʉ
+ʊ
+ʋ
+ʌ
+ʎ
+ʐ
+ʑ
+ʒ
+ʔ
+ʰ
+ʲ
+ʳ
+ʷ
+ʸ
+ʻ
+ʼ
+ʾ
+ʿ
+ˈ
+ː
+ˡ
+ˢ
+ˣ
+ˤ
+α
+β
+γ
+δ
+ε
+ζ
+η
+θ
+ι
+κ
+λ
+μ
+ν
+ξ
+ο
+π
+ρ
+ς
+σ
+τ
+υ
+φ
+χ
+ψ
+ω
+а
+б
+в
+г
+д
+е
+ж
+з
+и
+к
+л
+м
+н
+о
+п
+р
+с
+т
+у
+ф
+х
+ц
+ч
+ш
+щ
+ъ
+ы
+ь
+э
+ю
+я
+ђ
+є
+і
+ј
+љ
+њ
+ћ
+ӏ
+ա
+բ
+գ
+դ
+ե
+թ
+ի
+լ
+կ
+հ
+մ
+յ
+ն
+ո
+պ
+ս
+վ
+տ
+ր
+ւ
+ք
+־
+א
+ב
+ג
+ד
+ה
+ו
+ז
+ח
+ט
+י
+ך
+כ
+ל
+ם
+מ
+ן
+נ
+ס
+ע
+ף
+פ
+ץ
+צ
+ק
+ר
+ש
+ת
+،
+ء
+ا
+ب
+ة
+ت
+ث
+ج
+ح
+خ
+د
+ذ
+ر
+ز
+س
+ش
+ص
+ض
+ط
+ظ
+ع
+غ
+ـ
+ف
+ق
+ك
+ل
+م
+ن
+ه
+و
+ى
+ي
+ٹ
+پ
+چ
+ک
+گ
+ں
+ھ
+ہ
+ی
+ے
+अ
+आ
+उ
+ए
+क
+ख
+ग
+च
+ज
+ट
+ड
+ण
+त
+थ
+द
+ध
+न
+प
+ब
+भ
+म
+य
+र
+ल
+व
+श
+ष
+स
+ह
+ा
+ि
+ी
+ो
+।
+॥
+ং
+অ
+আ
+ই
+উ
+এ
+ও
+ক
+খ
+গ
+চ
+ছ
+জ
+ট
+ড
+ণ
+ত
+থ
+দ
+ধ
+ন
+প
+ব
+ভ
+ম
+য
+র
+ল
+শ
+ষ
+স
+হ
+া
+ি
+ী
+ে
+க
+ச
+ட
+த
+ந
+ன
+ப
+ம
+ய
+ர
+ல
+ள
+வ
+ா
+ி
+ு
+ே
+ை
+ನ
+ರ
+ಾ
+ක
+ය
+ර
+ල
+ව
+ා
+ก
+ง
+ต
+ท
+น
+พ
+ม
+ย
+ร
+ล
+ว
+ส
+อ
+า
+เ
+་
+།
+ག
+ང
+ད
+ན
+པ
+བ
+མ
+འ
+ར
+ལ
+ས
+မ
+ა
+ბ
+გ
+დ
+ე
+ვ
+თ
+ი
+კ
+ლ
+მ
+ნ
+ო
+რ
+ს
+ტ
+უ
+ᄀ
+ᄂ
+ᄃ
+ᄅ
+ᄆ
+ᄇ
+ᄉ
+ᄊ
+ᄋ
+ᄌ
+ᄎ
+ᄏ
+ᄐ
+ᄑ
+ᄒ
+ᅡ
+ᅢ
+ᅥ
+ᅦ
+ᅧ
+ᅩ
+ᅪ
+ᅭ
+ᅮ
+ᅯ
+ᅲ
+ᅳ
+ᅴ
+ᅵ
+ᆨ
+ᆫ
+ᆯ
+ᆷ
+ᆸ
+ᆼ
+ᴬ
+ᴮ
+ᴰ
+ᴵ
+ᴺ
+ᵀ
+ᵃ
+ᵇ
+ᵈ
+ᵉ
+ᵍ
+ᵏ
+ᵐ
+ᵒ
+ᵖ
+ᵗ
+ᵘ
+ᵢ
+ᵣ
+ᵤ
+ᵥ
+ᶜ
+ᶠ
+‐
+‑
+‒
+–
+—
+―
+‖
+‘
+’
+‚
+“
+”
+„
+†
+‡
+•
+…
+‰
+′
+″
+›
+‿
+⁄
+⁰
+ⁱ
+⁴
+⁵
+⁶
+⁷
+⁸
+⁹
+⁺
+⁻
+ⁿ
+₀
+₁
+₂
+₃
+₄
+₅
+₆
+₇
+₈
+₉
+₊
+₍
+₎
+ₐ
+ₑ
+ₒ
+ₓ
+ₕ
+ₖ
+ₗ
+ₘ
+ₙ
+ₚ
+ₛ
+ₜ
+₤
+₩
+€
+₱
+₹
+ℓ
+№
+ℝ
+™
+⅓
+⅔
+←
+↑
+→
+↓
+↔
+↦
+⇄
+⇌
+⇒
+∂
+∅
+∆
+∇
+∈
+−
+∗
+∘
+√
+∞
+∧
+∨
+∩
+∪
+≈
+≡
+≤
+≥
+⊂
+⊆
+⊕
+⊗
+⋅
+─
+│
+■
+▪
+●
+★
+☆
+☉
+♠
+♣
+♥
+♦
+♭
+♯
+⟨
+⟩
+ⱼ
+⺩
+⺼
+⽥
+、
+。
+〈
+〉
+《
+》
+「
+」
+『
+』
+〜
+あ
+い
+う
+え
+お
+か
+き
+く
+け
+こ
+さ
+し
+す
+せ
+そ
+た
+ち
+っ
+つ
+て
+と
+な
+に
+ぬ
+ね
+の
+は
+ひ
+ふ
+へ
+ほ
+ま
+み
+む
+め
+も
+や
+ゆ
+よ
+ら
+り
+る
+れ
+ろ
+を
+ん
+ァ
+ア
+ィ
+イ
+ウ
+ェ
+エ
+オ
+カ
+キ
+ク
+ケ
+コ
+サ
+シ
+ス
+セ
+タ
+チ
+ッ
+ツ
+テ
+ト
+ナ
+ニ
+ノ
+ハ
+ヒ
+フ
+ヘ
+ホ
+マ
+ミ
+ム
+メ
+モ
+ャ
+ュ
+ョ
+ラ
+リ
+ル
+レ
+ロ
+ワ
+ン
+・
+ー
+一
+三
+上
+下
+不
+世
+中
+主
+久
+之
+也
+事
+二
+五
+井
+京
+人
+亻
+仁
+介
+代
+仮
+伊
+会
+佐
+侍
+保
+信
+健
+元
+光
+八
+公
+内
+出
+分
+前
+劉
+力
+加
+勝
+北
+区
+十
+千
+南
+博
+原
+口
+古
+史
+司
+合
+吉
+同
+名
+和
+囗
+四
+国
+國
+土
+地
+坂
+城
+堂
+場
+士
+夏
+外
+大
+天
+太
+夫
+奈
+女
+子
+学
+宀
+宇
+安
+宗
+定
+宣
+宮
+家
+宿
+寺
+將
+小
+尚
+山
+岡
+島
+崎
+川
+州
+巿
+帝
+平
+年
+幸
+广
+弘
+張
+彳
+後
+御
+德
+心
+忄
+志
+忠
+愛
+成
+我
+戦
+戸
+手
+扌
+政
+文
+新
+方
+日
+明
+星
+春
+昭
+智
+曲
+書
+月
+有
+朝
+木
+本
+李
+村
+東
+松
+林
+森
+楊
+樹
+橋
+歌
+止
+正
+武
+比
+氏
+民
+水
+氵
+氷
+永
+江
+沢
+河
+治
+法
+海
+清
+漢
+瀬
+火
+版
+犬
+王
+生
+田
+男
+疒
+発
+白
+的
+皇
+目
+相
+省
+真
+石
+示
+社
+神
+福
+禾
+秀
+秋
+空
+立
+章
+竹
+糹
+美
+義
+耳
+良
+艹
+花
+英
+華
+葉
+藤
+行
+街
+西
+見
+訁
+語
+谷
+貝
+貴
+車
+軍
+辶
+道
+郎
+郡
+部
+都
+里
+野
+金
+鈴
+镇
+長
+門
+間
+阝
+阿
+陳
+陽
+雄
+青
+面
+風
+食
+香
+馬
+高
+龍
+龸
+fi
+fl
+!
+(
+)
+,
+-
+.
+/
+:
+?
+~
+the
+of
+and
+in
+to
+was
+he
+is
+as
+for
+on
+with
+that
+it
+his
+by
+at
+from
+her
+##s
+she
+you
+had
+an
+were
+but
+be
+this
+are
+not
+my
+they
+one
+which
+or
+have
+him
+me
+first
+all
+also
+their
+has
+up
+who
+out
+been
+when
+after
+there
+into
+new
+two
+its
+##a
+time
+would
+no
+what
+about
+said
+we
+over
+then
+other
+so
+more
+##e
+can
+if
+like
+back
+them
+only
+some
+could
+##i
+where
+just
+##ing
+during
+before
+##n
+do
+##o
+made
+school
+through
+than
+now
+years
+most
+world
+may
+between
+down
+well
+three
+##d
+year
+while
+will
+##ed
+##r
+##y
+later
+##t
+city
+under
+around
+did
+such
+being
+used
+state
+people
+part
+know
+against
+your
+many
+second
+university
+both
+national
+##er
+these
+don
+known
+off
+way
+until
+re
+how
+even
+get
+head
+...
+didn
+##ly
+team
+american
+because
+de
+##l
+born
+united
+film
+since
+still
+long
+work
+south
+us
+became
+any
+high
+again
+day
+family
+see
+right
+man
+eyes
+house
+season
+war
+states
+including
+took
+life
+north
+same
+each
+called
+name
+much
+place
+however
+go
+four
+group
+another
+found
+won
+area
+here
+going
+10
+away
+series
+left
+home
+music
+best
+make
+hand
+number
+company
+several
+never
+last
+john
+000
+very
+album
+take
+end
+good
+too
+following
+released
+game
+played
+little
+began
+district
+##m
+old
+want
+those
+side
+held
+own
+early
+county
+ll
+league
+use
+west
+##u
+face
+think
+##es
+2010
+government
+##h
+march
+came
+small
+general
+town
+june
+##on
+line
+based
+something
+##k
+september
+thought
+looked
+along
+international
+2011
+air
+july
+club
+went
+january
+october
+our
+august
+april
+york
+12
+few
+2012
+2008
+east
+show
+member
+college
+2009
+father
+public
+##us
+come
+men
+five
+set
+station
+church
+##c
+next
+former
+november
+room
+party
+located
+december
+2013
+age
+got
+2007
+##g
+system
+let
+love
+2006
+though
+every
+2014
+look
+song
+water
+century
+without
+body
+black
+night
+within
+great
+women
+single
+ve
+building
+large
+population
+river
+named
+band
+white
+started
+##an
+once
+15
+20
+should
+18
+2015
+service
+top
+built
+british
+open
+death
+king
+moved
+local
+times
+children
+february
+book
+why
+11
+door
+need
+president
+order
+final
+road
+wasn
+although
+due
+major
+died
+village
+third
+knew
+2016
+asked
+turned
+st
+wanted
+say
+##p
+together
+received
+main
+son
+served
+different
+##en
+behind
+himself
+felt
+members
+power
+football
+law
+voice
+play
+##in
+near
+park
+history
+30
+having
+2005
+16
+##man
+saw
+mother
+##al
+army
+point
+front
+help
+english
+street
+art
+late
+hands
+games
+award
+##ia
+young
+14
+put
+published
+country
+division
+across
+told
+13
+often
+ever
+french
+london
+center
+six
+red
+2017
+led
+days
+include
+light
+25
+find
+tell
+among
+species
+really
+according
+central
+half
+2004
+form
+original
+gave
+office
+making
+enough
+lost
+full
+opened
+must
+included
+live
+given
+german
+player
+run
+business
+woman
+community
+cup
+might
+million
+land
+2000
+court
+development
+17
+short
+round
+ii
+km
+seen
+class
+story
+always
+become
+sure
+research
+almost
+director
+council
+la
+##2
+career
+things
+using
+island
+##z
+couldn
+car
+##is
+24
+close
+force
+##1
+better
+free
+support
+control
+field
+students
+2003
+education
+married
+##b
+nothing
+worked
+others
+record
+big
+inside
+level
+anything
+continued
+give
+james
+##3
+military
+established
+non
+returned
+feel
+does
+title
+written
+thing
+feet
+william
+far
+co
+association
+hard
+already
+2002
+##ra
+championship
+human
+western
+100
+##na
+department
+hall
+role
+various
+production
+21
+19
+heart
+2001
+living
+fire
+version
+##ers
+##f
+television
+royal
+##4
+produced
+working
+act
+case
+society
+region
+present
+radio
+period
+looking
+least
+total
+keep
+england
+wife
+program
+per
+brother
+mind
+special
+22
+##le
+am
+works
+soon
+##6
+political
+george
+services
+taken
+created
+##7
+further
+able
+reached
+david
+union
+joined
+upon
+done
+important
+social
+information
+either
+##ic
+##x
+appeared
+position
+ground
+lead
+rock
+dark
+election
+23
+board
+france
+hair
+course
+arms
+site
+police
+girl
+instead
+real
+sound
+##v
+words
+moment
+##te
+someone
+##8
+summer
+project
+announced
+san
+less
+wrote
+past
+followed
+##5
+blue
+founded
+al
+finally
+india
+taking
+records
+america
+##ne
+1999
+design
+considered
+northern
+god
+stop
+battle
+toward
+european
+outside
+described
+track
+today
+playing
+language
+28
+call
+26
+heard
+professional
+low
+australia
+miles
+california
+win
+yet
+green
+##ie
+trying
+blood
+##ton
+southern
+science
+maybe
+everything
+match
+square
+27
+mouth
+video
+race
+recorded
+leave
+above
+##9
+daughter
+points
+space
+1998
+museum
+change
+middle
+common
+##0
+move
+tv
+post
+##ta
+lake
+seven
+tried
+elected
+closed
+ten
+paul
+minister
+##th
+months
+start
+chief
+return
+canada
+person
+sea
+release
+similar
+modern
+brought
+rest
+hit
+formed
+mr
+##la
+1997
+floor
+event
+doing
+thomas
+1996
+robert
+care
+killed
+training
+star
+week
+needed
+turn
+finished
+railway
+rather
+news
+health
+sent
+example
+ran
+term
+michael
+coming
+currently
+yes
+forces
+despite
+gold
+areas
+50
+stage
+fact
+29
+dead
+says
+popular
+2018
+originally
+germany
+probably
+developed
+result
+pulled
+friend
+stood
+money
+running
+mi
+signed
+word
+songs
+child
+eventually
+met
+tour
+average
+teams
+minutes
+festival
+current
+deep
+kind
+1995
+decided
+usually
+eastern
+seemed
+##ness
+episode
+bed
+added
+table
+indian
+private
+charles
+route
+available
+idea
+throughout
+centre
+addition
+appointed
+style
+1994
+books
+eight
+construction
+press
+mean
+wall
+friends
+remained
+schools
+study
+##ch
+##um
+institute
+oh
+chinese
+sometimes
+events
+possible
+1992
+australian
+type
+brown
+forward
+talk
+process
+food
+debut
+seat
+performance
+committee
+features
+character
+arts
+herself
+else
+lot
+strong
+russian
+range
+hours
+peter
+arm
+##da
+morning
+dr
+sold
+##ry
+quickly
+directed
+1993
+guitar
+china
+##w
+31
+list
+##ma
+performed
+media
+uk
+players
+smile
+##rs
+myself
+40
+placed
+coach
+province
+towards
+wouldn
+leading
+whole
+boy
+official
+designed
+grand
+census
+##el
+europe
+attack
+japanese
+henry
+1991
+##re
+##os
+cross
+getting
+alone
+action
+lower
+network
+wide
+washington
+japan
+1990
+hospital
+believe
+changed
+sister
+##ar
+hold
+gone
+sir
+hadn
+ship
+##ka
+studies
+academy
+shot
+rights
+below
+base
+bad
+involved
+kept
+largest
+##ist
+bank
+future
+especially
+beginning
+mark
+movement
+section
+female
+magazine
+plan
+professor
+lord
+longer
+##ian
+sat
+walked
+hill
+actually
+civil
+energy
+model
+families
+size
+thus
+aircraft
+completed
+includes
+data
+captain
+##or
+fight
+vocals
+featured
+richard
+bridge
+fourth
+1989
+officer
+stone
+hear
+##ism
+means
+medical
+groups
+management
+self
+lips
+competition
+entire
+lived
+technology
+leaving
+federal
+tournament
+bit
+passed
+hot
+independent
+awards
+kingdom
+mary
+spent
+fine
+doesn
+reported
+##ling
+jack
+fall
+raised
+itself
+stay
+true
+studio
+1988
+sports
+replaced
+paris
+systems
+saint
+leader
+theatre
+whose
+market
+capital
+parents
+spanish
+canadian
+earth
+##ity
+cut
+degree
+writing
+bay
+christian
+awarded
+natural
+higher
+bill
+##as
+coast
+provided
+previous
+senior
+ft
+valley
+organization
+stopped
+onto
+countries
+parts
+conference
+queen
+security
+interest
+saying
+allowed
+master
+earlier
+phone
+matter
+smith
+winning
+try
+happened
+moving
+campaign
+los
+##ley
+breath
+nearly
+mid
+1987
+certain
+girls
+date
+italian
+african
+standing
+fell
+artist
+##ted
+shows
+deal
+mine
+industry
+1986
+##ng
+everyone
+republic
+provide
+collection
+library
+student
+##ville
+primary
+owned
+older
+via
+heavy
+1st
+makes
+##able
+attention
+anyone
+africa
+##ri
+stated
+length
+ended
+fingers
+command
+staff
+skin
+foreign
+opening
+governor
+okay
+medal
+kill
+sun
+cover
+job
+1985
+introduced
+chest
+hell
+feeling
+##ies
+success
+meet
+reason
+standard
+meeting
+novel
+1984
+trade
+source
+buildings
+##land
+rose
+guy
+goal
+##ur
+chapter
+native
+husband
+previously
+unit
+limited
+entered
+weeks
+producer
+operations
+mountain
+takes
+covered
+forced
+related
+roman
+complete
+successful
+key
+texas
+cold
+##ya
+channel
+1980
+traditional
+films
+dance
+clear
+approximately
+500
+nine
+van
+prince
+question
+active
+tracks
+ireland
+regional
+silver
+author
+personal
+sense
+operation
+##ine
+economic
+1983
+holding
+twenty
+isbn
+additional
+speed
+hour
+edition
+regular
+historic
+places
+whom
+shook
+movie
+km²
+secretary
+prior
+report
+chicago
+read
+foundation
+view
+engine
+scored
+1982
+units
+ask
+airport
+property
+ready
+immediately
+lady
+month
+listed
+contract
+##de
+manager
+themselves
+lines
+##ki
+navy
+writer
+meant
+##ts
+runs
+##ro
+practice
+championships
+singer
+glass
+commission
+required
+forest
+starting
+culture
+generally
+giving
+access
+attended
+test
+couple
+stand
+catholic
+martin
+caught
+executive
+##less
+eye
+##ey
+thinking
+chair
+quite
+shoulder
+1979
+hope
+decision
+plays
+defeated
+municipality
+whether
+structure
+offered
+slowly
+pain
+ice
+direction
+##ion
+paper
+mission
+1981
+mostly
+200
+noted
+individual
+managed
+nature
+lives
+plant
+##ha
+helped
+except
+studied
+computer
+figure
+relationship
+issue
+significant
+loss
+die
+smiled
+gun
+ago
+highest
+1972
+##am
+male
+bring
+goals
+mexico
+problem
+distance
+commercial
+completely
+location
+annual
+famous
+drive
+1976
+neck
+1978
+surface
+caused
+italy
+understand
+greek
+highway
+wrong
+hotel
+comes
+appearance
+joseph
+double
+issues
+musical
+companies
+castle
+income
+review
+assembly
+bass
+initially
+parliament
+artists
+experience
+1974
+particular
+walk
+foot
+engineering
+talking
+window
+dropped
+##ter
+miss
+baby
+boys
+break
+1975
+stars
+edge
+remember
+policy
+carried
+train
+stadium
+bar
+sex
+angeles
+evidence
+##ge
+becoming
+assistant
+soviet
+1977
+upper
+step
+wing
+1970
+youth
+financial
+reach
+##ll
+actor
+numerous
+##se
+##st
+nodded
+arrived
+##ation
+minute
+##nt
+believed
+sorry
+complex
+beautiful
+victory
+associated
+temple
+1968
+1973
+chance
+perhaps
+metal
+##son
+1945
+bishop
+##et
+lee
+launched
+particularly
+tree
+le
+retired
+subject
+prize
+contains
+yeah
+theory
+empire
+##ce
+suddenly
+waiting
+trust
+recording
+##to
+happy
+terms
+camp
+champion
+1971
+religious
+pass
+zealand
+names
+2nd
+port
+ancient
+tom
+corner
+represented
+watch
+legal
+anti
+justice
+cause
+watched
+brothers
+45
+material
+changes
+simply
+response
+louis
+fast
+##ting
+answer
+60
+historical
+1969
+stories
+straight
+create
+feature
+increased
+rate
+administration
+virginia
+el
+activities
+cultural
+overall
+winner
+programs
+basketball
+legs
+guard
+beyond
+cast
+doctor
+mm
+flight
+results
+remains
+cost
+effect
+winter
+##ble
+larger
+islands
+problems
+chairman
+grew
+commander
+isn
+1967
+pay
+failed
+selected
+hurt
+fort
+box
+regiment
+majority
+journal
+35
+edward
+plans
+##ke
+##ni
+shown
+pretty
+irish
+characters
+directly
+scene
+likely
+operated
+allow
+spring
+##j
+junior
+matches
+looks
+mike
+houses
+fellow
+##tion
+beach
+marriage
+##ham
+##ive
+rules
+oil
+65
+florida
+expected
+nearby
+congress
+sam
+peace
+recent
+iii
+wait
+subsequently
+cell
+##do
+variety
+serving
+agreed
+please
+poor
+joe
+pacific
+attempt
+wood
+democratic
+piece
+prime
+##ca
+rural
+mile
+touch
+appears
+township
+1964
+1966
+soldiers
+##men
+##ized
+1965
+pennsylvania
+closer
+fighting
+claimed
+score
+jones
+physical
+editor
+##ous
+filled
+genus
+specific
+sitting
+super
+mom
+##va
+therefore
+supported
+status
+fear
+cases
+store
+meaning
+wales
+minor
+spain
+tower
+focus
+vice
+frank
+follow
+parish
+separate
+golden
+horse
+fifth
+remaining
+branch
+32
+presented
+stared
+##id
+uses
+secret
+forms
+##co
+baseball
+exactly
+##ck
+choice
+note
+discovered
+travel
+composed
+truth
+russia
+ball
+color
+kiss
+dad
+wind
+continue
+ring
+referred
+numbers
+digital
+greater
+##ns
+metres
+slightly
+direct
+increase
+1960
+responsible
+crew
+rule
+trees
+troops
+##no
+broke
+goes
+individuals
+hundred
+weight
+creek
+sleep
+memory
+defense
+provides
+ordered
+code
+value
+jewish
+windows
+1944
+safe
+judge
+whatever
+corps
+realized
+growing
+pre
+##ga
+cities
+alexander
+gaze
+lies
+spread
+scott
+letter
+showed
+situation
+mayor
+transport
+watching
+workers
+extended
+##li
+expression
+normal
+##ment
+chart
+multiple
+border
+##ba
+host
+##ner
+daily
+mrs
+walls
+piano
+##ko
+heat
+cannot
+##ate
+earned
+products
+drama
+era
+authority
+seasons
+join
+grade
+##io
+sign
+difficult
+machine
+1963
+territory
+mainly
+##wood
+stations
+squadron
+1962
+stepped
+iron
+19th
+##led
+serve
+appear
+sky
+speak
+broken
+charge
+knowledge
+kilometres
+removed
+ships
+article
+campus
+simple
+##ty
+pushed
+britain
+##ve
+leaves
+recently
+cd
+soft
+boston
+latter
+easy
+acquired
+poland
+##sa
+quality
+officers
+presence
+planned
+nations
+mass
+broadcast
+jean
+share
+image
+influence
+wild
+offer
+emperor
+electric
+reading
+headed
+ability
+promoted
+yellow
+ministry
+1942
+throat
+smaller
+politician
+##by
+latin
+spoke
+cars
+williams
+males
+lack
+pop
+80
+##ier
+acting
+seeing
+consists
+##ti
+estate
+1961
+pressure
+johnson
+newspaper
+jr
+chris
+olympics
+online
+conditions
+beat
+elements
+walking
+vote
+##field
+needs
+carolina
+text
+featuring
+global
+block
+shirt
+levels
+francisco
+purpose
+females
+et
+dutch
+duke
+ahead
+gas
+twice
+safety
+serious
+turning
+highly
+lieutenant
+firm
+maria
+amount
+mixed
+daniel
+proposed
+perfect
+agreement
+affairs
+3rd
+seconds
+contemporary
+paid
+1943
+prison
+save
+kitchen
+label
+administrative
+intended
+constructed
+academic
+nice
+teacher
+races
+1956
+formerly
+corporation
+ben
+nation
+issued
+shut
+1958
+drums
+housing
+victoria
+seems
+opera
+1959
+graduated
+function
+von
+mentioned
+picked
+build
+recognized
+shortly
+protection
+picture
+notable
+exchange
+elections
+1980s
+loved
+percent
+racing
+fish
+elizabeth
+garden
+volume
+hockey
+1941
+beside
+settled
+##ford
+1940
+competed
+replied
+drew
+1948
+actress
+marine
+scotland
+steel
+glanced
+farm
+steve
+1957
+risk
+tonight
+positive
+magic
+singles
+effects
+gray
+screen
+dog
+##ja
+residents
+bus
+sides
+none
+secondary
+literature
+polish
+destroyed
+flying
+founder
+households
+1939
+lay
+reserve
+usa
+gallery
+##ler
+1946
+industrial
+younger
+approach
+appearances
+urban
+ones
+1950
+finish
+avenue
+powerful
+fully
+growth
+page
+honor
+jersey
+projects
+advanced
+revealed
+basic
+90
+infantry
+pair
+equipment
+visit
+33
+evening
+search
+grant
+effort
+solo
+treatment
+buried
+republican
+primarily
+bottom
+owner
+1970s
+israel
+gives
+jim
+dream
+bob
+remain
+spot
+70
+notes
+produce
+champions
+contact
+ed
+soul
+accepted
+ways
+del
+##ally
+losing
+split
+price
+capacity
+basis
+trial
+questions
+##ina
+1955
+20th
+guess
+officially
+memorial
+naval
+initial
+##ization
+whispered
+median
+engineer
+##ful
+sydney
+##go
+columbia
+strength
+300
+1952
+tears
+senate
+00
+card
+asian
+agent
+1947
+software
+44
+draw
+warm
+supposed
+com
+pro
+##il
+transferred
+leaned
+##at
+candidate
+escape
+mountains
+asia
+potential
+activity
+entertainment
+seem
+traffic
+jackson
+murder
+36
+slow
+product
+orchestra
+haven
+agency
+bbc
+taught
+website
+comedy
+unable
+storm
+planning
+albums
+rugby
+environment
+scientific
+grabbed
+protect
+##hi
+boat
+typically
+1954
+1953
+damage
+principal
+divided
+dedicated
+mount
+ohio
+##berg
+pick
+fought
+driver
+##der
+empty
+shoulders
+sort
+thank
+berlin
+prominent
+account
+freedom
+necessary
+efforts
+alex
+headquarters
+follows
+alongside
+des
+simon
+andrew
+suggested
+operating
+learning
+steps
+1949
+sweet
+technical
+begin
+easily
+34
+teeth
+speaking
+settlement
+scale
+##sh
+renamed
+ray
+max
+enemy
+semi
+joint
+compared
+##rd
+scottish
+leadership
+analysis
+offers
+georgia
+pieces
+captured
+animal
+deputy
+guest
+organized
+##lin
+tony
+combined
+method
+challenge
+1960s
+huge
+wants
+battalion
+sons
+rise
+crime
+types
+facilities
+telling
+path
+1951
+platform
+sit
+1990s
+##lo
+tells
+assigned
+rich
+pull
+##ot
+commonly
+alive
+##za
+letters
+concept
+conducted
+wearing
+happen
+bought
+becomes
+holy
+gets
+ocean
+defeat
+languages
+purchased
+coffee
+occurred
+titled
+##q
+declared
+applied
+sciences
+concert
+sounds
+jazz
+brain
+##me
+painting
+fleet
+tax
+nick
+##ius
+michigan
+count
+animals
+leaders
+episodes
+##line
+content
+##den
+birth
+##it
+clubs
+64
+palace
+critical
+refused
+fair
+leg
+laughed
+returning
+surrounding
+participated
+formation
+lifted
+pointed
+connected
+rome
+medicine
+laid
+taylor
+santa
+powers
+adam
+tall
+shared
+focused
+knowing
+yards
+entrance
+falls
+##wa
+calling
+##ad
+sources
+chosen
+beneath
+resources
+yard
+##ite
+nominated
+silence
+zone
+defined
+##que
+gained
+thirty
+38
+bodies
+moon
+##ard
+adopted
+christmas
+widely
+register
+apart
+iran
+premier
+serves
+du
+unknown
+parties
+##les
+generation
+##ff
+continues
+quick
+fields
+brigade
+quiet
+teaching
+clothes
+impact
+weapons
+partner
+flat
+theater
+supreme
+1938
+37
+relations
+##tor
+plants
+suffered
+1936
+wilson
+kids
+begins
+##age
+1918
+seats
+armed
+internet
+models
+worth
+laws
+400
+communities
+classes
+background
+knows
+thanks
+quarter
+reaching
+humans
+carry
+killing
+format
+kong
+hong
+setting
+75
+architecture
+disease
+railroad
+inc
+possibly
+wish
+arthur
+thoughts
+harry
+doors
+density
+##di
+crowd
+illinois
+stomach
+tone
+unique
+reports
+anyway
+##ir
+liberal
+der
+vehicle
+thick
+dry
+drug
+faced
+largely
+facility
+theme
+holds
+creation
+strange
+colonel
+##mi
+revolution
+bell
+politics
+turns
+silent
+rail
+relief
+independence
+combat
+shape
+write
+determined
+sales
+learned
+4th
+finger
+oxford
+providing
+1937
+heritage
+fiction
+situated
+designated
+allowing
+distribution
+hosted
+##est
+sight
+interview
+estimated
+reduced
+##ria
+toronto
+footballer
+keeping
+guys
+damn
+claim
+motion
+sport
+sixth
+stayed
+##ze
+en
+rear
+receive
+handed
+twelve
+dress
+audience
+granted
+brazil
+##well
+spirit
+##ated
+noticed
+etc
+olympic
+representative
+eric
+tight
+trouble
+reviews
+drink
+vampire
+missing
+roles
+ranked
+newly
+household
+finals
+wave
+critics
+##ee
+phase
+massachusetts
+pilot
+unlike
+philadelphia
+bright
+guns
+crown
+organizations
+roof
+42
+respectively
+clearly
+tongue
+marked
+circle
+fox
+korea
+bronze
+brian
+expanded
+sexual
+supply
+yourself
+inspired
+labour
+fc
+##ah
+reference
+vision
+draft
+connection
+brand
+reasons
+1935
+classic
+driving
+trip
+jesus
+cells
+entry
+1920
+neither
+trail
+claims
+atlantic
+orders
+labor
+nose
+afraid
+identified
+intelligence
+calls
+cancer
+attacked
+passing
+stephen
+positions
+imperial
+grey
+jason
+39
+sunday
+48
+swedish
+avoid
+extra
+uncle
+message
+covers
+allows
+surprise
+materials
+fame
+hunter
+##ji
+1930
+citizens
+figures
+davis
+environmental
+confirmed
+shit
+titles
+di
+performing
+difference
+acts
+attacks
+##ov
+existing
+votes
+opportunity
+nor
+shop
+entirely
+trains
+opposite
+pakistan
+##pa
+develop
+resulted
+representatives
+actions
+reality
+pressed
+##ish
+barely
+wine
+conversation
+faculty
+northwest
+ends
+documentary
+nuclear
+stock
+grace
+sets
+eat
+alternative
+##ps
+bag
+resulting
+creating
+surprised
+cemetery
+1919
+drop
+finding
+sarah
+cricket
+streets
+tradition
+ride
+1933
+exhibition
+target
+ear
+explained
+rain
+composer
+injury
+apartment
+municipal
+educational
+occupied
+netherlands
+clean
+billion
+constitution
+learn
+1914
+maximum
+classical
+francis
+lose
+opposition
+jose
+ontario
+bear
+core
+hills
+rolled
+ending
+drawn
+permanent
+fun
+##tes
+##lla
+lewis
+sites
+chamber
+ryan
+##way
+scoring
+height
+1934
+##house
+lyrics
+staring
+55
+officials
+1917
+snow
+oldest
+##tic
+orange
+##ger
+qualified
+interior
+apparently
+succeeded
+thousand
+dinner
+lights
+existence
+fans
+heavily
+41
+greatest
+conservative
+send
+bowl
+plus
+enter
+catch
+##un
+economy
+duty
+1929
+speech
+authorities
+princess
+performances
+versions
+shall
+graduate
+pictures
+effective
+remembered
+poetry
+desk
+crossed
+starring
+starts
+passenger
+sharp
+##ant
+acres
+ass
+weather
+falling
+rank
+fund
+supporting
+check
+adult
+publishing
+heads
+cm
+southeast
+lane
+##burg
+application
+bc
+##ura
+les
+condition
+transfer
+prevent
+display
+ex
+regions
+earl
+federation
+cool
+relatively
+answered
+besides
+1928
+obtained
+portion
+##town
+mix
+##ding
+reaction
+liked
+dean
+express
+peak
+1932
+##tte
+counter
+religion
+chain
+rare
+miller
+convention
+aid
+lie
+vehicles
+mobile
+perform
+squad
+wonder
+lying
+crazy
+sword
+##ping
+attempted
+centuries
+weren
+philosophy
+category
+##ize
+anna
+interested
+47
+sweden
+wolf
+frequently
+abandoned
+kg
+literary
+alliance
+task
+entitled
+##ay
+threw
+promotion
+factory
+tiny
+soccer
+visited
+matt
+fm
+achieved
+52
+defence
+internal
+persian
+43
+methods
+##ging
+arrested
+otherwise
+cambridge
+programming
+villages
+elementary
+districts
+rooms
+criminal
+conflict
+worry
+trained
+1931
+attempts
+waited
+signal
+bird
+truck
+subsequent
+programme
+##ol
+ad
+49
+communist
+details
+faith
+sector
+patrick
+carrying
+laugh
+##ss
+controlled
+korean
+showing
+origin
+fuel
+evil
+1927
+##ent
+brief
+identity
+darkness
+address
+pool
+missed
+publication
+web
+planet
+ian
+anne
+wings
+invited
+##tt
+briefly
+standards
+kissed
+##be
+ideas
+climate
+causing
+walter
+worse
+albert
+articles
+winners
+desire
+aged
+northeast
+dangerous
+gate
+doubt
+1922
+wooden
+multi
+##ky
+poet
+rising
+funding
+46
+communications
+communication
+violence
+copies
+prepared
+ford
+investigation
+skills
+1924
+pulling
+electronic
+##ak
+##ial
+##han
+containing
+ultimately
+offices
+singing
+understanding
+restaurant
+tomorrow
+fashion
+christ
+ward
+da
+pope
+stands
+5th
+flow
+studios
+aired
+commissioned
+contained
+exist
+fresh
+americans
+##per
+wrestling
+approved
+kid
+employed
+respect
+suit
+1925
+angel
+asking
+increasing
+frame
+angry
+selling
+1950s
+thin
+finds
+##nd
+temperature
+statement
+ali
+explain
+inhabitants
+towns
+extensive
+narrow
+51
+jane
+flowers
+images
+promise
+somewhere
+object
+fly
+closely
+##ls
+1912
+bureau
+cape
+1926
+weekly
+presidential
+legislative
+1921
+##ai
+##au
+launch
+founding
+##ny
+978
+##ring
+artillery
+strike
+un
+institutions
+roll
+writers
+landing
+chose
+kevin
+anymore
+pp
+##ut
+attorney
+fit
+dan
+billboard
+receiving
+agricultural
+breaking
+sought
+dave
+admitted
+lands
+mexican
+##bury
+charlie
+specifically
+hole
+iv
+howard
+credit
+moscow
+roads
+accident
+1923
+proved
+wear
+struck
+hey
+guards
+stuff
+slid
+expansion
+1915
+cat
+anthony
+##kin
+melbourne
+opposed
+sub
+southwest
+architect
+failure
+plane
+1916
+##ron
+map
+camera
+tank
+listen
+regarding
+wet
+introduction
+metropolitan
+link
+ep
+fighter
+inch
+grown
+gene
+anger
+fixed
+buy
+dvd
+khan
+domestic
+worldwide
+chapel
+mill
+functions
+examples
+##head
+developing
+1910
+turkey
+hits
+pocket
+antonio
+papers
+grow
+unless
+circuit
+18th
+concerned
+attached
+journalist
+selection
+journey
+converted
+provincial
+painted
+hearing
+aren
+bands
+negative
+aside
+wondered
+knight
+lap
+survey
+ma
+##ow
+noise
+billy
+##ium
+shooting
+guide
+bedroom
+priest
+resistance
+motor
+homes
+sounded
+giant
+##mer
+150
+scenes
+equal
+comic
+patients
+hidden
+solid
+actual
+bringing
+afternoon
+touched
+funds
+wedding
+consisted
+marie
+canal
+sr
+kim
+treaty
+turkish
+recognition
+residence
+cathedral
+broad
+knees
+incident
+shaped
+fired
+norwegian
+handle
+cheek
+contest
+represent
+##pe
+representing
+beauty
+##sen
+birds
+advantage
+emergency
+wrapped
+drawing
+notice
+pink
+broadcasting
+##ong
+somehow
+bachelor
+seventh
+collected
+registered
+establishment
+alan
+assumed
+chemical
+personnel
+roger
+retirement
+jeff
+portuguese
+wore
+tied
+device
+threat
+progress
+advance
+##ised
+banks
+hired
+manchester
+nfl
+teachers
+structures
+forever
+##bo
+tennis
+helping
+saturday
+sale
+applications
+junction
+hip
+incorporated
+neighborhood
+dressed
+ceremony
+##ds
+influenced
+hers
+visual
+stairs
+decades
+inner
+kansas
+hung
+hoped
+gain
+scheduled
+downtown
+engaged
+austria
+clock
+norway
+certainly
+pale
+protected
+1913
+victor
+employees
+plate
+putting
+surrounded
+##ists
+finishing
+blues
+tropical
+##ries
+minnesota
+consider
+philippines
+accept
+54
+retrieved
+1900
+concern
+anderson
+properties
+institution
+gordon
+successfully
+vietnam
+##dy
+backing
+outstanding
+muslim
+crossing
+folk
+producing
+usual
+demand
+occurs
+observed
+lawyer
+educated
+##ana
+kelly
+string
+pleasure
+budget
+items
+quietly
+colorado
+philip
+typical
+##worth
+derived
+600
+survived
+asks
+mental
+##ide
+56
+jake
+jews
+distinguished
+ltd
+1911
+sri
+extremely
+53
+athletic
+loud
+thousands
+worried
+shadow
+transportation
+horses
+weapon
+arena
+importance
+users
+tim
+objects
+contributed
+dragon
+douglas
+aware
+senator
+johnny
+jordan
+sisters
+engines
+flag
+investment
+samuel
+shock
+capable
+clark
+row
+wheel
+refers
+session
+familiar
+biggest
+wins
+hate
+maintained
+drove
+hamilton
+request
+expressed
+injured
+underground
+churches
+walker
+wars
+tunnel
+passes
+stupid
+agriculture
+softly
+cabinet
+regarded
+joining
+indiana
+##ea
+##ms
+push
+dates
+spend
+behavior
+woods
+protein
+gently
+chase
+morgan
+mention
+burning
+wake
+combination
+occur
+mirror
+leads
+jimmy
+indeed
+impossible
+singapore
+paintings
+covering
+##nes
+soldier
+locations
+attendance
+sell
+historian
+wisconsin
+invasion
+argued
+painter
+diego
+changing
+egypt
+##don
+experienced
+inches
+##ku
+missouri
+vol
+grounds
+spoken
+switzerland
+##gan
+reform
+rolling
+ha
+forget
+massive
+resigned
+burned
+allen
+tennessee
+locked
+values
+improved
+##mo
+wounded
+universe
+sick
+dating
+facing
+pack
+purchase
+user
+##pur
+moments
+##ul
+merged
+anniversary
+1908
+coal
+brick
+understood
+causes
+dynasty
+queensland
+establish
+stores
+crisis
+promote
+hoping
+views
+cards
+referee
+extension
+##si
+raise
+arizona
+improve
+colonial
+formal
+charged
+##rt
+palm
+lucky
+hide
+rescue
+faces
+95
+feelings
+candidates
+juan
+##ell
+goods
+6th
+courses
+weekend
+59
+luke
+cash
+fallen
+##om
+delivered
+affected
+installed
+carefully
+tries
+swiss
+hollywood
+costs
+lincoln
+responsibility
+##he
+shore
+file
+proper
+normally
+maryland
+assistance
+jump
+constant
+offering
+friendly
+waters
+persons
+realize
+contain
+trophy
+800
+partnership
+factor
+58
+musicians
+cry
+bound
+oregon
+indicated
+hero
+houston
+medium
+##ure
+consisting
+somewhat
+##ara
+57
+cycle
+##che
+beer
+moore
+frederick
+gotten
+eleven
+worst
+weak
+approached
+arranged
+chin
+loan
+universal
+bond
+fifteen
+pattern
+disappeared
+##ney
+translated
+##zed
+lip
+arab
+capture
+interests
+insurance
+##chi
+shifted
+cave
+prix
+warning
+sections
+courts
+coat
+plot
+smell
+feed
+golf
+favorite
+maintain
+knife
+vs
+voted
+degrees
+finance
+quebec
+opinion
+translation
+manner
+ruled
+operate
+productions
+choose
+musician
+discovery
+confused
+tired
+separated
+stream
+techniques
+committed
+attend
+ranking
+kings
+throw
+passengers
+measure
+horror
+fan
+mining
+sand
+danger
+salt
+calm
+decade
+dam
+require
+runner
+##ik
+rush
+associate
+greece
+##ker
+rivers
+consecutive
+matthew
+##ski
+sighed
+sq
+documents
+steam
+edited
+closing
+tie
+accused
+1905
+##ini
+islamic
+distributed
+directors
+organisation
+bruce
+7th
+breathing
+mad
+lit
+arrival
+concrete
+taste
+08
+composition
+shaking
+faster
+amateur
+adjacent
+stating
+1906
+twin
+flew
+##ran
+tokyo
+publications
+##tone
+obviously
+ridge
+storage
+1907
+carl
+pages
+concluded
+desert
+driven
+universities
+ages
+terminal
+sequence
+borough
+250
+constituency
+creative
+cousin
+economics
+dreams
+margaret
+notably
+reduce
+montreal
+mode
+17th
+ears
+saved
+jan
+vocal
+##ica
+1909
+andy
+##jo
+riding
+roughly
+threatened
+##ise
+meters
+meanwhile
+landed
+compete
+repeated
+grass
+czech
+regularly
+charges
+tea
+sudden
+appeal
+##ung
+solution
+describes
+pierre
+classification
+glad
+parking
+##ning
+belt
+physics
+99
+rachel
+add
+hungarian
+participate
+expedition
+damaged
+gift
+childhood
+85
+fifty
+##red
+mathematics
+jumped
+letting
+defensive
+mph
+##ux
+##gh
+testing
+##hip
+hundreds
+shoot
+owners
+matters
+smoke
+israeli
+kentucky
+dancing
+mounted
+grandfather
+emma
+designs
+profit
+argentina
+##gs
+truly
+li
+lawrence
+cole
+begun
+detroit
+willing
+branches
+smiling
+decide
+miami
+enjoyed
+recordings
+##dale
+poverty
+ethnic
+gay
+##bi
+gary
+arabic
+09
+accompanied
+##one
+##ons
+fishing
+determine
+residential
+acid
+##ary
+alice
+returns
+starred
+mail
+##ang
+jonathan
+strategy
+##ue
+net
+forty
+cook
+businesses
+equivalent
+commonwealth
+distinct
+ill
+##cy
+seriously
+##ors
+##ped
+shift
+harris
+replace
+rio
+imagine
+formula
+ensure
+##ber
+additionally
+scheme
+conservation
+occasionally
+purposes
+feels
+favor
+##and
+##ore
+1930s
+contrast
+hanging
+hunt
+movies
+1904
+instruments
+victims
+danish
+christopher
+busy
+demon
+sugar
+earliest
+colony
+studying
+balance
+duties
+##ks
+belgium
+slipped
+carter
+05
+visible
+stages
+iraq
+fifa
+##im
+commune
+forming
+zero
+07
+continuing
+talked
+counties
+legend
+bathroom
+option
+tail
+clay
+daughters
+afterwards
+severe
+jaw
+visitors
+##ded
+devices
+aviation
+russell
+kate
+##vi
+entering
+subjects
+##ino
+temporary
+swimming
+forth
+smooth
+ghost
+audio
+bush
+operates
+rocks
+movements
+signs
+eddie
+##tz
+ann
+voices
+honorary
+06
+memories
+dallas
+pure
+measures
+racial
+promised
+66
+harvard
+ceo
+16th
+parliamentary
+indicate
+benefit
+flesh
+dublin
+louisiana
+1902
+1901
+patient
+sleeping
+1903
+membership
+coastal
+medieval
+wanting
+element
+scholars
+rice
+62
+limit
+survive
+makeup
+rating
+definitely
+collaboration
+obvious
+##tan
+boss
+ms
+baron
+birthday
+linked
+soil
+diocese
+##lan
+ncaa
+##mann
+offensive
+shell
+shouldn
+waist
+##tus
+plain
+ross
+organ
+resolution
+manufacturing
+adding
+relative
+kennedy
+98
+whilst
+moth
+marketing
+gardens
+crash
+72
+heading
+partners
+credited
+carlos
+moves
+cable
+##zi
+marshall
+##out
+depending
+bottle
+represents
+rejected
+responded
+existed
+04
+jobs
+denmark
+lock
+##ating
+treated
+graham
+routes
+talent
+commissioner
+drugs
+secure
+tests
+reign
+restored
+photography
+##gi
+contributions
+oklahoma
+designer
+disc
+grin
+seattle
+robin
+paused
+atlanta
+unusual
+##gate
+praised
+las
+laughing
+satellite
+hungary
+visiting
+##sky
+interesting
+factors
+deck
+poems
+norman
+##water
+stuck
+speaker
+rifle
+domain
+premiered
+##her
+dc
+comics
+actors
+01
+reputation
+eliminated
+8th
+ceiling
+prisoners
+script
+##nce
+leather
+austin
+mississippi
+rapidly
+admiral
+parallel
+charlotte
+guilty
+tools
+gender
+divisions
+fruit
+##bs
+laboratory
+nelson
+fantasy
+marry
+rapid
+aunt
+tribe
+requirements
+aspects
+suicide
+amongst
+adams
+bone
+ukraine
+abc
+kick
+sees
+edinburgh
+clothing
+column
+rough
+gods
+hunting
+broadway
+gathered
+concerns
+##ek
+spending
+ty
+12th
+snapped
+requires
+solar
+bones
+cavalry
+##tta
+iowa
+drinking
+waste
+index
+franklin
+charity
+thompson
+stewart
+tip
+flash
+landscape
+friday
+enjoy
+singh
+poem
+listening
+##back
+eighth
+fred
+differences
+adapted
+bomb
+ukrainian
+surgery
+corporate
+masters
+anywhere
+##more
+waves
+odd
+sean
+portugal
+orleans
+dick
+debate
+kent
+eating
+puerto
+cleared
+96
+expect
+cinema
+97
+guitarist
+blocks
+electrical
+agree
+involving
+depth
+dying
+panel
+struggle
+##ged
+peninsula
+adults
+novels
+emerged
+vienna
+metro
+debuted
+shoes
+tamil
+songwriter
+meets
+prove
+beating
+instance
+heaven
+scared
+sending
+marks
+artistic
+passage
+superior
+03
+significantly
+shopping
+##tive
+retained
+##izing
+malaysia
+technique
+cheeks
+##ola
+warren
+maintenance
+destroy
+extreme
+allied
+120
+appearing
+##yn
+fill
+advice
+alabama
+qualifying
+policies
+cleveland
+hat
+battery
+smart
+authors
+10th
+soundtrack
+acted
+dated
+lb
+glance
+equipped
+coalition
+funny
+outer
+ambassador
+roy
+possibility
+couples
+campbell
+dna
+loose
+ethan
+supplies
+1898
+gonna
+88
+monster
+##res
+shake
+agents
+frequency
+springs
+dogs
+practices
+61
+gang
+plastic
+easier
+suggests
+gulf
+blade
+exposed
+colors
+industries
+markets
+pan
+nervous
+electoral
+charts
+legislation
+ownership
+##idae
+mac
+appointment
+shield
+copy
+assault
+socialist
+abbey
+monument
+license
+throne
+employment
+jay
+93
+replacement
+charter
+cloud
+powered
+suffering
+accounts
+oak
+connecticut
+strongly
+wright
+colour
+crystal
+13th
+context
+welsh
+networks
+voiced
+gabriel
+jerry
+##cing
+forehead
+mp
+##ens
+manage
+schedule
+totally
+remix
+##ii
+forests
+occupation
+print
+nicholas
+brazilian
+strategic
+vampires
+engineers
+76
+roots
+seek
+correct
+instrumental
+und
+alfred
+backed
+hop
+##des
+stanley
+robinson
+traveled
+wayne
+welcome
+austrian
+achieve
+67
+exit
+rates
+1899
+strip
+whereas
+##cs
+sing
+deeply
+adventure
+bobby
+rick
+jamie
+careful
+components
+cap
+useful
+personality
+knee
+##shi
+pushing
+hosts
+02
+protest
+ca
+ottoman
+symphony
+##sis
+63
+boundary
+1890
+processes
+considering
+considerable
+tons
+##work
+##ft
+##nia
+cooper
+trading
+dear
+conduct
+91
+illegal
+apple
+revolutionary
+holiday
+definition
+harder
+##van
+jacob
+circumstances
+destruction
+##lle
+popularity
+grip
+classified
+liverpool
+donald
+baltimore
+flows
+seeking
+honour
+approval
+92
+mechanical
+till
+happening
+statue
+critic
+increasingly
+immediate
+describe
+commerce
+stare
+##ster
+indonesia
+meat
+rounds
+boats
+baker
+orthodox
+depression
+formally
+worn
+naked
+claire
+muttered
+sentence
+11th
+emily
+document
+77
+criticism
+wished
+vessel
+spiritual
+bent
+virgin
+parker
+minimum
+murray
+lunch
+danny
+printed
+compilation
+keyboards
+false
+blow
+belonged
+68
+raising
+78
+cutting
+##board
+pittsburgh
+##up
+9th
+shadows
+81
+hated
+indigenous
+jon
+15th
+barry
+scholar
+ah
+##zer
+oliver
+##gy
+stick
+susan
+meetings
+attracted
+spell
+romantic
+##ver
+ye
+1895
+photo
+demanded
+customers
+##ac
+1896
+logan
+revival
+keys
+modified
+commanded
+jeans
+##ious
+upset
+raw
+phil
+detective
+hiding
+resident
+vincent
+##bly
+experiences
+diamond
+defeating
+coverage
+lucas
+external
+parks
+franchise
+helen
+bible
+successor
+percussion
+celebrated
+il
+lift
+profile
+clan
+romania
+##ied
+mills
+##su
+nobody
+achievement
+shrugged
+fault
+1897
+rhythm
+initiative
+breakfast
+carbon
+700
+69
+lasted
+violent
+74
+wound
+ken
+killer
+gradually
+filmed
+°c
+dollars
+processing
+94
+remove
+criticized
+guests
+sang
+chemistry
+##vin
+legislature
+disney
+##bridge
+uniform
+escaped
+integrated
+proposal
+purple
+denied
+liquid
+karl
+influential
+morris
+nights
+stones
+intense
+experimental
+twisted
+71
+84
+##ld
+pace
+nazi
+mitchell
+ny
+blind
+reporter
+newspapers
+14th
+centers
+burn
+basin
+forgotten
+surviving
+filed
+collections
+monastery
+losses
+manual
+couch
+description
+appropriate
+merely
+tag
+missions
+sebastian
+restoration
+replacing
+triple
+73
+elder
+julia
+warriors
+benjamin
+julian
+convinced
+stronger
+amazing
+declined
+versus
+merchant
+happens
+output
+finland
+bare
+barbara
+absence
+ignored
+dawn
+injuries
+##port
+producers
+##ram
+82
+luis
+##ities
+kw
+admit
+expensive
+electricity
+nba
+exception
+symbol
+##ving
+ladies
+shower
+sheriff
+characteristics
+##je
+aimed
+button
+ratio
+effectively
+summit
+angle
+jury
+bears
+foster
+vessels
+pants
+executed
+evans
+dozen
+advertising
+kicked
+patrol
+1889
+competitions
+lifetime
+principles
+athletics
+##logy
+birmingham
+sponsored
+89
+rob
+nomination
+1893
+acoustic
+##sm
+creature
+longest
+##tra
+credits
+harbor
+dust
+josh
+##so
+territories
+milk
+infrastructure
+completion
+thailand
+indians
+leon
+archbishop
+##sy
+assist
+pitch
+blake
+arrangement
+girlfriend
+serbian
+operational
+hence
+sad
+scent
+fur
+dj
+sessions
+hp
+refer
+rarely
+##ora
+exists
+1892
+##ten
+scientists
+dirty
+penalty
+burst
+portrait
+seed
+79
+pole
+limits
+rival
+1894
+stable
+alpha
+grave
+constitutional
+alcohol
+arrest
+flower
+mystery
+devil
+architectural
+relationships
+greatly
+habitat
+##istic
+larry
+progressive
+remote
+cotton
+##ics
+##ok
+preserved
+reaches
+##ming
+cited
+86
+vast
+scholarship
+decisions
+cbs
+joy
+teach
+1885
+editions
+knocked
+eve
+searching
+partly
+participation
+gap
+animated
+fate
+excellent
+##ett
+na
+87
+alternate
+saints
+youngest
+##ily
+climbed
+##ita
+##tors
+suggest
+##ct
+discussion
+staying
+choir
+lakes
+jacket
+revenue
+nevertheless
+peaked
+instrument
+wondering
+annually
+managing
+neil
+1891
+signing
+terry
+##ice
+apply
+clinical
+brooklyn
+aim
+catherine
+fuck
+farmers
+figured
+ninth
+pride
+hugh
+evolution
+ordinary
+involvement
+comfortable
+shouted
+tech
+encouraged
+taiwan
+representation
+sharing
+##lia
+##em
+panic
+exact
+cargo
+competing
+fat
+cried
+83
+1920s
+occasions
+pa
+cabin
+borders
+utah
+marcus
+##isation
+badly
+muscles
+##ance
+victorian
+transition
+warner
+bet
+permission
+##rin
+slave
+terrible
+similarly
+shares
+seth
+uefa
+possession
+medals
+benefits
+colleges
+lowered
+perfectly
+mall
+transit
+##ye
+##kar
+publisher
+##ened
+harrison
+deaths
+elevation
+##ae
+asleep
+machines
+sigh
+ash
+hardly
+argument
+occasion
+parent
+leo
+decline
+1888
+contribution
+##ua
+concentration
+1000
+opportunities
+hispanic
+guardian
+extent
+emotions
+hips
+mason
+volumes
+bloody
+controversy
+diameter
+steady
+mistake
+phoenix
+identify
+violin
+##sk
+departure
+richmond
+spin
+funeral
+enemies
+1864
+gear
+literally
+connor
+random
+sergeant
+grab
+confusion
+1865
+transmission
+informed
+op
+leaning
+sacred
+suspended
+thinks
+gates
+portland
+luck
+agencies
+yours
+hull
+expert
+muscle
+layer
+practical
+sculpture
+jerusalem
+latest
+lloyd
+statistics
+deeper
+recommended
+warrior
+arkansas
+mess
+supports
+greg
+eagle
+1880
+recovered
+rated
+concerts
+rushed
+##ano
+stops
+eggs
+files
+premiere
+keith
+##vo
+delhi
+turner
+pit
+affair
+belief
+paint
+##zing
+mate
+##ach
+##ev
+victim
+##ology
+withdrew
+bonus
+styles
+fled
+##ud
+glasgow
+technologies
+funded
+nbc
+adaptation
+##ata
+portrayed
+cooperation
+supporters
+judges
+bernard
+justin
+hallway
+ralph
+##ick
+graduating
+controversial
+distant
+continental
+spider
+bite
+##ho
+recognize
+intention
+mixing
+##ese
+egyptian
+bow
+tourism
+suppose
+claiming
+tiger
+dominated
+participants
+vi
+##ru
+nurse
+partially
+tape
+##rum
+psychology
+##rn
+essential
+touring
+duo
+voting
+civilian
+emotional
+channels
+##king
+apparent
+hebrew
+1887
+tommy
+carrier
+intersection
+beast
+hudson
+##gar
+##zo
+lab
+nova
+bench
+discuss
+costa
+##ered
+detailed
+behalf
+drivers
+unfortunately
+obtain
+##lis
+rocky
+##dae
+siege
+friendship
+honey
+##rian
+1861
+amy
+hang
+posted
+governments
+collins
+respond
+wildlife
+preferred
+operator
+##po
+laura
+pregnant
+videos
+dennis
+suspected
+boots
+instantly
+weird
+automatic
+businessman
+alleged
+placing
+throwing
+ph
+mood
+1862
+perry
+venue
+jet
+remainder
+##lli
+##ci
+passion
+biological
+boyfriend
+1863
+dirt
+buffalo
+ron
+segment
+fa
+abuse
+##era
+genre
+thrown
+stroke
+colored
+stress
+exercise
+displayed
+##gen
+struggled
+##tti
+abroad
+dramatic
+wonderful
+thereafter
+madrid
+component
+widespread
+##sed
+tale
+citizen
+todd
+monday
+1886
+vancouver
+overseas
+forcing
+crying
+descent
+##ris
+discussed
+substantial
+ranks
+regime
+1870
+provinces
+switch
+drum
+zane
+ted
+tribes
+proof
+lp
+cream
+researchers
+volunteer
+manor
+silk
+milan
+donated
+allies
+venture
+principle
+delivery
+enterprise
+##ves
+##ans
+bars
+traditionally
+witch
+reminded
+copper
+##uk
+pete
+inter
+links
+colin
+grinned
+elsewhere
+competitive
+frequent
+##oy
+scream
+##hu
+tension
+texts
+submarine
+finnish
+defending
+defend
+pat
+detail
+1884
+affiliated
+stuart
+themes
+villa
+periods
+tool
+belgian
+ruling
+crimes
+answers
+folded
+licensed
+resort
+demolished
+hans
+lucy
+1881
+lion
+traded
+photographs
+writes
+craig
+##fa
+trials
+generated
+beth
+noble
+debt
+percentage
+yorkshire
+erected
+ss
+viewed
+grades
+confidence
+ceased
+islam
+telephone
+retail
+##ible
+chile
+m²
+roberts
+sixteen
+##ich
+commented
+hampshire
+innocent
+dual
+pounds
+checked
+regulations
+afghanistan
+sung
+rico
+liberty
+assets
+bigger
+options
+angels
+relegated
+tribute
+wells
+attending
+leaf
+##yan
+butler
+romanian
+forum
+monthly
+lisa
+patterns
+gmina
+##tory
+madison
+hurricane
+rev
+##ians
+bristol
+##ula
+elite
+valuable
+disaster
+democracy
+awareness
+germans
+freyja
+##ins
+loop
+absolutely
+paying
+populations
+maine
+sole
+prayer
+spencer
+releases
+doorway
+bull
+##ani
+lover
+midnight
+conclusion
+##sson
+thirteen
+lily
+mediterranean
+##lt
+nhl
+proud
+sample
+##hill
+drummer
+guinea
+##ova
+murphy
+climb
+##ston
+instant
+attributed
+horn
+ain
+railways
+steven
+##ao
+autumn
+ferry
+opponent
+root
+traveling
+secured
+corridor
+stretched
+tales
+sheet
+trinity
+cattle
+helps
+indicates
+manhattan
+murdered
+fitted
+1882
+gentle
+grandmother
+mines
+shocked
+vegas
+produces
+##light
+caribbean
+##ou
+belong
+continuous
+desperate
+drunk
+historically
+trio
+waved
+raf
+dealing
+nathan
+bat
+murmured
+interrupted
+residing
+scientist
+pioneer
+harold
+aaron
+##net
+delta
+attempting
+minority
+mini
+believes
+chorus
+tend
+lots
+eyed
+indoor
+load
+shots
+updated
+jail
+##llo
+concerning
+connecting
+wealth
+##ved
+slaves
+arrive
+rangers
+sufficient
+rebuilt
+##wick
+cardinal
+flood
+muhammad
+whenever
+relation
+runners
+moral
+repair
+viewers
+arriving
+revenge
+punk
+assisted
+bath
+fairly
+breathe
+lists
+innings
+illustrated
+whisper
+nearest
+voters
+clinton
+ties
+ultimate
+screamed
+beijing
+lions
+andre
+fictional
+gathering
+comfort
+radar
+suitable
+dismissed
+hms
+ban
+pine
+wrist
+atmosphere
+voivodeship
+bid
+timber
+##ned
+##nan
+giants
+##ane
+cameron
+recovery
+uss
+identical
+categories
+switched
+serbia
+laughter
+noah
+ensemble
+therapy
+peoples
+touching
+##off
+locally
+pearl
+platforms
+everywhere
+ballet
+tables
+lanka
+herbert
+outdoor
+toured
+derek
+1883
+spaces
+contested
+swept
+1878
+exclusive
+slight
+connections
+##dra
+winds
+prisoner
+collective
+bangladesh
+tube
+publicly
+wealthy
+thai
+##ys
+isolated
+select
+##ric
+insisted
+pen
+fortune
+ticket
+spotted
+reportedly
+animation
+enforcement
+tanks
+110
+decides
+wider
+lowest
+owen
+##time
+nod
+hitting
+##hn
+gregory
+furthermore
+magazines
+fighters
+solutions
+##ery
+pointing
+requested
+peru
+reed
+chancellor
+knights
+mask
+worker
+eldest
+flames
+reduction
+1860
+volunteers
+##tis
+reporting
+##hl
+wire
+advisory
+endemic
+origins
+settlers
+pursue
+knock
+consumer
+1876
+eu
+compound
+creatures
+mansion
+sentenced
+ivan
+deployed
+guitars
+frowned
+involves
+mechanism
+kilometers
+perspective
+shops
+maps
+terminus
+duncan
+alien
+fist
+bridges
+##pers
+heroes
+fed
+derby
+swallowed
+##ros
+patent
+sara
+illness
+characterized
+adventures
+slide
+hawaii
+jurisdiction
+##op
+organised
+##side
+adelaide
+walks
+biology
+se
+##ties
+rogers
+swing
+tightly
+boundaries
+##rie
+prepare
+implementation
+stolen
+##sha
+certified
+colombia
+edwards
+garage
+##mm
+recalled
+##ball
+rage
+harm
+nigeria
+breast
+##ren
+furniture
+pupils
+settle
+##lus
+cuba
+balls
+client
+alaska
+21st
+linear
+thrust
+celebration
+latino
+genetic
+terror
+##cia
+##ening
+lightning
+fee
+witness
+lodge
+establishing
+skull
+##ique
+earning
+hood
+##ei
+rebellion
+wang
+sporting
+warned
+missile
+devoted
+activist
+porch
+worship
+fourteen
+package
+1871
+decorated
+##shire
+housed
+##ock
+chess
+sailed
+doctors
+oscar
+joan
+treat
+garcia
+harbour
+jeremy
+##ire
+traditions
+dominant
+jacques
+##gon
+##wan
+relocated
+1879
+amendment
+sized
+companion
+simultaneously
+volleyball
+spun
+acre
+increases
+stopping
+loves
+belongs
+affect
+drafted
+tossed
+scout
+battles
+1875
+filming
+shoved
+munich
+tenure
+vertical
+romance
+pc
+##cher
+argue
+##ical
+craft
+ranging
+www
+opens
+honest
+tyler
+yesterday
+virtual
+##let
+muslims
+reveal
+snake
+immigrants
+radical
+screaming
+speakers
+firing
+saving
+belonging
+ease
+lighting
+prefecture
+blame
+farmer
+hungry
+grows
+rubbed
+beam
+sur
+subsidiary
+##cha
+armenian
+sao
+dropping
+conventional
+##fer
+microsoft
+reply
+qualify
+spots
+1867
+sweat
+festivals
+##ken
+immigration
+physician
+discover
+exposure
+sandy
+explanation
+isaac
+implemented
+##fish
+hart
+initiated
+connect
+stakes
+presents
+heights
+householder
+pleased
+tourist
+regardless
+slip
+closest
+##ction
+surely
+sultan
+brings
+riley
+preparation
+aboard
+slammed
+baptist
+experiment
+ongoing
+interstate
+organic
+playoffs
+##ika
+1877
+130
+##tar
+hindu
+error
+tours
+tier
+plenty
+arrangements
+talks
+trapped
+excited
+sank
+ho
+athens
+1872
+denver
+welfare
+suburb
+athletes
+trick
+diverse
+belly
+exclusively
+yelled
+1868
+##med
+conversion
+##ette
+1874
+internationally
+computers
+conductor
+abilities
+sensitive
+hello
+dispute
+measured
+globe
+rocket
+prices
+amsterdam
+flights
+tigers
+inn
+municipalities
+emotion
+references
+3d
+##mus
+explains
+airlines
+manufactured
+pm
+archaeological
+1873
+interpretation
+devon
+comment
+##ites
+settlements
+kissing
+absolute
+improvement
+suite
+impressed
+barcelona
+sullivan
+jefferson
+towers
+jesse
+julie
+##tin
+##lu
+grandson
+hi
+gauge
+regard
+rings
+interviews
+trace
+raymond
+thumb
+departments
+burns
+serial
+bulgarian
+scores
+demonstrated
+##ix
+1866
+kyle
+alberta
+underneath
+romanized
+##ward
+relieved
+acquisition
+phrase
+cliff
+reveals
+han
+cuts
+merger
+custom
+##dar
+nee
+gilbert
+graduation
+##nts
+assessment
+cafe
+difficulty
+demands
+swung
+democrat
+jennifer
+commons
+1940s
+grove
+##yo
+completing
+focuses
+sum
+substitute
+bearing
+stretch
+reception
+##py
+reflected
+essentially
+destination
+pairs
+##ched
+survival
+resource
+##bach
+promoting
+doubles
+messages
+tear
+##down
+##fully
+parade
+florence
+harvey
+incumbent
+partial
+framework
+900
+pedro
+frozen
+procedure
+olivia
+controls
+##mic
+shelter
+personally
+temperatures
+##od
+brisbane
+tested
+sits
+marble
+comprehensive
+oxygen
+leonard
+##kov
+inaugural
+iranian
+referring
+quarters
+attitude
+##ivity
+mainstream
+lined
+mars
+dakota
+norfolk
+unsuccessful
+##°
+explosion
+helicopter
+congressional
+##sing
+inspector
+bitch
+seal
+departed
+divine
+##ters
+coaching
+examination
+punishment
+manufacturer
+sink
+columns
+unincorporated
+signals
+nevada
+squeezed
+dylan
+dining
+photos
+martial
+manuel
+eighteen
+elevator
+brushed
+plates
+ministers
+ivy
+congregation
+##len
+slept
+specialized
+taxes
+curve
+restricted
+negotiations
+likes
+statistical
+arnold
+inspiration
+execution
+bold
+intermediate
+significance
+margin
+ruler
+wheels
+gothic
+intellectual
+dependent
+listened
+eligible
+buses
+widow
+syria
+earn
+cincinnati
+collapsed
+recipient
+secrets
+accessible
+philippine
+maritime
+goddess
+clerk
+surrender
+breaks
+playoff
+database
+##ified
+##lon
+ideal
+beetle
+aspect
+soap
+regulation
+strings
+expand
+anglo
+shorter
+crosses
+retreat
+tough
+coins
+wallace
+directions
+pressing
+##oon
+shipping
+locomotives
+comparison
+topics
+nephew
+##mes
+distinction
+honors
+travelled
+sierra
+ibn
+##over
+fortress
+sa
+recognised
+carved
+1869
+clients
+##dan
+intent
+##mar
+coaches
+describing
+bread
+##ington
+beaten
+northwestern
+##ona
+merit
+youtube
+collapse
+challenges
+em
+historians
+objective
+submitted
+virus
+attacking
+drake
+assume
+##ere
+diseases
+marc
+stem
+leeds
+##cus
+##ab
+farming
+glasses
+##lock
+visits
+nowhere
+fellowship
+relevant
+carries
+restaurants
+experiments
+101
+constantly
+bases
+targets
+shah
+tenth
+opponents
+verse
+territorial
+##ira
+writings
+corruption
+##hs
+instruction
+inherited
+reverse
+emphasis
+##vic
+employee
+arch
+keeps
+rabbi
+watson
+payment
+uh
+##ala
+nancy
+##tre
+venice
+fastest
+sexy
+banned
+adrian
+properly
+ruth
+touchdown
+dollar
+boards
+metre
+circles
+edges
+favour
+comments
+ok
+travels
+liberation
+scattered
+firmly
+##ular
+holland
+permitted
+diesel
+kenya
+den
+originated
+##ral
+demons
+resumed
+dragged
+rider
+##rus
+servant
+blinked
+extend
+torn
+##ias
+##sey
+input
+meal
+everybody
+cylinder
+kinds
+camps
+##fe
+bullet
+logic
+##wn
+croatian
+evolved
+healthy
+fool
+chocolate
+wise
+preserve
+pradesh
+##ess
+respective
+1850
+##ew
+chicken
+artificial
+gross
+corresponding
+convicted
+cage
+caroline
+dialogue
+##dor
+narrative
+stranger
+mario
+br
+christianity
+failing
+trent
+commanding
+buddhist
+1848
+maurice
+focusing
+yale
+bike
+altitude
+##ering
+mouse
+revised
+##sley
+veteran
+##ig
+pulls
+theology
+crashed
+campaigns
+legion
+##ability
+drag
+excellence
+customer
+cancelled
+intensity
+excuse
+##lar
+liga
+participating
+contributing
+printing
+##burn
+variable
+##rk
+curious
+bin
+legacy
+renaissance
+##my
+symptoms
+binding
+vocalist
+dancer
+##nie
+grammar
+gospel
+democrats
+ya
+enters
+sc
+diplomatic
+hitler
+##ser
+clouds
+mathematical
+quit
+defended
+oriented
+##heim
+fundamental
+hardware
+impressive
+equally
+convince
+confederate
+guilt
+chuck
+sliding
+##ware
+magnetic
+narrowed
+petersburg
+bulgaria
+otto
+phd
+skill
+##ama
+reader
+hopes
+pitcher
+reservoir
+hearts
+automatically
+expecting
+mysterious
+bennett
+extensively
+imagined
+seeds
+monitor
+fix
+##ative
+journalism
+struggling
+signature
+ranch
+encounter
+photographer
+observation
+protests
+##pin
+influences
+##hr
+calendar
+##all
+cruz
+croatia
+locomotive
+hughes
+naturally
+shakespeare
+basement
+hook
+uncredited
+faded
+theories
+approaches
+dare
+phillips
+filling
+fury
+obama
+##ain
+efficient
+arc
+deliver
+min
+raid
+breeding
+inducted
+leagues
+efficiency
+axis
+montana
+eagles
+##ked
+supplied
+instructions
+karen
+picking
+indicating
+trap
+anchor
+practically
+christians
+tomb
+vary
+occasional
+electronics
+lords
+readers
+newcastle
+faint
+innovation
+collect
+situations
+engagement
+160
+claude
+mixture
+##feld
+peer
+tissue
+logo
+lean
+##ration
+°f
+floors
+##ven
+architects
+reducing
+##our
+##ments
+rope
+1859
+ottawa
+##har
+samples
+banking
+declaration
+proteins
+resignation
+francois
+saudi
+advocate
+exhibited
+armor
+twins
+divorce
+##ras
+abraham
+reviewed
+jo
+temporarily
+matrix
+physically
+pulse
+curled
+##ena
+difficulties
+bengal
+usage
+##ban
+annie
+riders
+certificate
+##pi
+holes
+warsaw
+distinctive
+jessica
+##mon
+mutual
+1857
+customs
+circular
+eugene
+removal
+loaded
+mere
+vulnerable
+depicted
+generations
+dame
+heir
+enormous
+lightly
+climbing
+pitched
+lessons
+pilots
+nepal
+ram
+google
+preparing
+brad
+louise
+renowned
+##₂
+liam
+##ably
+plaza
+shaw
+sophie
+brilliant
+bills
+##bar
+##nik
+fucking
+mainland
+server
+pleasant
+seized
+veterans
+jerked
+fail
+beta
+brush
+radiation
+stored
+warmth
+southeastern
+nate
+sin
+raced
+berkeley
+joke
+athlete
+designation
+trunk
+##low
+roland
+qualification
+archives
+heels
+artwork
+receives
+judicial
+reserves
+##bed
+woke
+installation
+abu
+floating
+fake
+lesser
+excitement
+interface
+concentrated
+addressed
+characteristic
+amanda
+saxophone
+monk
+auto
+##bus
+releasing
+egg
+dies
+interaction
+defender
+ce
+outbreak
+glory
+loving
+##bert
+sequel
+consciousness
+http
+awake
+ski
+enrolled
+##ress
+handling
+rookie
+brow
+somebody
+biography
+warfare
+amounts
+contracts
+presentation
+fabric
+dissolved
+challenged
+meter
+psychological
+lt
+elevated
+rally
+accurate
+##tha
+hospitals
+undergraduate
+specialist
+venezuela
+exhibit
+shed
+nursing
+protestant
+fluid
+structural
+footage
+jared
+consistent
+prey
+##ska
+succession
+reflect
+exile
+lebanon
+wiped
+suspect
+shanghai
+resting
+integration
+preservation
+marvel
+variant
+pirates
+sheep
+rounded
+capita
+sailing
+colonies
+manuscript
+deemed
+variations
+clarke
+functional
+emerging
+boxing
+relaxed
+curse
+azerbaijan
+heavyweight
+nickname
+editorial
+rang
+grid
+tightened
+earthquake
+flashed
+miguel
+rushing
+##ches
+improvements
+boxes
+brooks
+180
+consumption
+molecular
+felix
+societies
+repeatedly
+variation
+aids
+civic
+graphics
+professionals
+realm
+autonomous
+receiver
+delayed
+workshop
+militia
+chairs
+trump
+canyon
+##point
+harsh
+extending
+lovely
+happiness
+##jan
+stake
+eyebrows
+embassy
+wellington
+hannah
+##ella
+sony
+corners
+bishops
+swear
+cloth
+contents
+xi
+namely
+commenced
+1854
+stanford
+nashville
+courage
+graphic
+commitment
+garrison
+##bin
+hamlet
+clearing
+rebels
+attraction
+literacy
+cooking
+ruins
+temples
+jenny
+humanity
+celebrate
+hasn
+freight
+sixty
+rebel
+bastard
+##art
+newton
+##ada
+deer
+##ges
+##ching
+smiles
+delaware
+singers
+##ets
+approaching
+assists
+flame
+##ph
+boulevard
+barrel
+planted
+##ome
+pursuit
+##sia
+consequences
+posts
+shallow
+invitation
+rode
+depot
+ernest
+kane
+rod
+concepts
+preston
+topic
+chambers
+striking
+blast
+arrives
+descendants
+montgomery
+ranges
+worlds
+##lay
+##ari
+span
+chaos
+praise
+##ag
+fewer
+1855
+sanctuary
+mud
+fbi
+##ions
+programmes
+maintaining
+unity
+harper
+bore
+handsome
+closure
+tournaments
+thunder
+nebraska
+linda
+facade
+puts
+satisfied
+argentine
+dale
+cork
+dome
+panama
+##yl
+1858
+tasks
+experts
+##ates
+feeding
+equation
+##las
+##ida
+##tu
+engage
+bryan
+##ax
+um
+quartet
+melody
+disbanded
+sheffield
+blocked
+gasped
+delay
+kisses
+maggie
+connects
+##non
+sts
+poured
+creator
+publishers
+##we
+guided
+ellis
+extinct
+hug
+gaining
+##ord
+complicated
+##bility
+poll
+clenched
+investigate
+##use
+thereby
+quantum
+spine
+cdp
+humor
+kills
+administered
+semifinals
+##du
+encountered
+ignore
+##bu
+commentary
+##maker
+bother
+roosevelt
+140
+plains
+halfway
+flowing
+cultures
+crack
+imprisoned
+neighboring
+airline
+##ses
+##view
+##mate
+##ec
+gather
+wolves
+marathon
+transformed
+##ill
+cruise
+organisations
+carol
+punch
+exhibitions
+numbered
+alarm
+ratings
+daddy
+silently
+##stein
+queens
+colours
+impression
+guidance
+liu
+tactical
+##rat
+marshal
+della
+arrow
+##ings
+rested
+feared
+tender
+owns
+bitter
+advisor
+escort
+##ides
+spare
+farms
+grants
+##ene
+dragons
+encourage
+colleagues
+cameras
+##und
+sucked
+pile
+spirits
+prague
+statements
+suspension
+landmark
+fence
+torture
+recreation
+bags
+permanently
+survivors
+pond
+spy
+predecessor
+bombing
+coup
+##og
+protecting
+transformation
+glow
+##lands
+##book
+dug
+priests
+andrea
+feat
+barn
+jumping
+##chen
+##ologist
+##con
+casualties
+stern
+auckland
+pipe
+serie
+revealing
+ba
+##bel
+trevor
+mercy
+spectrum
+yang
+consist
+governing
+collaborated
+possessed
+epic
+comprises
+blew
+shane
+##ack
+lopez
+honored
+magical
+sacrifice
+judgment
+perceived
+hammer
+mtv
+baronet
+tune
+das
+missionary
+sheets
+350
+neutral
+oral
+threatening
+attractive
+shade
+aims
+seminary
+##master
+estates
+1856
+michel
+wounds
+refugees
+manufacturers
+##nic
+mercury
+syndrome
+porter
+##iya
+##din
+hamburg
+identification
+upstairs
+purse
+widened
+pause
+cared
+breathed
+affiliate
+santiago
+prevented
+celtic
+fisher
+125
+recruited
+byzantine
+reconstruction
+farther
+##mp
+diet
+sake
+au
+spite
+sensation
+##ert
+blank
+separation
+105
+##hon
+vladimir
+armies
+anime
+##lie
+accommodate
+orbit
+cult
+sofia
+archive
+##ify
+##box
+founders
+sustained
+disorder
+honours
+northeastern
+mia
+crops
+violet
+threats
+blanket
+fires
+canton
+followers
+southwestern
+prototype
+voyage
+assignment
+altered
+moderate
+protocol
+pistol
+##eo
+questioned
+brass
+lifting
+1852
+math
+authored
+##ual
+doug
+dimensional
+dynamic
+##san
+1851
+pronounced
+grateful
+quest
+uncomfortable
+boom
+presidency
+stevens
+relating
+politicians
+chen
+barrier
+quinn
+diana
+mosque
+tribal
+cheese
+palmer
+portions
+sometime
+chester
+treasure
+wu
+bend
+download
+millions
+reforms
+registration
+##osa
+consequently
+monitoring
+ate
+preliminary
+brandon
+invented
+ps
+eaten
+exterior
+intervention
+ports
+documented
+log
+displays
+lecture
+sally
+favourite
+##itz
+vermont
+lo
+invisible
+isle
+breed
+##ator
+journalists
+relay
+speaks
+backward
+explore
+midfielder
+actively
+stefan
+procedures
+cannon
+blond
+kenneth
+centered
+servants
+chains
+libraries
+malcolm
+essex
+henri
+slavery
+##hal
+facts
+fairy
+coached
+cassie
+cats
+washed
+cop
+##fi
+announcement
+item
+2000s
+vinyl
+activated
+marco
+frontier
+growled
+curriculum
+##das
+loyal
+accomplished
+leslie
+ritual
+kenny
+##00
+vii
+napoleon
+hollow
+hybrid
+jungle
+stationed
+friedrich
+counted
+##ulated
+platinum
+theatrical
+seated
+col
+rubber
+glen
+1840
+diversity
+healing
+extends
+id
+provisions
+administrator
+columbus
+##oe
+tributary
+te
+assured
+org
+##uous
+prestigious
+examined
+lectures
+grammy
+ronald
+associations
+bailey
+allan
+essays
+flute
+believing
+consultant
+proceedings
+travelling
+1853
+kit
+kerala
+yugoslavia
+buddy
+methodist
+##ith
+burial
+centres
+batman
+##nda
+discontinued
+bo
+dock
+stockholm
+lungs
+severely
+##nk
+citing
+manga
+##ugh
+steal
+mumbai
+iraqi
+robot
+celebrity
+bride
+broadcasts
+abolished
+pot
+joel
+overhead
+franz
+packed
+reconnaissance
+johann
+acknowledged
+introduce
+handled
+doctorate
+developments
+drinks
+alley
+palestine
+##nis
+##aki
+proceeded
+recover
+bradley
+grain
+patch
+afford
+infection
+nationalist
+legendary
+##ath
+interchange
+virtually
+gen
+gravity
+exploration
+amber
+vital
+wishes
+powell
+doctrine
+elbow
+screenplay
+##bird
+contribute
+indonesian
+pet
+creates
+##com
+enzyme
+kylie
+discipline
+drops
+manila
+hunger
+##ien
+layers
+suffer
+fever
+bits
+monica
+keyboard
+manages
+##hood
+searched
+appeals
+##bad
+testament
+grande
+reid
+##war
+beliefs
+congo
+##ification
+##dia
+si
+requiring
+##via
+casey
+1849
+regret
+streak
+rape
+depends
+syrian
+sprint
+pound
+tourists
+upcoming
+pub
+##xi
+tense
+##els
+practiced
+echo
+nationwide
+guild
+motorcycle
+liz
+##zar
+chiefs
+desired
+elena
+bye
+precious
+absorbed
+relatives
+booth
+pianist
+##mal
+citizenship
+exhausted
+wilhelm
+##ceae
+##hed
+noting
+quarterback
+urge
+hectares
+##gue
+ace
+holly
+##tal
+blonde
+davies
+parked
+sustainable
+stepping
+twentieth
+airfield
+galaxy
+nest
+chip
+##nell
+tan
+shaft
+paulo
+requirement
+##zy
+paradise
+tobacco
+trans
+renewed
+vietnamese
+##cker
+##ju
+suggesting
+catching
+holmes
+enjoying
+md
+trips
+colt
+holder
+butterfly
+nerve
+reformed
+cherry
+bowling
+trailer
+carriage
+goodbye
+appreciate
+toy
+joshua
+interactive
+enabled
+involve
+##kan
+collar
+determination
+bunch
+facebook
+recall
+shorts
+superintendent
+episcopal
+frustration
+giovanni
+nineteenth
+laser
+privately
+array
+circulation
+##ovic
+armstrong
+deals
+painful
+permit
+discrimination
+##wi
+aires
+retiring
+cottage
+ni
+##sta
+horizon
+ellen
+jamaica
+ripped
+fernando
+chapters
+playstation
+patron
+lecturer
+navigation
+behaviour
+genes
+georgian
+export
+solomon
+rivals
+swift
+seventeen
+rodriguez
+princeton
+independently
+sox
+1847
+arguing
+entity
+casting
+hank
+criteria
+oakland
+geographic
+milwaukee
+reflection
+expanding
+conquest
+dubbed
+##tv
+halt
+brave
+brunswick
+doi
+arched
+curtis
+divorced
+predominantly
+somerset
+streams
+ugly
+zoo
+horrible
+curved
+buenos
+fierce
+dictionary
+vector
+theological
+unions
+handful
+stability
+chan
+punjab
+segments
+##lly
+altar
+ignoring
+gesture
+monsters
+pastor
+##stone
+thighs
+unexpected
+operators
+abruptly
+coin
+compiled
+associates
+improving
+migration
+pin
+##ose
+compact
+collegiate
+reserved
+##urs
+quarterfinals
+roster
+restore
+assembled
+hurry
+oval
+##cies
+1846
+flags
+martha
+##del
+victories
+sharply
+##rated
+argues
+deadly
+neo
+drawings
+symbols
+performer
+##iel
+griffin
+restrictions
+editing
+andrews
+java
+journals
+arabia
+compositions
+dee
+pierce
+removing
+hindi
+casino
+runway
+civilians
+minds
+nasa
+hotels
+##zation
+refuge
+rent
+retain
+potentially
+conferences
+suburban
+conducting
+##tto
+##tions
+##tle
+descended
+massacre
+##cal
+ammunition
+terrain
+fork
+souls
+counts
+chelsea
+durham
+drives
+cab
+##bank
+perth
+realizing
+palestinian
+finn
+simpson
+##dal
+betty
+##ule
+moreover
+particles
+cardinals
+tent
+evaluation
+extraordinary
+##oid
+inscription
+##works
+wednesday
+chloe
+maintains
+panels
+ashley
+trucks
+##nation
+cluster
+sunlight
+strikes
+zhang
+##wing
+dialect
+canon
+##ap
+tucked
+##ws
+collecting
+##mas
+##can
+##sville
+maker
+quoted
+evan
+franco
+aria
+buying
+cleaning
+eva
+closet
+provision
+apollo
+clinic
+rat
+##ez
+necessarily
+ac
+##gle
+##ising
+venues
+flipped
+cent
+spreading
+trustees
+checking
+authorized
+##sco
+disappointed
+##ado
+notion
+duration
+trumpet
+hesitated
+topped
+brussels
+rolls
+theoretical
+hint
+define
+aggressive
+repeat
+wash
+peaceful
+optical
+width
+allegedly
+mcdonald
+strict
+copyright
+##illa
+investors
+mar
+jam
+witnesses
+sounding
+miranda
+michelle
+privacy
+hugo
+harmony
+##pp
+valid
+lynn
+glared
+nina
+102
+headquartered
+diving
+boarding
+gibson
+##ncy
+albanian
+marsh
+routine
+dealt
+enhanced
+er
+intelligent
+substance
+targeted
+enlisted
+discovers
+spinning
+observations
+pissed
+smoking
+rebecca
+capitol
+visa
+varied
+costume
+seemingly
+indies
+compensation
+surgeon
+thursday
+arsenal
+westminster
+suburbs
+rid
+anglican
+##ridge
+knots
+foods
+alumni
+lighter
+fraser
+whoever
+portal
+scandal
+##ray
+gavin
+advised
+instructor
+flooding
+terrorist
+##ale
+teenage
+interim
+senses
+duck
+teen
+thesis
+abby
+eager
+overcome
+##ile
+newport
+glenn
+rises
+shame
+##cc
+prompted
+priority
+forgot
+bomber
+nicolas
+protective
+360
+cartoon
+katherine
+breeze
+lonely
+trusted
+henderson
+richardson
+relax
+banner
+candy
+palms
+remarkable
+##rio
+legends
+cricketer
+essay
+ordained
+edmund
+rifles
+trigger
+##uri
+##away
+sail
+alert
+1830
+audiences
+penn
+sussex
+siblings
+pursued
+indianapolis
+resist
+rosa
+consequence
+succeed
+avoided
+1845
+##ulation
+inland
+##tie
+##nna
+counsel
+profession
+chronicle
+hurried
+##una
+eyebrow
+eventual
+bleeding
+innovative
+cure
+##dom
+committees
+accounting
+con
+scope
+hardy
+heather
+tenor
+gut
+herald
+codes
+tore
+scales
+wagon
+##oo
+luxury
+tin
+prefer
+fountain
+triangle
+bonds
+darling
+convoy
+dried
+traced
+beings
+troy
+accidentally
+slam
+findings
+smelled
+joey
+lawyers
+outcome
+steep
+bosnia
+configuration
+shifting
+toll
+brook
+performers
+lobby
+philosophical
+construct
+shrine
+aggregate
+boot
+cox
+phenomenon
+savage
+insane
+solely
+reynolds
+lifestyle
+##ima
+nationally
+holdings
+consideration
+enable
+edgar
+mo
+mama
+##tein
+fights
+relegation
+chances
+atomic
+hub
+conjunction
+awkward
+reactions
+currency
+finale
+kumar
+underwent
+steering
+elaborate
+gifts
+comprising
+melissa
+veins
+reasonable
+sunshine
+chi
+solve
+trails
+inhabited
+elimination
+ethics
+huh
+ana
+molly
+consent
+apartments
+layout
+marines
+##ces
+hunters
+bulk
+##oma
+hometown
+##wall
+##mont
+cracked
+reads
+neighbouring
+withdrawn
+admission
+wingspan
+damned
+anthology
+lancashire
+brands
+batting
+forgive
+cuban
+awful
+##lyn
+104
+dimensions
+imagination
+##ade
+dante
+##ship
+tracking
+desperately
+goalkeeper
+##yne
+groaned
+workshops
+confident
+burton
+gerald
+milton
+circus
+uncertain
+slope
+copenhagen
+sophia
+fog
+philosopher
+portraits
+accent
+cycling
+varying
+gripped
+larvae
+garrett
+specified
+scotia
+mature
+luther
+kurt
+rap
+##kes
+aerial
+750
+ferdinand
+heated
+es
+transported
+##shan
+safely
+nonetheless
+##orn
+##gal
+motors
+demanding
+##sburg
+startled
+##brook
+ally
+generate
+caps
+ghana
+stained
+demo
+mentions
+beds
+ap
+afterward
+diary
+##bling
+utility
+##iro
+richards
+1837
+conspiracy
+conscious
+shining
+footsteps
+observer
+cyprus
+urged
+loyalty
+developer
+probability
+olive
+upgraded
+gym
+miracle
+insects
+graves
+1844
+ourselves
+hydrogen
+amazon
+katie
+tickets
+poets
+##pm
+planes
+##pan
+prevention
+witnessed
+dense
+jin
+randy
+tang
+warehouse
+monroe
+bang
+archived
+elderly
+investigations
+alec
+granite
+mineral
+conflicts
+controlling
+aboriginal
+carlo
+##zu
+mechanics
+stan
+stark
+rhode
+skirt
+est
+##berry
+bombs
+respected
+##horn
+imposed
+limestone
+deny
+nominee
+memphis
+grabbing
+disabled
+##als
+amusement
+aa
+frankfurt
+corn
+referendum
+varies
+slowed
+disk
+firms
+unconscious
+incredible
+clue
+sue
+##zhou
+twist
+##cio
+joins
+idaho
+chad
+developers
+computing
+destroyer
+103
+mortal
+tucker
+kingston
+choices
+yu
+carson
+1800
+os
+whitney
+geneva
+pretend
+dimension
+staged
+plateau
+maya
+##une
+freestyle
+##bc
+rovers
+hiv
+##ids
+tristan
+classroom
+prospect
+##hus
+honestly
+diploma
+lied
+thermal
+auxiliary
+feast
+unlikely
+iata
+##tel
+morocco
+pounding
+treasury
+lithuania
+considerably
+1841
+dish
+1812
+geological
+matching
+stumbled
+destroying
+marched
+brien
+advances
+cake
+nicole
+belle
+settling
+measuring
+directing
+##mie
+tuesday
+bassist
+capabilities
+stunned
+fraud
+torpedo
+##list
+##phone
+anton
+wisdom
+surveillance
+ruined
+##ulate
+lawsuit
+healthcare
+theorem
+halls
+trend
+aka
+horizontal
+dozens
+acquire
+lasting
+swim
+hawk
+gorgeous
+fees
+vicinity
+decrease
+adoption
+tactics
+##ography
+pakistani
+##ole
+draws
+##hall
+willie
+burke
+heath
+algorithm
+integral
+powder
+elliott
+brigadier
+jackie
+tate
+varieties
+darker
+##cho
+lately
+cigarette
+specimens
+adds
+##ree
+##ensis
+##inger
+exploded
+finalist
+cia
+murders
+wilderness
+arguments
+nicknamed
+acceptance
+onwards
+manufacture
+robertson
+jets
+tampa
+enterprises
+blog
+loudly
+composers
+nominations
+1838
+ai
+malta
+inquiry
+automobile
+hosting
+viii
+rays
+tilted
+grief
+museums
+strategies
+furious
+euro
+equality
+cohen
+poison
+surrey
+wireless
+governed
+ridiculous
+moses
+##esh
+##room
+vanished
+##ito
+barnes
+attract
+morrison
+istanbul
+##iness
+absent
+rotation
+petition
+janet
+##logical
+satisfaction
+custody
+deliberately
+observatory
+comedian
+surfaces
+pinyin
+novelist
+strictly
+canterbury
+oslo
+monks
+embrace
+ibm
+jealous
+photograph
+continent
+dorothy
+marina
+doc
+excess
+holden
+allegations
+explaining
+stack
+avoiding
+lance
+storyline
+majesty
+poorly
+spike
+dos
+bradford
+raven
+travis
+classics
+proven
+voltage
+pillow
+fists
+butt
+1842
+interpreted
+##car
+1839
+gage
+telegraph
+lens
+promising
+expelled
+casual
+collector
+zones
+##min
+silly
+nintendo
+##kh
+##bra
+downstairs
+chef
+suspicious
+afl
+flies
+vacant
+uganda
+pregnancy
+condemned
+lutheran
+estimates
+cheap
+decree
+saxon
+proximity
+stripped
+idiot
+deposits
+contrary
+presenter
+magnus
+glacier
+im
+offense
+edwin
+##ori
+upright
+##long
+bolt
+##ois
+toss
+geographical
+##izes
+environments
+delicate
+marking
+abstract
+xavier
+nails
+windsor
+plantation
+occurring
+equity
+saskatchewan
+fears
+drifted
+sequences
+vegetation
+revolt
+##stic
+1843
+sooner
+fusion
+opposing
+nato
+skating
+1836
+secretly
+ruin
+lease
+##oc
+edit
+##nne
+flora
+anxiety
+ruby
+##ological
+##mia
+tel
+bout
+taxi
+emmy
+frost
+rainbow
+compounds
+foundations
+rainfall
+assassination
+nightmare
+dominican
+##win
+achievements
+deserve
+orlando
+intact
+armenia
+##nte
+calgary
+valentine
+106
+marion
+proclaimed
+theodore
+bells
+courtyard
+thigh
+gonzalez
+console
+troop
+minimal
+monte
+everyday
+##ence
+##if
+supporter
+terrorism
+buck
+openly
+presbyterian
+activists
+carpet
+##iers
+rubbing
+uprising
+##yi
+cute
+conceived
+legally
+##cht
+millennium
+cello
+velocity
+ji
+rescued
+cardiff
+1835
+rex
+concentrate
+senators
+beard
+rendered
+glowing
+battalions
+scouts
+competitors
+sculptor
+catalogue
+arctic
+ion
+raja
+bicycle
+wow
+glancing
+lawn
+##woman
+gentleman
+lighthouse
+publish
+predicted
+calculated
+##val
+variants
+##gne
+strain
+##ui
+winston
+deceased
+##nus
+touchdowns
+brady
+caleb
+sinking
+echoed
+crush
+hon
+blessed
+protagonist
+hayes
+endangered
+magnitude
+editors
+##tine
+estimate
+responsibilities
+##mel
+backup
+laying
+consumed
+sealed
+zurich
+lovers
+frustrated
+##eau
+ahmed
+kicking
+mit
+treasurer
+1832
+biblical
+refuse
+terrified
+pump
+agrees
+genuine
+imprisonment
+refuses
+plymouth
+##hen
+lou
+##nen
+tara
+trembling
+antarctic
+ton
+learns
+##tas
+crap
+crucial
+faction
+atop
+##borough
+wrap
+lancaster
+odds
+hopkins
+erik
+lyon
+##eon
+bros
+##ode
+snap
+locality
+tips
+empress
+crowned
+cal
+acclaimed
+chuckled
+##ory
+clara
+sends
+mild
+towel
+##fl
+##day
+##а
+wishing
+assuming
+interviewed
+##bal
+##die
+interactions
+eden
+cups
+helena
+##lf
+indie
+beck
+##fire
+batteries
+filipino
+wizard
+parted
+##lam
+traces
+##born
+rows
+idol
+albany
+delegates
+##ees
+##sar
+discussions
+##ex
+notre
+instructed
+belgrade
+highways
+suggestion
+lauren
+possess
+orientation
+alexandria
+abdul
+beats
+salary
+reunion
+ludwig
+alright
+wagner
+intimate
+pockets
+slovenia
+hugged
+brighton
+merchants
+cruel
+stole
+trek
+slopes
+repairs
+enrollment
+politically
+underlying
+promotional
+counting
+boeing
+##bb
+isabella
+naming
+##и
+keen
+bacteria
+listing
+separately
+belfast
+ussr
+450
+lithuanian
+anybody
+ribs
+sphere
+martinez
+cock
+embarrassed
+proposals
+fragments
+nationals
+##fs
+##wski
+premises
+fin
+1500
+alpine
+matched
+freely
+bounded
+jace
+sleeve
+##af
+gaming
+pier
+populated
+evident
+##like
+frances
+flooded
+##dle
+frightened
+pour
+trainer
+framed
+visitor
+challenging
+pig
+wickets
+##fold
+infected
+email
+##pes
+arose
+##aw
+reward
+ecuador
+oblast
+vale
+ch
+shuttle
+##usa
+bach
+rankings
+forbidden
+cornwall
+accordance
+salem
+consumers
+bruno
+fantastic
+toes
+machinery
+resolved
+julius
+remembering
+propaganda
+iceland
+bombardment
+tide
+contacts
+wives
+##rah
+concerto
+macdonald
+albania
+implement
+daisy
+tapped
+sudan
+helmet
+angela
+mistress
+##lic
+crop
+sunk
+finest
+##craft
+hostile
+##ute
+##tsu
+boxer
+fr
+paths
+adjusted
+habit
+ballot
+supervision
+soprano
+##zen
+bullets
+wicked
+sunset
+regiments
+disappear
+lamp
+performs
+app
+##gia
+##oa
+rabbit
+digging
+incidents
+entries
+##cion
+dishes
+##oi
+introducing
+##ati
+##fied
+freshman
+slot
+jill
+tackles
+baroque
+backs
+##iest
+lone
+sponsor
+destiny
+altogether
+convert
+##aro
+consensus
+shapes
+demonstration
+basically
+feminist
+auction
+artifacts
+##bing
+strongest
+twitter
+halifax
+2019
+allmusic
+mighty
+smallest
+precise
+alexandra
+viola
+##los
+##ille
+manuscripts
+##illo
+dancers
+ari
+managers
+monuments
+blades
+barracks
+springfield
+maiden
+consolidated
+electron
+##end
+berry
+airing
+wheat
+nobel
+inclusion
+blair
+payments
+geography
+bee
+cc
+eleanor
+react
+##hurst
+afc
+manitoba
+##yu
+su
+lineup
+fitness
+recreational
+investments
+airborne
+disappointment
+##dis
+edmonton
+viewing
+##row
+renovation
+##cast
+infant
+bankruptcy
+roses
+aftermath
+pavilion
+##yer
+carpenter
+withdrawal
+ladder
+##hy
+discussing
+popped
+reliable
+agreements
+rochester
+##abad
+curves
+bombers
+220
+rao
+reverend
+decreased
+choosing
+107
+stiff
+consulting
+naples
+crawford
+tracy
+ka
+ribbon
+cops
+##lee
+crushed
+deciding
+unified
+teenager
+accepting
+flagship
+explorer
+poles
+sanchez
+inspection
+revived
+skilled
+induced
+exchanged
+flee
+locals
+tragedy
+swallow
+loading
+hanna
+demonstrate
+##ela
+salvador
+flown
+contestants
+civilization
+##ines
+wanna
+rhodes
+fletcher
+hector
+knocking
+considers
+##ough
+nash
+mechanisms
+sensed
+mentally
+walt
+unclear
+##eus
+renovated
+madame
+##cks
+crews
+governmental
+##hin
+undertaken
+monkey
+##ben
+##ato
+fatal
+armored
+copa
+caves
+governance
+grasp
+perception
+certification
+froze
+damp
+tugged
+wyoming
+##rg
+##ero
+newman
+##lor
+nerves
+curiosity
+graph
+115
+##ami
+withdraw
+tunnels
+dull
+meredith
+moss
+exhibits
+neighbors
+communicate
+accuracy
+explored
+raiders
+republicans
+secular
+kat
+superman
+penny
+criticised
+##tch
+freed
+update
+conviction
+wade
+ham
+likewise
+delegation
+gotta
+doll
+promises
+technological
+myth
+nationality
+resolve
+convent
+##mark
+sharon
+dig
+sip
+coordinator
+entrepreneur
+fold
+##dine
+capability
+councillor
+synonym
+blown
+swan
+cursed
+1815
+jonas
+haired
+sofa
+canvas
+keeper
+rivalry
+##hart
+rapper
+speedway
+swords
+postal
+maxwell
+estonia
+potter
+recurring
+##nn
+##ave
+errors
+##oni
+cognitive
+1834
+##²
+claws
+nadu
+roberto
+bce
+wrestler
+ellie
+##ations
+infinite
+ink
+##tia
+presumably
+finite
+staircase
+108
+noel
+patricia
+nacional
+##cation
+chill
+eternal
+tu
+preventing
+prussia
+fossil
+limbs
+##logist
+ernst
+frog
+perez
+rene
+##ace
+pizza
+prussian
+##ios
+##vy
+molecules
+regulatory
+answering
+opinions
+sworn
+lengths
+supposedly
+hypothesis
+upward
+habitats
+seating
+ancestors
+drank
+yield
+hd
+synthesis
+researcher
+modest
+##var
+mothers
+peered
+voluntary
+homeland
+##the
+acclaim
+##igan
+static
+valve
+luxembourg
+alto
+carroll
+fe
+receptor
+norton
+ambulance
+##tian
+johnston
+catholics
+depicting
+jointly
+elephant
+gloria
+mentor
+badge
+ahmad
+distinguish
+remarked
+councils
+precisely
+allison
+advancing
+detection
+crowded
+##10
+cooperative
+ankle
+mercedes
+dagger
+surrendered
+pollution
+commit
+subway
+jeffrey
+lesson
+sculptures
+provider
+##fication
+membrane
+timothy
+rectangular
+fiscal
+heating
+teammate
+basket
+particle
+anonymous
+deployment
+##ple
+missiles
+courthouse
+proportion
+shoe
+sec
+##ller
+complaints
+forbes
+blacks
+abandon
+remind
+sizes
+overwhelming
+autobiography
+natalie
+##awa
+risks
+contestant
+countryside
+babies
+scorer
+invaded
+enclosed
+proceed
+hurling
+disorders
+##cu
+reflecting
+continuously
+cruiser
+graduates
+freeway
+investigated
+ore
+deserved
+maid
+blocking
+phillip
+jorge
+shakes
+dove
+mann
+variables
+lacked
+burden
+accompanying
+que
+consistently
+organizing
+provisional
+complained
+endless
+##rm
+tubes
+juice
+georges
+krishna
+mick
+labels
+thriller
+##uch
+laps
+arcade
+sage
+snail
+##table
+shannon
+fi
+laurence
+seoul
+vacation
+presenting
+hire
+churchill
+surprisingly
+prohibited
+savannah
+technically
+##oli
+170
+##lessly
+testimony
+suited
+speeds
+toys
+romans
+mlb
+flowering
+measurement
+talented
+kay
+settings
+charleston
+expectations
+shattered
+achieving
+triumph
+ceremonies
+portsmouth
+lanes
+mandatory
+loser
+stretching
+cologne
+realizes
+seventy
+cornell
+careers
+webb
+##ulating
+americas
+budapest
+ava
+suspicion
+##ison
+yo
+conrad
+##hai
+sterling
+jessie
+rector
+##az
+1831
+transform
+organize
+loans
+christine
+volcanic
+warrant
+slender
+summers
+subfamily
+newer
+danced
+dynamics
+rhine
+proceeds
+heinrich
+gastropod
+commands
+sings
+facilitate
+easter
+ra
+positioned
+responses
+expense
+fruits
+yanked
+imported
+25th
+velvet
+vic
+primitive
+tribune
+baldwin
+neighbourhood
+donna
+rip
+hay
+pr
+##uro
+1814
+espn
+welcomed
+##aria
+qualifier
+glare
+highland
+timing
+##cted
+shells
+eased
+geometry
+louder
+exciting
+slovakia
+##sion
+##iz
+##lot
+savings
+prairie
+##ques
+marching
+rafael
+tonnes
+##lled
+curtain
+preceding
+shy
+heal
+greene
+worthy
+##pot
+detachment
+bury
+sherman
+##eck
+reinforced
+seeks
+bottles
+contracted
+duchess
+outfit
+walsh
+##sc
+mickey
+##ase
+geoffrey
+archer
+squeeze
+dawson
+eliminate
+invention
+##enberg
+neal
+##eth
+stance
+dealer
+coral
+maple
+retire
+polo
+simplified
+##ht
+1833
+hid
+watts
+backwards
+jules
+##oke
+genesis
+mt
+frames
+rebounds
+burma
+woodland
+moist
+santos
+whispers
+drained
+subspecies
+##aa
+streaming
+ulster
+burnt
+correspondence
+maternal
+gerard
+denis
+stealing
+##load
+genius
+duchy
+##oria
+inaugurated
+momentum
+suits
+placement
+sovereign
+clause
+thames
+##hara
+confederation
+reservation
+sketch
+yankees
+lets
+rotten
+charm
+hal
+verses
+ultra
+commercially
+dot
+salon
+citation
+adopt
+winnipeg
+mist
+allocated
+cairo
+##boy
+jenkins
+interference
+objectives
+##wind
+1820
+portfolio
+armoured
+sectors
+##eh
+initiatives
+##world
+integrity
+exercises
+robe
+tap
+ab
+gazed
+##tones
+distracted
+rulers
+111
+favorable
+jerome
+tended
+cart
+factories
+##eri
+diplomat
+valued
+gravel
+charitable
+##try
+calvin
+exploring
+chang
+shepherd
+terrace
+pdf
+pupil
+##ural
+reflects
+ups
+##rch
+governors
+shelf
+depths
+##nberg
+trailed
+crest
+tackle
+##nian
+##ats
+hatred
+##kai
+clare
+makers
+ethiopia
+longtime
+detected
+embedded
+lacking
+slapped
+rely
+thomson
+anticipation
+iso
+morton
+successive
+agnes
+screenwriter
+straightened
+philippe
+playwright
+haunted
+licence
+iris
+intentions
+sutton
+112
+logical
+correctly
+##weight
+branded
+licked
+tipped
+silva
+ricky
+narrator
+requests
+##ents
+greeted
+supernatural
+cow
+##wald
+lung
+refusing
+employer
+strait
+gaelic
+liner
+##piece
+zoe
+sabha
+##mba
+driveway
+harvest
+prints
+bates
+reluctantly
+threshold
+algebra
+ira
+wherever
+coupled
+240
+assumption
+picks
+##air
+designers
+raids
+gentlemen
+##ean
+roller
+blowing
+leipzig
+locks
+screw
+dressing
+strand
+##lings
+scar
+dwarf
+depicts
+##nu
+nods
+##mine
+differ
+boris
+##eur
+yuan
+flip
+##gie
+mob
+invested
+questioning
+applying
+##ture
+shout
+##sel
+gameplay
+blamed
+illustrations
+bothered
+weakness
+rehabilitation
+##of
+##zes
+envelope
+rumors
+miners
+leicester
+subtle
+kerry
+##ico
+ferguson
+##fu
+premiership
+ne
+##cat
+bengali
+prof
+catches
+remnants
+dana
+##rily
+shouting
+presidents
+baltic
+ought
+ghosts
+dances
+sailors
+shirley
+fancy
+dominic
+##bie
+madonna
+##rick
+bark
+buttons
+gymnasium
+ashes
+liver
+toby
+oath
+providence
+doyle
+evangelical
+nixon
+cement
+carnegie
+embarked
+hatch
+surroundings
+guarantee
+needing
+pirate
+essence
+##bee
+filter
+crane
+hammond
+projected
+immune
+percy
+twelfth
+##ult
+regent
+doctoral
+damon
+mikhail
+##ichi
+lu
+critically
+elect
+realised
+abortion
+acute
+screening
+mythology
+steadily
+##fc
+frown
+nottingham
+kirk
+wa
+minneapolis
+##rra
+module
+algeria
+mc
+nautical
+encounters
+surprising
+statues
+availability
+shirts
+pie
+alma
+brows
+munster
+mack
+soup
+crater
+tornado
+sanskrit
+cedar
+explosive
+bordered
+dixon
+planets
+stamp
+exam
+happily
+##bble
+carriers
+kidnapped
+##vis
+accommodation
+emigrated
+##met
+knockout
+correspondent
+violation
+profits
+peaks
+lang
+specimen
+agenda
+ancestry
+pottery
+spelling
+equations
+obtaining
+ki
+linking
+1825
+debris
+asylum
+##20
+buddhism
+teddy
+##ants
+gazette
+##nger
+##sse
+dental
+eligibility
+utc
+fathers
+averaged
+zimbabwe
+francesco
+coloured
+hissed
+translator
+lynch
+mandate
+humanities
+mackenzie
+uniforms
+lin
+##iana
+##gio
+asset
+mhz
+fitting
+samantha
+genera
+wei
+rim
+beloved
+shark
+riot
+entities
+expressions
+indo
+carmen
+slipping
+owing
+abbot
+neighbor
+sidney
+##av
+rats
+recommendations
+encouraging
+squadrons
+anticipated
+commanders
+conquered
+##oto
+donations
+diagnosed
+##mond
+divide
+##iva
+guessed
+decoration
+vernon
+auditorium
+revelation
+conversations
+##kers
+##power
+herzegovina
+dash
+alike
+protested
+lateral
+herman
+accredited
+mg
+##gent
+freeman
+mel
+fiji
+crow
+crimson
+##rine
+livestock
+##pped
+humanitarian
+bored
+oz
+whip
+##lene
+##ali
+legitimate
+alter
+grinning
+spelled
+anxious
+oriental
+wesley
+##nin
+##hole
+carnival
+controller
+detect
+##ssa
+bowed
+educator
+kosovo
+macedonia
+##sin
+occupy
+mastering
+stephanie
+janeiro
+para
+unaware
+nurses
+noon
+135
+cam
+hopefully
+ranger
+combine
+sociology
+polar
+rica
+##eer
+neill
+##sman
+holocaust
+##ip
+doubled
+lust
+1828
+109
+decent
+cooling
+unveiled
+##card
+1829
+nsw
+homer
+chapman
+meyer
+##gin
+dive
+mae
+reagan
+expertise
+##gled
+darwin
+brooke
+sided
+prosecution
+investigating
+comprised
+petroleum
+genres
+reluctant
+differently
+trilogy
+johns
+vegetables
+corpse
+highlighted
+lounge
+pension
+unsuccessfully
+elegant
+aided
+ivory
+beatles
+amelia
+cain
+dubai
+sunny
+immigrant
+babe
+click
+##nder
+underwater
+pepper
+combining
+mumbled
+atlas
+horns
+accessed
+ballad
+physicians
+homeless
+gestured
+rpm
+freak
+louisville
+corporations
+patriots
+prizes
+rational
+warn
+modes
+decorative
+overnight
+din
+troubled
+phantom
+##ort
+monarch
+sheer
+##dorf
+generals
+guidelines
+organs
+addresses
+##zon
+enhance
+curling
+parishes
+cord
+##kie
+linux
+caesar
+deutsche
+bavaria
+##bia
+coleman
+cyclone
+##eria
+bacon
+petty
+##yama
+##old
+hampton
+diagnosis
+1824
+throws
+complexity
+rita
+disputed
+##₃
+pablo
+##sch
+marketed
+trafficking
+##ulus
+examine
+plague
+formats
+##oh
+vault
+faithful
+##bourne
+webster
+##ox
+highlights
+##ient
+##ann
+phones
+vacuum
+sandwich
+modeling
+##gated
+bolivia
+clergy
+qualities
+isabel
+##nas
+##ars
+wears
+screams
+reunited
+annoyed
+bra
+##ancy
+##rate
+differential
+transmitter
+tattoo
+container
+poker
+##och
+excessive
+resides
+cowboys
+##tum
+augustus
+trash
+providers
+statute
+retreated
+balcony
+reversed
+void
+storey
+preceded
+masses
+leap
+laughs
+neighborhoods
+wards
+schemes
+falcon
+santo
+battlefield
+pad
+ronnie
+thread
+lesbian
+venus
+##dian
+beg
+sandstone
+daylight
+punched
+gwen
+analog
+stroked
+wwe
+acceptable
+measurements
+dec
+toxic
+##kel
+adequate
+surgical
+economist
+parameters
+varsity
+##sberg
+quantity
+ella
+##chy
+##rton
+countess
+generating
+precision
+diamonds
+expressway
+ga
+##ı
+1821
+uruguay
+talents
+galleries
+expenses
+scanned
+colleague
+outlets
+ryder
+lucien
+##ila
+paramount
+##bon
+syracuse
+dim
+fangs
+gown
+sweep
+##sie
+toyota
+missionaries
+websites
+##nsis
+sentences
+adviser
+val
+trademark
+spells
+##plane
+patience
+starter
+slim
+##borg
+toe
+incredibly
+shoots
+elliot
+nobility
+##wyn
+cowboy
+endorsed
+gardner
+tendency
+persuaded
+organisms
+emissions
+kazakhstan
+amused
+boring
+chips
+themed
+##hand
+llc
+constantinople
+chasing
+systematic
+guatemala
+borrowed
+erin
+carey
+##hard
+highlands
+struggles
+1810
+##ifying
+##ced
+wong
+exceptions
+develops
+enlarged
+kindergarten
+castro
+##ern
+##rina
+leigh
+zombie
+juvenile
+##most
+consul
+##nar
+sailor
+hyde
+clarence
+intensive
+pinned
+nasty
+useless
+jung
+clayton
+stuffed
+exceptional
+ix
+apostolic
+230
+transactions
+##dge
+exempt
+swinging
+cove
+religions
+##ash
+shields
+dairy
+bypass
+190
+pursuing
+bug
+joyce
+bombay
+chassis
+southampton
+chat
+interact
+redesignated
+##pen
+nascar
+pray
+salmon
+rigid
+regained
+malaysian
+grim
+publicity
+constituted
+capturing
+toilet
+delegate
+purely
+tray
+drift
+loosely
+striker
+weakened
+trinidad
+mitch
+itv
+defines
+transmitted
+ming
+scarlet
+nodding
+fitzgerald
+fu
+narrowly
+sp
+tooth
+standings
+virtue
+##₁
+##wara
+##cting
+chateau
+gloves
+lid
+##nel
+hurting
+conservatory
+##pel
+sinclair
+reopened
+sympathy
+nigerian
+strode
+advocated
+optional
+chronic
+discharge
+##rc
+suck
+compatible
+laurel
+stella
+shi
+fails
+wage
+dodge
+128
+informal
+sorts
+levi
+buddha
+villagers
+##aka
+chronicles
+heavier
+summoned
+gateway
+3000
+eleventh
+jewelry
+translations
+accordingly
+seas
+##ency
+fiber
+pyramid
+cubic
+dragging
+##ista
+caring
+##ops
+android
+contacted
+lunar
+##dt
+kai
+lisbon
+patted
+1826
+sacramento
+theft
+madagascar
+subtropical
+disputes
+ta
+holidays
+piper
+willow
+mare
+cane
+itunes
+newfoundland
+benny
+companions
+dong
+raj
+observe
+roar
+charming
+plaque
+tibetan
+fossils
+enacted
+manning
+bubble
+tina
+tanzania
+##eda
+##hir
+funk
+swamp
+deputies
+cloak
+ufc
+scenario
+par
+scratch
+metals
+anthem
+guru
+engaging
+specially
+##boat
+dialects
+nineteen
+cecil
+duet
+disability
+messenger
+unofficial
+##lies
+defunct
+eds
+moonlight
+drainage
+surname
+puzzle
+honda
+switching
+conservatives
+mammals
+knox
+broadcaster
+sidewalk
+cope
+##ried
+benson
+princes
+peterson
+##sal
+bedford
+sharks
+eli
+wreck
+alberto
+gasp
+archaeology
+lgbt
+teaches
+securities
+madness
+compromise
+waving
+coordination
+davidson
+visions
+leased
+possibilities
+eighty
+jun
+fernandez
+enthusiasm
+assassin
+sponsorship
+reviewer
+kingdoms
+estonian
+laboratories
+##fy
+##nal
+applies
+verb
+celebrations
+##zzo
+rowing
+lightweight
+sadness
+submit
+mvp
+balanced
+dude
+##vas
+explicitly
+metric
+magnificent
+mound
+brett
+mohammad
+mistakes
+irregular
+##hing
+##ass
+sanders
+betrayed
+shipped
+surge
+##enburg
+reporters
+termed
+georg
+pity
+verbal
+bulls
+abbreviated
+enabling
+appealed
+##are
+##atic
+sicily
+sting
+heel
+sweetheart
+bart
+spacecraft
+brutal
+monarchy
+##tter
+aberdeen
+cameo
+diane
+##ub
+survivor
+clyde
+##aries
+complaint
+##makers
+clarinet
+delicious
+chilean
+karnataka
+coordinates
+1818
+panties
+##rst
+pretending
+ar
+dramatically
+kiev
+bella
+tends
+distances
+113
+catalog
+launching
+instances
+telecommunications
+portable
+lindsay
+vatican
+##eim
+angles
+aliens
+marker
+stint
+screens
+bolton
+##rne
+judy
+wool
+benedict
+plasma
+europa
+spark
+imaging
+filmmaker
+swiftly
+##een
+contributor
+##nor
+opted
+stamps
+apologize
+financing
+butter
+gideon
+sophisticated
+alignment
+avery
+chemicals
+yearly
+speculation
+prominence
+professionally
+##ils
+immortal
+institutional
+inception
+wrists
+identifying
+tribunal
+derives
+gains
+##wo
+papal
+preference
+linguistic
+vince
+operative
+brewery
+##ont
+unemployment
+boyd
+##ured
+##outs
+albeit
+prophet
+1813
+bi
+##rr
+##face
+##rad
+quarterly
+asteroid
+cleaned
+radius
+temper
+##llen
+telugu
+jerk
+viscount
+menu
+##ote
+glimpse
+##aya
+yacht
+hawaiian
+baden
+##rl
+laptop
+readily
+##gu
+monetary
+offshore
+scots
+watches
+##yang
+##arian
+upgrade
+needle
+xbox
+lea
+encyclopedia
+flank
+fingertips
+##pus
+delight
+teachings
+confirm
+roth
+beaches
+midway
+winters
+##iah
+teasing
+daytime
+beverly
+gambling
+bonnie
+##backs
+regulated
+clement
+hermann
+tricks
+knot
+##shing
+##uring
+##vre
+detached
+ecological
+owed
+specialty
+byron
+inventor
+bats
+stays
+screened
+unesco
+midland
+trim
+affection
+##ander
+##rry
+jess
+thoroughly
+feedback
+##uma
+chennai
+strained
+heartbeat
+wrapping
+overtime
+pleaded
+##sworth
+mon
+leisure
+oclc
+##tate
+##ele
+feathers
+angelo
+thirds
+nuts
+surveys
+clever
+gill
+commentator
+##dos
+darren
+rides
+gibraltar
+##nc
+##mu
+dissolution
+dedication
+shin
+meals
+saddle
+elvis
+reds
+chaired
+taller
+appreciation
+functioning
+niece
+favored
+advocacy
+robbie
+criminals
+suffolk
+yugoslav
+passport
+constable
+congressman
+hastings
+vera
+##rov
+consecrated
+sparks
+ecclesiastical
+confined
+##ovich
+muller
+floyd
+nora
+1822
+paved
+1827
+cumberland
+ned
+saga
+spiral
+##flow
+appreciated
+yi
+collaborative
+treating
+similarities
+feminine
+finishes
+##ib
+jade
+import
+##nse
+##hot
+champagne
+mice
+securing
+celebrities
+helsinki
+attributes
+##gos
+cousins
+phases
+ache
+lucia
+gandhi
+submission
+vicar
+spear
+shine
+tasmania
+biting
+detention
+constitute
+tighter
+seasonal
+##gus
+terrestrial
+matthews
+##oka
+effectiveness
+parody
+philharmonic
+##onic
+1816
+strangers
+encoded
+consortium
+guaranteed
+regards
+shifts
+tortured
+collision
+supervisor
+inform
+broader
+insight
+theaters
+armour
+emeritus
+blink
+incorporates
+mapping
+##50
+##ein
+handball
+flexible
+##nta
+substantially
+generous
+thief
+##own
+carr
+loses
+1793
+prose
+ucla
+romeo
+generic
+metallic
+realization
+damages
+mk
+commissioners
+zach
+default
+##ther
+helicopters
+lengthy
+stems
+spa
+partnered
+spectators
+rogue
+indication
+penalties
+teresa
+1801
+sen
+##tric
+dalton
+##wich
+irving
+photographic
+##vey
+dell
+deaf
+peters
+excluded
+unsure
+##vable
+patterson
+crawled
+##zio
+resided
+whipped
+latvia
+slower
+ecole
+pipes
+employers
+maharashtra
+comparable
+va
+textile
+pageant
+##gel
+alphabet
+binary
+irrigation
+chartered
+choked
+antoine
+offs
+waking
+supplement
+##wen
+quantities
+demolition
+regain
+locate
+urdu
+folks
+alt
+114
+##mc
+scary
+andreas
+whites
+##ava
+classrooms
+mw
+aesthetic
+publishes
+valleys
+guides
+cubs
+johannes
+bryant
+conventions
+affecting
+##itt
+drain
+awesome
+isolation
+prosecutor
+ambitious
+apology
+captive
+downs
+atmospheric
+lorenzo
+aisle
+beef
+foul
+##onia
+kidding
+composite
+disturbed
+illusion
+natives
+##ffer
+emi
+rockets
+riverside
+wartime
+painters
+adolf
+melted
+##ail
+uncertainty
+simulation
+hawks
+progressed
+meantime
+builder
+spray
+breach
+unhappy
+regina
+russians
+##urg
+determining
+##tation
+tram
+1806
+##quin
+aging
+##12
+1823
+garion
+rented
+mister
+diaz
+terminated
+clip
+1817
+depend
+nervously
+disco
+owe
+defenders
+shiva
+notorious
+disbelief
+shiny
+worcester
+##gation
+##yr
+trailing
+undertook
+islander
+belarus
+limitations
+watershed
+fuller
+overlooking
+utilized
+raphael
+1819
+synthetic
+breakdown
+klein
+##nate
+moaned
+memoir
+lamb
+practicing
+##erly
+cellular
+arrows
+exotic
+##graphy
+witches
+117
+charted
+rey
+hut
+hierarchy
+subdivision
+freshwater
+giuseppe
+aloud
+reyes
+qatar
+marty
+sideways
+utterly
+sexually
+jude
+prayers
+mccarthy
+softball
+blend
+damien
+##gging
+##metric
+wholly
+erupted
+lebanese
+negro
+revenues
+tasted
+comparative
+teamed
+transaction
+labeled
+maori
+sovereignty
+parkway
+trauma
+gran
+malay
+121
+advancement
+descendant
+2020
+buzz
+salvation
+inventory
+symbolic
+##making
+antarctica
+mps
+##gas
+##bro
+mohammed
+myanmar
+holt
+submarines
+tones
+##lman
+locker
+patriarch
+bangkok
+emerson
+remarks
+predators
+kin
+afghan
+confession
+norwich
+rental
+emerge
+advantages
+##zel
+rca
+##hold
+shortened
+storms
+aidan
+##matic
+autonomy
+compliance
+##quet
+dudley
+atp
+##osis
+1803
+motto
+documentation
+summary
+professors
+spectacular
+christina
+archdiocese
+flashing
+innocence
+remake
+##dell
+psychic
+reef
+scare
+employ
+rs
+sticks
+meg
+gus
+leans
+##ude
+accompany
+bergen
+tomas
+##iko
+doom
+wages
+pools
+##nch
+##bes
+breasts
+scholarly
+alison
+outline
+brittany
+breakthrough
+willis
+realistic
+##cut
+##boro
+competitor
+##stan
+pike
+picnic
+icon
+designing
+commercials
+washing
+villain
+skiing
+micro
+costumes
+auburn
+halted
+executives
+##hat
+logistics
+cycles
+vowel
+applicable
+barrett
+exclaimed
+eurovision
+eternity
+ramon
+##umi
+##lls
+modifications
+sweeping
+disgust
+##uck
+torch
+aviv
+ensuring
+rude
+dusty
+sonic
+donovan
+outskirts
+cu
+pathway
+##band
+##gun
+##lines
+disciplines
+acids
+cadet
+paired
+##40
+sketches
+##sive
+marriages
+##⁺
+folding
+peers
+slovak
+implies
+admired
+##beck
+1880s
+leopold
+instinct
+attained
+weston
+megan
+horace
+##ination
+dorsal
+ingredients
+evolutionary
+##its
+complications
+deity
+lethal
+brushing
+levy
+deserted
+institutes
+posthumously
+delivering
+telescope
+coronation
+motivated
+rapids
+luc
+flicked
+pays
+volcano
+tanner
+weighed
+##nica
+crowds
+frankie
+gifted
+addressing
+granddaughter
+winding
+##rna
+constantine
+gomez
+##front
+landscapes
+rudolf
+anthropology
+slate
+werewolf
+##lio
+astronomy
+circa
+rouge
+dreaming
+sack
+knelt
+drowned
+naomi
+prolific
+tracked
+freezing
+herb
+##dium
+agony
+randall
+twisting
+wendy
+deposit
+touches
+vein
+wheeler
+##bbled
+##bor
+batted
+retaining
+tire
+presently
+compare
+specification
+daemon
+nigel
+##grave
+merry
+recommendation
+czechoslovakia
+sandra
+ng
+roma
+##sts
+lambert
+inheritance
+sheikh
+winchester
+cries
+examining
+##yle
+comeback
+cuisine
+nave
+##iv
+ko
+retrieve
+tomatoes
+barker
+polished
+defining
+irene
+lantern
+personalities
+begging
+tract
+swore
+1809
+175
+##gic
+omaha
+brotherhood
+##rley
+haiti
+##ots
+exeter
+##ete
+##zia
+steele
+dumb
+pearson
+210
+surveyed
+elisabeth
+trends
+##ef
+fritz
+##rf
+premium
+bugs
+fraction
+calmly
+viking
+##birds
+tug
+inserted
+unusually
+##ield
+confronted
+distress
+crashing
+brent
+turks
+resign
+##olo
+cambodia
+gabe
+sauce
+##kal
+evelyn
+116
+extant
+clusters
+quarry
+teenagers
+luna
+##lers
+##ister
+affiliation
+drill
+##ashi
+panthers
+scenic
+libya
+anita
+strengthen
+inscriptions
+##cated
+lace
+sued
+judith
+riots
+##uted
+mint
+##eta
+preparations
+midst
+dub
+challenger
+##vich
+mock
+cf
+displaced
+wicket
+breaths
+enables
+schmidt
+analyst
+##lum
+ag
+highlight
+automotive
+axe
+josef
+newark
+sufficiently
+resembles
+50th
+##pal
+flushed
+mum
+traits
+##ante
+commodore
+incomplete
+warming
+titular
+ceremonial
+ethical
+118
+celebrating
+eighteenth
+cao
+lima
+medalist
+mobility
+strips
+snakes
+##city
+miniature
+zagreb
+barton
+escapes
+umbrella
+automated
+doubted
+differs
+cooled
+georgetown
+dresden
+cooked
+fade
+wyatt
+rna
+jacobs
+carlton
+abundant
+stereo
+boost
+madras
+inning
+##hia
+spur
+ip
+malayalam
+begged
+osaka
+groan
+escaping
+charging
+dose
+vista
+##aj
+bud
+papa
+communists
+advocates
+edged
+tri
+##cent
+resemble
+peaking
+necklace
+fried
+montenegro
+saxony
+goose
+glances
+stuttgart
+curator
+recruit
+grocery
+sympathetic
+##tting
+##fort
+127
+lotus
+randolph
+ancestor
+##rand
+succeeding
+jupiter
+1798
+macedonian
+##heads
+hiking
+1808
+handing
+fischer
+##itive
+garbage
+node
+##pies
+prone
+singular
+papua
+inclined
+attractions
+italia
+pouring
+motioned
+grandma
+garnered
+jacksonville
+corp
+ego
+ringing
+aluminum
+##hausen
+ordering
+##foot
+drawer
+traders
+synagogue
+##play
+##kawa
+resistant
+wandering
+fragile
+fiona
+teased
+var
+hardcore
+soaked
+jubilee
+decisive
+exposition
+mercer
+poster
+valencia
+hale
+kuwait
+1811
+##ises
+##wr
+##eed
+tavern
+gamma
+122
+johan
+##uer
+airways
+amino
+gil
+##ury
+vocational
+domains
+torres
+##sp
+generator
+folklore
+outcomes
+##keeper
+canberra
+shooter
+fl
+beams
+confrontation
+##lling
+##gram
+feb
+aligned
+forestry
+pipeline
+jax
+motorway
+conception
+decay
+##tos
+coffin
+##cott
+stalin
+1805
+escorted
+minded
+##nam
+sitcom
+purchasing
+twilight
+veronica
+additions
+passive
+tensions
+straw
+123
+frequencies
+1804
+refugee
+cultivation
+##iate
+christie
+clary
+bulletin
+crept
+disposal
+##rich
+##zong
+processor
+crescent
+##rol
+bmw
+emphasized
+whale
+nazis
+aurora
+##eng
+dwelling
+hauled
+sponsors
+toledo
+mega
+ideology
+theatres
+tessa
+cerambycidae
+saves
+turtle
+cone
+suspects
+kara
+rusty
+yelling
+greeks
+mozart
+shades
+cocked
+participant
+##tro
+shire
+spit
+freeze
+necessity
+##cos
+inmates
+nielsen
+councillors
+loaned
+uncommon
+omar
+peasants
+botanical
+offspring
+daniels
+formations
+jokes
+1794
+pioneers
+sigma
+licensing
+##sus
+wheelchair
+polite
+1807
+liquor
+pratt
+trustee
+##uta
+forewings
+balloon
+##zz
+kilometre
+camping
+explicit
+casually
+shawn
+foolish
+teammates
+nm
+hassan
+carrie
+judged
+satisfy
+vanessa
+knives
+selective
+cnn
+flowed
+##lice
+eclipse
+stressed
+eliza
+mathematician
+cease
+cultivated
+##roy
+commissions
+browns
+##ania
+destroyers
+sheridan
+meadow
+##rius
+minerals
+##cial
+downstream
+clash
+gram
+memoirs
+ventures
+baha
+seymour
+archie
+midlands
+edith
+fare
+flynn
+invite
+canceled
+tiles
+stabbed
+boulder
+incorporate
+amended
+camden
+facial
+mollusk
+unreleased
+descriptions
+yoga
+grabs
+550
+raises
+ramp
+shiver
+##rose
+coined
+pioneering
+tunes
+qing
+warwick
+tops
+119
+melanie
+giles
+##rous
+wandered
+##inal
+annexed
+nov
+30th
+unnamed
+##ished
+organizational
+airplane
+normandy
+stoke
+whistle
+blessing
+violations
+chased
+holders
+shotgun
+##ctic
+outlet
+reactor
+##vik
+tires
+tearing
+shores
+fortified
+mascot
+constituencies
+nc
+columnist
+productive
+tibet
+##rta
+lineage
+hooked
+oct
+tapes
+judging
+cody
+##gger
+hansen
+kashmir
+triggered
+##eva
+solved
+cliffs
+##tree
+resisted
+anatomy
+protesters
+transparent
+implied
+##iga
+injection
+mattress
+excluding
+##mbo
+defenses
+helpless
+devotion
+##elli
+growl
+liberals
+weber
+phenomena
+atoms
+plug
+##iff
+mortality
+apprentice
+howe
+convincing
+aaa
+swimmer
+barber
+leone
+promptly
+sodium
+def
+nowadays
+arise
+##oning
+gloucester
+corrected
+dignity
+norm
+erie
+##ders
+elders
+evacuated
+sylvia
+compression
+##yar
+hartford
+pose
+backpack
+reasoning
+accepts
+24th
+wipe
+millimetres
+marcel
+##oda
+dodgers
+albion
+1790
+overwhelmed
+aerospace
+oaks
+1795
+showcase
+acknowledge
+recovering
+nolan
+ashe
+hurts
+geology
+fashioned
+disappearance
+farewell
+swollen
+shrug
+marquis
+wimbledon
+124
+rue
+1792
+commemorate
+reduces
+experiencing
+inevitable
+calcutta
+intel
+##court
+murderer
+sticking
+fisheries
+imagery
+bloom
+280
+brake
+##inus
+gustav
+hesitation
+memorable
+po
+viral
+beans
+accidents
+tunisia
+antenna
+spilled
+consort
+treatments
+aye
+perimeter
+##gard
+donation
+hostage
+migrated
+banker
+addiction
+apex
+lil
+trout
+##ously
+conscience
+##nova
+rams
+sands
+genome
+passionate
+troubles
+##lets
+##set
+amid
+##ibility
+##ret
+higgins
+exceed
+vikings
+##vie
+payne
+##zan
+muscular
+##ste
+defendant
+sucking
+##wal
+ibrahim
+fuselage
+claudia
+vfl
+europeans
+snails
+interval
+##garh
+preparatory
+statewide
+tasked
+lacrosse
+viktor
+##lation
+angola
+##hra
+flint
+implications
+employs
+teens
+patrons
+stall
+weekends
+barriers
+scrambled
+nucleus
+tehran
+jenna
+parsons
+lifelong
+robots
+displacement
+5000
+##bles
+precipitation
+##gt
+knuckles
+clutched
+1802
+marrying
+ecology
+marx
+accusations
+declare
+scars
+kolkata
+mat
+meadows
+bermuda
+skeleton
+finalists
+vintage
+crawl
+coordinate
+affects
+subjected
+orchestral
+mistaken
+##tc
+mirrors
+dipped
+relied
+260
+arches
+candle
+##nick
+incorporating
+wildly
+fond
+basilica
+owl
+fringe
+rituals
+whispering
+stirred
+feud
+tertiary
+slick
+goat
+honorable
+whereby
+skip
+ricardo
+stripes
+parachute
+adjoining
+submerged
+synthesizer
+##gren
+intend
+positively
+ninety
+phi
+beaver
+partition
+fellows
+alexis
+prohibition
+carlisle
+bizarre
+fraternity
+##bre
+doubts
+icy
+cbc
+aquatic
+sneak
+sonny
+combines
+airports
+crude
+supervised
+spatial
+merge
+alfonso
+##bic
+corrupt
+scan
+undergo
+##ams
+disabilities
+colombian
+comparing
+dolphins
+perkins
+##lish
+reprinted
+unanimous
+bounced
+hairs
+underworld
+midwest
+semester
+bucket
+paperback
+miniseries
+coventry
+demise
+##leigh
+demonstrations
+sensor
+rotating
+yan
+##hler
+arrange
+soils
+##idge
+hyderabad
+labs
+##dr
+brakes
+grandchildren
+##nde
+negotiated
+rover
+ferrari
+continuation
+directorate
+augusta
+stevenson
+counterpart
+gore
+##rda
+nursery
+rican
+ave
+collectively
+broadly
+pastoral
+repertoire
+asserted
+discovering
+nordic
+styled
+fiba
+cunningham
+harley
+middlesex
+survives
+tumor
+tempo
+zack
+aiming
+lok
+urgent
+##rade
+##nto
+devils
+##ement
+contractor
+turin
+##wl
+##ool
+bliss
+repaired
+simmons
+moan
+astronomical
+cr
+negotiate
+lyric
+1890s
+lara
+bred
+clad
+angus
+pbs
+##ience
+engineered
+posed
+##lk
+hernandez
+possessions
+elbows
+psychiatric
+strokes
+confluence
+electorate
+lifts
+campuses
+lava
+alps
+##ep
+##ution
+##date
+physicist
+woody
+##page
+##ographic
+##itis
+juliet
+reformation
+sparhawk
+320
+complement
+suppressed
+jewel
+##½
+floated
+##kas
+continuity
+sadly
+##ische
+inability
+melting
+scanning
+paula
+flour
+judaism
+safer
+vague
+##lm
+solving
+curb
+##stown
+financially
+gable
+bees
+expired
+miserable
+cassidy
+dominion
+1789
+cupped
+145
+robbery
+facto
+amos
+warden
+resume
+tallest
+marvin
+ing
+pounded
+usd
+declaring
+gasoline
+##aux
+darkened
+270
+650
+sophomore
+##mere
+erection
+gossip
+televised
+risen
+dial
+##eu
+pillars
+##link
+passages
+profound
+##tina
+arabian
+ashton
+silicon
+nail
+##ead
+##lated
+##wer
+##hardt
+fleming
+firearms
+ducked
+circuits
+blows
+waterloo
+titans
+##lina
+atom
+fireplace
+cheshire
+financed
+activation
+algorithms
+##zzi
+constituent
+catcher
+cherokee
+partnerships
+sexuality
+platoon
+tragic
+vivian
+guarded
+whiskey
+meditation
+poetic
+##late
+##nga
+##ake
+porto
+listeners
+dominance
+kendra
+mona
+chandler
+factions
+22nd
+salisbury
+attitudes
+derivative
+##ido
+##haus
+intake
+paced
+javier
+illustrator
+barrels
+bias
+cockpit
+burnett
+dreamed
+ensuing
+##anda
+receptors
+someday
+hawkins
+mattered
+##lal
+slavic
+1799
+jesuit
+cameroon
+wasted
+tai
+wax
+lowering
+victorious
+freaking
+outright
+hancock
+librarian
+sensing
+bald
+calcium
+myers
+tablet
+announcing
+barack
+shipyard
+pharmaceutical
+##uan
+greenwich
+flush
+medley
+patches
+wolfgang
+pt
+speeches
+acquiring
+exams
+nikolai
+##gg
+hayden
+kannada
+##type
+reilly
+##pt
+waitress
+abdomen
+devastated
+capped
+pseudonym
+pharmacy
+fulfill
+paraguay
+1796
+clicked
+##trom
+archipelago
+syndicated
+##hman
+lumber
+orgasm
+rejection
+clifford
+lorraine
+advent
+mafia
+rodney
+brock
+##ght
+##used
+##elia
+cassette
+chamberlain
+despair
+mongolia
+sensors
+developmental
+upstream
+##eg
+##alis
+spanning
+165
+trombone
+basque
+seeded
+interred
+renewable
+rhys
+leapt
+revision
+molecule
+##ages
+chord
+vicious
+nord
+shivered
+23rd
+arlington
+debts
+corpus
+sunrise
+bays
+blackburn
+centimetres
+##uded
+shuddered
+gm
+strangely
+gripping
+cartoons
+isabelle
+orbital
+##ppa
+seals
+proving
+##lton
+refusal
+strengthened
+bust
+assisting
+baghdad
+batsman
+portrayal
+mara
+pushes
+spears
+og
+##cock
+reside
+nathaniel
+brennan
+1776
+confirmation
+caucus
+##worthy
+markings
+yemen
+nobles
+ku
+lazy
+viewer
+catalan
+encompasses
+sawyer
+##fall
+sparked
+substances
+patents
+braves
+arranger
+evacuation
+sergio
+persuade
+dover
+tolerance
+penguin
+cum
+jockey
+insufficient
+townships
+occupying
+declining
+plural
+processed
+projection
+puppet
+flanders
+introduces
+liability
+##yon
+gymnastics
+antwerp
+taipei
+hobart
+candles
+jeep
+wes
+observers
+126
+chaplain
+bundle
+glorious
+##hine
+hazel
+flung
+sol
+excavations
+dumped
+stares
+sh
+bangalore
+triangular
+icelandic
+intervals
+expressing
+turbine
+##vers
+songwriting
+crafts
+##igo
+jasmine
+ditch
+rite
+##ways
+entertaining
+comply
+sorrow
+wrestlers
+basel
+emirates
+marian
+rivera
+helpful
+##some
+caution
+downward
+networking
+##atory
+##tered
+darted
+genocide
+emergence
+replies
+specializing
+spokesman
+convenient
+unlocked
+fading
+augustine
+concentrations
+resemblance
+elijah
+investigator
+andhra
+##uda
+promotes
+bean
+##rrell
+fleeing
+wan
+simone
+announcer
+##ame
+##bby
+lydia
+weaver
+132
+residency
+modification
+##fest
+stretches
+##ast
+alternatively
+nat
+lowe
+lacks
+##ented
+pam
+tile
+concealed
+inferior
+abdullah
+residences
+tissues
+vengeance
+##ided
+moisture
+peculiar
+groove
+zip
+bologna
+jennings
+ninja
+oversaw
+zombies
+pumping
+batch
+livingston
+emerald
+installations
+1797
+peel
+nitrogen
+rama
+##fying
+##star
+schooling
+strands
+responding
+werner
+##ost
+lime
+casa
+accurately
+targeting
+##rod
+underway
+##uru
+hemisphere
+lester
+##yard
+occupies
+2d
+griffith
+angrily
+reorganized
+##owing
+courtney
+deposited
+##dd
+##30
+estadio
+##ifies
+dunn
+exiled
+##ying
+checks
+##combe
+##о
+##fly
+successes
+unexpectedly
+blu
+assessed
+##flower
+##ه
+observing
+sacked
+spiders
+kn
+##tail
+mu
+nodes
+prosperity
+audrey
+divisional
+155
+broncos
+tangled
+adjust
+feeds
+erosion
+paolo
+surf
+directory
+snatched
+humid
+admiralty
+screwed
+gt
+reddish
+##nese
+modules
+trench
+lamps
+bind
+leah
+bucks
+competes
+##nz
+##form
+transcription
+##uc
+isles
+violently
+clutching
+pga
+cyclist
+inflation
+flats
+ragged
+unnecessary
+##hian
+stubborn
+coordinated
+harriet
+baba
+disqualified
+330
+insect
+wolfe
+##fies
+reinforcements
+rocked
+duel
+winked
+embraced
+bricks
+##raj
+hiatus
+defeats
+pending
+brightly
+jealousy
+##xton
+##hm
+##uki
+lena
+gdp
+colorful
+##dley
+stein
+kidney
+##shu
+underwear
+wanderers
+##haw
+##icus
+guardians
+m³
+roared
+habits
+##wise
+permits
+gp
+uranium
+punished
+disguise
+bundesliga
+elise
+dundee
+erotic
+partisan
+pi
+collectors
+float
+individually
+rendering
+behavioral
+bucharest
+ser
+hare
+valerie
+corporal
+nutrition
+proportional
+##isa
+immense
+##kis
+pavement
+##zie
+##eld
+sutherland
+crouched
+1775
+##lp
+suzuki
+trades
+endurance
+operas
+crosby
+prayed
+priory
+rory
+socially
+##urn
+gujarat
+##pu
+walton
+cube
+pasha
+privilege
+lennon
+floods
+thorne
+waterfall
+nipple
+scouting
+approve
+##lov
+minorities
+voter
+dwight
+extensions
+assure
+ballroom
+slap
+dripping
+privileges
+rejoined
+confessed
+demonstrating
+patriotic
+yell
+investor
+##uth
+pagan
+slumped
+squares
+##cle
+##kins
+confront
+bert
+embarrassment
+##aid
+aston
+urging
+sweater
+starr
+yuri
+brains
+williamson
+commuter
+mortar
+structured
+selfish
+exports
+##jon
+cds
+##him
+unfinished
+##rre
+mortgage
+destinations
+##nagar
+canoe
+solitary
+buchanan
+delays
+magistrate
+fk
+##pling
+motivation
+##lier
+##vier
+recruiting
+assess
+##mouth
+malik
+antique
+1791
+pius
+rahman
+reich
+tub
+zhou
+smashed
+airs
+galway
+xii
+conditioning
+honduras
+discharged
+dexter
+##pf
+lionel
+129
+debates
+lemon
+tiffany
+volunteered
+dom
+dioxide
+procession
+devi
+sic
+tremendous
+advertisements
+colts
+transferring
+verdict
+hanover
+decommissioned
+utter
+relate
+pac
+racism
+##top
+beacon
+limp
+similarity
+terra
+occurrence
+ant
+##how
+becky
+capt
+updates
+armament
+richie
+pal
+##graph
+halloween
+mayo
+##ssen
+##bone
+cara
+serena
+fcc
+dolls
+obligations
+##dling
+violated
+lafayette
+jakarta
+exploitation
+##ime
+infamous
+iconic
+##lah
+##park
+kitty
+moody
+reginald
+dread
+spill
+crystals
+olivier
+modeled
+bluff
+equilibrium
+separating
+notices
+ordnance
+extinction
+onset
+cosmic
+attachment
+sammy
+expose
+privy
+anchored
+##bil
+abbott
+admits
+bending
+baritone
+emmanuel
+policeman
+vaughan
+winged
+climax
+dresses
+denny
+polytechnic
+mohamed
+burmese
+authentic
+nikki
+genetics
+grandparents
+homestead
+gaza
+postponed
+metacritic
+una
+##sby
+##bat
+unstable
+dissertation
+##rial
+##cian
+curls
+obscure
+uncovered
+bronx
+praying
+disappearing
+##hoe
+prehistoric
+coke
+turret
+mutations
+nonprofit
+pits
+monaco
+##ي
+##usion
+prominently
+dispatched
+podium
+##mir
+uci
+##uation
+133
+fortifications
+birthplace
+kendall
+##lby
+##oll
+preacher
+rack
+goodman
+##rman
+persistent
+##ott
+countless
+jaime
+recorder
+lexington
+persecution
+jumps
+renewal
+wagons
+##11
+crushing
+##holder
+decorations
+##lake
+abundance
+wrath
+laundry
+£1
+garde
+##rp
+jeanne
+beetles
+peasant
+##sl
+splitting
+caste
+sergei
+##rer
+##ema
+scripts
+##ively
+rub
+satellites
+##vor
+inscribed
+verlag
+scrapped
+gale
+packages
+chick
+potato
+slogan
+kathleen
+arabs
+##culture
+counterparts
+reminiscent
+choral
+##tead
+rand
+retains
+bushes
+dane
+accomplish
+courtesy
+closes
+##oth
+slaughter
+hague
+krakow
+lawson
+tailed
+elias
+ginger
+##ttes
+canopy
+betrayal
+rebuilding
+turf
+##hof
+frowning
+allegiance
+brigades
+kicks
+rebuild
+polls
+alias
+nationalism
+td
+rowan
+audition
+bowie
+fortunately
+recognizes
+harp
+dillon
+horrified
+##oro
+renault
+##tics
+ropes
+##α
+presumed
+rewarded
+infrared
+wiping
+accelerated
+illustration
+##rid
+presses
+practitioners
+badminton
+##iard
+detained
+##tera
+recognizing
+relates
+misery
+##sies
+##tly
+reproduction
+piercing
+potatoes
+thornton
+esther
+manners
+hbo
+##aan
+ours
+bullshit
+ernie
+perennial
+sensitivity
+illuminated
+rupert
+##jin
+##iss
+##ear
+rfc
+nassau
+##dock
+staggered
+socialism
+##haven
+appointments
+nonsense
+prestige
+sharma
+haul
+##tical
+solidarity
+gps
+##ook
+##rata
+igor
+pedestrian
+##uit
+baxter
+tenants
+wires
+medication
+unlimited
+guiding
+impacts
+diabetes
+##rama
+sasha
+pas
+clive
+extraction
+131
+continually
+constraints
+##bilities
+sonata
+hunted
+sixteenth
+chu
+planting
+quote
+mayer
+pretended
+abs
+spat
+##hua
+ceramic
+##cci
+curtains
+pigs
+pitching
+##dad
+latvian
+sore
+dayton
+##sted
+##qi
+patrols
+slice
+playground
+##nted
+shone
+stool
+apparatus
+inadequate
+mates
+treason
+##ija
+desires
+##liga
+##croft
+somalia
+laurent
+mir
+leonardo
+oracle
+grape
+obliged
+chevrolet
+thirteenth
+stunning
+enthusiastic
+##ede
+accounted
+concludes
+currents
+basil
+##kovic
+drought
+##rica
+mai
+##aire
+shove
+posting
+##shed
+pilgrimage
+humorous
+packing
+fry
+pencil
+wines
+smells
+144
+marilyn
+aching
+newest
+clung
+bon
+neighbours
+sanctioned
+##pie
+mug
+##stock
+drowning
+##mma
+hydraulic
+##vil
+hiring
+reminder
+lilly
+investigators
+##ncies
+sour
+##eous
+compulsory
+packet
+##rion
+##graphic
+##elle
+cannes
+##inate
+depressed
+##rit
+heroic
+importantly
+theresa
+##tled
+conway
+saturn
+marginal
+rae
+##xia
+corresponds
+royce
+pact
+jasper
+explosives
+packaging
+aluminium
+##ttered
+denotes
+rhythmic
+spans
+assignments
+hereditary
+outlined
+originating
+sundays
+lad
+reissued
+greeting
+beatrice
+##dic
+pillar
+marcos
+plots
+handbook
+alcoholic
+judiciary
+avant
+slides
+extract
+masculine
+blur
+##eum
+##force
+homage
+trembled
+owens
+hymn
+trey
+omega
+signaling
+socks
+accumulated
+reacted
+attic
+theo
+lining
+angie
+distraction
+primera
+talbot
+##key
+1200
+ti
+creativity
+billed
+##hey
+deacon
+eduardo
+identifies
+proposition
+dizzy
+gunner
+hogan
+##yam
+##pping
+##hol
+ja
+##chan
+jensen
+reconstructed
+##berger
+clearance
+darius
+##nier
+abe
+harlem
+plea
+dei
+circled
+emotionally
+notation
+fascist
+neville
+exceeded
+upwards
+viable
+ducks
+##fo
+workforce
+racer
+limiting
+shri
+##lson
+possesses
+1600
+kerr
+moths
+devastating
+laden
+disturbing
+locking
+##cture
+gal
+fearing
+accreditation
+flavor
+aide
+1870s
+mountainous
+##baum
+melt
+##ures
+motel
+texture
+servers
+soda
+##mb
+herd
+##nium
+erect
+puzzled
+hum
+peggy
+examinations
+gould
+testified
+geoff
+ren
+devised
+sacks
+##law
+denial
+posters
+grunted
+cesar
+tutor
+ec
+gerry
+offerings
+byrne
+falcons
+combinations
+ct
+incoming
+pardon
+rocking
+26th
+avengers
+flared
+mankind
+seller
+uttar
+loch
+nadia
+stroking
+exposing
+##hd
+fertile
+ancestral
+instituted
+##has
+noises
+prophecy
+taxation
+eminent
+vivid
+pol
+##bol
+dart
+indirect
+multimedia
+notebook
+upside
+displaying
+adrenaline
+referenced
+geometric
+##iving
+progression
+##ddy
+blunt
+announce
+##far
+implementing
+##lav
+aggression
+liaison
+cooler
+cares
+headache
+plantations
+gorge
+dots
+impulse
+thickness
+ashamed
+averaging
+kathy
+obligation
+precursor
+137
+fowler
+symmetry
+thee
+225
+hears
+##rai
+undergoing
+ads
+butcher
+bowler
+##lip
+cigarettes
+subscription
+goodness
+##ically
+browne
+##hos
+##tech
+kyoto
+donor
+##erty
+damaging
+friction
+drifting
+expeditions
+hardened
+prostitution
+152
+fauna
+blankets
+claw
+tossing
+snarled
+butterflies
+recruits
+investigative
+coated
+healed
+138
+communal
+hai
+xiii
+academics
+boone
+psychologist
+restless
+lahore
+stephens
+mba
+brendan
+foreigners
+printer
+##pc
+ached
+explode
+27th
+deed
+scratched
+dared
+##pole
+cardiac
+1780
+okinawa
+proto
+commando
+compelled
+oddly
+electrons
+##base
+replica
+thanksgiving
+##rist
+sheila
+deliberate
+stafford
+tidal
+representations
+hercules
+ou
+##path
+##iated
+kidnapping
+lenses
+##tling
+deficit
+samoa
+mouths
+consuming
+computational
+maze
+granting
+smirk
+razor
+fixture
+ideals
+inviting
+aiden
+nominal
+##vs
+issuing
+julio
+pitt
+ramsey
+docks
+##oss
+exhaust
+##owed
+bavarian
+draped
+anterior
+mating
+ethiopian
+explores
+noticing
+##nton
+discarded
+convenience
+hoffman
+endowment
+beasts
+cartridge
+mormon
+paternal
+probe
+sleeves
+interfere
+lump
+deadline
+##rail
+jenks
+bulldogs
+scrap
+alternating
+justified
+reproductive
+nam
+seize
+descending
+secretariat
+kirby
+coupe
+grouped
+smash
+panther
+sedan
+tapping
+##18
+lola
+cheer
+germanic
+unfortunate
+##eter
+unrelated
+##fan
+subordinate
+##sdale
+suzanne
+advertisement
+##ility
+horsepower
+##lda
+cautiously
+discourse
+luigi
+##mans
+##fields
+noun
+prevalent
+mao
+schneider
+everett
+surround
+governorate
+kira
+##avia
+westward
+##take
+misty
+rails
+sustainability
+134
+unused
+##rating
+packs
+toast
+unwilling
+regulate
+thy
+suffrage
+nile
+awe
+assam
+definitions
+travelers
+affordable
+##rb
+conferred
+sells
+undefeated
+beneficial
+torso
+basal
+repeating
+remixes
+##pass
+bahrain
+cables
+fang
+##itated
+excavated
+numbering
+statutory
+##rey
+deluxe
+##lian
+forested
+ramirez
+derbyshire
+zeus
+slamming
+transfers
+astronomer
+banana
+lottery
+berg
+histories
+bamboo
+##uchi
+resurrection
+posterior
+bowls
+vaguely
+##thi
+thou
+preserving
+tensed
+offence
+##inas
+meyrick
+callum
+ridden
+watt
+langdon
+tying
+lowland
+snorted
+daring
+truman
+##hale
+##girl
+aura
+overly
+filing
+weighing
+goa
+infections
+philanthropist
+saunders
+eponymous
+##owski
+latitude
+perspectives
+reviewing
+mets
+commandant
+radial
+##kha
+flashlight
+reliability
+koch
+vowels
+amazed
+ada
+elaine
+supper
+##rth
+##encies
+predator
+debated
+soviets
+cola
+##boards
+##nah
+compartment
+crooked
+arbitrary
+fourteenth
+##ctive
+havana
+majors
+steelers
+clips
+profitable
+ambush
+exited
+packers
+##tile
+nude
+cracks
+fungi
+##е
+limb
+trousers
+josie
+shelby
+tens
+frederic
+##ος
+definite
+smoothly
+constellation
+insult
+baton
+discs
+lingering
+##nco
+conclusions
+lent
+staging
+becker
+grandpa
+shaky
+##tron
+einstein
+obstacles
+sk
+adverse
+elle
+economically
+##moto
+mccartney
+thor
+dismissal
+motions
+readings
+nostrils
+treatise
+##pace
+squeezing
+evidently
+prolonged
+1783
+venezuelan
+je
+marguerite
+beirut
+takeover
+shareholders
+##vent
+denise
+digit
+airplay
+norse
+##bbling
+imaginary
+pills
+hubert
+blaze
+vacated
+eliminating
+##ello
+vine
+mansfield
+##tty
+retrospective
+barrow
+borne
+clutch
+bail
+forensic
+weaving
+##nett
+##witz
+desktop
+citadel
+promotions
+worrying
+dorset
+ieee
+subdivided
+##iating
+manned
+expeditionary
+pickup
+synod
+chuckle
+185
+barney
+##rz
+##ffin
+functionality
+karachi
+litigation
+meanings
+uc
+lick
+turbo
+anders
+##ffed
+execute
+curl
+oppose
+ankles
+typhoon
+##د
+##ache
+##asia
+linguistics
+compassion
+pressures
+grazing
+perfection
+##iting
+immunity
+monopoly
+muddy
+backgrounds
+136
+namibia
+francesca
+monitors
+attracting
+stunt
+tuition
+##ии
+vegetable
+##mates
+##quent
+mgm
+jen
+complexes
+forts
+##ond
+cellar
+bites
+seventeenth
+royals
+flemish
+failures
+mast
+charities
+##cular
+peruvian
+capitals
+macmillan
+ipswich
+outward
+frigate
+postgraduate
+folds
+employing
+##ouse
+concurrently
+fiery
+##tai
+contingent
+nightmares
+monumental
+nicaragua
+##kowski
+lizard
+mal
+fielding
+gig
+reject
+##pad
+harding
+##ipe
+coastline
+##cin
+##nos
+beethoven
+humphrey
+innovations
+##tam
+##nge
+norris
+doris
+solicitor
+huang
+obey
+141
+##lc
+niagara
+##tton
+shelves
+aug
+bourbon
+curry
+nightclub
+specifications
+hilton
+##ndo
+centennial
+dispersed
+worm
+neglected
+briggs
+sm
+font
+kuala
+uneasy
+plc
+##nstein
+##bound
+##aking
+##burgh
+awaiting
+pronunciation
+##bbed
+##quest
+eh
+optimal
+zhu
+raped
+greens
+presided
+brenda
+worries
+##life
+venetian
+marxist
+turnout
+##lius
+refined
+braced
+sins
+grasped
+sunderland
+nickel
+speculated
+lowell
+cyrillic
+communism
+fundraising
+resembling
+colonists
+mutant
+freddie
+usc
+##mos
+gratitude
+##run
+mural
+##lous
+chemist
+wi
+reminds
+28th
+steals
+tess
+pietro
+##ingen
+promoter
+ri
+microphone
+honoured
+rai
+sant
+##qui
+feather
+##nson
+burlington
+kurdish
+terrorists
+deborah
+sickness
+##wed
+##eet
+hazard
+irritated
+desperation
+veil
+clarity
+##rik
+jewels
+xv
+##gged
+##ows
+##cup
+berkshire
+unfair
+mysteries
+orchid
+winced
+exhaustion
+renovations
+stranded
+obe
+infinity
+##nies
+adapt
+redevelopment
+thanked
+registry
+olga
+domingo
+noir
+tudor
+ole
+##atus
+commenting
+behaviors
+##ais
+crisp
+pauline
+probable
+stirling
+wigan
+##bian
+paralympics
+panting
+surpassed
+##rew
+luca
+barred
+pony
+famed
+##sters
+cassandra
+waiter
+carolyn
+exported
+##orted
+andres
+destructive
+deeds
+jonah
+castles
+vacancy
+suv
+##glass
+1788
+orchard
+yep
+famine
+belarusian
+sprang
+##forth
+skinny
+##mis
+administrators
+rotterdam
+zambia
+zhao
+boiler
+discoveries
+##ride
+##physics
+lucius
+disappointing
+outreach
+spoon
+##frame
+qualifications
+unanimously
+enjoys
+regency
+##iidae
+stade
+realism
+veterinary
+rodgers
+dump
+alain
+chestnut
+castile
+censorship
+rumble
+gibbs
+##itor
+communion
+reggae
+inactivated
+logs
+loads
+##houses
+homosexual
+##iano
+ale
+informs
+##cas
+phrases
+plaster
+linebacker
+ambrose
+kaiser
+fascinated
+850
+limerick
+recruitment
+forge
+mastered
+##nding
+leinster
+rooted
+threaten
+##strom
+borneo
+##hes
+suggestions
+scholarships
+propeller
+documentaries
+patronage
+coats
+constructing
+invest
+neurons
+comet
+entirety
+shouts
+identities
+annoying
+unchanged
+wary
+##antly
+##ogy
+neat
+oversight
+##kos
+phillies
+replay
+constance
+##kka
+incarnation
+humble
+skies
+minus
+##acy
+smithsonian
+##chel
+guerrilla
+jar
+cadets
+##plate
+surplus
+audit
+##aru
+cracking
+joanna
+louisa
+pacing
+##lights
+intentionally
+##iri
+diner
+nwa
+imprint
+australians
+tong
+unprecedented
+bunker
+naive
+specialists
+ark
+nichols
+railing
+leaked
+pedal
+##uka
+shrub
+longing
+roofs
+v8
+captains
+neural
+tuned
+##ntal
+##jet
+emission
+medina
+frantic
+codex
+definitive
+sid
+abolition
+intensified
+stocks
+enrique
+sustain
+genoa
+oxide
+##written
+clues
+cha
+##gers
+tributaries
+fragment
+venom
+##rity
+##ente
+##sca
+muffled
+vain
+sire
+laos
+##ingly
+##hana
+hastily
+snapping
+surfaced
+sentiment
+motive
+##oft
+contests
+approximate
+mesa
+luckily
+dinosaur
+exchanges
+propelled
+accord
+bourne
+relieve
+tow
+masks
+offended
+##ues
+cynthia
+##mmer
+rains
+bartender
+zinc
+reviewers
+lois
+##sai
+legged
+arrogant
+rafe
+rosie
+comprise
+handicap
+blockade
+inlet
+lagoon
+copied
+drilling
+shelley
+petals
+##inian
+mandarin
+obsolete
+##inated
+onward
+arguably
+productivity
+cindy
+praising
+seldom
+busch
+discusses
+raleigh
+shortage
+ranged
+stanton
+encouragement
+firstly
+conceded
+overs
+temporal
+##uke
+cbe
+##bos
+woo
+certainty
+pumps
+##pton
+stalked
+##uli
+lizzie
+periodic
+thieves
+weaker
+##night
+gases
+shoving
+chooses
+wc
+##chemical
+prompting
+weights
+##kill
+robust
+flanked
+sticky
+hu
+tuberculosis
+##eb
+##eal
+christchurch
+resembled
+wallet
+reese
+inappropriate
+pictured
+distract
+fixing
+fiddle
+giggled
+burger
+heirs
+hairy
+mechanic
+torque
+apache
+obsessed
+chiefly
+cheng
+logging
+##tag
+extracted
+meaningful
+numb
+##vsky
+gloucestershire
+reminding
+##bay
+unite
+##lit
+breeds
+diminished
+clown
+glove
+1860s
+##ن
+##ug
+archibald
+focal
+freelance
+sliced
+depiction
+##yk
+organism
+switches
+sights
+stray
+crawling
+##ril
+lever
+leningrad
+interpretations
+loops
+anytime
+reel
+alicia
+delighted
+##ech
+inhaled
+xiv
+suitcase
+bernie
+vega
+licenses
+northampton
+exclusion
+induction
+monasteries
+racecourse
+homosexuality
+##right
+##sfield
+##rky
+dimitri
+michele
+alternatives
+ions
+commentators
+genuinely
+objected
+pork
+hospitality
+fencing
+stephan
+warships
+peripheral
+wit
+drunken
+wrinkled
+quentin
+spends
+departing
+chung
+numerical
+spokesperson
+##zone
+johannesburg
+caliber
+killers
+##udge
+assumes
+neatly
+demographic
+abigail
+bloc
+##vel
+mounting
+##lain
+bentley
+slightest
+xu
+recipients
+##jk
+merlin
+##writer
+seniors
+prisons
+blinking
+hindwings
+flickered
+kappa
+##hel
+80s
+strengthening
+appealing
+brewing
+gypsy
+mali
+lashes
+hulk
+unpleasant
+harassment
+bio
+treaties
+predict
+instrumentation
+pulp
+troupe
+boiling
+mantle
+##ffe
+ins
+##vn
+dividing
+handles
+verbs
+##onal
+coconut
+senegal
+340
+thorough
+gum
+momentarily
+##sto
+cocaine
+panicked
+destined
+##turing
+teatro
+denying
+weary
+captained
+mans
+##hawks
+##code
+wakefield
+bollywood
+thankfully
+##16
+cyril
+##wu
+amendments
+##bahn
+consultation
+stud
+reflections
+kindness
+1787
+internally
+##ovo
+tex
+mosaic
+distribute
+paddy
+seeming
+143
+##hic
+piers
+##15
+##mura
+##verse
+popularly
+winger
+kang
+sentinel
+mccoy
+##anza
+covenant
+##bag
+verge
+fireworks
+suppress
+thrilled
+dominate
+##jar
+swansea
+##60
+142
+reconciliation
+##ndi
+stiffened
+cue
+dorian
+##uf
+damascus
+amor
+ida
+foremost
+##aga
+porsche
+unseen
+dir
+##had
+##azi
+stony
+lexi
+melodies
+##nko
+angular
+integer
+podcast
+ants
+inherent
+jaws
+justify
+persona
+##olved
+josephine
+##nr
+##ressed
+customary
+flashes
+gala
+cyrus
+glaring
+backyard
+ariel
+physiology
+greenland
+html
+stir
+avon
+atletico
+finch
+methodology
+ked
+##lent
+mas
+catholicism
+townsend
+branding
+quincy
+fits
+containers
+1777
+ashore
+aragon
+##19
+forearm
+poisoning
+##sd
+adopting
+conquer
+grinding
+amnesty
+keller
+finances
+evaluate
+forged
+lankan
+instincts
+##uto
+guam
+bosnian
+photographed
+workplace
+desirable
+protector
+##dog
+allocation
+intently
+encourages
+willy
+##sten
+bodyguard
+electro
+brighter
+##ν
+bihar
+##chev
+lasts
+opener
+amphibious
+sal
+verde
+arte
+##cope
+captivity
+vocabulary
+yields
+##tted
+agreeing
+desmond
+pioneered
+##chus
+strap
+campaigned
+railroads
+##ович
+emblem
+##dre
+stormed
+501
+##ulous
+marijuana
+northumberland
+##gn
+##nath
+bowen
+landmarks
+beaumont
+##qua
+danube
+##bler
+attorneys
+th
+ge
+flyers
+critique
+villains
+cass
+mutation
+acc
+##0s
+colombo
+mckay
+motif
+sampling
+concluding
+syndicate
+##rell
+neon
+stables
+ds
+warnings
+clint
+mourning
+wilkinson
+##tated
+merrill
+leopard
+evenings
+exhaled
+emil
+sonia
+ezra
+discrete
+stove
+farrell
+fifteenth
+prescribed
+superhero
+##rier
+worms
+helm
+wren
+##duction
+##hc
+expo
+##rator
+hq
+unfamiliar
+antony
+prevents
+acceleration
+fiercely
+mari
+painfully
+calculations
+cheaper
+ign
+clifton
+irvine
+davenport
+mozambique
+##np
+pierced
+##evich
+wonders
+##wig
+##cate
+##iling
+crusade
+ware
+##uel
+enzymes
+reasonably
+mls
+##coe
+mater
+ambition
+bunny
+eliot
+kernel
+##fin
+asphalt
+headmaster
+torah
+aden
+lush
+pins
+waived
+##care
+##yas
+joao
+substrate
+enforce
+##grad
+##ules
+alvarez
+selections
+epidemic
+tempted
+##bit
+bremen
+translates
+ensured
+waterfront
+29th
+forrest
+manny
+malone
+kramer
+reigning
+cookies
+simpler
+absorption
+205
+engraved
+##ffy
+evaluated
+1778
+haze
+146
+comforting
+crossover
+##abe
+thorn
+##rift
+##imo
+##pop
+suppression
+fatigue
+cutter
+##tr
+201
+wurttemberg
+##orf
+enforced
+hovering
+proprietary
+gb
+samurai
+syllable
+ascent
+lacey
+tick
+lars
+tractor
+merchandise
+rep
+bouncing
+defendants
+##yre
+huntington
+##ground
+##oko
+standardized
+##hor
+##hima
+assassinated
+nu
+predecessors
+rainy
+liar
+assurance
+lyrical
+##uga
+secondly
+flattened
+ios
+parameter
+undercover
+##mity
+bordeaux
+punish
+ridges
+markers
+exodus
+inactive
+hesitate
+debbie
+nyc
+pledge
+savoy
+nagar
+offset
+organist
+##tium
+hesse
+marin
+converting
+##iver
+diagram
+propulsion
+pu
+validity
+reverted
+supportive
+##dc
+ministries
+clans
+responds
+proclamation
+##inae
+##ø
+##rea
+ein
+pleading
+patriot
+sf
+birch
+islanders
+strauss
+hates
+##dh
+brandenburg
+concession
+rd
+##ob
+1900s
+killings
+textbook
+antiquity
+cinematography
+wharf
+embarrassing
+setup
+creed
+farmland
+inequality
+centred
+signatures
+fallon
+370
+##ingham
+##uts
+ceylon
+gazing
+directive
+laurie
+##tern
+globally
+##uated
+##dent
+allah
+excavation
+threads
+##cross
+148
+frantically
+icc
+utilize
+determines
+respiratory
+thoughtful
+receptions
+##dicate
+merging
+chandra
+seine
+147
+builders
+builds
+diagnostic
+dev
+visibility
+goddamn
+analyses
+dhaka
+cho
+proves
+chancel
+concurrent
+curiously
+canadians
+pumped
+restoring
+1850s
+turtles
+jaguar
+sinister
+spinal
+traction
+declan
+vows
+1784
+glowed
+capitalism
+swirling
+install
+universidad
+##lder
+##oat
+soloist
+##genic
+##oor
+coincidence
+beginnings
+nissan
+dip
+resorts
+caucasus
+combustion
+infectious
+##eno
+pigeon
+serpent
+##itating
+conclude
+masked
+salad
+jew
+##gr
+surreal
+toni
+##wc
+harmonica
+151
+##gins
+##etic
+##coat
+fishermen
+intending
+bravery
+##wave
+klaus
+titan
+wembley
+taiwanese
+ransom
+40th
+incorrect
+hussein
+eyelids
+jp
+cooke
+dramas
+utilities
+##etta
+##print
+eisenhower
+principally
+granada
+lana
+##rak
+openings
+concord
+##bl
+bethany
+connie
+morality
+sega
+##mons
+##nard
+earnings
+##kara
+##cine
+wii
+communes
+##rel
+coma
+composing
+softened
+severed
+grapes
+##17
+nguyen
+analyzed
+warlord
+hubbard
+heavenly
+behave
+slovenian
+##hit
+##ony
+hailed
+filmmakers
+trance
+caldwell
+skye
+unrest
+coward
+likelihood
+##aging
+bern
+sci
+taliban
+honolulu
+propose
+##wang
+1700
+browser
+imagining
+cobra
+contributes
+dukes
+instinctively
+conan
+violinist
+##ores
+accessories
+gradual
+##amp
+quotes
+sioux
+##dating
+undertake
+intercepted
+sparkling
+compressed
+139
+fungus
+tombs
+haley
+imposing
+rests
+degradation
+lincolnshire
+retailers
+wetlands
+tulsa
+distributor
+dungeon
+nun
+greenhouse
+convey
+atlantis
+aft
+exits
+oman
+dresser
+lyons
+##sti
+joking
+eddy
+judgement
+omitted
+digits
+##cts
+##game
+juniors
+##rae
+cents
+stricken
+une
+##ngo
+wizards
+weir
+breton
+nan
+technician
+fibers
+liking
+royalty
+##cca
+154
+persia
+terribly
+magician
+##rable
+##unt
+vance
+cafeteria
+booker
+camille
+warmer
+##static
+consume
+cavern
+gaps
+compass
+contemporaries
+foyer
+soothing
+graveyard
+maj
+plunged
+blush
+##wear
+cascade
+demonstrates
+ordinance
+##nov
+boyle
+##lana
+rockefeller
+shaken
+banjo
+izzy
+##ense
+breathless
+vines
+##32
+##eman
+alterations
+chromosome
+dwellings
+feudal
+mole
+153
+catalonia
+relics
+tenant
+mandated
+##fm
+fridge
+hats
+honesty
+patented
+raul
+heap
+cruisers
+accusing
+enlightenment
+infants
+wherein
+chatham
+contractors
+zen
+affinity
+hc
+osborne
+piston
+156
+traps
+maturity
+##rana
+lagos
+##zal
+peering
+##nay
+attendant
+dealers
+protocols
+subset
+prospects
+biographical
+##cre
+artery
+##zers
+insignia
+nuns
+endured
+##eration
+recommend
+schwartz
+serbs
+berger
+cromwell
+crossroads
+##ctor
+enduring
+clasped
+grounded
+##bine
+marseille
+twitched
+abel
+choke
+https
+catalyst
+moldova
+italians
+##tist
+disastrous
+wee
+##oured
+##nti
+wwf
+nope
+##piration
+##asa
+expresses
+thumbs
+167
+##nza
+coca
+1781
+cheating
+##ption
+skipped
+sensory
+heidelberg
+spies
+satan
+dangers
+semifinal
+202
+bohemia
+whitish
+confusing
+shipbuilding
+relies
+surgeons
+landings
+ravi
+baku
+moor
+suffix
+alejandro
+##yana
+litre
+upheld
+##unk
+rajasthan
+##rek
+coaster
+insists
+posture
+scenarios
+etienne
+favoured
+appoint
+transgender
+elephants
+poked
+greenwood
+defences
+fulfilled
+militant
+somali
+1758
+chalk
+potent
+##ucci
+migrants
+wink
+assistants
+nos
+restriction
+activism
+niger
+##ario
+colon
+shaun
+##sat
+daphne
+##erated
+swam
+congregations
+reprise
+considerations
+magnet
+playable
+xvi
+##р
+overthrow
+tobias
+knob
+chavez
+coding
+##mers
+propped
+katrina
+orient
+newcomer
+##suke
+temperate
+##pool
+farmhouse
+interrogation
+##vd
+committing
+##vert
+forthcoming
+strawberry
+joaquin
+macau
+ponds
+shocking
+siberia
+##cellular
+chant
+contributors
+##nant
+##ologists
+sped
+absorb
+hail
+1782
+spared
+##hore
+barbados
+karate
+opus
+originates
+saul
+##xie
+evergreen
+leaped
+##rock
+correlation
+exaggerated
+weekday
+unification
+bump
+tracing
+brig
+afb
+pathways
+utilizing
+##ners
+mod
+mb
+disturbance
+kneeling
+##stad
+##guchi
+100th
+pune
+##thy
+decreasing
+168
+manipulation
+miriam
+academia
+ecosystem
+occupational
+rbi
+##lem
+rift
+##14
+rotary
+stacked
+incorporation
+awakening
+generators
+guerrero
+racist
+##omy
+cyber
+derivatives
+culminated
+allie
+annals
+panzer
+sainte
+wikipedia
+pops
+zu
+austro
+##vate
+algerian
+politely
+nicholson
+mornings
+educate
+tastes
+thrill
+dartmouth
+##gating
+db
+##jee
+regan
+differing
+concentrating
+choreography
+divinity
+##media
+pledged
+alexandre
+routing
+gregor
+madeline
+##idal
+apocalypse
+##hora
+gunfire
+culminating
+elves
+fined
+liang
+lam
+programmed
+tar
+guessing
+transparency
+gabrielle
+##gna
+cancellation
+flexibility
+##lining
+accession
+shea
+stronghold
+nets
+specializes
+##rgan
+abused
+hasan
+sgt
+ling
+exceeding
+##₄
+admiration
+supermarket
+##ark
+photographers
+specialised
+tilt
+resonance
+hmm
+perfume
+380
+sami
+threatens
+garland
+botany
+guarding
+boiled
+greet
+puppy
+russo
+supplier
+wilmington
+vibrant
+vijay
+##bius
+paralympic
+grumbled
+paige
+faa
+licking
+margins
+hurricanes
+##gong
+fest
+grenade
+ripping
+##uz
+counseling
+weigh
+##sian
+needles
+wiltshire
+edison
+costly
+##not
+fulton
+tramway
+redesigned
+staffordshire
+cache
+gasping
+watkins
+sleepy
+candidacy
+##group
+monkeys
+timeline
+throbbing
+##bid
+##sos
+berth
+uzbekistan
+vanderbilt
+bothering
+overturned
+ballots
+gem
+##iger
+sunglasses
+subscribers
+hooker
+compelling
+ang
+exceptionally
+saloon
+stab
+##rdi
+carla
+terrifying
+rom
+##vision
+coil
+##oids
+satisfying
+vendors
+31st
+mackay
+deities
+overlooked
+ambient
+bahamas
+felipe
+olympia
+whirled
+botanist
+advertised
+tugging
+##dden
+disciples
+morales
+unionist
+rites
+foley
+morse
+motives
+creepy
+##₀
+soo
+##sz
+bargain
+highness
+frightening
+turnpike
+tory
+reorganization
+##cer
+depict
+biographer
+##walk
+unopposed
+manifesto
+##gles
+institut
+emile
+accidental
+kapoor
+##dam
+kilkenny
+cortex
+lively
+##13
+romanesque
+jain
+shan
+cannons
+##ood
+##ske
+petrol
+echoing
+amalgamated
+disappears
+cautious
+proposes
+sanctions
+trenton
+##ر
+flotilla
+aus
+contempt
+tor
+canary
+cote
+theirs
+##hun
+conceptual
+deleted
+fascinating
+paso
+blazing
+elf
+honourable
+hutchinson
+##eiro
+##outh
+##zin
+surveyor
+tee
+amidst
+wooded
+reissue
+intro
+##ono
+cobb
+shelters
+newsletter
+hanson
+brace
+encoding
+confiscated
+dem
+caravan
+marino
+scroll
+melodic
+cows
+imam
+##adi
+##aneous
+northward
+searches
+biodiversity
+cora
+310
+roaring
+##bers
+connell
+theologian
+halo
+compose
+pathetic
+unmarried
+dynamo
+##oot
+az
+calculation
+toulouse
+deserves
+humour
+nr
+forgiveness
+tam
+undergone
+martyr
+pamela
+myths
+whore
+counselor
+hicks
+290
+heavens
+battleship
+electromagnetic
+##bbs
+stellar
+establishments
+presley
+hopped
+##chin
+temptation
+90s
+wills
+nas
+##yuan
+nhs
+##nya
+seminars
+##yev
+adaptations
+gong
+asher
+lex
+indicator
+sikh
+tobago
+cites
+goin
+##yte
+satirical
+##gies
+characterised
+correspond
+bubbles
+lure
+participates
+##vid
+eruption
+skate
+therapeutic
+1785
+canals
+wholesale
+defaulted
+sac
+460
+petit
+##zzled
+virgil
+leak
+ravens
+256
+portraying
+##yx
+ghetto
+creators
+dams
+portray
+vicente
+##rington
+fae
+namesake
+bounty
+##arium
+joachim
+##ota
+##iser
+aforementioned
+axle
+snout
+depended
+dismantled
+reuben
+480
+##ibly
+gallagher
+##lau
+##pd
+earnest
+##ieu
+##iary
+inflicted
+objections
+##llar
+asa
+gritted
+##athy
+jericho
+##sea
+##was
+flick
+underside
+ceramics
+undead
+substituted
+195
+eastward
+undoubtedly
+wheeled
+chimney
+##iche
+guinness
+cb
+##ager
+siding
+##bell
+traitor
+baptiste
+disguised
+inauguration
+149
+tipperary
+choreographer
+perched
+warmed
+stationary
+eco
+##ike
+##ntes
+bacterial
+##aurus
+flores
+phosphate
+##core
+attacker
+invaders
+alvin
+intersects
+a1
+indirectly
+immigrated
+businessmen
+cornelius
+valves
+narrated
+pill
+sober
+ul
+nationale
+monastic
+applicants
+scenery
+##jack
+161
+motifs
+constitutes
+cpu
+##osh
+jurisdictions
+sd
+tuning
+irritation
+woven
+##uddin
+fertility
+gao
+##erie
+antagonist
+impatient
+glacial
+hides
+boarded
+denominations
+interception
+##jas
+cookie
+nicola
+##tee
+algebraic
+marquess
+bahn
+parole
+buyers
+bait
+turbines
+paperwork
+bestowed
+natasha
+renee
+oceans
+purchases
+157
+vaccine
+215
+##tock
+fixtures
+playhouse
+integrate
+jai
+oswald
+intellectuals
+##cky
+booked
+nests
+mortimer
+##isi
+obsession
+sept
+##gler
+##sum
+440
+scrutiny
+simultaneous
+squinted
+##shin
+collects
+oven
+shankar
+penned
+remarkably
+##я
+slips
+luggage
+spectral
+1786
+collaborations
+louie
+consolidation
+##ailed
+##ivating
+420
+hoover
+blackpool
+harness
+ignition
+vest
+tails
+belmont
+mongol
+skinner
+##nae
+visually
+mage
+derry
+##tism
+##unce
+stevie
+transitional
+##rdy
+redskins
+drying
+prep
+prospective
+##21
+annoyance
+oversee
+##loaded
+fills
+##books
+##iki
+announces
+fda
+scowled
+respects
+prasad
+mystic
+tucson
+##vale
+revue
+springer
+bankrupt
+1772
+aristotle
+salvatore
+habsburg
+##geny
+dal
+natal
+nut
+pod
+chewing
+darts
+moroccan
+walkover
+rosario
+lenin
+punjabi
+##ße
+grossed
+scattering
+wired
+invasive
+hui
+polynomial
+corridors
+wakes
+gina
+portrays
+##cratic
+arid
+retreating
+erich
+irwin
+sniper
+##dha
+linen
+lindsey
+maneuver
+butch
+shutting
+socio
+bounce
+commemorative
+postseason
+jeremiah
+pines
+275
+mystical
+beads
+bp
+abbas
+furnace
+bidding
+consulted
+assaulted
+empirical
+rubble
+enclosure
+sob
+weakly
+cancel
+polly
+yielded
+##emann
+curly
+prediction
+battered
+70s
+vhs
+jacqueline
+render
+sails
+barked
+detailing
+grayson
+riga
+sloane
+raging
+##yah
+herbs
+bravo
+##athlon
+alloy
+giggle
+imminent
+suffers
+assumptions
+waltz
+##itate
+accomplishments
+##ited
+bathing
+remixed
+deception
+prefix
+##emia
+deepest
+##tier
+##eis
+balkan
+frogs
+##rong
+slab
+##pate
+philosophers
+peterborough
+grains
+imports
+dickinson
+rwanda
+##atics
+1774
+dirk
+lan
+tablets
+##rove
+clone
+##rice
+caretaker
+hostilities
+mclean
+##gre
+regimental
+treasures
+norms
+impose
+tsar
+tango
+diplomacy
+variously
+complain
+192
+recognise
+arrests
+1779
+celestial
+pulitzer
+##dus
+bing
+libretto
+##moor
+adele
+splash
+##rite
+expectation
+lds
+confronts
+##izer
+spontaneous
+harmful
+wedge
+entrepreneurs
+buyer
+##ope
+bilingual
+translate
+rugged
+conner
+circulated
+uae
+eaton
+##gra
+##zzle
+lingered
+lockheed
+vishnu
+reelection
+alonso
+##oom
+joints
+yankee
+headline
+cooperate
+heinz
+laureate
+invading
+##sford
+echoes
+scandinavian
+##dham
+hugging
+vitamin
+salute
+micah
+hind
+trader
+##sper
+radioactive
+##ndra
+militants
+poisoned
+ratified
+remark
+campeonato
+deprived
+wander
+prop
+##dong
+outlook
+##tani
+##rix
+##eye
+chiang
+darcy
+##oping
+mandolin
+spice
+statesman
+babylon
+182
+walled
+forgetting
+afro
+##cap
+158
+giorgio
+buffer
+##polis
+planetary
+##gis
+overlap
+terminals
+kinda
+centenary
+##bir
+arising
+manipulate
+elm
+ke
+1770
+ak
+##tad
+chrysler
+mapped
+moose
+pomeranian
+quad
+macarthur
+assemblies
+shoreline
+recalls
+stratford
+##rted
+noticeable
+##evic
+imp
+##rita
+##sque
+accustomed
+supplying
+tents
+disgusted
+vogue
+sipped
+filters
+khz
+reno
+selecting
+luftwaffe
+mcmahon
+tyne
+masterpiece
+carriages
+collided
+dunes
+exercised
+flare
+remembers
+muzzle
+##mobile
+heck
+##rson
+burgess
+lunged
+middleton
+boycott
+bilateral
+##sity
+hazardous
+lumpur
+multiplayer
+spotlight
+jackets
+goldman
+liege
+porcelain
+rag
+waterford
+benz
+attracts
+hopeful
+battling
+ottomans
+kensington
+baked
+hymns
+cheyenne
+lattice
+levine
+borrow
+polymer
+clashes
+michaels
+monitored
+commitments
+denounced
+##25
+##von
+cavity
+##oney
+hobby
+akin
+##holders
+futures
+intricate
+cornish
+patty
+##oned
+illegally
+dolphin
+##lag
+barlow
+yellowish
+maddie
+apologized
+luton
+plagued
+##puram
+nana
+##rds
+sway
+fanny
+łodz
+##rino
+psi
+suspicions
+hanged
+##eding
+initiate
+charlton
+##por
+nak
+competent
+235
+analytical
+annex
+wardrobe
+reservations
+##rma
+sect
+162
+fairfax
+hedge
+piled
+buckingham
+uneven
+bauer
+simplicity
+snyder
+interpret
+accountability
+donors
+moderately
+byrd
+continents
+##cite
+##max
+disciple
+hr
+jamaican
+ping
+nominees
+##uss
+mongolian
+diver
+attackers
+eagerly
+ideological
+pillows
+miracles
+apartheid
+revolver
+sulfur
+clinics
+moran
+163
+##enko
+ile
+katy
+rhetoric
+##icated
+chronology
+recycling
+##hrer
+elongated
+mughal
+pascal
+profiles
+vibration
+databases
+domination
+##fare
+##rant
+matthias
+digest
+rehearsal
+polling
+weiss
+initiation
+reeves
+clinging
+flourished
+impress
+ngo
+##hoff
+##ume
+buckley
+symposium
+rhythms
+weed
+emphasize
+transforming
+##taking
+##gence
+##yman
+accountant
+analyze
+flicker
+foil
+priesthood
+voluntarily
+decreases
+##80
+##hya
+slater
+sv
+charting
+mcgill
+##lde
+moreno
+##iu
+besieged
+zur
+robes
+##phic
+admitting
+api
+deported
+turmoil
+peyton
+earthquakes
+##ares
+nationalists
+beau
+clair
+brethren
+interrupt
+welch
+curated
+galerie
+requesting
+164
+##ested
+impending
+steward
+viper
+##vina
+complaining
+beautifully
+brandy
+foam
+nl
+1660
+##cake
+alessandro
+punches
+laced
+explanations
+##lim
+attribute
+clit
+reggie
+discomfort
+##cards
+smoothed
+whales
+##cene
+adler
+countered
+duffy
+disciplinary
+widening
+recipe
+reliance
+conducts
+goats
+gradient
+preaching
+##shaw
+matilda
+quasi
+striped
+meridian
+cannabis
+cordoba
+certificates
+##agh
+##tering
+graffiti
+hangs
+pilgrims
+repeats
+##ych
+revive
+urine
+etat
+##hawk
+fueled
+belts
+fuzzy
+susceptible
+##hang
+mauritius
+salle
+sincere
+beers
+hooks
+##cki
+arbitration
+entrusted
+advise
+sniffed
+seminar
+junk
+donnell
+processors
+principality
+strapped
+celia
+mendoza
+everton
+fortunes
+prejudice
+starving
+reassigned
+steamer
+##lund
+tuck
+evenly
+foreman
+##ffen
+dans
+375
+envisioned
+slit
+##xy
+baseman
+liberia
+rosemary
+##weed
+electrified
+periodically
+potassium
+stride
+contexts
+sperm
+slade
+mariners
+influx
+bianca
+subcommittee
+##rane
+spilling
+icao
+estuary
+##nock
+delivers
+iphone
+##ulata
+isa
+mira
+bohemian
+dessert
+##sbury
+welcoming
+proudly
+slowing
+##chs
+musee
+ascension
+russ
+##vian
+waits
+##psy
+africans
+exploit
+##morphic
+gov
+eccentric
+crab
+peck
+##ull
+entrances
+formidable
+marketplace
+groom
+bolted
+metabolism
+patton
+robbins
+courier
+payload
+endure
+##ifier
+andes
+refrigerator
+##pr
+ornate
+##uca
+ruthless
+illegitimate
+masonry
+strasbourg
+bikes
+adobe
+##³
+apples
+quintet
+willingly
+niche
+bakery
+corpses
+energetic
+##cliffe
+##sser
+##ards
+177
+centimeters
+centro
+fuscous
+cretaceous
+rancho
+##yde
+andrei
+telecom
+tottenham
+oasis
+ordination
+vulnerability
+presiding
+corey
+cp
+penguins
+sims
+##pis
+malawi
+piss
+##48
+correction
+##cked
+##ffle
+##ryn
+countdown
+detectives
+psychiatrist
+psychedelic
+dinosaurs
+blouse
+##get
+choi
+vowed
+##oz
+randomly
+##pol
+49ers
+scrub
+blanche
+bruins
+dusseldorf
+##using
+unwanted
+##ums
+212
+dominique
+elevations
+headlights
+om
+laguna
+##oga
+1750
+famously
+ignorance
+shrewsbury
+##aine
+ajax
+breuning
+che
+confederacy
+greco
+overhaul
+##screen
+paz
+skirts
+disagreement
+cruelty
+jagged
+phoebe
+shifter
+hovered
+viruses
+##wes
+mandy
+##lined
+##gc
+landlord
+squirrel
+dashed
+##ι
+ornamental
+gag
+wally
+grange
+literal
+spurs
+undisclosed
+proceeding
+yin
+##text
+billie
+orphan
+spanned
+humidity
+indy
+weighted
+presentations
+explosions
+lucian
+##tary
+vaughn
+hindus
+##anga
+##hell
+psycho
+171
+daytona
+protects
+efficiently
+rematch
+sly
+tandem
+##oya
+rebranded
+impaired
+hee
+metropolis
+peach
+godfrey
+diaspora
+ethnicity
+prosperous
+gleaming
+dar
+grossing
+playback
+##rden
+stripe
+pistols
+##tain
+births
+labelled
+##cating
+172
+rudy
+alba
+##onne
+aquarium
+hostility
+##gb
+##tase
+shudder
+sumatra
+hardest
+lakers
+consonant
+creeping
+demos
+homicide
+capsule
+zeke
+liberties
+expulsion
+pueblo
+##comb
+trait
+transporting
+##ddin
+##neck
+##yna
+depart
+gregg
+mold
+ledge
+hangar
+oldham
+playboy
+termination
+analysts
+gmbh
+romero
+##itic
+insist
+cradle
+filthy
+brightness
+slash
+shootout
+deposed
+bordering
+##truct
+isis
+microwave
+tumbled
+sheltered
+cathy
+werewolves
+messy
+andersen
+convex
+clapped
+clinched
+satire
+wasting
+edo
+vc
+rufus
+##jak
+mont
+##etti
+poznan
+##keeping
+restructuring
+transverse
+##rland
+azerbaijani
+slovene
+gestures
+roommate
+choking
+shear
+##quist
+vanguard
+oblivious
+##hiro
+disagreed
+baptism
+##lich
+coliseum
+##aceae
+salvage
+societe
+cory
+locke
+relocation
+relying
+versailles
+ahl
+swelling
+##elo
+cheerful
+##word
+##edes
+gin
+sarajevo
+obstacle
+diverted
+##nac
+messed
+thoroughbred
+fluttered
+utrecht
+chewed
+acquaintance
+assassins
+dispatch
+mirza
+##wart
+nike
+salzburg
+swell
+yen
+##gee
+idle
+ligue
+samson
+##nds
+##igh
+playful
+spawned
+##cise
+tease
+##case
+burgundy
+##bot
+stirring
+skeptical
+interceptions
+marathi
+##dies
+bedrooms
+aroused
+pinch
+##lik
+preferences
+tattoos
+buster
+digitally
+projecting
+rust
+##ital
+kitten
+priorities
+addison
+pseudo
+##guard
+dusk
+icons
+sermon
+##psis
+##iba
+bt
+##lift
+##xt
+ju
+truce
+rink
+##dah
+##wy
+defects
+psychiatry
+offences
+calculate
+glucose
+##iful
+##rized
+##unda
+francaise
+##hari
+richest
+warwickshire
+carly
+1763
+purity
+redemption
+lending
+##cious
+muse
+bruises
+cerebral
+aero
+carving
+##name
+preface
+terminology
+invade
+monty
+##int
+anarchist
+blurred
+##iled
+rossi
+treats
+guts
+shu
+foothills
+ballads
+undertaking
+premise
+cecilia
+affiliates
+blasted
+conditional
+wilder
+minors
+drone
+rudolph
+buffy
+swallowing
+horton
+attested
+##hop
+rutherford
+howell
+primetime
+livery
+penal
+##bis
+minimize
+hydro
+wrecked
+wrought
+palazzo
+##gling
+cans
+vernacular
+friedman
+nobleman
+shale
+walnut
+danielle
+##ection
+##tley
+sears
+##kumar
+chords
+lend
+flipping
+streamed
+por
+dracula
+gallons
+sacrifices
+gamble
+orphanage
+##iman
+mckenzie
+##gible
+boxers
+daly
+##balls
+##ان
+208
+##ific
+##rative
+##iq
+exploited
+slated
+##uity
+circling
+hillary
+pinched
+goldberg
+provost
+campaigning
+lim
+piles
+ironically
+jong
+mohan
+successors
+usaf
+##tem
+##ught
+autobiographical
+haute
+preserves
+##ending
+acquitted
+comparisons
+203
+hydroelectric
+gangs
+cypriot
+torpedoes
+rushes
+chrome
+derive
+bumps
+instability
+fiat
+pets
+##mbe
+silas
+dye
+reckless
+settler
+##itation
+info
+heats
+##writing
+176
+canonical
+maltese
+fins
+mushroom
+stacy
+aspen
+avid
+##kur
+##loading
+vickers
+gaston
+hillside
+statutes
+wilde
+gail
+kung
+sabine
+comfortably
+motorcycles
+##rgo
+169
+pneumonia
+fetch
+##sonic
+axel
+faintly
+parallels
+##oop
+mclaren
+spouse
+compton
+interdisciplinary
+miner
+##eni
+181
+clamped
+##chal
+##llah
+separates
+versa
+##mler
+scarborough
+labrador
+##lity
+##osing
+rutgers
+hurdles
+como
+166
+burt
+divers
+##100
+wichita
+cade
+coincided
+##erson
+bruised
+mla
+##pper
+vineyard
+##ili
+##brush
+notch
+mentioning
+jase
+hearted
+kits
+doe
+##acle
+pomerania
+##ady
+ronan
+seizure
+pavel
+problematic
+##zaki
+domenico
+##ulin
+catering
+penelope
+dependence
+parental
+emilio
+ministerial
+atkinson
+##bolic
+clarkson
+chargers
+colby
+grill
+peeked
+arises
+summon
+##aged
+fools
+##grapher
+faculties
+qaeda
+##vial
+garner
+refurbished
+##hwa
+geelong
+disasters
+nudged
+bs
+shareholder
+lori
+algae
+reinstated
+rot
+##ades
+##nous
+invites
+stainless
+183
+inclusive
+##itude
+diocesan
+til
+##icz
+denomination
+##xa
+benton
+floral
+registers
+##ider
+##erman
+##kell
+absurd
+brunei
+guangzhou
+hitter
+retaliation
+##uled
+##eve
+blanc
+nh
+consistency
+contamination
+##eres
+##rner
+dire
+palermo
+broadcasters
+diaries
+inspire
+vols
+brewer
+tightening
+ky
+mixtape
+hormone
+##tok
+stokes
+##color
+##dly
+##ssi
+pg
+##ometer
+##lington
+sanitation
+##tility
+intercontinental
+apps
+##adt
+¹⁄₂
+cylinders
+economies
+favourable
+unison
+croix
+gertrude
+odyssey
+vanity
+dangling
+##logists
+upgrades
+dice
+middleweight
+practitioner
+##ight
+206
+henrik
+parlor
+orion
+angered
+lac
+python
+blurted
+##rri
+sensual
+intends
+swings
+angled
+##phs
+husky
+attain
+peerage
+precinct
+textiles
+cheltenham
+shuffled
+dai
+confess
+tasting
+bhutan
+##riation
+tyrone
+segregation
+abrupt
+ruiz
+##rish
+smirked
+blackwell
+confidential
+browning
+amounted
+##put
+vase
+scarce
+fabulous
+raided
+staple
+guyana
+unemployed
+glider
+shay
+##tow
+carmine
+troll
+intervene
+squash
+superstar
+##uce
+cylindrical
+len
+roadway
+researched
+handy
+##rium
+##jana
+meta
+lao
+declares
+##rring
+##tadt
+##elin
+##kova
+willem
+shrubs
+napoleonic
+realms
+skater
+qi
+volkswagen
+##ł
+tad
+hara
+archaeologist
+awkwardly
+eerie
+##kind
+wiley
+##heimer
+##24
+titus
+organizers
+cfl
+crusaders
+lama
+usb
+vent
+enraged
+thankful
+occupants
+maximilian
+##gaard
+possessing
+textbooks
+##oran
+collaborator
+quaker
+##ulo
+avalanche
+mono
+silky
+straits
+isaiah
+mustang
+surged
+resolutions
+potomac
+descend
+cl
+kilograms
+plato
+strains
+saturdays
+##olin
+bernstein
+##ype
+holstein
+ponytail
+##watch
+belize
+conversely
+heroine
+perpetual
+##ylus
+charcoal
+piedmont
+glee
+negotiating
+backdrop
+prologue
+##jah
+##mmy
+pasadena
+climbs
+ramos
+sunni
+##holm
+##tner
+##tri
+anand
+deficiency
+hertfordshire
+stout
+##avi
+aperture
+orioles
+##irs
+doncaster
+intrigued
+bombed
+coating
+otis
+##mat
+cocktail
+##jit
+##eto
+amir
+arousal
+sar
+##proof
+##act
+##ories
+dixie
+pots
+##bow
+whereabouts
+159
+##fted
+drains
+bullying
+cottages
+scripture
+coherent
+fore
+poe
+appetite
+##uration
+sampled
+##ators
+##dp
+derrick
+rotor
+jays
+peacock
+installment
+##rro
+advisors
+##coming
+rodeo
+scotch
+##mot
+##db
+##fen
+##vant
+ensued
+rodrigo
+dictatorship
+martyrs
+twenties
+##н
+towed
+incidence
+marta
+rainforest
+sai
+scaled
+##cles
+oceanic
+qualifiers
+symphonic
+mcbride
+dislike
+generalized
+aubrey
+colonization
+##iation
+##lion
+##ssing
+disliked
+lublin
+salesman
+##ulates
+spherical
+whatsoever
+sweating
+avalon
+contention
+punt
+severity
+alderman
+atari
+##dina
+##grant
+##rop
+scarf
+seville
+vertices
+annexation
+fairfield
+fascination
+inspiring
+launches
+palatinate
+regretted
+##rca
+feral
+##iom
+elk
+nap
+olsen
+reddy
+yong
+##leader
+##iae
+garment
+transports
+feng
+gracie
+outrage
+viceroy
+insides
+##esis
+breakup
+grady
+organizer
+softer
+grimaced
+222
+murals
+galicia
+arranging
+vectors
+##rsten
+bas
+##sb
+##cens
+sloan
+##eka
+bitten
+ara
+fender
+nausea
+bumped
+kris
+banquet
+comrades
+detector
+persisted
+##llan
+adjustment
+endowed
+cinemas
+##shot
+sellers
+##uman
+peek
+epa
+kindly
+neglect
+simpsons
+talon
+mausoleum
+runaway
+hangul
+lookout
+##cic
+rewards
+coughed
+acquainted
+chloride
+##ald
+quicker
+accordion
+neolithic
+##qa
+artemis
+coefficient
+lenny
+pandora
+tx
+##xed
+ecstasy
+litter
+segunda
+chairperson
+gemma
+hiss
+rumor
+vow
+nasal
+antioch
+compensate
+patiently
+transformers
+##eded
+judo
+morrow
+penis
+posthumous
+philips
+bandits
+husbands
+denote
+flaming
+##any
+##phones
+langley
+yorker
+1760
+walters
+##uo
+##kle
+gubernatorial
+fatty
+samsung
+leroy
+outlaw
+##nine
+unpublished
+poole
+jakob
+##ᵢ
+##ₙ
+crete
+distorted
+superiority
+##dhi
+intercept
+crust
+mig
+claus
+crashes
+positioning
+188
+stallion
+301
+frontal
+armistice
+##estinal
+elton
+aj
+encompassing
+camel
+commemorated
+malaria
+woodward
+calf
+cigar
+penetrate
+##oso
+willard
+##rno
+##uche
+illustrate
+amusing
+convergence
+noteworthy
+##lma
+##rva
+journeys
+realise
+manfred
+##sable
+410
+##vocation
+hearings
+fiance
+##posed
+educators
+provoked
+adjusting
+##cturing
+modular
+stockton
+paterson
+vlad
+rejects
+electors
+selena
+maureen
+##tres
+uber
+##rce
+swirled
+##num
+proportions
+nanny
+pawn
+naturalist
+parma
+apostles
+awoke
+ethel
+wen
+##bey
+monsoon
+overview
+##inating
+mccain
+rendition
+risky
+adorned
+##ih
+equestrian
+germain
+nj
+conspicuous
+confirming
+##yoshi
+shivering
+##imeter
+milestone
+rumours
+flinched
+bounds
+smacked
+token
+##bei
+lectured
+automobiles
+##shore
+impacted
+##iable
+nouns
+nero
+##leaf
+ismail
+prostitute
+trams
+##lace
+bridget
+sud
+stimulus
+impressions
+reins
+revolves
+##oud
+##gned
+giro
+honeymoon
+##swell
+criterion
+##sms
+##uil
+libyan
+prefers
+##osition
+211
+preview
+sucks
+accusation
+bursts
+metaphor
+diffusion
+tolerate
+faye
+betting
+cinematographer
+liturgical
+specials
+bitterly
+humboldt
+##ckle
+flux
+rattled
+##itzer
+archaeologists
+odor
+authorised
+marshes
+discretion
+##ов
+alarmed
+archaic
+inverse
+##leton
+explorers
+##pine
+drummond
+tsunami
+woodlands
+##minate
+##tland
+booklet
+insanity
+owning
+insert
+crafted
+calculus
+##tore
+receivers
+##bt
+stung
+##eca
+##nched
+prevailing
+travellers
+eyeing
+lila
+graphs
+##borne
+178
+julien
+##won
+morale
+adaptive
+therapist
+erica
+cw
+libertarian
+bowman
+pitches
+vita
+##ional
+crook
+##ads
+##entation
+caledonia
+mutiny
+##sible
+1840s
+automation
+##ß
+flock
+##pia
+ironic
+pathology
+##imus
+remarried
+##22
+joker
+withstand
+energies
+##att
+shropshire
+hostages
+madeleine
+tentatively
+conflicting
+mateo
+recipes
+euros
+ol
+mercenaries
+nico
+##ndon
+albuquerque
+augmented
+mythical
+bel
+freud
+##child
+cough
+##lica
+365
+freddy
+lillian
+genetically
+nuremberg
+calder
+209
+bonn
+outdoors
+paste
+suns
+urgency
+vin
+restraint
+tyson
+##cera
+##selle
+barrage
+bethlehem
+kahn
+##par
+mounts
+nippon
+barony
+happier
+ryu
+makeshift
+sheldon
+blushed
+castillo
+barking
+listener
+taped
+bethel
+fluent
+headlines
+pornography
+rum
+disclosure
+sighing
+mace
+doubling
+gunther
+manly
+##plex
+rt
+interventions
+physiological
+forwards
+emerges
+##tooth
+##gny
+compliment
+rib
+recession
+visibly
+barge
+faults
+connector
+exquisite
+prefect
+##rlin
+patio
+##cured
+elevators
+brandt
+italics
+pena
+173
+wasp
+satin
+ea
+botswana
+graceful
+respectable
+##jima
+##rter
+##oic
+franciscan
+generates
+##dl
+alfredo
+disgusting
+##olate
+##iously
+sherwood
+warns
+cod
+promo
+cheryl
+sino
+##ة
+##escu
+twitch
+##zhi
+brownish
+thom
+ortiz
+##dron
+densely
+##beat
+carmel
+reinforce
+##bana
+187
+anastasia
+downhill
+vertex
+contaminated
+remembrance
+harmonic
+homework
+##sol
+fiancee
+gears
+olds
+angelica
+loft
+ramsay
+quiz
+colliery
+sevens
+##cape
+autism
+##hil
+walkway
+##boats
+ruben
+abnormal
+ounce
+khmer
+##bbe
+zachary
+bedside
+morphology
+punching
+##olar
+sparrow
+convinces
+##35
+hewitt
+queer
+remastered
+rods
+mabel
+solemn
+notified
+lyricist
+symmetric
+##xide
+174
+encore
+passports
+wildcats
+##uni
+baja
+##pac
+mildly
+##ease
+bleed
+commodity
+mounds
+glossy
+orchestras
+##omo
+damian
+prelude
+ambitions
+##vet
+awhile
+remotely
+##aud
+asserts
+imply
+##iques
+distinctly
+modelling
+remedy
+##dded
+windshield
+dani
+xiao
+##endra
+audible
+powerplant
+1300
+invalid
+elemental
+acquisitions
+##hala
+immaculate
+libby
+plata
+smuggling
+ventilation
+denoted
+minh
+##morphism
+430
+differed
+dion
+kelley
+lore
+mocking
+sabbath
+spikes
+hygiene
+drown
+runoff
+stylized
+tally
+liberated
+aux
+interpreter
+righteous
+aba
+siren
+reaper
+pearce
+millie
+##cier
+##yra
+gaius
+##iso
+captures
+##ttering
+dorm
+claudio
+##sic
+benches
+knighted
+blackness
+##ored
+discount
+fumble
+oxidation
+routed
+##ς
+novak
+perpendicular
+spoiled
+fracture
+splits
+##urt
+pads
+topology
+##cats
+axes
+fortunate
+offenders
+protestants
+esteem
+221
+broadband
+convened
+frankly
+hound
+prototypes
+isil
+facilitated
+keel
+##sher
+sahara
+awaited
+bubba
+orb
+prosecutors
+186
+hem
+520
+##xing
+relaxing
+remnant
+romney
+sorted
+slalom
+stefano
+ulrich
+##active
+exemption
+folder
+pauses
+foliage
+hitchcock
+epithet
+204
+criticisms
+##aca
+ballistic
+brody
+hinduism
+chaotic
+youths
+equals
+##pala
+pts
+thicker
+analogous
+capitalist
+improvised
+overseeing
+sinatra
+ascended
+beverage
+##tl
+straightforward
+##kon
+curran
+##west
+bois
+325
+induce
+surveying
+emperors
+sax
+unpopular
+##kk
+cartoonist
+fused
+##mble
+unto
+##yuki
+localities
+##cko
+##ln
+darlington
+slain
+academie
+lobbying
+sediment
+puzzles
+##grass
+defiance
+dickens
+manifest
+tongues
+alumnus
+arbor
+coincide
+184
+appalachian
+mustafa
+examiner
+cabaret
+traumatic
+yves
+bracelet
+draining
+heroin
+magnum
+baths
+odessa
+consonants
+mitsubishi
+##gua
+kellan
+vaudeville
+##fr
+joked
+null
+straps
+probation
+##ław
+ceded
+interfaces
+##pas
+##zawa
+blinding
+viet
+224
+rothschild
+museo
+640
+huddersfield
+##vr
+tactic
+##storm
+brackets
+dazed
+incorrectly
+##vu
+reg
+glazed
+fearful
+manifold
+benefited
+irony
+##sun
+stumbling
+##rte
+willingness
+balkans
+mei
+wraps
+##aba
+injected
+##lea
+gu
+syed
+harmless
+##hammer
+bray
+takeoff
+poppy
+timor
+cardboard
+astronaut
+purdue
+weeping
+southbound
+cursing
+stalls
+diagonal
+##neer
+lamar
+bryce
+comte
+weekdays
+harrington
+##uba
+negatively
+##see
+lays
+grouping
+##cken
+##henko
+affirmed
+halle
+modernist
+##lai
+hodges
+smelling
+aristocratic
+baptized
+dismiss
+justification
+oilers
+##now
+coupling
+qin
+snack
+healer
+##qing
+gardener
+layla
+battled
+formulated
+stephenson
+gravitational
+##gill
+##jun
+1768
+granny
+coordinating
+suites
+##cd
+##ioned
+monarchs
+##cote
+##hips
+sep
+blended
+apr
+barrister
+deposition
+fia
+mina
+policemen
+paranoid
+##pressed
+churchyard
+covert
+crumpled
+creep
+abandoning
+tr
+transmit
+conceal
+barr
+understands
+readiness
+spire
+##cology
+##enia
+##erry
+610
+startling
+unlock
+vida
+bowled
+slots
+##nat
+##islav
+spaced
+trusting
+admire
+rig
+##ink
+slack
+##70
+mv
+207
+casualty
+##wei
+classmates
+##odes
+##rar
+##rked
+amherst
+furnished
+evolve
+foundry
+menace
+mead
+##lein
+flu
+wesleyan
+##kled
+monterey
+webber
+##vos
+wil
+##mith
+##на
+bartholomew
+justices
+restrained
+##cke
+amenities
+191
+mediated
+sewage
+trenches
+ml
+mainz
+##thus
+1800s
+##cula
+##inski
+caine
+bonding
+213
+converts
+spheres
+superseded
+marianne
+crypt
+sweaty
+ensign
+historia
+##br
+spruce
+##post
+##ask
+forks
+thoughtfully
+yukon
+pamphlet
+ames
+##uter
+karma
+##yya
+bryn
+negotiation
+sighs
+incapable
+##mbre
+##ntial
+actresses
+taft
+##mill
+luce
+prevailed
+##amine
+1773
+motionless
+envoy
+testify
+investing
+sculpted
+instructors
+provence
+kali
+cullen
+horseback
+##while
+goodwin
+##jos
+gaa
+norte
+##ldon
+modify
+wavelength
+abd
+214
+skinned
+sprinter
+forecast
+scheduling
+marries
+squared
+tentative
+##chman
+boer
+##isch
+bolts
+swap
+fisherman
+assyrian
+impatiently
+guthrie
+martins
+murdoch
+194
+tanya
+nicely
+dolly
+lacy
+med
+##45
+syn
+decks
+fashionable
+millionaire
+##ust
+surfing
+##ml
+##ision
+heaved
+tammy
+consulate
+attendees
+routinely
+197
+fuse
+saxophonist
+backseat
+malaya
+##lord
+scowl
+tau
+##ishly
+193
+sighted
+steaming
+##rks
+303
+911
+##holes
+##hong
+ching
+##wife
+bless
+conserved
+jurassic
+stacey
+unix
+zion
+chunk
+rigorous
+blaine
+198
+peabody
+slayer
+dismay
+brewers
+nz
+##jer
+det
+##glia
+glover
+postwar
+int
+penetration
+sylvester
+imitation
+vertically
+airlift
+heiress
+knoxville
+viva
+##uin
+390
+macon
+##rim
+##fighter
+##gonal
+janice
+##orescence
+##wari
+marius
+belongings
+leicestershire
+196
+blanco
+inverted
+preseason
+sanity
+sobbing
+##due
+##elt
+##dled
+collingwood
+regeneration
+flickering
+shortest
+##mount
+##osi
+feminism
+##lat
+sherlock
+cabinets
+fumbled
+northbound
+precedent
+snaps
+##mme
+researching
+##akes
+guillaume
+insights
+manipulated
+vapor
+neighbour
+sap
+gangster
+frey
+f1
+stalking
+scarcely
+callie
+barnett
+tendencies
+audi
+doomed
+assessing
+slung
+panchayat
+ambiguous
+bartlett
+##etto
+distributing
+violating
+wolverhampton
+##hetic
+swami
+histoire
+##urus
+liable
+pounder
+groin
+hussain
+larsen
+popping
+surprises
+##atter
+vie
+curt
+##station
+mute
+relocate
+musicals
+authorization
+richter
+##sef
+immortality
+tna
+bombings
+##press
+deteriorated
+yiddish
+##acious
+robbed
+colchester
+cs
+pmid
+ao
+verified
+balancing
+apostle
+swayed
+recognizable
+oxfordshire
+retention
+nottinghamshire
+contender
+judd
+invitational
+shrimp
+uhf
+##icient
+cleaner
+longitudinal
+tanker
+##mur
+acronym
+broker
+koppen
+sundance
+suppliers
+##gil
+4000
+clipped
+fuels
+petite
+##anne
+landslide
+helene
+diversion
+populous
+landowners
+auspices
+melville
+quantitative
+##xes
+ferries
+nicky
+##llus
+doo
+haunting
+roche
+carver
+downed
+unavailable
+##pathy
+approximation
+hiroshima
+##hue
+garfield
+valle
+comparatively
+keyboardist
+traveler
+##eit
+congestion
+calculating
+subsidiaries
+##bate
+serb
+modernization
+fairies
+deepened
+ville
+averages
+##lore
+inflammatory
+tonga
+##itch
+co₂
+squads
+##hea
+gigantic
+serum
+enjoyment
+retailer
+verona
+35th
+cis
+##phobic
+magna
+technicians
+##vati
+arithmetic
+##sport
+levin
+##dation
+amtrak
+chow
+sienna
+##eyer
+backstage
+entrepreneurship
+##otic
+learnt
+tao
+##udy
+worcestershire
+formulation
+baggage
+hesitant
+bali
+sabotage
+##kari
+barren
+enhancing
+murmur
+pl
+freshly
+putnam
+syntax
+aces
+medicines
+resentment
+bandwidth
+##sier
+grins
+chili
+guido
+##sei
+framing
+implying
+gareth
+lissa
+genevieve
+pertaining
+admissions
+geo
+thorpe
+proliferation
+sato
+bela
+analyzing
+parting
+##gor
+awakened
+##isman
+huddled
+secrecy
+##kling
+hush
+gentry
+540
+dungeons
+##ego
+coasts
+##utz
+sacrificed
+##chule
+landowner
+mutually
+prevalence
+programmer
+adolescent
+disrupted
+seaside
+gee
+trusts
+vamp
+georgie
+##nesian
+##iol
+schedules
+sindh
+##market
+etched
+hm
+sparse
+bey
+beaux
+scratching
+gliding
+unidentified
+216
+collaborating
+gems
+jesuits
+oro
+accumulation
+shaping
+mbe
+anal
+##xin
+231
+enthusiasts
+newscast
+##egan
+janata
+dewey
+parkinson
+179
+ankara
+biennial
+towering
+dd
+inconsistent
+950
+##chet
+thriving
+terminate
+cabins
+furiously
+eats
+advocating
+donkey
+marley
+muster
+phyllis
+leiden
+##user
+grassland
+glittering
+iucn
+loneliness
+217
+memorandum
+armenians
+##ddle
+popularized
+rhodesia
+60s
+lame
+##illon
+sans
+bikini
+header
+orbits
+##xx
+##finger
+##ulator
+sharif
+spines
+biotechnology
+strolled
+naughty
+yates
+##wire
+fremantle
+milo
+##mour
+abducted
+removes
+##atin
+humming
+wonderland
+##chrome
+##ester
+hume
+pivotal
+##rates
+armand
+grams
+believers
+elector
+rte
+apron
+bis
+scraped
+##yria
+endorsement
+initials
+##llation
+eps
+dotted
+hints
+buzzing
+emigration
+nearer
+##tom
+indicators
+##ulu
+coarse
+neutron
+protectorate
+##uze
+directional
+exploits
+pains
+loire
+1830s
+proponents
+guggenheim
+rabbits
+ritchie
+305
+hectare
+inputs
+hutton
+##raz
+verify
+##ako
+boilers
+longitude
+##lev
+skeletal
+yer
+emilia
+citrus
+compromised
+##gau
+pokemon
+prescription
+paragraph
+eduard
+cadillac
+attire
+categorized
+kenyan
+weddings
+charley
+##bourg
+entertain
+monmouth
+##lles
+nutrients
+davey
+mesh
+incentive
+practised
+ecosystems
+kemp
+subdued
+overheard
+##rya
+bodily
+maxim
+##nius
+apprenticeship
+ursula
+##fight
+lodged
+rug
+silesian
+unconstitutional
+patel
+inspected
+coyote
+unbeaten
+##hak
+34th
+disruption
+convict
+parcel
+##cl
+##nham
+collier
+implicated
+mallory
+##iac
+##lab
+susannah
+winkler
+##rber
+shia
+phelps
+sediments
+graphical
+robotic
+##sner
+adulthood
+mart
+smoked
+##isto
+kathryn
+clarified
+##aran
+divides
+convictions
+oppression
+pausing
+burying
+##mt
+federico
+mathias
+eileen
+##tana
+kite
+hunched
+##acies
+189
+##atz
+disadvantage
+liza
+kinetic
+greedy
+paradox
+yokohama
+dowager
+trunks
+ventured
+##gement
+gupta
+vilnius
+olaf
+##thest
+crimean
+hopper
+##ej
+progressively
+arturo
+mouthed
+arrondissement
+##fusion
+rubin
+simulcast
+oceania
+##orum
+##stra
+##rred
+busiest
+intensely
+navigator
+cary
+##vine
+##hini
+##bies
+fife
+rowe
+rowland
+posing
+insurgents
+shafts
+lawsuits
+activate
+conor
+inward
+culturally
+garlic
+265
+##eering
+eclectic
+##hui
+##kee
+##nl
+furrowed
+vargas
+meteorological
+rendezvous
+##aus
+culinary
+commencement
+##dition
+quota
+##notes
+mommy
+salaries
+overlapping
+mule
+##iology
+##mology
+sums
+wentworth
+##isk
+##zione
+mainline
+subgroup
+##illy
+hack
+plaintiff
+verdi
+bulb
+differentiation
+engagements
+multinational
+supplemented
+bertrand
+caller
+regis
+##naire
+##sler
+##arts
+##imated
+blossom
+propagation
+kilometer
+viaduct
+vineyards
+##uate
+beckett
+optimization
+golfer
+songwriters
+seminal
+semitic
+thud
+volatile
+evolving
+ridley
+##wley
+trivial
+distributions
+scandinavia
+jiang
+##ject
+wrestled
+insistence
+##dio
+emphasizes
+napkin
+##ods
+adjunct
+rhyme
+##ricted
+##eti
+hopeless
+surrounds
+tremble
+32nd
+smoky
+##ntly
+oils
+medicinal
+padded
+steer
+wilkes
+219
+255
+concessions
+hue
+uniquely
+blinded
+landon
+yahoo
+##lane
+hendrix
+commemorating
+dex
+specify
+chicks
+##ggio
+intercity
+1400
+morley
+##torm
+highlighting
+##oting
+pang
+oblique
+stalled
+##liner
+flirting
+newborn
+1769
+bishopric
+shaved
+232
+currie
+##ush
+dharma
+spartan
+##ooped
+favorites
+smug
+novella
+sirens
+abusive
+creations
+espana
+##lage
+paradigm
+semiconductor
+sheen
+##rdo
+##yen
+##zak
+nrl
+renew
+##pose
+##tur
+adjutant
+marches
+norma
+##enity
+ineffective
+weimar
+grunt
+##gat
+lordship
+plotting
+expenditure
+infringement
+lbs
+refrain
+av
+mimi
+mistakenly
+postmaster
+1771
+##bara
+ras
+motorsports
+tito
+199
+subjective
+##zza
+bully
+stew
+##kaya
+prescott
+1a
+##raphic
+##zam
+bids
+styling
+paranormal
+reeve
+sneaking
+exploding
+katz
+akbar
+migrant
+syllables
+indefinitely
+##ogical
+destroys
+replaces
+applause
+##phine
+pest
+##fide
+218
+articulated
+bertie
+##thing
+##cars
+##ptic
+courtroom
+crowley
+aesthetics
+cummings
+tehsil
+hormones
+titanic
+dangerously
+##ibe
+stadion
+jaenelle
+auguste
+ciudad
+##chu
+mysore
+partisans
+##sio
+lucan
+philipp
+##aly
+debating
+henley
+interiors
+##rano
+##tious
+homecoming
+beyonce
+usher
+henrietta
+prepares
+weeds
+##oman
+ely
+plucked
+##pire
+##dable
+luxurious
+##aq
+artifact
+password
+pasture
+juno
+maddy
+minsk
+##dder
+##ologies
+##rone
+assessments
+martian
+royalist
+1765
+examines
+##mani
+##rge
+nino
+223
+parry
+scooped
+relativity
+##eli
+##uting
+##cao
+congregational
+noisy
+traverse
+##agawa
+strikeouts
+nickelodeon
+obituary
+transylvania
+binds
+depictions
+polk
+trolley
+##yed
+##lard
+breeders
+##under
+dryly
+hokkaido
+1762
+strengths
+stacks
+bonaparte
+connectivity
+neared
+prostitutes
+stamped
+anaheim
+gutierrez
+sinai
+##zzling
+bram
+fresno
+madhya
+##86
+proton
+##lena
+##llum
+##phon
+reelected
+wanda
+##anus
+##lb
+ample
+distinguishing
+##yler
+grasping
+sermons
+tomato
+bland
+stimulation
+avenues
+##eux
+spreads
+scarlett
+fern
+pentagon
+assert
+baird
+chesapeake
+ir
+calmed
+distortion
+fatalities
+##olis
+correctional
+pricing
+##astic
+##gina
+prom
+dammit
+ying
+collaborate
+##chia
+welterweight
+33rd
+pointer
+substitution
+bonded
+umpire
+communicating
+multitude
+paddle
+##obe
+federally
+intimacy
+##insky
+betray
+ssr
+##lett
+##lean
+##lves
+##therapy
+airbus
+##tery
+functioned
+ud
+bearer
+biomedical
+netflix
+##hire
+##nca
+condom
+brink
+ik
+##nical
+macy
+##bet
+flap
+gma
+experimented
+jelly
+lavender
+##icles
+##ulia
+munro
+##mian
+##tial
+rye
+##rle
+60th
+gigs
+hottest
+rotated
+predictions
+fuji
+bu
+##erence
+##omi
+barangay
+##fulness
+##sas
+clocks
+##rwood
+##liness
+cereal
+roe
+wight
+decker
+uttered
+babu
+onion
+xml
+forcibly
+##df
+petra
+sarcasm
+hartley
+peeled
+storytelling
+##42
+##xley
+##ysis
+##ffa
+fibre
+kiel
+auditor
+fig
+harald
+greenville
+##berries
+geographically
+nell
+quartz
+##athic
+cemeteries
+##lr
+crossings
+nah
+holloway
+reptiles
+chun
+sichuan
+snowy
+660
+corrections
+##ivo
+zheng
+ambassadors
+blacksmith
+fielded
+fluids
+hardcover
+turnover
+medications
+melvin
+academies
+##erton
+ro
+roach
+absorbing
+spaniards
+colton
+##founded
+outsider
+espionage
+kelsey
+245
+edible
+##ulf
+dora
+establishes
+##sham
+##tries
+contracting
+##tania
+cinematic
+costello
+nesting
+##uron
+connolly
+duff
+##nology
+mma
+##mata
+fergus
+sexes
+gi
+optics
+spectator
+woodstock
+banning
+##hee
+##fle
+differentiate
+outfielder
+refinery
+226
+312
+gerhard
+horde
+lair
+drastically
+##udi
+landfall
+##cheng
+motorsport
+odi
+##achi
+predominant
+quay
+skins
+##ental
+edna
+harshly
+complementary
+murdering
+##aves
+wreckage
+##90
+ono
+outstretched
+lennox
+munitions
+galen
+reconcile
+470
+scalp
+bicycles
+gillespie
+questionable
+rosenberg
+guillermo
+hostel
+jarvis
+kabul
+volvo
+opium
+yd
+##twined
+abuses
+decca
+outpost
+##cino
+sensible
+neutrality
+##64
+ponce
+anchorage
+atkins
+turrets
+inadvertently
+disagree
+libre
+vodka
+reassuring
+weighs
+##yal
+glide
+jumper
+ceilings
+repertory
+outs
+stain
+##bial
+envy
+##ucible
+smashing
+heightened
+policing
+hyun
+mixes
+lai
+prima
+##ples
+celeste
+##bina
+lucrative
+intervened
+kc
+manually
+##rned
+stature
+staffed
+bun
+bastards
+nairobi
+priced
+##auer
+thatcher
+##kia
+tripped
+comune
+##ogan
+##pled
+brasil
+incentives
+emanuel
+hereford
+musica
+##kim
+benedictine
+biennale
+##lani
+eureka
+gardiner
+rb
+knocks
+sha
+##ael
+##elled
+##onate
+efficacy
+ventura
+masonic
+sanford
+maize
+leverage
+##feit
+capacities
+santana
+##aur
+novelty
+vanilla
+##cter
+##tour
+benin
+##oir
+##rain
+neptune
+drafting
+tallinn
+##cable
+humiliation
+##boarding
+schleswig
+fabian
+bernardo
+liturgy
+spectacle
+sweeney
+pont
+routledge
+##tment
+cosmos
+ut
+hilt
+sleek
+universally
+##eville
+##gawa
+typed
+##dry
+favors
+allegheny
+glaciers
+##rly
+recalling
+aziz
+##log
+parasite
+requiem
+auf
+##berto
+##llin
+illumination
+##breaker
+##issa
+festivities
+bows
+govern
+vibe
+vp
+333
+sprawled
+larson
+pilgrim
+bwf
+leaping
+##rts
+##ssel
+alexei
+greyhound
+hoarse
+##dler
+##oration
+seneca
+##cule
+gaping
+##ulously
+##pura
+cinnamon
+##gens
+##rricular
+craven
+fantasies
+houghton
+engined
+reigned
+dictator
+supervising
+##oris
+bogota
+commentaries
+unnatural
+fingernails
+spirituality
+tighten
+##tm
+canadiens
+protesting
+intentional
+cheers
+sparta
+##ytic
+##iere
+##zine
+widen
+belgarath
+controllers
+dodd
+iaaf
+navarre
+##ication
+defect
+squire
+steiner
+whisky
+##mins
+560
+inevitably
+tome
+##gold
+chew
+##uid
+##lid
+elastic
+##aby
+streaked
+alliances
+jailed
+regal
+##ined
+##phy
+czechoslovak
+narration
+absently
+##uld
+bluegrass
+guangdong
+quran
+criticizing
+hose
+hari
+##liest
+##owa
+skier
+streaks
+deploy
+##lom
+raft
+bose
+dialed
+huff
+##eira
+haifa
+simplest
+bursting
+endings
+ib
+sultanate
+##titled
+franks
+whitman
+ensures
+sven
+##ggs
+collaborators
+forster
+organising
+ui
+banished
+napier
+injustice
+teller
+layered
+thump
+##otti
+roc
+battleships
+evidenced
+fugitive
+sadie
+robotics
+##roud
+equatorial
+geologist
+##iza
+yielding
+##bron
+##sr
+internationale
+mecca
+##diment
+sbs
+skyline
+toad
+uploaded
+reflective
+undrafted
+lal
+leafs
+bayern
+##dai
+lakshmi
+shortlisted
+##stick
+##wicz
+camouflage
+donate
+af
+christi
+lau
+##acio
+disclosed
+nemesis
+1761
+assemble
+straining
+northamptonshire
+tal
+##asi
+bernardino
+premature
+heidi
+42nd
+coefficients
+galactic
+reproduce
+buzzed
+sensations
+zionist
+monsieur
+myrtle
+##eme
+archery
+strangled
+musically
+viewpoint
+antiquities
+bei
+trailers
+seahawks
+cured
+pee
+preferring
+tasmanian
+lange
+sul
+##mail
+##working
+colder
+overland
+lucivar
+massey
+gatherings
+haitian
+##smith
+disapproval
+flaws
+##cco
+##enbach
+1766
+npr
+##icular
+boroughs
+creole
+forums
+techno
+1755
+dent
+abdominal
+streetcar
+##eson
+##stream
+procurement
+gemini
+predictable
+##tya
+acheron
+christoph
+feeder
+fronts
+vendor
+bernhard
+jammu
+tumors
+slang
+##uber
+goaltender
+twists
+curving
+manson
+vuelta
+mer
+peanut
+confessions
+pouch
+unpredictable
+allowance
+theodor
+vascular
+##factory
+bala
+authenticity
+metabolic
+coughing
+nanjing
+##cea
+pembroke
+##bard
+splendid
+36th
+ff
+hourly
+##ahu
+elmer
+handel
+##ivate
+awarding
+thrusting
+dl
+experimentation
+##hesion
+##46
+caressed
+entertained
+steak
+##rangle
+biologist
+orphans
+baroness
+oyster
+stepfather
+##dridge
+mirage
+reefs
+speeding
+##31
+barons
+1764
+227
+inhabit
+preached
+repealed
+##tral
+honoring
+boogie
+captives
+administer
+johanna
+##imate
+gel
+suspiciously
+1767
+sobs
+##dington
+backbone
+hayward
+garry
+##folding
+##nesia
+maxi
+##oof
+##ppe
+ellison
+galileo
+##stand
+crimea
+frenzy
+amour
+bumper
+matrices
+natalia
+baking
+garth
+palestinians
+##grove
+smack
+conveyed
+ensembles
+gardening
+##manship
+##rup
+##stituting
+1640
+harvesting
+topography
+jing
+shifters
+dormitory
+##carriage
+##lston
+ist
+skulls
+##stadt
+dolores
+jewellery
+sarawak
+##wai
+##zier
+fences
+christy
+confinement
+tumbling
+credibility
+fir
+stench
+##bria
+##plication
+##nged
+##sam
+virtues
+##belt
+marjorie
+pba
+##eem
+##made
+celebrates
+schooner
+agitated
+barley
+fulfilling
+anthropologist
+##pro
+restrict
+novi
+regulating
+##nent
+padres
+##rani
+##hesive
+loyola
+tabitha
+milky
+olson
+proprietor
+crambidae
+guarantees
+intercollegiate
+ljubljana
+hilda
+##sko
+ignorant
+hooded
+##lts
+sardinia
+##lidae
+##vation
+frontman
+privileged
+witchcraft
+##gp
+jammed
+laude
+poking
+##than
+bracket
+amazement
+yunnan
+##erus
+maharaja
+linnaeus
+264
+commissioning
+milano
+peacefully
+##logies
+akira
+rani
+regulator
+##36
+grasses
+##rance
+luzon
+crows
+compiler
+gretchen
+seaman
+edouard
+tab
+buccaneers
+ellington
+hamlets
+whig
+socialists
+##anto
+directorial
+easton
+mythological
+##kr
+##vary
+rhineland
+semantic
+taut
+dune
+inventions
+succeeds
+##iter
+replication
+branched
+##pired
+jul
+prosecuted
+kangaroo
+penetrated
+##avian
+middlesbrough
+doses
+bleak
+madam
+predatory
+relentless
+##vili
+reluctance
+##vir
+hailey
+crore
+silvery
+1759
+monstrous
+swimmers
+transmissions
+hawthorn
+informing
+##eral
+toilets
+caracas
+crouch
+kb
+##sett
+295
+cartel
+hadley
+##aling
+alexia
+yvonne
+##biology
+cinderella
+eton
+superb
+blizzard
+stabbing
+industrialist
+maximus
+##gm
+##orus
+groves
+maud
+clade
+oversized
+comedic
+##bella
+rosen
+nomadic
+fulham
+montane
+beverages
+galaxies
+redundant
+swarm
+##rot
+##folia
+##llis
+buckinghamshire
+fen
+bearings
+bahadur
+##rom
+gilles
+phased
+dynamite
+faber
+benoit
+vip
+##ount
+##wd
+booking
+fractured
+tailored
+anya
+spices
+westwood
+cairns
+auditions
+inflammation
+steamed
+##rocity
+##acion
+##urne
+skyla
+thereof
+watford
+torment
+archdeacon
+transforms
+lulu
+demeanor
+fucked
+serge
+##sor
+mckenna
+minas
+entertainer
+##icide
+caress
+originate
+residue
+##sty
+1740
+##ilised
+##org
+beech
+##wana
+subsidies
+##ghton
+emptied
+gladstone
+ru
+firefighters
+voodoo
+##rcle
+het
+nightingale
+tamara
+edmond
+ingredient
+weaknesses
+silhouette
+285
+compatibility
+withdrawing
+hampson
+##mona
+anguish
+giggling
+##mber
+bookstore
+##jiang
+southernmost
+tilting
+##vance
+bai
+economical
+rf
+briefcase
+dreadful
+hinted
+projections
+shattering
+totaling
+##rogate
+analogue
+indicted
+periodical
+fullback
+##dman
+haynes
+##tenberg
+##ffs
+##ishment
+1745
+thirst
+stumble
+penang
+vigorous
+##ddling
+##kor
+##lium
+octave
+##ove
+##enstein
+##inen
+##ones
+siberian
+##uti
+cbn
+repeal
+swaying
+##vington
+khalid
+tanaka
+unicorn
+otago
+plastered
+lobe
+riddle
+##rella
+perch
+##ishing
+croydon
+filtered
+graeme
+tripoli
+##ossa
+crocodile
+##chers
+sufi
+mined
+##tung
+inferno
+lsu
+##phi
+swelled
+utilizes
+£2
+cale
+periodicals
+styx
+hike
+informally
+coop
+lund
+##tidae
+ala
+hen
+qui
+transformations
+disposed
+sheath
+chickens
+##cade
+fitzroy
+sas
+silesia
+unacceptable
+odisha
+1650
+sabrina
+pe
+spokane
+ratios
+athena
+massage
+shen
+dilemma
+##drum
+##riz
+##hul
+corona
+doubtful
+niall
+##pha
+##bino
+fines
+cite
+acknowledging
+bangor
+ballard
+bathurst
+##resh
+huron
+mustered
+alzheimer
+garments
+kinase
+tyre
+warship
+##cp
+flashback
+pulmonary
+braun
+cheat
+kamal
+cyclists
+constructions
+grenades
+ndp
+traveller
+excuses
+stomped
+signalling
+trimmed
+futsal
+mosques
+relevance
+##wine
+wta
+##23
+##vah
+##lter
+hoc
+##riding
+optimistic
+##´s
+deco
+sim
+interacting
+rejecting
+moniker
+waterways
+##ieri
+##oku
+mayors
+gdansk
+outnumbered
+pearls
+##ended
+##hampton
+fairs
+totals
+dominating
+262
+notions
+stairway
+compiling
+pursed
+commodities
+grease
+yeast
+##jong
+carthage
+griffiths
+residual
+amc
+contraction
+laird
+sapphire
+##marine
+##ivated
+amalgamation
+dissolve
+inclination
+lyle
+packaged
+altitudes
+suez
+canons
+graded
+lurched
+narrowing
+boasts
+guise
+wed
+enrico
+##ovsky
+rower
+scarred
+bree
+cub
+iberian
+protagonists
+bargaining
+proposing
+trainers
+voyages
+vans
+fishes
+##aea
+##ivist
+##verance
+encryption
+artworks
+kazan
+sabre
+cleopatra
+hepburn
+rotting
+supremacy
+mecklenburg
+##brate
+burrows
+hazards
+outgoing
+flair
+organizes
+##ctions
+scorpion
+##usions
+boo
+234
+chevalier
+dunedin
+slapping
+##34
+ineligible
+pensions
+##38
+##omic
+manufactures
+emails
+bismarck
+238
+weakening
+blackish
+ding
+mcgee
+quo
+##rling
+northernmost
+xx
+manpower
+greed
+sampson
+clicking
+##ange
+##horpe
+##inations
+##roving
+torre
+##eptive
+##moral
+symbolism
+38th
+asshole
+meritorious
+outfits
+splashed
+biographies
+sprung
+astros
+##tale
+302
+737
+filly
+raoul
+nw
+tokugawa
+linden
+clubhouse
+##apa
+tracts
+romano
+##pio
+putin
+tags
+##note
+chained
+dickson
+gunshot
+moe
+gunn
+rashid
+##tails
+zipper
+##bas
+##nea
+contrasted
+##ply
+##udes
+plum
+pharaoh
+##pile
+aw
+comedies
+ingrid
+sandwiches
+subdivisions
+1100
+mariana
+nokia
+kamen
+hz
+delaney
+veto
+herring
+##words
+possessive
+outlines
+##roup
+siemens
+stairwell
+rc
+gallantry
+messiah
+palais
+yells
+233
+zeppelin
+##dm
+bolivar
+##cede
+smackdown
+mckinley
+##mora
+##yt
+muted
+geologic
+finely
+unitary
+avatar
+hamas
+maynard
+rees
+bog
+contrasting
+##rut
+liv
+chico
+disposition
+pixel
+##erate
+becca
+dmitry
+yeshiva
+narratives
+##lva
+##ulton
+mercenary
+sharpe
+tempered
+navigate
+stealth
+amassed
+keynes
+##lini
+untouched
+##rrie
+havoc
+lithium
+##fighting
+abyss
+graf
+southward
+wolverine
+balloons
+implements
+ngos
+transitions
+##icum
+ambushed
+concacaf
+dormant
+economists
+##dim
+costing
+csi
+rana
+universite
+boulders
+verity
+##llon
+collin
+mellon
+misses
+cypress
+fluorescent
+lifeless
+spence
+##ulla
+crewe
+shepard
+pak
+revelations
+##م
+jolly
+gibbons
+paw
+##dro
+##quel
+freeing
+##test
+shack
+fries
+palatine
+##51
+##hiko
+accompaniment
+cruising
+recycled
+##aver
+erwin
+sorting
+synthesizers
+dyke
+realities
+sg
+strides
+enslaved
+wetland
+##ghan
+competence
+gunpowder
+grassy
+maroon
+reactors
+objection
+##oms
+carlson
+gearbox
+macintosh
+radios
+shelton
+##sho
+clergyman
+prakash
+254
+mongols
+trophies
+oricon
+228
+stimuli
+twenty20
+cantonese
+cortes
+mirrored
+##saurus
+bhp
+cristina
+melancholy
+##lating
+enjoyable
+nuevo
+##wny
+downfall
+schumacher
+##ind
+banging
+lausanne
+rumbled
+paramilitary
+reflex
+ax
+amplitude
+migratory
+##gall
+##ups
+midi
+barnard
+lastly
+sherry
+##hp
+##nall
+keystone
+##kra
+carleton
+slippery
+##53
+coloring
+foe
+socket
+otter
+##rgos
+mats
+##tose
+consultants
+bafta
+bison
+topping
+##km
+490
+primal
+abandonment
+transplant
+atoll
+hideous
+mort
+pained
+reproduced
+tae
+howling
+##turn
+unlawful
+billionaire
+hotter
+poised
+lansing
+##chang
+dinamo
+retro
+messing
+nfc
+domesday
+##mina
+blitz
+timed
+##athing
+##kley
+ascending
+gesturing
+##izations
+signaled
+tis
+chinatown
+mermaid
+savanna
+jameson
+##aint
+catalina
+##pet
+##hers
+cochrane
+cy
+chatting
+##kus
+alerted
+computation
+mused
+noelle
+majestic
+mohawk
+campo
+octagonal
+##sant
+##hend
+241
+aspiring
+##mart
+comprehend
+iona
+paralyzed
+shimmering
+swindon
+rhone
+##eley
+reputed
+configurations
+pitchfork
+agitation
+francais
+gillian
+lipstick
+##ilo
+outsiders
+pontifical
+resisting
+bitterness
+sewer
+rockies
+##edd
+##ucher
+misleading
+1756
+exiting
+galloway
+##nging
+risked
+##heart
+246
+commemoration
+schultz
+##rka
+integrating
+##rsa
+poses
+shrieked
+##weiler
+guineas
+gladys
+jerking
+owls
+goldsmith
+nightly
+penetrating
+##unced
+lia
+##33
+ignited
+betsy
+##aring
+##thorpe
+follower
+vigorously
+##rave
+coded
+kiran
+knit
+zoology
+tbilisi
+##28
+##bered
+repository
+govt
+deciduous
+dino
+growling
+##bba
+enhancement
+unleashed
+chanting
+pussy
+biochemistry
+##eric
+kettle
+repression
+toxicity
+nrhp
+##arth
+##kko
+##bush
+ernesto
+commended
+outspoken
+242
+mca
+parchment
+sms
+kristen
+##aton
+bisexual
+raked
+glamour
+navajo
+a2
+conditioned
+showcased
+##hma
+spacious
+youthful
+##esa
+usl
+appliances
+junta
+brest
+layne
+conglomerate
+enchanted
+chao
+loosened
+picasso
+circulating
+inspect
+montevideo
+##centric
+##kti
+piazza
+spurred
+##aith
+bari
+freedoms
+poultry
+stamford
+lieu
+##ect
+indigo
+sarcastic
+bahia
+stump
+attach
+dvds
+frankenstein
+lille
+approx
+scriptures
+pollen
+##script
+nmi
+overseen
+##ivism
+tides
+proponent
+newmarket
+inherit
+milling
+##erland
+centralized
+##rou
+distributors
+credentials
+drawers
+abbreviation
+##lco
+##xon
+downing
+uncomfortably
+ripe
+##oes
+erase
+franchises
+##ever
+populace
+##bery
+##khar
+decomposition
+pleas
+##tet
+daryl
+sabah
+##stle
+##wide
+fearless
+genie
+lesions
+annette
+##ogist
+oboe
+appendix
+nair
+dripped
+petitioned
+maclean
+mosquito
+parrot
+rpg
+hampered
+1648
+operatic
+reservoirs
+##tham
+irrelevant
+jolt
+summarized
+##fp
+medallion
+##taff
+##−
+clawed
+harlow
+narrower
+goddard
+marcia
+bodied
+fremont
+suarez
+altering
+tempest
+mussolini
+porn
+##isms
+sweetly
+oversees
+walkers
+solitude
+grimly
+shrines
+hk
+ich
+supervisors
+hostess
+dietrich
+legitimacy
+brushes
+expressive
+##yp
+dissipated
+##rse
+localized
+systemic
+##nikov
+gettysburg
+##js
+##uaries
+dialogues
+muttering
+251
+housekeeper
+sicilian
+discouraged
+##frey
+beamed
+kaladin
+halftime
+kidnap
+##amo
+##llet
+1754
+synonymous
+depleted
+instituto
+insulin
+reprised
+##opsis
+clashed
+##ctric
+interrupting
+radcliffe
+insisting
+medici
+1715
+ejected
+playfully
+turbulent
+##47
+starvation
+##rini
+shipment
+rebellious
+petersen
+verification
+merits
+##rified
+cakes
+##charged
+1757
+milford
+shortages
+spying
+fidelity
+##aker
+emitted
+storylines
+harvested
+seismic
+##iform
+cheung
+kilda
+theoretically
+barbie
+lynx
+##rgy
+##tius
+goblin
+mata
+poisonous
+##nburg
+reactive
+residues
+obedience
+##евич
+conjecture
+##rac
+401
+hating
+sixties
+kicker
+moaning
+motown
+##bha
+emancipation
+neoclassical
+##hering
+consoles
+ebert
+professorship
+##tures
+sustaining
+assaults
+obeyed
+affluent
+incurred
+tornadoes
+##eber
+##zow
+emphasizing
+highlanders
+cheated
+helmets
+##ctus
+internship
+terence
+bony
+executions
+legislators
+berries
+peninsular
+tinged
+##aco
+1689
+amplifier
+corvette
+ribbons
+lavish
+pennant
+##lander
+worthless
+##chfield
+##forms
+mariano
+pyrenees
+expenditures
+##icides
+chesterfield
+mandir
+tailor
+39th
+sergey
+nestled
+willed
+aristocracy
+devotees
+goodnight
+raaf
+rumored
+weaponry
+remy
+appropriations
+harcourt
+burr
+riaa
+##lence
+limitation
+unnoticed
+guo
+soaking
+swamps
+##tica
+collapsing
+tatiana
+descriptive
+brigham
+psalm
+##chment
+maddox
+##lization
+patti
+caliph
+##aja
+akron
+injuring
+serra
+##ganj
+basins
+##sari
+astonished
+launcher
+##church
+hilary
+wilkins
+sewing
+##sf
+stinging
+##fia
+##ncia
+underwood
+startup
+##ition
+compilations
+vibrations
+embankment
+jurist
+##nity
+bard
+juventus
+groundwater
+kern
+palaces
+helium
+boca
+cramped
+marissa
+soto
+##worm
+jae
+princely
+##ggy
+faso
+bazaar
+warmly
+##voking
+229
+pairing
+##lite
+##grate
+##nets
+wien
+freaked
+ulysses
+rebirth
+##alia
+##rent
+mummy
+guzman
+jimenez
+stilled
+##nitz
+trajectory
+tha
+woken
+archival
+professions
+##pts
+##pta
+hilly
+shadowy
+shrink
+##bolt
+norwood
+glued
+migrate
+stereotypes
+devoid
+##pheus
+625
+evacuate
+horrors
+infancy
+gotham
+knowles
+optic
+downloaded
+sachs
+kingsley
+parramatta
+darryl
+mor
+##onale
+shady
+commence
+confesses
+kan
+##meter
+##placed
+marlborough
+roundabout
+regents
+frigates
+io
+##imating
+gothenburg
+revoked
+carvings
+clockwise
+convertible
+intruder
+##sche
+banged
+##ogo
+vicky
+bourgeois
+##mony
+dupont
+footing
+##gum
+pd
+##real
+buckle
+yun
+penthouse
+sane
+720
+serviced
+stakeholders
+neumann
+bb
+##eers
+comb
+##gam
+catchment
+pinning
+rallies
+typing
+##elles
+forefront
+freiburg
+sweetie
+giacomo
+widowed
+goodwill
+worshipped
+aspirations
+midday
+##vat
+fishery
+##trick
+bournemouth
+turk
+243
+hearth
+ethanol
+guadalajara
+murmurs
+sl
+##uge
+afforded
+scripted
+##hta
+wah
+##jn
+coroner
+translucent
+252
+memorials
+puck
+progresses
+clumsy
+##race
+315
+candace
+recounted
+##27
+##slin
+##uve
+filtering
+##mac
+howl
+strata
+heron
+leveled
+##ays
+dubious
+##oja
+##т
+##wheel
+citations
+exhibiting
+##laya
+##mics
+##pods
+turkic
+##lberg
+injunction
+##ennial
+##mit
+antibodies
+##44
+organise
+##rigues
+cardiovascular
+cushion
+inverness
+##zquez
+dia
+cocoa
+sibling
+##tman
+##roid
+expanse
+feasible
+tunisian
+algiers
+##relli
+rus
+bloomberg
+dso
+westphalia
+bro
+tacoma
+281
+downloads
+##ours
+konrad
+duran
+##hdi
+continuum
+jett
+compares
+legislator
+secession
+##nable
+##gues
+##zuka
+translating
+reacher
+##gley
+##ła
+aleppo
+##agi
+tc
+orchards
+trapping
+linguist
+versatile
+drumming
+postage
+calhoun
+superiors
+##mx
+barefoot
+leary
+##cis
+ignacio
+alfa
+kaplan
+##rogen
+bratislava
+mori
+##vot
+disturb
+haas
+313
+cartridges
+gilmore
+radiated
+salford
+tunic
+hades
+##ulsive
+archeological
+delilah
+magistrates
+auditioned
+brewster
+charters
+empowerment
+blogs
+cappella
+dynasties
+iroquois
+whipping
+##krishna
+raceway
+truths
+myra
+weaken
+judah
+mcgregor
+##horse
+mic
+refueling
+37th
+burnley
+bosses
+markus
+premio
+query
+##gga
+dunbar
+##economic
+darkest
+lyndon
+sealing
+commendation
+reappeared
+##mun
+addicted
+ezio
+slaughtered
+satisfactory
+shuffle
+##eves
+##thic
+##uj
+fortification
+warrington
+##otto
+resurrected
+fargo
+mane
+##utable
+##lei
+##space
+foreword
+ox
+##aris
+##vern
+abrams
+hua
+##mento
+sakura
+##alo
+uv
+sentimental
+##skaya
+midfield
+##eses
+sturdy
+scrolls
+macleod
+##kyu
+entropy
+##lance
+mitochondrial
+cicero
+excelled
+thinner
+convoys
+perceive
+##oslav
+##urable
+systematically
+grind
+burkina
+287
+##tagram
+ops
+##aman
+guantanamo
+##cloth
+##tite
+forcefully
+wavy
+##jou
+pointless
+##linger
+##tze
+layton
+portico
+superficial
+clerical
+outlaws
+##hism
+burials
+muir
+##inn
+creditors
+hauling
+rattle
+##leg
+calais
+monde
+archers
+reclaimed
+dwell
+wexford
+hellenic
+falsely
+remorse
+##tek
+dough
+furnishings
+##uttered
+gabon
+neurological
+novice
+##igraphy
+contemplated
+pulpit
+nightstand
+saratoga
+##istan
+documenting
+pulsing
+taluk
+##firmed
+busted
+marital
+##rien
+disagreements
+wasps
+##yes
+hodge
+mcdonnell
+mimic
+fran
+pendant
+dhabi
+musa
+##nington
+congratulations
+argent
+darrell
+concussion
+losers
+regrets
+thessaloniki
+reversal
+donaldson
+hardwood
+thence
+achilles
+ritter
+##eran
+demonic
+jurgen
+prophets
+goethe
+eki
+classmate
+buff
+##cking
+yank
+irrational
+##inging
+perished
+seductive
+qur
+sourced
+##crat
+##typic
+mustard
+ravine
+barre
+horizontally
+characterization
+phylogenetic
+boise
+##dit
+##runner
+##tower
+brutally
+intercourse
+seduce
+##bbing
+fay
+ferris
+ogden
+amar
+nik
+unarmed
+##inator
+evaluating
+kyrgyzstan
+sweetness
+##lford
+##oki
+mccormick
+meiji
+notoriety
+stimulate
+disrupt
+figuring
+instructional
+mcgrath
+##zoo
+groundbreaking
+##lto
+flinch
+khorasan
+agrarian
+bengals
+mixer
+radiating
+##sov
+ingram
+pitchers
+nad
+tariff
+##cript
+tata
+##codes
+##emi
+##ungen
+appellate
+lehigh
+##bled
+##giri
+brawl
+duct
+texans
+##ciation
+##ropolis
+skipper
+speculative
+vomit
+doctrines
+stresses
+253
+davy
+graders
+whitehead
+jozef
+timely
+cumulative
+haryana
+paints
+appropriately
+boon
+cactus
+##ales
+##pid
+dow
+legions
+##pit
+perceptions
+1730
+picturesque
+##yse
+periphery
+rune
+wr
+##aha
+celtics
+sentencing
+whoa
+##erin
+confirms
+variance
+425
+moines
+mathews
+spade
+rave
+m1
+fronted
+fx
+blending
+alleging
+reared
+##gl
+237
+##paper
+grassroots
+eroded
+##free
+##physical
+directs
+ordeal
+##sław
+accelerate
+hacker
+rooftop
+##inia
+lev
+buys
+cebu
+devote
+##lce
+specialising
+##ulsion
+choreographed
+repetition
+warehouses
+##ryl
+paisley
+tuscany
+analogy
+sorcerer
+hash
+huts
+shards
+descends
+exclude
+nix
+chaplin
+gaga
+ito
+vane
+##drich
+causeway
+misconduct
+limo
+orchestrated
+glands
+jana
+##kot
+u2
+##mple
+##sons
+branching
+contrasts
+scoop
+longed
+##virus
+chattanooga
+##75
+syrup
+cornerstone
+##tized
+##mind
+##iaceae
+careless
+precedence
+frescoes
+##uet
+chilled
+consult
+modelled
+snatch
+peat
+##thermal
+caucasian
+humane
+relaxation
+spins
+temperance
+##lbert
+occupations
+lambda
+hybrids
+moons
+mp3
+##oese
+247
+rolf
+societal
+yerevan
+ness
+##ssler
+befriended
+mechanized
+nominate
+trough
+boasted
+cues
+seater
+##hom
+bends
+##tangle
+conductors
+emptiness
+##lmer
+eurasian
+adriatic
+tian
+##cie
+anxiously
+lark
+propellers
+chichester
+jock
+ev
+2a
+##holding
+credible
+recounts
+tori
+loyalist
+abduction
+##hoot
+##redo
+nepali
+##mite
+ventral
+tempting
+##ango
+##crats
+steered
+##wice
+javelin
+dipping
+laborers
+prentice
+looming
+titanium
+##ː
+badges
+emir
+tensor
+##ntation
+egyptians
+rash
+denies
+hawthorne
+lombard
+showers
+wehrmacht
+dietary
+trojan
+##reus
+welles
+executing
+horseshoe
+lifeboat
+##lak
+elsa
+infirmary
+nearing
+roberta
+boyer
+mutter
+trillion
+joanne
+##fine
+##oked
+sinks
+vortex
+uruguayan
+clasp
+sirius
+##block
+accelerator
+prohibit
+sunken
+byu
+chronological
+diplomats
+ochreous
+510
+symmetrical
+1644
+maia
+##tology
+salts
+reigns
+atrocities
+##ия
+hess
+bared
+issn
+##vyn
+cater
+saturated
+##cycle
+##isse
+sable
+voyager
+dyer
+yusuf
+##inge
+fountains
+wolff
+##39
+##nni
+engraving
+rollins
+atheist
+ominous
+##ault
+herr
+chariot
+martina
+strung
+##fell
+##farlane
+horrific
+sahib
+gazes
+saetan
+erased
+ptolemy
+##olic
+flushing
+lauderdale
+analytic
+##ices
+530
+navarro
+beak
+gorilla
+herrera
+broom
+guadalupe
+raiding
+sykes
+311
+bsc
+deliveries
+1720
+invasions
+carmichael
+tajikistan
+thematic
+ecumenical
+sentiments
+onstage
+##rians
+##brand
+##sume
+catastrophic
+flanks
+molten
+##arns
+waller
+aimee
+terminating
+##icing
+alternately
+##oche
+nehru
+printers
+outraged
+##eving
+empires
+template
+banners
+repetitive
+za
+##oise
+vegetarian
+##tell
+guiana
+opt
+cavendish
+lucknow
+synthesized
+##hani
+##mada
+finalized
+##ctable
+fictitious
+mayoral
+unreliable
+##enham
+embracing
+peppers
+rbis
+##chio
+##neo
+inhibition
+slashed
+togo
+orderly
+embroidered
+safari
+salty
+236
+barron
+benito
+totaled
+##dak
+pubs
+simulated
+caden
+devin
+tolkien
+momma
+welding
+sesame
+##ept
+gottingen
+hardness
+630
+shaman
+temeraire
+620
+adequately
+pediatric
+##kit
+ck
+assertion
+radicals
+composure
+cadence
+seafood
+beaufort
+lazarus
+mani
+warily
+cunning
+kurdistan
+249
+cantata
+##kir
+ares
+##41
+##clusive
+nape
+townland
+geared
+insulted
+flutter
+boating
+violate
+draper
+dumping
+malmo
+##hh
+##romatic
+firearm
+alta
+bono
+obscured
+##clave
+exceeds
+panorama
+unbelievable
+##train
+preschool
+##essed
+disconnected
+installing
+rescuing
+secretaries
+accessibility
+##castle
+##drive
+##ifice
+##film
+bouts
+slug
+waterway
+mindanao
+##buro
+##ratic
+halves
+##ل
+calming
+liter
+maternity
+adorable
+bragg
+electrification
+mcc
+##dote
+roxy
+schizophrenia
+##body
+munoz
+kaye
+whaling
+239
+mil
+tingling
+tolerant
+##ago
+unconventional
+volcanoes
+##finder
+deportivo
+##llie
+robson
+kaufman
+neuroscience
+wai
+deportation
+masovian
+scraping
+converse
+##bh
+hacking
+bulge
+##oun
+administratively
+yao
+580
+amp
+mammoth
+booster
+claremont
+hooper
+nomenclature
+pursuits
+mclaughlin
+melinda
+##sul
+catfish
+barclay
+substrates
+taxa
+zee
+originals
+kimberly
+packets
+padma
+##ality
+borrowing
+ostensibly
+solvent
+##bri
+##genesis
+##mist
+lukas
+shreveport
+veracruz
+##ь
+##lou
+##wives
+cheney
+tt
+anatolia
+hobbs
+##zyn
+cyclic
+radiant
+alistair
+greenish
+siena
+dat
+independents
+##bation
+conform
+pieter
+hyper
+applicant
+bradshaw
+spores
+telangana
+vinci
+inexpensive
+nuclei
+322
+jang
+nme
+soho
+spd
+##ign
+cradled
+receptionist
+pow
+##43
+##rika
+fascism
+##ifer
+experimenting
+##ading
+##iec
+##region
+345
+jocelyn
+maris
+stair
+nocturnal
+toro
+constabulary
+elgin
+##kker
+msc
+##giving
+##schen
+##rase
+doherty
+doping
+sarcastically
+batter
+maneuvers
+##cano
+##apple
+##gai
+##git
+intrinsic
+##nst
+##stor
+1753
+showtime
+cafes
+gasps
+lviv
+ushered
+##thed
+fours
+restart
+astonishment
+transmitting
+flyer
+shrugs
+##sau
+intriguing
+cones
+dictated
+mushrooms
+medial
+##kovsky
+##elman
+escorting
+gaped
+##26
+godfather
+##door
+##sell
+djs
+recaptured
+timetable
+vila
+1710
+3a
+aerodrome
+mortals
+scientology
+##orne
+angelina
+mag
+convection
+unpaid
+insertion
+intermittent
+lego
+##nated
+endeavor
+kota
+pereira
+##lz
+304
+bwv
+glamorgan
+insults
+agatha
+fey
+##cend
+fleetwood
+mahogany
+protruding
+steamship
+zeta
+##arty
+mcguire
+suspense
+##sphere
+advising
+urges
+##wala
+hurriedly
+meteor
+gilded
+inline
+arroyo
+stalker
+##oge
+excitedly
+revered
+##cure
+earle
+introductory
+##break
+##ilde
+mutants
+puff
+pulses
+reinforcement
+##haling
+curses
+lizards
+stalk
+correlated
+##fixed
+fallout
+macquarie
+##unas
+bearded
+denton
+heaving
+802
+##ocation
+winery
+assign
+dortmund
+##lkirk
+everest
+invariant
+charismatic
+susie
+##elling
+bled
+lesley
+telegram
+sumner
+bk
+##ogen
+##к
+wilcox
+needy
+colbert
+duval
+##iferous
+##mbled
+allotted
+attends
+imperative
+##hita
+replacements
+hawker
+##inda
+insurgency
+##zee
+##eke
+casts
+##yla
+680
+ives
+transitioned
+##pack
+##powering
+authoritative
+baylor
+flex
+cringed
+plaintiffs
+woodrow
+##skie
+drastic
+ape
+aroma
+unfolded
+commotion
+nt
+preoccupied
+theta
+routines
+lasers
+privatization
+wand
+domino
+ek
+clenching
+nsa
+strategically
+showered
+bile
+handkerchief
+pere
+storing
+christophe
+insulting
+316
+nakamura
+romani
+asiatic
+magdalena
+palma
+cruises
+stripping
+405
+konstantin
+soaring
+##berman
+colloquially
+forerunner
+havilland
+incarcerated
+parasites
+sincerity
+##utus
+disks
+plank
+saigon
+##ining
+corbin
+homo
+ornaments
+powerhouse
+##tlement
+chong
+fastened
+feasibility
+idf
+morphological
+usable
+##nish
+##zuki
+aqueduct
+jaguars
+keepers
+##flies
+aleksandr
+faust
+assigns
+ewing
+bacterium
+hurled
+tricky
+hungarians
+integers
+wallis
+321
+yamaha
+##isha
+hushed
+oblivion
+aviator
+evangelist
+friars
+##eller
+monograph
+ode
+##nary
+airplanes
+labourers
+charms
+##nee
+1661
+hagen
+tnt
+rudder
+fiesta
+transcript
+dorothea
+ska
+inhibitor
+maccabi
+retorted
+raining
+encompassed
+clauses
+menacing
+1642
+lineman
+##gist
+vamps
+##ape
+##dick
+gloom
+##rera
+dealings
+easing
+seekers
+##nut
+##pment
+helens
+unmanned
+##anu
+##isson
+basics
+##amy
+##ckman
+adjustments
+1688
+brutality
+horne
+##zell
+sui
+##55
+##mable
+aggregator
+##thal
+rhino
+##drick
+##vira
+counters
+zoom
+##01
+##rting
+mn
+montenegrin
+packard
+##unciation
+##♭
+##kki
+reclaim
+scholastic
+thugs
+pulsed
+##icia
+syriac
+quan
+saddam
+banda
+kobe
+blaming
+buddies
+dissent
+##lusion
+##usia
+corbett
+jaya
+delle
+erratic
+lexie
+##hesis
+435
+amiga
+hermes
+##pressing
+##leen
+chapels
+gospels
+jamal
+##uating
+compute
+revolving
+warp
+##sso
+##thes
+armory
+##eras
+##gol
+antrim
+loki
+##kow
+##asian
+##good
+##zano
+braid
+handwriting
+subdistrict
+funky
+pantheon
+##iculate
+concurrency
+estimation
+improper
+juliana
+##his
+newcomers
+johnstone
+staten
+communicated
+##oco
+##alle
+sausage
+stormy
+##stered
+##tters
+superfamily
+##grade
+acidic
+collateral
+tabloid
+##oped
+##rza
+bladder
+austen
+##ellant
+mcgraw
+##hay
+hannibal
+mein
+aquino
+lucifer
+wo
+badger
+boar
+cher
+christensen
+greenberg
+interruption
+##kken
+jem
+244
+mocked
+bottoms
+cambridgeshire
+##lide
+sprawling
+##bbly
+eastwood
+ghent
+synth
+##buck
+advisers
+##bah
+nominally
+hapoel
+qu
+daggers
+estranged
+fabricated
+towels
+vinnie
+wcw
+misunderstanding
+anglia
+nothin
+unmistakable
+##dust
+##lova
+chilly
+marquette
+truss
+##edge
+##erine
+reece
+##lty
+##chemist
+##connected
+272
+308
+41st
+bash
+raion
+waterfalls
+##ump
+##main
+labyrinth
+queue
+theorist
+##istle
+bharatiya
+flexed
+soundtracks
+rooney
+leftist
+patrolling
+wharton
+plainly
+alleviate
+eastman
+schuster
+topographic
+engages
+immensely
+unbearable
+fairchild
+1620
+dona
+lurking
+parisian
+oliveira
+ia
+indictment
+hahn
+bangladeshi
+##aster
+vivo
+##uming
+##ential
+antonia
+expects
+indoors
+kildare
+harlan
+##logue
+##ogenic
+##sities
+forgiven
+##wat
+childish
+tavi
+##mide
+##orra
+plausible
+grimm
+successively
+scooted
+##bola
+##dget
+##rith
+spartans
+emery
+flatly
+azure
+epilogue
+##wark
+flourish
+##iny
+##tracted
+##overs
+##oshi
+bestseller
+distressed
+receipt
+spitting
+hermit
+topological
+##cot
+drilled
+subunit
+francs
+##layer
+eel
+##fk
+##itas
+octopus
+footprint
+petitions
+ufo
+##say
+##foil
+interfering
+leaking
+palo
+##metry
+thistle
+valiant
+##pic
+narayan
+mcpherson
+##fast
+gonzales
+##ym
+##enne
+dustin
+novgorod
+solos
+##zman
+doin
+##raph
+##patient
+##meyer
+soluble
+ashland
+cuffs
+carole
+pendleton
+whistling
+vassal
+##river
+deviation
+revisited
+constituents
+rallied
+rotate
+loomed
+##eil
+##nting
+amateurs
+augsburg
+auschwitz
+crowns
+skeletons
+##cona
+bonnet
+257
+dummy
+globalization
+simeon
+sleeper
+mandal
+differentiated
+##crow
+##mare
+milne
+bundled
+exasperated
+talmud
+owes
+segregated
+##feng
+##uary
+dentist
+piracy
+props
+##rang
+devlin
+##torium
+malicious
+paws
+##laid
+dependency
+##ergy
+##fers
+##enna
+258
+pistons
+rourke
+jed
+grammatical
+tres
+maha
+wig
+512
+ghostly
+jayne
+##achal
+##creen
+##ilis
+##lins
+##rence
+designate
+##with
+arrogance
+cambodian
+clones
+showdown
+throttle
+twain
+##ception
+lobes
+metz
+nagoya
+335
+braking
+##furt
+385
+roaming
+##minster
+amin
+crippled
+##37
+##llary
+indifferent
+hoffmann
+idols
+intimidating
+1751
+261
+influenza
+memo
+onions
+1748
+bandage
+consciously
+##landa
+##rage
+clandestine
+observes
+swiped
+tangle
+##ener
+##jected
+##trum
+##bill
+##lta
+hugs
+congresses
+josiah
+spirited
+##dek
+humanist
+managerial
+filmmaking
+inmate
+rhymes
+debuting
+grimsby
+ur
+##laze
+duplicate
+vigor
+##tf
+republished
+bolshevik
+refurbishment
+antibiotics
+martini
+methane
+newscasts
+royale
+horizons
+levant
+iain
+visas
+##ischen
+paler
+##around
+manifestation
+snuck
+alf
+chop
+futile
+pedestal
+rehab
+##kat
+bmg
+kerman
+res
+fairbanks
+jarrett
+abstraction
+saharan
+##zek
+1746
+procedural
+clearer
+kincaid
+sash
+luciano
+##ffey
+crunch
+helmut
+##vara
+revolutionaries
+##tute
+creamy
+leach
+##mmon
+1747
+permitting
+nes
+plight
+wendell
+##lese
+contra
+ts
+clancy
+ipa
+mach
+staples
+autopsy
+disturbances
+nueva
+karin
+pontiac
+##uding
+proxy
+venerable
+haunt
+leto
+bergman
+expands
+##helm
+wal
+##pipe
+canning
+celine
+cords
+obesity
+##enary
+intrusion
+planner
+##phate
+reasoned
+sequencing
+307
+harrow
+##chon
+##dora
+marred
+mcintyre
+repay
+tarzan
+darting
+248
+harrisburg
+margarita
+repulsed
+##hur
+##lding
+belinda
+hamburger
+novo
+compliant
+runways
+bingham
+registrar
+skyscraper
+ic
+cuthbert
+improvisation
+livelihood
+##corp
+##elial
+admiring
+##dened
+sporadic
+believer
+casablanca
+popcorn
+##29
+asha
+shovel
+##bek
+##dice
+coiled
+tangible
+##dez
+casper
+elsie
+resin
+tenderness
+rectory
+##ivision
+avail
+sonar
+##mori
+boutique
+##dier
+guerre
+bathed
+upbringing
+vaulted
+sandals
+blessings
+##naut
+##utnant
+1680
+306
+foxes
+pia
+corrosion
+hesitantly
+confederates
+crystalline
+footprints
+shapiro
+tirana
+valentin
+drones
+45th
+microscope
+shipments
+texted
+inquisition
+wry
+guernsey
+unauthorized
+resigning
+760
+ripple
+schubert
+stu
+reassure
+felony
+##ardo
+brittle
+koreans
+##havan
+##ives
+dun
+implicit
+tyres
+##aldi
+##lth
+magnolia
+##ehan
+##puri
+##poulos
+aggressively
+fei
+gr
+familiarity
+##poo
+indicative
+##trust
+fundamentally
+jimmie
+overrun
+395
+anchors
+moans
+##opus
+britannia
+armagh
+##ggle
+purposely
+seizing
+##vao
+bewildered
+mundane
+avoidance
+cosmopolitan
+geometridae
+quartermaster
+caf
+415
+chatter
+engulfed
+gleam
+purge
+##icate
+juliette
+jurisprudence
+guerra
+revisions
+##bn
+casimir
+brew
+##jm
+1749
+clapton
+cloudy
+conde
+hermitage
+278
+simulations
+torches
+vincenzo
+matteo
+##rill
+hidalgo
+booming
+westbound
+accomplishment
+tentacles
+unaffected
+##sius
+annabelle
+flopped
+sloping
+##litz
+dreamer
+interceptor
+vu
+##loh
+consecration
+copying
+messaging
+breaker
+climates
+hospitalized
+1752
+torino
+afternoons
+winfield
+witnessing
+##teacher
+breakers
+choirs
+sawmill
+coldly
+##ege
+sipping
+haste
+uninhabited
+conical
+bibliography
+pamphlets
+severn
+edict
+##oca
+deux
+illnesses
+grips
+##pl
+rehearsals
+sis
+thinkers
+tame
+##keepers
+1690
+acacia
+reformer
+##osed
+##rys
+shuffling
+##iring
+##shima
+eastbound
+ionic
+rhea
+flees
+littered
+##oum
+rocker
+vomiting
+groaning
+champ
+overwhelmingly
+civilizations
+paces
+sloop
+adoptive
+##tish
+skaters
+##vres
+aiding
+mango
+##joy
+nikola
+shriek
+##ignon
+pharmaceuticals
+##mg
+tuna
+calvert
+gustavo
+stocked
+yearbook
+##urai
+##mana
+computed
+subsp
+riff
+hanoi
+kelvin
+hamid
+moors
+pastures
+summons
+jihad
+nectar
+##ctors
+bayou
+untitled
+pleasing
+vastly
+republics
+intellect
+##η
+##ulio
+##tou
+crumbling
+stylistic
+sb
+##ی
+consolation
+frequented
+h₂o
+walden
+widows
+##iens
+404
+##ignment
+chunks
+improves
+288
+grit
+recited
+##dev
+snarl
+sociological
+##arte
+##gul
+inquired
+##held
+bruise
+clube
+consultancy
+homogeneous
+hornets
+multiplication
+pasta
+prick
+savior
+##grin
+##kou
+##phile
+yoon
+##gara
+grimes
+vanishing
+cheering
+reacting
+bn
+distillery
+##quisite
+##vity
+coe
+dockyard
+massif
+##jord
+escorts
+voss
+##valent
+byte
+chopped
+hawke
+illusions
+workings
+floats
+##koto
+##vac
+kv
+annapolis
+madden
+##onus
+alvaro
+noctuidae
+##cum
+##scopic
+avenge
+steamboat
+forte
+illustrates
+erika
+##trip
+570
+dew
+nationalities
+bran
+manifested
+thirsty
+diversified
+muscled
+reborn
+##standing
+arson
+##lessness
+##dran
+##logram
+##boys
+##kushima
+##vious
+willoughby
+##phobia
+286
+alsace
+dashboard
+yuki
+##chai
+granville
+myspace
+publicized
+tricked
+##gang
+adjective
+##ater
+relic
+reorganisation
+enthusiastically
+indications
+saxe
+##lassified
+consolidate
+iec
+padua
+helplessly
+ramps
+renaming
+regulars
+pedestrians
+accents
+convicts
+inaccurate
+lowers
+mana
+##pati
+barrie
+bjp
+outta
+someplace
+berwick
+flanking
+invoked
+marrow
+sparsely
+excerpts
+clothed
+rei
+##ginal
+wept
+##straße
+##vish
+alexa
+excel
+##ptive
+membranes
+aquitaine
+creeks
+cutler
+sheppard
+implementations
+ns
+##dur
+fragrance
+budge
+concordia
+magnesium
+marcelo
+##antes
+gladly
+vibrating
+##rral
+##ggles
+montrose
+##omba
+lew
+seamus
+1630
+cocky
+##ament
+##uen
+bjorn
+##rrick
+fielder
+fluttering
+##lase
+methyl
+kimberley
+mcdowell
+reductions
+barbed
+##jic
+##tonic
+aeronautical
+condensed
+distracting
+##promising
+huffed
+##cala
+##sle
+claudius
+invincible
+missy
+pious
+balthazar
+ci
+##lang
+butte
+combo
+orson
+##dication
+myriad
+1707
+silenced
+##fed
+##rh
+coco
+netball
+yourselves
+##oza
+clarify
+heller
+peg
+durban
+etudes
+offender
+roast
+blackmail
+curvature
+##woods
+vile
+309
+illicit
+suriname
+##linson
+overture
+1685
+bubbling
+gymnast
+tucking
+##mming
+##ouin
+maldives
+##bala
+gurney
+##dda
+##eased
+##oides
+backside
+pinto
+jars
+racehorse
+tending
+##rdial
+baronetcy
+wiener
+duly
+##rke
+barbarian
+cupping
+flawed
+##thesis
+bertha
+pleistocene
+puddle
+swearing
+##nob
+##tically
+fleeting
+prostate
+amulet
+educating
+##mined
+##iti
+##tler
+75th
+jens
+respondents
+analytics
+cavaliers
+papacy
+raju
+##iente
+##ulum
+##tip
+funnel
+271
+disneyland
+##lley
+sociologist
+##iam
+2500
+faulkner
+louvre
+menon
+##dson
+276
+##ower
+afterlife
+mannheim
+peptide
+referees
+comedians
+meaningless
+##anger
+##laise
+fabrics
+hurley
+renal
+sleeps
+##bour
+##icle
+breakout
+kristin
+roadside
+animator
+clover
+disdain
+unsafe
+redesign
+##urity
+firth
+barnsley
+portage
+reset
+narrows
+268
+commandos
+expansive
+speechless
+tubular
+##lux
+essendon
+eyelashes
+smashwords
+##yad
+##bang
+##claim
+craved
+sprinted
+chet
+somme
+astor
+wrocław
+orton
+266
+bane
+##erving
+##uing
+mischief
+##amps
+##sund
+scaling
+terre
+##xious
+impairment
+offenses
+undermine
+moi
+soy
+contiguous
+arcadia
+inuit
+seam
+##tops
+macbeth
+rebelled
+##icative
+##iot
+590
+elaborated
+frs
+uniformed
+##dberg
+259
+powerless
+priscilla
+stimulated
+980
+qc
+arboretum
+frustrating
+trieste
+bullock
+##nified
+enriched
+glistening
+intern
+##adia
+locus
+nouvelle
+ollie
+ike
+lash
+starboard
+ee
+tapestry
+headlined
+hove
+rigged
+##vite
+pollock
+##yme
+thrive
+clustered
+cas
+roi
+gleamed
+olympiad
+##lino
+pressured
+regimes
+##hosis
+##lick
+ripley
+##ophone
+kickoff
+gallon
+rockwell
+##arable
+crusader
+glue
+revolutions
+scrambling
+1714
+grover
+##jure
+englishman
+aztec
+263
+contemplating
+coven
+ipad
+preach
+triumphant
+tufts
+##esian
+rotational
+##phus
+328
+falkland
+##brates
+strewn
+clarissa
+rejoin
+environmentally
+glint
+banded
+drenched
+moat
+albanians
+johor
+rr
+maestro
+malley
+nouveau
+shaded
+taxonomy
+v6
+adhere
+bunk
+airfields
+##ritan
+1741
+encompass
+remington
+tran
+##erative
+amelie
+mazda
+friar
+morals
+passions
+##zai
+breadth
+vis
+##hae
+argus
+burnham
+caressing
+insider
+rudd
+##imov
+##mini
+##rso
+italianate
+murderous
+textual
+wainwright
+armada
+bam
+weave
+timer
+##taken
+##nh
+fra
+##crest
+ardent
+salazar
+taps
+tunis
+##ntino
+allegro
+gland
+philanthropic
+##chester
+implication
+##optera
+esq
+judas
+noticeably
+wynn
+##dara
+inched
+indexed
+crises
+villiers
+bandit
+royalties
+patterned
+cupboard
+interspersed
+accessory
+isla
+kendrick
+entourage
+stitches
+##esthesia
+headwaters
+##ior
+interlude
+distraught
+draught
+1727
+##basket
+biased
+sy
+transient
+triad
+subgenus
+adapting
+kidd
+shortstop
+##umatic
+dimly
+spiked
+mcleod
+reprint
+nellie
+pretoria
+windmill
+##cek
+singled
+##mps
+273
+reunite
+##orous
+747
+bankers
+outlying
+##omp
+##ports
+##tream
+apologies
+cosmetics
+patsy
+##deh
+##ocks
+##yson
+bender
+nantes
+serene
+##nad
+lucha
+mmm
+323
+##cius
+##gli
+cmll
+coinage
+nestor
+juarez
+##rook
+smeared
+sprayed
+twitching
+sterile
+irina
+embodied
+juveniles
+enveloped
+miscellaneous
+cancers
+dq
+gulped
+luisa
+crested
+swat
+donegal
+ref
+##anov
+##acker
+hearst
+mercantile
+##lika
+doorbell
+ua
+vicki
+##alla
+##som
+bilbao
+psychologists
+stryker
+sw
+horsemen
+turkmenistan
+wits
+##national
+anson
+mathew
+screenings
+##umb
+rihanna
+##agne
+##nessy
+aisles
+##iani
+##osphere
+hines
+kenton
+saskatoon
+tasha
+truncated
+##champ
+##itan
+mildred
+advises
+fredrik
+interpreting
+inhibitors
+##athi
+spectroscopy
+##hab
+##kong
+karim
+panda
+##oia
+##nail
+##vc
+conqueror
+kgb
+leukemia
+##dity
+arrivals
+cheered
+pisa
+phosphorus
+shielded
+##riated
+mammal
+unitarian
+urgently
+chopin
+sanitary
+##mission
+spicy
+drugged
+hinges
+##tort
+tipping
+trier
+impoverished
+westchester
+##caster
+267
+epoch
+nonstop
+##gman
+##khov
+aromatic
+centrally
+cerro
+##tively
+##vio
+billions
+modulation
+sedimentary
+283
+facilitating
+outrageous
+goldstein
+##eak
+##kt
+ld
+maitland
+penultimate
+pollard
+##dance
+fleets
+spaceship
+vertebrae
+##nig
+alcoholism
+als
+recital
+##bham
+##ference
+##omics
+m2
+##bm
+trois
+##tropical
+##в
+commemorates
+##meric
+marge
+##raction
+1643
+670
+cosmetic
+ravaged
+##ige
+catastrophe
+eng
+##shida
+albrecht
+arterial
+bellamy
+decor
+harmon
+##rde
+bulbs
+synchronized
+vito
+easiest
+shetland
+shielding
+wnba
+##glers
+##ssar
+##riam
+brianna
+cumbria
+##aceous
+##rard
+cores
+thayer
+##nsk
+brood
+hilltop
+luminous
+carts
+keynote
+larkin
+logos
+##cta
+##ا
+##mund
+##quay
+lilith
+tinted
+277
+wrestle
+mobilization
+##uses
+sequential
+siam
+bloomfield
+takahashi
+274
+##ieving
+presenters
+ringo
+blazed
+witty
+##oven
+##ignant
+devastation
+haydn
+harmed
+newt
+therese
+##peed
+gershwin
+molina
+rabbis
+sudanese
+001
+innate
+restarted
+##sack
+##fus
+slices
+wb
+##shah
+enroll
+hypothetical
+hysterical
+1743
+fabio
+indefinite
+warped
+##hg
+exchanging
+525
+unsuitable
+##sboro
+gallo
+1603
+bret
+cobalt
+homemade
+##hunter
+mx
+operatives
+##dhar
+terraces
+durable
+latch
+pens
+whorls
+##ctuated
+##eaux
+billing
+ligament
+succumbed
+##gly
+regulators
+spawn
+##brick
+##stead
+filmfare
+rochelle
+##nzo
+1725
+circumstance
+saber
+supplements
+##nsky
+##tson
+crowe
+wellesley
+carrot
+##9th
+##movable
+primate
+drury
+sincerely
+topical
+##mad
+##rao
+callahan
+kyiv
+smarter
+tits
+undo
+##yeh
+announcements
+anthologies
+barrio
+nebula
+##islaus
+##shaft
+##tyn
+bodyguards
+2021
+assassinate
+barns
+emmett
+scully
+##mah
+##yd
+##eland
+##tino
+##itarian
+demoted
+gorman
+lashed
+prized
+adventist
+writ
+##gui
+alla
+invertebrates
+##ausen
+1641
+amman
+1742
+align
+healy
+redistribution
+##gf
+##rize
+insulation
+##drop
+adherents
+hezbollah
+vitro
+ferns
+yanking
+269
+php
+registering
+uppsala
+cheerleading
+confines
+mischievous
+tully
+##ross
+49th
+docked
+roam
+stipulated
+pumpkin
+##bry
+prompt
+##ezer
+blindly
+shuddering
+craftsmen
+frail
+scented
+katharine
+scramble
+shaggy
+sponge
+helix
+zaragoza
+279
+##52
+43rd
+backlash
+fontaine
+seizures
+posse
+cowan
+nonfiction
+telenovela
+wwii
+hammered
+undone
+##gpur
+encircled
+irs
+##ivation
+artefacts
+oneself
+searing
+smallpox
+##belle
+##osaurus
+shandong
+breached
+upland
+blushing
+rankin
+infinitely
+psyche
+tolerated
+docking
+evicted
+##col
+unmarked
+##lving
+gnome
+lettering
+litres
+musique
+##oint
+benevolent
+##jal
+blackened
+##anna
+mccall
+racers
+tingle
+##ocene
+##orestation
+introductions
+radically
+292
+##hiff
+##باد
+1610
+1739
+munchen
+plead
+##nka
+condo
+scissors
+##sight
+##tens
+apprehension
+##cey
+##yin
+hallmark
+watering
+formulas
+sequels
+##llas
+aggravated
+bae
+commencing
+##building
+enfield
+prohibits
+marne
+vedic
+civilized
+euclidean
+jagger
+beforehand
+blasts
+dumont
+##arney
+##nem
+740
+conversions
+hierarchical
+rios
+simulator
+##dya
+##lellan
+hedges
+oleg
+thrusts
+shadowed
+darby
+maximize
+1744
+gregorian
+##nded
+##routed
+sham
+unspecified
+##hog
+emory
+factual
+##smo
+##tp
+fooled
+##rger
+ortega
+wellness
+marlon
+##oton
+##urance
+casket
+keating
+ley
+enclave
+##ayan
+char
+influencing
+jia
+##chenko
+412
+ammonia
+erebidae
+incompatible
+violins
+cornered
+##arat
+grooves
+astronauts
+columbian
+rampant
+fabrication
+kyushu
+mahmud
+vanish
+##dern
+mesopotamia
+##lete
+ict
+##rgen
+caspian
+kenji
+pitted
+##vered
+999
+grimace
+roanoke
+tchaikovsky
+twinned
+##analysis
+##awan
+xinjiang
+arias
+clemson
+kazakh
+sizable
+1662
+##khand
+##vard
+plunge
+tatum
+vittorio
+##nden
+cholera
+##dana
+##oper
+bracing
+indifference
+projectile
+superliga
+##chee
+realises
+upgrading
+299
+porte
+retribution
+##vies
+nk
+stil
+##resses
+ama
+bureaucracy
+blackberry
+bosch
+testosterone
+collapses
+greer
+##pathic
+ioc
+fifties
+malls
+##erved
+bao
+baskets
+adolescents
+siegfried
+##osity
+##tosis
+mantra
+detecting
+existent
+fledgling
+##cchi
+dissatisfied
+gan
+telecommunication
+mingled
+sobbed
+6000
+controversies
+outdated
+taxis
+##raus
+fright
+slams
+##lham
+##fect
+##tten
+detectors
+fetal
+tanned
+##uw
+fray
+goth
+olympian
+skipping
+mandates
+scratches
+sheng
+unspoken
+hyundai
+tracey
+hotspur
+restrictive
+##buch
+americana
+mundo
+##bari
+burroughs
+diva
+vulcan
+##6th
+distinctions
+thumping
+##ngen
+mikey
+sheds
+fide
+rescues
+springsteen
+vested
+valuation
+##ece
+##ely
+pinnacle
+rake
+sylvie
+##edo
+almond
+quivering
+##irus
+alteration
+faltered
+##wad
+51st
+hydra
+ticked
+##kato
+recommends
+##dicated
+antigua
+arjun
+stagecoach
+wilfred
+trickle
+pronouns
+##pon
+aryan
+nighttime
+##anian
+gall
+pea
+stitch
+##hei
+leung
+milos
+##dini
+eritrea
+nexus
+starved
+snowfall
+kant
+parasitic
+cot
+discus
+hana
+strikers
+appleton
+kitchens
+##erina
+##partisan
+##itha
+##vius
+disclose
+metis
+##channel
+1701
+tesla
+##vera
+fitch
+1735
+blooded
+##tila
+decimal
+##tang
+##bai
+cyclones
+eun
+bottled
+peas
+pensacola
+basha
+bolivian
+crabs
+boil
+lanterns
+partridge
+roofed
+1645
+necks
+##phila
+opined
+patting
+##kla
+##lland
+chuckles
+volta
+whereupon
+##nche
+devout
+euroleague
+suicidal
+##dee
+inherently
+involuntary
+knitting
+nasser
+##hide
+puppets
+colourful
+courageous
+southend
+stills
+miraculous
+hodgson
+richer
+rochdale
+ethernet
+greta
+uniting
+prism
+umm
+##haya
+##itical
+##utation
+deterioration
+pointe
+prowess
+##ropriation
+lids
+scranton
+billings
+subcontinent
+##koff
+##scope
+brute
+kellogg
+psalms
+degraded
+##vez
+stanisław
+##ructured
+ferreira
+pun
+astonishing
+gunnar
+##yat
+arya
+prc
+gottfried
+##tight
+excursion
+##ographer
+dina
+##quil
+##nare
+huffington
+illustrious
+wilbur
+gundam
+verandah
+##zard
+naacp
+##odle
+constructive
+fjord
+kade
+##naud
+generosity
+thrilling
+baseline
+cayman
+frankish
+plastics
+accommodations
+zoological
+##fting
+cedric
+qb
+motorized
+##dome
+##otted
+squealed
+tackled
+canucks
+budgets
+situ
+asthma
+dail
+gabled
+grasslands
+whimpered
+writhing
+judgments
+##65
+minnie
+pv
+##carbon
+bananas
+grille
+domes
+monique
+odin
+maguire
+markham
+tierney
+##estra
+##chua
+libel
+poke
+speedy
+atrium
+laval
+notwithstanding
+##edly
+fai
+kala
+##sur
+robb
+##sma
+listings
+luz
+supplementary
+tianjin
+##acing
+enzo
+jd
+ric
+scanner
+croats
+transcribed
+##49
+arden
+cv
+##hair
+##raphy
+##lver
+##uy
+357
+seventies
+staggering
+alam
+horticultural
+hs
+regression
+timbers
+blasting
+##ounded
+montagu
+manipulating
+##cit
+catalytic
+1550
+troopers
+##meo
+condemnation
+fitzpatrick
+##oire
+##roved
+inexperienced
+1670
+castes
+##lative
+outing
+314
+dubois
+flicking
+quarrel
+ste
+learners
+1625
+iq
+whistled
+##class
+282
+classify
+tariffs
+temperament
+355
+folly
+liszt
+##yles
+immersed
+jordanian
+ceasefire
+apparel
+extras
+maru
+fished
+##bio
+harta
+stockport
+assortment
+craftsman
+paralysis
+transmitters
+##cola
+blindness
+##wk
+fatally
+proficiency
+solemnly
+##orno
+repairing
+amore
+groceries
+ultraviolet
+##chase
+schoolhouse
+##tua
+resurgence
+nailed
+##otype
+##×
+ruse
+saliva
+diagrams
+##tructing
+albans
+rann
+thirties
+1b
+antennas
+hilarious
+cougars
+paddington
+stats
+##eger
+breakaway
+ipod
+reza
+authorship
+prohibiting
+scoffed
+##etz
+##ttle
+conscription
+defected
+trondheim
+##fires
+ivanov
+keenan
+##adan
+##ciful
+##fb
+##slow
+locating
+##ials
+##tford
+cadiz
+basalt
+blankly
+interned
+rags
+rattling
+##tick
+carpathian
+reassured
+sync
+bum
+guildford
+iss
+staunch
+##onga
+astronomers
+sera
+sofie
+emergencies
+susquehanna
+##heard
+duc
+mastery
+vh1
+williamsburg
+bayer
+buckled
+craving
+##khan
+##rdes
+bloomington
+##write
+alton
+barbecue
+##bians
+justine
+##hri
+##ndt
+delightful
+smartphone
+newtown
+photon
+retrieval
+peugeot
+hissing
+##monium
+##orough
+flavors
+lighted
+relaunched
+tainted
+##games
+##lysis
+anarchy
+microscopic
+hopping
+adept
+evade
+evie
+##beau
+inhibit
+sinn
+adjustable
+hurst
+intuition
+wilton
+cisco
+44th
+lawful
+lowlands
+stockings
+thierry
+##dalen
+##hila
+##nai
+fates
+prank
+tb
+maison
+lobbied
+provocative
+1724
+4a
+utopia
+##qual
+carbonate
+gujarati
+purcell
+##rford
+curtiss
+##mei
+overgrown
+arenas
+mediation
+swallows
+##rnik
+respectful
+turnbull
+##hedron
+##hope
+alyssa
+ozone
+##ʻi
+ami
+gestapo
+johansson
+snooker
+canteen
+cuff
+declines
+empathy
+stigma
+##ags
+##iner
+##raine
+taxpayers
+gui
+volga
+##wright
+##copic
+lifespan
+overcame
+tattooed
+enactment
+giggles
+##ador
+##camp
+barrington
+bribe
+obligatory
+orbiting
+peng
+##enas
+elusive
+sucker
+##vating
+cong
+hardship
+empowered
+anticipating
+estrada
+cryptic
+greasy
+detainees
+planck
+sudbury
+plaid
+dod
+marriott
+kayla
+##ears
+##vb
+##zd
+mortally
+##hein
+cognition
+radha
+319
+liechtenstein
+meade
+richly
+argyle
+harpsichord
+liberalism
+trumpets
+lauded
+tyrant
+salsa
+tiled
+lear
+promoters
+reused
+slicing
+trident
+##chuk
+##gami
+##lka
+cantor
+checkpoint
+##points
+gaul
+leger
+mammalian
+##tov
+##aar
+##schaft
+doha
+frenchman
+nirvana
+##vino
+delgado
+headlining
+##eron
+##iography
+jug
+tko
+1649
+naga
+intersections
+##jia
+benfica
+nawab
+##suka
+ashford
+gulp
+##deck
+##vill
+##rug
+brentford
+frazier
+pleasures
+dunne
+potsdam
+shenzhen
+dentistry
+##tec
+flanagan
+##dorff
+##hear
+chorale
+dinah
+prem
+quezon
+##rogated
+relinquished
+sutra
+terri
+##pani
+flaps
+##rissa
+poly
+##rnet
+homme
+aback
+##eki
+linger
+womb
+##kson
+##lewood
+doorstep
+orthodoxy
+threaded
+westfield
+##rval
+dioceses
+fridays
+subsided
+##gata
+loyalists
+##biotic
+##ettes
+letterman
+lunatic
+prelate
+tenderly
+invariably
+souza
+thug
+winslow
+##otide
+furlongs
+gogh
+jeopardy
+##runa
+pegasus
+##umble
+humiliated
+standalone
+tagged
+##roller
+freshmen
+klan
+##bright
+attaining
+initiating
+transatlantic
+logged
+viz
+##uance
+1723
+combatants
+intervening
+stephane
+chieftain
+despised
+grazed
+317
+cdc
+galveston
+godzilla
+macro
+simulate
+##planes
+parades
+##esses
+960
+##ductive
+##unes
+equator
+overdose
+##cans
+##hosh
+##lifting
+joshi
+epstein
+sonora
+treacherous
+aquatics
+manchu
+responsive
+##sation
+supervisory
+##christ
+##llins
+##ibar
+##balance
+##uso
+kimball
+karlsruhe
+mab
+##emy
+ignores
+phonetic
+reuters
+spaghetti
+820
+almighty
+danzig
+rumbling
+tombstone
+designations
+lured
+outset
+##felt
+supermarkets
+##wt
+grupo
+kei
+kraft
+susanna
+##blood
+comprehension
+genealogy
+##aghan
+##verted
+redding
+##ythe
+1722
+bowing
+##pore
+##roi
+lest
+sharpened
+fulbright
+valkyrie
+sikhs
+##unds
+swans
+bouquet
+merritt
+##tage
+##venting
+commuted
+redhead
+clerks
+leasing
+cesare
+dea
+hazy
+##vances
+fledged
+greenfield
+servicemen
+##gical
+armando
+blackout
+dt
+sagged
+downloadable
+intra
+potion
+pods
+##4th
+##mism
+xp
+attendants
+gambia
+stale
+##ntine
+plump
+asteroids
+rediscovered
+buds
+flea
+hive
+##neas
+1737
+classifications
+debuts
+##eles
+olympus
+scala
+##eurs
+##gno
+##mute
+hummed
+sigismund
+visuals
+wiggled
+await
+pilasters
+clench
+sulfate
+##ances
+bellevue
+enigma
+trainee
+snort
+##sw
+clouded
+denim
+##rank
+##rder
+churning
+hartman
+lodges
+riches
+sima
+##missible
+accountable
+socrates
+regulates
+mueller
+##cr
+1702
+avoids
+solids
+himalayas
+nutrient
+pup
+##jevic
+squat
+fades
+nec
+##lates
+##pina
+##rona
+##ου
+privateer
+tequila
+##gative
+##mpton
+apt
+hornet
+immortals
+##dou
+asturias
+cleansing
+dario
+##rries
+##anta
+etymology
+servicing
+zhejiang
+##venor
+##nx
+horned
+erasmus
+rayon
+relocating
+£10
+##bags
+escalated
+promenade
+stubble
+2010s
+artisans
+axial
+liquids
+mora
+sho
+yoo
+##tsky
+bundles
+oldies
+##nally
+notification
+bastion
+##ths
+sparkle
+##lved
+1728
+leash
+pathogen
+highs
+##hmi
+immature
+880
+gonzaga
+ignatius
+mansions
+monterrey
+sweets
+bryson
+##loe
+polled
+regatta
+brightest
+pei
+rosy
+squid
+hatfield
+payroll
+addict
+meath
+cornerback
+heaviest
+lodging
+##mage
+capcom
+rippled
+##sily
+barnet
+mayhem
+ymca
+snuggled
+rousseau
+##cute
+blanchard
+284
+fragmented
+leighton
+chromosomes
+risking
+##md
+##strel
+##utter
+corinne
+coyotes
+cynical
+hiroshi
+yeomanry
+##ractive
+ebook
+grading
+mandela
+plume
+agustin
+magdalene
+##rkin
+bea
+femme
+trafford
+##coll
+##lun
+##tance
+52nd
+fourier
+upton
+##mental
+camilla
+gust
+iihf
+islamabad
+longevity
+##kala
+feldman
+netting
+##rization
+endeavour
+foraging
+mfa
+orr
+##open
+greyish
+contradiction
+graz
+##ruff
+handicapped
+marlene
+tweed
+oaxaca
+spp
+campos
+miocene
+pri
+configured
+cooks
+pluto
+cozy
+pornographic
+##entes
+70th
+fairness
+glided
+jonny
+lynne
+rounding
+sired
+##emon
+##nist
+remade
+uncover
+##mack
+complied
+lei
+newsweek
+##jured
+##parts
+##enting
+##pg
+293
+finer
+guerrillas
+athenian
+deng
+disused
+stepmother
+accuse
+gingerly
+seduction
+521
+confronting
+##walker
+##going
+gora
+nostalgia
+sabres
+virginity
+wrenched
+##minated
+syndication
+wielding
+eyre
+##56
+##gnon
+##igny
+behaved
+taxpayer
+sweeps
+##growth
+childless
+gallant
+##ywood
+amplified
+geraldine
+scrape
+##ffi
+babylonian
+fresco
+##rdan
+##kney
+##position
+1718
+restricting
+tack
+fukuoka
+osborn
+selector
+partnering
+##dlow
+318
+gnu
+kia
+tak
+whitley
+gables
+##54
+##mania
+mri
+softness
+immersion
+##bots
+##evsky
+1713
+chilling
+insignificant
+pcs
+##uis
+elites
+lina
+purported
+supplemental
+teaming
+##americana
+##dding
+##inton
+proficient
+rouen
+##nage
+##rret
+niccolo
+selects
+##bread
+fluffy
+1621
+gruff
+knotted
+mukherjee
+polgara
+thrash
+nicholls
+secluded
+smoothing
+thru
+corsica
+loaf
+whitaker
+inquiries
+##rrier
+##kam
+indochina
+289
+marlins
+myles
+peking
+##tea
+extracts
+pastry
+superhuman
+connacht
+vogel
+##ditional
+##het
+##udged
+##lash
+gloss
+quarries
+refit
+teaser
+##alic
+##gaon
+20s
+materialized
+sling
+camped
+pickering
+tung
+tracker
+pursuant
+##cide
+cranes
+soc
+##cini
+##typical
+##viere
+anhalt
+overboard
+workout
+chores
+fares
+orphaned
+stains
+##logie
+fenton
+surpassing
+joyah
+triggers
+##itte
+grandmaster
+##lass
+##lists
+clapping
+fraudulent
+ledger
+nagasaki
+##cor
+##nosis
+##tsa
+eucalyptus
+tun
+##icio
+##rney
+##tara
+dax
+heroism
+ina
+wrexham
+onboard
+unsigned
+##dates
+moshe
+galley
+winnie
+droplets
+exiles
+praises
+watered
+noodles
+##aia
+fein
+adi
+leland
+multicultural
+stink
+bingo
+comets
+erskine
+modernized
+canned
+constraint
+domestically
+chemotherapy
+featherweight
+stifled
+##mum
+darkly
+irresistible
+refreshing
+hasty
+isolate
+##oys
+kitchener
+planners
+##wehr
+cages
+yarn
+implant
+toulon
+elects
+childbirth
+yue
+##lind
+##lone
+cn
+rightful
+sportsman
+junctions
+remodeled
+specifies
+##rgh
+291
+##oons
+complimented
+##urgent
+lister
+ot
+##logic
+bequeathed
+cheekbones
+fontana
+gabby
+##dial
+amadeus
+corrugated
+maverick
+resented
+triangles
+##hered
+##usly
+nazareth
+tyrol
+1675
+assent
+poorer
+sectional
+aegean
+##cous
+296
+nylon
+ghanaian
+##egorical
+##weig
+cushions
+forbid
+fusiliers
+obstruction
+somerville
+##scia
+dime
+earrings
+elliptical
+leyte
+oder
+polymers
+timmy
+atm
+midtown
+piloted
+settles
+continual
+externally
+mayfield
+##uh
+enrichment
+henson
+keane
+persians
+1733
+benji
+braden
+pep
+324
+##efe
+contenders
+pepsi
+valet
+##isches
+298
+##asse
+##earing
+goofy
+stroll
+##amen
+authoritarian
+occurrences
+adversary
+ahmedabad
+tangent
+toppled
+dorchester
+1672
+modernism
+marxism
+islamist
+charlemagne
+exponential
+racks
+unicode
+brunette
+mbc
+pic
+skirmish
+##bund
+##lad
+##powered
+##yst
+hoisted
+messina
+shatter
+##ctum
+jedi
+vantage
+##music
+##neil
+clemens
+mahmoud
+corrupted
+authentication
+lowry
+nils
+##washed
+omnibus
+wounding
+jillian
+##itors
+##opped
+serialized
+narcotics
+handheld
+##arm
+##plicity
+intersecting
+stimulating
+##onis
+crate
+fellowships
+hemingway
+casinos
+climatic
+fordham
+copeland
+drip
+beatty
+leaflets
+robber
+brothel
+madeira
+##hedral
+sphinx
+ultrasound
+##vana
+valor
+forbade
+leonid
+villas
+##aldo
+duane
+marquez
+##cytes
+disadvantaged
+forearms
+kawasaki
+reacts
+consular
+lax
+uncles
+uphold
+##hopper
+concepcion
+dorsey
+lass
+##izan
+arching
+passageway
+1708
+researches
+tia
+internationals
+##graphs
+##opers
+distinguishes
+javanese
+divert
+##uven
+plotted
+##listic
+##rwin
+##erik
+##tify
+affirmative
+signifies
+validation
+##bson
+kari
+felicity
+georgina
+zulu
+##eros
+##rained
+##rath
+overcoming
+##dot
+argyll
+##rbin
+1734
+chiba
+ratification
+windy
+earls
+parapet
+##marks
+hunan
+pristine
+astrid
+punta
+##gart
+brodie
+##kota
+##oder
+malaga
+minerva
+rouse
+##phonic
+bellowed
+pagoda
+portals
+reclamation
+##gur
+##odies
+##⁄₄
+parentheses
+quoting
+allergic
+palette
+showcases
+benefactor
+heartland
+nonlinear
+##tness
+bladed
+cheerfully
+scans
+##ety
+##hone
+1666
+girlfriends
+pedersen
+hiram
+sous
+##liche
+##nator
+1683
+##nery
+##orio
+##umen
+bobo
+primaries
+smiley
+##cb
+unearthed
+uniformly
+fis
+metadata
+1635
+ind
+##oted
+recoil
+##titles
+##tura
+##ια
+406
+hilbert
+jamestown
+mcmillan
+tulane
+seychelles
+##frid
+antics
+coli
+fated
+stucco
+##grants
+1654
+bulky
+accolades
+arrays
+caledonian
+carnage
+optimism
+puebla
+##tative
+##cave
+enforcing
+rotherham
+seo
+dunlop
+aeronautics
+chimed
+incline
+zoning
+archduke
+hellenistic
+##oses
+##sions
+candi
+thong
+##ople
+magnate
+rustic
+##rsk
+projective
+slant
+##offs
+danes
+hollis
+vocalists
+##ammed
+congenital
+contend
+gesellschaft
+##ocating
+##pressive
+douglass
+quieter
+##cm
+##kshi
+howled
+salim
+spontaneously
+townsville
+buena
+southport
+##bold
+kato
+1638
+faerie
+stiffly
+##vus
+##rled
+297
+flawless
+realising
+taboo
+##7th
+bytes
+straightening
+356
+jena
+##hid
+##rmin
+cartwright
+berber
+bertram
+soloists
+411
+noses
+417
+coping
+fission
+hardin
+inca
+##cen
+1717
+mobilized
+vhf
+##raf
+biscuits
+curate
+##85
+##anial
+331
+gaunt
+neighbourhoods
+1540
+##abas
+blanca
+bypassed
+sockets
+behold
+coincidentally
+##bane
+nara
+shave
+splinter
+terrific
+##arion
+##erian
+commonplace
+juris
+redwood
+waistband
+boxed
+caitlin
+fingerprints
+jennie
+naturalized
+##ired
+balfour
+craters
+jody
+bungalow
+hugely
+quilt
+glitter
+pigeons
+undertaker
+bulging
+constrained
+goo
+##sil
+##akh
+assimilation
+reworked
+##person
+persuasion
+##pants
+felicia
+##cliff
+##ulent
+1732
+explodes
+##dun
+##inium
+##zic
+lyman
+vulture
+hog
+overlook
+begs
+northwards
+ow
+spoil
+##urer
+fatima
+favorably
+accumulate
+sargent
+sorority
+corresponded
+dispersal
+kochi
+toned
+##imi
+##lita
+internacional
+newfound
+##agger
+##lynn
+##rigue
+booths
+peanuts
+##eborg
+medicare
+muriel
+nur
+##uram
+crates
+millennia
+pajamas
+worsened
+##breakers
+jimi
+vanuatu
+yawned
+##udeau
+carousel
+##hony
+hurdle
+##ccus
+##mounted
+##pod
+rv
+##eche
+airship
+ambiguity
+compulsion
+recapture
+##claiming
+arthritis
+##osomal
+1667
+asserting
+ngc
+sniffing
+dade
+discontent
+glendale
+ported
+##amina
+defamation
+rammed
+##scent
+fling
+livingstone
+##fleet
+875
+##ppy
+apocalyptic
+comrade
+lcd
+##lowe
+cessna
+eine
+persecuted
+subsistence
+demi
+hoop
+reliefs
+710
+coptic
+progressing
+stemmed
+perpetrators
+1665
+priestess
+##nio
+dobson
+ebony
+rooster
+itf
+tortricidae
+##bbon
+##jian
+cleanup
+##jean
+##øy
+1721
+eighties
+taxonomic
+holiness
+##hearted
+##spar
+antilles
+showcasing
+stabilized
+##nb
+gia
+mascara
+michelangelo
+dawned
+##uria
+##vinsky
+extinguished
+fitz
+grotesque
+£100
+##fera
+##loid
+##mous
+barges
+neue
+throbbed
+cipher
+johnnie
+##a1
+##mpt
+outburst
+##swick
+spearheaded
+administrations
+c1
+heartbreak
+pixels
+pleasantly
+##enay
+lombardy
+plush
+##nsed
+bobbie
+##hly
+reapers
+tremor
+xiang
+minogue
+substantive
+hitch
+barak
+##wyl
+kwan
+##encia
+910
+obscene
+elegance
+indus
+surfer
+bribery
+conserve
+##hyllum
+##masters
+horatio
+##fat
+apes
+rebound
+psychotic
+##pour
+iteration
+##mium
+##vani
+botanic
+horribly
+antiques
+dispose
+paxton
+##hli
+##wg
+timeless
+1704
+disregard
+engraver
+hounds
+##bau
+##version
+looted
+uno
+facilitates
+groans
+masjid
+rutland
+antibody
+disqualification
+decatur
+footballers
+quake
+slacks
+48th
+rein
+scribe
+stabilize
+commits
+exemplary
+tho
+##hort
+##chison
+pantry
+traversed
+##hiti
+disrepair
+identifiable
+vibrated
+baccalaureate
+##nnis
+csa
+interviewing
+##iensis
+##raße
+greaves
+wealthiest
+343
+classed
+jogged
+£5
+##58
+##atal
+illuminating
+knicks
+respecting
+##uno
+scrubbed
+##iji
+##dles
+kruger
+moods
+growls
+raider
+silvia
+chefs
+kam
+vr
+cree
+percival
+##terol
+gunter
+counterattack
+defiant
+henan
+ze
+##rasia
+##riety
+equivalence
+submissions
+##fra
+##thor
+bautista
+mechanically
+##heater
+cornice
+herbal
+templar
+##mering
+outputs
+ruining
+ligand
+renumbered
+extravagant
+mika
+blockbuster
+eta
+insurrection
+##ilia
+darkening
+ferocious
+pianos
+strife
+kinship
+##aer
+melee
+##anor
+##iste
+##may
+##oue
+decidedly
+weep
+##jad
+##missive
+##ppel
+354
+puget
+unease
+##gnant
+1629
+hammering
+kassel
+ob
+wessex
+##lga
+bromwich
+egan
+paranoia
+utilization
+##atable
+##idad
+contradictory
+provoke
+##ols
+##ouring
+##tangled
+knesset
+##very
+##lette
+plumbing
+##sden
+##¹
+greensboro
+occult
+sniff
+338
+zev
+beaming
+gamer
+haggard
+mahal
+##olt
+##pins
+mendes
+utmost
+briefing
+gunnery
+##gut
+##pher
+##zh
+##rok
+1679
+khalifa
+sonya
+##boot
+principals
+urbana
+wiring
+##liffe
+##minating
+##rrado
+dahl
+nyu
+skepticism
+np
+townspeople
+ithaca
+lobster
+somethin
+##fur
+##arina
+##−1
+freighter
+zimmerman
+biceps
+contractual
+##herton
+amend
+hurrying
+subconscious
+##anal
+336
+meng
+clermont
+spawning
+##eia
+##lub
+dignitaries
+impetus
+snacks
+spotting
+twigs
+##bilis
+##cz
+##ouk
+libertadores
+nic
+skylar
+##aina
+##firm
+gustave
+asean
+##anum
+dieter
+legislatures
+flirt
+bromley
+trolls
+umar
+##bbies
+##tyle
+blah
+parc
+bridgeport
+crank
+negligence
+##nction
+46th
+constantin
+molded
+bandages
+seriousness
+00pm
+siegel
+carpets
+compartments
+upbeat
+statehood
+##dner
+##edging
+marko
+730
+platt
+##hane
+paving
+##iy
+1738
+abbess
+impatience
+limousine
+nbl
+##talk
+441
+lucille
+mojo
+nightfall
+robbers
+##nais
+karel
+brisk
+calves
+replicate
+ascribed
+telescopes
+##olf
+intimidated
+##reen
+ballast
+specialization
+##sit
+aerodynamic
+caliphate
+rainer
+visionary
+##arded
+epsilon
+##aday
+##onte
+aggregation
+auditory
+boosted
+reunification
+kathmandu
+loco
+robyn
+402
+acknowledges
+appointing
+humanoid
+newell
+redeveloped
+restraints
+##tained
+barbarians
+chopper
+1609
+italiana
+##lez
+##lho
+investigates
+wrestlemania
+##anies
+##bib
+690
+##falls
+creaked
+dragoons
+gravely
+minions
+stupidity
+volley
+##harat
+##week
+musik
+##eries
+##uously
+fungal
+massimo
+semantics
+malvern
+##ahl
+##pee
+discourage
+embryo
+imperialism
+1910s
+profoundly
+##ddled
+jiangsu
+sparkled
+stat
+##holz
+sweatshirt
+tobin
+##iction
+sneered
+##cheon
+##oit
+brit
+causal
+smyth
+##neuve
+diffuse
+perrin
+silvio
+##ipes
+##recht
+detonated
+iqbal
+selma
+##nism
+##zumi
+roasted
+##riders
+tay
+##ados
+##mament
+##mut
+##rud
+840
+completes
+nipples
+cfa
+flavour
+hirsch
+##laus
+calderon
+sneakers
+moravian
+##ksha
+1622
+rq
+294
+##imeters
+bodo
+##isance
+##pre
+##ronia
+anatomical
+excerpt
+##lke
+dh
+kunst
+##tablished
+##scoe
+biomass
+panted
+unharmed
+gael
+housemates
+montpellier
+##59
+coa
+rodents
+tonic
+hickory
+singleton
+##taro
+451
+1719
+aldo
+breaststroke
+dempsey
+och
+rocco
+##cuit
+merton
+dissemination
+midsummer
+serials
+##idi
+haji
+polynomials
+##rdon
+gs
+enoch
+prematurely
+shutter
+taunton
+£3
+##grating
+##inates
+archangel
+harassed
+##asco
+326
+archway
+dazzling
+##ecin
+1736
+sumo
+wat
+##kovich
+1086
+honneur
+##ently
+##nostic
+##ttal
+##idon
+1605
+403
+1716
+blogger
+rents
+##gnan
+hires
+##ikh
+##dant
+howie
+##rons
+handler
+retracted
+shocks
+1632
+arun
+duluth
+kepler
+trumpeter
+##lary
+peeking
+seasoned
+trooper
+##mara
+laszlo
+##iciencies
+##rti
+heterosexual
+##inatory
+##ssion
+indira
+jogging
+##inga
+##lism
+beit
+dissatisfaction
+malice
+##ately
+nedra
+peeling
+##rgeon
+47th
+stadiums
+475
+vertigo
+##ains
+iced
+restroom
+##plify
+##tub
+illustrating
+pear
+##chner
+##sibility
+inorganic
+rappers
+receipts
+watery
+##kura
+lucinda
+##oulos
+reintroduced
+##8th
+##tched
+gracefully
+saxons
+nutritional
+wastewater
+rained
+favourites
+bedrock
+fisted
+hallways
+likeness
+upscale
+##lateral
+1580
+blinds
+prequel
+##pps
+##tama
+deter
+humiliating
+restraining
+tn
+vents
+1659
+laundering
+recess
+rosary
+tractors
+coulter
+federer
+##ifiers
+##plin
+persistence
+##quitable
+geschichte
+pendulum
+quakers
+##beam
+bassett
+pictorial
+buffet
+koln
+##sitor
+drills
+reciprocal
+shooters
+##57
+##cton
+##tees
+converge
+pip
+dmitri
+donnelly
+yamamoto
+aqua
+azores
+demographics
+hypnotic
+spitfire
+suspend
+wryly
+roderick
+##rran
+sebastien
+##asurable
+mavericks
+##fles
+##200
+himalayan
+prodigy
+##iance
+transvaal
+demonstrators
+handcuffs
+dodged
+mcnamara
+sublime
+1726
+crazed
+##efined
+##till
+ivo
+pondered
+reconciled
+shrill
+sava
+##duk
+bal
+cad
+heresy
+jaipur
+goran
+##nished
+341
+lux
+shelly
+whitehall
+##hre
+israelis
+peacekeeping
+##wled
+1703
+demetrius
+ousted
+##arians
+##zos
+beale
+anwar
+backstroke
+raged
+shrinking
+cremated
+##yck
+benign
+towing
+wadi
+darmstadt
+landfill
+parana
+soothe
+colleen
+sidewalks
+mayfair
+tumble
+hepatitis
+ferrer
+superstructure
+##gingly
+##urse
+##wee
+anthropological
+translators
+##mies
+closeness
+hooves
+##pw
+mondays
+##roll
+##vita
+landscaping
+##urized
+purification
+sock
+thorns
+thwarted
+jalan
+tiberius
+##taka
+saline
+##rito
+confidently
+khyber
+sculptors
+##ij
+brahms
+hammersmith
+inspectors
+battista
+fivb
+fragmentation
+hackney
+##uls
+arresting
+exercising
+antoinette
+bedfordshire
+##zily
+dyed
+##hema
+1656
+racetrack
+variability
+##tique
+1655
+austrians
+deteriorating
+madman
+theorists
+aix
+lehman
+weathered
+1731
+decreed
+eruptions
+1729
+flaw
+quinlan
+sorbonne
+flutes
+nunez
+1711
+adored
+downwards
+fable
+rasped
+1712
+moritz
+mouthful
+renegade
+shivers
+stunts
+dysfunction
+restrain
+translit
+327
+pancakes
+##avio
+##cision
+##tray
+351
+vial
+##lden
+bain
+##maid
+##oxide
+chihuahua
+malacca
+vimes
+##rba
+##rnier
+1664
+donnie
+plaques
+##ually
+337
+bangs
+floppy
+huntsville
+loretta
+nikolay
+##otte
+eater
+handgun
+ubiquitous
+##hett
+eras
+zodiac
+1634
+##omorphic
+1820s
+##zog
+cochran
+##bula
+##lithic
+warring
+##rada
+dalai
+excused
+blazers
+mcconnell
+reeling
+bot
+este
+##abi
+geese
+hoax
+taxon
+##bla
+guitarists
+##icon
+condemning
+hunts
+inversion
+moffat
+taekwondo
+##lvis
+1624
+stammered
+##rest
+##rzy
+sousa
+fundraiser
+marylebone
+navigable
+uptown
+cabbage
+daniela
+salman
+shitty
+whimper
+##kian
+##utive
+programmers
+protections
+rm
+##rmi
+##rued
+forceful
+##enes
+fuss
+##tao
+##wash
+brat
+oppressive
+reykjavik
+spartak
+ticking
+##inkles
+##kiewicz
+adolph
+horst
+maui
+protege
+straighten
+cpc
+landau
+concourse
+clements
+resultant
+##ando
+imaginative
+joo
+reactivated
+##rem
+##ffled
+##uising
+consultative
+##guide
+flop
+kaitlyn
+mergers
+parenting
+somber
+##vron
+supervise
+vidhan
+##imum
+courtship
+exemplified
+harmonies
+medallist
+refining
+##rrow
+##ка
+amara
+##hum
+780
+goalscorer
+sited
+overshadowed
+rohan
+displeasure
+secretive
+multiplied
+osman
+##orth
+engravings
+padre
+##kali
+##veda
+miniatures
+mis
+##yala
+clap
+pali
+rook
+##cana
+1692
+57th
+antennae
+astro
+oskar
+1628
+bulldog
+crotch
+hackett
+yucatan
+##sure
+amplifiers
+brno
+ferrara
+migrating
+##gree
+thanking
+turing
+##eza
+mccann
+ting
+andersson
+onslaught
+gaines
+ganga
+incense
+standardization
+##mation
+sentai
+scuba
+stuffing
+turquoise
+waivers
+alloys
+##vitt
+regaining
+vaults
+##clops
+##gizing
+digger
+furry
+memorabilia
+probing
+##iad
+payton
+rec
+deutschland
+filippo
+opaque
+seamen
+zenith
+afrikaans
+##filtration
+disciplined
+inspirational
+##merie
+banco
+confuse
+grafton
+tod
+##dgets
+championed
+simi
+anomaly
+biplane
+##ceptive
+electrode
+##para
+1697
+cleavage
+crossbow
+swirl
+informant
+##lars
+##osta
+afi
+bonfire
+spec
+##oux
+lakeside
+slump
+##culus
+##lais
+##qvist
+##rrigan
+1016
+facades
+borg
+inwardly
+cervical
+xl
+pointedly
+050
+stabilization
+##odon
+chests
+1699
+hacked
+ctv
+orthogonal
+suzy
+##lastic
+gaulle
+jacobite
+rearview
+##cam
+##erted
+ashby
+##drik
+##igate
+##mise
+##zbek
+affectionately
+canine
+disperse
+latham
+##istles
+##ivar
+spielberg
+##orin
+##idium
+ezekiel
+cid
+##sg
+durga
+middletown
+##cina
+customized
+frontiers
+harden
+##etano
+##zzy
+1604
+bolsheviks
+##66
+coloration
+yoko
+##bedo
+briefs
+slabs
+debra
+liquidation
+plumage
+##oin
+blossoms
+dementia
+subsidy
+1611
+proctor
+relational
+jerseys
+parochial
+ter
+##ici
+esa
+peshawar
+cavalier
+loren
+cpi
+idiots
+shamrock
+1646
+dutton
+malabar
+mustache
+##endez
+##ocytes
+referencing
+terminates
+marche
+yarmouth
+##sop
+acton
+mated
+seton
+subtly
+baptised
+beige
+extremes
+jolted
+kristina
+telecast
+##actic
+safeguard
+waldo
+##baldi
+##bular
+endeavors
+sloppy
+subterranean
+##ensburg
+##itung
+delicately
+pigment
+tq
+##scu
+1626
+##ound
+collisions
+coveted
+herds
+##personal
+##meister
+##nberger
+chopra
+##ricting
+abnormalities
+defective
+galician
+lucie
+##dilly
+alligator
+likened
+##genase
+burundi
+clears
+complexion
+derelict
+deafening
+diablo
+fingered
+champaign
+dogg
+enlist
+isotope
+labeling
+mrna
+##erre
+brilliance
+marvelous
+##ayo
+1652
+crawley
+ether
+footed
+dwellers
+deserts
+hamish
+rubs
+warlock
+skimmed
+##lizer
+870
+buick
+embark
+heraldic
+irregularities
+##ajan
+kiara
+##kulam
+##ieg
+antigen
+kowalski
+##lge
+oakley
+visitation
+##mbit
+vt
+##suit
+1570
+murderers
+##miento
+##rites
+chimneys
+##sling
+condemn
+custer
+exchequer
+havre
+##ghi
+fluctuations
+##rations
+dfb
+hendricks
+vaccines
+##tarian
+nietzsche
+biking
+juicy
+##duced
+brooding
+scrolling
+selangor
+##ragan
+352
+annum
+boomed
+seminole
+sugarcane
+##dna
+departmental
+dismissing
+innsbruck
+arteries
+ashok
+batavia
+daze
+kun
+overtook
+##rga
+##tlan
+beheaded
+gaddafi
+holm
+electronically
+faulty
+galilee
+fractures
+kobayashi
+##lized
+gunmen
+magma
+aramaic
+mala
+eastenders
+inference
+messengers
+bf
+##qu
+407
+bathrooms
+##vere
+1658
+flashbacks
+ideally
+misunderstood
+##jali
+##weather
+mendez
+##grounds
+505
+uncanny
+##iii
+1709
+friendships
+##nbc
+sacrament
+accommodated
+reiterated
+logistical
+pebbles
+thumped
+##escence
+administering
+decrees
+drafts
+##flight
+##cased
+##tula
+futuristic
+picket
+intimidation
+winthrop
+##fahan
+interfered
+339
+afar
+francoise
+morally
+uta
+cochin
+croft
+dwarfs
+##bruck
+##dents
+##nami
+biker
+##hner
+##meral
+nano
+##isen
+##ometric
+##pres
+##ан
+brightened
+meek
+parcels
+securely
+gunners
+##jhl
+##zko
+agile
+hysteria
+##lten
+##rcus
+bukit
+champs
+chevy
+cuckoo
+leith
+sadler
+theologians
+welded
+##section
+1663
+jj
+plurality
+xander
+##rooms
+##formed
+shredded
+temps
+intimately
+pau
+tormented
+##lok
+##stellar
+1618
+charred
+ems
+essen
+##mmel
+alarms
+spraying
+ascot
+blooms
+twinkle
+##abia
+##apes
+internment
+obsidian
+##chaft
+snoop
+##dav
+##ooping
+malibu
+##tension
+quiver
+##itia
+hays
+mcintosh
+travers
+walsall
+##ffie
+1623
+beverley
+schwarz
+plunging
+structurally
+m3
+rosenthal
+vikram
+##tsk
+770
+ghz
+##onda
+##tiv
+chalmers
+groningen
+pew
+reckon
+unicef
+##rvis
+55th
+##gni
+1651
+sulawesi
+avila
+cai
+metaphysical
+screwing
+turbulence
+##mberg
+augusto
+samba
+56th
+baffled
+momentary
+toxin
+##urian
+##wani
+aachen
+condoms
+dali
+steppe
+##3d
+##app
+##oed
+##year
+adolescence
+dauphin
+electrically
+inaccessible
+microscopy
+nikita
+##ega
+atv
+##cel
+##enter
+##oles
+##oteric
+##ы
+accountants
+punishments
+wrongly
+bribes
+adventurous
+clinch
+flinders
+southland
+##hem
+##kata
+gough
+##ciency
+lads
+soared
+##ה
+undergoes
+deformation
+outlawed
+rubbish
+##arus
+##mussen
+##nidae
+##rzburg
+arcs
+##ingdon
+##tituted
+1695
+wheelbase
+wheeling
+bombardier
+campground
+zebra
+##lices
+##oj
+##bain
+lullaby
+##ecure
+donetsk
+wylie
+grenada
+##arding
+##ης
+squinting
+eireann
+opposes
+##andra
+maximal
+runes
+##broken
+##cuting
+##iface
+##ror
+##rosis
+additive
+britney
+adultery
+triggering
+##drome
+detrimental
+aarhus
+containment
+jc
+swapped
+vichy
+##ioms
+madly
+##oric
+##rag
+brant
+##ckey
+##trix
+1560
+1612
+broughton
+rustling
+##stems
+##uder
+asbestos
+mentoring
+##nivorous
+finley
+leaps
+##isan
+apical
+pry
+slits
+substitutes
+##dict
+intuitive
+fantasia
+insistent
+unreasonable
+##igen
+##vna
+domed
+hannover
+margot
+ponder
+##zziness
+impromptu
+jian
+lc
+rampage
+stemming
+##eft
+andrey
+gerais
+whichever
+amnesia
+appropriated
+anzac
+clicks
+modifying
+ultimatum
+cambrian
+maids
+verve
+yellowstone
+##mbs
+conservatoire
+##scribe
+adherence
+dinners
+spectra
+imperfect
+mysteriously
+sidekick
+tatar
+tuba
+##aks
+##ifolia
+distrust
+##athan
+##zle
+c2
+ronin
+zac
+##pse
+celaena
+instrumentalist
+scents
+skopje
+##mbling
+comical
+compensated
+vidal
+condor
+intersect
+jingle
+wavelengths
+##urrent
+mcqueen
+##izzly
+carp
+weasel
+422
+kanye
+militias
+postdoctoral
+eugen
+gunslinger
+##ɛ
+faux
+hospice
+##for
+appalled
+derivation
+dwarves
+##elis
+dilapidated
+##folk
+astoria
+philology
+##lwyn
+##otho
+##saka
+inducing
+philanthropy
+##bf
+##itative
+geek
+markedly
+sql
+##yce
+bessie
+indices
+rn
+##flict
+495
+frowns
+resolving
+weightlifting
+tugs
+cleric
+contentious
+1653
+mania
+rms
+##miya
+##reate
+##ruck
+##tucket
+bien
+eels
+marek
+##ayton
+##cence
+discreet
+unofficially
+##ife
+leaks
+##bber
+1705
+332
+dung
+compressor
+hillsborough
+pandit
+shillings
+distal
+##skin
+381
+##tat
+##you
+nosed
+##nir
+mangrove
+undeveloped
+##idia
+textures
+##inho
+##500
+##rise
+ae
+irritating
+nay
+amazingly
+bancroft
+apologetic
+compassionate
+kata
+symphonies
+##lovic
+airspace
+##lch
+930
+gifford
+precautions
+fulfillment
+sevilla
+vulgar
+martinique
+##urities
+looting
+piccolo
+tidy
+##dermott
+quadrant
+armchair
+incomes
+mathematicians
+stampede
+nilsson
+##inking
+##scan
+foo
+quarterfinal
+##ostal
+shang
+shouldered
+squirrels
+##owe
+344
+vinegar
+##bner
+##rchy
+##systems
+delaying
+##trics
+ars
+dwyer
+rhapsody
+sponsoring
+##gration
+bipolar
+cinder
+starters
+##olio
+##urst
+421
+signage
+##nty
+aground
+figurative
+mons
+acquaintances
+duets
+erroneously
+soyuz
+elliptic
+recreated
+##cultural
+##quette
+##ssed
+##tma
+##zcz
+moderator
+scares
+##itaire
+##stones
+##udence
+juniper
+sighting
+##just
+##nsen
+britten
+calabria
+ry
+bop
+cramer
+forsyth
+stillness
+##л
+airmen
+gathers
+unfit
+##umber
+##upt
+taunting
+##rip
+seeker
+streamlined
+##bution
+holster
+schumann
+tread
+vox
+##gano
+##onzo
+strive
+dil
+reforming
+covent
+newbury
+predicting
+##orro
+decorate
+tre
+##puted
+andover
+ie
+asahi
+dept
+dunkirk
+gills
+##tori
+buren
+huskies
+##stis
+##stov
+abstracts
+bets
+loosen
+##opa
+1682
+yearning
+##glio
+##sir
+berman
+effortlessly
+enamel
+napoli
+persist
+##peration
+##uez
+attache
+elisa
+b1
+invitations
+##kic
+accelerating
+reindeer
+boardwalk
+clutches
+nelly
+polka
+starbucks
+##kei
+adamant
+huey
+lough
+unbroken
+adventurer
+embroidery
+inspecting
+stanza
+##ducted
+naia
+taluka
+##pone
+##roids
+chases
+deprivation
+florian
+##jing
+##ppet
+earthly
+##lib
+##ssee
+colossal
+foreigner
+vet
+freaks
+patrice
+rosewood
+triassic
+upstate
+##pkins
+dominates
+ata
+chants
+ks
+vo
+##400
+##bley
+##raya
+##rmed
+555
+agra
+infiltrate
+##ailing
+##ilation
+##tzer
+##uppe
+##werk
+binoculars
+enthusiast
+fujian
+squeak
+##avs
+abolitionist
+almeida
+boredom
+hampstead
+marsden
+rations
+##ands
+inflated
+334
+bonuses
+rosalie
+patna
+##rco
+329
+detachments
+penitentiary
+54th
+flourishing
+woolf
+##dion
+##etched
+papyrus
+##lster
+##nsor
+##toy
+bobbed
+dismounted
+endelle
+inhuman
+motorola
+tbs
+wince
+wreath
+##ticus
+hideout
+inspections
+sanjay
+disgrace
+infused
+pudding
+stalks
+##urbed
+arsenic
+leases
+##hyl
+##rrard
+collarbone
+##waite
+##wil
+dowry
+##bant
+##edance
+genealogical
+nitrate
+salamanca
+scandals
+thyroid
+necessitated
+##!
+##"
+###
+##$
+##%
+##&
+##'
+##(
+##)
+##*
+##+
+##,
+##-
+##.
+##/
+##:
+##;
+##<
+##=
+##>
+##?
+##@
+##[
+##\
+##]
+##^
+##_
+##`
+##{
+##|
+##}
+##~
+##¡
+##¢
+##£
+##¤
+##¥
+##¦
+##§
+##¨
+##©
+##ª
+##«
+##¬
+##®
+##±
+##´
+##µ
+##¶
+##·
+##º
+##»
+##¼
+##¾
+##¿
+##æ
+##ð
+##÷
+##þ
+##đ
+##ħ
+##ŋ
+##œ
+##ƒ
+##ɐ
+##ɑ
+##ɒ
+##ɔ
+##ɕ
+##ə
+##ɡ
+##ɣ
+##ɨ
+##ɪ
+##ɫ
+##ɬ
+##ɯ
+##ɲ
+##ɴ
+##ɹ
+##ɾ
+##ʀ
+##ʁ
+##ʂ
+##ʃ
+##ʉ
+##ʊ
+##ʋ
+##ʌ
+##ʎ
+##ʐ
+##ʑ
+##ʒ
+##ʔ
+##ʰ
+##ʲ
+##ʳ
+##ʷ
+##ʸ
+##ʻ
+##ʼ
+##ʾ
+##ʿ
+##ˈ
+##ˡ
+##ˢ
+##ˣ
+##ˤ
+##β
+##γ
+##δ
+##ε
+##ζ
+##θ
+##κ
+##λ
+##μ
+##ξ
+##ο
+##π
+##ρ
+##σ
+##τ
+##υ
+##φ
+##χ
+##ψ
+##ω
+##б
+##г
+##д
+##ж
+##з
+##м
+##п
+##с
+##у
+##ф
+##х
+##ц
+##ч
+##ш
+##щ
+##ъ
+##э
+##ю
+##ђ
+##є
+##і
+##ј
+##љ
+##њ
+##ћ
+##ӏ
+##ա
+##բ
+##գ
+##դ
+##ե
+##թ
+##ի
+##լ
+##կ
+##հ
+##մ
+##յ
+##ն
+##ո
+##պ
+##ս
+##վ
+##տ
+##ր
+##ւ
+##ք
+##־
+##א
+##ב
+##ג
+##ד
+##ו
+##ז
+##ח
+##ט
+##י
+##ך
+##כ
+##ל
+##ם
+##מ
+##ן
+##נ
+##ס
+##ע
+##ף
+##פ
+##ץ
+##צ
+##ק
+##ר
+##ש
+##ת
+##،
+##ء
+##ب
+##ت
+##ث
+##ج
+##ح
+##خ
+##ذ
+##ز
+##س
+##ش
+##ص
+##ض
+##ط
+##ظ
+##ع
+##غ
+##ـ
+##ف
+##ق
+##ك
+##و
+##ى
+##ٹ
+##پ
+##چ
+##ک
+##گ
+##ں
+##ھ
+##ہ
+##ے
+##अ
+##आ
+##उ
+##ए
+##क
+##ख
+##ग
+##च
+##ज
+##ट
+##ड
+##ण
+##त
+##थ
+##द
+##ध
+##न
+##प
+##ब
+##भ
+##म
+##य
+##र
+##ल
+##व
+##श
+##ष
+##स
+##ह
+##ा
+##ि
+##ी
+##ो
+##।
+##॥
+##ং
+##অ
+##আ
+##ই
+##উ
+##এ
+##ও
+##ক
+##খ
+##গ
+##চ
+##ছ
+##জ
+##ট
+##ড
+##ণ
+##ত
+##থ
+##দ
+##ধ
+##ন
+##প
+##ব
+##ভ
+##ম
+##য
+##র
+##ল
+##শ
+##ষ
+##স
+##হ
+##া
+##ি
+##ী
+##ে
+##க
+##ச
+##ட
+##த
+##ந
+##ன
+##ப
+##ம
+##ய
+##ர
+##ல
+##ள
+##வ
+##ா
+##ி
+##ு
+##ே
+##ை
+##ನ
+##ರ
+##ಾ
+##ක
+##ය
+##ර
+##ල
+##ව
+##ා
+##ก
+##ง
+##ต
+##ท
+##น
+##พ
+##ม
+##ย
+##ร
+##ล
+##ว
+##ส
+##อ
+##า
+##เ
+##་
+##།
+##ག
+##ང
+##ད
+##ན
+##པ
+##བ
+##མ
+##འ
+##ར
+##ལ
+##ས
+##မ
+##ა
+##ბ
+##გ
+##დ
+##ე
+##ვ
+##თ
+##ი
+##კ
+##ლ
+##მ
+##ნ
+##ო
+##რ
+##ს
+##ტ
+##უ
+##ᄀ
+##ᄂ
+##ᄃ
+##ᄅ
+##ᄆ
+##ᄇ
+##ᄉ
+##ᄊ
+##ᄋ
+##ᄌ
+##ᄎ
+##ᄏ
+##ᄐ
+##ᄑ
+##ᄒ
+##ᅡ
+##ᅢ
+##ᅥ
+##ᅦ
+##ᅧ
+##ᅩ
+##ᅪ
+##ᅭ
+##ᅮ
+##ᅯ
+##ᅲ
+##ᅳ
+##ᅴ
+##ᅵ
+##ᆨ
+##ᆫ
+##ᆯ
+##ᆷ
+##ᆸ
+##ᆼ
+##ᴬ
+##ᴮ
+##ᴰ
+##ᴵ
+##ᴺ
+##ᵀ
+##ᵃ
+##ᵇ
+##ᵈ
+##ᵉ
+##ᵍ
+##ᵏ
+##ᵐ
+##ᵒ
+##ᵖ
+##ᵗ
+##ᵘ
+##ᵣ
+##ᵤ
+##ᵥ
+##ᶜ
+##ᶠ
+##‐
+##‑
+##‒
+##–
+##—
+##―
+##‖
+##‘
+##’
+##‚
+##“
+##”
+##„
+##†
+##‡
+##•
+##…
+##‰
+##′
+##″
+##›
+##‿
+##⁄
+##⁰
+##ⁱ
+##⁴
+##⁵
+##⁶
+##⁷
+##⁸
+##⁹
+##⁻
+##ⁿ
+##₅
+##₆
+##₇
+##₈
+##₉
+##₊
+##₍
+##₎
+##ₐ
+##ₑ
+##ₒ
+##ₓ
+##ₕ
+##ₖ
+##ₗ
+##ₘ
+##ₚ
+##ₛ
+##ₜ
+##₤
+##₩
+##€
+##₱
+##₹
+##ℓ
+##№
+##ℝ
+##™
+##⅓
+##⅔
+##←
+##↑
+##→
+##↓
+##↔
+##↦
+##⇄
+##⇌
+##⇒
+##∂
+##∅
+##∆
+##∇
+##∈
+##∗
+##∘
+##√
+##∞
+##∧
+##∨
+##∩
+##∪
+##≈
+##≡
+##≤
+##≥
+##⊂
+##⊆
+##⊕
+##⊗
+##⋅
+##─
+##│
+##■
+##▪
+##●
+##★
+##☆
+##☉
+##♠
+##♣
+##♥
+##♦
+##♯
+##⟨
+##⟩
+##ⱼ
+##⺩
+##⺼
+##⽥
+##、
+##。
+##〈
+##〉
+##《
+##》
+##「
+##」
+##『
+##』
+##〜
+##あ
+##い
+##う
+##え
+##お
+##か
+##き
+##く
+##け
+##こ
+##さ
+##し
+##す
+##せ
+##そ
+##た
+##ち
+##っ
+##つ
+##て
+##と
+##な
+##に
+##ぬ
+##ね
+##の
+##は
+##ひ
+##ふ
+##へ
+##ほ
+##ま
+##み
+##む
+##め
+##も
+##や
+##ゆ
+##よ
+##ら
+##り
+##る
+##れ
+##ろ
+##を
+##ん
+##ァ
+##ア
+##ィ
+##イ
+##ウ
+##ェ
+##エ
+##オ
+##カ
+##キ
+##ク
+##ケ
+##コ
+##サ
+##シ
+##ス
+##セ
+##タ
+##チ
+##ッ
+##ツ
+##テ
+##ト
+##ナ
+##ニ
+##ノ
+##ハ
+##ヒ
+##フ
+##ヘ
+##ホ
+##マ
+##ミ
+##ム
+##メ
+##モ
+##ャ
+##ュ
+##ョ
+##ラ
+##リ
+##ル
+##レ
+##ロ
+##ワ
+##ン
+##・
+##ー
+##一
+##三
+##上
+##下
+##不
+##世
+##中
+##主
+##久
+##之
+##也
+##事
+##二
+##五
+##井
+##京
+##人
+##亻
+##仁
+##介
+##代
+##仮
+##伊
+##会
+##佐
+##侍
+##保
+##信
+##健
+##元
+##光
+##八
+##公
+##内
+##出
+##分
+##前
+##劉
+##力
+##加
+##勝
+##北
+##区
+##十
+##千
+##南
+##博
+##原
+##口
+##古
+##史
+##司
+##合
+##吉
+##同
+##名
+##和
+##囗
+##四
+##国
+##國
+##土
+##地
+##坂
+##城
+##堂
+##場
+##士
+##夏
+##外
+##大
+##天
+##太
+##夫
+##奈
+##女
+##子
+##学
+##宀
+##宇
+##安
+##宗
+##定
+##宣
+##宮
+##家
+##宿
+##寺
+##將
+##小
+##尚
+##山
+##岡
+##島
+##崎
+##川
+##州
+##巿
+##帝
+##平
+##年
+##幸
+##广
+##弘
+##張
+##彳
+##後
+##御
+##德
+##心
+##忄
+##志
+##忠
+##愛
+##成
+##我
+##戦
+##戸
+##手
+##扌
+##政
+##文
+##新
+##方
+##日
+##明
+##星
+##春
+##昭
+##智
+##曲
+##書
+##月
+##有
+##朝
+##木
+##本
+##李
+##村
+##東
+##松
+##林
+##森
+##楊
+##樹
+##橋
+##歌
+##止
+##正
+##武
+##比
+##氏
+##民
+##水
+##氵
+##氷
+##永
+##江
+##沢
+##河
+##治
+##法
+##海
+##清
+##漢
+##瀬
+##火
+##版
+##犬
+##王
+##生
+##田
+##男
+##疒
+##発
+##白
+##的
+##皇
+##目
+##相
+##省
+##真
+##石
+##示
+##社
+##神
+##福
+##禾
+##秀
+##秋
+##空
+##立
+##章
+##竹
+##糹
+##美
+##義
+##耳
+##良
+##艹
+##花
+##英
+##華
+##葉
+##藤
+##行
+##街
+##西
+##見
+##訁
+##語
+##谷
+##貝
+##貴
+##車
+##軍
+##辶
+##道
+##郎
+##郡
+##部
+##都
+##里
+##野
+##金
+##鈴
+##镇
+##長
+##門
+##間
+##阝
+##阿
+##陳
+##陽
+##雄
+##青
+##面
+##風
+##食
+##香
+##馬
+##高
+##龍
+##龸
+##fi
+##fl
+##!
+##(
+##)
+##,
+##-
+##.
+##/
+##:
+##?
+##~
diff --git a/native/annotator/pod_ner/utils.cc b/native/annotator/pod_ner/utils.cc
new file mode 100644
index 0000000..136a996
--- /dev/null
+++ b/native/annotator/pod_ner/utils.cc
@@ -0,0 +1,436 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/pod_ner/utils.h"
+
+#include <algorithm>
+#include <iostream>
+#include <unordered_map>
+
+#include "annotator/model_generated.h"
+#include "annotator/types.h"
+#include "utils/base/logging.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_split.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+// Returns true if the needle string is contained in the haystack.
+bool StrIsOneOf(const std::string &needle,
+ const std::vector<std::string> &haystack) {
+ return std::find(haystack.begin(), haystack.end(), needle) != haystack.end();
+}
+
+// Finds the wordpiece span of the tokens in the given span.
+WordpieceSpan CodepointSpanToWordpieceSpan(
+ const CodepointSpan &span, const std::vector<Token> &tokens,
+ const std::vector<int32_t> &word_starts, int num_wordpieces) {
+ int span_first_wordpiece_index = 0;
+ int span_last_wordpiece_index = num_wordpieces;
+ for (int i = 0; i < tokens.size(); i++) {
+ if (tokens[i].start <= span.first && span.first < tokens[i].end) {
+ span_first_wordpiece_index = word_starts[i];
+ }
+ if (tokens[i].start <= span.second && span.second <= tokens[i].end) {
+ span_last_wordpiece_index =
+ (i + 1) < word_starts.size() ? word_starts[i + 1] : num_wordpieces;
+ break;
+ }
+ }
+ return WordpieceSpan(span_first_wordpiece_index, span_last_wordpiece_index);
+}
+
+} // namespace
+
+std::string SaftLabelToCollection(absl::string_view saft_label) {
+ return std::string(saft_label.substr(saft_label.rfind('/') + 1));
+}
+
+namespace internal {
+
+int FindLastFullTokenIndex(const std::vector<int32_t> &word_starts,
+ int num_wordpieces, int wordpiece_end) {
+ if (word_starts.empty()) {
+ return 0;
+ }
+ if (*word_starts.rbegin() < wordpiece_end &&
+ num_wordpieces <= wordpiece_end) {
+ // Last token.
+ return word_starts.size() - 1;
+ }
+ for (int i = word_starts.size() - 1; i > 0; --i) {
+ if (word_starts[i] <= wordpiece_end) {
+ return (i - 1);
+ }
+ }
+ return 0;
+}
+
+int FindFirstFullTokenIndex(const std::vector<int32_t> &word_starts,
+ int first_wordpiece_index) {
+ for (int i = 0; i < word_starts.size(); ++i) {
+ if (word_starts[i] == first_wordpiece_index) {
+ return i;
+ } else if (word_starts[i] > first_wordpiece_index) {
+ return std::max(0, i - 1);
+ }
+ }
+
+ return std::max(0, static_cast<int>(word_starts.size()) - 1);
+}
+
+WordpieceSpan ExpandWindowAndAlign(int max_num_wordpieces_in_window,
+ int num_wordpieces,
+ WordpieceSpan wordpiece_span_to_expand) {
+ if (wordpiece_span_to_expand.length() >= max_num_wordpieces_in_window) {
+ return wordpiece_span_to_expand;
+ }
+ int window_first_wordpiece_index = std::max(
+ 0, wordpiece_span_to_expand.begin - ((max_num_wordpieces_in_window -
+ wordpiece_span_to_expand.length()) /
+ 2));
+ if ((window_first_wordpiece_index + max_num_wordpieces_in_window) >
+ num_wordpieces) {
+ window_first_wordpiece_index =
+ std::max(num_wordpieces - max_num_wordpieces_in_window, 0);
+ }
+ return WordpieceSpan(
+ window_first_wordpiece_index,
+ std::min(window_first_wordpiece_index + max_num_wordpieces_in_window,
+ num_wordpieces));
+}
+
+WordpieceSpan FindWordpiecesWindowAroundSpan(
+ const CodepointSpan &span_of_interest, const std::vector<Token> &tokens,
+ const std::vector<int32_t> &word_starts, int num_wordpieces,
+ int max_num_wordpieces_in_window) {
+ WordpieceSpan wordpiece_span_to_expand = CodepointSpanToWordpieceSpan(
+ span_of_interest, tokens, word_starts, num_wordpieces);
+ WordpieceSpan max_wordpiece_span = ExpandWindowAndAlign(
+ max_num_wordpieces_in_window, num_wordpieces, wordpiece_span_to_expand);
+ return max_wordpiece_span;
+}
+
+WordpieceSpan FindFullTokensSpanInWindow(
+ const std::vector<int32_t> &word_starts,
+ const WordpieceSpan &wordpiece_span, int max_num_wordpieces,
+ int num_wordpieces, int *first_token_index, int *num_tokens) {
+ int window_first_wordpiece_index = wordpiece_span.begin;
+ *first_token_index = internal::FindFirstFullTokenIndex(
+ word_starts, window_first_wordpiece_index);
+ window_first_wordpiece_index = word_starts[*first_token_index];
+
+ // Need to update the last index in case the first moved backward.
+ int wordpiece_window_end = std::min(
+ wordpiece_span.end, window_first_wordpiece_index + max_num_wordpieces);
+ int last_token_index;
+ last_token_index = internal::FindLastFullTokenIndex(
+ word_starts, num_wordpieces, wordpiece_window_end);
+ wordpiece_window_end = last_token_index == (word_starts.size() - 1)
+ ? num_wordpieces
+ : word_starts[last_token_index + 1];
+
+ *num_tokens = last_token_index - *first_token_index + 1;
+ return WordpieceSpan(window_first_wordpiece_index, wordpiece_window_end);
+}
+
+} // namespace internal
+
+WindowGenerator::WindowGenerator(const std::vector<int32_t> &wordpiece_indices,
+ const std::vector<int32_t> &token_starts,
+ const std::vector<Token> &tokens,
+ int max_num_wordpieces,
+ int sliding_window_overlap,
+ const CodepointSpan &span_of_interest)
+ : wordpiece_indices_(&wordpiece_indices),
+ token_starts_(&token_starts),
+ tokens_(&tokens),
+ max_num_effective_wordpieces_(max_num_wordpieces),
+ sliding_window_num_wordpieces_overlap_(sliding_window_overlap) {
+ entire_wordpiece_span_ = internal::FindWordpiecesWindowAroundSpan(
+ span_of_interest, tokens, token_starts, wordpiece_indices.size(),
+ max_num_wordpieces);
+ next_wordpiece_span_ = WordpieceSpan(
+ entire_wordpiece_span_.begin,
+ std::min(entire_wordpiece_span_.begin + max_num_effective_wordpieces_,
+ entire_wordpiece_span_.end));
+ previous_wordpiece_span_ = WordpieceSpan(-1, -1);
+}
+
+bool WindowGenerator::Next(VectorSpan<int32_t> *cur_wordpiece_indices,
+ VectorSpan<int32_t> *cur_token_starts,
+ VectorSpan<Token> *cur_tokens) {
+ if (Done()) {
+ return false;
+ }
+ // Update the span to cover full tokens.
+ int cur_first_token_index, cur_num_tokens;
+ next_wordpiece_span_ = internal::FindFullTokensSpanInWindow(
+ *token_starts_, next_wordpiece_span_, max_num_effective_wordpieces_,
+ wordpiece_indices_->size(), &cur_first_token_index, &cur_num_tokens);
+ *cur_token_starts = VectorSpan<int32_t>(
+ token_starts_->begin() + cur_first_token_index,
+ token_starts_->begin() + cur_first_token_index + cur_num_tokens);
+ *cur_tokens = VectorSpan<Token>(
+ tokens_->begin() + cur_first_token_index,
+ tokens_->begin() + cur_first_token_index + cur_num_tokens);
+
+ // Handle the edge case where the tokens are composed of many wordpieces and
+ // the window doesn't advance.
+ if (next_wordpiece_span_.begin <= previous_wordpiece_span_.begin ||
+ next_wordpiece_span_.end <= previous_wordpiece_span_.end) {
+ return false;
+ }
+ previous_wordpiece_span_ = next_wordpiece_span_;
+
+ int next_wordpiece_first = std::max(
+ previous_wordpiece_span_.end - sliding_window_num_wordpieces_overlap_,
+ previous_wordpiece_span_.begin + 1);
+ next_wordpiece_span_ = WordpieceSpan(
+ next_wordpiece_first,
+ std::min(next_wordpiece_first + max_num_effective_wordpieces_,
+ entire_wordpiece_span_.end));
+
+ *cur_wordpiece_indices = VectorSpan<int>(
+ wordpiece_indices_->begin() + previous_wordpiece_span_.begin,
+ wordpiece_indices_->begin() + previous_wordpiece_span_.begin +
+ previous_wordpiece_span_.length());
+
+ return true;
+}
+
+bool ConvertTagsToAnnotatedSpans(const VectorSpan<Token> &tokens,
+ const std::vector<std::string> &tags,
+ const std::vector<std::string> &label_filter,
+ bool relaxed_inside_label_matching,
+ bool relaxed_label_category_matching,
+ float priority_score,
+ std::vector<AnnotatedSpan> *results) {
+ AnnotatedSpan current_span;
+ std::string current_tag_type;
+ if (tags.size() > tokens.size()) {
+ return false;
+ }
+ for (int i = 0; i < tags.size(); i++) {
+ if (tags[i].empty()) {
+ return false;
+ }
+
+ std::vector<absl::string_view> tag_parts = absl::StrSplit(tags[i], '-');
+ TC3_CHECK_GT(tag_parts.size(), 0);
+ if (tag_parts[0].size() != 1) {
+ return false;
+ }
+
+ std::string tag_type = "";
+ if (tag_parts.size() > 2) {
+ // Skip if the current label doesn't match the filter.
+ if (!StrIsOneOf(std::string(tag_parts[1]), label_filter)) {
+ current_tag_type = "";
+ current_span = {};
+ continue;
+ }
+
+ // Relax the matching of the label category if specified.
+ tag_type = relaxed_label_category_matching
+ ? std::string(tag_parts[2])
+ : absl::StrCat(tag_parts[1], "-", tag_parts[2]);
+ }
+
+ switch (tag_parts[0][0]) {
+ case 'S': {
+ if (tag_parts.size() != 3) {
+ return false;
+ }
+
+ current_span = {};
+ current_tag_type = "";
+ results->push_back(AnnotatedSpan{
+ {tokens[i].start, tokens[i].end},
+ {{/*arg_collection=*/SaftLabelToCollection(tag_parts[2]),
+ /*arg_score=*/1.0, priority_score}}});
+ break;
+ };
+
+ case 'B': {
+ if (tag_parts.size() != 3) {
+ return false;
+ }
+ current_tag_type = tag_type;
+ current_span = {};
+ current_span.classification.push_back(
+ {/*arg_collection=*/SaftLabelToCollection(tag_parts[2]),
+ /*arg_score=*/1.0, priority_score});
+ current_span.span.first = tokens[i].start;
+ break;
+ };
+
+ case 'I': {
+ if (tag_parts.size() != 3) {
+ return false;
+ }
+ if (!relaxed_inside_label_matching && current_tag_type != tag_type) {
+ current_tag_type = "";
+ current_span = {};
+ }
+ break;
+ }
+
+ case 'E': {
+ if (tag_parts.size() != 3) {
+ return false;
+ }
+ if (!current_tag_type.empty() && current_tag_type == tag_type) {
+ current_span.span.second = tokens[i].end;
+ results->push_back(current_span);
+ current_span = {};
+ current_tag_type = "";
+ }
+ break;
+ };
+
+ case 'O': {
+ current_tag_type = "";
+ current_span = {};
+ break;
+ };
+
+ default: {
+ TC3_LOG(ERROR) << "Unrecognized tag: " << tags[i];
+ return false;
+ }
+ }
+ }
+ return true;
+}
+
+using PodNerModel_::CollectionT;
+using PodNerModel_::LabelT;
+using PodNerModel_::Label_::BoiseType;
+using PodNerModel_::Label_::MentionType;
+
+bool ConvertTagsToAnnotatedSpans(const VectorSpan<Token> &tokens,
+ const std::vector<LabelT> &labels,
+ const std::vector<CollectionT> &collections,
+ const std::vector<MentionType> &mention_filter,
+ bool relaxed_inside_label_matching,
+ bool relaxed_mention_type_matching,
+ std::vector<AnnotatedSpan> *results) {
+ if (labels.size() > tokens.size()) {
+ return false;
+ }
+
+ AnnotatedSpan current_span;
+ std::string current_collection_name = "";
+
+ for (int i = 0; i < labels.size(); i++) {
+ const LabelT &label = labels[i];
+
+ if (label.collection_id < 0 || label.collection_id >= collections.size()) {
+ return false;
+ }
+
+ if (std::find(mention_filter.begin(), mention_filter.end(),
+ label.mention_type) == mention_filter.end()) {
+ // Skip if the current label doesn't match the filter.
+ current_span = {};
+ current_collection_name = "";
+ continue;
+ }
+
+ switch (label.boise_type) {
+ case BoiseType::BoiseType_SINGLE: {
+ current_span = {};
+ current_collection_name = "";
+ results->push_back(AnnotatedSpan{
+ {tokens[i].start, tokens[i].end},
+ {{/*arg_collection=*/collections[label.collection_id].name,
+ /*arg_score=*/1.0,
+ collections[label.collection_id].single_token_priority_score}}});
+ break;
+ };
+
+ case BoiseType::BoiseType_BEGIN: {
+ current_span = {};
+ current_span.classification.push_back(
+ {/*arg_collection=*/collections[label.collection_id].name,
+ /*arg_score=*/1.0,
+ collections[label.collection_id].multi_token_priority_score});
+ current_span.span.first = tokens[i].start;
+ current_collection_name = collections[label.collection_id].name;
+ break;
+ };
+
+ case BoiseType::BoiseType_INTERMEDIATE: {
+ if (current_collection_name.empty() ||
+ (!relaxed_mention_type_matching &&
+ labels[i - 1].mention_type != label.mention_type) ||
+ (!relaxed_inside_label_matching &&
+ labels[i - 1].collection_id != label.collection_id)) {
+ current_span = {};
+ current_collection_name = "";
+ }
+ break;
+ }
+
+ case BoiseType::BoiseType_END: {
+ if (!current_collection_name.empty() &&
+ current_collection_name == collections[label.collection_id].name &&
+ (relaxed_mention_type_matching ||
+ labels[i - 1].mention_type == label.mention_type)) {
+ current_span.span.second = tokens[i].end;
+ results->push_back(current_span);
+ }
+ current_span = {};
+ current_collection_name = "";
+ break;
+ };
+
+ case BoiseType::BoiseType_O: {
+ current_span = {};
+ current_collection_name = "";
+ break;
+ };
+
+ default: {
+ TC3_LOG(ERROR) << "Unrecognized tag: " << labels[i].boise_type;
+ return false;
+ }
+ }
+ }
+ return true;
+}
+
+bool MergeLabelsIntoLeftSequence(
+ const std::vector<PodNerModel_::LabelT> &labels_right,
+ int index_first_right_tag_in_left,
+ std::vector<PodNerModel_::LabelT> *labels_left) {
+ if (index_first_right_tag_in_left > labels_left->size()) {
+ return false;
+ }
+
+ int overlaping_from_left =
+ (labels_left->size() - index_first_right_tag_in_left) / 2;
+
+ labels_left->resize(index_first_right_tag_in_left + labels_right.size());
+ std::copy(labels_right.begin() + overlaping_from_left, labels_right.end(),
+ labels_left->begin() + index_first_right_tag_in_left +
+ overlaping_from_left);
+ return true;
+}
+
+} // namespace libtextclassifier3
diff --git a/native/annotator/pod_ner/utils.h b/native/annotator/pod_ner/utils.h
new file mode 100644
index 0000000..6c4a902
--- /dev/null
+++ b/native/annotator/pod_ner/utils.h
@@ -0,0 +1,147 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_POD_NER_UTILS_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_POD_NER_UTILS_H_
+
+#include <string>
+#include <vector>
+
+#include "annotator/model_generated.h"
+#include "annotator/types.h"
+#include "absl/strings/string_view.h"
+
+namespace libtextclassifier3 {
+// Converts saft labels like /saft/person to collection name 'person'.
+std::string SaftLabelToCollection(absl::string_view saft_label);
+
+struct WordpieceSpan {
+ // Beginning index is inclusive, end index is exclusive.
+ WordpieceSpan() : begin(0), end(0) {}
+ WordpieceSpan(int begin, int end) : begin(begin), end(end) {}
+ int begin;
+ int end;
+ bool operator==(const WordpieceSpan &other) const {
+ return this->begin == other.begin && this->end == other.end;
+ }
+ int length() { return end - begin; }
+};
+
+namespace internal {
+// Finds the wordpiece window arond the given span_of_interest. If the number
+// of wordpieces in this window is smaller than max_num_wordpieces_in_window
+// it is expanded around the span of interest.
+WordpieceSpan FindWordpiecesWindowAroundSpan(
+ const CodepointSpan &span_of_interest, const std::vector<Token> &tokens,
+ const std::vector<int32_t> &word_starts, int num_wordpieces,
+ int max_num_wordpieces_in_window);
+// Expands the given wordpiece window around the given window to the be
+// maximal possible while making sure it includes only full tokens.
+WordpieceSpan ExpandWindowAndAlign(int max_num_wordpieces_in_window,
+ int num_wordpieces,
+ WordpieceSpan wordpiece_span_to_expand);
+// Returns the index of the last token which ends before wordpiece_end.
+int FindLastFullTokenIndex(const std::vector<int32_t> &word_starts,
+ int num_wordpieces, int wordpiece_end);
+// Returns the index of the token which includes first_wordpiece_index.
+int FindFirstFullTokenIndex(const std::vector<int32_t> &word_starts,
+ int first_wordpiece_index);
+// Given wordpiece_span, and max_num_wordpieces, finds:
+// 1. The first token which includes wordpiece_span.begin.
+// 2. The length of tokens sequence which starts from this token and:
+// a. Its last token's last wordpiece index ends before wordpiece_span.end.
+// b. Its overall number of wordpieces is at most max_num_wordpieces.
+// Returns the updated wordpiece_span: begin and end wordpieces of this token
+// sequence.
+WordpieceSpan FindFullTokensSpanInWindow(
+ const std::vector<int32_t> &word_starts,
+ const WordpieceSpan &wordpiece_span, int max_num_wordpieces,
+ int num_wordpieces, int *first_token_index, int *num_tokens);
+
+} // namespace internal
+// Converts sequence of IOB tags to AnnotatedSpans. Ignores illegal sequences.
+// Setting label_filter can also help ignore certain label tags like "NAM" or
+// "NOM".
+// The inside tag can be ignored when setting relaxed_inside_label_matching,
+// e.g. B-NAM-location, I-NAM-other, E-NAM-location would be considered a valid
+// sequence.
+// The label category matching can be ignored when setting
+// relaxed_label_category_matching. The matching will only operate at the entity
+// level, e.g. B-NAM-location, E-NOM-location would be considered a valid
+// sequence.
+bool ConvertTagsToAnnotatedSpans(const VectorSpan<Token> &tokens,
+ const std::vector<std::string> &tags,
+ const std::vector<std::string> &label_filter,
+ bool relaxed_inside_label_matching,
+ bool relaxed_label_category_matching,
+ float priority_score,
+ std::vector<AnnotatedSpan> *results);
+
+// Like the previous function but instead of getting the tags as strings
+// the input is PodNerModel_::LabelT along with the collections vector which
+// hold the collection name and priorities. e.g. a tag was "B-NAM-location" and
+// the priority_score was 1.0 it would be Label(BoiseType_BEGIN,
+// MentionType_NAM, 1) and collections={{"xxx", 1., 1.},
+// {"location", 1., 1.}, {"yyy", 1., 1.}, ...}.
+bool ConvertTagsToAnnotatedSpans(
+ const VectorSpan<Token> &tokens,
+ const std::vector<PodNerModel_::LabelT> &labels,
+ const std::vector<PodNerModel_::CollectionT> &collections,
+ const std::vector<PodNerModel_::Label_::MentionType> &mention_filter,
+ bool relaxed_inside_label_matching, bool relaxed_mention_type_matching,
+ std::vector<AnnotatedSpan> *results);
+
+// Merge two overlaping sequences of labels, the result is placed into the left
+// sequence. In the overlapping part takes the labels from the left sequence on
+// the first half and from the right on the second half.
+bool MergeLabelsIntoLeftSequence(
+ const std::vector<PodNerModel_::LabelT> &labels_right,
+ int index_first_right_tag_in_left,
+ std::vector<PodNerModel_::LabelT> *labels_left);
+
+// This class is used to slide over {wordpiece_indices, token_starts, tokens} in
+// windows of at most max_num_wordpieces while assuring that each window
+// contains only full tokens.
+class WindowGenerator {
+ public:
+ WindowGenerator(const std::vector<int32_t> &wordpiece_indices,
+ const std::vector<int32_t> &token_starts,
+ const std::vector<Token> &tokens, int max_num_wordpieces,
+ int sliding_window_overlap,
+ const CodepointSpan &span_of_interest);
+
+ bool Next(VectorSpan<int32_t> *cur_wordpiece_indices,
+ VectorSpan<int32_t> *cur_token_starts,
+ VectorSpan<Token> *cur_tokens);
+
+ bool Done() const {
+ return previous_wordpiece_span_.end >= entire_wordpiece_span_.end;
+ }
+
+ private:
+ const std::vector<int32_t> *wordpiece_indices_;
+ const std::vector<int32_t> *token_starts_;
+ const std::vector<Token> *tokens_;
+ int max_num_effective_wordpieces_;
+ int sliding_window_num_wordpieces_overlap_;
+ WordpieceSpan entire_wordpiece_span_;
+ WordpieceSpan next_wordpiece_span_;
+ WordpieceSpan previous_wordpiece_span_;
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_POD_NER_UTILS_H_
diff --git a/native/annotator/pod_ner/utils_test.cc b/native/annotator/pod_ner/utils_test.cc
new file mode 100644
index 0000000..fdc82f2
--- /dev/null
+++ b/native/annotator/pod_ner/utils_test.cc
@@ -0,0 +1,905 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/pod_ner/utils.h"
+
+#include <iterator>
+
+#include "annotator/model_generated.h"
+#include "annotator/types.h"
+#include "utils/tokenizer-utils.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/container/flat_hash_map.h"
+#include "absl/strings/str_split.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+using ::testing::IsEmpty;
+using ::testing::Not;
+
+using PodNerModel_::CollectionT;
+using PodNerModel_::LabelT;
+using PodNerModel_::Label_::BoiseType;
+using PodNerModel_::Label_::BoiseType_BEGIN;
+using PodNerModel_::Label_::BoiseType_END;
+using PodNerModel_::Label_::BoiseType_INTERMEDIATE;
+using PodNerModel_::Label_::BoiseType_O;
+using PodNerModel_::Label_::BoiseType_SINGLE;
+using PodNerModel_::Label_::MentionType;
+using PodNerModel_::Label_::MentionType_NAM;
+using PodNerModel_::Label_::MentionType_NOM;
+using PodNerModel_::Label_::MentionType_UNDEFINED;
+
+constexpr float kPriorityScore = 0.;
+const std::vector<std::string>& kCollectionNames =
+ *new std::vector<std::string>{"undefined", "location", "person", "art",
+ "organization", "entitiy", "xxx"};
+const auto& kStringToBoiseType = *new absl::flat_hash_map<
+ absl::string_view, libtextclassifier3::PodNerModel_::Label_::BoiseType>({
+ {"B", libtextclassifier3::PodNerModel_::Label_::BoiseType_BEGIN},
+ {"O", libtextclassifier3::PodNerModel_::Label_::BoiseType_O},
+ {"I", libtextclassifier3::PodNerModel_::Label_::BoiseType_INTERMEDIATE},
+ {"S", libtextclassifier3::PodNerModel_::Label_::BoiseType_SINGLE},
+ {"E", libtextclassifier3::PodNerModel_::Label_::BoiseType_END},
+});
+const auto& kStringToMentionType = *new absl::flat_hash_map<
+ absl::string_view, libtextclassifier3::PodNerModel_::Label_::MentionType>(
+ {{"NAM", libtextclassifier3::PodNerModel_::Label_::MentionType_NAM},
+ {"NOM", libtextclassifier3::PodNerModel_::Label_::MentionType_NOM}});
+LabelT CreateLabel(BoiseType boise_type, MentionType mention_type,
+ int collection_id) {
+ LabelT label;
+ label.boise_type = boise_type;
+ label.mention_type = mention_type;
+ label.collection_id = collection_id;
+ return label;
+}
+std::vector<PodNerModel_::LabelT> TagsToLabels(
+ const std::vector<std::string>& tags) {
+ std::vector<PodNerModel_::LabelT> labels;
+ for (const auto& tag : tags) {
+ if (tag == "O") {
+ labels.emplace_back(CreateLabel(BoiseType_O, MentionType_UNDEFINED, 0));
+ } else {
+ std::vector<absl::string_view> tag_parts = absl::StrSplit(tag, '-');
+ labels.emplace_back(CreateLabel(
+ kStringToBoiseType.at(tag_parts[0]),
+ kStringToMentionType.at(tag_parts[1]),
+ std::distance(
+ kCollectionNames.begin(),
+ std::find(kCollectionNames.begin(), kCollectionNames.end(),
+ std::string(tag_parts[2].substr(
+ tag_parts[2].rfind('/') + 1))))));
+ }
+ }
+ return labels;
+}
+
+std::vector<CollectionT> GetCollections() {
+ std::vector<CollectionT> collections;
+ for (const std::string& collection_name : kCollectionNames) {
+ CollectionT collection;
+ collection.name = collection_name;
+ collection.single_token_priority_score = kPriorityScore;
+ collection.multi_token_priority_score = kPriorityScore;
+ collections.emplace_back(collection);
+ }
+ return collections;
+}
+
+class ConvertTagsToAnnotatedSpansTest : public testing::TestWithParam<bool> {};
+INSTANTIATE_TEST_SUITE_P(TagsAndLabelsTest, ConvertTagsToAnnotatedSpansTest,
+ testing::Values(true, false));
+
+TEST_P(ConvertTagsToAnnotatedSpansTest,
+ ConvertTagsToAnnotatedSpansHandlesBIESequence) {
+ std::vector<AnnotatedSpan> annotations;
+ std::string text = "We met in New York City";
+ std::vector<std::string> tags = {"O",
+ "O",
+ "O",
+ "B-NAM-/saft/location",
+ "I-NAM-/saft/location",
+ "E-NAM-/saft/location"};
+ if (GetParam()) {
+ ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
+ VectorSpan<Token>(TokenizeOnSpace(text)), tags,
+ /*label_filter=*/{"NAM", "NOM"},
+ /*relaxed_inside_label_matching=*/false,
+ /*relaxed_label_category_matching=*/false, kPriorityScore,
+ &annotations));
+ } else {
+ ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
+ VectorSpan<Token>(TokenizeOnSpace(text)), TagsToLabels(tags),
+ GetCollections(),
+ /*mention_filter=*/{MentionType_NAM, MentionType_NOM},
+ /*relaxed_inside_label_matching=*/false,
+ /*relaxed_mention_type_matching=*/false, &annotations));
+ }
+
+ EXPECT_EQ(annotations.size(), 1);
+ EXPECT_EQ(annotations[0].span, CodepointSpan(10, 23));
+ EXPECT_EQ(annotations[0].classification[0].collection, "location");
+}
+
+TEST_P(ConvertTagsToAnnotatedSpansTest,
+ ConvertTagsToAnnotatedSpansHandlesSSequence) {
+ std::vector<AnnotatedSpan> annotations;
+ std::string text = "His father was it.";
+ std::vector<std::string> tags = {"O", "S-NAM-/saft/person", "O", "O"};
+ if (GetParam()) {
+ ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
+ VectorSpan<Token>(TokenizeOnSpace(text)), tags,
+ /*label_filter=*/{"NAM", "NOM"},
+ /*relaxed_inside_label_matching=*/false,
+ /*relaxed_label_category_matching=*/false, kPriorityScore,
+ &annotations));
+ } else {
+ ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
+ VectorSpan<Token>(TokenizeOnSpace(text)), TagsToLabels(tags),
+ GetCollections(),
+ /*mention_filter=*/{MentionType_NAM, MentionType_NOM},
+ /*relaxed_inside_label_matching=*/false,
+ /*relaxed_mention_type_matching=*/false, &annotations));
+ }
+
+ EXPECT_EQ(annotations.size(), 1);
+ EXPECT_EQ(annotations[0].span, CodepointSpan(4, 10));
+ EXPECT_EQ(annotations[0].classification[0].collection, "person");
+}
+
+TEST_P(ConvertTagsToAnnotatedSpansTest,
+ ConvertTagsToAnnotatedSpansHandlesMultiple) {
+ std::vector<AnnotatedSpan> annotations;
+ std::string text =
+ "Jaromir Jagr, Barak Obama and I met in Google New York City";
+ std::vector<std::string> tags = {"B-NAM-/saft/person",
+ "E-NAM-/saft/person",
+ "B-NOM-/saft/person",
+ "E-NOM-/saft/person",
+ "O",
+ "O",
+ "O",
+ "O",
+ "S-NAM-/saft/organization",
+ "B-NAM-/saft/location",
+ "I-NAM-/saft/location",
+ "E-NAM-/saft/location"};
+ if (GetParam()) {
+ ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
+ VectorSpan<Token>(TokenizeOnSpace(text)), tags,
+ /*label_filter=*/{"NAM", "NOM"},
+ /*relaxed_inside_label_matching=*/false,
+ /*relaxed_label_category_matching=*/false, kPriorityScore,
+ &annotations));
+ } else {
+ ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
+ VectorSpan<Token>(TokenizeOnSpace(text)), TagsToLabels(tags),
+ GetCollections(),
+ /*mention_filter=*/{MentionType_NAM, MentionType_NOM},
+ /*relaxed_inside_label_matching=*/false,
+ /*relaxed_mention_type_matching=*/false, &annotations));
+
+ ASSERT_EQ(annotations.size(), 4);
+ EXPECT_EQ(annotations[0].span, CodepointSpan(0, 13));
+ ASSERT_THAT(annotations[0].classification, Not(IsEmpty()));
+ EXPECT_EQ(annotations[0].classification[0].collection, "person");
+ EXPECT_EQ(annotations[1].span, CodepointSpan(14, 25));
+ ASSERT_THAT(annotations[1].classification, Not(IsEmpty()));
+ EXPECT_EQ(annotations[1].classification[0].collection, "person");
+ EXPECT_EQ(annotations[2].span, CodepointSpan(39, 45));
+ ASSERT_THAT(annotations[2].classification, Not(IsEmpty()));
+ EXPECT_EQ(annotations[2].classification[0].collection, "organization");
+ EXPECT_EQ(annotations[3].span, CodepointSpan(46, 59));
+ ASSERT_THAT(annotations[3].classification, Not(IsEmpty()));
+ EXPECT_EQ(annotations[3].classification[0].collection, "location");
+ }
+}
+
+TEST_P(ConvertTagsToAnnotatedSpansTest,
+ ConvertTagsToAnnotatedSpansHandlesMultipleFirstTokenNotFirst) {
+ std::vector<AnnotatedSpan> annotations;
+ std::vector<Token> original_tokens = TokenizeOnSpace(
+ "Jaromir Jagr, Barak Obama and I met in Google New York City");
+ std::vector<std::string> tags = {"B-NOM-/saft/person",
+ "E-NOM-/saft/person",
+ "O",
+ "O",
+ "O",
+ "O",
+ "S-NAM-/saft/organization",
+ "B-NAM-/saft/location",
+ "I-NAM-/saft/location",
+ "E-NAM-/saft/location"};
+ if (GetParam()) {
+ ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
+ VectorSpan<Token>(original_tokens.begin() + 2, original_tokens.end()),
+ tags,
+ /*label_filter=*/{"NAM", "NOM"},
+ /*relaxed_inside_label_matching=*/false,
+ /*relaxed_label_category_matching=*/false, kPriorityScore,
+ &annotations));
+ } else {
+ ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
+ VectorSpan<Token>(original_tokens.begin() + 2, original_tokens.end()),
+ TagsToLabels(tags), GetCollections(),
+ /*mention_filter=*/{MentionType_NAM, MentionType_NOM},
+ /*relaxed_inside_label_matching=*/false,
+ /*relaxed_mention_type_matching=*/false, &annotations));
+ }
+
+ ASSERT_EQ(annotations.size(), 3);
+ EXPECT_EQ(annotations[0].span, CodepointSpan(14, 25));
+ ASSERT_THAT(annotations[0].classification, Not(IsEmpty()));
+ EXPECT_EQ(annotations[0].classification[0].collection, "person");
+ EXPECT_EQ(annotations[1].span, CodepointSpan(39, 45));
+ ASSERT_THAT(annotations[1].classification, Not(IsEmpty()));
+ EXPECT_EQ(annotations[1].classification[0].collection, "organization");
+ EXPECT_EQ(annotations[2].span, CodepointSpan(46, 59));
+ ASSERT_THAT(annotations[2].classification, Not(IsEmpty()));
+ EXPECT_EQ(annotations[2].classification[0].collection, "location");
+}
+
+TEST(PodNerUtilsTest, ConvertTagsToAnnotatedSpansInvalidCollection) {
+ std::vector<AnnotatedSpan> annotations;
+ std::string text = "We met in New York City";
+ std::vector<std::string> tags = {"O", "O", "S-NAM-/saft/invalid_collection"};
+
+ ASSERT_FALSE(ConvertTagsToAnnotatedSpans(
+ VectorSpan<Token>(TokenizeOnSpace(text)), TagsToLabels(tags),
+ GetCollections(),
+ /*mention_filter=*/{MentionType_NAM, MentionType_NOM},
+ /*relaxed_inside_label_matching=*/false,
+ /*relaxed_mention_type_matching=*/false, &annotations));
+}
+
+TEST_P(ConvertTagsToAnnotatedSpansTest,
+ ConvertTagsToAnnotatedSpansIgnoresInconsistentStart) {
+ std::vector<AnnotatedSpan> annotations;
+ std::string text = "We met in New York City";
+ std::vector<std::string> tags = {"O",
+ "O",
+ "O",
+ "B-NAM-/saft/xxx",
+ "I-NAM-/saft/location",
+ "E-NAM-/saft/location"};
+ if (GetParam()) {
+ ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
+ VectorSpan<Token>(TokenizeOnSpace(text)), tags,
+ /*label_filter=*/{"NAM", "NOM"},
+ /*relaxed_inside_label_matching=*/false,
+ /*relaxed_label_category_matching=*/false, kPriorityScore,
+ &annotations));
+ } else {
+ ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
+ VectorSpan<Token>(TokenizeOnSpace(text)), TagsToLabels(tags),
+ GetCollections(),
+ /*mention_filter=*/{MentionType_NAM, MentionType_NOM},
+ /*relaxed_inside_label_matching=*/false,
+ /*relaxed_mention_type_matching=*/false, &annotations));
+ }
+ EXPECT_THAT(annotations, IsEmpty());
+}
+
+TEST_P(ConvertTagsToAnnotatedSpansTest,
+ ConvertTagsToAnnotatedSpansIgnoresInconsistentLabelTypeStart) {
+ std::vector<AnnotatedSpan> annotations;
+ std::string text = "We met in New York City";
+ std::vector<std::string> tags = {"O",
+ "O",
+ "O",
+ "B-NOM-/saft/location",
+ "I-NAM-/saft/location",
+ "E-NAM-/saft/location"};
+ if (GetParam()) {
+ ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
+ VectorSpan<Token>(TokenizeOnSpace(text)), tags,
+ /*label_filter=*/{"NAM", "NOM"},
+ /*relaxed_inside_label_matching=*/false,
+ /*relaxed_label_category_matching=*/false, kPriorityScore,
+ &annotations));
+ } else {
+ ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
+ VectorSpan<Token>(TokenizeOnSpace(text)), TagsToLabels(tags),
+ GetCollections(),
+ /*mention_filter=*/{MentionType_NAM, MentionType_NOM},
+ /*relaxed_inside_label_matching=*/false,
+ /*relaxed_mention_type_matching=*/false, &annotations));
+ }
+
+ EXPECT_THAT(annotations, IsEmpty());
+}
+
+TEST_P(ConvertTagsToAnnotatedSpansTest,
+ ConvertTagsToAnnotatedSpansIgnoresInconsistentInside) {
+ std::vector<AnnotatedSpan> annotations;
+ std::string text = "We met in New York City";
+ std::vector<std::string> tags = {"O",
+ "O",
+ "O",
+ "B-NAM-/saft/location",
+ "I-NAM-/saft/xxx",
+ "E-NAM-/saft/location"};
+ if (GetParam()) {
+ ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
+ VectorSpan<Token>(TokenizeOnSpace(text)), tags,
+ /*label_filter=*/{"NAM", "NOM"},
+ /*relaxed_inside_label_matching=*/false,
+ /*relaxed_label_category_matching=*/false, kPriorityScore,
+ &annotations));
+ } else {
+ ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
+ VectorSpan<Token>(TokenizeOnSpace(text)), TagsToLabels(tags),
+ GetCollections(),
+ /*mention_filter=*/{MentionType_NAM, MentionType_NOM},
+ /*relaxed_inside_label_matching=*/false,
+ /*relaxed_mention_type_matching=*/false, &annotations));
+ }
+
+ EXPECT_THAT(annotations, IsEmpty());
+}
+
+TEST_P(ConvertTagsToAnnotatedSpansTest,
+ ConvertTagsToAnnotatedSpansIgnoresInconsistentLabelTypeInside) {
+ std::vector<AnnotatedSpan> annotations;
+ std::string text = "We met in New York City";
+ std::vector<std::string> tags = {"O",
+ "O",
+ "O",
+ "B-NAM-/saft/location",
+ "I-NOM-/saft/location",
+ "E-NAM-/saft/location"};
+ if (GetParam()) {
+ ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
+ VectorSpan<Token>(TokenizeOnSpace(text)), tags,
+ /*label_filter=*/{"NAM", "NOM"},
+ /*relaxed_inside_label_matching=*/false,
+ /*relaxed_label_category_matching=*/false, kPriorityScore,
+ &annotations));
+ } else {
+ ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
+ VectorSpan<Token>(TokenizeOnSpace(text)), TagsToLabels(tags),
+ GetCollections(),
+ /*mention_filter=*/{MentionType_NAM, MentionType_NOM},
+ /*relaxed_inside_label_matching=*/false,
+ /*relaxed_mention_type_matching=*/false, &annotations));
+ }
+ EXPECT_THAT(annotations, IsEmpty());
+}
+
+TEST_P(ConvertTagsToAnnotatedSpansTest,
+ ConvertTagsToAnnotatedSpansHandlesInconsistentInside) {
+ std::vector<AnnotatedSpan> annotations;
+ std::string text = "We met in New York City";
+ std::vector<std::string> tags = {"O",
+ "O",
+ "O",
+ "B-NAM-/saft/location",
+ "I-NAM-/saft/xxx",
+ "E-NAM-/saft/location"};
+ if (GetParam()) {
+ ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
+ VectorSpan<Token>(TokenizeOnSpace(text)), tags,
+ /*label_filter=*/{"NAM", "NOM"},
+ /*relaxed_inside_label_matching=*/true,
+ /*relaxed_label_category_matching=*/false, kPriorityScore,
+ &annotations));
+ } else {
+ ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
+ VectorSpan<Token>(TokenizeOnSpace(text)), TagsToLabels(tags),
+ GetCollections(),
+ /*mention_filter=*/{MentionType_NAM, MentionType_NOM},
+ /*relaxed_inside_label_matching=*/true,
+ /*relaxed_mention_type_matching=*/false, &annotations));
+ }
+
+ EXPECT_EQ(annotations.size(), 1);
+ EXPECT_EQ(annotations[0].span, CodepointSpan(10, 23));
+ EXPECT_EQ(annotations[0].classification[0].collection, "location");
+}
+
+TEST_P(ConvertTagsToAnnotatedSpansTest,
+ ConvertTagsToAnnotatedSpansIgnoresInconsistentEnd) {
+ std::vector<AnnotatedSpan> annotations;
+ std::string text = "We met in New York City";
+ std::vector<std::string> tags = {"O",
+ "O",
+ "O",
+ "B-NAM-/saft/location",
+ "I-NAM-/saft/location",
+ "E-NAM-/saft/xxx"};
+ if (GetParam()) {
+ ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
+ VectorSpan<Token>(TokenizeOnSpace(text)), tags,
+ /*label_filter=*/{"NAM", "NOM"},
+ /*relaxed_inside_label_matching=*/false,
+ /*relaxed_label_category_matching=*/false, kPriorityScore,
+ &annotations));
+ } else {
+ ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
+ VectorSpan<Token>(TokenizeOnSpace(text)), TagsToLabels(tags),
+ GetCollections(),
+ /*mention_filter=*/{MentionType_NAM, MentionType_NOM},
+ /*relaxed_inside_label_matching=*/false,
+ /*relaxed_mention_type_matching=*/false, &annotations));
+ }
+
+ EXPECT_THAT(annotations, IsEmpty());
+}
+
+TEST_P(ConvertTagsToAnnotatedSpansTest,
+ ConvertTagsToAnnotatedSpansIgnoresInconsistentLabelTypeEnd) {
+ std::vector<AnnotatedSpan> annotations;
+ std::string text = "We met in New York City";
+ std::vector<std::string> tags = {"O",
+ "O",
+ "O",
+ "B-NAM-/saft/location",
+ "I-NAM-/saft/location",
+ "E-NOM-/saft/location"};
+ if (GetParam()) {
+ ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
+ VectorSpan<Token>(TokenizeOnSpace(text)), tags,
+ /*label_filter=*/{"NAM", "NOM"},
+ /*relaxed_inside_label_matching=*/false,
+ /*relaxed_label_category_matching=*/false, kPriorityScore,
+ &annotations));
+ } else {
+ ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
+ VectorSpan<Token>(TokenizeOnSpace(text)), TagsToLabels(tags),
+ GetCollections(),
+ /*mention_filter=*/{MentionType_NAM, MentionType_NOM},
+ /*relaxed_inside_label_matching=*/false,
+ /*relaxed_mention_type_matching=*/false, &annotations));
+ }
+
+ EXPECT_THAT(annotations, IsEmpty());
+}
+
+TEST_P(
+ ConvertTagsToAnnotatedSpansTest,
+ ConvertTagsToAnnotatedSpansHandlesInconsistentLabelTypeWhenEntityMatches) {
+ std::vector<AnnotatedSpan> annotations;
+ std::string text = "We met in New York City";
+ std::vector<std::string> tags = {"O",
+ "O",
+ "O",
+ "B-NOM-/saft/location",
+ "I-NOM-/saft/location",
+ "E-NAM-/saft/location"};
+ if (GetParam()) {
+ ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
+ VectorSpan<Token>(TokenizeOnSpace(text)), tags,
+ /*label_filter=*/{"NAM", "NOM"},
+ /*relaxed_inside_label_matching=*/false,
+ /*relaxed_label_category_matching=*/true, kPriorityScore,
+ &annotations));
+ } else {
+ ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
+ VectorSpan<Token>(TokenizeOnSpace(text)), TagsToLabels(tags),
+ GetCollections(),
+ /*mention_filter=*/{MentionType_NAM, MentionType_NOM},
+ /*relaxed_inside_label_matching=*/false,
+ /*relaxed_mention_type_matching=*/true, &annotations));
+ }
+
+ EXPECT_EQ(annotations.size(), 1);
+ EXPECT_EQ(annotations[0].span, CodepointSpan(10, 23));
+ EXPECT_EQ(annotations[0].classification[0].collection, "location");
+}
+
+TEST_P(ConvertTagsToAnnotatedSpansTest,
+ ConvertTagsToAnnotatedSpansIgnoresFilteredLabel) {
+ std::vector<AnnotatedSpan> annotations;
+ std::string text = "We met in New York City";
+ std::vector<std::string> tags = {"O",
+ "O",
+ "O",
+ "B-NAM-/saft/location",
+ "I-NAM-/saft/location",
+ "E-NAM-/saft/location"};
+ if (GetParam()) {
+ ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
+ VectorSpan<Token>(TokenizeOnSpace(text)), tags,
+ /*label_filter=*/{"NOM"},
+ /*relaxed_inside_label_matching=*/false,
+ /*relaxed_label_category_matching=*/false, kPriorityScore,
+ &annotations));
+ } else {
+ ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
+ VectorSpan<Token>(TokenizeOnSpace(text)), TagsToLabels(tags),
+ GetCollections(),
+ /*mention_filter=*/{MentionType_NOM},
+ /*relaxed_inside_label_matching=*/false,
+ /*relaxed_mention_type_matching=*/false, &annotations));
+ }
+
+ EXPECT_THAT(annotations, IsEmpty());
+}
+
+TEST_P(ConvertTagsToAnnotatedSpansTest,
+ ConvertTagsToAnnotatedSpansWithEmptyLabelFilterIgnoresAll) {
+ std::vector<AnnotatedSpan> annotations;
+ std::string text = "We met in New York City";
+ std::vector<std::string> tags = {"O",
+ "O",
+ "O",
+ "B-NOM-/saft/location",
+ "I-NOM-/saft/location",
+ "E-NOM-/saft/location"};
+ if (GetParam()) {
+ ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
+ VectorSpan<Token>(TokenizeOnSpace(text)), tags,
+ /*label_filter=*/{},
+ /*relaxed_inside_label_matching=*/false,
+ /*relaxed_label_category_matching=*/false, kPriorityScore,
+ &annotations));
+ } else {
+ ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
+ VectorSpan<Token>(TokenizeOnSpace(text)), TagsToLabels(tags),
+ GetCollections(),
+ /*mention_filter=*/{},
+ /*relaxed_inside_label_matching=*/false,
+ /*relaxed_mention_type_matching=*/false, &annotations));
+ }
+
+ EXPECT_THAT(annotations, IsEmpty());
+}
+
+TEST(PodNerUtilsTest, MergeLabelsIntoLeftSequence) {
+ std::vector<PodNerModel_::LabelT> original_labels_left;
+ original_labels_left.emplace_back(
+ CreateLabel(BoiseType_O, MentionType_UNDEFINED, 0));
+ original_labels_left.emplace_back(
+ CreateLabel(BoiseType_O, MentionType_UNDEFINED, 0));
+ original_labels_left.emplace_back(
+ CreateLabel(BoiseType_O, MentionType_UNDEFINED, 0));
+ original_labels_left.emplace_back(
+ CreateLabel(BoiseType_SINGLE, MentionType_NAM, 1));
+ original_labels_left.emplace_back(
+ CreateLabel(BoiseType_O, MentionType_UNDEFINED, 0));
+ original_labels_left.emplace_back(
+ CreateLabel(BoiseType_O, MentionType_UNDEFINED, 0));
+ original_labels_left.emplace_back(
+ CreateLabel(BoiseType_SINGLE, MentionType_NAM, 2));
+
+ std::vector<PodNerModel_::LabelT> labels_right;
+ labels_right.emplace_back(
+ CreateLabel(BoiseType_BEGIN, MentionType_UNDEFINED, 3));
+ labels_right.emplace_back(CreateLabel(BoiseType_O, MentionType_UNDEFINED, 0));
+ labels_right.emplace_back(CreateLabel(BoiseType_O, MentionType_UNDEFINED, 0));
+ labels_right.emplace_back(CreateLabel(BoiseType_BEGIN, MentionType_NAM, 4));
+ labels_right.emplace_back(
+ CreateLabel(BoiseType_INTERMEDIATE, MentionType_UNDEFINED, 4));
+ labels_right.emplace_back(
+ CreateLabel(BoiseType_END, MentionType_UNDEFINED, 4));
+ std::vector<PodNerModel_::LabelT> labels_left = original_labels_left;
+
+ ASSERT_TRUE(MergeLabelsIntoLeftSequence(labels_right,
+ /*index_first_right_tag_in_left=*/3,
+ &labels_left));
+ EXPECT_EQ(labels_left.size(), 9);
+ EXPECT_EQ(labels_left[0].collection_id, 0);
+ EXPECT_EQ(labels_left[1].collection_id, 0);
+ EXPECT_EQ(labels_left[2].collection_id, 0);
+ EXPECT_EQ(labels_left[3].collection_id, 1);
+ EXPECT_EQ(labels_left[4].collection_id, 0);
+ EXPECT_EQ(labels_left[5].collection_id, 0);
+ EXPECT_EQ(labels_left[6].collection_id, 4);
+ EXPECT_EQ(labels_left[7].collection_id, 4);
+ EXPECT_EQ(labels_left[8].collection_id, 4);
+
+ labels_left = original_labels_left;
+ ASSERT_TRUE(MergeLabelsIntoLeftSequence(labels_right,
+ /*index_first_right_tag_in_left=*/2,
+ &labels_left));
+ EXPECT_EQ(labels_left.size(), 8);
+ EXPECT_EQ(labels_left[0].collection_id, 0);
+ EXPECT_EQ(labels_left[1].collection_id, 0);
+ EXPECT_EQ(labels_left[2].collection_id, 0);
+ EXPECT_EQ(labels_left[3].collection_id, 1);
+ EXPECT_EQ(labels_left[4].collection_id, 0);
+ EXPECT_EQ(labels_left[5].collection_id, 4);
+ EXPECT_EQ(labels_left[6].collection_id, 4);
+ EXPECT_EQ(labels_left[7].collection_id, 4);
+}
+
+TEST(PodNerUtilsTest, FindWordpiecesWindowAroundSpanAllWordpices) {
+ std::vector<Token> tokens{{"a", 0, 1}, {"b", 2, 3}, {"c", 4, 5},
+ {"d", 6, 7}, {"e", 8, 9}, {"f", 10, 11},
+ {"my", 12, 14}, {"name", 15, 19}};
+ std::vector<int32_t> word_starts{0, 2, 3, 5, 6, 7, 10, 11};
+
+ WordpieceSpan wordpieceSpan = internal::FindWordpiecesWindowAroundSpan(
+ {2, 3}, tokens, word_starts,
+ /*num_wordpieces=*/12,
+ /*max_num_wordpieces_in_window=*/15);
+ EXPECT_EQ(wordpieceSpan, WordpieceSpan(0, 12));
+}
+
+TEST(PodNerUtilsTest, FindWordpiecesWindowAroundSpanInMiddle) {
+ std::vector<Token> tokens{{"a", 0, 1}, {"b", 2, 3}, {"c", 4, 5},
+ {"d", 6, 7}, {"e", 8, 9}, {"f", 10, 11},
+ {"my", 12, 14}, {"name", 15, 19}};
+ std::vector<int32_t> word_starts{0, 2, 3, 5, 6, 7, 10, 11};
+
+ WordpieceSpan wordpieceSpan = internal::FindWordpiecesWindowAroundSpan(
+ {6, 7}, tokens, word_starts,
+ /*num_wordpieces=*/12,
+ /*max_num_wordpieces_in_window=*/5);
+ EXPECT_EQ(wordpieceSpan, WordpieceSpan(3, 8));
+
+ wordpieceSpan = internal::FindWordpiecesWindowAroundSpan(
+ {6, 7}, tokens, word_starts,
+ /*num_wordpieces=*/12,
+ /*max_num_wordpieces_in_window=*/6);
+ EXPECT_EQ(wordpieceSpan, WordpieceSpan(3, 9));
+
+ wordpieceSpan = internal::FindWordpiecesWindowAroundSpan(
+ {12, 14}, tokens, word_starts,
+ /*num_wordpieces=*/12,
+ /*max_num_wordpieces_in_window=*/3);
+ EXPECT_EQ(wordpieceSpan, WordpieceSpan(9, 12));
+}
+
+TEST(PodNerUtilsTest, FindWordpiecesWindowAroundSpanCloseToStart) {
+ std::vector<Token> tokens{{"a", 0, 1}, {"b", 2, 3}, {"c", 4, 5},
+ {"d", 6, 7}, {"e", 8, 9}, {"f", 10, 11},
+ {"my", 12, 14}, {"name", 15, 19}};
+ std::vector<int32_t> word_starts{0, 2, 3, 5, 6, 7, 10, 11};
+
+ WordpieceSpan wordpieceSpan = internal::FindWordpiecesWindowAroundSpan(
+ {2, 3}, tokens, word_starts,
+ /*num_wordpieces=*/12,
+ /*max_num_wordpieces_in_window=*/7);
+ EXPECT_EQ(wordpieceSpan, WordpieceSpan(0, 7));
+}
+
+TEST(PodNerUtilsTest, FindWordpiecesWindowAroundSpanCloseToEnd) {
+ std::vector<Token> tokens{{"a", 0, 1}, {"b", 2, 3}, {"c", 4, 5},
+ {"d", 6, 7}, {"e", 8, 9}, {"f", 10, 11},
+ {"my", 12, 14}, {"name", 15, 19}};
+ std::vector<int32_t> word_starts{0, 2, 3, 5, 6, 7, 10, 11};
+
+ WordpieceSpan wordpieceSpan = internal::FindWordpiecesWindowAroundSpan(
+ {15, 19}, tokens, word_starts,
+ /*num_wordpieces=*/12,
+ /*max_num_wordpieces_in_window=*/7);
+ EXPECT_EQ(wordpieceSpan, WordpieceSpan(5, 12));
+}
+
+TEST(PodNerUtilsTest, FindWordpiecesWindowAroundSpanBigSpan) {
+ std::vector<Token> tokens{{"a", 0, 1}, {"b", 2, 3}, {"c", 4, 5},
+ {"d", 6, 7}, {"e", 8, 9}, {"f", 10, 11},
+ {"my", 12, 14}, {"name", 15, 19}};
+ std::vector<int32_t> word_starts{0, 2, 3, 5, 6, 7, 10, 11};
+
+ WordpieceSpan wordpieceSpan = internal::FindWordpiecesWindowAroundSpan(
+ {0, 19}, tokens, word_starts,
+ /*num_wordpieces=*/12,
+ /*max_num_wordpieces_in_window=*/5);
+ EXPECT_EQ(wordpieceSpan, WordpieceSpan(0, 12));
+}
+
+TEST(PodNerUtilsTest, FindFullTokensSpanInWindow) {
+ std::vector<int32_t> word_starts{0, 2, 3, 5, 6, 7, 10, 11};
+ int first_token_index, num_tokens;
+ WordpieceSpan updated_wordpiece_span = internal::FindFullTokensSpanInWindow(
+ word_starts, /*wordpiece_span=*/{0, 6},
+ /*max_num_wordpieces=*/6, /*num_wordpieces=*/12, &first_token_index,
+ &num_tokens);
+ EXPECT_EQ(updated_wordpiece_span, WordpieceSpan(0, 6));
+ EXPECT_EQ(first_token_index, 0);
+ EXPECT_EQ(num_tokens, 4);
+
+ updated_wordpiece_span = internal::FindFullTokensSpanInWindow(
+ word_starts, /*wordpiece_span=*/{2, 6},
+ /*max_num_wordpieces=*/6, /*num_wordpieces=*/12, &first_token_index,
+ &num_tokens);
+ EXPECT_EQ(updated_wordpiece_span, WordpieceSpan(2, 6));
+ EXPECT_EQ(first_token_index, 1);
+ EXPECT_EQ(num_tokens, 3);
+}
+
+TEST(PodNerUtilsTest, FindFullTokensSpanInWindowStartInMiddleOfToken) {
+ std::vector<int32_t> word_starts{0, 2, 3, 5, 6, 7, 10, 11};
+ int first_token_index, num_tokens;
+ WordpieceSpan updated_wordpiece_span = internal::FindFullTokensSpanInWindow(
+ word_starts, /*wordpiece_span=*/{1, 6},
+ /*max_num_wordpieces=*/6, /*num_wordpieces=*/12, &first_token_index,
+ &num_tokens);
+ EXPECT_EQ(updated_wordpiece_span, WordpieceSpan(0, 6));
+ EXPECT_EQ(first_token_index, 0);
+ EXPECT_EQ(num_tokens, 4);
+}
+
+TEST(PodNerUtilsTest, FindFullTokensSpanInWindowEndsInMiddleOfToken) {
+ std::vector<int32_t> word_starts{0, 2, 3, 5, 6, 7, 10, 11};
+ int first_token_index, num_tokens;
+ WordpieceSpan updated_wordpiece_span = internal::FindFullTokensSpanInWindow(
+ word_starts, /*wordpiece_span=*/{1, 9},
+ /*max_num_wordpieces=*/6, /*num_wordpieces=*/12, &first_token_index,
+ &num_tokens);
+ EXPECT_EQ(updated_wordpiece_span, WordpieceSpan(0, 6));
+ EXPECT_EQ(first_token_index, 0);
+ EXPECT_EQ(num_tokens, 4);
+}
+TEST(PodNerUtilsTest, FindFirstFullTokenIndexSizeOne) {
+ std::vector<int32_t> word_starts{1, 2, 3, 5, 6, 7, 10, 11};
+ int index_first_full_token = internal::FindFirstFullTokenIndex(
+ word_starts, /*first_wordpiece_index=*/2);
+ EXPECT_EQ(index_first_full_token, 1);
+}
+
+TEST(PodNerUtilsTest, FindFirstFullTokenIndexFirst) {
+ std::vector<int32_t> word_starts{1, 2, 3, 5, 6, 7, 10, 11};
+ int index_first_full_token = internal::FindFirstFullTokenIndex(
+ word_starts, /*first_wordpiece_index=*/0);
+ EXPECT_EQ(index_first_full_token, 0);
+}
+
+TEST(PodNerUtilsTest, FindFirstFullTokenIndexSizeGreaterThanOne) {
+ std::vector<int32_t> word_starts{1, 2, 3, 5, 6, 7, 10, 11};
+ int index_first_full_token = internal::FindFirstFullTokenIndex(
+ word_starts, /*first_wordpiece_index=*/4);
+ EXPECT_EQ(index_first_full_token, 2);
+}
+
+TEST(PodNerUtilsTest, FindLastFullTokenIndexSizeOne) {
+ std::vector<int32_t> word_starts{1, 2, 3, 5, 6, 7, 10, 11};
+ int index_last_full_token = internal::FindLastFullTokenIndex(
+ word_starts, /*num_wordpieces=*/12, /*wordpiece_end=*/3);
+ EXPECT_EQ(index_last_full_token, 1);
+}
+
+TEST(PodNerUtilsTest, FindLastFullTokenIndexSizeGreaterThanOne) {
+ std::vector<int32_t> word_starts{1, 3, 4, 6, 8, 9};
+ int index_last_full_token = internal::FindLastFullTokenIndex(
+ word_starts, /*num_wordpieces=*/10, /*wordpiece_end=*/6);
+ EXPECT_EQ(index_last_full_token, 2);
+
+ index_last_full_token = internal::FindLastFullTokenIndex(
+ word_starts, /*num_wordpieces=*/10, /*wordpiece_end=*/7);
+ EXPECT_EQ(index_last_full_token, 2);
+
+ index_last_full_token = internal::FindLastFullTokenIndex(
+ word_starts, /*num_wordpieces=*/10, /*wordpiece_end=*/5);
+ EXPECT_EQ(index_last_full_token, 1);
+}
+
+TEST(PodNerUtilsTest, FindLastFullTokenIndexLast) {
+ std::vector<int32_t> word_starts{1, 2, 3, 5, 6, 7, 10, 11};
+ int index_last_full_token = internal::FindLastFullTokenIndex(
+ word_starts, /*num_wordpieces=*/12, /*wordpiece_end=*/12);
+ EXPECT_EQ(index_last_full_token, 7);
+
+ index_last_full_token = internal::FindLastFullTokenIndex(
+ word_starts, /*num_wordpieces=*/14, /*wordpiece_end=*/14);
+ EXPECT_EQ(index_last_full_token, 7);
+}
+
+TEST(PodNerUtilsTest, FindLastFullTokenIndexBeforeLast) {
+ std::vector<int32_t> word_starts{1, 2, 3, 5, 6, 7, 10, 11};
+ int index_last_full_token = internal::FindLastFullTokenIndex(
+ word_starts, /*num_wordpieces=*/15, /*wordpiece_end=*/12);
+ EXPECT_EQ(index_last_full_token, 6);
+}
+
+TEST(PodNerUtilsTest, ExpandWindowAndAlignSequenceSmallerThanMax) {
+ WordpieceSpan maxWordpieceSpan = internal::ExpandWindowAndAlign(
+ /*max_num_wordpieces_in_window=*/10, /*num_wordpieces=*/8,
+ /*wordpiece_span_to_expand=*/{2, 5});
+ EXPECT_EQ(maxWordpieceSpan, WordpieceSpan(0, 8));
+}
+
+TEST(PodNerUtilsTest, ExpandWindowAndAlignWindowLengthGreaterThanMax) {
+ WordpieceSpan maxWordpieceSpan = internal::ExpandWindowAndAlign(
+ /*max_num_wordpieces_in_window=*/10, /*num_wordpieces=*/100,
+ /*wordpiece_span_to_expand=*/{2, 51});
+ EXPECT_EQ(maxWordpieceSpan, WordpieceSpan(2, 51));
+}
+
+TEST(PodNerUtilsTest, ExpandWindowAndAlignFirstIndexCloseToStart) {
+ WordpieceSpan maxWordpieceSpan = internal::ExpandWindowAndAlign(
+ /*max_num_wordpieces_in_window=*/10, /*num_wordpieces=*/20,
+ /*wordpiece_span_to_expand=*/{2, 4});
+ EXPECT_EQ(maxWordpieceSpan, WordpieceSpan(0, 10));
+}
+
+TEST(PodNerUtilsTest, ExpandWindowAndAlignFirstIndexCloseToEnd) {
+ WordpieceSpan maxWordpieceSpan = internal::ExpandWindowAndAlign(
+ /*max_num_wordpieces_in_window=*/10, /*num_wordpieces=*/20,
+ /*wordpiece_span_to_expand=*/{18, 20});
+ EXPECT_EQ(maxWordpieceSpan, WordpieceSpan(10, 20));
+}
+
+TEST(PodNerUtilsTest, ExpandWindowAndAlignFirstIndexInTheMiddle) {
+ int window_first_wordpiece_index = 10;
+ int window_last_wordpiece_index = 11;
+ WordpieceSpan maxWordpieceSpan = internal::ExpandWindowAndAlign(
+ /*max_num_wordpieces_in_window=*/10, /*num_wordpieces=*/20,
+ /*wordpiece_span_to_expand=*/{10, 12});
+ EXPECT_EQ(maxWordpieceSpan, WordpieceSpan(6, 16));
+
+ window_first_wordpiece_index = 10;
+ window_last_wordpiece_index = 12;
+ maxWordpieceSpan = internal::ExpandWindowAndAlign(
+ /*max_num_wordpieces_in_window=*/10, /*num_wordpieces=*/20,
+ /*wordpiece_span_to_expand=*/{10, 13});
+ EXPECT_EQ(maxWordpieceSpan, WordpieceSpan(7, 17));
+}
+
+TEST(PodNerUtilsTest, WindowGenerator) {
+ std::vector<int32_t> wordpiece_indices = {10, 20, 30, 40, 50, 60, 70, 80};
+ std::vector<Token> tokens{{"a", 0, 1}, {"b", 2, 3}, {"c", 4, 5},
+ {"d", 6, 7}, {"e", 8, 9}, {"f", 10, 11}};
+ std::vector<int32_t> token_starts{0, 2, 3, 5, 6, 7};
+ WindowGenerator window_generator(wordpiece_indices, token_starts, tokens,
+ /*max_num_wordpieces=*/4,
+ /*sliding_window_overlap=*/1,
+ /*span_of_interest=*/{0, 12});
+ VectorSpan<int32_t> cur_wordpiece_indices;
+ VectorSpan<int32_t> cur_token_starts;
+ VectorSpan<Token> cur_tokens;
+ ASSERT_TRUE(window_generator.Next(&cur_wordpiece_indices, &cur_token_starts,
+ &cur_tokens));
+ ASSERT_FALSE(window_generator.Done());
+ ASSERT_EQ(cur_wordpiece_indices.size(), 3);
+ for (int i = 0; i < 3; i++) {
+ ASSERT_EQ(cur_wordpiece_indices[i], wordpiece_indices[i]);
+ }
+ ASSERT_EQ(cur_token_starts.size(), 2);
+ ASSERT_EQ(cur_tokens.size(), 2);
+ for (int i = 0; i < cur_tokens.size(); i++) {
+ ASSERT_EQ(cur_token_starts[i], token_starts[i]);
+ ASSERT_EQ(cur_tokens[i], tokens[i]);
+ }
+
+ ASSERT_TRUE(window_generator.Next(&cur_wordpiece_indices, &cur_token_starts,
+ &cur_tokens));
+ ASSERT_FALSE(window_generator.Done());
+ ASSERT_EQ(cur_wordpiece_indices.size(), 4);
+ for (int i = 0; i < cur_wordpiece_indices.size(); i++) {
+ ASSERT_EQ(cur_wordpiece_indices[i], wordpiece_indices[i + 2]);
+ }
+ ASSERT_EQ(cur_token_starts.size(), 3);
+ ASSERT_EQ(cur_tokens.size(), 3);
+ for (int i = 0; i < cur_tokens.size(); i++) {
+ ASSERT_EQ(cur_token_starts[i], token_starts[i + 1]);
+ ASSERT_EQ(cur_tokens[i], tokens[i + 1]);
+ }
+
+ ASSERT_TRUE(window_generator.Next(&cur_wordpiece_indices, &cur_token_starts,
+ &cur_tokens));
+ ASSERT_TRUE(window_generator.Done());
+ ASSERT_EQ(cur_wordpiece_indices.size(), 3);
+ for (int i = 0; i < cur_wordpiece_indices.size(); i++) {
+ ASSERT_EQ(cur_wordpiece_indices[i], wordpiece_indices[i + 5]);
+ }
+ ASSERT_EQ(cur_token_starts.size(), 3);
+ ASSERT_EQ(cur_tokens.size(), 3);
+ for (int i = 0; i < cur_tokens.size(); i++) {
+ ASSERT_EQ(cur_token_starts[i], token_starts[i + 3]);
+ ASSERT_EQ(cur_tokens[i], tokens[i + 3]);
+ }
+
+ ASSERT_FALSE(window_generator.Next(&cur_wordpiece_indices, &cur_token_starts,
+ &cur_tokens));
+}
+} // namespace
+} // namespace libtextclassifier3
diff --git a/native/annotator/strip-unpaired-brackets.cc b/native/annotator/strip-unpaired-brackets.cc
index b1067ad..8bf93d9 100644
--- a/native/annotator/strip-unpaired-brackets.cc
+++ b/native/annotator/strip-unpaired-brackets.cc
@@ -22,59 +22,23 @@
#include "utils/utf8/unicodetext.h"
namespace libtextclassifier3 {
-namespace {
-// Returns true if given codepoint is contained in the given span in context.
-bool IsCodepointInSpan(const char32 codepoint,
- const UnicodeText& context_unicode,
- const CodepointSpan span) {
- auto begin_it = context_unicode.begin();
- std::advance(begin_it, span.first);
- auto end_it = context_unicode.begin();
- std::advance(end_it, span.second);
-
- return std::find(begin_it, end_it, codepoint) != end_it;
-}
-
-// Returns the first codepoint of the span.
-char32 FirstSpanCodepoint(const UnicodeText& context_unicode,
- const CodepointSpan span) {
- auto it = context_unicode.begin();
- std::advance(it, span.first);
- return *it;
-}
-
-// Returns the last codepoint of the span.
-char32 LastSpanCodepoint(const UnicodeText& context_unicode,
- const CodepointSpan span) {
- auto it = context_unicode.begin();
- std::advance(it, span.second - 1);
- return *it;
-}
-
-} // namespace
-
-CodepointSpan StripUnpairedBrackets(const std::string& context,
- CodepointSpan span, const UniLib& unilib) {
- const UnicodeText context_unicode =
- UTF8ToUnicodeText(context, /*do_copy=*/false);
- return StripUnpairedBrackets(context_unicode, span, unilib);
-}
-
-// If the first or the last codepoint of the given span is a bracket, the
-// bracket is stripped if the span does not contain its corresponding paired
-// version.
-CodepointSpan StripUnpairedBrackets(const UnicodeText& context_unicode,
- CodepointSpan span, const UniLib& unilib) {
- if (context_unicode.empty() || !ValidNonEmptySpan(span)) {
+CodepointSpan StripUnpairedBrackets(
+ const UnicodeText::const_iterator& span_begin,
+ const UnicodeText::const_iterator& span_end, CodepointSpan span,
+ const UniLib& unilib) {
+ if (span_begin == span_end || !span.IsValid() || span.IsEmpty()) {
return span;
}
- const char32 begin_char = FirstSpanCodepoint(context_unicode, span);
+ UnicodeText::const_iterator begin = span_begin;
+ const UnicodeText::const_iterator end = span_end;
+ const char32 begin_char = *begin;
const char32 paired_begin_char = unilib.GetPairedBracket(begin_char);
if (paired_begin_char != begin_char) {
if (!unilib.IsOpeningBracket(begin_char) ||
- !IsCodepointInSpan(paired_begin_char, context_unicode, span)) {
+ std::find(begin, end, paired_begin_char) == end) {
+ ++begin;
++span.first;
}
}
@@ -83,11 +47,11 @@
return span;
}
- const char32 end_char = LastSpanCodepoint(context_unicode, span);
+ const char32 end_char = *std::prev(end);
const char32 paired_end_char = unilib.GetPairedBracket(end_char);
if (paired_end_char != end_char) {
if (!unilib.IsClosingBracket(end_char) ||
- !IsCodepointInSpan(paired_end_char, context_unicode, span)) {
+ std::find(begin, end, paired_end_char) == end) {
--span.second;
}
}
@@ -102,4 +66,21 @@
return span;
}
+CodepointSpan StripUnpairedBrackets(const UnicodeText& context,
+ CodepointSpan span, const UniLib& unilib) {
+ if (!span.IsValid() || span.IsEmpty()) {
+ return span;
+ }
+ const UnicodeText span_text = UnicodeText::Substring(
+ context, span.first, span.second, /*do_copy=*/false);
+ return StripUnpairedBrackets(span_text.begin(), span_text.end(), span,
+ unilib);
+}
+
+CodepointSpan StripUnpairedBrackets(const std::string& context,
+ CodepointSpan span, const UniLib& unilib) {
+ return StripUnpairedBrackets(UTF8ToUnicodeText(context, /*do_copy=*/false),
+ span, unilib);
+}
+
} // namespace libtextclassifier3
diff --git a/native/annotator/strip-unpaired-brackets.h b/native/annotator/strip-unpaired-brackets.h
index ceb8d60..c6cdc1a 100644
--- a/native/annotator/strip-unpaired-brackets.h
+++ b/native/annotator/strip-unpaired-brackets.h
@@ -23,14 +23,21 @@
#include "utils/utf8/unilib.h"
namespace libtextclassifier3 {
+
// If the first or the last codepoint of the given span is a bracket, the
// bracket is stripped if the span does not contain its corresponding paired
// version.
-CodepointSpan StripUnpairedBrackets(const std::string& context,
+CodepointSpan StripUnpairedBrackets(
+ const UnicodeText::const_iterator& span_begin,
+ const UnicodeText::const_iterator& span_end, CodepointSpan span,
+ const UniLib& unilib);
+
+// Same as above but takes a UnicodeText instance for the span.
+CodepointSpan StripUnpairedBrackets(const UnicodeText& context,
CodepointSpan span, const UniLib& unilib);
-// Same as above but takes UnicodeText instance directly.
-CodepointSpan StripUnpairedBrackets(const UnicodeText& context_unicode,
+// Same as above but takes a string instance.
+CodepointSpan StripUnpairedBrackets(const std::string& context,
CodepointSpan span, const UniLib& unilib);
} // namespace libtextclassifier3
diff --git a/native/annotator/strip-unpaired-brackets_test.cc b/native/annotator/strip-unpaired-brackets_test.cc
index 32585ce..a7a3d29 100644
--- a/native/annotator/strip-unpaired-brackets_test.cc
+++ b/native/annotator/strip-unpaired-brackets_test.cc
@@ -30,36 +30,36 @@
TEST_F(StripUnpairedBracketsTest, StripUnpairedBrackets) {
// If the brackets match, nothing gets stripped.
EXPECT_EQ(StripUnpairedBrackets("call me (123) 456 today", {8, 17}, unilib_),
- std::make_pair(8, 17));
+ CodepointSpan(8, 17));
EXPECT_EQ(StripUnpairedBrackets("call me (123 456) today", {8, 17}, unilib_),
- std::make_pair(8, 17));
+ CodepointSpan(8, 17));
// If the brackets don't match, they get stripped.
EXPECT_EQ(StripUnpairedBrackets("call me (123 456 today", {8, 16}, unilib_),
- std::make_pair(9, 16));
+ CodepointSpan(9, 16));
EXPECT_EQ(StripUnpairedBrackets("call me )123 456 today", {8, 16}, unilib_),
- std::make_pair(9, 16));
+ CodepointSpan(9, 16));
EXPECT_EQ(StripUnpairedBrackets("call me 123 456) today", {8, 16}, unilib_),
- std::make_pair(8, 15));
+ CodepointSpan(8, 15));
EXPECT_EQ(StripUnpairedBrackets("call me 123 456( today", {8, 16}, unilib_),
- std::make_pair(8, 15));
+ CodepointSpan(8, 15));
// Strips brackets correctly from length-1 selections that consist of
// a bracket only.
EXPECT_EQ(StripUnpairedBrackets("call me at ) today", {11, 12}, unilib_),
- std::make_pair(12, 12));
+ CodepointSpan(12, 12));
EXPECT_EQ(StripUnpairedBrackets("call me at ( today", {11, 12}, unilib_),
- std::make_pair(12, 12));
+ CodepointSpan(12, 12));
// Handles invalid spans gracefully.
EXPECT_EQ(StripUnpairedBrackets("call me at today", {11, 11}, unilib_),
- std::make_pair(11, 11));
+ CodepointSpan(11, 11));
EXPECT_EQ(StripUnpairedBrackets("hello world", {0, 0}, unilib_),
- std::make_pair(0, 0));
+ CodepointSpan(0, 0));
EXPECT_EQ(StripUnpairedBrackets("hello world", {11, 11}, unilib_),
- std::make_pair(11, 11));
+ CodepointSpan(11, 11));
EXPECT_EQ(StripUnpairedBrackets("hello world", {-1, -1}, unilib_),
- std::make_pair(-1, -1));
+ CodepointSpan(-1, -1));
}
} // namespace
diff --git a/native/annotator/test-utils.h b/native/annotator/test-utils.h
new file mode 100644
index 0000000..d63e66e
--- /dev/null
+++ b/native/annotator/test-utils.h
@@ -0,0 +1,79 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_TEST_UTILS_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_TEST_UTILS_H_
+
+#include "annotator/types.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+
+using ::testing::Value;
+
+MATCHER_P3(IsAnnotatedSpan, start, end, best_class, "") {
+ const std::string first_result = arg.classification.empty()
+ ? "<INVALID RESULTS>"
+ : arg.classification[0].collection;
+ return Value(arg.span, CodepointSpan(start, end)) &&
+ Value(first_result, best_class);
+}
+
+MATCHER_P(IsAnnotationWithType, best_class, "") {
+ const std::string first_result = arg.classification.empty()
+ ? "<INVALID RESULTS>"
+ : arg.classification[0].collection;
+ return Value(first_result, best_class);
+}
+
+MATCHER_P2(IsDateResult, time_ms_utc, granularity, "") {
+ return Value(arg.collection, "date") &&
+ Value(arg.datetime_parse_result.time_ms_utc, time_ms_utc) &&
+ Value(arg.datetime_parse_result.granularity, granularity);
+}
+
+MATCHER_P2(IsDatetimeResult, time_ms_utc, granularity, "") {
+ return Value(arg.collection, "datetime") &&
+ Value(arg.datetime_parse_result.time_ms_utc, time_ms_utc) &&
+ Value(arg.datetime_parse_result.granularity, granularity);
+}
+
+MATCHER_P3(IsDurationSpan, start, end, duration_ms, "") {
+ if (arg.classification.empty()) {
+ return false;
+ }
+ return ExplainMatchResult(IsAnnotatedSpan(start, end, "duration"), arg,
+ result_listener) &&
+ arg.classification[0].duration_ms == duration_ms;
+}
+
+MATCHER_P4(IsDatetimeSpan, start, end, time_ms_utc, granularity, "") {
+ if (arg.classification.empty()) {
+ return false;
+ }
+ return ExplainMatchResult(IsAnnotatedSpan(start, end, "datetime"), arg,
+ result_listener) &&
+ arg.classification[0].datetime_parse_result.time_ms_utc ==
+ time_ms_utc &&
+ arg.classification[0].datetime_parse_result.granularity == granularity;
+}
+
+MATCHER_P2(IsBetween, low, high, "") { return low < arg && arg < high; }
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_TEST_UTILS_H_
diff --git a/native/annotator/test_data/datetime.fb b/native/annotator/test_data/datetime.fb
new file mode 100644
index 0000000..5828b94
--- /dev/null
+++ b/native/annotator/test_data/datetime.fb
Binary files differ
diff --git a/native/annotator/test_data/lang_id.smfb b/native/annotator/test_data/lang_id.smfb
new file mode 100644
index 0000000..e94dada
--- /dev/null
+++ b/native/annotator/test_data/lang_id.smfb
Binary files differ
diff --git a/native/annotator/test_data/test_model.fb b/native/annotator/test_data/test_model.fb
new file mode 100644
index 0000000..2e9418e
--- /dev/null
+++ b/native/annotator/test_data/test_model.fb
Binary files differ
diff --git a/native/annotator/test_data/test_vocab_model.fb b/native/annotator/test_data/test_vocab_model.fb
new file mode 100644
index 0000000..d9d9a94
--- /dev/null
+++ b/native/annotator/test_data/test_vocab_model.fb
Binary files differ
diff --git a/native/annotator/test_data/wrong_embeddings.fb b/native/annotator/test_data/wrong_embeddings.fb
new file mode 100644
index 0000000..dfb7369
--- /dev/null
+++ b/native/annotator/test_data/wrong_embeddings.fb
Binary files differ
diff --git a/native/annotator/translate/translate_test.cc b/native/annotator/translate/translate_test.cc
new file mode 100644
index 0000000..5c4a63f
--- /dev/null
+++ b/native/annotator/translate/translate_test.cc
@@ -0,0 +1,208 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/translate/translate.h"
+
+#include <memory>
+
+#include "annotator/model_generated.h"
+#include "utils/test-data-test-utils.h"
+#include "lang_id/fb_model/lang-id-from-fb.h"
+#include "lang_id/lang-id.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+using testing::AllOf;
+using testing::Field;
+
+const TranslateAnnotatorOptions* TestingTranslateAnnotatorOptions() {
+ static const flatbuffers::DetachedBuffer* options_data = []() {
+ TranslateAnnotatorOptionsT options;
+ options.enabled = true;
+ options.algorithm = TranslateAnnotatorOptions_::Algorithm_BACKOFF;
+ options.backoff_options.reset(
+ new TranslateAnnotatorOptions_::BackoffOptionsT());
+
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(TranslateAnnotatorOptions::Pack(builder, &options));
+ return new flatbuffers::DetachedBuffer(builder.Release());
+ }();
+
+ return flatbuffers::GetRoot<TranslateAnnotatorOptions>(options_data->data());
+}
+
+class TestingTranslateAnnotator : public TranslateAnnotator {
+ public:
+ // Make these protected members public for tests.
+ using TranslateAnnotator::BackoffDetectLanguages;
+ using TranslateAnnotator::FindIndexOfNextWhitespaceOrPunctuation;
+ using TranslateAnnotator::TokenAlignedSubstringAroundSpan;
+ using TranslateAnnotator::TranslateAnnotator;
+};
+
+std::string GetModelPath() { return GetTestDataPath("annotator/test_data/"); }
+
+class TranslateAnnotatorTest : public ::testing::Test {
+ protected:
+ TranslateAnnotatorTest()
+ : INIT_UNILIB_FOR_TESTING(unilib_),
+ langid_model_(libtextclassifier3::mobile::lang_id::GetLangIdFromFlatbufferFile(
+ GetModelPath() + "lang_id.smfb")),
+ translate_annotator_(TestingTranslateAnnotatorOptions(),
+ langid_model_.get(), &unilib_) {}
+
+ UniLib unilib_;
+ std::unique_ptr<libtextclassifier3::mobile::lang_id::LangId> langid_model_;
+ TestingTranslateAnnotator translate_annotator_;
+};
+
+TEST_F(TranslateAnnotatorTest, WhenSpeaksEnglishGetsTranslateActionForCzech) {
+ ClassificationResult classification;
+ EXPECT_TRUE(translate_annotator_.ClassifyText(
+ UTF8ToUnicodeText("Třista třicet tři stříbrných stříkaček."), {18, 28},
+ "en", &classification));
+
+ EXPECT_THAT(classification,
+ AllOf(Field(&ClassificationResult::collection, "translate")));
+ const EntityData* entity_data =
+ GetEntityData(classification.serialized_entity_data.data());
+ const auto predictions =
+ entity_data->translate()->language_prediction_results();
+ EXPECT_EQ(predictions->size(), 1);
+ EXPECT_EQ(predictions->Get(0)->language_tag()->str(), "cs");
+ EXPECT_GT(predictions->Get(0)->confidence_score(), 0);
+ EXPECT_LE(predictions->Get(0)->confidence_score(), 1);
+}
+
+TEST_F(TranslateAnnotatorTest, EntityDataIsSet) {
+ ClassificationResult classification;
+ EXPECT_TRUE(translate_annotator_.ClassifyText(UTF8ToUnicodeText("学校"),
+ {0, 2}, "en", &classification));
+
+ EXPECT_THAT(classification,
+ AllOf(Field(&ClassificationResult::collection, "translate")));
+ const EntityData* entity_data =
+ GetEntityData(classification.serialized_entity_data.data());
+ const auto predictions =
+ entity_data->translate()->language_prediction_results();
+ EXPECT_EQ(predictions->size(), 2);
+ EXPECT_EQ(predictions->Get(0)->language_tag()->str(), "zh");
+ EXPECT_GT(predictions->Get(0)->confidence_score(), 0);
+ EXPECT_LE(predictions->Get(0)->confidence_score(), 1);
+ EXPECT_EQ(predictions->Get(1)->language_tag()->str(), "ja");
+ EXPECT_TRUE(predictions->Get(0)->confidence_score() >=
+ predictions->Get(1)->confidence_score());
+}
+
+TEST_F(TranslateAnnotatorTest,
+ WhenSpeaksEnglishDoesntGetTranslateActionForEnglish) {
+ ClassificationResult classification;
+ EXPECT_FALSE(translate_annotator_.ClassifyText(
+ UTF8ToUnicodeText("This is utterly unutterable."), {8, 15}, "en",
+ &classification));
+}
+
+TEST_F(TranslateAnnotatorTest,
+ WhenSpeaksMultipleAndNotCzechGetsTranslateActionForCzech) {
+ ClassificationResult classification;
+ EXPECT_TRUE(translate_annotator_.ClassifyText(
+ UTF8ToUnicodeText("Třista třicet tři stříbrných stříkaček."), {8, 15},
+ "de,en,ja", &classification));
+
+ EXPECT_THAT(classification,
+ AllOf(Field(&ClassificationResult::collection, "translate")));
+}
+
+TEST_F(TranslateAnnotatorTest,
+ WhenSpeaksMultipleAndEnglishDoesntGetTranslateActionForEnglish) {
+ ClassificationResult classification;
+ EXPECT_FALSE(translate_annotator_.ClassifyText(
+ UTF8ToUnicodeText("This is utterly unutterable."), {8, 15}, "cs,en,de,ja",
+ &classification));
+}
+
+TEST_F(TranslateAnnotatorTest, FindIndexOfNextWhitespaceOrPunctuation) {
+ const UnicodeText text =
+ UTF8ToUnicodeText("Třista třicet, tři stříbrných stříkaček");
+
+ EXPECT_EQ(
+ translate_annotator_.FindIndexOfNextWhitespaceOrPunctuation(text, 0, -1),
+ text.begin());
+ EXPECT_EQ(
+ translate_annotator_.FindIndexOfNextWhitespaceOrPunctuation(text, 35, 1),
+ text.end());
+ EXPECT_EQ(
+ translate_annotator_.FindIndexOfNextWhitespaceOrPunctuation(text, 10, -1),
+ std::next(text.begin(), 6));
+ EXPECT_EQ(
+ translate_annotator_.FindIndexOfNextWhitespaceOrPunctuation(text, 10, 1),
+ std::next(text.begin(), 13));
+}
+
+TEST_F(TranslateAnnotatorTest, TokenAlignedSubstringAroundSpan) {
+ const UnicodeText text =
+ UTF8ToUnicodeText("Třista třicet, tři stříbrných stříkaček");
+
+ EXPECT_EQ(translate_annotator_.TokenAlignedSubstringAroundSpan(
+ text, {35, 37}, /*minimum_length=*/100),
+ text);
+ EXPECT_EQ(translate_annotator_.TokenAlignedSubstringAroundSpan(
+ text, {35, 37}, /*minimum_length=*/0),
+ UTF8ToUnicodeText("ač"));
+ EXPECT_EQ(translate_annotator_.TokenAlignedSubstringAroundSpan(
+ text, {35, 37}, /*minimum_length=*/3),
+ UTF8ToUnicodeText("stříkaček"));
+ EXPECT_EQ(translate_annotator_.TokenAlignedSubstringAroundSpan(
+ text, {35, 37}, /*minimum_length=*/10),
+ UTF8ToUnicodeText("stříkaček"));
+ EXPECT_EQ(translate_annotator_.TokenAlignedSubstringAroundSpan(
+ text, {35, 37}, /*minimum_length=*/11),
+ UTF8ToUnicodeText("stříbrných stříkaček"));
+
+ const UnicodeText text_no_whitespace =
+ UTF8ToUnicodeText("reallyreallylongstring");
+ EXPECT_EQ(translate_annotator_.TokenAlignedSubstringAroundSpan(
+ text_no_whitespace, {10, 11}, /*minimum_length=*/2),
+ UTF8ToUnicodeText("reallyreallylongstring"));
+}
+
+TEST_F(TranslateAnnotatorTest, TokenAlignedSubstringWhitespaceText) {
+ const UnicodeText text = UTF8ToUnicodeText(" ");
+
+ // Shouldn't modify the selection in case it's all whitespace.
+ EXPECT_EQ(translate_annotator_.TokenAlignedSubstringAroundSpan(
+ text, {5, 7}, /*minimum_length=*/3),
+ UTF8ToUnicodeText(" "));
+ EXPECT_EQ(translate_annotator_.TokenAlignedSubstringAroundSpan(
+ text, {5, 5}, /*minimum_length=*/1),
+ UTF8ToUnicodeText(""));
+}
+
+TEST_F(TranslateAnnotatorTest, TokenAlignedSubstringMostlyWhitespaceText) {
+ const UnicodeText text = UTF8ToUnicodeText("a a");
+
+ // Should still select the whole text even if pointing to whitespace
+ // initially.
+ EXPECT_EQ(translate_annotator_.TokenAlignedSubstringAroundSpan(
+ text, {5, 7}, /*minimum_length=*/11),
+ UTF8ToUnicodeText("a a"));
+}
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/native/annotator/types-test-util.h b/native/annotator/types-test-util.h
index 1d018a1..55dd214 100644
--- a/native/annotator/types-test-util.h
+++ b/native/annotator/types-test-util.h
@@ -34,9 +34,11 @@
TC3_DECLARE_PRINT_OPERATOR(AnnotatedSpan)
TC3_DECLARE_PRINT_OPERATOR(ClassificationResult)
+TC3_DECLARE_PRINT_OPERATOR(CodepointSpan)
TC3_DECLARE_PRINT_OPERATOR(DatetimeParsedData)
TC3_DECLARE_PRINT_OPERATOR(DatetimeParseResultSpan)
TC3_DECLARE_PRINT_OPERATOR(Token)
+TC3_DECLARE_PRINT_OPERATOR(TokenSpan)
#undef TC3_DECLARE_PRINT_OPERATOR
diff --git a/native/annotator/types.cc b/native/annotator/types.cc
index be542d3..b1dde17 100644
--- a/native/annotator/types.cc
+++ b/native/annotator/types.cc
@@ -22,6 +22,21 @@
namespace libtextclassifier3 {
+const CodepointSpan CodepointSpan::kInvalid =
+ CodepointSpan(kInvalidIndex, kInvalidIndex);
+
+const TokenSpan TokenSpan::kInvalid = TokenSpan(kInvalidIndex, kInvalidIndex);
+
+logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
+ const CodepointSpan& span) {
+ return stream << "CodepointSpan(" << span.first << ", " << span.second << ")";
+}
+
+logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
+ const TokenSpan& span) {
+ return stream << "TokenSpan(" << span.first << ", " << span.second << ")";
+}
+
logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
const Token& token) {
if (!token.is_padding) {
diff --git a/native/annotator/types.h b/native/annotator/types.h
index 665d4b6..45999cd 100644
--- a/native/annotator/types.h
+++ b/native/annotator/types.h
@@ -30,9 +30,10 @@
#include <vector>
#include "annotator/entity-data_generated.h"
+#include "annotator/knowledge/knowledge-engine-types.h"
#include "utils/base/integral_types.h"
#include "utils/base/logging.h"
-#include "utils/flatbuffers.h"
+#include "utils/flatbuffers/flatbuffers.h"
#include "utils/optional.h"
#include "utils/variant.h"
@@ -56,15 +57,53 @@
// Marks a span in a sequence of codepoints. The first element is the index of
// the first codepoint of the span, and the second element is the index of the
// codepoint one past the end of the span.
-// TODO(b/71982294): Make it a struct.
-using CodepointSpan = std::pair<CodepointIndex, CodepointIndex>;
+struct CodepointSpan {
+ static const CodepointSpan kInvalid;
+
+ CodepointSpan() : first(kInvalidIndex), second(kInvalidIndex) {}
+
+ CodepointSpan(CodepointIndex start, CodepointIndex end)
+ : first(start), second(end) {}
+
+ CodepointSpan& operator=(const CodepointSpan& other) = default;
+
+ bool operator==(const CodepointSpan& other) const {
+ return this->first == other.first && this->second == other.second;
+ }
+
+ bool operator!=(const CodepointSpan& other) const {
+ return !(*this == other);
+ }
+
+ bool operator<(const CodepointSpan& other) const {
+ if (this->first != other.first) {
+ return this->first < other.first;
+ }
+ return this->second < other.second;
+ }
+
+ bool IsValid() const {
+ return this->first != kInvalidIndex && this->second != kInvalidIndex &&
+ this->first <= this->second && this->first >= 0;
+ }
+
+ bool IsEmpty() const { return this->first == this->second; }
+
+ CodepointIndex first;
+ CodepointIndex second;
+};
+
+// Pretty-printing function for CodepointSpan.
+logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
+ const CodepointSpan& span);
inline bool SpansOverlap(const CodepointSpan& a, const CodepointSpan& b) {
return a.first < b.second && b.first < a.second;
}
-inline bool ValidNonEmptySpan(const CodepointSpan& span) {
- return span.first < span.second && span.first >= 0 && span.second >= 0;
+inline bool SpanContains(const CodepointSpan& span,
+ const CodepointSpan& sub_span) {
+ return span.first <= sub_span.first && span.second >= sub_span.second;
}
template <typename T>
@@ -102,35 +141,61 @@
// Marks a span in a sequence of tokens. The first element is the index of the
// first token in the span, and the second element is the index of the token one
// past the end of the span.
-// TODO(b/71982294): Make it a struct.
-using TokenSpan = std::pair<TokenIndex, TokenIndex>;
+struct TokenSpan {
+ static const TokenSpan kInvalid;
-// Returns the size of the token span. Assumes that the span is valid.
-inline int TokenSpanSize(const TokenSpan& token_span) {
- return token_span.second - token_span.first;
-}
+ TokenSpan() : first(kInvalidIndex), second(kInvalidIndex) {}
-// Returns a token span consisting of one token.
-inline TokenSpan SingleTokenSpan(int token_index) {
- return {token_index, token_index + 1};
-}
+ TokenSpan(TokenIndex start, TokenIndex end) : first(start), second(end) {}
-// Returns an intersection of two token spans. Assumes that both spans are valid
-// and overlapping.
+ // Creates a token span consisting of one token.
+ explicit TokenSpan(int token_index)
+ : first(token_index), second(token_index + 1) {}
+
+ TokenSpan& operator=(const TokenSpan& other) = default;
+
+ bool operator==(const TokenSpan& other) const {
+ return this->first == other.first && this->second == other.second;
+ }
+
+ bool operator!=(const TokenSpan& other) const { return !(*this == other); }
+
+ bool operator<(const TokenSpan& other) const {
+ if (this->first != other.first) {
+ return this->first < other.first;
+ }
+ return this->second < other.second;
+ }
+
+ bool IsValid() const {
+ return this->first != kInvalidIndex && this->second != kInvalidIndex;
+ }
+
+ // Returns the size of the token span. Assumes that the span is valid.
+ int Size() const { return this->second - this->first; }
+
+ // Returns an expanded token span by adding a certain number of tokens on its
+ // left and on its right.
+ TokenSpan Expand(int num_tokens_left, int num_tokens_right) const {
+ return {this->first - num_tokens_left, this->second + num_tokens_right};
+ }
+
+ TokenIndex first;
+ TokenIndex second;
+};
+
+// Pretty-printing function for TokenSpan.
+logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
+ const TokenSpan& span);
+
+// Returns an intersection of two token spans. Assumes that both spans are
+// valid and overlapping.
inline TokenSpan IntersectTokenSpans(const TokenSpan& token_span1,
const TokenSpan& token_span2) {
return {std::max(token_span1.first, token_span2.first),
std::min(token_span1.second, token_span2.second)};
}
-// Returns and expanded token span by adding a certain number of tokens on its
-// left and on its right.
-inline TokenSpan ExpandTokenSpan(const TokenSpan& token_span,
- int num_tokens_left, int num_tokens_right) {
- return {token_span.first - num_tokens_left,
- token_span.second + num_tokens_right};
-}
-
// Token holds a token, its position in the original string and whether it was
// part of the input span.
struct Token {
@@ -169,7 +234,7 @@
is_padding == other.is_padding;
}
- bool IsContainedInSpan(CodepointSpan span) const {
+ bool IsContainedInSpan(const CodepointSpan& span) const {
return start >= span.first && end <= span.second;
}
};
@@ -178,6 +243,11 @@
logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
const Token& token);
+// Returns a TokenSpan that merges all of the given token spans.
+inline TokenSpan AllOf(const std::vector<Token>& tokens) {
+ return {0, static_cast<TokenIndex>(tokens.size())};
+}
+
enum DatetimeGranularity {
GRANULARITY_UNKNOWN = -1, // GRANULARITY_UNKNOWN is used as a proxy for this
// structure being uninitialized.
@@ -218,9 +288,9 @@
SECOND = 8,
// Meridiem field where 0 == AM, 1 == PM.
MERIDIEM = 9,
- // Number of hours offset from UTC this date time is in.
+ // Offset in number of minutes from UTC this date time is in.
ZONE_OFFSET = 10,
- // Number of hours offest for DST.
+ // Offset in number of hours for DST.
DST_OFFSET = 11,
};
@@ -314,17 +384,18 @@
float priority_score;
DatetimeParseResultSpan()
- : target_classification_score(-1.0), priority_score(-1.0) {}
+ : span(CodepointSpan::kInvalid),
+ target_classification_score(-1.0),
+ priority_score(-1.0) {}
DatetimeParseResultSpan(const CodepointSpan& span,
const std::vector<DatetimeParseResult>& data,
const float target_classification_score,
- const float priority_score) {
- this->span = span;
- this->data = data;
- this->target_classification_score = target_classification_score;
- this->priority_score = priority_score;
- }
+ const float priority_score)
+ : span(span),
+ data(data),
+ target_classification_score(target_classification_score),
+ priority_score(priority_score) {}
bool operator==(const DatetimeParseResultSpan& other) const {
return span == other.span && data == other.data &&
@@ -365,7 +436,8 @@
std::string serialized_knowledge_result;
ContactPointer contact_pointer;
std::string contact_name, contact_given_name, contact_family_name,
- contact_nickname, contact_email_address, contact_phone_number, contact_id;
+ contact_nickname, contact_email_address, contact_phone_number,
+ contact_account_type, contact_account_name, contact_id;
std::string app_name, app_package_name;
int64 numeric_value;
double numeric_double_value;
@@ -456,6 +528,13 @@
// The location context passed along with each annotation.
Optional<LocationContext> location_context;
+ // If true, the POD NER annotator is used.
+ bool use_pod_ner = true;
+
+ // If true and the model file supports that, the new vocab annotator is used
+ // to annotate "Dictionary". Otherwise, we use the FFModel to do so.
+ bool use_vocab_annotator = true;
+
bool operator==(const BaseOptions& other) const {
bool location_context_equality = this->location_context.has_value() ==
other.location_context.has_value();
@@ -468,7 +547,9 @@
this->annotation_usecase == other.annotation_usecase &&
this->detected_text_language_tags ==
other.detected_text_language_tags &&
- location_context_equality;
+ location_context_equality &&
+ this->use_pod_ner == other.use_pod_ner &&
+ this->use_vocab_annotator == other.use_vocab_annotator;
}
};
@@ -493,10 +574,14 @@
// Comma-separated list of language tags which the user can read and
// understand (BCP 47).
std::string user_familiar_language_tags;
+ // If true, trigger dictionary on words that are of beginner level.
+ bool trigger_dictionary_on_beginner_words = false;
bool operator==(const ClassificationOptions& other) const {
return this->user_familiar_language_tags ==
other.user_familiar_language_tags &&
+ this->trigger_dictionary_on_beginner_words ==
+ other.trigger_dictionary_on_beginner_words &&
BaseOptions::operator==(other) && DatetimeOptions::operator==(other);
}
};
@@ -525,11 +610,24 @@
// Defines the permissions for the annotators.
Permissions permissions;
+ AnnotateMode annotate_mode = AnnotateMode::kEntityAnnotation;
+
+ // If true, trigger dictionary on words that are of beginner level.
+ bool trigger_dictionary_on_beginner_words = false;
+
+ // If true, enables an optimized code path for annotation.
+ // The optimization caused crashes previously, which is why we are rolling it
+ // out using this temporary flag. See: b/178503899
+ bool enable_optimization = false;
+
bool operator==(const AnnotationOptions& other) const {
return this->is_serialized_entity_data_enabled ==
other.is_serialized_entity_data_enabled &&
this->permissions == other.permissions &&
this->entity_types == other.entity_types &&
+ this->annotate_mode == other.annotate_mode &&
+ this->trigger_dictionary_on_beginner_words ==
+ other.trigger_dictionary_on_beginner_words &&
BaseOptions::operator==(other) && DatetimeOptions::operator==(other);
}
};
@@ -552,7 +650,7 @@
enum class Source { OTHER, KNOWLEDGE, DURATION, DATETIME, PERSON_NAME };
// Unicode codepoint indices in the input string.
- CodepointSpan span = {kInvalidIndex, kInvalidIndex};
+ CodepointSpan span = CodepointSpan::kInvalid;
// Classification result for the span.
std::vector<ClassificationResult> classification;
@@ -574,8 +672,31 @@
source(arg_source) {}
};
+// Represents Annotations that correspond to all input fragments.
+struct Annotations {
+ // List of annotations found in the corresponding input fragments. For these
+ // annotations, topicality score will not be set.
+ std::vector<std::vector<AnnotatedSpan>> annotated_spans;
+
+ // List of topicality results found across all input fragments.
+ std::vector<ClassificationResult> topicality_results;
+
+ Annotations() = default;
+
+ explicit Annotations(
+ std::vector<std::vector<AnnotatedSpan>> arg_annotated_spans)
+ : annotated_spans(std::move(arg_annotated_spans)) {}
+
+ Annotations(std::vector<std::vector<AnnotatedSpan>> arg_annotated_spans,
+ std::vector<ClassificationResult> arg_topicality_results)
+ : annotated_spans(std::move(arg_annotated_spans)),
+ topicality_results(std::move(arg_topicality_results)) {}
+};
+
struct InputFragment {
std::string text;
+ float bounding_box_top;
+ float bounding_box_height;
// If present will override the AnnotationOptions reference time and timezone
// when annotating this specific string fragment.
@@ -591,7 +712,7 @@
class VectorSpan {
public:
VectorSpan() : begin_(), end_() {}
- VectorSpan(const std::vector<T>& v) // NOLINT(runtime/explicit)
+ explicit VectorSpan(const std::vector<T>& v) // NOLINT(runtime/explicit)
: begin_(v.begin()), end_(v.end()) {}
VectorSpan(typename std::vector<T>::const_iterator begin,
typename std::vector<T>::const_iterator end)
diff --git a/native/annotator/vocab/vocab-annotator-impl.cc b/native/annotator/vocab/vocab-annotator-impl.cc
new file mode 100644
index 0000000..4b5cc73
--- /dev/null
+++ b/native/annotator/vocab/vocab-annotator-impl.cc
@@ -0,0 +1,130 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/vocab/vocab-annotator-impl.h"
+
+#include "annotator/feature-processor.h"
+#include "annotator/model_generated.h"
+#include "utils/base/logging.h"
+#include "utils/optional.h"
+#include "utils/strings/numbers.h"
+
+namespace libtextclassifier3 {
+
+VocabAnnotator::VocabAnnotator(
+ std::unique_ptr<VocabLevelTable> vocab_level_table,
+ const std::vector<Locale>& triggering_locales,
+ const FeatureProcessor& feature_processor, const UniLib& unilib,
+ const VocabModel* model)
+ : vocab_level_table_(std::move(vocab_level_table)),
+ triggering_locales_(triggering_locales),
+ feature_processor_(feature_processor),
+ unilib_(unilib),
+ model_(model) {}
+
+std::unique_ptr<VocabAnnotator> VocabAnnotator::Create(
+ const VocabModel* model, const FeatureProcessor& feature_processor,
+ const UniLib& unilib) {
+ std::unique_ptr<VocabLevelTable> vocab_lebel_table =
+ VocabLevelTable::Create(model);
+ if (vocab_lebel_table == nullptr) {
+ TC3_LOG(ERROR) << "Failed to create vocab level table.";
+ return nullptr;
+ }
+ std::vector<Locale> triggering_locales;
+ if (model->triggering_locales() &&
+ !ParseLocales(model->triggering_locales()->c_str(),
+ &triggering_locales)) {
+ TC3_LOG(ERROR) << "Could not parse model supported locales.";
+ return nullptr;
+ }
+
+ return std::unique_ptr<VocabAnnotator>(
+ new VocabAnnotator(std::move(vocab_lebel_table), triggering_locales,
+ feature_processor, unilib, model));
+}
+
+bool VocabAnnotator::Annotate(
+ const UnicodeText& context,
+ const std::vector<Locale> detected_text_language_tags,
+ bool trigger_on_beginner_words, std::vector<AnnotatedSpan>* results) const {
+ std::vector<Token> tokens = feature_processor_.Tokenize(context);
+ for (const Token& token : tokens) {
+ ClassificationResult classification_result;
+ CodepointSpan stripped_span;
+ bool found = ClassifyTextInternal(
+ context, {token.start, token.end}, detected_text_language_tags,
+ trigger_on_beginner_words, &classification_result, &stripped_span);
+ if (found) {
+ results->push_back(AnnotatedSpan{stripped_span, {classification_result}});
+ }
+ }
+ return true;
+}
+
+bool VocabAnnotator::ClassifyText(
+ const UnicodeText& context, CodepointSpan click,
+ const std::vector<Locale> detected_text_language_tags,
+ bool trigger_on_beginner_words, ClassificationResult* result) const {
+ CodepointSpan stripped_span;
+ return ClassifyTextInternal(context, click, detected_text_language_tags,
+ trigger_on_beginner_words, result,
+ &stripped_span);
+}
+
+bool VocabAnnotator::ClassifyTextInternal(
+ const UnicodeText& context, const CodepointSpan click,
+ const std::vector<Locale> detected_text_language_tags,
+ bool trigger_on_beginner_words, ClassificationResult* classification_result,
+ CodepointSpan* classified_span) const {
+ if (vocab_level_table_ == nullptr) {
+ return false;
+ }
+
+ if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
+ triggering_locales_,
+ /*default_value=*/false)) {
+ return false;
+ }
+ const CodepointSpan stripped_span =
+ feature_processor_.StripBoundaryCodepoints(context,
+ {click.first, click.second});
+ const UnicodeText stripped_token = UnicodeText::Substring(
+ context, stripped_span.first, stripped_span.second, /*do_copy=*/false);
+ const std::string lower_token =
+ unilib_.ToLowerText(stripped_token).ToUTF8String();
+
+ const Optional<LookupResult> result = vocab_level_table_->Lookup(lower_token);
+ if (!result.has_value()) {
+ return false;
+ }
+ if (result.value().do_not_trigger_in_upper_case &&
+ unilib_.IsUpper(*stripped_token.begin())) {
+ TC3_VLOG(INFO) << "Not trigger define: proper noun in upper case.";
+ return false;
+ }
+ if (result.value().beginner_level && !trigger_on_beginner_words) {
+ TC3_VLOG(INFO) << "Not trigger define: for beginner only.";
+ return false;
+ }
+ *classification_result =
+ ClassificationResult("dictionary", model_->target_classification_score(),
+ model_->priority_score());
+ *classified_span = stripped_span;
+
+ return true;
+}
+} // namespace libtextclassifier3
diff --git a/native/annotator/vocab/vocab-annotator-impl.h b/native/annotator/vocab/vocab-annotator-impl.h
new file mode 100644
index 0000000..1a2194a
--- /dev/null
+++ b/native/annotator/vocab/vocab-annotator-impl.h
@@ -0,0 +1,72 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_VOCAB_VOCAB_ANNOTATOR_IMPL_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_VOCAB_VOCAB_ANNOTATOR_IMPL_H_
+
+#include "annotator/feature-processor.h"
+#include "annotator/model_generated.h"
+#include "annotator/types.h"
+#include "annotator/vocab/vocab-level-table.h"
+#include "utils/i18n/locale.h"
+#include "utils/utf8/unicodetext.h"
+#include "utils/utf8/unilib.h"
+
+namespace libtextclassifier3 {
+
+// Annotates vocabs of different levels which users may want to look them up
+// in a dictionary.
+class VocabAnnotator {
+ public:
+ static std::unique_ptr<VocabAnnotator> Create(
+ const VocabModel *model, const FeatureProcessor &feature_processor,
+ const UniLib &unilib);
+
+ bool Annotate(const UnicodeText &context,
+ const std::vector<Locale> detected_text_language_tags,
+ bool trigger_on_beginner_words,
+ std::vector<AnnotatedSpan> *results) const;
+
+ bool ClassifyText(const UnicodeText &context, CodepointSpan click,
+ const std::vector<Locale> detected_text_language_tags,
+ bool trigger_on_beginner_words,
+ ClassificationResult *result) const;
+
+ private:
+ explicit VocabAnnotator(std::unique_ptr<VocabLevelTable> vocab_level_table,
+ const std::vector<Locale> &triggering_locales,
+ const FeatureProcessor &feature_processor,
+ const UniLib &unilib, const VocabModel *model);
+
+ bool ClassifyTextInternal(
+ const UnicodeText &context, const CodepointSpan click,
+ const std::vector<Locale> detected_text_language_tags,
+ bool trigger_on_beginner_words,
+ ClassificationResult *classification_result,
+ CodepointSpan *classified_span) const;
+ bool ShouldTriggerOnBeginnerVocabs() const;
+
+ const std::unique_ptr<VocabLevelTable> vocab_level_table_;
+ // Locales for which this annotator triggers.
+ const std::vector<Locale> triggering_locales_;
+ const FeatureProcessor &feature_processor_;
+ const UniLib &unilib_;
+ const VocabModel *model_;
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_VOCAB_VOCAB_ANNOTATOR_IMPL_H_
diff --git a/native/annotator/vocab/vocab-annotator.h b/native/annotator/vocab/vocab-annotator.h
new file mode 100644
index 0000000..35fa928
--- /dev/null
+++ b/native/annotator/vocab/vocab-annotator.h
@@ -0,0 +1,28 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_VOCAB_VOCAB_ANNOTATOR_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_VOCAB_VOCAB_ANNOTATOR_H_
+
+#if defined TC3_VOCAB_ANNOTATOR_IMPL
+#include "annotator/vocab/vocab-annotator-impl.h"
+#elif defined TC3_VOCAB_ANNOTATOR_DUMMY
+#include "annotator/vocab/vocab-annotator-dummy.h"
+#else
+#error No vocab-annotator implementation specified.
+#endif // TC3_VOCAB_ANNOTATOR_IMPL
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_VOCAB_VOCAB_ANNOTATOR_H_
diff --git a/native/annotator/vocab/vocab-level-table.cc b/native/annotator/vocab/vocab-level-table.cc
new file mode 100644
index 0000000..71b3d8f
--- /dev/null
+++ b/native/annotator/vocab/vocab-level-table.cc
@@ -0,0 +1,70 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/vocab/vocab-level-table.h"
+
+#include <cstddef>
+#include <memory>
+
+#include "annotator/model_generated.h"
+#include "utils/base/endian.h"
+#include "utils/container/bit-vector.h"
+#include "utils/optional.h"
+#include "marisa/trie.h"
+
+namespace libtextclassifier3 {
+
+std::unique_ptr<VocabLevelTable> VocabLevelTable::Create(
+ const VocabModel* model) {
+ if (!LittleEndian::IsLittleEndian()) {
+ // TODO(tonymak) Consider making this work on a big endian device.
+ TC3_LOG(ERROR)
+ << "VocabLevelTable is only working on a little endian device.";
+ return nullptr;
+ }
+ const flatbuffers::Vector<uint8_t>* trie_data = model->vocab_trie();
+ if (trie_data == nullptr) {
+ TC3_LOG(ERROR) << "vocab_trie is missing from the model file.";
+ return nullptr;
+ }
+ std::unique_ptr<marisa::Trie> vocab_trie(new marisa::Trie);
+ vocab_trie->map(trie_data->data(), trie_data->size());
+
+ return std::unique_ptr<VocabLevelTable>(new VocabLevelTable(
+ model, std::move(vocab_trie), BitVector(model->beginner_level()),
+ BitVector(model->do_not_trigger_in_upper_case())));
+}
+
+VocabLevelTable::VocabLevelTable(const VocabModel* model,
+ std::unique_ptr<marisa::Trie> vocab_trie,
+ const BitVector beginner_level,
+ const BitVector do_not_trigger_in_upper_case)
+ : model_(model),
+ vocab_trie_(std::move(vocab_trie)),
+ beginner_level_(beginner_level),
+ do_not_trigger_in_upper_case_(do_not_trigger_in_upper_case) {}
+
+Optional<LookupResult> VocabLevelTable::Lookup(const std::string& vocab) const {
+ marisa::Agent agent;
+ agent.set_query(vocab.data(), vocab.size());
+ if (vocab_trie_->lookup(agent)) {
+ const int vector_idx = agent.key().id();
+ return Optional<LookupResult>({beginner_level_[vector_idx],
+ do_not_trigger_in_upper_case_[vector_idx]});
+ }
+ return Optional<LookupResult>();
+}
+} // namespace libtextclassifier3
diff --git a/native/annotator/vocab/vocab-level-table.h b/native/annotator/vocab/vocab-level-table.h
new file mode 100644
index 0000000..f83ad72
--- /dev/null
+++ b/native/annotator/vocab/vocab-level-table.h
@@ -0,0 +1,58 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_VOCAB_VOCAB_LEVEL_TABLE_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_VOCAB_VOCAB_LEVEL_TABLE_H_
+
+#include "annotator/model_generated.h"
+#include "annotator/types.h"
+#include "utils/container/bit-vector.h"
+#include "marisa/trie.h"
+
+namespace libtextclassifier3 {
+
+struct LookupResult {
+ // Whether to trigger define for users of beginner proficiency.
+ bool beginner_level;
+ // Whether if we should avoid triggering define if the leading character is in
+ // upper case.
+ bool do_not_trigger_in_upper_case;
+};
+
+// A table of vocabs and their levels which is backed by a marisa trie.
+// See http://www.s-yata.jp/marisa-trie/docs/readme.en.html.
+class VocabLevelTable {
+ public:
+ static std::unique_ptr<VocabLevelTable> Create(const VocabModel* model);
+
+ Optional<LookupResult> Lookup(const std::string& vocab) const;
+
+ private:
+ explicit VocabLevelTable(const VocabModel* model,
+ std::unique_ptr<marisa::Trie> vocab_trie,
+ const BitVector beginner_level,
+ const BitVector do_not_trigger_in_upper_case);
+ static const VocabModel* LoadAndVerifyModel();
+
+ const VocabModel* model_;
+ const std::unique_ptr<marisa::Trie> vocab_trie_;
+ const BitVector beginner_level_;
+ const BitVector do_not_trigger_in_upper_case_;
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_VOCAB_VOCAB_LEVEL_TABLE_H_
diff --git a/native/annotator/zlib-utils.cc b/native/annotator/zlib-utils.cc
index c3c2cf1..2f3012e 100644
--- a/native/annotator/zlib-utils.cc
+++ b/native/annotator/zlib-utils.cc
@@ -20,7 +20,6 @@
#include "utils/base/logging.h"
#include "utils/intents/zlib-utils.h"
-#include "utils/resources.h"
#include "utils/zlib/zlib.h"
namespace libtextclassifier3 {
@@ -66,11 +65,6 @@
}
}
- // Compress resources.
- if (model->resources != nullptr) {
- CompressResources(model->resources.get());
- }
-
// Compress intent generator.
if (model->intent_options != nullptr) {
CompressIntentModel(model->intent_options.get());
@@ -126,10 +120,6 @@
}
}
- if (model->resources != nullptr) {
- DecompressResources(model->resources.get());
- }
-
if (model->intent_options != nullptr) {
DecompressIntentModel(model->intent_options.get());
}
diff --git a/native/annotator/zlib-utils_test.cc b/native/annotator/zlib-utils_test.cc
index df33ea1..7e4ef08 100644
--- a/native/annotator/zlib-utils_test.cc
+++ b/native/annotator/zlib-utils_test.cc
@@ -55,42 +55,9 @@
model.intent_options->generator.back()->lua_template_generator =
std::vector<uint8_t>(intent_generator2.begin(), intent_generator2.end());
- // NOTE: The resource strings contain some repetition, so that the compressed
- // version is smaller than the uncompressed one. Because the compression code
- // looks at that as well.
- model.resources.reset(new ResourcePoolT);
- model.resources->resource_entry.emplace_back(new ResourceEntryT);
- model.resources->resource_entry.back()->resource.emplace_back(new ResourceT);
- model.resources->resource_entry.back()->resource.back()->content =
- "rrrrrrrrrrrrr1.1";
- model.resources->resource_entry.back()->resource.emplace_back(new ResourceT);
- model.resources->resource_entry.back()->resource.back()->content =
- "rrrrrrrrrrrrr1.2";
- model.resources->resource_entry.emplace_back(new ResourceEntryT);
- model.resources->resource_entry.back()->resource.emplace_back(new ResourceT);
- model.resources->resource_entry.back()->resource.back()->content =
- "rrrrrrrrrrrrr2.1";
- model.resources->resource_entry.back()->resource.emplace_back(new ResourceT);
- model.resources->resource_entry.back()->resource.back()->content =
- "rrrrrrrrrrrrr2.2";
-
// Compress the model.
EXPECT_TRUE(CompressModel(&model));
- // Sanity check that uncompressed field is removed.
- EXPECT_TRUE(model.regex_model->patterns[0]->pattern.empty());
- EXPECT_TRUE(model.regex_model->patterns[1]->pattern.empty());
- EXPECT_TRUE(model.datetime_model->patterns[0]->regexes[0]->pattern.empty());
- EXPECT_TRUE(model.datetime_model->extractors[0]->pattern.empty());
- EXPECT_TRUE(
- model.intent_options->generator[0]->lua_template_generator.empty());
- EXPECT_TRUE(
- model.intent_options->generator[1]->lua_template_generator.empty());
- EXPECT_TRUE(model.resources->resource_entry[0]->resource[0]->content.empty());
- EXPECT_TRUE(model.resources->resource_entry[0]->resource[1]->content.empty());
- EXPECT_TRUE(model.resources->resource_entry[1]->resource[0]->content.empty());
- EXPECT_TRUE(model.resources->resource_entry[1]->resource[1]->content.empty());
-
// Pack and load the model.
flatbuffers::FlatBufferBuilder builder;
builder.Finish(Model::Pack(builder, &model));
@@ -139,14 +106,6 @@
EXPECT_EQ(
model.intent_options->generator[1]->lua_template_generator,
std::vector<uint8_t>(intent_generator2.begin(), intent_generator2.end()));
- EXPECT_EQ(model.resources->resource_entry[0]->resource[0]->content,
- "rrrrrrrrrrrrr1.1");
- EXPECT_EQ(model.resources->resource_entry[0]->resource[1]->content,
- "rrrrrrrrrrrrr1.2");
- EXPECT_EQ(model.resources->resource_entry[1]->resource[0]->content,
- "rrrrrrrrrrrrr2.1");
- EXPECT_EQ(model.resources->resource_entry[1]->resource[1]->content,
- "rrrrrrrrrrrrr2.2");
}
} // namespace libtextclassifier3
diff --git a/native/lang_id/common/file/mmap.cc b/native/lang_id/common/file/mmap.cc
index 3dcdd3b..19afcc4 100644
--- a/native/lang_id/common/file/mmap.cc
+++ b/native/lang_id/common/file/mmap.cc
@@ -160,6 +160,7 @@
SAFTM_LOG(ERROR) << "Error closing file descriptor: " << last_error;
}
}
+
private:
const int fd_;
@@ -195,13 +196,23 @@
size_t file_size_in_bytes = static_cast<size_t>(sb.st_size);
// Perform actual mmap.
+ return MmapFile(fd, /*offset_in_bytes=*/0, file_size_in_bytes);
+}
+
+MmapHandle MmapFile(int fd, size_t offset_in_bytes, size_t size_in_bytes) {
+ // Make sure the offset is a multiple of the page size, as returned by
+ // sysconf(_SC_PAGE_SIZE); this is required by the man-page for mmap.
+ static const size_t kPageSize = sysconf(_SC_PAGE_SIZE);
+ const size_t aligned_offset = (offset_in_bytes / kPageSize) * kPageSize;
+ const size_t alignment_shift = offset_in_bytes - aligned_offset;
+ const size_t aligned_length = size_in_bytes + alignment_shift;
+
void *mmap_addr = mmap(
// Let system pick address for mmapp-ed data.
nullptr,
- // Mmap all bytes from the file.
- file_size_in_bytes,
+ aligned_length,
// One can read / write the mapped data (but see MAP_PRIVATE below).
// Normally, we expect only to read it, but in the future, we may want to
@@ -215,16 +226,15 @@
// Descriptor of file to mmap.
fd,
- // Map bytes right from the beginning of the file. This, and
- // file_size_in_bytes (2nd argument) means we map all bytes from the file.
- 0);
+ aligned_offset);
if (mmap_addr == MAP_FAILED) {
const std::string last_error = GetLastSystemError();
SAFTM_LOG(ERROR) << "Error while mmapping: " << last_error;
return GetErrorMmapHandle();
}
- return MmapHandle(mmap_addr, file_size_in_bytes);
+ return MmapHandle(static_cast<char *>(mmap_addr) + alignment_shift,
+ size_in_bytes);
}
bool Unmap(MmapHandle mmap_handle) {
diff --git a/native/lang_id/common/file/mmap.h b/native/lang_id/common/file/mmap.h
index f785465..923751a 100644
--- a/native/lang_id/common/file/mmap.h
+++ b/native/lang_id/common/file/mmap.h
@@ -19,6 +19,7 @@
#include <stddef.h>
+#include <cstddef>
#include <string>
#include "lang_id/common/lite_strings/stringpiece.h"
@@ -97,8 +98,15 @@
#endif
// Like MmapFile(const std::string &filename), but uses a file descriptor.
+// This function maps the entire file content.
MmapHandle MmapFile(FileDescriptorOrHandle fd);
+// Like MmapFile(const std::string &filename), but uses a file descriptor,
+// with an offset relative to the file start and a specified size, such that we
+// consider only a range of the file content.
+MmapHandle MmapFile(FileDescriptorOrHandle fd, size_t offset_in_bytes,
+ size_t size_in_bytes);
+
// Unmaps a file mapped using MmapFile. Returns true on success, false
// otherwise.
bool Unmap(MmapHandle mmap_handle);
@@ -112,6 +120,10 @@
explicit ScopedMmap(FileDescriptorOrHandle fd) : handle_(MmapFile(fd)) {}
+ explicit ScopedMmap(FileDescriptorOrHandle fd, size_t offset_in_bytes,
+ size_t size_in_bytes)
+ : handle_(MmapFile(fd, offset_in_bytes, size_in_bytes)) {}
+
~ScopedMmap() {
if (handle_.ok()) {
Unmap(handle_);
diff --git a/native/lang_id/common/flatbuffers/embedding-network-params-from-flatbuffer.cc b/native/lang_id/common/flatbuffers/embedding-network-params-from-flatbuffer.cc
index ee22420..d6daa3f 100644
--- a/native/lang_id/common/flatbuffers/embedding-network-params-from-flatbuffer.cc
+++ b/native/lang_id/common/flatbuffers/embedding-network-params-from-flatbuffer.cc
@@ -384,6 +384,7 @@
const flatbuffers::Vector<uint16_t> *scales = matrix->scales();
if (scales == nullptr) {
SAFTM_LOG(ERROR) << "nullptr scales";
+ return nullptr;
}
return scales->data();
}
diff --git a/native/lang_id/common/flatbuffers/model-utils.cc b/native/lang_id/common/flatbuffers/model-utils.cc
index 66f7f38..8efa386 100644
--- a/native/lang_id/common/flatbuffers/model-utils.cc
+++ b/native/lang_id/common/flatbuffers/model-utils.cc
@@ -26,14 +26,6 @@
namespace libtextclassifier3 {
namespace saft_fbs {
-namespace {
-
-// Returns true if we have clear evidence that |model| fails its checksum.
-//
-// E.g., if |model| has the crc32 field, and the value of that field does not
-// match the checksum, then this function returns true. If there is no crc32
-// field, then we don't know what the original (at build time) checksum was, so
-// we don't know anything clear and this function returns false.
bool ClearlyFailsChecksum(const Model &model) {
if (!flatbuffers::IsFieldPresent(&model, Model::VT_CRC32)) {
SAFTM_LOG(WARNING)
@@ -50,7 +42,6 @@
SAFTM_DLOG(INFO) << "Successfully checked CRC32 " << actual_crc32;
return false;
}
-} // namespace
const Model *GetVerifiedModelFromBytes(const char *data, size_t num_bytes) {
if ((data == nullptr) || (num_bytes == 0)) {
diff --git a/native/lang_id/common/flatbuffers/model-utils.h b/native/lang_id/common/flatbuffers/model-utils.h
index 197e1e3..cf33dd5 100644
--- a/native/lang_id/common/flatbuffers/model-utils.h
+++ b/native/lang_id/common/flatbuffers/model-utils.h
@@ -66,6 +66,14 @@
// corruption. GetVerifiedModelFromBytes performs this check.
mobile::uint32 ComputeCrc2Checksum(const Model *model);
+// Returns true if we have clear evidence that |model| fails its checksum.
+//
+// E.g., if |model| has the crc32 field, and the value of that field does not
+// match the checksum, then this function returns true. If there is no crc32
+// field, then we don't know what the original (at build time) checksum was, so
+// we don't know anything clear and this function returns false.
+bool ClearlyFailsChecksum(const Model &model);
+
} // namespace saft_fbs
} // namespace nlp_saft
diff --git a/native/lang_id/common/lite_base/endian.h b/native/lang_id/common/lite_base/endian.h
index 16c2dca..2e3ee26 100644
--- a/native/lang_id/common/lite_base/endian.h
+++ b/native/lang_id/common/lite_base/endian.h
@@ -93,15 +93,6 @@
// Conversion functions.
#ifdef SAFTM_IS_LITTLE_ENDIAN
- static uint16 FromHost16(uint16 x) { return x; }
- static uint16 ToHost16(uint16 x) { return x; }
-
- static uint32 FromHost32(uint32 x) { return x; }
- static uint32 ToHost32(uint32 x) { return x; }
-
- static uint64 FromHost64(uint64 x) { return x; }
- static uint64 ToHost64(uint64 x) { return x; }
-
static bool IsLittleEndian() { return true; }
#elif defined SAFTM_IS_BIG_ENDIAN
diff --git a/native/lang_id/common/lite_base/integral-types.h b/native/lang_id/common/lite_base/integral-types.h
index 4c3038c..9b02296 100644
--- a/native/lang_id/common/lite_base/integral-types.h
+++ b/native/lang_id/common/lite_base/integral-types.h
@@ -19,11 +19,13 @@
#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_BASE_INTEGRAL_TYPES_H_
#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_BASE_INTEGRAL_TYPES_H_
+#include <cstdint>
+
namespace libtextclassifier3 {
namespace mobile {
typedef unsigned int uint32;
-typedef unsigned long long uint64;
+typedef uint64_t uint64;
#ifndef SWIG
typedef int int32;
@@ -37,11 +39,7 @@
typedef signed int char32;
#endif // SWIG
-#ifdef COMPILER_MSVC
-typedef __int64 int64;
-#else
-typedef long long int64; // NOLINT
-#endif // COMPILER_MSVC
+using int64 = int64_t;
// Some compile-time assertions that our new types have the intended size.
static_assert(sizeof(int) == 4, "Our typedefs depend on int being 32 bits");
diff --git a/native/lang_id/fb_model/lang-id-from-fb.cc b/native/lang_id/fb_model/lang-id-from-fb.cc
index b2163eb..dc36fb7 100644
--- a/native/lang_id/fb_model/lang-id-from-fb.cc
+++ b/native/lang_id/fb_model/lang-id-from-fb.cc
@@ -44,6 +44,16 @@
new LangId(std::move(model_provider)));
}
+std::unique_ptr<LangId> GetLangIdFromFlatbufferFileDescriptor(
+ FileDescriptorOrHandle fd, size_t offset, size_t num_bytes) {
+ std::unique_ptr<ModelProvider> model_provider(
+ new ModelProviderFromFlatbuffer(fd, offset, num_bytes));
+
+ // NOTE: we avoid absl (including absl::make_unique), due to b/113350902
+ return std::unique_ptr<LangId>( // NOLINT
+ new LangId(std::move(model_provider)));
+}
+
std::unique_ptr<LangId> GetLangIdFromFlatbufferBytes(const char *data,
size_t num_bytes) {
std::unique_ptr<ModelProvider> model_provider(
diff --git a/native/lang_id/fb_model/lang-id-from-fb.h b/native/lang_id/fb_model/lang-id-from-fb.h
index 061247b..eed843d 100644
--- a/native/lang_id/fb_model/lang-id-from-fb.h
+++ b/native/lang_id/fb_model/lang-id-from-fb.h
@@ -40,6 +40,11 @@
FileDescriptorOrHandle fd);
// Returns a LangId built using the SAFT model in flatbuffer format from
+// given file descriptor, staring at |offset| and of size |num_bytes|.
+std::unique_ptr<LangId> GetLangIdFromFlatbufferFileDescriptor(
+ FileDescriptorOrHandle fd, size_t offset, size_t num_bytes);
+
+// Returns a LangId built using the SAFT model in flatbuffer format from
// the |num_bytes| bytes that start at address |data|.
//
// IMPORTANT: the model bytes must be alive during the lifetime of the returned
diff --git a/native/lang_id/fb_model/model-provider-from-fb.cc b/native/lang_id/fb_model/model-provider-from-fb.cc
index c81b116..43bf860 100644
--- a/native/lang_id/fb_model/model-provider-from-fb.cc
+++ b/native/lang_id/fb_model/model-provider-from-fb.cc
@@ -48,6 +48,16 @@
Initialize(scoped_mmap_->handle().to_stringpiece());
}
+ModelProviderFromFlatbuffer::ModelProviderFromFlatbuffer(
+ FileDescriptorOrHandle fd, std::size_t offset, std::size_t size)
+
+ // Using mmap as a fast way to read the model bytes. As the file is
+ // unmapped only when the field scoped_mmap_ is destructed, the model bytes
+ // stay alive for the entire lifetime of this object.
+ : scoped_mmap_(new ScopedMmap(fd, offset, size)) {
+ Initialize(scoped_mmap_->handle().to_stringpiece());
+}
+
void ModelProviderFromFlatbuffer::Initialize(StringPiece model_bytes) {
// Note: valid_ was initialized to false. In the code below, we set valid_ to
// true only if all initialization steps completed successfully. Otherwise,
diff --git a/native/lang_id/fb_model/model-provider-from-fb.h b/native/lang_id/fb_model/model-provider-from-fb.h
index c3def49..55e631c 100644
--- a/native/lang_id/fb_model/model-provider-from-fb.h
+++ b/native/lang_id/fb_model/model-provider-from-fb.h
@@ -43,6 +43,11 @@
// file descriptor |fd|.
explicit ModelProviderFromFlatbuffer(FileDescriptorOrHandle fd);
+ // Constructs a model provider based on a flatbuffer-format SAFT model from
+ // file descriptor |fd|.
+ ModelProviderFromFlatbuffer(FileDescriptorOrHandle fd, std::size_t offset,
+ std::size_t size);
+
// Constructs a model provider from a flatbuffer-format SAFT model the bytes
// of which are already in RAM (size bytes starting from address data).
// Useful if you "transport" these bytes otherwise than via a normal file
diff --git a/native/lang_id/lang-id-wrapper.cc b/native/lang_id/lang-id-wrapper.cc
index 4246cce..baeb0f2 100644
--- a/native/lang_id/lang-id-wrapper.cc
+++ b/native/lang_id/lang-id-wrapper.cc
@@ -40,6 +40,13 @@
return langid_model;
}
+std::unique_ptr<libtextclassifier3::mobile::lang_id::LangId> LoadFromUnownedBuffer(
+ const char* buffer, int size) {
+ std::unique_ptr<libtextclassifier3::mobile::lang_id::LangId> langid_model =
+ libtextclassifier3::mobile::lang_id::GetLangIdFromFlatbufferBytes(buffer, size);
+ return langid_model;
+}
+
std::vector<std::pair<std::string, float>> GetPredictions(
const libtextclassifier3::mobile::lang_id::LangId* model, const std::string& text) {
return GetPredictions(model, text.data(), text.size());
diff --git a/native/lang_id/lang-id-wrapper.h b/native/lang_id/lang-id-wrapper.h
index 47e6f44..8e65af7 100644
--- a/native/lang_id/lang-id-wrapper.h
+++ b/native/lang_id/lang-id-wrapper.h
@@ -35,6 +35,11 @@
std::unique_ptr<libtextclassifier3::mobile::lang_id::LangId> LoadFromDescriptor(
const int fd);
+// Loads the LangId model from a buffer. The buffer needs to outlive the LangId
+// instance.
+std::unique_ptr<libtextclassifier3::mobile::lang_id::LangId> LoadFromUnownedBuffer(
+ const char* buffer, int size);
+
// Returns the LangId predictions (locale, confidence) from the given LangId
// model. The maximum number of predictions returned will be computed internally
// relatively to the noise threshold.
diff --git a/native/lang_id/lang-id_jni.cc b/native/lang_id/lang-id_jni.cc
index 30753dc..e86f198 100644
--- a/native/lang_id/lang-id_jni.cc
+++ b/native/lang_id/lang-id_jni.cc
@@ -28,9 +28,9 @@
#include "lang_id/lang-id.h"
using libtextclassifier3::JniHelper;
+using libtextclassifier3::JStringToUtf8String;
using libtextclassifier3::ScopedLocalRef;
using libtextclassifier3::StatusOr;
-using libtextclassifier3::ToStlString;
using libtextclassifier3::mobile::lang_id::GetLangIdFromFlatbufferFile;
using libtextclassifier3::mobile::lang_id::GetLangIdFromFlatbufferFileDescriptor;
using libtextclassifier3::mobile::lang_id::LangId;
@@ -63,7 +63,7 @@
env, result_class.get(), result_class_constructor,
predicted_language.get(),
static_cast<jfloat>(lang_id_predictions[i].second)));
- env->SetObjectArrayElement(results.get(), i, result.get());
+ JniHelper::SetObjectArrayElement(env, results.get(), i, result.get());
}
return results;
}
@@ -74,7 +74,7 @@
} // namespace
TC3_JNI_METHOD(jlong, TC3_LANG_ID_CLASS_NAME, nativeNew)
-(JNIEnv* env, jobject thiz, jint fd) {
+(JNIEnv* env, jobject clazz, jint fd) {
std::unique_ptr<LangId> lang_id = GetLangIdFromFlatbufferFileDescriptor(fd);
if (!lang_id->is_valid()) {
return reinterpret_cast<jlong>(nullptr);
@@ -83,8 +83,9 @@
}
TC3_JNI_METHOD(jlong, TC3_LANG_ID_CLASS_NAME, nativeNewFromPath)
-(JNIEnv* env, jobject thiz, jstring path) {
- TC3_ASSIGN_OR_RETURN_0(const std::string path_str, ToStlString(env, path));
+(JNIEnv* env, jobject clazz, jstring path) {
+ TC3_ASSIGN_OR_RETURN_0(const std::string path_str,
+ JStringToUtf8String(env, path));
std::unique_ptr<LangId> lang_id = GetLangIdFromFlatbufferFile(path_str);
if (!lang_id->is_valid()) {
return reinterpret_cast<jlong>(nullptr);
@@ -92,14 +93,25 @@
return reinterpret_cast<jlong>(lang_id.release());
}
+TC3_JNI_METHOD(jlong, TC3_LANG_ID_CLASS_NAME, nativeNewWithOffset)
+(JNIEnv* env, jobject clazz, jint fd, jlong offset, jlong size) {
+ std::unique_ptr<LangId> lang_id =
+ GetLangIdFromFlatbufferFileDescriptor(fd, offset, size);
+ if (!lang_id->is_valid()) {
+ return reinterpret_cast<jlong>(nullptr);
+ }
+ return reinterpret_cast<jlong>(lang_id.release());
+}
+
TC3_JNI_METHOD(jobjectArray, TC3_LANG_ID_CLASS_NAME, nativeDetectLanguages)
-(JNIEnv* env, jobject clazz, jlong ptr, jstring text) {
+(JNIEnv* env, jobject thiz, jlong ptr, jstring text) {
LangId* model = reinterpret_cast<LangId*>(ptr);
if (!model) {
return nullptr;
}
- TC3_ASSIGN_OR_RETURN_NULL(const std::string text_str, ToStlString(env, text));
+ TC3_ASSIGN_OR_RETURN_NULL(const std::string text_str,
+ JStringToUtf8String(env, text));
const std::vector<std::pair<std::string, float>>& prediction_results =
libtextclassifier3::langid::GetPredictions(model, text_str);
@@ -111,7 +123,7 @@
}
TC3_JNI_METHOD(void, TC3_LANG_ID_CLASS_NAME, nativeClose)
-(JNIEnv* env, jobject clazz, jlong ptr) {
+(JNIEnv* env, jobject thiz, jlong ptr) {
if (!ptr) {
TC3_LOG(ERROR) << "Trying to close null LangId.";
return;
@@ -121,7 +133,7 @@
}
TC3_JNI_METHOD(jint, TC3_LANG_ID_CLASS_NAME, nativeGetVersion)
-(JNIEnv* env, jobject clazz, jlong ptr) {
+(JNIEnv* env, jobject thiz, jlong ptr) {
if (!ptr) {
return -1;
}
@@ -164,3 +176,13 @@
LangId* model = reinterpret_cast<LangId*>(ptr);
return model->GetFloatProperty("min_text_size_in_bytes", 0);
}
+
+TC3_JNI_METHOD(jint, TC3_LANG_ID_CLASS_NAME, nativeGetVersionWithOffset)
+(JNIEnv* env, jobject clazz, jint fd, jlong offset, jlong size) {
+ std::unique_ptr<LangId> lang_id =
+ GetLangIdFromFlatbufferFileDescriptor(fd, offset, size);
+ if (!lang_id->is_valid()) {
+ return -1;
+ }
+ return lang_id->GetModelVersion();
+}
diff --git a/native/lang_id/lang-id_jni.h b/native/lang_id/lang-id_jni.h
index 219349c..e917197 100644
--- a/native/lang_id/lang-id_jni.h
+++ b/native/lang_id/lang-id_jni.h
@@ -20,7 +20,9 @@
#define LIBTEXTCLASSIFIER_LANG_ID_LANG_ID_JNI_H_
#include <jni.h>
+
#include <string>
+
#include "utils/java/jni-base.h"
#ifndef TC3_LANG_ID_CLASS_NAME
@@ -39,14 +41,17 @@
TC3_JNI_METHOD(jlong, TC3_LANG_ID_CLASS_NAME, nativeNewFromPath)
(JNIEnv* env, jobject clazz, jstring path);
+TC3_JNI_METHOD(jlong, TC3_LANG_ID_CLASS_NAME, nativeNewWithOffset)
+(JNIEnv* env, jobject clazz, jint fd, jlong offset, jlong size);
+
TC3_JNI_METHOD(jobjectArray, TC3_LANG_ID_CLASS_NAME, nativeDetectLanguages)
-(JNIEnv* env, jobject clazz, jlong ptr, jstring text);
+(JNIEnv* env, jobject thiz, jlong ptr, jstring text);
TC3_JNI_METHOD(void, TC3_LANG_ID_CLASS_NAME, nativeClose)
-(JNIEnv* env, jobject clazz, jlong ptr);
+(JNIEnv* env, jobject thiz, jlong ptr);
TC3_JNI_METHOD(jint, TC3_LANG_ID_CLASS_NAME, nativeGetVersion)
-(JNIEnv* env, jobject clazz, jlong ptr);
+(JNIEnv* env, jobject thiz, jlong ptr);
TC3_JNI_METHOD(jint, TC3_LANG_ID_CLASS_NAME, nativeGetVersionFromFd)
(JNIEnv* env, jobject clazz, jint fd);
@@ -60,6 +65,9 @@
TC3_JNI_METHOD(jint, TC3_LANG_ID_CLASS_NAME, nativeGetMinTextSizeInBytes)
(JNIEnv* env, jobject thizz, jlong ptr);
+TC3_JNI_METHOD(jint, TC3_LANG_ID_CLASS_NAME, nativeGetVersionWithOffset)
+(JNIEnv* env, jobject clazz, jint fd, jlong offset, jlong size);
+
#ifdef __cplusplus
}
#endif
diff --git a/native/lang_id/script/tiny-script-detector.h b/native/lang_id/script/tiny-script-detector.h
index a55da04..d08270c 100644
--- a/native/lang_id/script/tiny-script-detector.h
+++ b/native/lang_id/script/tiny-script-detector.h
@@ -74,12 +74,12 @@
// CPU, so it's better to use than int32.
static const unsigned int kGreekStart = 0x370;
- // Commented out (unsued in the code): kGreekEnd = 0x3FF;
+ // Commented out (unused in the code): kGreekEnd = 0x3FF;
static const unsigned int kCyrillicStart = 0x400;
static const unsigned int kCyrillicEnd = 0x4FF;
static const unsigned int kHebrewStart = 0x590;
- // Commented out (unsued in the code): kHebrewEnd = 0x5FF;
+ // Commented out (unused in the code): kHebrewEnd = 0x5FF;
static const unsigned int kArabicStart = 0x600;
static const unsigned int kArabicEnd = 0x6FF;
const unsigned int codepoint = ((p[0] & 0x1F) << 6) | (p[1] & 0x3F);
@@ -117,7 +117,7 @@
static const unsigned int kHiraganaStart = 0x3041;
static const unsigned int kHiraganaEnd = 0x309F;
- // Commented out (unsued in the code): kKatakanaStart = 0x30A0;
+ // Commented out (unused in the code): kKatakanaStart = 0x30A0;
static const unsigned int kKatakanaEnd = 0x30FF;
const unsigned int codepoint =
((p[0] & 0x0F) << 12) | ((p[1] & 0x3F) << 6) | (p[2] & 0x3F);
diff --git a/native/models/actions_suggestions.en.model b/native/models/actions_suggestions.en.model
index d4b0ced..74422f6 100755
--- a/native/models/actions_suggestions.en.model
+++ b/native/models/actions_suggestions.en.model
Binary files differ
diff --git a/native/models/actions_suggestions.universal.model b/native/models/actions_suggestions.universal.model
index 2ee546c..f74fed4 100755
--- a/native/models/actions_suggestions.universal.model
+++ b/native/models/actions_suggestions.universal.model
Binary files differ
diff --git a/native/models/textclassifier.ar.model b/native/models/textclassifier.ar.model
index 923d8af..ff460e6 100755
--- a/native/models/textclassifier.ar.model
+++ b/native/models/textclassifier.ar.model
Binary files differ
diff --git a/native/models/textclassifier.en.model b/native/models/textclassifier.en.model
index aec4302..9eca5dd 100755
--- a/native/models/textclassifier.en.model
+++ b/native/models/textclassifier.en.model
Binary files differ
diff --git a/native/models/textclassifier.es.model b/native/models/textclassifier.es.model
index 7ff3d73..c25fef1 100755
--- a/native/models/textclassifier.es.model
+++ b/native/models/textclassifier.es.model
Binary files differ
diff --git a/native/models/textclassifier.fr.model b/native/models/textclassifier.fr.model
index cc5f488..b98c075 100755
--- a/native/models/textclassifier.fr.model
+++ b/native/models/textclassifier.fr.model
Binary files differ
diff --git a/native/models/textclassifier.it.model b/native/models/textclassifier.it.model
index 5d40ef5..5bb5a21 100755
--- a/native/models/textclassifier.it.model
+++ b/native/models/textclassifier.it.model
Binary files differ
diff --git a/native/models/textclassifier.ja.model b/native/models/textclassifier.ja.model
index 9d65601..8851b7c 100755
--- a/native/models/textclassifier.ja.model
+++ b/native/models/textclassifier.ja.model
Binary files differ
diff --git a/native/models/textclassifier.ko.model b/native/models/textclassifier.ko.model
index becba7a..7b1b26a 100755
--- a/native/models/textclassifier.ko.model
+++ b/native/models/textclassifier.ko.model
Binary files differ
diff --git a/native/models/textclassifier.nl.model b/native/models/textclassifier.nl.model
index bac8350..7005cf4 100755
--- a/native/models/textclassifier.nl.model
+++ b/native/models/textclassifier.nl.model
Binary files differ
diff --git a/native/models/textclassifier.pl.model b/native/models/textclassifier.pl.model
index 03b2825..9d3b7e3 100755
--- a/native/models/textclassifier.pl.model
+++ b/native/models/textclassifier.pl.model
Binary files differ
diff --git a/native/models/textclassifier.pt.model b/native/models/textclassifier.pt.model
index 39f0b12..4af2b0d 100755
--- a/native/models/textclassifier.pt.model
+++ b/native/models/textclassifier.pt.model
Binary files differ
diff --git a/native/models/textclassifier.ru.model b/native/models/textclassifier.ru.model
index 6d08044..fda7a7c 100755
--- a/native/models/textclassifier.ru.model
+++ b/native/models/textclassifier.ru.model
Binary files differ
diff --git a/native/models/textclassifier.th.model b/native/models/textclassifier.th.model
index 5e0f9dd..f3b6ce5 100755
--- a/native/models/textclassifier.th.model
+++ b/native/models/textclassifier.th.model
Binary files differ
diff --git a/native/models/textclassifier.tr.model b/native/models/textclassifier.tr.model
index 2dbc1d8..8e34988 100755
--- a/native/models/textclassifier.tr.model
+++ b/native/models/textclassifier.tr.model
Binary files differ
diff --git a/native/models/textclassifier.universal.model b/native/models/textclassifier.universal.model
index 853e389..09f1e0b 100755
--- a/native/models/textclassifier.universal.model
+++ b/native/models/textclassifier.universal.model
Binary files differ
diff --git a/native/models/textclassifier.zh.model b/native/models/textclassifier.zh.model
index 8d989d7..f664882 100755
--- a/native/models/textclassifier.zh.model
+++ b/native/models/textclassifier.zh.model
Binary files differ
diff --git a/native/tensorflow_models/seq_flow_lite/tflite_ops/layer_norm.cc b/native/tensorflow_models/seq_flow_lite/tflite_ops/layer_norm.cc
new file mode 100644
index 0000000..e28b04d
--- /dev/null
+++ b/native/tensorflow_models/seq_flow_lite/tflite_ops/layer_norm.cc
@@ -0,0 +1,347 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow_models/seq_flow_lite/tflite_ops/layer_norm.h"
+
+#include <unordered_set>
+#include <vector>
+
+#include "tensorflow_models/seq_flow_lite/tflite_ops/quantization_util.h"
+#include "tensorflow/lite/kernels/kernel_util.h"
+
+namespace seq_flow_lite {
+namespace ops {
+namespace custom {
+
+namespace {
+
+const int kInputIndex = 0;
+const int kScaleIndex = 1;
+const int kOffsetIndex = 2;
+const int kAxisIndex = 3;
+const int kOutputIndex = 0;
+
+TfLiteStatus Resize(TfLiteContext* context, TfLiteNode* node) {
+ if (node->outputs->size != 1) {
+ return kTfLiteError;
+ }
+
+ TfLiteTensor* input = &context->tensors[node->inputs->data[kInputIndex]];
+ TfLiteTensor* scale = &context->tensors[node->inputs->data[kScaleIndex]];
+ TfLiteTensor* offset = &context->tensors[node->inputs->data[kOffsetIndex]];
+ TF_LITE_ENSURE_EQ(context, input->type, kTfLiteUInt8);
+ TF_LITE_ENSURE_EQ(context, offset->dims->data[0], 1);
+ TF_LITE_ENSURE_EQ(context, offset->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, offset->type, kTfLiteUInt8);
+ TF_LITE_ENSURE_EQ(context, scale->dims->data[0], 1);
+ TF_LITE_ENSURE_EQ(context, scale->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, scale->type, kTfLiteUInt8);
+ if (node->inputs->size == 4) {
+ TfLiteTensor* axis = &context->tensors[node->inputs->data[kAxisIndex]];
+ TF_LITE_ENSURE_EQ(context, axis->type, kTfLiteInt32);
+ }
+
+ TfLiteTensor* output = &context->tensors[node->outputs->data[kOutputIndex]];
+ TF_LITE_ENSURE_EQ(context, output->type, kTfLiteUInt8);
+ return context->ResizeTensor(context, output,
+ TfLiteIntArrayCopy(input->dims));
+}
+
+int GetNumberOfSteps(const TfLiteTensor* input) {
+ int number_of_steps = 1;
+ for (int i = 0; i < input->dims->size; ++i) {
+ number_of_steps *= input->dims->data[i];
+ }
+ return number_of_steps;
+}
+
+inline int GetNumberOfFeatures(const TfLiteTensor* input, const int* axis,
+ const int num_axis) {
+ int num_features = 1;
+ for (int i = 0; i < num_axis; ++i) {
+ num_features *= input->dims->data[axis[i]];
+ }
+ return num_features;
+}
+
+// Performs sanity checks on input axis and resolves into valid dimensions.
+inline bool ResolveAxis(const int num_dims, const int* axis, const int num_axis,
+ int* out_axis, int* out_num_axis) {
+ *out_num_axis = 0;
+ // Short-circuit axis resolution for scalars; the axis will go unused.
+ if (num_dims == 0) {
+ return true;
+ }
+
+ // Using an unordered set to reduce complexity in looking up duplicates.
+ std::unordered_set<int> unique_indices;
+ for (int64_t idx = 0; idx < num_axis; ++idx) {
+ // Handle negative index.
+ int current = axis[idx] < 0 ? (axis[idx] + num_dims) : axis[idx];
+ assert(current >= 0 && current < num_dims);
+ // Only adding the axis if it wasn't added before.
+ if (unique_indices.find(current) == unique_indices.end()) {
+ unique_indices.insert(current);
+ out_axis[*out_num_axis] = current;
+ *out_num_axis += 1;
+ }
+ }
+ return true;
+}
+
+// Given current position in the input array, the api computes the next valid
+// index.
+bool ValidIndex(const int* input_dims, const int input_dims_size,
+ int* curr_pos) {
+ if (input_dims_size == 0) {
+ return false;
+ }
+ assert(input_dims != nullptr);
+ assert(curr_pos != nullptr);
+ for (int idx = input_dims_size - 1; idx >= 0; --idx) {
+ int current_val = curr_pos[idx] + 1;
+ assert(input_dims[idx] >= current_val);
+ if (input_dims[idx] == current_val) {
+ curr_pos[idx] = 0;
+ } else {
+ curr_pos[idx] = current_val;
+ return true;
+ }
+ }
+ return false;
+}
+
+// Gets next offset depending on reduction axis. Implementation borrowed from
+// tflite reduce mean implementation.
+int GetOffset(const int* input_dims, const int input_dims_size,
+ const int* curr_pos, const int* axis, const int axis_size) {
+ if (input_dims_size == 0) return 0;
+ assert(input_dims != nullptr);
+ assert(curr_pos != nullptr);
+ int offset = 0;
+ for (int idx = 0; idx < input_dims_size; ++idx) {
+ // if idx is part of reduction axes, we skip offset calculation.
+ bool is_axis = false;
+ if (axis != nullptr) {
+ for (int redux = 0; redux < axis_size; ++redux) {
+ if (idx == axis[redux]) {
+ is_axis = true;
+ break;
+ }
+ }
+ }
+ if (!is_axis) offset = offset * input_dims[idx] + curr_pos[idx];
+ }
+
+ return offset;
+}
+
+// TODO(b/132896827): Current implementation needs further evaluation to reduce
+// space time complexities.
+TfLiteStatus FlexibleLayerNorm(const TfLiteTensor* input, const float scale,
+ const float offset, const int* axis,
+ const int num_axis, TfLiteTensor* output) {
+ int num_features = GetNumberOfFeatures(input, &axis[0], num_axis);
+ int time_steps = static_cast<int>(GetNumberOfSteps(input) / num_features);
+
+ std::vector<float> sum_x(time_steps, 0.0f);
+ std::vector<float> sum_xx(time_steps, 0.0f);
+ std::vector<int> index_iter(input->dims->size, 0);
+
+ // Computing sum and squared sum for features across the reduction axes.
+ do {
+ // Not passing reduction axes to get the input offset as we are simply
+ // iterating through the multidimensional array.
+ int input_offset = GetOffset(input->dims->data, input->dims->size,
+ &index_iter[0], nullptr, 0);
+ // Passing in the valid reduction axes as we would like to get the output
+ // offset after reduction.
+ int stats_offset = GetOffset(input->dims->data, input->dims->size,
+ &index_iter[0], &axis[0], num_axis);
+ float input_val = PodDequantize(*input, input_offset);
+ sum_x[stats_offset] += input_val;
+ sum_xx[stats_offset] += input_val * input_val;
+ } while (ValidIndex(input->dims->data, input->dims->size, &index_iter[0]));
+
+ std::vector<float> multiplier(time_steps, 1.0f);
+ std::vector<float> bias(time_steps, 0.0f);
+
+ // Computing stats for the reduction axes.
+ for (int i = 0; i < time_steps; ++i) {
+ sum_x[i] = sum_x[i] / num_features;
+ sum_xx[i] = sum_xx[i] / num_features;
+ const float variance = sum_xx[i] - sum_x[i] * sum_x[i];
+ const float inverse_stddev = 1 / sqrt(variance + 1e-6);
+ multiplier[i] = inverse_stddev * scale;
+ bias[i] = offset - sum_x[i] * inverse_stddev * scale;
+ }
+
+ const float out_inverse_scale = 1.0f / output->params.scale;
+ const int32_t out_zero_point = output->params.zero_point;
+ uint8_t* out_ptr = output->data.uint8;
+ std::fill(index_iter.begin(), index_iter.end(), 0);
+
+ // Using the stats to fill the output pointer.
+ do {
+ // Not passing reduction axes to get the input offset as we are simply
+ // iterating through the multidimensional array.
+ int input_offset = GetOffset(input->dims->data, input->dims->size,
+ &index_iter[0], nullptr, 0);
+ // Passing in the valid reduction axes as we would like to get the output
+ // offset after reduction.
+ int stats_offset = GetOffset(input->dims->data, input->dims->size,
+ &index_iter[0], &axis[0], num_axis);
+ float input_val = PodDequantize(*input, input_offset);
+
+ const float value =
+ input_val * multiplier[stats_offset] + bias[stats_offset];
+ out_ptr[input_offset] =
+ PodQuantize(value, out_zero_point, out_inverse_scale);
+ } while (ValidIndex(input->dims->data, input->dims->size, &index_iter[0]));
+
+ return kTfLiteOk;
+}
+
+TfLiteStatus DefaultLayerNormFloat(const TfLiteTensor* input, const float scale,
+ const float offset, TfLiteTensor* output) {
+ const int input_rank = input->dims->size;
+ const int num_features = input->dims->data[input_rank - 1];
+ const int time_steps =
+ static_cast<int>(GetNumberOfSteps(input) / num_features);
+ float* out_ptr = output->data.f;
+ for (int i = 0; i < time_steps; ++i) {
+ float sum_x = 0;
+ float sum_xx = 0;
+ for (int j = 0, index = i * num_features; j < num_features; ++j, ++index) {
+ sum_x += input->data.f[index];
+ sum_xx += input->data.f[index] * input->data.f[index];
+ }
+ const float exp_xx = sum_xx / num_features;
+ const float exp_x = sum_x / num_features;
+ const float variance = exp_xx - exp_x * exp_x;
+ const float inverse_stddev = 1 / sqrt(variance + 1e-6);
+ const float multiplier = inverse_stddev * scale;
+
+ const float bias = offset - exp_x * inverse_stddev * scale;
+ for (int j = 0, index = i * num_features; j < num_features; ++j, ++index) {
+ out_ptr[index] = input->data.f[index] * multiplier + bias;
+ }
+ }
+ return kTfLiteOk;
+}
+
+TfLiteStatus DefaultLayerNorm(const TfLiteTensor* input, const float scale,
+ const float offset, TfLiteTensor* output) {
+ const int input_rank = input->dims->size;
+ const int num_features = input->dims->data[input_rank - 1];
+ const int time_steps =
+ static_cast<int>(GetNumberOfSteps(input) / num_features);
+
+ std::vector<float> temp_buffer(num_features, 0.0f);
+ const float out_inverse_scale = 1.0f / output->params.scale;
+ const int32_t out_zero_point = output->params.zero_point;
+ uint8_t* out_ptr = output->data.uint8;
+ for (int i = 0; i < time_steps; ++i) {
+ float sum_x = 0;
+ float sum_xx = 0;
+ for (int j = 0, index = i * num_features; j < num_features; ++j, ++index) {
+ temp_buffer[j] = PodDequantize(*input, index);
+ sum_x += temp_buffer[j];
+ sum_xx += temp_buffer[j] * temp_buffer[j];
+ }
+ const float exp_xx = sum_xx / num_features;
+ const float exp_x = sum_x / num_features;
+ const float variance = exp_xx - exp_x * exp_x;
+ const float inverse_stddev = 1 / sqrt(variance + 1e-6);
+ const float multiplier = inverse_stddev * scale;
+ const float bias = offset - exp_x * inverse_stddev * scale;
+ for (int j = 0, index = i * num_features; j < num_features; ++j, ++index) {
+ const float value = temp_buffer[j] * multiplier + bias;
+ out_ptr[index] = PodQuantize(value, out_zero_point, out_inverse_scale);
+ }
+ }
+ return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ const TfLiteTensor* input =
+ &context->tensors[node->inputs->data[kInputIndex]];
+ TfLiteTensor* output = &context->tensors[node->outputs->data[kOutputIndex]];
+ TfLiteTensor scale_tensor = context->tensors[node->inputs->data[kScaleIndex]];
+ TfLiteTensor offset_tensor =
+ context->tensors[node->inputs->data[kOffsetIndex]];
+ float scale = 1.0;
+ float offset = 0.0;
+ if (input->type == kTfLiteUInt8) {
+ scale = PodDequantize(scale_tensor, 0);
+ offset = PodDequantize(offset_tensor, 0);
+ } else {
+ scale = scale_tensor.data.f[0];
+ offset = offset_tensor.data.f[0];
+ }
+
+ TfLiteTensor* axis = &context->tensors[node->inputs->data[kAxisIndex]];
+ int num_axis = static_cast<int>(tflite::NumElements(axis));
+ // For backward compatibility reasons, we handle the default layer norm for
+ // last channel as below.
+ if (num_axis == 1 && (axis->data.i32[0] == -1 ||
+ axis->data.i32[0] == (input->dims->size - 1))) {
+ if (input->type == kTfLiteUInt8) {
+ return DefaultLayerNorm(input, scale, offset, output);
+ } else if (input->type == kTfLiteFloat32) {
+ return DefaultLayerNormFloat(input, scale, offset, output);
+ } else {
+ TF_LITE_ENSURE_MSG(context, false,
+ "Input should be eith Uint8 or Float32.");
+ }
+ }
+
+ std::vector<int> resolved_axis(num_axis);
+ // Resolve axis.
+ int num_resolved_axis = 0;
+ if (!ResolveAxis(input->dims->size, axis->data.i32, num_axis,
+ &resolved_axis[0], &num_resolved_axis)) {
+ return kTfLiteError;
+ }
+
+ return FlexibleLayerNorm(input, scale, offset, &resolved_axis[0],
+ num_resolved_axis, output);
+}
+
+} // namespace
+
+TfLiteRegistration* Register_LAYER_NORM() {
+ static TfLiteRegistration r = {nullptr, nullptr, Resize, Eval};
+ return &r;
+}
+
+} // namespace custom
+} // namespace ops
+} // namespace seq_flow_lite
diff --git a/native/tensorflow_models/seq_flow_lite/tflite_ops/layer_norm.h b/native/tensorflow_models/seq_flow_lite/tflite_ops/layer_norm.h
new file mode 100644
index 0000000..6d84ca4
--- /dev/null
+++ b/native/tensorflow_models/seq_flow_lite/tflite_ops/layer_norm.h
@@ -0,0 +1,46 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef LEARNING_EXPANDER_POD_DEEP_POD_TFLITE_HANDLERS_LAYER_NORM_H_
+#define LEARNING_EXPANDER_POD_DEEP_POD_TFLITE_HANDLERS_LAYER_NORM_H_
+
+#include "tensorflow/lite/kernels/register.h"
+
+namespace seq_flow_lite {
+namespace ops {
+namespace custom {
+
+TfLiteRegistration* Register_LAYER_NORM();
+
+} // namespace custom
+} // namespace ops
+} // namespace seq_flow_lite
+
+#endif // LEARNING_EXPANDER_POD_DEEP_POD_TFLITE_HANDLERS_LAYER_NORM_H_
diff --git a/native/tensorflow_models/seq_flow_lite/tflite_ops/quantization_util.h b/native/tensorflow_models/seq_flow_lite/tflite_ops/quantization_util.h
new file mode 100644
index 0000000..7f2db41
--- /dev/null
+++ b/native/tensorflow_models/seq_flow_lite/tflite_ops/quantization_util.h
@@ -0,0 +1,69 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TFLITE_OPS_QUANTIZATION_UTIL_H_
+#define TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TFLITE_OPS_QUANTIZATION_UTIL_H_
+
+#include <algorithm>
+#include <cmath>
+
+#include "tensorflow/lite/context.h"
+
+namespace seq_flow_lite {
+
+// Returns the original (dequantized) value of 8bit value.
+inline float PodDequantizeValue(const TfLiteTensor& tensor, uint8_t value) {
+ const int32_t zero_point = tensor.params.zero_point;
+ const float scale = tensor.params.scale;
+ return (static_cast<int32_t>(value) - zero_point) * scale;
+}
+
+// Returns the original (dequantized) value of the 'index'-th element of
+// 'tensor.
+inline float PodDequantize(const TfLiteTensor& tensor, int index) {
+ return PodDequantizeValue(tensor, tensor.data.uint8[index]);
+}
+
+// Quantizes 'value' to 8bit, given the quantization bias (zero_point) and
+// factor (inverse_scale).
+inline uint8_t PodQuantize(float value, int32_t zero_point,
+ float inverse_scale) {
+ const float integer_value_in_float = value * inverse_scale;
+ const float offset = (integer_value_in_float >= 0.0) ? 0.5f : -0.5f;
+ // NOTE(sfeuz): This assumes value * inverse_scale is within [INT_MIN,
+ // INT_MAX].
+ int32_t integer_value =
+ static_cast<int32_t>(integer_value_in_float + offset) + zero_point;
+ return static_cast<uint8_t>(std::max(std::min(255, integer_value), 0));
+}
+
+} // namespace seq_flow_lite
+
+#endif // TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TFLITE_OPS_QUANTIZATION_UTIL_H_
diff --git a/native/testing/JvmTestLauncher.java b/native/testing/JvmTestLauncher.java
new file mode 100644
index 0000000..c35544c
--- /dev/null
+++ b/native/testing/JvmTestLauncher.java
@@ -0,0 +1,29 @@
+package com.google.android.textclassifier.tests;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+import android.content.Context;
+import androidx.test.InstrumentationRegistry;
+
+
+/** This is a launcher of the tests because we need a valid JNIEnv in some C++ tests. */
+@RunWith(JUnit4.class)
+public class JvmTestLauncher {
+
+ @Before
+ public void setUp() throws Exception {
+ System.loadLibrary("jvm_test_launcher");
+ }
+
+ private native boolean testsMain(Context context);
+
+ @Test
+ public void testNative() throws Exception {
+ assertThat(testsMain(InstrumentationRegistry.getContext())).isTrue();
+ }
+}
diff --git a/native/testing/jvm_test_launcher.cc b/native/testing/jvm_test_launcher.cc
new file mode 100644
index 0000000..2c68cf7
--- /dev/null
+++ b/native/testing/jvm_test_launcher.cc
@@ -0,0 +1,23 @@
+#include <jni.h>
+
+#include "utils/testing/logging_event_listener.h"
+#include "gtest/gtest.h"
+
+JNIEnv* g_jenv = nullptr;
+jobject g_context = nullptr;
+
+// This method is called from Java to trigger running of all the tests.
+extern "C" JNIEXPORT jboolean JNICALL
+Java_com_google_android_textclassifier_tests_JvmTestLauncher_testsMain(
+ JNIEnv* env, jclass clazz, jobject context) {
+ g_jenv = env;
+ g_context = context;
+
+ char arg[] = "jvm_test_launcher";
+ std::vector<char*> argv = {arg};
+ int argc = 1;
+ testing::InitGoogleTest(&argc, argv.data());
+ testing::UnitTest::GetInstance()->listeners().Append(
+ new libtextclassifier3::LoggingEventListener());
+ return RUN_ALL_TESTS() == 0;
+}
\ No newline at end of file
diff --git a/native/util/hash/hash.cc b/native/util/hash/hash.cc
deleted file mode 100644
index eaa85ae..0000000
--- a/native/util/hash/hash.cc
+++ /dev/null
@@ -1,79 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "util/hash/hash.h"
-
-#include "utils/base/macros.h"
-
-namespace libtextclassifier2 {
-
-namespace {
-// Lower-level versions of Get... that read directly from a character buffer
-// without any bounds checking.
-inline uint32 DecodeFixed32(const char *ptr) {
- return ((static_cast<uint32>(static_cast<unsigned char>(ptr[0]))) |
- (static_cast<uint32>(static_cast<unsigned char>(ptr[1])) << 8) |
- (static_cast<uint32>(static_cast<unsigned char>(ptr[2])) << 16) |
- (static_cast<uint32>(static_cast<unsigned char>(ptr[3])) << 24));
-}
-
-// 0xff is in case char is signed.
-static inline uint32 ByteAs32(char c) { return static_cast<uint32>(c) & 0xff; }
-} // namespace
-
-uint32 Hash32(const char *data, size_t n, uint32 seed) {
- // 'm' and 'r' are mixing constants generated offline.
- // They're not really 'magic', they just happen to work well.
- const uint32 m = 0x5bd1e995;
- const int r = 24;
-
- // Initialize the hash to a 'random' value
- uint32 h = static_cast<uint32>(seed ^ n);
-
- // Mix 4 bytes at a time into the hash
- while (n >= 4) {
- uint32 k = DecodeFixed32(data);
- k *= m;
- k ^= k >> r;
- k *= m;
- h *= m;
- h ^= k;
- data += 4;
- n -= 4;
- }
-
- // Handle the last few bytes of the input array
- switch (n) {
- case 3:
- h ^= ByteAs32(data[2]) << 16;
- TC3_FALLTHROUGH_INTENDED;
- case 2:
- h ^= ByteAs32(data[1]) << 8;
- TC3_FALLTHROUGH_INTENDED;
- case 1:
- h ^= ByteAs32(data[0]);
- h *= m;
- }
-
- // Do a few final mixes of the hash to ensure the last few
- // bytes are well-incorporated.
- h ^= h >> 13;
- h *= m;
- h ^= h >> 15;
- return h;
-}
-
-} // namespace libtextclassifier2
diff --git a/native/util/hash/hash.h b/native/util/hash/hash.h
deleted file mode 100644
index 9353e5f..0000000
--- a/native/util/hash/hash.h
+++ /dev/null
@@ -1,40 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_UTIL_HASH_HASH_H_
-#define LIBTEXTCLASSIFIER_UTIL_HASH_HASH_H_
-
-#include <string>
-
-#include "utils/base/integral_types.h"
-
-namespace libtextclassifier2 {
-
-using namespace libtextclassifier3;
-
-uint32 Hash32(const char *data, size_t n, uint32 seed);
-
-static inline uint32 Hash32WithDefaultSeed(const char *data, size_t n) {
- return Hash32(data, n, 0xBEEF);
-}
-
-static inline uint32 Hash32WithDefaultSeed(const std::string &input) {
- return Hash32WithDefaultSeed(input.data(), input.size());
-}
-
-} // namespace libtextclassifier2
-
-#endif // LIBTEXTCLASSIFIER_UTIL_HASH_HASH_H_
diff --git a/native/utils/base/arena.cc b/native/utils/base/arena.cc
index fcaed8e..d1e85e8 100644
--- a/native/utils/base/arena.cc
+++ b/native/utils/base/arena.cc
@@ -29,6 +29,7 @@
namespace libtextclassifier3 {
+#ifndef __cpp_aligned_new
static void *aligned_malloc(size_t size, int minimum_alignment) {
void *ptr = nullptr;
// posix_memalign requires that the requested alignment be at least
@@ -42,6 +43,7 @@
else
return ptr;
}
+#endif // !__cpp_aligned_new
// The value here doesn't matter until page_aligned_ is supported.
static const int kPageSize = 8192; // should be getpagesize()
diff --git a/native/utils/base/arena.h b/native/utils/base/arena.h
index 28b6f6c..712deeb 100644
--- a/native/utils/base/arena.h
+++ b/native/utils/base/arena.h
@@ -204,7 +204,7 @@
// Allocates and initializes an object on the arena.
template <typename T, typename... Args>
- T* AllocAndInit(Args... args) {
+ T* AllocAndInit(Args&&... args) {
return new (reinterpret_cast<T*>(AllocAligned(sizeof(T), alignof(T))))
T(std::forward<Args>(args)...);
}
diff --git a/native/utils/base/arena_test.cc b/native/utils/base/arena_test.cc
new file mode 100644
index 0000000..a84190d
--- /dev/null
+++ b/native/utils/base/arena_test.cc
@@ -0,0 +1,385 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/base/arena.h"
+
+#include "utils/base/logging.h"
+#include "utils/base/macros.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+
+//------------------------------------------------------------------------
+// Write random data to allocated memory
+static void TestMemory(void* mem, int size) {
+ // Do some memory allocation to check that the arena doesn't mess up
+ // the internal memory allocator
+ char* tmp[100];
+ for (int i = 0; i < TC3_ARRAYSIZE(tmp); i++) {
+ tmp[i] = new char[i * i + 1];
+ }
+
+ memset(mem, 0xcc, size);
+
+ // Free up the allocated memory;
+ for (char* s : tmp) {
+ delete[] s;
+ }
+}
+
+//------------------------------------------------------------------------
+// Check memory ptr
+static void CheckMemory(void* mem, int size) {
+ TC3_CHECK(mem != nullptr);
+ TestMemory(mem, size);
+}
+
+//------------------------------------------------------------------------
+// Check memory ptr and alignment
+static void CheckAlignment(void* mem, int size, int alignment) {
+ TC3_CHECK(mem != nullptr);
+ ASSERT_EQ(0, (reinterpret_cast<uintptr_t>(mem) & (alignment - 1)))
+ << "mem=" << mem << " alignment=" << alignment;
+ TestMemory(mem, size);
+}
+
+//------------------------------------------------------------------------
+template <class A>
+void TestArena(const char* name, A* a, int blksize) {
+ TC3_VLOG(INFO) << "Testing arena '" << name << "': blksize = " << blksize
+ << ": actual blksize = " << a->block_size();
+
+ int s;
+ blksize = a->block_size();
+
+ // Allocate zero bytes
+ TC3_CHECK(a->is_empty());
+ a->Alloc(0);
+ TC3_CHECK(a->is_empty());
+
+ // Allocate same as blksize
+ CheckMemory(a->Alloc(blksize), blksize);
+ TC3_CHECK(!a->is_empty());
+
+ // Allocate some chunks adding up to blksize
+ s = blksize / 4;
+ CheckMemory(a->Alloc(s), s);
+ CheckMemory(a->Alloc(s), s);
+ CheckMemory(a->Alloc(s), s);
+
+ int s2 = blksize - (s * 3);
+ CheckMemory(a->Alloc(s2), s2);
+
+ // Allocate large chunk
+ CheckMemory(a->Alloc(blksize * 2), blksize * 2);
+ CheckMemory(a->Alloc(blksize * 2 + 1), blksize * 2 + 1);
+ CheckMemory(a->Alloc(blksize * 2 + 2), blksize * 2 + 2);
+ CheckMemory(a->Alloc(blksize * 2 + 3), blksize * 2 + 3);
+
+ // Allocate aligned
+ s = blksize / 2;
+ CheckAlignment(a->AllocAligned(s, 1), s, 1);
+ CheckAlignment(a->AllocAligned(s + 1, 2), s + 1, 2);
+ CheckAlignment(a->AllocAligned(s + 2, 2), s + 2, 2);
+ CheckAlignment(a->AllocAligned(s + 3, 4), s + 3, 4);
+ CheckAlignment(a->AllocAligned(s + 4, 4), s + 4, 4);
+ CheckAlignment(a->AllocAligned(s + 5, 4), s + 5, 4);
+ CheckAlignment(a->AllocAligned(s + 6, 4), s + 6, 4);
+
+ // Free
+ for (int i = 0; i < 100; i++) {
+ int i2 = i * i;
+ a->Free(a->Alloc(i2), i2);
+ }
+
+ // Memdup
+ char mem[500];
+ for (int i = 0; i < 500; i++) mem[i] = i & 255;
+ char* mem2 = a->Memdup(mem, sizeof(mem));
+ TC3_CHECK_EQ(0, memcmp(mem, mem2, sizeof(mem)));
+
+ // MemdupPlusNUL
+ const char* msg_mpn = "won't use all this length";
+ char* msg2_mpn = a->MemdupPlusNUL(msg_mpn, 10);
+ TC3_CHECK_EQ(0, strcmp(msg2_mpn, "won't use "));
+ a->Free(msg2_mpn, 11);
+
+ // Strdup
+ const char* msg = "arena unit test is cool...";
+ char* msg2 = a->Strdup(msg);
+ TC3_CHECK_EQ(0, strcmp(msg, msg2));
+ a->Free(msg2, strlen(msg) + 1);
+
+ // Strndup
+ char* msg3 = a->Strndup(msg, 10);
+ TC3_CHECK_EQ(0, strncmp(msg3, msg, 10));
+ a->Free(msg3, 10);
+ TC3_CHECK(!a->is_empty());
+
+ // Reset
+ a->Reset();
+ TC3_CHECK(a->is_empty());
+
+ // Realloc
+ char* m1 = a->Alloc(blksize / 2);
+ CheckMemory(m1, blksize / 2);
+ TC3_CHECK(!a->is_empty());
+ CheckMemory(a->Alloc(blksize / 2), blksize / 2); // Allocate another block
+ m1 = a->Realloc(m1, blksize / 2, blksize);
+ CheckMemory(m1, blksize);
+ m1 = a->Realloc(m1, blksize, 23456);
+ CheckMemory(m1, 23456);
+
+ // Shrink
+ m1 = a->Shrink(m1, 200);
+ CheckMemory(m1, 200);
+ m1 = a->Shrink(m1, 100);
+ CheckMemory(m1, 100);
+ m1 = a->Shrink(m1, 1);
+ CheckMemory(m1, 1);
+ a->Free(m1, 1);
+ TC3_CHECK(!a->is_empty());
+
+ // Calloc
+ char* m2 = a->Calloc(2000);
+ for (int i = 0; i < 2000; ++i) {
+ TC3_CHECK_EQ(0, m2[i]);
+ }
+
+ // bytes_until_next_allocation
+ a->Reset();
+ TC3_CHECK(a->is_empty());
+ int alignment = blksize - a->bytes_until_next_allocation();
+ TC3_VLOG(INFO) << "Alignment overhead in initial block = " << alignment;
+
+ s = a->bytes_until_next_allocation() - 1;
+ CheckMemory(a->Alloc(s), s);
+ TC3_CHECK_EQ(a->bytes_until_next_allocation(), 1);
+ CheckMemory(a->Alloc(1), 1);
+ TC3_CHECK_EQ(a->bytes_until_next_allocation(), 0);
+
+ CheckMemory(a->Alloc(2 * blksize), 2 * blksize);
+ TC3_CHECK_EQ(a->bytes_until_next_allocation(), 0);
+
+ CheckMemory(a->Alloc(1), 1);
+ TC3_CHECK_EQ(a->bytes_until_next_allocation(), blksize - 1);
+
+ s = blksize / 2;
+ char* m0 = a->Alloc(s);
+ CheckMemory(m0, s);
+ TC3_CHECK_EQ(a->bytes_until_next_allocation(), blksize - s - 1);
+ m0 = a->Shrink(m0, 1);
+ CheckMemory(m0, 1);
+ TC3_CHECK_EQ(a->bytes_until_next_allocation(), blksize - 2);
+
+ a->Reset();
+ TC3_CHECK(a->is_empty());
+ TC3_CHECK_EQ(a->bytes_until_next_allocation(), blksize - alignment);
+}
+
+static void EnsureNoAddressInRangeIsPoisoned(void* buffer, size_t range_size) {
+#ifdef ADDRESS_SANITIZER
+ TC3_CHECK_EQ(nullptr, __asan_region_is_poisoned(buffer, range_size));
+#endif
+}
+
+static void DoTest(const char* label, int blksize, char* buffer) {
+ {
+ UnsafeArena ua(buffer, blksize);
+ TestArena((std::string("UnsafeArena") + label).c_str(), &ua, blksize);
+ }
+ EnsureNoAddressInRangeIsPoisoned(buffer, blksize);
+}
+
+//------------------------------------------------------------------------
+class BasicTest : public ::testing::TestWithParam<int> {};
+
+INSTANTIATE_TEST_SUITE_P(AllSizes, BasicTest,
+ ::testing::Values(BaseArena::kDefaultAlignment + 1, 10,
+ 100, 1024, 12345, 123450, 131072,
+ 1234500));
+
+TEST_P(BasicTest, DoTest) {
+ const int blksize = GetParam();
+
+ // Allocate some memory from heap first
+ char* tmp[100];
+ for (int i = 0; i < TC3_ARRAYSIZE(tmp); i++) {
+ tmp[i] = new char[i * i];
+ }
+
+ // Initial buffer for testing pre-allocated arenas
+ char* buffer = new char[blksize + BaseArena::kDefaultAlignment];
+
+ DoTest("", blksize, nullptr);
+ DoTest("(p0)", blksize, buffer + 0);
+ DoTest("(p1)", blksize, buffer + 1);
+ DoTest("(p2)", blksize, buffer + 2);
+ DoTest("(p3)", blksize, buffer + 3);
+ DoTest("(p4)", blksize, buffer + 4);
+ DoTest("(p5)", blksize, buffer + 5);
+
+ // Free up the allocated heap memory
+ for (char* s : tmp) {
+ delete[] s;
+ }
+
+ delete[] buffer;
+}
+
+//------------------------------------------------------------------------
+// NOTE: these stats will only be accurate in non-debug mode (otherwise
+// they'll all be 0). So: if you want accurate timing, run in "normal"
+// or "opt" mode. If you want accurate stats, run in "debug" mode.
+void ShowStatus(const char* const header, const BaseArena::Status& status) {
+ printf("\n--- status: %s\n", header);
+ printf(" %zu bytes allocated\n", status.bytes_allocated());
+}
+
+// This just tests the arena code proper, without use of allocators of
+// gladiators or STL or anything like that
+void TestArena2(UnsafeArena* const arena) {
+ const char sshort[] = "This is a short string";
+ char slong[3000];
+ memset(slong, 'a', sizeof(slong));
+ slong[sizeof(slong) - 1] = '\0';
+
+ char* s1 = arena->Strdup(sshort);
+ char* s2 = arena->Strdup(slong);
+ char* s3 = arena->Strndup(sshort, 100);
+ char* s4 = arena->Strndup(slong, 100);
+ char* s5 = arena->Memdup(sshort, 10);
+ char* s6 = arena->Realloc(s5, 10, 20);
+ arena->Shrink(s5, 10); // get s5 back to using 10 bytes again
+ char* s7 = arena->Memdup(slong, 10);
+ char* s8 = arena->Realloc(s7, 10, 5);
+ char* s9 = arena->Strdup(s1);
+ char* s10 = arena->Realloc(s4, 100, 10);
+ char* s11 = arena->Realloc(s4, 10, 100);
+ char* s12 = arena->Strdup(s9);
+ char* s13 = arena->Realloc(s9, sizeof(sshort) - 1, 100000); // won't fit :-)
+
+ TC3_CHECK_EQ(0, strcmp(s1, sshort));
+ TC3_CHECK_EQ(0, strcmp(s2, slong));
+ TC3_CHECK_EQ(0, strcmp(s3, sshort));
+ // s4 was realloced so it is not safe to read from
+ TC3_CHECK_EQ(0, strncmp(s5, sshort, 10));
+ TC3_CHECK_EQ(0, strncmp(s6, sshort, 10));
+ TC3_CHECK_EQ(s5, s6); // Should have room to grow here
+ // only the first 5 bytes of s7 should match; the realloc should have
+ // caused the next byte to actually go to s9
+ TC3_CHECK_EQ(0, strncmp(s7, slong, 5));
+ TC3_CHECK_EQ(s7, s8); // Realloc-smaller should cause us to realloc in place
+ // s9 was realloced so it is not safe to read from
+ TC3_CHECK_EQ(s10, s4); // Realloc-smaller should cause us to realloc in place
+ // Even though we're back to prev size, we had to move the pointer. Thus
+ // only the first 10 bytes are known since we grew from 10 to 100
+ TC3_CHECK_NE(s11, s4);
+ TC3_CHECK_EQ(0, strncmp(s11, slong, 10));
+ TC3_CHECK_EQ(0, strcmp(s12, s1));
+ TC3_CHECK_NE(s12, s13); // too big to grow-in-place, so we should move
+}
+
+//--------------------------------------------------------------------
+// Test some fundamental STL containers
+
+template <typename T>
+struct test_hash {
+ int operator()(const T&) const { return 0; }
+ inline bool operator()(const T& s1, const T& s2) const { return s1 < s2; }
+};
+template <>
+struct test_hash<const char*> {
+ int operator()(const char*) const { return 0; }
+
+ inline bool operator()(const char* s1, const char* s2) const {
+ return (s1 != s2) &&
+ (s2 == nullptr || (s1 != nullptr && strcmp(s1, s2) < 0));
+ }
+};
+
+// temp definitions from strutil.h, until the compiler error
+// generated by #including that file is fixed.
+struct streq {
+ bool operator()(const char* s1, const char* s2) const {
+ return ((s1 == nullptr && s2 == nullptr) ||
+ (s1 && s2 && *s1 == *s2 && strcmp(s1, s2) == 0));
+ }
+};
+struct strlt {
+ bool operator()(const char* s1, const char* s2) const {
+ return (s1 != s2) &&
+ (s2 == nullptr || (s1 != nullptr && strcmp(s1, s2) < 0));
+ }
+};
+
+void DoPoisonTest(BaseArena* b, size_t size) {
+#ifdef ADDRESS_SANITIZER
+ TC3_LOG(INFO) << "DoPoisonTest(" << b << ", " << size << ")";
+ char* c1 = b->SlowAlloc(size);
+ char* c2 = b->SlowAlloc(size);
+ TC3_CHECK_EQ(nullptr, __asan_region_is_poisoned(c1, size));
+ TC3_CHECK_EQ(nullptr, __asan_region_is_poisoned(c2, size));
+ char* c3 = b->SlowRealloc(c2, size, size / 2);
+ TC3_CHECK_EQ(nullptr, __asan_region_is_poisoned(c3, size / 2));
+ TC3_CHECK_NE(nullptr, __asan_region_is_poisoned(c2, size));
+ b->Reset();
+ TC3_CHECK_NE(nullptr, __asan_region_is_poisoned(c1, size));
+ TC3_CHECK_NE(nullptr, __asan_region_is_poisoned(c2, size));
+ TC3_CHECK_NE(nullptr, __asan_region_is_poisoned(c3, size / 2));
+#endif
+}
+
+TEST(ArenaTest, TestPoison) {
+ {
+ UnsafeArena arena(512);
+ DoPoisonTest(&arena, 128);
+ DoPoisonTest(&arena, 256);
+ DoPoisonTest(&arena, 512);
+ DoPoisonTest(&arena, 1024);
+ }
+
+ char* buffer = new char[512];
+ {
+ UnsafeArena arena(buffer, 512);
+ DoPoisonTest(&arena, 128);
+ DoPoisonTest(&arena, 256);
+ DoPoisonTest(&arena, 512);
+ DoPoisonTest(&arena, 1024);
+ }
+ EnsureNoAddressInRangeIsPoisoned(buffer, 512);
+
+ delete[] buffer;
+}
+
+//------------------------------------------------------------------------
+
+template <class A>
+void TestStrndupUnterminated() {
+ const char kFoo[3] = {'f', 'o', 'o'};
+ char* source = new char[3];
+ memcpy(source, kFoo, sizeof(kFoo));
+ A arena(4096);
+ char* dup = arena.Strndup(source, sizeof(kFoo));
+ TC3_CHECK_EQ(0, memcmp(dup, kFoo, sizeof(kFoo)));
+ delete[] source;
+}
+
+TEST(ArenaTest, StrndupWithUnterminatedStringUnsafe) {
+ TestStrndupUnterminated<UnsafeArena>();
+}
+
+} // namespace libtextclassifier3
diff --git a/native/utils/base/endian.h b/native/utils/base/endian.h
index 9312704..810bc46 100644
--- a/native/utils/base/endian.h
+++ b/native/utils/base/endian.h
@@ -53,8 +53,8 @@
#define bswap_64(x) OSSwapInt64(x)
#endif // !defined(bswap_16)
#else
-#define GG_LONGLONG(x) x##LL
-#define GG_ULONGLONG(x) x##ULL
+#define int64_t {x} x##LL
+#define uint64_t {x} x##ULL
static inline uint16 bswap_16(uint16 x) {
return (uint16)(((x & 0xFF) << 8) | ((x & 0xFF00) >> 8)); // NOLINT
}
@@ -65,14 +65,12 @@
}
#define bswap_32(x) bswap_32(x)
static inline uint64 bswap_64(uint64 x) {
- return (((x & GG_ULONGLONG(0xFF)) << 56) |
- ((x & GG_ULONGLONG(0xFF00)) << 40) |
- ((x & GG_ULONGLONG(0xFF0000)) << 24) |
- ((x & GG_ULONGLONG(0xFF000000)) << 8) |
- ((x & GG_ULONGLONG(0xFF00000000)) >> 8) |
- ((x & GG_ULONGLONG(0xFF0000000000)) >> 24) |
- ((x & GG_ULONGLONG(0xFF000000000000)) >> 40) |
- ((x & GG_ULONGLONG(0xFF00000000000000)) >> 56));
+ return (((x & uint64_t{0xFF}) << 56) | ((x & uint64_t{0xFF00}) << 40) |
+ ((x & uint64_t{0xFF0000}) << 24) | ((x & uint64_t{0xFF000000}) << 8) |
+ ((x & uint64_t{0xFF00000000}) >> 8) |
+ ((x & uint64_t{0xFF0000000000}) >> 24) |
+ ((x & uint64_t{0xFF000000000000}) >> 40) |
+ ((x & uint64_t{0xFF00000000000000}) >> 56));
}
#define bswap_64(x) bswap_64(x)
#endif
diff --git a/native/utils/base/logging.h b/native/utils/base/logging.h
index eae71b9..ce7cac8 100644
--- a/native/utils/base/logging.h
+++ b/native/utils/base/logging.h
@@ -24,7 +24,6 @@
#include "utils/base/logging_levels.h"
#include "utils/base/port.h"
-
namespace libtextclassifier3 {
namespace logging {
diff --git a/native/utils/base/status_macros.h b/native/utils/base/status_macros.h
index 40159fe..604b5b3 100644
--- a/native/utils/base/status_macros.h
+++ b/native/utils/base/status_macros.h
@@ -56,11 +56,20 @@
// TC3_RETURN_IF_ERROR(foo.Method(args...));
// return libtextclassifier3::Status();
// }
-#define TC3_RETURN_IF_ERROR(expr) \
+#define TC3_RETURN_IF_ERROR(expr) \
+ TC3_RETURN_IF_ERROR_INTERNAL(expr, std::move(adapter).status())
+
+#define TC3_RETURN_NULL_IF_ERROR(expr) \
+ TC3_RETURN_IF_ERROR_INTERNAL(expr, nullptr)
+
+#define TC3_RETURN_FALSE_IF_ERROR(expr) \
+ TC3_RETURN_IF_ERROR_INTERNAL(expr, false)
+
+#define TC3_RETURN_IF_ERROR_INTERNAL(expr, return_value) \
TC3_STATUS_MACROS_IMPL_ELSE_BLOCKER_ \
if (::libtextclassifier3::StatusAdapter adapter{expr}) { \
} else /* NOLINT */ \
- return std::move(adapter).status()
+ return return_value
// The GNU compiler emits a warning for code like:
//
diff --git a/native/utils/base/statusor.h b/native/utils/base/statusor.h
index dde9ecd..1bafcc7 100644
--- a/native/utils/base/statusor.h
+++ b/native/utils/base/statusor.h
@@ -34,7 +34,7 @@
inline StatusOr();
// Builds from a non-OK status. Crashes if an OK status is specified.
- inline StatusOr(const Status& status); // NOLINT
+ inline StatusOr(const Status& status); // NOLINT
// Builds from the specified value.
inline StatusOr(const T& value); // NOLINT
@@ -88,6 +88,8 @@
// Conversion assignment operator, T must be assignable from U
template <typename U>
inline StatusOr& operator=(const StatusOr<U>& other);
+ template <typename U>
+ inline StatusOr& operator=(StatusOr<U>&& other);
inline ~StatusOr();
@@ -136,6 +138,40 @@
friend class StatusOr;
private:
+ void Clear() {
+ if (ok()) {
+ value_.~T();
+ }
+ }
+
+ // Construct the value through placement new with the passed argument.
+ template <typename... Arg>
+ void MakeValue(Arg&&... arg) {
+ new (&value_) T(std::forward<Arg>(arg)...);
+ }
+
+ // Creates a valid instance of type T constructed with U and assigns it to
+ // value_. Handles how to properly assign to value_ if value_ was never
+ // actually initialized (if this is currently non-OK).
+ template <typename U>
+ void AssignValue(U&& value) {
+ if (ok()) {
+ value_ = std::forward<U>(value);
+ } else {
+ MakeValue(std::forward<U>(value));
+ status_ = Status::OK;
+ }
+ }
+
+ // Creates a status constructed with U and assigns it to status_. It also
+ // properly destroys value_ if this is OK and value_ represents a valid
+ // instance of T.
+ template <typename U>
+ void AssignStatus(U&& v) {
+ Clear();
+ status_ = static_cast<Status>(std::forward<U>(v));
+ }
+
Status status_;
// The members of unions do not require initialization and are not destructed
// unless specifically called. This allows us to construct instances of
@@ -212,35 +248,47 @@
template <typename T>
inline StatusOr<T>& StatusOr<T>::operator=(const StatusOr& other) {
- status_ = other.status_;
- if (status_.ok()) {
- value_ = other.value_;
+ if (other.ok()) {
+ AssignValue(other.value_);
+ } else {
+ AssignStatus(other.status_);
}
return *this;
}
template <typename T>
inline StatusOr<T>& StatusOr<T>::operator=(StatusOr&& other) {
- status_ = other.status_;
- if (status_.ok()) {
- value_ = std::move(other.value_);
+ if (other.ok()) {
+ AssignValue(std::move(other.value_));
+ } else {
+ AssignStatus(std::move(other.status_));
}
return *this;
}
template <typename T>
inline StatusOr<T>::~StatusOr() {
- if (ok()) {
- value_.~T();
- }
+ Clear();
}
template <typename T>
template <typename U>
inline StatusOr<T>& StatusOr<T>::operator=(const StatusOr<U>& other) {
- status_ = other.status_;
- if (status_.ok()) {
- value_ = other.value_;
+ if (other.ok()) {
+ AssignValue(other.value_);
+ } else {
+ AssignStatus(other.status_);
+ }
+ return *this;
+}
+
+template <typename T>
+template <typename U>
+inline StatusOr<T>& StatusOr<T>::operator=(StatusOr<U>&& other) {
+ if (other.ok()) {
+ AssignValue(std::move(other.value_));
+ } else {
+ AssignStatus(std::move(other.status_));
}
return *this;
}
@@ -259,7 +307,17 @@
#define TC3_ASSIGN_OR_RETURN_FALSE(lhs, rexpr) \
TC3_ASSIGN_OR_RETURN(lhs, rexpr, false)
-#define TC3_ASSIGN_OR_RETURN_0(lhs, rexpr) TC3_ASSIGN_OR_RETURN(lhs, rexpr, 0)
+#define TC3_ASSIGN_OR_RETURN_0(...) \
+ TC_STATUS_MACROS_IMPL_GET_VARIADIC_( \
+ (__VA_ARGS__, TC_STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_0_3_, \
+ TC_STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_0_2_)) \
+ (__VA_ARGS__)
+
+#define TC_STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_0_2_(lhs, rexpr) \
+ TC3_ASSIGN_OR_RETURN(lhs, rexpr, 0)
+#define TC_STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_0_3_(lhs, rexpr, \
+ log_expression) \
+ TC3_ASSIGN_OR_RETURN(lhs, rexpr, (log_expression, 0))
// =================================================================
// == Implementation details, do not rely on anything below here. ==
@@ -281,11 +339,11 @@
#define TC_STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_2_(lhs, rexpr) \
TC_STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_3_(lhs, rexpr, _)
-#define TC_STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_3_(lhs, rexpr, \
- error_expression) \
- TC_STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_( \
- TC_STATUS_MACROS_IMPL_CONCAT_(_status_or_value, __LINE__), lhs, rexpr, \
- error_expression)
+#define TC_STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_3_(lhs, rexpr, \
+ error_expression) \
+ TC_STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_( \
+ TC_STATUS_MACROS_IMPL_CONCAT_(_status_or_value, __COUNTER__), lhs, \
+ rexpr, error_expression)
#define TC_STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_(statusor, lhs, rexpr, \
error_expression) \
auto statusor = (rexpr); \
diff --git a/native/utils/base/statusor_test.cc b/native/utils/base/statusor_test.cc
index 23165b0..04ac8ee 100644
--- a/native/utils/base/statusor_test.cc
+++ b/native/utils/base/statusor_test.cc
@@ -103,6 +103,84 @@
EXPECT_FALSE(moved_error_status.ok());
}
+// Create a class that has validly defined copy and move operators, but will
+// cause a crash if assignment operators are invoked on an instance that was
+// never initialized.
+class Baz {
+ public:
+ Baz() : i_(new int), invalid_(false) {}
+ Baz(const Baz& other) {
+ i_ = new int;
+ *i_ = *other.i_;
+ invalid_ = false;
+ }
+ Baz(const Foo& other) { // NOLINT
+ i_ = new int;
+ *i_ = other.i();
+ invalid_ = false;
+ }
+ Baz(Baz&& other) {
+ // Copy other.i_ into this so that this holds it now. Mark other as invalid
+ // so that it doesn't destroy the int that this now owns when other is
+ // destroyed.
+ i_ = other.i_;
+ other.invalid_ = true;
+ invalid_ = false;
+ }
+ Baz& operator=(const Baz& rhs) {
+ // Copy rhs.i_ into tmp. Then swap this with tmp so that this no has the
+ // value that rhs had and tmp will destroy the value that this used to hold.
+ Baz tmp(rhs);
+ std::swap(i_, tmp.i_);
+ return *this;
+ }
+ Baz& operator=(Baz&& rhs) {
+ std::swap(i_, rhs.i_);
+ return *this;
+ }
+ ~Baz() {
+ if (!invalid_) delete i_;
+ }
+
+ private:
+ int* i_;
+ bool invalid_;
+};
+
+TEST(StatusOrTest, CopyAssignment) {
+ StatusOr<Baz> baz_or;
+ EXPECT_FALSE(baz_or.ok());
+ Baz b;
+ StatusOr<Baz> other(b);
+ baz_or = other;
+ EXPECT_TRUE(baz_or.ok());
+ EXPECT_TRUE(other.ok());
+}
+
+TEST(StatusOrTest, MoveAssignment) {
+ StatusOr<Baz> baz_or;
+ EXPECT_FALSE(baz_or.ok());
+ baz_or = StatusOr<Baz>(Baz());
+ EXPECT_TRUE(baz_or.ok());
+}
+
+TEST(StatusOrTest, CopyConversionAssignment) {
+ StatusOr<Baz> baz_or;
+ EXPECT_FALSE(baz_or.ok());
+ StatusOr<Foo> foo_or(Foo(12));
+ baz_or = foo_or;
+ EXPECT_TRUE(baz_or.ok());
+ EXPECT_TRUE(foo_or.ok());
+}
+
+TEST(StatusOrTest, MoveConversionAssignment) {
+ StatusOr<Baz> baz_or;
+ EXPECT_FALSE(baz_or.ok());
+ StatusOr<Foo> foo_or(Foo(12));
+ baz_or = std::move(foo_or);
+ EXPECT_TRUE(baz_or.ok());
+}
+
struct OkFn {
StatusOr<int> operator()() { return 42; }
};
diff --git a/native/utils/bert_tokenizer.cc b/native/utils/bert_tokenizer.cc
new file mode 100644
index 0000000..bf9341f
--- /dev/null
+++ b/native/utils/bert_tokenizer.cc
@@ -0,0 +1,115 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/bert_tokenizer.h"
+
+#include <string>
+
+#include "annotator/types.h"
+#include "utils/tokenizer-utils.h"
+#include "utils/utf8/unicodetext.h"
+#include "utils/utf8/unilib.h"
+#include "absl/strings/string_view.h"
+
+namespace libtextclassifier3 {
+
+FlatHashMapBackedWordpiece::FlatHashMapBackedWordpiece(
+ const std::vector<std::string>& vocab)
+ : vocab_{vocab} {
+ for (int i = 0; i < vocab_.size(); ++i) {
+ index_map_[vocab_[i]] = i;
+ }
+}
+
+LookupStatus FlatHashMapBackedWordpiece::Contains(absl::string_view key,
+ bool* value) const {
+ *value = index_map_.contains(key);
+ return LookupStatus();
+}
+
+bool FlatHashMapBackedWordpiece::LookupId(const absl::string_view key,
+ int* result) const {
+ auto it = index_map_.find(key);
+ if (it == index_map_.end()) {
+ return false;
+ }
+ *result = it->second;
+ return true;
+}
+
+bool FlatHashMapBackedWordpiece::LookupWord(int vocab_id,
+ absl::string_view* result) const {
+ if (vocab_id >= vocab_.size() || vocab_id < 0) {
+ return false;
+ }
+ *result = vocab_[vocab_id];
+ return true;
+}
+
+TokenizerResult BertTokenizer::TokenizeSingleToken(const std::string& token) {
+ std::vector<std::string> tokens = {token};
+ return BertTokenizer::Tokenize(tokens);
+}
+
+TokenizerResult BertTokenizer::Tokenize(const std::string& input) {
+ std::vector<std::string> tokens = PreTokenize(input);
+ return BertTokenizer::Tokenize(tokens);
+}
+
+TokenizerResult BertTokenizer::Tokenize(
+ const std::vector<std::string>& tokens) {
+ WordpieceTokenizerResult result;
+ std::vector<std::string>& subwords = result.subwords;
+ std::vector<int>& wp_absolute_begin_offset = result.wp_begin_offset;
+ std::vector<int>& wp_absolute_end_offset = result.wp_end_offset;
+
+ for (int token_index = 0; token_index < tokens.size(); token_index++) {
+ auto& token = tokens[token_index];
+ int num_word_pieces = 0;
+ LookupStatus status = WordpieceTokenize(
+ token, options_.max_bytes_per_token, options_.max_chars_per_subtoken,
+ options_.suffix_indicator, options_.use_unknown_token,
+ options_.unknown_token, options_.split_unknown_chars, &vocab_,
+ &subwords, &wp_absolute_begin_offset, &wp_absolute_end_offset,
+ &num_word_pieces);
+
+ if (!status.success) {
+ return std::move(result);
+ }
+ }
+
+ return std::move(result);
+}
+
+// This replicates how the original bert_tokenizer from the tflite-support
+// library pretokenize text by using regex_split with these default regexes.
+// It splits the text on spaces, punctuations and chinese characters and
+// output all the tokens except spaces.
+// So far, the only difference between this and the original implementation
+// we are aware of is that the original regexes has 8 ranges of chinese
+// unicodes. We have all these 8 ranges plus two extra ranges.
+std::vector<std::string> BertTokenizer::PreTokenize(
+ const absl::string_view input) {
+ const std::vector<Token> tokens =
+ TokenizeOnWhiteSpacePunctuationAndChineseLetter(input);
+ std::vector<std::string> token_texts;
+ std::transform(tokens.begin(), tokens.end(), std::back_inserter(token_texts),
+ [](Token const& token) { return std::move(token.value); });
+
+ return token_texts;
+}
+
+} // namespace libtextclassifier3
diff --git a/native/utils/bert_tokenizer.h b/native/utils/bert_tokenizer.h
new file mode 100644
index 0000000..eb5f978
--- /dev/null
+++ b/native/utils/bert_tokenizer.h
@@ -0,0 +1,140 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_BERT_TOKENIZER_H_
+#define LIBTEXTCLASSIFIER_UTILS_BERT_TOKENIZER_H_
+
+#include <fstream>
+#include <string>
+#include <vector>
+
+#include "utils/wordpiece_tokenizer.h"
+#include "absl/container/flat_hash_map.h"
+#include "tensorflow_lite_support/cc/text/tokenizers/tokenizer.h"
+#include "tensorflow_lite_support/cc/utils/common_utils.h"
+
+namespace libtextclassifier3 {
+
+using ::tflite::support::text::tokenizer::TokenizerResult;
+using ::tflite::support::utils::LoadVocabFromBuffer;
+using ::tflite::support::utils::LoadVocabFromFile;
+
+constexpr int kDefaultMaxBytesPerToken = 100;
+constexpr int kDefaultMaxCharsPerSubToken = 100;
+constexpr char kDefaultSuffixIndicator[] = "##";
+constexpr bool kDefaultUseUnknownToken = true;
+constexpr char kDefaultUnknownToken[] = "[UNK]";
+constexpr bool kDefaultSplitUnknownChars = false;
+
+// Result of wordpiece tokenization including subwords and offsets.
+// Example:
+// input: tokenize me please
+// subwords: token ##ize me plea ##se
+// wp_begin_offset: [0, 5, 9, 12, 16]
+// wp_end_offset: [ 5, 8, 11, 16, 18]
+// row_lengths: [2, 1, 1]
+struct WordpieceTokenizerResult
+ : tflite::support::text::tokenizer::TokenizerResult {
+ std::vector<int> wp_begin_offset;
+ std::vector<int> wp_end_offset;
+ std::vector<int> row_lengths;
+};
+
+// Options to create a BertTokenizer.
+struct BertTokenizerOptions {
+ int max_bytes_per_token = kDefaultMaxBytesPerToken;
+ int max_chars_per_subtoken = kDefaultMaxCharsPerSubToken;
+ std::string suffix_indicator = kDefaultSuffixIndicator;
+ bool use_unknown_token = kDefaultUseUnknownToken;
+ std::string unknown_token = kDefaultUnknownToken;
+ bool split_unknown_chars = kDefaultSplitUnknownChars;
+};
+
+// A flat-hash-map based implementation of WordpieceVocab, used in
+// BertTokenizer to invoke tensorflow::text::WordpieceTokenize within.
+class FlatHashMapBackedWordpiece : public WordpieceVocab {
+ public:
+ explicit FlatHashMapBackedWordpiece(const std::vector<std::string>& vocab);
+
+ LookupStatus Contains(absl::string_view key, bool* value) const override;
+ bool LookupId(absl::string_view key, int* result) const;
+ bool LookupWord(int vocab_id, absl::string_view* result) const;
+ int VocabularySize() const { return vocab_.size(); }
+
+ private:
+ // All words indexed position in vocabulary file.
+ std::vector<std::string> vocab_;
+ absl::flat_hash_map<absl::string_view, int> index_map_;
+};
+
+// Wordpiece tokenizer for bert models. Initialized with a vocab file or vector.
+class BertTokenizer : public tflite::support::text::tokenizer::Tokenizer {
+ public:
+ // Initialize the tokenizer from vocab vector and tokenizer configs.
+ explicit BertTokenizer(const std::vector<std::string>& vocab,
+ const BertTokenizerOptions& options = {})
+ : vocab_{FlatHashMapBackedWordpiece(vocab)}, options_{options} {}
+
+ // Initialize the tokenizer from file path to vocab and tokenizer configs.
+ explicit BertTokenizer(const std::string& path_to_vocab,
+ const BertTokenizerOptions& options = {})
+ : BertTokenizer(LoadVocabFromFile(path_to_vocab), options) {}
+
+ // Initialize the tokenizer from buffer and size of vocab and tokenizer
+ // configs.
+ BertTokenizer(const char* vocab_buffer_data, size_t vocab_buffer_size,
+ const BertTokenizerOptions& options = {})
+ : BertTokenizer(LoadVocabFromBuffer(vocab_buffer_data, vocab_buffer_size),
+ options) {}
+
+ // Perform tokenization, first tokenize the input and then find the subwords.
+ // return tokenized results containing the subwords.
+ TokenizerResult Tokenize(const std::string& input) override;
+
+ // Perform tokenization on a single token, return tokenized results containing
+ // the subwords.
+ TokenizerResult TokenizeSingleToken(const std::string& token);
+
+ // Perform tokenization, return tokenized results containing the subwords.
+ TokenizerResult Tokenize(const std::vector<std::string>& tokens);
+
+ // Check if a certain key is included in the vocab.
+ LookupStatus Contains(const absl::string_view key, bool* value) const {
+ return vocab_.Contains(key, value);
+ }
+
+ // Find the id of a wordpiece.
+ bool LookupId(absl::string_view key, int* result) const override {
+ return vocab_.LookupId(key, result);
+ }
+
+ // Find the wordpiece from an id.
+ bool LookupWord(int vocab_id, absl::string_view* result) const override {
+ return vocab_.LookupWord(vocab_id, result);
+ }
+
+ int VocabularySize() const { return vocab_.VocabularySize(); }
+
+ static std::vector<std::string> PreTokenize(const absl::string_view input);
+
+ private:
+ FlatHashMapBackedWordpiece vocab_;
+ BertTokenizerOptions options_;
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_BERT_TOKENIZER_H_
diff --git a/native/utils/bert_tokenizer_test.cc b/native/utils/bert_tokenizer_test.cc
new file mode 100644
index 0000000..3c4e52c
--- /dev/null
+++ b/native/utils/bert_tokenizer_test.cc
@@ -0,0 +1,171 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/bert_tokenizer.h"
+
+#include "utils/test-data-test-utils.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+
+using ::testing::ElementsAre;
+
+namespace {
+constexpr char kTestVocabPath[] = "annotator/pod_ner/test_data/vocab.txt";
+
+void AssertTokenizerResults(std::unique_ptr<BertTokenizer> tokenizer) {
+ auto results = tokenizer->Tokenize("i'm question");
+
+ EXPECT_THAT(results.subwords, ElementsAre("i", "'", "m", "question"));
+}
+
+TEST(BertTokenizerTest, TestTokenizerCreationFromBuffer) {
+ std::string buffer = GetTestFileContent(kTestVocabPath);
+
+ auto tokenizer =
+ absl::make_unique<BertTokenizer>(buffer.data(), buffer.size());
+
+ AssertTokenizerResults(std::move(tokenizer));
+}
+
+TEST(BertTokenizerTest, TestTokenizerCreationFromFile) {
+ auto tokenizer =
+ absl::make_unique<BertTokenizer>(GetTestDataPath(kTestVocabPath));
+
+ AssertTokenizerResults(std::move(tokenizer));
+}
+
+TEST(BertTokenizerTest, TestTokenizerCreationFromVector) {
+ std::vector<std::string> vocab;
+ vocab.emplace_back("i");
+ vocab.emplace_back("'");
+ vocab.emplace_back("m");
+ vocab.emplace_back("question");
+ auto tokenizer = absl::make_unique<BertTokenizer>(vocab);
+
+ AssertTokenizerResults(std::move(tokenizer));
+}
+
+TEST(BertTokenizerTest, TestTokenizerMultipleRows) {
+ auto tokenizer =
+ absl::make_unique<BertTokenizer>(GetTestDataPath(kTestVocabPath));
+
+ auto results = tokenizer->Tokenize("i'm questionansweraskask");
+
+ EXPECT_THAT(results.subwords, ElementsAre("i", "'", "m", "question", "##ans",
+ "##wer", "##ask", "##ask"));
+}
+
+TEST(BertTokenizerTest, TestTokenizerUnknownTokens) {
+ std::vector<std::string> vocab;
+ vocab.emplace_back("i");
+ vocab.emplace_back("'");
+ vocab.emplace_back("m");
+ vocab.emplace_back("question");
+ auto tokenizer = absl::make_unique<BertTokenizer>(vocab);
+
+ auto results = tokenizer->Tokenize("i'm questionansweraskask");
+
+ EXPECT_THAT(results.subwords,
+ ElementsAre("i", "'", "m", kDefaultUnknownToken));
+}
+
+TEST(BertTokenizerTest, TestLookupId) {
+ std::vector<std::string> vocab;
+ vocab.emplace_back("i");
+ vocab.emplace_back("'");
+ vocab.emplace_back("m");
+ vocab.emplace_back("question");
+ auto tokenizer = absl::make_unique<BertTokenizer>(vocab);
+
+ int i;
+ ASSERT_FALSE(tokenizer->LookupId("iDontExist", &i));
+
+ ASSERT_TRUE(tokenizer->LookupId("i", &i));
+ ASSERT_EQ(i, 0);
+ ASSERT_TRUE(tokenizer->LookupId("'", &i));
+ ASSERT_EQ(i, 1);
+ ASSERT_TRUE(tokenizer->LookupId("m", &i));
+ ASSERT_EQ(i, 2);
+ ASSERT_TRUE(tokenizer->LookupId("question", &i));
+ ASSERT_EQ(i, 3);
+}
+
+TEST(BertTokenizerTest, TestLookupWord) {
+ std::vector<std::string> vocab;
+ vocab.emplace_back("i");
+ vocab.emplace_back("'");
+ vocab.emplace_back("m");
+ vocab.emplace_back("question");
+ auto tokenizer = absl::make_unique<BertTokenizer>(vocab);
+
+ absl::string_view result;
+ ASSERT_FALSE(tokenizer->LookupWord(6, &result));
+
+ ASSERT_TRUE(tokenizer->LookupWord(0, &result));
+ ASSERT_EQ(result, "i");
+ ASSERT_TRUE(tokenizer->LookupWord(1, &result));
+ ASSERT_EQ(result, "'");
+ ASSERT_TRUE(tokenizer->LookupWord(2, &result));
+ ASSERT_EQ(result, "m");
+ ASSERT_TRUE(tokenizer->LookupWord(3, &result));
+ ASSERT_EQ(result, "question");
+}
+
+TEST(BertTokenizerTest, TestContains) {
+ std::vector<std::string> vocab;
+ vocab.emplace_back("i");
+ vocab.emplace_back("'");
+ vocab.emplace_back("m");
+ vocab.emplace_back("question");
+ auto tokenizer = absl::make_unique<BertTokenizer>(vocab);
+
+ bool result;
+ tokenizer->Contains("iDontExist", &result);
+ ASSERT_FALSE(result);
+
+ tokenizer->Contains("i", &result);
+ ASSERT_TRUE(result);
+ tokenizer->Contains("'", &result);
+ ASSERT_TRUE(result);
+ tokenizer->Contains("m", &result);
+ ASSERT_TRUE(result);
+ tokenizer->Contains("question", &result);
+ ASSERT_TRUE(result);
+}
+
+TEST(BertTokenizerTest, TestLVocabularySize) {
+ std::vector<std::string> vocab;
+ vocab.emplace_back("i");
+ vocab.emplace_back("'");
+ vocab.emplace_back("m");
+ vocab.emplace_back("question");
+ auto tokenizer = absl::make_unique<BertTokenizer>(vocab);
+
+ ASSERT_EQ(tokenizer->VocabularySize(), 4);
+}
+
+TEST(BertTokenizerTest, SimpleEnglishWithPunctuation) {
+ absl::string_view input = "I am fine, thanks!";
+
+ std::vector<std::string> tokens = BertTokenizer::PreTokenize(input);
+
+ EXPECT_THAT(tokens, testing::ElementsAreArray(
+ {"I", "am", "fine", ",", "thanks", "!"}));
+}
+} // namespace
+} // namespace libtextclassifier3
diff --git a/native/utils/calendar/calendar_test.cc b/native/utils/calendar/calendar_test.cc
new file mode 100644
index 0000000..b94813c
--- /dev/null
+++ b/native/utils/calendar/calendar_test.cc
@@ -0,0 +1,369 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/jvm-test-utils.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+class CalendarTest : public ::testing::Test {
+ protected:
+ CalendarTest()
+ : calendarlib_(libtextclassifier3::CreateCalendarLibForTesting()) {}
+
+ static constexpr int kWednesday = 4;
+ std::unique_ptr<CalendarLib> calendarlib_;
+};
+
+TEST_F(CalendarTest, Interface) {
+ int64 time;
+ DatetimeGranularity granularity;
+ std::string timezone;
+ DatetimeParsedData data;
+ bool result = calendarlib_->InterpretParseData(
+ data, /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Zurich",
+ /*reference_locale=*/"en-CH",
+ /*prefer_future_for_unspecified_date=*/false, &time, &granularity);
+ TC3_LOG(INFO) << result;
+}
+
+TEST_F(CalendarTest, SetsZeroTimeWhenNotRelative) {
+ int64 time;
+ DatetimeGranularity granularity;
+ DatetimeParsedData data;
+ data.SetAbsoluteValue(DatetimeComponent::ComponentType::YEAR, 2018);
+
+ ASSERT_TRUE(calendarlib_->InterpretParseData(
+ data,
+ /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
+ /*reference_locale=*/"en-CH",
+ /*prefer_future_for_unspecified_date=*/false, &time, &granularity));
+ EXPECT_EQ(time, 1514761200000L /* Jan 01 2018 00:00:00 */);
+
+ ASSERT_TRUE(calendarlib_->InterpretParseData(
+ data,
+ /*reference_time_ms_utc=*/1L, /*reference_timezone=*/"Europe/Zurich",
+ /*reference_locale=*/"en-CH",
+ /*prefer_future_for_unspecified_date=*/false, &time, &granularity));
+ EXPECT_EQ(time, 1514761200000L /* Jan 01 2018 00:00:00 */);
+}
+
+TEST_F(CalendarTest, SetsTimeZone) {
+ int64 time;
+ DatetimeGranularity granularity;
+ DatetimeParsedData data;
+ data.SetAbsoluteValue(DatetimeComponent::ComponentType::YEAR, 2018);
+ data.SetAbsoluteValue(DatetimeComponent::ComponentType::HOUR, 7);
+ data.SetAbsoluteValue(DatetimeComponent::ComponentType::MINUTE, 30);
+ data.SetAbsoluteValue(DatetimeComponent::ComponentType::SECOND, 10);
+
+ ASSERT_TRUE(calendarlib_->InterpretParseData(
+ data,
+ /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
+ /*reference_locale=*/"en-CH",
+ /*prefer_future_for_unspecified_date=*/false, &time, &granularity));
+ EXPECT_EQ(time, 1514788210000L /* Jan 01 2018 07:30:10 GMT+01:00 */);
+
+ data.SetAbsoluteValue(DatetimeComponent::ComponentType::ZONE_OFFSET,
+ 60); // GMT+01:00
+ ASSERT_TRUE(calendarlib_->InterpretParseData(
+ data,
+ /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
+ /*reference_locale=*/"en-CH",
+ /*prefer_future_for_unspecified_date=*/false, &time, &granularity));
+ EXPECT_EQ(time, 1514788210000L /* Jan 01 2018 07:30:10 GMT+01:00 */);
+
+ // Now the hour is in terms of GMT+02:00 which is one hour ahead of
+ // GMT+01:00.
+ data.SetAbsoluteValue(DatetimeComponent::ComponentType::ZONE_OFFSET,
+ 120); // GMT+02:00
+ ASSERT_TRUE(calendarlib_->InterpretParseData(
+ data,
+ /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
+ /*reference_locale=*/"en-CH",
+ /*prefer_future_for_unspecified_date=*/false, &time, &granularity));
+ EXPECT_EQ(time, 1514784610000L /* Jan 01 2018 06:30:10 GMT+01:00 */);
+}
+
+TEST_F(CalendarTest, RoundingToGranularityBasic) {
+ int64 time;
+ DatetimeGranularity granularity;
+ DatetimeParsedData data;
+
+ data.SetAbsoluteValue(DatetimeComponent::ComponentType::YEAR, 2018);
+ ASSERT_TRUE(calendarlib_->InterpretParseData(
+ data,
+ /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
+ /*reference_locale=*/"en-CH",
+ /*prefer_future_for_unspecified_date=*/false, &time, &granularity));
+ EXPECT_EQ(time, 1514761200000L /* Jan 01 2018 00:00:00 */);
+
+ data.SetAbsoluteValue(DatetimeComponent::ComponentType::MONTH, 4);
+ ASSERT_TRUE(calendarlib_->InterpretParseData(
+ data,
+ /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
+ /*reference_locale=*/"en-CH",
+ /*prefer_future_for_unspecified_date=*/false, &time, &granularity));
+ EXPECT_EQ(time, 1522533600000L /* Apr 01 2018 00:00:00 */);
+
+ data.SetAbsoluteValue(DatetimeComponent::ComponentType::DAY_OF_MONTH, 25);
+ ASSERT_TRUE(calendarlib_->InterpretParseData(
+ data,
+ /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
+ /*reference_locale=*/"en-CH",
+ /*prefer_future_for_unspecified_date=*/false, &time, &granularity));
+ EXPECT_EQ(time, 1524607200000L /* Apr 25 2018 00:00:00 */);
+
+ data.SetAbsoluteValue(DatetimeComponent::ComponentType::HOUR, 9);
+ ASSERT_TRUE(calendarlib_->InterpretParseData(
+ data,
+ /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
+ /*reference_locale=*/"en-CH",
+ /*prefer_future_for_unspecified_date=*/false, &time, &granularity));
+ EXPECT_EQ(time, 1524639600000L /* Apr 25 2018 09:00:00 */);
+
+ data.SetAbsoluteValue(DatetimeComponent::ComponentType::MINUTE, 33);
+ ASSERT_TRUE(calendarlib_->InterpretParseData(
+ data,
+ /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
+ /*reference_locale=*/"en-CH",
+ /*prefer_future_for_unspecified_date=*/false, &time, &granularity));
+ EXPECT_EQ(time, 1524641580000 /* Apr 25 2018 09:33:00 */);
+
+ data.SetAbsoluteValue(DatetimeComponent::ComponentType::SECOND, 59);
+ ASSERT_TRUE(calendarlib_->InterpretParseData(
+ data,
+ /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
+ /*reference_locale=*/"en-CH",
+ /*prefer_future_for_unspecified_date=*/false, &time, &granularity));
+ EXPECT_EQ(time, 1524641639000 /* Apr 25 2018 09:33:59 */);
+}
+
+TEST_F(CalendarTest, RoundingToGranularityWeek) {
+ int64 time;
+ DatetimeGranularity granularity;
+ // Prepare data structure that means: "next week"
+ DatetimeParsedData data;
+ data.SetRelativeValue(DatetimeComponent::ComponentType::WEEK,
+ DatetimeComponent::RelativeQualifier::NEXT);
+ data.SetRelativeCount(DatetimeComponent::ComponentType::WEEK, 1);
+
+ ASSERT_TRUE(calendarlib_->InterpretParseData(
+ data,
+ /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
+ /*reference_locale=*/"de-CH",
+ /*prefer_future_for_unspecified_date=*/false, &time, &granularity));
+ EXPECT_EQ(time, 342000000L /* Mon Jan 05 1970 00:00:00 */);
+
+ ASSERT_TRUE(calendarlib_->InterpretParseData(
+ data,
+ /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
+ /*reference_locale=*/"en-US",
+ /*prefer_future_for_unspecified_date=*/false, &time, &granularity));
+ EXPECT_EQ(time, 255600000L /* Sun Jan 04 1970 00:00:00 */);
+}
+
+TEST_F(CalendarTest, RelativeTime) {
+ const int64 ref_time = 1524648839000L; /* 25 April 2018 09:33:59 */
+ int64 time;
+ DatetimeGranularity granularity;
+
+ // Two Weds from now.
+ DatetimeParsedData future_wed_parse;
+ future_wed_parse.SetRelativeValue(
+ DatetimeComponent::ComponentType::DAY_OF_WEEK,
+ DatetimeComponent::RelativeQualifier::FUTURE);
+ future_wed_parse.SetRelativeCount(
+ DatetimeComponent::ComponentType::DAY_OF_WEEK, 2);
+ future_wed_parse.SetAbsoluteValue(
+ DatetimeComponent::ComponentType::DAY_OF_WEEK, kWednesday);
+
+ ASSERT_TRUE(calendarlib_->InterpretParseData(
+ future_wed_parse, ref_time, /*reference_timezone=*/"Europe/Zurich",
+ /*reference_locale=*/"en-US",
+ /*prefer_future_for_unspecified_date=*/false, &time, &granularity));
+ EXPECT_EQ(time, 1525858439000L /* Wed May 09 2018 11:33:59 */);
+ EXPECT_EQ(granularity, GRANULARITY_DAY);
+
+ // Next Wed.
+ DatetimeParsedData next_wed_parse;
+ next_wed_parse.SetAbsoluteValue(DatetimeComponent::ComponentType::DAY_OF_WEEK,
+ kWednesday);
+ next_wed_parse.SetRelativeValue(DatetimeComponent::ComponentType::DAY_OF_WEEK,
+ DatetimeComponent::RelativeQualifier::NEXT);
+ next_wed_parse.SetRelativeCount(DatetimeComponent::ComponentType::DAY_OF_WEEK,
+ 1);
+
+ ASSERT_TRUE(calendarlib_->InterpretParseData(
+ next_wed_parse, ref_time, /*reference_timezone=*/"Europe/Zurich",
+ /*reference_locale=*/"en-US",
+ /*prefer_future_for_unspecified_date=*/false, &time, &granularity));
+ EXPECT_EQ(time, 1525212000000L /* Wed May 02 2018 00:00:00 */);
+ EXPECT_EQ(granularity, GRANULARITY_DAY);
+
+ // Same Wed.
+ DatetimeParsedData same_wed_parse;
+ same_wed_parse.SetRelativeValue(DatetimeComponent::ComponentType::DAY_OF_WEEK,
+ DatetimeComponent::RelativeQualifier::THIS);
+ same_wed_parse.SetAbsoluteValue(DatetimeComponent::ComponentType::DAY_OF_WEEK,
+ kWednesday);
+ same_wed_parse.SetRelativeCount(DatetimeComponent::ComponentType::DAY_OF_WEEK,
+ 1);
+
+ ASSERT_TRUE(calendarlib_->InterpretParseData(
+ same_wed_parse, ref_time, /*reference_timezone=*/"Europe/Zurich",
+ /*reference_locale=*/"en-US",
+ /*prefer_future_for_unspecified_date=*/false, &time, &granularity));
+ EXPECT_EQ(time, 1524607200000L /* Wed Apr 25 2018 00:00:00 */);
+ EXPECT_EQ(granularity, GRANULARITY_DAY);
+
+ // Previous Wed.
+ DatetimeParsedData last_wed_parse;
+ last_wed_parse.SetRelativeValue(DatetimeComponent::ComponentType::DAY_OF_WEEK,
+ DatetimeComponent::RelativeQualifier::LAST);
+ last_wed_parse.SetAbsoluteValue(DatetimeComponent::ComponentType::DAY_OF_WEEK,
+ kWednesday);
+ last_wed_parse.SetRelativeCount(DatetimeComponent::ComponentType::DAY_OF_WEEK,
+ 1);
+
+ ASSERT_TRUE(calendarlib_->InterpretParseData(
+ last_wed_parse, ref_time, /*reference_timezone=*/"Europe/Zurich",
+ /*reference_locale=*/"en-US",
+ /*prefer_future_for_unspecified_date=*/false, &time, &granularity));
+ EXPECT_EQ(time, 1524002400000L /* Wed Apr 18 2018 00:00:00 */);
+ EXPECT_EQ(granularity, GRANULARITY_DAY);
+
+ // Two Weds ago.
+ DatetimeParsedData past_wed_parse;
+ past_wed_parse.SetRelativeValue(DatetimeComponent::ComponentType::DAY_OF_WEEK,
+ DatetimeComponent::RelativeQualifier::PAST);
+ past_wed_parse.SetAbsoluteValue(DatetimeComponent::ComponentType::DAY_OF_WEEK,
+ kWednesday);
+ past_wed_parse.SetRelativeCount(DatetimeComponent::ComponentType::DAY_OF_WEEK,
+ -2);
+
+ ASSERT_TRUE(calendarlib_->InterpretParseData(
+ past_wed_parse, ref_time, /*reference_timezone=*/"Europe/Zurich",
+ /*reference_locale=*/"en-US",
+ /*prefer_future_for_unspecified_date=*/false, &time, &granularity));
+ EXPECT_EQ(time, 1523439239000L /* Wed Apr 11 2018 11:33:59 */);
+ EXPECT_EQ(granularity, GRANULARITY_DAY);
+
+ // In 3 hours.
+ DatetimeParsedData in_3_hours_parse;
+ in_3_hours_parse.SetRelativeValue(
+ DatetimeComponent::ComponentType::HOUR,
+ DatetimeComponent::RelativeQualifier::FUTURE);
+ in_3_hours_parse.SetRelativeCount(DatetimeComponent::ComponentType::HOUR, 3);
+
+ ASSERT_TRUE(calendarlib_->InterpretParseData(
+ in_3_hours_parse, ref_time, /*reference_timezone=*/"Europe/Zurich",
+ /*reference_locale=*/"en-US",
+ /*prefer_future_for_unspecified_date=*/false, &time, &granularity));
+ EXPECT_EQ(time, 1524659639000L /* Wed Apr 25 2018 14:33:59 */);
+ EXPECT_EQ(granularity, GRANULARITY_HOUR);
+
+ // In 5 minutes.
+ DatetimeParsedData in_5_minutes_parse;
+ in_5_minutes_parse.SetRelativeValue(
+ DatetimeComponent::ComponentType::MINUTE,
+ DatetimeComponent::RelativeQualifier::FUTURE);
+ in_5_minutes_parse.SetRelativeCount(DatetimeComponent::ComponentType::MINUTE,
+ 5);
+
+ ASSERT_TRUE(calendarlib_->InterpretParseData(
+ in_5_minutes_parse, ref_time, /*reference_timezone=*/"Europe/Zurich",
+ /*reference_locale=*/"en-US",
+ /*prefer_future_for_unspecified_date=*/false, &time, &granularity));
+ EXPECT_EQ(time, 1524649139000L /* Wed Apr 25 2018 14:33:59 */);
+ EXPECT_EQ(granularity, GRANULARITY_MINUTE);
+
+ // In 10 seconds.
+ DatetimeParsedData in_10_seconds_parse;
+ in_10_seconds_parse.SetRelativeValue(
+ DatetimeComponent::ComponentType::SECOND,
+ DatetimeComponent::RelativeQualifier::FUTURE);
+ in_10_seconds_parse.SetRelativeCount(DatetimeComponent::ComponentType::SECOND,
+ 10);
+
+ ASSERT_TRUE(calendarlib_->InterpretParseData(
+ in_10_seconds_parse, ref_time, /*reference_timezone=*/"Europe/Zurich",
+ /*reference_locale=*/"en-US",
+ /*prefer_future_for_unspecified_date=*/false, &time, &granularity));
+ EXPECT_EQ(time, 1524648849000L /* Wed Apr 25 2018 14:33:59 */);
+ EXPECT_EQ(granularity, GRANULARITY_SECOND);
+}
+
+TEST_F(CalendarTest, AddsADayWhenTimeInThePastAndDayNotSpecified) {
+ int64 time;
+ DatetimeGranularity granularity;
+ DatetimeParsedData data;
+ data.SetAbsoluteValue(DatetimeComponent::ComponentType::HOUR, 7);
+ data.SetAbsoluteValue(DatetimeComponent::ComponentType::MINUTE, 10);
+
+ ASSERT_TRUE(calendarlib_->InterpretParseData(
+ data,
+ /*reference_time_ms_utc=*/1567317600000L /* Sept 01 2019 00:00:00 */,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*reference_locale=*/"en-CH", /*prefer_future_for_unspecified_date=*/true,
+ &time, &granularity));
+ EXPECT_EQ(time, 1567401000000L /* Sept 02 2019 07:10:00 */);
+}
+
+TEST_F(CalendarTest,
+ DoesntAddADayWhenTimeInThePastAndDayNotSpecifiedAndDisabled) {
+ int64 time;
+ DatetimeGranularity granularity;
+ DatetimeParsedData data;
+ data.SetAbsoluteValue(DatetimeComponent::ComponentType::HOUR, 7);
+ data.SetAbsoluteValue(DatetimeComponent::ComponentType::MINUTE, 10);
+
+ ASSERT_TRUE(calendarlib_->InterpretParseData(
+ data,
+ /*reference_time_ms_utc=*/1567317600000L /* Sept 01 2019 00:00:00 */,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*reference_locale=*/"en-CH",
+ /*prefer_future_for_unspecified_date=*/false, &time, &granularity));
+ EXPECT_EQ(time, 1567314600000L /* Sept 01 2019 07:10:00 */);
+}
+
+TEST_F(CalendarTest, DoesntAddADayWhenTimeInTheFutureAndDayNotSpecified) {
+ int64 time;
+ DatetimeGranularity granularity;
+ DatetimeParsedData data;
+ data.SetAbsoluteValue(DatetimeComponent::ComponentType::HOUR, 9);
+ data.SetAbsoluteValue(DatetimeComponent::ComponentType::MINUTE, 10);
+
+ ASSERT_TRUE(calendarlib_->InterpretParseData(
+ data,
+ /*reference_time_ms_utc=*/1567317600000L /* Sept 01 2019 00:00:00 */,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*reference_locale=*/"en-CH", /*prefer_future_for_unspecified_date=*/true,
+ &time, &granularity));
+ EXPECT_EQ(time, 1567321800000L /* Sept 01 2019 09:10:00 */);
+
+ ASSERT_TRUE(calendarlib_->InterpretParseData(
+ data,
+ /*reference_time_ms_utc=*/1567317600000L /* Sept 01 2019 00:00:00 */,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*reference_locale=*/"en-CH",
+ /*prefer_future_for_unspecified_date=*/false, &time, &granularity));
+ EXPECT_EQ(time, 1567321800000L /* Sept 01 2019 09:10:00 */);
+}
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/native/utils/codepoint-range.fbs b/native/utils/codepoint-range.fbs
old mode 100755
new mode 100644
diff --git a/native/utils/container/bit-vector.cc b/native/utils/container/bit-vector.cc
new file mode 100644
index 0000000..388e488
--- /dev/null
+++ b/native/utils/container/bit-vector.cc
@@ -0,0 +1,96 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/container/bit-vector.h"
+
+#include <math.h>
+
+#include <algorithm>
+#include <memory>
+#include <vector>
+
+#include "utils/base/logging.h"
+#include "utils/container/bit-vector_generated.h"
+
+namespace libtextclassifier3 {
+
+BitVector::BitVector(const BitVectorData* bit_vector_data)
+ : bit_vector_data_(bit_vector_data) {}
+
+bool BitVector::GetFromSparseData(int32 index) const {
+ return std::binary_search(
+ bit_vector_data_->sparse_data()->sorted_indices_32()->begin(),
+ bit_vector_data_->sparse_data()->sorted_indices_32()->end(), index);
+}
+
+bool BitVector::GetFromDenseData(int32 index) const {
+ if (index >= bit_vector_data_->dense_data()->size()) {
+ return false;
+ }
+ int32 byte_index = index / 8;
+ uint8 extracted_byte =
+ bit_vector_data_->dense_data()->data()->Get(byte_index);
+ uint8 bit_index = index % 8;
+ return extracted_byte & (1 << bit_index);
+}
+
+bool BitVector::Get(int32 index) const {
+ TC3_DCHECK(index >= 0);
+
+ if (bit_vector_data_ == nullptr) {
+ return false;
+ }
+ if (bit_vector_data_->dense_data() != nullptr) {
+ return GetFromDenseData(index);
+ }
+ return GetFromSparseData(index);
+}
+
+std::unique_ptr<BitVectorDataT> BitVector::CreateSparseBitVectorData(
+ const std::vector<int32>& indices) {
+ auto bit_vector_data = std::make_unique<BitVectorDataT>();
+ bit_vector_data->sparse_data =
+ std::make_unique<libtextclassifier3::SparseBitVectorDataT>();
+ bit_vector_data->sparse_data->sorted_indices_32 = indices;
+ return bit_vector_data;
+}
+
+std::unique_ptr<BitVectorDataT> BitVector::CreateDenseBitVectorData(
+ const std::vector<bool>& data) {
+ uint8_t temp = 0;
+ std::vector<uint8_t> result;
+ for (int i = 0; i < data.size(); i++) {
+ if (i != 0 && (i % 8) == 0) {
+ result.push_back(temp);
+ temp = 0;
+ }
+ if (data[i]) {
+ temp += (1 << (i % 8));
+ }
+ }
+ if ((data.size() % 8) != 0) {
+ result.push_back(temp);
+ }
+
+ auto bit_vector_data = std::make_unique<BitVectorDataT>();
+ bit_vector_data->dense_data =
+ std::make_unique<libtextclassifier3::DenseBitVectorDataT>();
+ bit_vector_data->dense_data->data = result;
+ bit_vector_data->dense_data->size = data.size();
+ return bit_vector_data;
+}
+
+} // namespace libtextclassifier3
diff --git a/native/utils/container/bit-vector.fbs b/native/utils/container/bit-vector.fbs
new file mode 100644
index 0000000..d117ee5
--- /dev/null
+++ b/native/utils/container/bit-vector.fbs
@@ -0,0 +1,40 @@
+//
+// Copyright (C) 2018 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+// The data that is used to represent a BitVector.
+namespace libtextclassifier3;
+table BitVectorData {
+ dense_data:DenseBitVectorData;
+ sparse_data:SparseBitVectorData;
+}
+
+// A dense representation of a bit vector.
+namespace libtextclassifier3;
+table DenseBitVectorData {
+ // The bits.
+ data:[ubyte];
+
+ // Number of bits in this bit vector.
+ size:int;
+}
+
+// A sparse representation of a bit vector.
+namespace libtextclassifier3;
+table SparseBitVectorData {
+ // A vector of sorted indices of elements that are 1.
+ sorted_indices_32:[int];
+}
+
diff --git a/native/utils/container/bit-vector.h b/native/utils/container/bit-vector.h
new file mode 100644
index 0000000..f6716d5
--- /dev/null
+++ b/native/utils/container/bit-vector.h
@@ -0,0 +1,57 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_CONTAINER_BIT_VECTOR_H_
+#define LIBTEXTCLASSIFIER_UTILS_CONTAINER_BIT_VECTOR_H_
+
+#include <set>
+#include <vector>
+
+#include "utils/base/integral_types.h"
+#include "utils/container/bit-vector_generated.h"
+
+namespace libtextclassifier3 {
+
+// A read-only bit vector. It does not own the data and it is like a view on
+// the given data. There are two internal representations, sparse and dense.
+// The dense one stores every bits. The sparse stores only the indices of
+// elements that are 1.
+class BitVector {
+ public:
+ explicit BitVector(const BitVectorData* bit_vector_data);
+
+ // Gets a particular bit. If the underlying data does not contain the
+ // value of the asked bit, false is returned.
+ const bool operator[](int index) const { return Get(index); }
+
+ // Creates a BitVectorDataT using the dense representation.
+ static std::unique_ptr<BitVectorDataT> CreateDenseBitVectorData(
+ const std::vector<bool>& data);
+
+ // Creates a BitVectorDataT using the sparse representation.
+ static std::unique_ptr<BitVectorDataT> CreateSparseBitVectorData(
+ const std::vector<int32>& indices);
+
+ private:
+ const BitVectorData* bit_vector_data_;
+
+ bool Get(int index) const;
+ bool GetFromSparseData(int index) const;
+ bool GetFromDenseData(int index) const;
+};
+
+} // namespace libtextclassifier3
+#endif // LIBTEXTCLASSIFIER_UTILS_CONTAINER_BIT_VECTOR_H_
diff --git a/native/utils/container/bit-vector_test.cc b/native/utils/container/bit-vector_test.cc
new file mode 100644
index 0000000..dfa67e8
--- /dev/null
+++ b/native/utils/container/bit-vector_test.cc
@@ -0,0 +1,85 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/container/bit-vector.h"
+
+#include <memory>
+
+#include "utils/container/bit-vector_generated.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+TEST(BitVectorTest, Dense) {
+ std::vector<bool> data = {false, true, true, true, false,
+ false, true, false, false, true};
+
+ std::unique_ptr<BitVectorDataT> mutable_bit_vector_data =
+ BitVector::CreateDenseBitVectorData(data);
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(BitVectorData::Pack(builder, mutable_bit_vector_data.get()));
+ const flatbuffers::DetachedBuffer buffer = builder.Release();
+ const BitVectorData* bit_vector_data =
+ flatbuffers::GetRoot<BitVectorData>(buffer.data());
+
+ BitVector bit_vector(bit_vector_data);
+ EXPECT_EQ(bit_vector[0], false);
+ EXPECT_EQ(bit_vector[1], true);
+ EXPECT_EQ(bit_vector[2], true);
+ EXPECT_EQ(bit_vector[3], true);
+ EXPECT_EQ(bit_vector[4], false);
+ EXPECT_EQ(bit_vector[5], false);
+ EXPECT_EQ(bit_vector[6], true);
+ EXPECT_EQ(bit_vector[7], false);
+ EXPECT_EQ(bit_vector[8], false);
+ EXPECT_EQ(bit_vector[9], true);
+ EXPECT_EQ(bit_vector[10], false);
+}
+
+TEST(BitVectorTest, Sparse) {
+ std::vector<int32> sorted_indices = {3, 7};
+
+ std::unique_ptr<BitVectorDataT> mutable_bit_vector_data =
+ BitVector::CreateSparseBitVectorData(sorted_indices);
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(BitVectorData::Pack(builder, mutable_bit_vector_data.get()));
+ const flatbuffers::DetachedBuffer buffer = builder.Release();
+ const BitVectorData* bit_vector_data =
+ flatbuffers::GetRoot<BitVectorData>(buffer.data());
+
+ BitVector bit_vector(bit_vector_data);
+ EXPECT_EQ(bit_vector[0], false);
+ EXPECT_EQ(bit_vector[1], false);
+ EXPECT_EQ(bit_vector[2], false);
+ EXPECT_EQ(bit_vector[3], true);
+ EXPECT_EQ(bit_vector[4], false);
+ EXPECT_EQ(bit_vector[5], false);
+ EXPECT_EQ(bit_vector[6], false);
+ EXPECT_EQ(bit_vector[7], true);
+ EXPECT_EQ(bit_vector[8], false);
+}
+
+TEST(BitVectorTest, Null) {
+ BitVector bit_vector(nullptr);
+
+ EXPECT_EQ(bit_vector[0], false);
+ EXPECT_EQ(bit_vector[1], false);
+}
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/native/utils/container/double-array-trie_test.cc b/native/utils/container/double-array-trie_test.cc
new file mode 100644
index 0000000..b639d53
--- /dev/null
+++ b/native/utils/container/double-array-trie_test.cc
@@ -0,0 +1,132 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/container/double-array-trie.h"
+
+#include <fstream>
+#include <string>
+#include <vector>
+
+#include "utils/test-data-test-utils.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+std::string GetTestConfigPath() {
+ return GetTestDataPath("utils/container/test_data/test_trie.bin");
+}
+
+TEST(DoubleArrayTest, Lookup) {
+ // Test trie that contains pieces "hell", "hello", "o", "there".
+ std::ifstream test_config_stream(GetTestConfigPath());
+ std::string config((std::istreambuf_iterator<char>(test_config_stream)),
+ (std::istreambuf_iterator<char>()));
+ DoubleArrayTrie trie(reinterpret_cast<const TrieNode*>(config.data()),
+ config.size() / sizeof(TrieNode));
+
+ {
+ std::vector<StringSet::Match> matches;
+ EXPECT_TRUE(trie.FindAllPrefixMatches("hello there", &matches));
+ EXPECT_EQ(matches.size(), 2);
+ EXPECT_EQ(matches[0].id, 0 /*hell*/);
+ EXPECT_EQ(matches[0].match_length, 4 /*hell*/);
+ EXPECT_EQ(matches[1].id, 1 /*hello*/);
+ EXPECT_EQ(matches[1].match_length, 5 /*hello*/);
+ }
+
+ {
+ std::vector<StringSet::Match> matches;
+ EXPECT_TRUE(trie.FindAllPrefixMatches("he", &matches));
+ EXPECT_THAT(matches, testing::IsEmpty());
+ }
+
+ {
+ std::vector<StringSet::Match> matches;
+ EXPECT_TRUE(trie.FindAllPrefixMatches("abcd", &matches));
+ EXPECT_THAT(matches, testing::IsEmpty());
+ }
+
+ {
+ std::vector<StringSet::Match> matches;
+ EXPECT_TRUE(trie.FindAllPrefixMatches("", &matches));
+ EXPECT_THAT(matches, testing::IsEmpty());
+ }
+
+ {
+ std::vector<StringSet::Match> matches;
+ EXPECT_TRUE(trie.FindAllPrefixMatches("hi there", &matches));
+ EXPECT_THAT(matches, testing::IsEmpty());
+ }
+
+ {
+ std::vector<StringSet::Match> matches;
+ EXPECT_TRUE(trie.FindAllPrefixMatches(StringPiece("\0", 1), &matches));
+ EXPECT_THAT(matches, testing::IsEmpty());
+ }
+
+ {
+ std::vector<StringSet::Match> matches;
+ EXPECT_TRUE(
+ trie.FindAllPrefixMatches(StringPiece("\xff, \xfe", 2), &matches));
+ EXPECT_THAT(matches, testing::IsEmpty());
+ }
+
+ {
+ StringSet::Match match;
+ EXPECT_TRUE(trie.LongestPrefixMatch("hella there", &match));
+ EXPECT_EQ(match.id, 0 /*hell*/);
+ }
+
+ {
+ StringSet::Match match;
+ EXPECT_TRUE(trie.LongestPrefixMatch("hello there", &match));
+ EXPECT_EQ(match.id, 1 /*hello*/);
+ }
+
+ {
+ StringSet::Match match;
+ EXPECT_TRUE(trie.LongestPrefixMatch("abcd", &match));
+ EXPECT_EQ(match.id, -1);
+ }
+
+ {
+ StringSet::Match match;
+ EXPECT_TRUE(trie.LongestPrefixMatch("", &match));
+ EXPECT_EQ(match.id, -1);
+ }
+
+ {
+ int value;
+ EXPECT_TRUE(trie.Find("hell", &value));
+ EXPECT_EQ(value, 0);
+ }
+
+ {
+ int value;
+ EXPECT_FALSE(trie.Find("hella", &value));
+ }
+
+ {
+ int value;
+ EXPECT_TRUE(trie.Find("hello", &value));
+ EXPECT_EQ(value, 1 /*hello*/);
+ }
+}
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/native/utils/container/test_data/test_trie.bin b/native/utils/container/test_data/test_trie.bin
new file mode 100644
index 0000000..ade1f29
--- /dev/null
+++ b/native/utils/container/test_data/test_trie.bin
Binary files differ
diff --git a/native/utils/flatbuffers.h b/native/utils/flatbuffers.h
deleted file mode 100644
index aaf248e..0000000
--- a/native/utils/flatbuffers.h
+++ /dev/null
@@ -1,449 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-// Utility functions for working with FlatBuffers.
-
-#ifndef LIBTEXTCLASSIFIER_UTILS_FLATBUFFERS_H_
-#define LIBTEXTCLASSIFIER_UTILS_FLATBUFFERS_H_
-
-#include <memory>
-#include <string>
-#include <unordered_map>
-
-#include "annotator/model_generated.h"
-#include "utils/base/logging.h"
-#include "utils/flatbuffers_generated.h"
-#include "utils/strings/stringpiece.h"
-#include "utils/variant.h"
-#include "flatbuffers/flatbuffers.h"
-#include "flatbuffers/reflection.h"
-#include "flatbuffers/reflection_generated.h"
-
-namespace libtextclassifier3 {
-
-class ReflectiveFlatBuffer;
-class RepeatedField;
-
-// Loads and interprets the buffer as 'FlatbufferMessage' and verifies its
-// integrity.
-template <typename FlatbufferMessage>
-const FlatbufferMessage* LoadAndVerifyFlatbuffer(const void* buffer, int size) {
- const FlatbufferMessage* message =
- flatbuffers::GetRoot<FlatbufferMessage>(buffer);
- if (message == nullptr) {
- return nullptr;
- }
- flatbuffers::Verifier verifier(reinterpret_cast<const uint8_t*>(buffer),
- size);
- if (message->Verify(verifier)) {
- return message;
- } else {
- return nullptr;
- }
-}
-
-// Same as above but takes string.
-template <typename FlatbufferMessage>
-const FlatbufferMessage* LoadAndVerifyFlatbuffer(const std::string& buffer) {
- return LoadAndVerifyFlatbuffer<FlatbufferMessage>(buffer.c_str(),
- buffer.size());
-}
-
-// Loads and interprets the buffer as 'FlatbufferMessage', verifies its
-// integrity and returns its mutable version.
-template <typename FlatbufferMessage>
-std::unique_ptr<typename FlatbufferMessage::NativeTableType>
-LoadAndVerifyMutableFlatbuffer(const void* buffer, int size) {
- const FlatbufferMessage* message =
- LoadAndVerifyFlatbuffer<FlatbufferMessage>(buffer, size);
- if (message == nullptr) {
- return nullptr;
- }
- return std::unique_ptr<typename FlatbufferMessage::NativeTableType>(
- message->UnPack());
-}
-
-// Same as above but takes string.
-template <typename FlatbufferMessage>
-std::unique_ptr<typename FlatbufferMessage::NativeTableType>
-LoadAndVerifyMutableFlatbuffer(const std::string& buffer) {
- return LoadAndVerifyMutableFlatbuffer<FlatbufferMessage>(buffer.c_str(),
- buffer.size());
-}
-
-template <typename FlatbufferMessage>
-const char* FlatbufferFileIdentifier() {
- return nullptr;
-}
-
-template <>
-const char* FlatbufferFileIdentifier<Model>();
-
-// Packs the mutable flatbuffer message to string.
-template <typename FlatbufferMessage>
-std::string PackFlatbuffer(
- const typename FlatbufferMessage::NativeTableType* mutable_message) {
- flatbuffers::FlatBufferBuilder builder;
- builder.Finish(FlatbufferMessage::Pack(builder, mutable_message),
- FlatbufferFileIdentifier<FlatbufferMessage>());
- return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
- builder.GetSize());
-}
-
-class ReflectiveFlatbuffer;
-
-// Checks whether a variant value type agrees with a field type.
-template <typename T>
-bool IsMatchingType(const reflection::BaseType type) {
- switch (type) {
- case reflection::Bool:
- return std::is_same<T, bool>::value;
- case reflection::Byte:
- return std::is_same<T, int8>::value;
- case reflection::UByte:
- return std::is_same<T, uint8>::value;
- case reflection::Int:
- return std::is_same<T, int32>::value;
- case reflection::UInt:
- return std::is_same<T, uint32>::value;
- case reflection::Long:
- return std::is_same<T, int64>::value;
- case reflection::ULong:
- return std::is_same<T, uint64>::value;
- case reflection::Float:
- return std::is_same<T, float>::value;
- case reflection::Double:
- return std::is_same<T, double>::value;
- case reflection::String:
- return std::is_same<T, std::string>::value ||
- std::is_same<T, StringPiece>::value ||
- std::is_same<T, const char*>::value;
- case reflection::Obj:
- return std::is_same<T, ReflectiveFlatbuffer>::value;
- default:
- return false;
- }
-}
-
-// A flatbuffer that can be built using flatbuffer reflection data of the
-// schema.
-// Normally, field information is hard-coded in code generated from a flatbuffer
-// schema. Here we lookup the necessary information for building a flatbuffer
-// from the provided reflection meta data.
-// When serializing a flatbuffer, the library requires that the sub messages
-// are already serialized, therefore we explicitly keep the field values and
-// serialize the message in (reverse) topological dependency order.
-class ReflectiveFlatbuffer {
- public:
- ReflectiveFlatbuffer(const reflection::Schema* schema,
- const reflection::Object* type)
- : schema_(schema), type_(type) {}
-
- // Gets the field information for a field name, returns nullptr if the
- // field was not defined.
- const reflection::Field* GetFieldOrNull(const StringPiece field_name) const;
- const reflection::Field* GetFieldOrNull(const FlatbufferField* field) const;
- const reflection::Field* GetFieldOrNull(const int field_offset) const;
-
- // Gets a nested field and the message it is defined on.
- bool GetFieldWithParent(const FlatbufferFieldPath* field_path,
- ReflectiveFlatbuffer** parent,
- reflection::Field const** field);
-
- // Sets a field to a specific value.
- // Returns true if successful, and false if the field was not found or the
- // expected type doesn't match.
- template <typename T>
- bool Set(StringPiece field_name, T value);
-
- // Sets a field to a specific value.
- // Returns true if successful, and false if the expected type doesn't match.
- // Expects `field` to be non-null.
- template <typename T>
- bool Set(const reflection::Field* field, T value);
-
- // Sets a field to a specific value. Field is specified by path.
- template <typename T>
- bool Set(const FlatbufferFieldPath* path, T value);
-
- // Sets sub-message field (if not set yet), and returns a pointer to it.
- // Returns nullptr if the field was not found, or the field type was not a
- // table.
- ReflectiveFlatbuffer* Mutable(StringPiece field_name);
- ReflectiveFlatbuffer* Mutable(const reflection::Field* field);
-
- // Parses the value (according to the type) and sets a primitive field to the
- // parsed value.
- bool ParseAndSet(const reflection::Field* field, const std::string& value);
- bool ParseAndSet(const FlatbufferFieldPath* path, const std::string& value);
-
- // Adds a primitive value to the repeated field.
- template <typename T>
- bool Add(StringPiece field_name, T value);
-
- // Add a sub-message to the repeated field.
- ReflectiveFlatbuffer* Add(StringPiece field_name);
-
- template <typename T>
- bool Add(const reflection::Field* field, T value);
-
- ReflectiveFlatbuffer* Add(const reflection::Field* field);
-
- // Gets the reflective flatbuffer for a repeated field.
- // Returns nullptr if the field was not found, or the field type was not a
- // vector.
- RepeatedField* Repeated(StringPiece field_name);
- RepeatedField* Repeated(const reflection::Field* field);
-
- // Serializes the flatbuffer.
- flatbuffers::uoffset_t Serialize(
- flatbuffers::FlatBufferBuilder* builder) const;
- std::string Serialize() const;
-
- // Merges the fields from the given flatbuffer table into this flatbuffer.
- // Scalar fields will be overwritten, if present in `from`.
- // Embedded messages will be merged.
- bool MergeFrom(const flatbuffers::Table* from);
- bool MergeFromSerializedFlatbuffer(StringPiece from);
-
- // Flattens the flatbuffer as a flat map.
- // (Nested) fields names are joined by `key_separator`.
- std::map<std::string, Variant> AsFlatMap(
- const std::string& key_separator = ".") const {
- std::map<std::string, Variant> result;
- AsFlatMap(key_separator, /*key_prefix=*/"", &result);
- return result;
- }
-
- // Converts the flatbuffer's content to a human-readable textproto
- // representation.
- std::string ToTextProto() const;
-
- bool HasExplicitlySetFields() const {
- return !fields_.empty() || !children_.empty() || !repeated_fields_.empty();
- }
-
- private:
- // Helper function for merging given repeated field from given flatbuffer
- // table. Appends the elements.
- template <typename T>
- bool AppendFromVector(const flatbuffers::Table* from,
- const reflection::Field* field);
-
- const reflection::Schema* const schema_;
- const reflection::Object* const type_;
-
- // Cached primitive fields (scalars and strings).
- std::unordered_map<const reflection::Field*, Variant> fields_;
-
- // Cached sub-messages.
- std::unordered_map<const reflection::Field*,
- std::unique_ptr<ReflectiveFlatbuffer>>
- children_;
-
- // Cached repeated fields.
- std::unordered_map<const reflection::Field*, std::unique_ptr<RepeatedField>>
- repeated_fields_;
-
- // Flattens the flatbuffer as a flat map.
- // (Nested) fields names are joined by `key_separator` and prefixed by
- // `key_prefix`.
- void AsFlatMap(const std::string& key_separator,
- const std::string& key_prefix,
- std::map<std::string, Variant>* result) const;
-};
-
-// A helper class to build flatbuffers based on schema reflection data.
-// Can be used to a `ReflectiveFlatbuffer` for the root message of the
-// schema, or any defined table via name.
-class ReflectiveFlatbufferBuilder {
- public:
- explicit ReflectiveFlatbufferBuilder(const reflection::Schema* schema)
- : schema_(schema) {}
-
- // Starts a new root table message.
- std::unique_ptr<ReflectiveFlatbuffer> NewRoot() const;
-
- // Starts a new table message. Returns nullptr if no table with given name is
- // found in the schema.
- std::unique_ptr<ReflectiveFlatbuffer> NewTable(
- const StringPiece table_name) const;
-
- private:
- const reflection::Schema* const schema_;
-};
-
-// Encapsulates a repeated field.
-// Serves as a common base class for repeated fields.
-class RepeatedField {
- public:
- RepeatedField(const reflection::Schema* const schema,
- const reflection::Field* field)
- : schema_(schema),
- field_(field),
- is_primitive_(field->type()->element() != reflection::BaseType::Obj) {}
-
- template <typename T>
- bool Add(const T value);
-
- ReflectiveFlatbuffer* Add();
-
- template <typename T>
- T Get(int index) const {
- return items_.at(index).Value<T>();
- }
-
- template <>
- ReflectiveFlatbuffer* Get(int index) const {
- if (is_primitive_) {
- TC3_LOG(ERROR) << "Trying to get primitive value out of non-primitive "
- "repeated field.";
- return nullptr;
- }
- return object_items_.at(index).get();
- }
-
- int Size() const {
- if (is_primitive_) {
- return items_.size();
- } else {
- return object_items_.size();
- }
- }
-
- flatbuffers::uoffset_t Serialize(
- flatbuffers::FlatBufferBuilder* builder) const;
-
- private:
- flatbuffers::uoffset_t SerializeString(
- flatbuffers::FlatBufferBuilder* builder) const;
- flatbuffers::uoffset_t SerializeObject(
- flatbuffers::FlatBufferBuilder* builder) const;
-
- const reflection::Schema* const schema_;
- const reflection::Field* field_;
- bool is_primitive_;
-
- std::vector<Variant> items_;
- std::vector<std::unique_ptr<ReflectiveFlatbuffer>> object_items_;
-};
-
-template <typename T>
-bool ReflectiveFlatbuffer::Set(StringPiece field_name, T value) {
- if (const reflection::Field* field = GetFieldOrNull(field_name)) {
- if (field->type()->base_type() == reflection::BaseType::Vector ||
- field->type()->base_type() == reflection::BaseType::Obj) {
- TC3_LOG(ERROR)
- << "Trying to set a primitive value on a non-scalar field.";
- return false;
- }
- return Set<T>(field, value);
- }
- TC3_LOG(ERROR) << "Couldn't find a field: " << field_name;
- return false;
-}
-
-template <typename T>
-bool ReflectiveFlatbuffer::Set(const reflection::Field* field, T value) {
- if (field == nullptr) {
- TC3_LOG(ERROR) << "Expected non-null field.";
- return false;
- }
- Variant variant_value(value);
- if (!IsMatchingType<T>(field->type()->base_type())) {
- TC3_LOG(ERROR) << "Type mismatch for field `" << field->name()->str()
- << "`, expected: " << field->type()->base_type()
- << ", got: " << variant_value.GetType();
- return false;
- }
- fields_[field] = variant_value;
- return true;
-}
-
-template <typename T>
-bool ReflectiveFlatbuffer::Set(const FlatbufferFieldPath* path, T value) {
- ReflectiveFlatbuffer* parent;
- const reflection::Field* field;
- if (!GetFieldWithParent(path, &parent, &field)) {
- return false;
- }
- return parent->Set<T>(field, value);
-}
-
-template <typename T>
-bool ReflectiveFlatbuffer::Add(StringPiece field_name, T value) {
- const reflection::Field* field = GetFieldOrNull(field_name);
- if (field == nullptr) {
- return false;
- }
-
- if (field->type()->base_type() != reflection::BaseType::Vector) {
- return false;
- }
-
- return Add<T>(field, value);
-}
-
-template <typename T>
-bool ReflectiveFlatbuffer::Add(const reflection::Field* field, T value) {
- if (field == nullptr) {
- return false;
- }
- Repeated(field)->Add(value);
- return true;
-}
-
-template <typename T>
-bool RepeatedField::Add(const T value) {
- if (!is_primitive_ || !IsMatchingType<T>(field_->type()->element())) {
- TC3_LOG(ERROR) << "Trying to add value of unmatching type.";
- return false;
- }
- items_.push_back(Variant{value});
- return true;
-}
-
-// Resolves field lookups by name to the concrete field offsets.
-bool SwapFieldNamesForOffsetsInPath(const reflection::Schema* schema,
- FlatbufferFieldPathT* path);
-
-template <typename T>
-bool ReflectiveFlatbuffer::AppendFromVector(const flatbuffers::Table* from,
- const reflection::Field* field) {
- const flatbuffers::Vector<T>* from_vector =
- from->GetPointer<const flatbuffers::Vector<T>*>(field->offset());
- if (from_vector == nullptr) {
- return false;
- }
-
- RepeatedField* to_repeated = Repeated(field);
- for (const T element : *from_vector) {
- to_repeated->Add(element);
- }
- return true;
-}
-
-inline logging::LoggingStringStream& operator<<(
- logging::LoggingStringStream& stream, flatbuffers::String* message) {
- if (message != nullptr) {
- stream.message.append(message->c_str(), message->size());
- }
- return stream;
-}
-
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_UTILS_FLATBUFFERS_H_
diff --git a/native/utils/flatbuffers.fbs b/native/utils/flatbuffers/flatbuffers.fbs
old mode 100755
new mode 100644
similarity index 100%
rename from native/utils/flatbuffers.fbs
rename to native/utils/flatbuffers/flatbuffers.fbs
diff --git a/native/utils/flatbuffers/flatbuffers.h b/native/utils/flatbuffers/flatbuffers.h
new file mode 100644
index 0000000..1bb739b
--- /dev/null
+++ b/native/utils/flatbuffers/flatbuffers.h
@@ -0,0 +1,118 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Utility functions for working with FlatBuffers.
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_FLATBUFFERS_FLATBUFFERS_H_
+#define LIBTEXTCLASSIFIER_UTILS_FLATBUFFERS_FLATBUFFERS_H_
+
+#include <string>
+
+#include "annotator/model_generated.h"
+#include "flatbuffers/flatbuffers.h"
+
+namespace libtextclassifier3 {
+
+// Loads and interprets the buffer as 'FlatbufferMessage' and verifies its
+// integrity.
+template <typename FlatbufferMessage>
+const FlatbufferMessage* LoadAndVerifyFlatbuffer(const void* buffer, int size) {
+ const FlatbufferMessage* message =
+ flatbuffers::GetRoot<FlatbufferMessage>(buffer);
+ if (message == nullptr) {
+ return nullptr;
+ }
+ flatbuffers::Verifier verifier(reinterpret_cast<const uint8_t*>(buffer),
+ size);
+ if (message->Verify(verifier)) {
+ return message;
+ } else {
+ return nullptr;
+ }
+}
+
+// Same as above but takes string.
+template <typename FlatbufferMessage>
+const FlatbufferMessage* LoadAndVerifyFlatbuffer(const std::string& buffer) {
+ return LoadAndVerifyFlatbuffer<FlatbufferMessage>(buffer.c_str(),
+ buffer.size());
+}
+
+// Loads and interprets the buffer as 'FlatbufferMessage', verifies its
+// integrity and returns its mutable version.
+template <typename FlatbufferMessage>
+std::unique_ptr<typename FlatbufferMessage::NativeTableType>
+LoadAndVerifyMutableFlatbuffer(const void* buffer, int size) {
+ const FlatbufferMessage* message =
+ LoadAndVerifyFlatbuffer<FlatbufferMessage>(buffer, size);
+ if (message == nullptr) {
+ return nullptr;
+ }
+ return std::unique_ptr<typename FlatbufferMessage::NativeTableType>(
+ message->UnPack());
+}
+
+// Same as above but takes string.
+template <typename FlatbufferMessage>
+std::unique_ptr<typename FlatbufferMessage::NativeTableType>
+LoadAndVerifyMutableFlatbuffer(const std::string& buffer) {
+ return LoadAndVerifyMutableFlatbuffer<FlatbufferMessage>(buffer.c_str(),
+ buffer.size());
+}
+
+template <typename FlatbufferMessage>
+const char* FlatbufferFileIdentifier() {
+ return nullptr;
+}
+
+template <>
+inline const char* FlatbufferFileIdentifier<Model>() {
+ return ModelIdentifier();
+}
+
+// Packs the mutable flatbuffer message to string.
+template <typename FlatbufferMessage>
+std::string PackFlatbuffer(
+ const typename FlatbufferMessage::NativeTableType* mutable_message) {
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(FlatbufferMessage::Pack(builder, mutable_message),
+ FlatbufferFileIdentifier<FlatbufferMessage>());
+ return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize());
+}
+
+// A convenience flatbuffer object with its underlying buffer.
+template <typename T, typename B = flatbuffers::DetachedBuffer>
+class OwnedFlatbuffer {
+ public:
+ explicit OwnedFlatbuffer(B&& buffer) : buffer_(std::move(buffer)) {}
+
+ // Cast as flatbuffer type.
+ const T* get() const { return flatbuffers::GetRoot<T>(buffer_.data()); }
+
+ const B& buffer() const { return buffer_; }
+
+ const T* operator->() const {
+ return flatbuffers::GetRoot<T>(buffer_.data());
+ }
+
+ private:
+ B buffer_;
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_FLATBUFFERS_FLATBUFFERS_H_
diff --git a/native/utils/flatbuffers/flatbuffers_test.bfbs b/native/utils/flatbuffers/flatbuffers_test.bfbs
new file mode 100644
index 0000000..519550f
--- /dev/null
+++ b/native/utils/flatbuffers/flatbuffers_test.bfbs
Binary files differ
diff --git a/native/utils/flatbuffers/flatbuffers_test.fbs b/native/utils/flatbuffers/flatbuffers_test.fbs
new file mode 100644
index 0000000..f501e13
--- /dev/null
+++ b/native/utils/flatbuffers/flatbuffers_test.fbs
@@ -0,0 +1,67 @@
+//
+// Copyright (C) 2018 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+namespace libtextclassifier3.test;
+
+table FlightNumberInfo {
+ carrier_code: string;
+ flight_code: int;
+}
+
+table ContactInfo {
+ first_name: string;
+ last_name: string;
+ phone_number: string;
+ score: float;
+}
+
+table Reminder {
+ title: string;
+ notes: [string];
+}
+
+table NestedA {
+ nestedb: NestedB;
+ value: string;
+ repeated_str: [string];
+}
+
+table NestedB {
+ nesteda: NestedA;
+}
+
+enum EnumValue : short {
+ VALUE_0 = 0,
+ VALUE_1 = 1,
+ VALUE_2 = 2,
+}
+
+table EntityData {
+ an_int_field: int;
+ a_long_field: int64;
+ a_bool_field: bool;
+ a_float_field: float;
+ a_double_field: double;
+ flight_number: FlightNumberInfo;
+ contact_info: ContactInfo;
+ reminders: [Reminder];
+ numbers: [int];
+ strings: [string];
+ nested: NestedA;
+ enum_value: EnumValue;
+}
+
+root_type libtextclassifier3.test.EntityData;
diff --git a/native/utils/flatbuffers/flatbuffers_test_extended.bfbs b/native/utils/flatbuffers/flatbuffers_test_extended.bfbs
new file mode 100644
index 0000000..fec4363
--- /dev/null
+++ b/native/utils/flatbuffers/flatbuffers_test_extended.bfbs
Binary files differ
diff --git a/native/utils/flatbuffers/flatbuffers_test_extended.fbs b/native/utils/flatbuffers/flatbuffers_test_extended.fbs
new file mode 100644
index 0000000..0410874
--- /dev/null
+++ b/native/utils/flatbuffers/flatbuffers_test_extended.fbs
@@ -0,0 +1,68 @@
+//
+// Copyright (C) 2018 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+namespace libtextclassifier3.test;
+
+table FlightNumberInfo {
+ carrier_code: string;
+ flight_code: int;
+}
+
+table ContactInfo {
+ first_name: string;
+ last_name: string;
+ phone_number: string;
+ score: float;
+}
+
+table Reminder {
+ title: string;
+ notes: [string];
+}
+
+table NestedA {
+ nestedb: NestedB;
+ value: string;
+ repeated_str: [string];
+}
+
+table NestedB {
+ nesteda: NestedA;
+}
+
+enum EnumValue : short {
+ VALUE_0 = 0,
+ VALUE_1 = 1,
+ VALUE_2 = 2,
+}
+
+table EntityData {
+ an_int_field: int;
+ a_long_field: int64;
+ a_bool_field: bool;
+ a_float_field: float;
+ a_double_field: double;
+ flight_number: FlightNumberInfo;
+ contact_info: ContactInfo;
+ reminders: [Reminder];
+ numbers: [int];
+ strings: [string];
+ nested: NestedA;
+ enum_value: EnumValue;
+ mystic: string; // Extra field.
+}
+
+root_type libtextclassifier3.test.EntityData;
diff --git a/native/utils/flatbuffers.cc b/native/utils/flatbuffers/mutable.cc
similarity index 66%
rename from native/utils/flatbuffers.cc
rename to native/utils/flatbuffers/mutable.cc
index cf4c97f..0f425eb 100644
--- a/native/utils/flatbuffers.cc
+++ b/native/utils/flatbuffers/mutable.cc
@@ -14,10 +14,11 @@
* limitations under the License.
*/
-#include "utils/flatbuffers.h"
+#include "utils/flatbuffers/mutable.h"
#include <vector>
+#include "utils/flatbuffers/reflection.h"
#include "utils/strings/numbers.h"
#include "utils/variant.h"
#include "flatbuffers/reflection_generated.h"
@@ -25,56 +26,6 @@
namespace libtextclassifier3 {
namespace {
-// Gets the field information for a field name, returns nullptr if the
-// field was not defined.
-const reflection::Field* GetFieldOrNull(const reflection::Object* type,
- const StringPiece field_name) {
- TC3_CHECK(type != nullptr && type->fields() != nullptr);
- return type->fields()->LookupByKey(field_name.data());
-}
-
-const reflection::Field* GetFieldOrNull(const reflection::Object* type,
- const int field_offset) {
- if (type->fields() == nullptr) {
- return nullptr;
- }
- for (const reflection::Field* field : *type->fields()) {
- if (field->offset() == field_offset) {
- return field;
- }
- }
- return nullptr;
-}
-
-const reflection::Field* GetFieldOrNull(const reflection::Object* type,
- const StringPiece field_name,
- const int field_offset) {
- // Lookup by name might be faster as the fields are sorted by name in the
- // schema data, so try that first.
- if (!field_name.empty()) {
- return GetFieldOrNull(type, field_name.data());
- }
- return GetFieldOrNull(type, field_offset);
-}
-
-const reflection::Field* GetFieldOrNull(const reflection::Object* type,
- const FlatbufferField* field) {
- TC3_CHECK(type != nullptr && field != nullptr);
- if (field->field_name() == nullptr) {
- return GetFieldOrNull(type, field->field_offset());
- }
- return GetFieldOrNull(
- type,
- StringPiece(field->field_name()->data(), field->field_name()->size()),
- field->field_offset());
-}
-
-const reflection::Field* GetFieldOrNull(const reflection::Object* type,
- const FlatbufferFieldT* field) {
- TC3_CHECK(type != nullptr && field != nullptr);
- return GetFieldOrNull(type, field->field_name, field->field_offset);
-}
-
bool Parse(const std::string& str_value, float* value) {
double double_value;
if (!ParseDouble(str_value.data(), &double_value)) {
@@ -103,8 +54,7 @@
template <typename T>
bool ParseAndSetField(const reflection::Field* field,
- const std::string& str_value,
- ReflectiveFlatbuffer* buffer) {
+ const std::string& str_value, MutableFlatbuffer* buffer) {
T value;
if (!Parse(str_value, &value)) {
TC3_LOG(ERROR) << "Could not parse '" << str_value << "'";
@@ -120,44 +70,48 @@
} // namespace
-template <>
-const char* FlatbufferFileIdentifier<Model>() {
- return ModelIdentifier();
+MutableFlatbufferBuilder::MutableFlatbufferBuilder(
+ const reflection::Schema* schema, StringPiece root_type)
+ : schema_(schema), root_type_(TypeForName(schema, root_type)) {}
+
+std::unique_ptr<MutableFlatbuffer> MutableFlatbufferBuilder::NewRoot() const {
+ return NewTable(root_type_);
}
-std::unique_ptr<ReflectiveFlatbuffer> ReflectiveFlatbufferBuilder::NewRoot()
- const {
- if (!schema_->root_table()) {
- TC3_LOG(ERROR) << "No root table specified.";
+std::unique_ptr<MutableFlatbuffer> MutableFlatbufferBuilder::NewTable(
+ StringPiece table_name) const {
+ return NewTable(TypeForName(schema_, table_name));
+}
+
+std::unique_ptr<MutableFlatbuffer> MutableFlatbufferBuilder::NewTable(
+ const int type_id) const {
+ if (type_id < 0 || type_id >= schema_->objects()->size()) {
+ TC3_LOG(ERROR) << "Invalid type id: " << type_id;
return nullptr;
}
- return std::unique_ptr<ReflectiveFlatbuffer>(
- new ReflectiveFlatbuffer(schema_, schema_->root_table()));
+ return NewTable(schema_->objects()->Get(type_id));
}
-std::unique_ptr<ReflectiveFlatbuffer> ReflectiveFlatbufferBuilder::NewTable(
- StringPiece table_name) const {
- for (const reflection::Object* object : *schema_->objects()) {
- if (table_name.Equals(object->name()->str())) {
- return std::unique_ptr<ReflectiveFlatbuffer>(
- new ReflectiveFlatbuffer(schema_, object));
- }
+std::unique_ptr<MutableFlatbuffer> MutableFlatbufferBuilder::NewTable(
+ const reflection::Object* type) const {
+ if (type == nullptr) {
+ return nullptr;
}
- return nullptr;
+ return std::make_unique<MutableFlatbuffer>(schema_, type);
}
-const reflection::Field* ReflectiveFlatbuffer::GetFieldOrNull(
+const reflection::Field* MutableFlatbuffer::GetFieldOrNull(
const StringPiece field_name) const {
return libtextclassifier3::GetFieldOrNull(type_, field_name);
}
-const reflection::Field* ReflectiveFlatbuffer::GetFieldOrNull(
+const reflection::Field* MutableFlatbuffer::GetFieldOrNull(
const FlatbufferField* field) const {
return libtextclassifier3::GetFieldOrNull(type_, field);
}
-bool ReflectiveFlatbuffer::GetFieldWithParent(
- const FlatbufferFieldPath* field_path, ReflectiveFlatbuffer** parent,
+bool MutableFlatbuffer::GetFieldWithParent(
+ const FlatbufferFieldPath* field_path, MutableFlatbuffer** parent,
reflection::Field const** field) {
const auto* path = field_path->field();
if (path == nullptr || path->size() == 0) {
@@ -178,13 +132,48 @@
return true;
}
-const reflection::Field* ReflectiveFlatbuffer::GetFieldOrNull(
+const reflection::Field* MutableFlatbuffer::GetFieldOrNull(
const int field_offset) const {
return libtextclassifier3::GetFieldOrNull(type_, field_offset);
}
-bool ReflectiveFlatbuffer::ParseAndSet(const reflection::Field* field,
- const std::string& value) {
+bool MutableFlatbuffer::SetFromEnumValueName(const reflection::Field* field,
+ StringPiece value_name) {
+ if (!IsEnum(field->type())) {
+ return false;
+ }
+ Variant variant_value = ParseEnumValue(schema_, field->type(), value_name);
+ if (!variant_value.HasValue()) {
+ return false;
+ }
+ fields_[field] = variant_value;
+ return true;
+}
+
+bool MutableFlatbuffer::SetFromEnumValueName(StringPiece field_name,
+ StringPiece value_name) {
+ if (const reflection::Field* field = GetFieldOrNull(field_name)) {
+ return SetFromEnumValueName(field, value_name);
+ }
+ return false;
+}
+
+bool MutableFlatbuffer::SetFromEnumValueName(const FlatbufferFieldPath* path,
+ StringPiece value_name) {
+ MutableFlatbuffer* parent;
+ const reflection::Field* field;
+ if (!GetFieldWithParent(path, &parent, &field)) {
+ return false;
+ }
+ return parent->SetFromEnumValueName(field, value_name);
+}
+
+bool MutableFlatbuffer::ParseAndSet(const reflection::Field* field,
+ const std::string& value) {
+ // Try parsing as an enum value.
+ if (IsEnum(field->type()) && SetFromEnumValueName(field, value)) {
+ return true;
+ }
switch (field->type()->base_type() == reflection::Vector
? field->type()->element()
: field->type()->base_type()) {
@@ -204,9 +193,9 @@
}
}
-bool ReflectiveFlatbuffer::ParseAndSet(const FlatbufferFieldPath* path,
- const std::string& value) {
- ReflectiveFlatbuffer* parent;
+bool MutableFlatbuffer::ParseAndSet(const FlatbufferFieldPath* path,
+ const std::string& value) {
+ MutableFlatbuffer* parent;
const reflection::Field* field;
if (!GetFieldWithParent(path, &parent, &field)) {
return false;
@@ -214,7 +203,7 @@
return parent->ParseAndSet(field, value);
}
-ReflectiveFlatbuffer* ReflectiveFlatbuffer::Add(StringPiece field_name) {
+MutableFlatbuffer* MutableFlatbuffer::Add(StringPiece field_name) {
const reflection::Field* field = GetFieldOrNull(field_name);
if (field == nullptr) {
return nullptr;
@@ -227,16 +216,14 @@
return Add(field);
}
-ReflectiveFlatbuffer* ReflectiveFlatbuffer::Add(
- const reflection::Field* field) {
+MutableFlatbuffer* MutableFlatbuffer::Add(const reflection::Field* field) {
if (field == nullptr) {
return nullptr;
}
return Repeated(field)->Add();
}
-ReflectiveFlatbuffer* ReflectiveFlatbuffer::Mutable(
- const StringPiece field_name) {
+MutableFlatbuffer* MutableFlatbuffer::Mutable(const StringPiece field_name) {
if (const reflection::Field* field = GetFieldOrNull(field_name)) {
return Mutable(field);
}
@@ -244,10 +231,8 @@
return nullptr;
}
-ReflectiveFlatbuffer* ReflectiveFlatbuffer::Mutable(
- const reflection::Field* field) {
+MutableFlatbuffer* MutableFlatbuffer::Mutable(const reflection::Field* field) {
if (field->type()->base_type() != reflection::Obj) {
- TC3_LOG(ERROR) << "Field is not of type Object.";
return nullptr;
}
const auto entry = children_.find(field);
@@ -258,12 +243,31 @@
/*hint=*/entry,
std::make_pair(
field,
- std::unique_ptr<ReflectiveFlatbuffer>(new ReflectiveFlatbuffer(
+ std::unique_ptr<MutableFlatbuffer>(new MutableFlatbuffer(
schema_, schema_->objects()->Get(field->type()->index())))));
return it->second.get();
}
-RepeatedField* ReflectiveFlatbuffer::Repeated(StringPiece field_name) {
+MutableFlatbuffer* MutableFlatbuffer::Mutable(const FlatbufferFieldPath* path) {
+ const auto* field_path = path->field();
+ if (field_path == nullptr || field_path->size() == 0) {
+ return this;
+ }
+ MutableFlatbuffer* object = this;
+ for (int i = 0; i < field_path->size(); i++) {
+ const reflection::Field* field = object->GetFieldOrNull(field_path->Get(i));
+ if (field == nullptr) {
+ return nullptr;
+ }
+ object = object->Mutable(field);
+ if (object == nullptr) {
+ return nullptr;
+ }
+ }
+ return object;
+}
+
+RepeatedField* MutableFlatbuffer::Repeated(StringPiece field_name) {
if (const reflection::Field* field = GetFieldOrNull(field_name)) {
return Repeated(field);
}
@@ -271,7 +275,7 @@
return nullptr;
}
-RepeatedField* ReflectiveFlatbuffer::Repeated(const reflection::Field* field) {
+RepeatedField* MutableFlatbuffer::Repeated(const reflection::Field* field) {
if (field->type()->base_type() != reflection::Vector) {
TC3_LOG(ERROR) << "Field is not of type Vector.";
return nullptr;
@@ -291,7 +295,16 @@
return it->second.get();
}
-flatbuffers::uoffset_t ReflectiveFlatbuffer::Serialize(
+RepeatedField* MutableFlatbuffer::Repeated(const FlatbufferFieldPath* path) {
+ MutableFlatbuffer* parent;
+ const reflection::Field* field;
+ if (!GetFieldWithParent(path, &parent, &field)) {
+ return nullptr;
+ }
+ return parent->Repeated(field);
+}
+
+flatbuffers::uoffset_t MutableFlatbuffer::Serialize(
flatbuffers::FlatBufferBuilder* builder) const {
// Build all children before we can start with this table.
std::vector<
@@ -380,51 +393,14 @@
return builder->EndTable(table_start);
}
-std::string ReflectiveFlatbuffer::Serialize() const {
+std::string MutableFlatbuffer::Serialize() const {
flatbuffers::FlatBufferBuilder builder;
builder.Finish(flatbuffers::Offset<void>(Serialize(&builder)));
return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
builder.GetSize());
}
-template <>
-bool ReflectiveFlatbuffer::AppendFromVector<std::string>(
- const flatbuffers::Table* from, const reflection::Field* field) {
- auto* from_vector = from->GetPointer<
- const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>*>(
- field->offset());
- if (from_vector == nullptr) {
- return false;
- }
-
- RepeatedField* to_repeated = Repeated(field);
- for (const flatbuffers::String* element : *from_vector) {
- to_repeated->Add(element->str());
- }
- return true;
-}
-
-template <>
-bool ReflectiveFlatbuffer::AppendFromVector<ReflectiveFlatbuffer>(
- const flatbuffers::Table* from, const reflection::Field* field) {
- auto* from_vector = from->GetPointer<const flatbuffers::Vector<
- flatbuffers::Offset<const flatbuffers::Table>>*>(field->offset());
- if (from_vector == nullptr) {
- return false;
- }
-
- RepeatedField* to_repeated = Repeated(field);
- for (const flatbuffers::Table* const from_element : *from_vector) {
- ReflectiveFlatbuffer* to_element = to_repeated->Add();
- if (to_element == nullptr) {
- return false;
- }
- to_element->MergeFrom(from_element);
- }
- return true;
-}
-
-bool ReflectiveFlatbuffer::MergeFrom(const flatbuffers::Table* from) {
+bool MutableFlatbuffer::MergeFrom(const flatbuffers::Table* from) {
// No fields to set.
if (type_->fields() == nullptr) {
return true;
@@ -479,7 +455,7 @@
->str());
break;
case reflection::Obj:
- if (ReflectiveFlatbuffer* nested_field = Mutable(field);
+ if (MutableFlatbuffer* nested_field = Mutable(field);
nested_field == nullptr ||
!nested_field->MergeFrom(
from->GetPointer<const flatbuffers::Table* const>(
@@ -487,46 +463,13 @@
return false;
}
break;
- case reflection::Vector:
- switch (field->type()->element()) {
- case reflection::Int:
- AppendFromVector<int32>(from, field);
- break;
- case reflection::UInt:
- AppendFromVector<uint>(from, field);
- break;
- case reflection::Long:
- AppendFromVector<int64>(from, field);
- break;
- case reflection::ULong:
- AppendFromVector<uint64>(from, field);
- break;
- case reflection::Byte:
- AppendFromVector<int8_t>(from, field);
- break;
- case reflection::UByte:
- AppendFromVector<uint8_t>(from, field);
- break;
- case reflection::String:
- AppendFromVector<std::string>(from, field);
- break;
- case reflection::Obj:
- AppendFromVector<ReflectiveFlatbuffer>(from, field);
- break;
- case reflection::Double:
- AppendFromVector<double>(from, field);
- break;
- case reflection::Float:
- AppendFromVector<float>(from, field);
- break;
- default:
- TC3_LOG(ERROR) << "Repeated unsupported type: "
- << field->type()->element()
- << " for field: " << field->name()->str();
- return false;
- break;
+ case reflection::Vector: {
+ if (RepeatedField* repeated_field = Repeated(field);
+ repeated_field == nullptr || !repeated_field->Extend(from)) {
+ return false;
}
break;
+ }
default:
TC3_LOG(ERROR) << "Unsupported type: " << type
<< " for field: " << field->name()->str();
@@ -536,12 +479,12 @@
return true;
}
-bool ReflectiveFlatbuffer::MergeFromSerializedFlatbuffer(StringPiece from) {
+bool MutableFlatbuffer::MergeFromSerializedFlatbuffer(StringPiece from) {
return MergeFrom(flatbuffers::GetAnyRoot(
reinterpret_cast<const unsigned char*>(from.data())));
}
-void ReflectiveFlatbuffer::AsFlatMap(
+void MutableFlatbuffer::AsFlatMap(
const std::string& key_separator, const std::string& key_prefix,
std::map<std::string, Variant>* result) const {
// Add direct fields.
@@ -557,7 +500,23 @@
}
}
-std::string ReflectiveFlatbuffer::ToTextProto() const {
+std::string RepeatedField::ToTextProto() const {
+ std::string result = " [";
+ std::string current_field_separator;
+ for (int index = 0; index < Size(); index++) {
+ if (is_primitive_) {
+ result.append(current_field_separator + items_.at(index).ToString());
+ } else {
+ result.append(current_field_separator + "{" +
+ Get<MutableFlatbuffer*>(index)->ToTextProto() + "}");
+ }
+ current_field_separator = ", ";
+ }
+ result.append("] ");
+ return result;
+}
+
+std::string MutableFlatbuffer::ToTextProto() const {
std::string result;
std::string current_field_separator;
// Add direct fields.
@@ -573,6 +532,14 @@
current_field_separator = ", ";
}
+ // Add repeated message
+ for (const auto& repeated_fb_pair : repeated_fields_) {
+ result.append(current_field_separator +
+ repeated_fb_pair.first->name()->c_str() + ": " +
+ repeated_fb_pair.second->ToTextProto());
+ current_field_separator = ", ";
+ }
+
// Add nested messages.
for (const auto& field_flatbuffer_pair : children_) {
const std::string field_name = field_flatbuffer_pair.first->name()->str();
@@ -584,47 +551,17 @@
return result;
}
-bool SwapFieldNamesForOffsetsInPath(const reflection::Schema* schema,
- FlatbufferFieldPathT* path) {
- if (schema == nullptr || !schema->root_table()) {
- TC3_LOG(ERROR) << "Empty schema provided.";
- return false;
- }
-
- reflection::Object const* type = schema->root_table();
- for (int i = 0; i < path->field.size(); i++) {
- const reflection::Field* field = GetFieldOrNull(type, path->field[i].get());
- if (field == nullptr) {
- TC3_LOG(ERROR) << "Could not find field: " << path->field[i]->field_name;
- return false;
- }
- path->field[i]->field_name.clear();
- path->field[i]->field_offset = field->offset();
-
- // Descend.
- if (i < path->field.size() - 1) {
- if (field->type()->base_type() != reflection::Obj) {
- TC3_LOG(ERROR) << "Field: " << field->name()->str()
- << " is not of type `Object`.";
- return false;
- }
- type = schema->objects()->Get(field->type()->index());
- }
- }
- return true;
-}
-
//
// Repeated field methods.
//
-ReflectiveFlatbuffer* RepeatedField::Add() {
+MutableFlatbuffer* RepeatedField::Add() {
if (is_primitive_) {
TC3_LOG(ERROR) << "Trying to add sub-message on a primitive-typed field.";
return nullptr;
}
- object_items_.emplace_back(new ReflectiveFlatbuffer(
+ object_items_.emplace_back(new MutableFlatbuffer(
schema_, schema_->objects()->Get(field_->type()->index())));
return object_items_.back().get();
}
@@ -644,6 +581,46 @@
} // namespace
+bool RepeatedField::Extend(const flatbuffers::Table* from) {
+ switch (field_->type()->element()) {
+ case reflection::Int:
+ AppendFromVector<int32>(from);
+ return true;
+ case reflection::UInt:
+ AppendFromVector<uint>(from);
+ return true;
+ case reflection::Long:
+ AppendFromVector<int64>(from);
+ return true;
+ case reflection::ULong:
+ AppendFromVector<uint64>(from);
+ return true;
+ case reflection::Byte:
+ AppendFromVector<int8_t>(from);
+ return true;
+ case reflection::UByte:
+ AppendFromVector<uint8_t>(from);
+ return true;
+ case reflection::String:
+ AppendFromVector<std::string>(from);
+ return true;
+ case reflection::Obj:
+ AppendFromVector<MutableFlatbuffer>(from);
+ return true;
+ case reflection::Double:
+ AppendFromVector<double>(from);
+ return true;
+ case reflection::Float:
+ AppendFromVector<float>(from);
+ return true;
+ default:
+ TC3_LOG(ERROR) << "Repeated unsupported type: "
+ << field_->type()->element()
+ << " for field: " << field_->name()->str();
+ return false;
+ }
+}
+
flatbuffers::uoffset_t RepeatedField::Serialize(
flatbuffers::FlatBufferBuilder* builder) const {
switch (field_->type()->element()) {
diff --git a/native/utils/flatbuffers/mutable.h b/native/utils/flatbuffers/mutable.h
new file mode 100644
index 0000000..90f6baa
--- /dev/null
+++ b/native/utils/flatbuffers/mutable.h
@@ -0,0 +1,429 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_FLATBUFFERS_MUTABLE_H_
+#define LIBTEXTCLASSIFIER_UTILS_FLATBUFFERS_MUTABLE_H_
+
+#include <memory>
+#include <string>
+#include <unordered_map>
+
+#include "annotator/model_generated.h"
+#include "utils/base/logging.h"
+#include "utils/flatbuffers/flatbuffers_generated.h"
+#include "utils/flatbuffers/reflection.h"
+#include "utils/strings/stringpiece.h"
+#include "utils/variant.h"
+#include "flatbuffers/flatbuffers.h"
+#include "flatbuffers/reflection.h"
+#include "flatbuffers/reflection_generated.h"
+
+namespace libtextclassifier3 {
+
+class MutableFlatbuffer;
+class RepeatedField;
+
+template <typename T>
+constexpr bool IsStringType() {
+ return std::is_same<T, std::string>::value ||
+ std::is_same<T, StringPiece>::value ||
+ std::is_same<T, const char*>::value;
+}
+
+// Checks whether a variant value type agrees with a field type.
+template <typename T>
+bool IsMatchingType(const reflection::BaseType type) {
+ switch (type) {
+ case reflection::String:
+ return IsStringType<T>();
+ case reflection::Obj:
+ return std::is_same<T, MutableFlatbuffer>::value;
+ default:
+ return type == flatbuffers_base_type<T>::value;
+ }
+}
+
+// A mutable flatbuffer that can be built using flatbuffer reflection data of
+// the schema. Normally, field information is hard-coded in code generated from
+// a flatbuffer schema. Here we lookup the necessary information for building a
+// flatbuffer from the provided reflection meta data. When serializing a
+// flatbuffer, the library requires that the sub messages are already
+// serialized, therefore we explicitly keep the field values and serialize the
+// message in (reverse) topological dependency order.
+class MutableFlatbuffer {
+ public:
+ MutableFlatbuffer(const reflection::Schema* schema,
+ const reflection::Object* type)
+ : schema_(schema), type_(type) {}
+
+ // Gets the field information for a field name, returns nullptr if the
+ // field was not defined.
+ const reflection::Field* GetFieldOrNull(const StringPiece field_name) const;
+ const reflection::Field* GetFieldOrNull(const FlatbufferField* field) const;
+ const reflection::Field* GetFieldOrNull(const int field_offset) const;
+
+ // Gets a nested field and the message it is defined on.
+ bool GetFieldWithParent(const FlatbufferFieldPath* field_path,
+ MutableFlatbuffer** parent,
+ reflection::Field const** field);
+
+ // Sets a field to a specific value.
+ // Returns true if successful, and false if the field was not found or the
+ // expected type doesn't match.
+ template <typename T>
+ bool Set(StringPiece field_name, T value);
+
+ // Sets a field to a specific value.
+ // Returns true if successful, and false if the expected type doesn't match.
+ // Expects `field` to be non-null.
+ template <typename T>
+ bool Set(const reflection::Field* field, T value);
+
+ // Sets a field to a specific value. Field is specified by path.
+ template <typename T>
+ bool Set(const FlatbufferFieldPath* path, T value);
+
+ // Sets an enum field from an enum value name.
+ // Returns true if the value could be successfully parsed.
+ bool SetFromEnumValueName(StringPiece field_name, StringPiece value_name);
+
+ // Sets an enum field from an enum value name.
+ // Returns true if the value could be successfully parsed.
+ bool SetFromEnumValueName(const reflection::Field* field,
+ StringPiece value_name);
+
+ // Sets an enum field from an enum value name. Field is specified by path.
+ // Returns true if the value could be successfully parsed.
+ bool SetFromEnumValueName(const FlatbufferFieldPath* path,
+ StringPiece value_name);
+
+ // Sets sub-message field (if not set yet), and returns a pointer to it.
+ // Returns nullptr if the field was not found, or the field type was not a
+ // table.
+ MutableFlatbuffer* Mutable(StringPiece field_name);
+ MutableFlatbuffer* Mutable(const reflection::Field* field);
+
+ // Sets a sub-message field (if not set yet) specified by path, and returns a
+ // pointer to it. Returns nullptr if the field was not found, or the field
+ // type was not a table.
+ MutableFlatbuffer* Mutable(const FlatbufferFieldPath* path);
+
+ // Parses the value (according to the type) and sets a primitive field to the
+ // parsed value.
+ bool ParseAndSet(const reflection::Field* field, const std::string& value);
+ bool ParseAndSet(const FlatbufferFieldPath* path, const std::string& value);
+
+ // Adds a primitive value to the repeated field.
+ template <typename T>
+ bool Add(StringPiece field_name, T value);
+
+ // Add a sub-message to the repeated field.
+ MutableFlatbuffer* Add(StringPiece field_name);
+
+ template <typename T>
+ bool Add(const reflection::Field* field, T value);
+
+ MutableFlatbuffer* Add(const reflection::Field* field);
+
+ // Gets the reflective flatbuffer for a repeated field.
+ // Returns nullptr if the field was not found, or the field type was not a
+ // vector.
+ RepeatedField* Repeated(StringPiece field_name);
+ RepeatedField* Repeated(const reflection::Field* field);
+
+ // Gets a repeated field specified by path.
+ // Returns nullptr if the field was not found, or the field
+ // type was not a repeated field.
+ RepeatedField* Repeated(const FlatbufferFieldPath* path);
+
+ // Serializes the flatbuffer.
+ flatbuffers::uoffset_t Serialize(
+ flatbuffers::FlatBufferBuilder* builder) const;
+ std::string Serialize() const;
+
+ // Merges the fields from the given flatbuffer table into this flatbuffer.
+ // Scalar fields will be overwritten, if present in `from`.
+ // Embedded messages will be merged.
+ bool MergeFrom(const flatbuffers::Table* from);
+ bool MergeFromSerializedFlatbuffer(StringPiece from);
+
+ // Flattens the flatbuffer as a flat map.
+ // (Nested) fields names are joined by `key_separator`.
+ std::map<std::string, Variant> AsFlatMap(
+ const std::string& key_separator = ".") const {
+ std::map<std::string, Variant> result;
+ AsFlatMap(key_separator, /*key_prefix=*/"", &result);
+ return result;
+ }
+
+ // Converts the flatbuffer's content to a human-readable textproto
+ // representation.
+ std::string ToTextProto() const;
+
+ bool HasExplicitlySetFields() const {
+ return !fields_.empty() || !children_.empty() || !repeated_fields_.empty();
+ }
+
+ const reflection::Object* type() const { return type_; }
+
+ private:
+ // Helper function for merging given repeated field from given flatbuffer
+ // table. Appends the elements.
+ template <typename T>
+ bool AppendFromVector(const flatbuffers::Table* from,
+ const reflection::Field* field);
+
+ // Flattens the flatbuffer as a flat map.
+ // (Nested) fields names are joined by `key_separator` and prefixed by
+ // `key_prefix`.
+ void AsFlatMap(const std::string& key_separator,
+ const std::string& key_prefix,
+ std::map<std::string, Variant>* result) const;
+
+ const reflection::Schema* const schema_;
+ const reflection::Object* const type_;
+
+ // Cached primitive fields (scalars and strings).
+ std::unordered_map<const reflection::Field*, Variant> fields_;
+
+ // Cached sub-messages.
+ std::unordered_map<const reflection::Field*,
+ std::unique_ptr<MutableFlatbuffer>>
+ children_;
+
+ // Cached repeated fields.
+ std::unordered_map<const reflection::Field*, std::unique_ptr<RepeatedField>>
+ repeated_fields_;
+};
+
+// A helper class to build flatbuffers based on schema reflection data.
+// Can be used to a `MutableFlatbuffer` for the root message of the
+// schema, or any defined table via name.
+class MutableFlatbufferBuilder {
+ public:
+ explicit MutableFlatbufferBuilder(const reflection::Schema* schema)
+ : schema_(schema), root_type_(schema->root_table()) {}
+ explicit MutableFlatbufferBuilder(const reflection::Schema* schema,
+ StringPiece root_type);
+
+ // Starts a new root table message.
+ std::unique_ptr<MutableFlatbuffer> NewRoot() const;
+
+ // Creates a new table message. Returns nullptr if no table with given name is
+ // found in the schema.
+ std::unique_ptr<MutableFlatbuffer> NewTable(
+ const StringPiece table_name) const;
+
+ // Creates a new message for the given type id. Returns nullptr if the type is
+ // invalid.
+ std::unique_ptr<MutableFlatbuffer> NewTable(int type_id) const;
+
+ // Creates a new message for the given type.
+ std::unique_ptr<MutableFlatbuffer> NewTable(
+ const reflection::Object* type) const;
+
+ private:
+ const reflection::Schema* const schema_;
+ const reflection::Object* const root_type_;
+};
+
+// Encapsulates a repeated field.
+// Serves as a common base class for repeated fields.
+class RepeatedField {
+ public:
+ RepeatedField(const reflection::Schema* const schema,
+ const reflection::Field* field)
+ : schema_(schema),
+ field_(field),
+ is_primitive_(field->type()->element() != reflection::BaseType::Obj) {}
+
+ template <typename T>
+ bool Add(const T value);
+
+ MutableFlatbuffer* Add();
+
+ template <typename T>
+ T Get(int index) const {
+ return items_.at(index).Value<T>();
+ }
+
+ template <>
+ MutableFlatbuffer* Get(int index) const {
+ if (is_primitive_) {
+ TC3_LOG(ERROR) << "Trying to get primitive value out of non-primitive "
+ "repeated field.";
+ return nullptr;
+ }
+ return object_items_.at(index).get();
+ }
+
+ int Size() const {
+ if (is_primitive_) {
+ return items_.size();
+ } else {
+ return object_items_.size();
+ }
+ }
+
+ bool Extend(const flatbuffers::Table* from);
+
+ flatbuffers::uoffset_t Serialize(
+ flatbuffers::FlatBufferBuilder* builder) const;
+
+ std::string ToTextProto() const;
+
+ private:
+ template <typename T>
+ bool AppendFromVector(const flatbuffers::Table* from);
+
+ flatbuffers::uoffset_t SerializeString(
+ flatbuffers::FlatBufferBuilder* builder) const;
+ flatbuffers::uoffset_t SerializeObject(
+ flatbuffers::FlatBufferBuilder* builder) const;
+
+ const reflection::Schema* const schema_;
+ const reflection::Field* field_;
+ bool is_primitive_;
+
+ std::vector<Variant> items_;
+ std::vector<std::unique_ptr<MutableFlatbuffer>> object_items_;
+};
+
+template <typename T>
+bool MutableFlatbuffer::Set(StringPiece field_name, T value) {
+ if (const reflection::Field* field = GetFieldOrNull(field_name)) {
+ if (field->type()->base_type() == reflection::BaseType::Vector ||
+ field->type()->base_type() == reflection::BaseType::Obj) {
+ TC3_LOG(ERROR)
+ << "Trying to set a primitive value on a non-scalar field.";
+ return false;
+ }
+ return Set<T>(field, value);
+ }
+ TC3_LOG(ERROR) << "Couldn't find a field: " << field_name;
+ return false;
+}
+
+template <typename T>
+bool MutableFlatbuffer::Set(const reflection::Field* field, T value) {
+ if (field == nullptr) {
+ TC3_LOG(ERROR) << "Expected non-null field.";
+ return false;
+ }
+ Variant variant_value(value);
+ if (!IsMatchingType<T>(field->type()->base_type())) {
+ TC3_LOG(ERROR) << "Type mismatch for field `" << field->name()->str()
+ << "`, expected: "
+ << EnumNameBaseType(field->type()->base_type())
+ << ", got: " << variant_value.GetType();
+ return false;
+ }
+ fields_[field] = variant_value;
+ return true;
+}
+
+template <typename T>
+bool MutableFlatbuffer::Set(const FlatbufferFieldPath* path, T value) {
+ MutableFlatbuffer* parent;
+ const reflection::Field* field;
+ if (!GetFieldWithParent(path, &parent, &field)) {
+ return false;
+ }
+ return parent->Set<T>(field, value);
+}
+
+template <typename T>
+bool MutableFlatbuffer::Add(StringPiece field_name, T value) {
+ const reflection::Field* field = GetFieldOrNull(field_name);
+ if (field == nullptr) {
+ return false;
+ }
+
+ if (field->type()->base_type() != reflection::BaseType::Vector) {
+ return false;
+ }
+
+ return Add<T>(field, value);
+}
+
+template <typename T>
+bool MutableFlatbuffer::Add(const reflection::Field* field, T value) {
+ if (field == nullptr) {
+ return false;
+ }
+ Repeated(field)->Add(value);
+ return true;
+}
+
+template <typename T>
+bool RepeatedField::Add(const T value) {
+ if (!is_primitive_ || !IsMatchingType<T>(field_->type()->element())) {
+ TC3_LOG(ERROR) << "Trying to add value of unmatching type.";
+ return false;
+ }
+ items_.push_back(Variant{value});
+ return true;
+}
+
+template <typename T>
+bool RepeatedField::AppendFromVector(const flatbuffers::Table* from) {
+ const flatbuffers::Vector<T>* values =
+ from->GetPointer<const flatbuffers::Vector<T>*>(field_->offset());
+ if (values == nullptr) {
+ return false;
+ }
+ for (const T element : *values) {
+ Add(element);
+ }
+ return true;
+}
+
+template <>
+inline bool RepeatedField::AppendFromVector<std::string>(
+ const flatbuffers::Table* from) {
+ auto* values = from->GetPointer<
+ const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>*>(
+ field_->offset());
+ if (values == nullptr) {
+ return false;
+ }
+ for (const flatbuffers::String* element : *values) {
+ Add(element->str());
+ }
+ return true;
+}
+
+template <>
+inline bool RepeatedField::AppendFromVector<MutableFlatbuffer>(
+ const flatbuffers::Table* from) {
+ auto* values = from->GetPointer<const flatbuffers::Vector<
+ flatbuffers::Offset<const flatbuffers::Table>>*>(field_->offset());
+ if (values == nullptr) {
+ return false;
+ }
+ for (const flatbuffers::Table* const from_element : *values) {
+ MutableFlatbuffer* to_element = Add();
+ if (to_element == nullptr) {
+ return false;
+ }
+ to_element->MergeFrom(from_element);
+ }
+ return true;
+}
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_FLATBUFFERS_MUTABLE_H_
diff --git a/native/utils/flatbuffers/mutable_test.cc b/native/utils/flatbuffers/mutable_test.cc
new file mode 100644
index 0000000..8fefc07
--- /dev/null
+++ b/native/utils/flatbuffers/mutable_test.cc
@@ -0,0 +1,367 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/flatbuffers/mutable.h"
+
+#include <map>
+#include <memory>
+#include <string>
+
+#include "utils/flatbuffers/flatbuffers.h"
+#include "utils/flatbuffers/flatbuffers_generated.h"
+#include "utils/flatbuffers/flatbuffers_test_generated.h"
+#include "utils/flatbuffers/test-utils.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "flatbuffers/flatbuffers.h"
+#include "flatbuffers/reflection.h"
+#include "flatbuffers/reflection_generated.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+using ::testing::ElementsAre;
+using ::testing::SizeIs;
+
+class MutableFlatbufferTest : public testing::Test {
+ public:
+ explicit MutableFlatbufferTest()
+ : schema_(LoadTestMetadata()), builder_(schema_.get()) {}
+
+ protected:
+ OwnedFlatbuffer<reflection::Schema, std::string> schema_;
+ MutableFlatbufferBuilder builder_;
+};
+
+TEST_F(MutableFlatbufferTest, PrimitiveFieldsAreCorrectlySet) {
+ std::unique_ptr<MutableFlatbuffer> buffer = builder_.NewRoot();
+ EXPECT_TRUE(buffer != nullptr);
+ EXPECT_TRUE(buffer->Set("an_int_field", 42));
+ EXPECT_TRUE(buffer->Set("a_long_field", int64{84}));
+ EXPECT_TRUE(buffer->Set("a_bool_field", true));
+ EXPECT_TRUE(buffer->Set("a_float_field", 1.f));
+ EXPECT_TRUE(buffer->Set("a_double_field", 1.0));
+
+ // Try to parse with the generated code.
+ std::unique_ptr<test::EntityDataT> entity_data =
+ LoadAndVerifyMutableFlatbuffer<test::EntityData>(buffer->Serialize());
+ ASSERT_NE(entity_data, nullptr);
+ EXPECT_EQ(entity_data->an_int_field, 42);
+ EXPECT_EQ(entity_data->a_long_field, 84);
+ EXPECT_EQ(entity_data->a_bool_field, true);
+ EXPECT_NEAR(entity_data->a_float_field, 1.f, 1e-4);
+ EXPECT_NEAR(entity_data->a_double_field, 1.f, 1e-4);
+}
+
+TEST_F(MutableFlatbufferTest, EnumValuesCanBeSpecifiedByName) {
+ std::unique_ptr<MutableFlatbuffer> buffer = builder_.NewRoot();
+ EXPECT_TRUE(buffer != nullptr);
+
+ EXPECT_TRUE(IsEnum(buffer->GetFieldOrNull("enum_value")->type()));
+
+ EXPECT_TRUE(buffer->SetFromEnumValueName("enum_value", "VALUE_1"));
+
+ std::unique_ptr<test::EntityDataT> entity_data =
+ LoadAndVerifyMutableFlatbuffer<test::EntityData>(buffer->Serialize());
+ EXPECT_EQ(entity_data->enum_value,
+ libtextclassifier3::test::EnumValue_VALUE_1);
+}
+
+TEST_F(MutableFlatbufferTest, HandlesUnknownFields) {
+ std::unique_ptr<MutableFlatbuffer> buffer = builder_.NewRoot();
+ EXPECT_TRUE(buffer != nullptr);
+
+ // Add a field that is not known to the (statically generated) code.
+ EXPECT_TRUE(buffer->Set("mystic", "this is an unknown field."));
+
+ OwnedFlatbuffer<flatbuffers::Table, std::string> extra(buffer->Serialize());
+ EXPECT_EQ(extra
+ ->GetPointer<const flatbuffers::String*>(
+ buffer->GetFieldOrNull("mystic")->offset())
+ ->str(),
+ "this is an unknown field.");
+}
+
+TEST_F(MutableFlatbufferTest, HandlesNestedFields) {
+ OwnedFlatbuffer<FlatbufferFieldPath, std::string> path =
+ CreateFieldPath({"flight_number", "carrier_code"});
+ std::unique_ptr<MutableFlatbuffer> buffer = builder_.NewRoot();
+
+ MutableFlatbuffer* parent = nullptr;
+ reflection::Field const* field = nullptr;
+ EXPECT_TRUE(buffer->GetFieldWithParent(path.get(), &parent, &field));
+ EXPECT_EQ(parent, buffer->Mutable("flight_number"));
+ EXPECT_EQ(field,
+ buffer->Mutable("flight_number")->GetFieldOrNull("carrier_code"));
+}
+
+TEST_F(MutableFlatbufferTest, HandlesMultipleNestedFields) {
+ std::unique_ptr<MutableFlatbuffer> buffer = builder_.NewRoot();
+ MutableFlatbuffer* flight_info = buffer->Mutable("flight_number");
+ flight_info->Set("carrier_code", "LX");
+ flight_info->Set("flight_code", 38);
+
+ MutableFlatbuffer* contact_info = buffer->Mutable("contact_info");
+ EXPECT_TRUE(contact_info->Set("first_name", "Barack"));
+ EXPECT_TRUE(contact_info->Set("last_name", "Obama"));
+ EXPECT_TRUE(contact_info->Set("phone_number", "1-800-TEST"));
+ EXPECT_TRUE(contact_info->Set("score", 1.f));
+
+ // Try to parse with the generated code.
+ std::unique_ptr<test::EntityDataT> entity_data =
+ LoadAndVerifyMutableFlatbuffer<test::EntityData>(buffer->Serialize());
+ ASSERT_NE(entity_data, nullptr);
+ EXPECT_EQ(entity_data->flight_number->carrier_code, "LX");
+ EXPECT_EQ(entity_data->flight_number->flight_code, 38);
+ EXPECT_EQ(entity_data->contact_info->first_name, "Barack");
+ EXPECT_EQ(entity_data->contact_info->last_name, "Obama");
+ EXPECT_EQ(entity_data->contact_info->phone_number, "1-800-TEST");
+ EXPECT_NEAR(entity_data->contact_info->score, 1.f, 1e-4);
+}
+
+TEST_F(MutableFlatbufferTest, HandlesFieldsSetWithNamePath) {
+ FlatbufferFieldPathT path;
+ path.field.emplace_back(new FlatbufferFieldT);
+ path.field.back()->field_name = "flight_number";
+ path.field.emplace_back(new FlatbufferFieldT);
+ path.field.back()->field_name = "carrier_code";
+ flatbuffers::FlatBufferBuilder path_builder;
+ path_builder.Finish(FlatbufferFieldPath::Pack(path_builder, &path));
+
+ std::unique_ptr<MutableFlatbuffer> buffer = builder_.NewRoot();
+ // Test setting value using Set function.
+ buffer->Mutable("flight_number")->Set("flight_code", 38);
+ // Test setting value using FlatbufferFieldPath.
+ buffer->Set(flatbuffers::GetRoot<FlatbufferFieldPath>(
+ path_builder.GetBufferPointer()),
+ "LX");
+
+ // Try to parse with the generated code.
+ std::unique_ptr<test::EntityDataT> entity_data =
+ LoadAndVerifyMutableFlatbuffer<test::EntityData>(buffer->Serialize());
+ ASSERT_NE(entity_data, nullptr);
+ EXPECT_EQ(entity_data->flight_number->carrier_code, "LX");
+ EXPECT_EQ(entity_data->flight_number->flight_code, 38);
+}
+
+TEST_F(MutableFlatbufferTest, HandlesFieldsSetWithOffsetPath) {
+ FlatbufferFieldPathT path;
+ path.field.emplace_back(new FlatbufferFieldT);
+ path.field.back()->field_offset = 14;
+ path.field.emplace_back(new FlatbufferFieldT);
+ path.field.back()->field_offset = 4;
+ flatbuffers::FlatBufferBuilder path_builder;
+ path_builder.Finish(FlatbufferFieldPath::Pack(path_builder, &path));
+
+ std::unique_ptr<MutableFlatbuffer> buffer = builder_.NewRoot();
+ // Test setting value using Set function.
+ buffer->Mutable("flight_number")->Set("flight_code", 38);
+ // Test setting value using FlatbufferFieldPath.
+ buffer->Set(flatbuffers::GetRoot<FlatbufferFieldPath>(
+ path_builder.GetBufferPointer()),
+ "LX");
+
+ // Try to parse with the generated code.
+ std::unique_ptr<test::EntityDataT> entity_data =
+ LoadAndVerifyMutableFlatbuffer<test::EntityData>(buffer->Serialize());
+ ASSERT_NE(entity_data, nullptr);
+ EXPECT_EQ(entity_data->flight_number->carrier_code, "LX");
+ EXPECT_EQ(entity_data->flight_number->flight_code, 38);
+}
+
+TEST_F(MutableFlatbufferTest, PartialBuffersAreCorrectlyMerged) {
+ std::unique_ptr<MutableFlatbuffer> buffer = builder_.NewRoot();
+ buffer->Set("an_int_field", 42);
+ buffer->Set("a_long_field", int64{84});
+ MutableFlatbuffer* flight_info = buffer->Mutable("flight_number");
+ flight_info->Set("carrier_code", "LX");
+ flight_info->Set("flight_code", 38);
+ auto* reminders = buffer->Repeated("reminders");
+ MutableFlatbuffer* reminder1 = reminders->Add();
+ reminder1->Set("title", "reminder1");
+ auto* reminder1_notes = reminder1->Repeated("notes");
+ reminder1_notes->Add("note1");
+ reminder1_notes->Add("note2");
+
+ // Create message to merge.
+ test::EntityDataT additional_entity_data;
+ additional_entity_data.an_int_field = 43;
+ additional_entity_data.flight_number.reset(new test::FlightNumberInfoT);
+ additional_entity_data.flight_number->flight_code = 39;
+ additional_entity_data.contact_info.reset(new test::ContactInfoT);
+ additional_entity_data.contact_info->first_name = "Barack";
+ additional_entity_data.reminders.push_back(
+ std::unique_ptr<test::ReminderT>(new test::ReminderT));
+ additional_entity_data.reminders[0]->notes.push_back("additional note1");
+ additional_entity_data.reminders[0]->notes.push_back("additional note2");
+ additional_entity_data.numbers.push_back(9);
+ additional_entity_data.numbers.push_back(10);
+ additional_entity_data.strings.push_back("str1");
+ additional_entity_data.strings.push_back("str2");
+
+ // Merge it.
+ EXPECT_TRUE(buffer->MergeFromSerializedFlatbuffer(
+ PackFlatbuffer<test::EntityData>(&additional_entity_data)));
+
+ // Try to parse it with the generated code.
+ std::unique_ptr<test::EntityDataT> entity_data =
+ LoadAndVerifyMutableFlatbuffer<test::EntityData>(buffer->Serialize());
+ ASSERT_NE(entity_data, nullptr);
+ EXPECT_EQ(entity_data->an_int_field, 43);
+ EXPECT_EQ(entity_data->a_long_field, 84);
+ EXPECT_EQ(entity_data->flight_number->carrier_code, "LX");
+ EXPECT_EQ(entity_data->flight_number->flight_code, 39);
+ EXPECT_EQ(entity_data->contact_info->first_name, "Barack");
+ ASSERT_THAT(entity_data->reminders, SizeIs(2));
+ EXPECT_THAT(entity_data->reminders[1]->notes,
+ ElementsAre("additional note1", "additional note2"));
+ EXPECT_THAT(entity_data->numbers, ElementsAre(9, 10));
+ EXPECT_THAT(entity_data->strings, ElementsAre("str1", "str2"));
+}
+
+TEST_F(MutableFlatbufferTest, MergesNestedFields) {
+ std::unique_ptr<MutableFlatbuffer> buffer = builder_.NewRoot();
+
+ // Set a multiply nested field.
+ OwnedFlatbuffer<FlatbufferFieldPath, std::string> field_path =
+ CreateFieldPath({"nested", "nestedb", "nesteda", "nestedb", "nesteda"});
+ buffer->Mutable(field_path.get())->Set("value", "le value");
+
+ std::unique_ptr<test::EntityDataT> entity_data =
+ LoadAndVerifyMutableFlatbuffer<test::EntityData>(buffer->Serialize());
+ ASSERT_NE(entity_data, nullptr);
+ EXPECT_EQ(entity_data->nested->nestedb->nesteda->nestedb->nesteda->value,
+ "le value");
+}
+
+TEST_F(MutableFlatbufferTest, PrimitiveAndNestedFieldsAreCorrectlyFlattened) {
+ std::unique_ptr<MutableFlatbuffer> buffer = builder_.NewRoot();
+ buffer->Set("an_int_field", 42);
+ buffer->Set("a_long_field", int64{84});
+ MutableFlatbuffer* flight_info = buffer->Mutable("flight_number");
+ flight_info->Set("carrier_code", "LX");
+ flight_info->Set("flight_code", 38);
+
+ std::map<std::string, Variant> entity_data_map = buffer->AsFlatMap();
+ EXPECT_EQ(4, entity_data_map.size());
+ EXPECT_EQ(42, entity_data_map["an_int_field"].Value<int>());
+ EXPECT_EQ(84, entity_data_map["a_long_field"].Value<int64>());
+ EXPECT_EQ("LX", entity_data_map["flight_number.carrier_code"]
+ .ConstRefValue<std::string>());
+ EXPECT_EQ(38, entity_data_map["flight_number.flight_code"].Value<int>());
+}
+
+TEST_F(MutableFlatbufferTest, ToTextProtoWorks) {
+ std::unique_ptr<MutableFlatbuffer> buffer = builder_.NewRoot();
+ buffer->Set("an_int_field", 42);
+ buffer->Set("a_long_field", int64{84});
+ MutableFlatbuffer* flight_info = buffer->Mutable("flight_number");
+ flight_info->Set("carrier_code", "LX");
+ flight_info->Set("flight_code", 38);
+
+ // Add non primitive type.
+ auto reminders = buffer->Repeated("reminders");
+ auto foo_reminder = reminders->Add();
+ foo_reminder->Set("title", "foo reminder");
+ auto bar_reminder = reminders->Add();
+ bar_reminder->Set("title", "bar reminder");
+
+ // Add primitive type.
+ EXPECT_TRUE(buffer->Repeated("numbers")->Add(static_cast<int>(111)));
+ EXPECT_TRUE(buffer->Repeated("numbers")->Add(static_cast<int>(222)));
+ EXPECT_TRUE(buffer->Repeated("numbers")->Add(static_cast<int>(333)));
+
+ EXPECT_EQ(buffer->ToTextProto(),
+ "a_long_field: 84, an_int_field: 42, numbers: [111, 222, 333] , "
+ "reminders: [{title: 'foo reminder'}, {title: 'bar reminder'}] , "
+ "flight_number {flight_code: 38, carrier_code: 'LX'}");
+}
+
+TEST_F(MutableFlatbufferTest, RepeatedFieldSetThroughReflectionCanBeRead) {
+ std::unique_ptr<MutableFlatbuffer> buffer = builder_.NewRoot();
+
+ auto reminders = buffer->Repeated("reminders");
+ {
+ auto reminder = reminders->Add();
+ reminder->Set("title", "test reminder");
+ auto notes = reminder->Repeated("notes");
+ notes->Add("note A");
+ notes->Add("note B");
+ }
+ {
+ auto reminder = reminders->Add();
+ reminder->Set("title", "test reminder 2");
+ reminder->Add("notes", "note i");
+ reminder->Add("notes", "note ii");
+ reminder->Add("notes", "note iii");
+ }
+
+ std::unique_ptr<test::EntityDataT> entity_data =
+ LoadAndVerifyMutableFlatbuffer<test::EntityData>(buffer->Serialize());
+ ASSERT_NE(entity_data, nullptr);
+ EXPECT_THAT(entity_data->reminders, SizeIs(2));
+ EXPECT_EQ(entity_data->reminders[0]->title, "test reminder");
+ EXPECT_THAT(entity_data->reminders[0]->notes,
+ ElementsAre("note A", "note B"));
+ EXPECT_EQ(entity_data->reminders[1]->title, "test reminder 2");
+ EXPECT_THAT(entity_data->reminders[1]->notes,
+ ElementsAre("note i", "note ii", "note iii"));
+}
+
+TEST_F(MutableFlatbufferTest, RepeatedFieldAddMethodWithIncompatibleValues) {
+ std::unique_ptr<MutableFlatbuffer> buffer = builder_.NewRoot();
+ EXPECT_FALSE(buffer->Repeated("numbers")->Add(static_cast<int64>(123)));
+ EXPECT_FALSE(buffer->Repeated("numbers")->Add(static_cast<int8>(9)));
+ EXPECT_TRUE(buffer->Repeated("numbers")->Add(static_cast<int>(999)));
+
+ // Try to parse it with the generated code.
+ std::unique_ptr<test::EntityDataT> entity_data =
+ LoadAndVerifyMutableFlatbuffer<test::EntityData>(buffer->Serialize());
+ ASSERT_NE(entity_data, nullptr);
+ ASSERT_NE(entity_data, nullptr);
+ ASSERT_EQ(entity_data->numbers.size(), 1);
+ EXPECT_EQ(entity_data->numbers[0], 999);
+}
+
+TEST_F(MutableFlatbufferTest, RepeatedFieldGetAndSizeMethods) {
+ std::unique_ptr<MutableFlatbuffer> buffer = builder_.NewRoot();
+ EXPECT_TRUE(buffer->Repeated("numbers")->Add(1));
+ EXPECT_TRUE(buffer->Repeated("numbers")->Add(2));
+ EXPECT_TRUE(buffer->Repeated("numbers")->Add(3));
+
+ EXPECT_EQ(buffer->Repeated("numbers")->Size(), 3);
+ EXPECT_EQ(buffer->Repeated("numbers")->Get<int>(0), 1);
+ EXPECT_EQ(buffer->Repeated("numbers")->Get<int>(1), 2);
+ EXPECT_EQ(buffer->Repeated("numbers")->Get<int>(2), 3);
+}
+
+TEST_F(MutableFlatbufferTest, GetsRepeatedFieldFromPath) {
+ std::unique_ptr<MutableFlatbuffer> buffer = builder_.NewRoot();
+ OwnedFlatbuffer<FlatbufferFieldPath, std::string> notes =
+ CreateFieldPath({"nested", "repeated_str"});
+
+ EXPECT_TRUE(buffer->Repeated(notes.get())->Add("a"));
+ EXPECT_TRUE(buffer->Repeated(notes.get())->Add("test"));
+
+ std::unique_ptr<test::EntityDataT> entity_data =
+ LoadAndVerifyMutableFlatbuffer<test::EntityData>(buffer->Serialize());
+ ASSERT_NE(entity_data, nullptr);
+ EXPECT_THAT(entity_data->nested->repeated_str, SizeIs(2));
+ EXPECT_THAT(entity_data->nested->repeated_str, ElementsAre("a", "test"));
+}
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/native/utils/flatbuffers/reflection.cc b/native/utils/flatbuffers/reflection.cc
new file mode 100644
index 0000000..7d6d3f4
--- /dev/null
+++ b/native/utils/flatbuffers/reflection.cc
@@ -0,0 +1,167 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/flatbuffers/reflection.h"
+
+namespace libtextclassifier3 {
+
+const reflection::Field* GetFieldOrNull(const reflection::Object* type,
+ const StringPiece field_name) {
+ TC3_CHECK(type != nullptr && type->fields() != nullptr);
+ return type->fields()->LookupByKey(field_name.data());
+}
+
+const reflection::Field* GetFieldOrNull(const reflection::Object* type,
+ const int field_offset) {
+ if (type->fields() == nullptr) {
+ return nullptr;
+ }
+ for (const reflection::Field* field : *type->fields()) {
+ if (field->offset() == field_offset) {
+ return field;
+ }
+ }
+ return nullptr;
+}
+
+const reflection::Field* GetFieldOrNull(const reflection::Object* type,
+ const StringPiece field_name,
+ const int field_offset) {
+ // Lookup by name might be faster as the fields are sorted by name in the
+ // schema data, so try that first.
+ if (!field_name.empty()) {
+ return GetFieldOrNull(type, field_name.data());
+ }
+ return GetFieldOrNull(type, field_offset);
+}
+
+const reflection::Field* GetFieldOrNull(const reflection::Object* type,
+ const FlatbufferField* field) {
+ TC3_CHECK(type != nullptr && field != nullptr);
+ if (field->field_name() == nullptr) {
+ return GetFieldOrNull(type, field->field_offset());
+ }
+ return GetFieldOrNull(
+ type,
+ StringPiece(field->field_name()->data(), field->field_name()->size()),
+ field->field_offset());
+}
+
+const reflection::Field* GetFieldOrNull(const reflection::Object* type,
+ const FlatbufferFieldT* field) {
+ TC3_CHECK(type != nullptr && field != nullptr);
+ return GetFieldOrNull(type, field->field_name, field->field_offset);
+}
+
+const reflection::Object* TypeForName(const reflection::Schema* schema,
+ const StringPiece type_name) {
+ for (const reflection::Object* object : *schema->objects()) {
+ if (type_name.Equals(object->name()->str())) {
+ return object;
+ }
+ }
+ return nullptr;
+}
+
+Optional<int> TypeIdForObject(const reflection::Schema* schema,
+ const reflection::Object* type) {
+ for (int i = 0; i < schema->objects()->size(); i++) {
+ if (schema->objects()->Get(i) == type) {
+ return Optional<int>(i);
+ }
+ }
+ return Optional<int>();
+}
+
+Optional<int> TypeIdForName(const reflection::Schema* schema,
+ const StringPiece type_name) {
+ for (int i = 0; i < schema->objects()->size(); i++) {
+ if (type_name.Equals(schema->objects()->Get(i)->name()->str())) {
+ return Optional<int>(i);
+ }
+ }
+ return Optional<int>();
+}
+
+bool SwapFieldNamesForOffsetsInPath(const reflection::Schema* schema,
+ FlatbufferFieldPathT* path) {
+ if (schema == nullptr || !schema->root_table()) {
+ TC3_LOG(ERROR) << "Empty schema provided.";
+ return false;
+ }
+
+ reflection::Object const* type = schema->root_table();
+ for (int i = 0; i < path->field.size(); i++) {
+ const reflection::Field* field = GetFieldOrNull(type, path->field[i].get());
+ if (field == nullptr) {
+ TC3_LOG(ERROR) << "Could not find field: " << path->field[i]->field_name;
+ return false;
+ }
+ path->field[i]->field_name.clear();
+ path->field[i]->field_offset = field->offset();
+
+ // Descend.
+ if (i < path->field.size() - 1) {
+ if (field->type()->base_type() != reflection::Obj) {
+ TC3_LOG(ERROR) << "Field: " << field->name()->str()
+ << " is not of type `Object`.";
+ return false;
+ }
+ type = schema->objects()->Get(field->type()->index());
+ }
+ }
+ return true;
+}
+
+Variant ParseEnumValue(const reflection::Schema* schema,
+ const reflection::Type* type, StringPiece value) {
+ TC3_DCHECK(IsEnum(type));
+ TC3_CHECK_NE(schema->enums(), nullptr);
+ const auto* enum_values = schema->enums()->Get(type->index())->values();
+ if (enum_values == nullptr) {
+ TC3_LOG(ERROR) << "Enum has no specified values.";
+ return Variant();
+ }
+ for (const reflection::EnumVal* enum_value : *enum_values) {
+ if (value.Equals(StringPiece(enum_value->name()->c_str(),
+ enum_value->name()->size()))) {
+ const int64 value = enum_value->value();
+ switch (type->base_type()) {
+ case reflection::BaseType::Byte:
+ return Variant(static_cast<int8>(value));
+ case reflection::BaseType::UByte:
+ return Variant(static_cast<uint8>(value));
+ case reflection::BaseType::Short:
+ return Variant(static_cast<int16>(value));
+ case reflection::BaseType::UShort:
+ return Variant(static_cast<uint16>(value));
+ case reflection::BaseType::Int:
+ return Variant(static_cast<int32>(value));
+ case reflection::BaseType::UInt:
+ return Variant(static_cast<uint32>(value));
+ case reflection::BaseType::Long:
+ return Variant(value);
+ case reflection::BaseType::ULong:
+ return Variant(static_cast<uint64>(value));
+ default:
+ break;
+ }
+ }
+ }
+ return Variant();
+}
+
+} // namespace libtextclassifier3
diff --git a/native/utils/flatbuffers/reflection.h b/native/utils/flatbuffers/reflection.h
new file mode 100644
index 0000000..9a0fec7
--- /dev/null
+++ b/native/utils/flatbuffers/reflection.h
@@ -0,0 +1,197 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Utility functions for working with FlatBuffers.
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_FLATBUFFERS_REFLECTION_H_
+#define LIBTEXTCLASSIFIER_UTILS_FLATBUFFERS_REFLECTION_H_
+
+#include "utils/flatbuffers/flatbuffers_generated.h"
+#include "utils/optional.h"
+#include "utils/strings/stringpiece.h"
+#include "utils/variant.h"
+#include "flatbuffers/reflection.h"
+#include "flatbuffers/reflection_generated.h"
+
+namespace libtextclassifier3 {
+
+template <typename T>
+struct flatbuffers_base_type {
+ static const reflection::BaseType value;
+};
+
+template <typename T>
+inline const reflection::BaseType flatbuffers_base_type<T>::value =
+ reflection::None;
+
+template <>
+inline const reflection::BaseType flatbuffers_base_type<bool>::value =
+ reflection::Bool;
+
+template <>
+inline const reflection::BaseType flatbuffers_base_type<int8>::value =
+ reflection::Byte;
+
+template <>
+inline const reflection::BaseType flatbuffers_base_type<uint8>::value =
+ reflection::UByte;
+
+template <>
+inline const reflection::BaseType flatbuffers_base_type<int16>::value =
+ reflection::Short;
+
+template <>
+inline const reflection::BaseType flatbuffers_base_type<uint16>::value =
+ reflection::UShort;
+
+template <>
+inline const reflection::BaseType flatbuffers_base_type<int32>::value =
+ reflection::Int;
+
+template <>
+inline const reflection::BaseType flatbuffers_base_type<uint32>::value =
+ reflection::UInt;
+
+template <>
+inline const reflection::BaseType flatbuffers_base_type<int64>::value =
+ reflection::Long;
+
+template <>
+inline const reflection::BaseType flatbuffers_base_type<uint64>::value =
+ reflection::ULong;
+
+template <>
+inline const reflection::BaseType flatbuffers_base_type<float>::value =
+ reflection::Float;
+
+template <>
+inline const reflection::BaseType flatbuffers_base_type<double>::value =
+ reflection::Double;
+
+template <>
+inline const reflection::BaseType flatbuffers_base_type<StringPiece>::value =
+ reflection::String;
+
+template <reflection::BaseType>
+struct flatbuffers_cpp_type;
+
+template <>
+struct flatbuffers_cpp_type<reflection::BaseType::Bool> {
+ using value = bool;
+};
+
+template <>
+struct flatbuffers_cpp_type<reflection::BaseType::Byte> {
+ using value = int8;
+};
+
+template <>
+struct flatbuffers_cpp_type<reflection::BaseType::UByte> {
+ using value = uint8;
+};
+
+template <>
+struct flatbuffers_cpp_type<reflection::BaseType::Short> {
+ using value = int16;
+};
+
+template <>
+struct flatbuffers_cpp_type<reflection::BaseType::UShort> {
+ using value = uint16;
+};
+
+template <>
+struct flatbuffers_cpp_type<reflection::BaseType::Int> {
+ using value = int32;
+};
+
+template <>
+struct flatbuffers_cpp_type<reflection::BaseType::UInt> {
+ using value = uint32;
+};
+
+template <>
+struct flatbuffers_cpp_type<reflection::BaseType::Long> {
+ using value = int64;
+};
+
+template <>
+struct flatbuffers_cpp_type<reflection::BaseType::ULong> {
+ using value = uint64;
+};
+
+template <>
+struct flatbuffers_cpp_type<reflection::BaseType::Float> {
+ using value = float;
+};
+
+template <>
+struct flatbuffers_cpp_type<reflection::BaseType::Double> {
+ using value = double;
+};
+
+// Gets the field information for a field name, returns nullptr if the
+// field was not defined.
+const reflection::Field* GetFieldOrNull(const reflection::Object* type,
+ const StringPiece field_name);
+
+// Gets the field information for a field offet, returns nullptr if no field was
+// defined with the given offset.
+const reflection::Field* GetFieldOrNull(const reflection::Object* type,
+ const int field_offset);
+
+// Gets a field by name or offset, returns nullptr if no field was found.
+const reflection::Field* GetFieldOrNull(const reflection::Object* type,
+ const StringPiece field_name,
+ const int field_offset);
+
+// Gets a field by a field spec, either by name or offset. Returns nullptr if no
+// such field was found.
+const reflection::Field* GetFieldOrNull(const reflection::Object* type,
+ const FlatbufferField* field);
+
+const reflection::Field* GetFieldOrNull(const reflection::Object* type,
+ const FlatbufferFieldT* field);
+
+// Gets the type information for the given type name or nullptr if not
+// specified.
+const reflection::Object* TypeForName(const reflection::Schema* schema,
+ const StringPiece type_name);
+
+// Gets the type id for a type name.
+Optional<int> TypeIdForName(const reflection::Schema* schema,
+ const StringPiece type_name);
+
+// Gets the type id for a type.
+Optional<int> TypeIdForObject(const reflection::Schema* schema,
+ const reflection::Object* type);
+
+// Resolves field lookups by name to the concrete field offsets.
+bool SwapFieldNamesForOffsetsInPath(const reflection::Schema* schema,
+ FlatbufferFieldPathT* path);
+
+// Checks whether a type denotes an enum.
+inline bool IsEnum(const reflection::Type* type) {
+ return flatbuffers::IsInteger(type->base_type()) && type->index() >= 0;
+}
+
+// Parses an enum value.
+Variant ParseEnumValue(const reflection::Schema* schema,
+ const reflection::Type* type, StringPiece value);
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_FLATBUFFERS_REFLECTION_H_
diff --git a/native/utils/flatbuffers/reflection_test.cc b/native/utils/flatbuffers/reflection_test.cc
new file mode 100644
index 0000000..9a56b77
--- /dev/null
+++ b/native/utils/flatbuffers/reflection_test.cc
@@ -0,0 +1,69 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/flatbuffers/reflection.h"
+
+#include "utils/flatbuffers/flatbuffers_generated.h"
+#include "utils/flatbuffers/test-utils.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "flatbuffers/reflection_generated.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+TEST(ReflectionTest, ResolvesFieldOffsets) {
+ std::string metadata_buffer = LoadTestMetadata();
+ const reflection::Schema* schema =
+ flatbuffers::GetRoot<reflection::Schema>(metadata_buffer.data());
+ FlatbufferFieldPathT path;
+ path.field.emplace_back(new FlatbufferFieldT);
+ path.field.back()->field_name = "flight_number";
+ path.field.emplace_back(new FlatbufferFieldT);
+ path.field.back()->field_name = "carrier_code";
+
+ EXPECT_TRUE(SwapFieldNamesForOffsetsInPath(schema, &path));
+
+ EXPECT_THAT(path.field[0]->field_name, testing::IsEmpty());
+ EXPECT_EQ(14, path.field[0]->field_offset);
+ EXPECT_THAT(path.field[1]->field_name, testing::IsEmpty());
+ EXPECT_EQ(4, path.field[1]->field_offset);
+}
+
+TEST(ReflectionTest, ParseEnumValuesByName) {
+ std::string metadata_buffer = LoadTestMetadata();
+ const reflection::Schema* schema =
+ flatbuffers::GetRoot<reflection::Schema>(metadata_buffer.data());
+ const reflection::Field* field =
+ GetFieldOrNull(schema->root_table(), "enum_value");
+
+ Variant enum_value_0 = ParseEnumValue(schema, field->type(), "VALUE_0");
+ Variant enum_value_1 = ParseEnumValue(schema, field->type(), "VALUE_1");
+ Variant enum_no_value = ParseEnumValue(schema, field->type(), "NO_VALUE");
+
+ EXPECT_TRUE(enum_value_0.HasValue());
+ EXPECT_EQ(enum_value_0.GetType(), Variant::TYPE_INT_VALUE);
+ EXPECT_EQ(enum_value_0.Value<int>(), 0);
+
+ EXPECT_TRUE(enum_value_1.HasValue());
+ EXPECT_EQ(enum_value_1.GetType(), Variant::TYPE_INT_VALUE);
+ EXPECT_EQ(enum_value_1.Value<int>(), 1);
+
+ EXPECT_FALSE(enum_no_value.HasValue());
+}
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/native/utils/flatbuffers/test-utils.h b/native/utils/flatbuffers/test-utils.h
new file mode 100644
index 0000000..dbfc732
--- /dev/null
+++ b/native/utils/flatbuffers/test-utils.h
@@ -0,0 +1,59 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Common test utils.
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_FLATBUFFERS_TEST_UTILS_H_
+#define LIBTEXTCLASSIFIER_UTILS_FLATBUFFERS_TEST_UTILS_H_
+
+#include <fstream>
+#include <string>
+
+#include "utils/flatbuffers/flatbuffers.h"
+#include "utils/flatbuffers/flatbuffers_generated.h"
+#include "utils/test-data-test-utils.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+
+inline std::string LoadTestMetadata() {
+ std::ifstream test_config_stream(
+ GetTestDataPath("utils/flatbuffers/flatbuffers_test_extended.bfbs"));
+ return std::string((std::istreambuf_iterator<char>(test_config_stream)),
+ (std::istreambuf_iterator<char>()));
+}
+
+// Creates a flatbuffer field path from a dot separated field path string.
+inline std::unique_ptr<FlatbufferFieldPathT> CreateUnpackedFieldPath(
+ const std::vector<std::string>& fields) {
+ std::unique_ptr<FlatbufferFieldPathT> path(new FlatbufferFieldPathT);
+ for (const std::string& field : fields) {
+ path->field.emplace_back(new FlatbufferFieldT);
+ path->field.back()->field_name = field;
+ }
+ return path;
+}
+
+inline OwnedFlatbuffer<FlatbufferFieldPath, std::string> CreateFieldPath(
+ const std::vector<std::string>& fields) {
+ std::unique_ptr<FlatbufferFieldPathT> path = CreateUnpackedFieldPath(fields);
+ return OwnedFlatbuffer<FlatbufferFieldPath, std::string>(
+ PackFlatbuffer<FlatbufferFieldPath>(path.get()));
+}
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_FLATBUFFERS_TEST_UTILS_H_
diff --git a/native/utils/grammar/analyzer.cc b/native/utils/grammar/analyzer.cc
new file mode 100644
index 0000000..b760442
--- /dev/null
+++ b/native/utils/grammar/analyzer.cc
@@ -0,0 +1,92 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/grammar/analyzer.h"
+
+#include "utils/base/status_macros.h"
+#include "utils/utf8/unicodetext.h"
+
+namespace libtextclassifier3::grammar {
+
+Analyzer::Analyzer(const UniLib* unilib, const RulesSet* rules_set)
+ // TODO(smillius): Add tokenizer options to `RulesSet`.
+ : owned_tokenizer_(new Tokenizer(libtextclassifier3::TokenizationType_ICU,
+ unilib,
+ /*codepoint_ranges=*/{},
+ /*internal_tokenizer_codepoint_ranges=*/{},
+ /*split_on_script_change=*/false,
+ /*icu_preserve_whitespace_tokens=*/false)),
+ tokenizer_(owned_tokenizer_.get()),
+ parser_(unilib, rules_set),
+ semantic_evaluator_(rules_set->semantic_values_schema() != nullptr
+ ? flatbuffers::GetRoot<reflection::Schema>(
+ rules_set->semantic_values_schema()->data())
+ : nullptr) {}
+
+Analyzer::Analyzer(const UniLib* unilib, const RulesSet* rules_set,
+ const Tokenizer* tokenizer)
+ : tokenizer_(tokenizer),
+ parser_(unilib, rules_set),
+ semantic_evaluator_(rules_set->semantic_values_schema() != nullptr
+ ? flatbuffers::GetRoot<reflection::Schema>(
+ rules_set->semantic_values_schema()->data())
+ : nullptr) {}
+
+StatusOr<std::vector<EvaluatedDerivation>> Analyzer::Parse(
+ const TextContext& input, UnsafeArena* arena,
+ bool deduplicate_derivations) const {
+ std::vector<EvaluatedDerivation> result;
+
+ std::vector<Derivation> derivations = parser_.Parse(input, arena);
+ if (deduplicate_derivations) {
+ derivations = DeduplicateDerivations<Derivation>(derivations);
+ }
+ // Evaluate each derivation.
+ for (const Derivation& derivation : derivations) {
+ if (derivation.IsValid()) {
+ TC3_ASSIGN_OR_RETURN(const SemanticValue* value,
+ semantic_evaluator_.Eval(input, derivation, arena));
+ result.emplace_back(
+ EvaluatedDerivation{{/*parse_tree=*/derivation.parse_tree,
+ /*rule_id=*/derivation.rule_id},
+ /*semantic_value=*/value});
+ }
+ }
+
+ return result;
+}
+
+StatusOr<std::vector<EvaluatedDerivation>> Analyzer::Parse(
+ const UnicodeText& text, const std::vector<Locale>& locales,
+ UnsafeArena* arena, bool deduplicate_derivations) const {
+ return Parse(BuildTextContextForInput(text, locales), arena,
+ deduplicate_derivations);
+}
+
+TextContext Analyzer::BuildTextContextForInput(
+ const UnicodeText& text, const std::vector<Locale>& locales) const {
+ TextContext context;
+ context.text = UnicodeText(text, /*do_copy=*/false);
+ context.tokens = tokenizer_->Tokenize(context.text);
+ context.codepoints = context.text.Codepoints();
+ context.codepoints.push_back(context.text.end());
+ context.locales = locales;
+ context.context_span.first = 0;
+ context.context_span.second = context.tokens.size();
+ return context;
+}
+
+} // namespace libtextclassifier3::grammar
diff --git a/native/utils/grammar/analyzer.h b/native/utils/grammar/analyzer.h
new file mode 100644
index 0000000..6d1dd46
--- /dev/null
+++ b/native/utils/grammar/analyzer.h
@@ -0,0 +1,64 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_ANALYZER_H_
+#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_ANALYZER_H_
+
+#include "utils/base/arena.h"
+#include "utils/base/statusor.h"
+#include "utils/grammar/evaluated-derivation.h"
+#include "utils/grammar/parsing/parser.h"
+#include "utils/grammar/semantics/composer.h"
+#include "utils/grammar/text-context.h"
+#include "utils/i18n/locale.h"
+#include "utils/tokenizer.h"
+#include "utils/utf8/unilib.h"
+
+namespace libtextclassifier3::grammar {
+
+// An analyzer that parses and semantically evaluates an input text with a
+// grammar.
+class Analyzer {
+ public:
+ explicit Analyzer(const UniLib* unilib, const RulesSet* rules_set);
+ explicit Analyzer(const UniLib* unilib, const RulesSet* rules_set,
+ const Tokenizer* tokenizer);
+
+ // Parses and evaluates an input.
+ StatusOr<std::vector<EvaluatedDerivation>> Parse(
+ const TextContext& input, UnsafeArena* arena,
+ bool deduplicate_derivations = true) const;
+
+ StatusOr<std::vector<EvaluatedDerivation>> Parse(
+ const UnicodeText& text, const std::vector<Locale>& locales,
+ UnsafeArena* arena, bool deduplicate_derivations = true) const;
+
+ // Pre-processes an input text for parsing.
+ TextContext BuildTextContextForInput(
+ const UnicodeText& text, const std::vector<Locale>& locales = {}) const;
+
+ const Parser& parser() const { return parser_; }
+
+ private:
+ std::unique_ptr<Tokenizer> owned_tokenizer_;
+ const Tokenizer* tokenizer_;
+ Parser parser_;
+ SemanticComposer semantic_evaluator_;
+};
+
+} // namespace libtextclassifier3::grammar
+
+#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_ANALYZER_H_
diff --git a/native/utils/grammar/analyzer_test.cc b/native/utils/grammar/analyzer_test.cc
new file mode 100644
index 0000000..3905b70
--- /dev/null
+++ b/native/utils/grammar/analyzer_test.cc
@@ -0,0 +1,99 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/grammar/analyzer.h"
+
+#include "utils/grammar/testing/utils.h"
+#include "utils/grammar/types.h"
+#include "utils/grammar/utils/rules.h"
+#include "utils/utf8/unicodetext.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3::grammar {
+namespace {
+
+using ::testing::SizeIs;
+
+class AnalyzerTest : public GrammarTest {};
+
+TEST_F(AnalyzerTest, ParsesTextWithGrammar) {
+ RulesSetT model;
+
+ // Add semantic values schema.
+ model.semantic_values_schema.assign(semantic_values_schema_.buffer().begin(),
+ semantic_values_schema_.buffer().end());
+
+ // Define rules and semantics.
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ Rules rules(locale_shard_map);
+ rules.Add("<month>", {"january"},
+ static_cast<CallbackId>(DefaultCallback::kSemanticExpression),
+ /*callback_param=*/model.semantic_expression.size());
+ model.semantic_expression.push_back(CreatePrimitiveConstExpression(1));
+
+ rules.Add("<month>", {"february"},
+ static_cast<CallbackId>(DefaultCallback::kSemanticExpression),
+ /*callback_param=*/model.semantic_expression.size());
+ model.semantic_expression.push_back(CreatePrimitiveConstExpression(2));
+
+ const int kMonth = 0;
+ rules.Add("<month_rule>", {"<month>"},
+ static_cast<CallbackId>(DefaultCallback::kRootRule), kMonth);
+ rules.Finalize().Serialize(/*include_debug_information=*/false, &model);
+ const std::string model_buffer = PackFlatbuffer<RulesSet>(&model);
+
+ Analyzer analyzer(unilib_.get(),
+ flatbuffers::GetRoot<RulesSet>(model_buffer.data()));
+
+ {
+ auto maybe_results = analyzer.Parse(
+ UTF8ToUnicodeText("The month is January 2020", /*do_copy=*/false),
+ /*locales=*/{}, &arena_);
+ EXPECT_TRUE(maybe_results.ok());
+
+ const std::vector<EvaluatedDerivation> results = maybe_results.ValueOrDie();
+ EXPECT_THAT(results, SizeIs(1));
+
+ // Check parse tree.
+ EXPECT_THAT(results[0], IsDerivation(kMonth /* rule_id */, 13 /* begin */,
+ 20 /* end */));
+
+ // Check semantic result.
+ EXPECT_EQ(results[0].value->Value<int32>(), 1);
+ }
+
+ {
+ auto maybe_results =
+ analyzer.Parse(UTF8ToUnicodeText("february", /*do_copy=*/false),
+ /*locales=*/{}, &arena_);
+ EXPECT_TRUE(maybe_results.ok());
+
+ const std::vector<EvaluatedDerivation> results = maybe_results.ValueOrDie();
+ EXPECT_THAT(results, SizeIs(1));
+
+ // Check parse tree.
+ EXPECT_THAT(results[0],
+ IsDerivation(kMonth /* rule_id */, 0 /* begin */, 8 /* end */));
+
+ // Check semantic result.
+ EXPECT_EQ(results[0].value->Value<int32>(), 2);
+ }
+}
+
+} // namespace
+} // namespace libtextclassifier3::grammar
diff --git a/native/utils/grammar/callback-delegate.h b/native/utils/grammar/callback-delegate.h
deleted file mode 100644
index a5424dd..0000000
--- a/native/utils/grammar/callback-delegate.h
+++ /dev/null
@@ -1,43 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_CALLBACK_DELEGATE_H_
-#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_CALLBACK_DELEGATE_H_
-
-#include "utils/base/integral_types.h"
-#include "utils/grammar/match.h"
-#include "utils/grammar/rules_generated.h"
-#include "utils/grammar/types.h"
-
-namespace libtextclassifier3::grammar {
-
-class Matcher;
-
-// CallbackDelegate is an interface and default implementation used by the
-// grammar matcher to dispatch rule matches.
-class CallbackDelegate {
- public:
- virtual ~CallbackDelegate() = default;
-
- // This is called by the matcher whenever it finds a match for a rule to
- // which a callback is attached.
- virtual void MatchFound(const Match* match, const CallbackId callback_id,
- const int64 callback_param, Matcher* matcher) {}
-};
-
-} // namespace libtextclassifier3::grammar
-
-#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_CALLBACK_DELEGATE_H_
diff --git a/native/utils/grammar/evaluated-derivation.h b/native/utils/grammar/evaluated-derivation.h
new file mode 100644
index 0000000..4ae409d
--- /dev/null
+++ b/native/utils/grammar/evaluated-derivation.h
@@ -0,0 +1,32 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_EVALUATED_DERIVATION_H_
+#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_EVALUATED_DERIVATION_H_
+
+#include "utils/grammar/parsing/derivation.h"
+#include "utils/grammar/semantics/value.h"
+
+namespace libtextclassifier3::grammar {
+
+// A parse tree for a root rule and its semantic value.
+struct EvaluatedDerivation : public Derivation {
+ const SemanticValue* value;
+};
+
+}; // namespace libtextclassifier3::grammar
+
+#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_EVALUATED_DERIVATION_H_
diff --git a/native/utils/grammar/lexer.cc b/native/utils/grammar/lexer.cc
deleted file mode 100644
index 3a2d0d3..0000000
--- a/native/utils/grammar/lexer.cc
+++ /dev/null
@@ -1,321 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/grammar/lexer.h"
-
-#include <unordered_map>
-
-#include "annotator/types.h"
-#include "utils/zlib/zlib.h"
-#include "utils/zlib/zlib_regex.h"
-
-namespace libtextclassifier3::grammar {
-namespace {
-
-inline bool CheckMemoryUsage(const Matcher* matcher) {
- // The maximum memory usage for matching.
- constexpr int kMaxMemoryUsage = 1 << 20;
- return matcher->ArenaSize() <= kMaxMemoryUsage;
-}
-
-Match* CheckedAddMatch(const Nonterm nonterm,
- const CodepointSpan codepoint_span,
- const int match_offset, const int16 type,
- Matcher* matcher) {
- if (nonterm == kUnassignedNonterm || !CheckMemoryUsage(matcher)) {
- return nullptr;
- }
- return matcher->AllocateAndInitMatch<Match>(nonterm, codepoint_span,
- match_offset, type);
-}
-
-void CheckedEmit(const Nonterm nonterm, const CodepointSpan codepoint_span,
- const int match_offset, int16 type, Matcher* matcher) {
- if (nonterm != kUnassignedNonterm && CheckMemoryUsage(matcher)) {
- matcher->AddMatch(matcher->AllocateAndInitMatch<Match>(
- nonterm, codepoint_span, match_offset, type));
- }
-}
-
-int MapCodepointToTokenPaddingIfPresent(
- const std::unordered_map<CodepointIndex, CodepointIndex>& token_alignment,
- const int start) {
- const auto it = token_alignment.find(start);
- if (it != token_alignment.end()) {
- return it->second;
- }
- return start;
-}
-
-} // namespace
-
-Lexer::Lexer(const UniLib* unilib, const RulesSet* rules)
- : unilib_(*unilib),
- rules_(rules),
- regex_annotators_(BuildRegexAnnotator(unilib_, rules)) {}
-
-std::vector<Lexer::RegexAnnotator> Lexer::BuildRegexAnnotator(
- const UniLib& unilib, const RulesSet* rules) const {
- std::vector<Lexer::RegexAnnotator> result;
- if (rules->regex_annotator() != nullptr) {
- std::unique_ptr<ZlibDecompressor> decompressor =
- ZlibDecompressor::Instance();
- result.reserve(rules->regex_annotator()->size());
- for (const RulesSet_::RegexAnnotator* regex_annotator :
- *rules->regex_annotator()) {
- result.push_back(
- {UncompressMakeRegexPattern(unilib_, regex_annotator->pattern(),
- regex_annotator->compressed_pattern(),
- rules->lazy_regex_compilation(),
- decompressor.get()),
- regex_annotator->nonterminal()});
- }
- }
- return result;
-}
-
-void Lexer::Emit(const Symbol& symbol, const RulesSet_::Nonterminals* nonterms,
- Matcher* matcher) const {
- switch (symbol.type) {
- case Symbol::Type::TYPE_MATCH: {
- // Just emit the match.
- matcher->AddMatch(symbol.match);
- return;
- }
- case Symbol::Type::TYPE_DIGITS: {
- // Emit <digits> if used by the rules.
- CheckedEmit(nonterms->digits_nt(), symbol.codepoint_span,
- symbol.match_offset, Match::kDigitsType, matcher);
-
- // Emit <n_digits> if used by the rules.
- if (nonterms->n_digits_nt() != nullptr) {
- const int num_digits =
- symbol.codepoint_span.second - symbol.codepoint_span.first;
- if (num_digits <= nonterms->n_digits_nt()->size()) {
- CheckedEmit(nonterms->n_digits_nt()->Get(num_digits - 1),
- symbol.codepoint_span, symbol.match_offset,
- Match::kDigitsType, matcher);
- }
- }
- break;
- }
- case Symbol::Type::TYPE_TERM: {
- // Emit <uppercase_token> if used by the rules.
- if (nonterms->uppercase_token_nt() != 0 &&
- unilib_.IsUpperText(
- UTF8ToUnicodeText(symbol.lexeme, /*do_copy=*/false))) {
- CheckedEmit(nonterms->uppercase_token_nt(), symbol.codepoint_span,
- symbol.match_offset, Match::kTokenType, matcher);
- }
- break;
- }
- default:
- break;
- }
-
- // Emit the token as terminal.
- if (CheckMemoryUsage(matcher)) {
- matcher->AddTerminal(symbol.codepoint_span, symbol.match_offset,
- symbol.lexeme);
- }
-
- // Emit <token> if used by rules.
- CheckedEmit(nonterms->token_nt(), symbol.codepoint_span, symbol.match_offset,
- Match::kTokenType, matcher);
-}
-
-Lexer::Symbol::Type Lexer::GetSymbolType(
- const UnicodeText::const_iterator& it) const {
- if (unilib_.IsPunctuation(*it)) {
- return Symbol::Type::TYPE_PUNCTUATION;
- } else if (unilib_.IsDigit(*it)) {
- return Symbol::Type::TYPE_DIGITS;
- } else {
- return Symbol::Type::TYPE_TERM;
- }
-}
-
-void Lexer::ProcessToken(const StringPiece value, const int prev_token_end,
- const CodepointSpan codepoint_span,
- std::vector<Lexer::Symbol>* symbols) const {
- // Possibly split token.
- UnicodeText token_unicode = UTF8ToUnicodeText(value.data(), value.size(),
- /*do_copy=*/false);
- int last_end = prev_token_end;
- auto token_end = token_unicode.end();
- auto it = token_unicode.begin();
- Symbol::Type type = GetSymbolType(it);
- CodepointIndex sub_token_start = codepoint_span.first;
- while (it != token_end) {
- auto next = std::next(it);
- int num_codepoints = 1;
- Symbol::Type next_type;
- while (next != token_end) {
- next_type = GetSymbolType(next);
- if (type == Symbol::Type::TYPE_PUNCTUATION || next_type != type) {
- break;
- }
- ++next;
- ++num_codepoints;
- }
- symbols->push_back(Symbol{
- type, CodepointSpan{sub_token_start, sub_token_start + num_codepoints},
- /*match_offset=*/last_end,
- /*lexeme=*/
- StringPiece(it.utf8_data(), next.utf8_data() - it.utf8_data())});
- last_end = sub_token_start + num_codepoints;
- it = next;
- type = next_type;
- sub_token_start = last_end;
- }
-}
-
-void Lexer::Process(const UnicodeText& text, const std::vector<Token>& tokens,
- const std::vector<AnnotatedSpan>* annotations,
- Matcher* matcher) const {
- return Process(text, tokens.begin(), tokens.end(), annotations, matcher);
-}
-
-void Lexer::Process(const UnicodeText& text,
- const std::vector<Token>::const_iterator& begin,
- const std::vector<Token>::const_iterator& end,
- const std::vector<AnnotatedSpan>* annotations,
- Matcher* matcher) const {
- if (begin == end) {
- return;
- }
-
- const RulesSet_::Nonterminals* nonterminals = rules_->nonterminals();
-
- // Initialize processing of new text.
- CodepointIndex prev_token_end = 0;
- std::vector<Symbol> symbols;
- matcher->Reset();
-
- // The matcher expects the terminals and non-terminals it received to be in
- // non-decreasing end-position order. The sorting above makes sure the
- // pre-defined matches adhere to that order.
- // Ideally, we would just have to emit a predefined match whenever we see that
- // the next token we feed would be ending later.
- // But as we implicitly ignore whitespace, we have to merge preceding
- // whitespace to the match start so that tokens and non-terminals fed appear
- // as next to each other without whitespace.
- // We keep track of real token starts and precending whitespace in
- // `token_match_start`, so that we can extend a predefined match's start to
- // include the preceding whitespace.
- std::unordered_map<CodepointIndex, CodepointIndex> token_match_start;
-
- // Add start symbols.
- if (Match* match =
- CheckedAddMatch(nonterminals->start_nt(), CodepointSpan{0, 0},
- /*match_offset=*/0, Match::kBreakType, matcher)) {
- symbols.push_back(Symbol(match));
- }
- if (Match* match =
- CheckedAddMatch(nonterminals->wordbreak_nt(), CodepointSpan{0, 0},
- /*match_offset=*/0, Match::kBreakType, matcher)) {
- symbols.push_back(Symbol(match));
- }
-
- for (auto token_it = begin; token_it != end; token_it++) {
- const Token& token = *token_it;
-
- // Record match starts for token boundaries, so that we can snap pre-defined
- // matches to it.
- if (prev_token_end != token.start) {
- token_match_start[token.start] = prev_token_end;
- }
-
- ProcessToken(token.value,
- /*prev_token_end=*/prev_token_end,
- CodepointSpan{token.start, token.end}, &symbols);
- prev_token_end = token.end;
-
- // Add word break symbol if used by the grammar.
- if (Match* match = CheckedAddMatch(
- nonterminals->wordbreak_nt(), CodepointSpan{token.end, token.end},
- /*match_offset=*/token.end, Match::kBreakType, matcher)) {
- symbols.push_back(Symbol(match));
- }
- }
-
- // Add end symbol if used by the grammar.
- if (Match* match = CheckedAddMatch(
- nonterminals->end_nt(), CodepointSpan{prev_token_end, prev_token_end},
- /*match_offset=*/prev_token_end, Match::kBreakType, matcher)) {
- symbols.push_back(Symbol(match));
- }
-
- // Add matches based on annotations.
- auto annotation_nonterminals = nonterminals->annotation_nt();
- if (annotation_nonterminals != nullptr && annotations != nullptr) {
- for (const AnnotatedSpan& annotated_span : *annotations) {
- const ClassificationResult& classification =
- annotated_span.classification.front();
- if (auto entry = annotation_nonterminals->LookupByKey(
- classification.collection.c_str())) {
- AnnotationMatch* match = matcher->AllocateAndInitMatch<AnnotationMatch>(
- entry->value(), annotated_span.span,
- /*match_offset=*/
- MapCodepointToTokenPaddingIfPresent(token_match_start,
- annotated_span.span.first),
- Match::kAnnotationMatch);
- match->annotation = &classification;
- symbols.push_back(Symbol(match));
- }
- }
- }
-
- // Add regex annotator matches for the range covered by the tokens.
- for (const RegexAnnotator& regex_annotator : regex_annotators_) {
- std::unique_ptr<UniLib::RegexMatcher> regex_matcher =
- regex_annotator.pattern->Matcher(UnicodeText::Substring(
- text, begin->start, prev_token_end, /*do_copy=*/false));
- int status = UniLib::RegexMatcher::kNoError;
- while (regex_matcher->Find(&status) &&
- status == UniLib::RegexMatcher::kNoError) {
- const CodepointSpan span = {
- regex_matcher->Start(0, &status) + begin->start,
- regex_matcher->End(0, &status) + begin->start};
- if (Match* match =
- CheckedAddMatch(regex_annotator.nonterm, span, /*match_offset=*/
- MapCodepointToTokenPaddingIfPresent(
- token_match_start, span.first),
- Match::kUnknownType, matcher)) {
- symbols.push_back(Symbol(match));
- }
- }
- }
-
- std::sort(symbols.begin(), symbols.end(),
- [](const Symbol& a, const Symbol& b) {
- // Sort by increasing (end, start) position to guarantee the
- // matcher requirement that the tokens are fed in non-decreasing
- // end position order.
- return std::tie(a.codepoint_span.second, a.codepoint_span.first) <
- std::tie(b.codepoint_span.second, b.codepoint_span.first);
- });
-
- // Emit symbols to matcher.
- for (const Symbol& symbol : symbols) {
- Emit(symbol, nonterminals, matcher);
- }
-
- // Finish the matching.
- matcher->Finish();
-}
-
-} // namespace libtextclassifier3::grammar
diff --git a/native/utils/grammar/lexer.h b/native/utils/grammar/lexer.h
deleted file mode 100644
index ca31c25..0000000
--- a/native/utils/grammar/lexer.h
+++ /dev/null
@@ -1,176 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-// This is a lexer that runs off the tokenizer and outputs the tokens to a
-// grammar matcher. The tokens it forwards are the same as the ones produced
-// by the tokenizer, but possibly further split and normalized (downcased).
-// Examples:
-//
-// - single character tokens for punctuation (e.g., AddTerminal("?"))
-//
-// - a string of letters (e.g., "Foo" -- it calls AddTerminal() on "foo")
-//
-// - a string of digits (e.g., AddTerminal("37"))
-//
-// In addition to the terminal tokens above, it also outputs certain
-// special nonterminals:
-//
-// - a <token> nonterminal, which it outputs in addition to the
-// regular AddTerminal() call for every token
-//
-// - a <digits> nonterminal, which it outputs in addition to
-// the regular AddTerminal() call for each string of digits
-//
-// - <N_digits> nonterminals, where N is the length of the string of
-// digits. By default the maximum N that will be output is 20. This
-// may be changed at compile time by kMaxNDigitsLength. For instance,
-// "123" will produce a <3_digits> nonterminal, "1234567" will produce
-// a <7_digits> nonterminal.
-//
-// It does not output any whitespace. Instead, whitespace gets absorbed into
-// the token that follows them in the text.
-// For example, if the text contains:
-//
-// ...hello there world...
-// | | |
-// offset=16 39 52
-//
-// then the output will be:
-//
-// "hello" [?, 16)
-// "there" [16, 44) <-- note "16" NOT "39"
-// "world" [44, ?) <-- note "44" NOT "52"
-//
-// This makes it appear to the Matcher as if the tokens are adjacent -- so
-// whitespace is simply ignored.
-//
-// A minor optimization: We don't bother to output nonterminals if the grammar
-// rules don't reference them.
-
-#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_LEXER_H_
-#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_LEXER_H_
-
-#include "annotator/types.h"
-#include "utils/grammar/matcher.h"
-#include "utils/grammar/rules_generated.h"
-#include "utils/grammar/types.h"
-#include "utils/strings/stringpiece.h"
-#include "utils/utf8/unicodetext.h"
-#include "utils/utf8/unilib.h"
-
-namespace libtextclassifier3::grammar {
-
-class Lexer {
- public:
- explicit Lexer(const UniLib* unilib, const RulesSet* rules);
-
- // Processes a tokenized text. Classifies the tokens and feeds them to the
- // matcher.
- // The provided annotations will be fed to the matcher alongside the tokens.
- // NOTE: The `annotations` need to outlive any dependent processing.
- void Process(const UnicodeText& text, const std::vector<Token>& tokens,
- const std::vector<AnnotatedSpan>* annotations,
- Matcher* matcher) const;
- void Process(const UnicodeText& text,
- const std::vector<Token>::const_iterator& begin,
- const std::vector<Token>::const_iterator& end,
- const std::vector<AnnotatedSpan>* annotations,
- Matcher* matcher) const;
-
- private:
- // A lexical symbol with an identified meaning that represents raw tokens,
- // token categories or predefined text matches.
- // It is the unit fed to the grammar matcher.
- struct Symbol {
- // The type of the lexical symbol.
- enum class Type {
- // A raw token.
- TYPE_TERM,
-
- // A symbol representing a string of digits.
- TYPE_DIGITS,
-
- // Punctuation characters.
- TYPE_PUNCTUATION,
-
- // A predefined match.
- TYPE_MATCH
- };
-
- explicit Symbol() = default;
-
- // Constructs a symbol of a given type with an anchor in the text.
- Symbol(const Type type, const CodepointSpan codepoint_span,
- const int match_offset, StringPiece lexeme)
- : type(type),
- codepoint_span(codepoint_span),
- match_offset(match_offset),
- lexeme(lexeme) {}
-
- // Constructs a symbol from a pre-defined match.
- explicit Symbol(Match* match)
- : type(Type::TYPE_MATCH),
- codepoint_span(match->codepoint_span),
- match_offset(match->match_offset),
- match(match) {}
-
- // The type of the symbole.
- Type type;
-
- // The span in the text as codepoint offsets.
- CodepointSpan codepoint_span;
-
- // The match start offset (including preceding whitespace) as codepoint
- // offset.
- int match_offset;
-
- // The symbol text value.
- StringPiece lexeme;
-
- // The predefined match.
- Match* match;
- };
-
- // Processes a single token: the token is split and classified into symbols.
- void ProcessToken(const StringPiece value, const int prev_token_end,
- const CodepointSpan codepoint_span,
- std::vector<Symbol>* symbols) const;
-
- // Emits a token to the matcher.
- void Emit(const Symbol& symbol, const RulesSet_::Nonterminals* nonterms,
- Matcher* matcher) const;
-
- // Gets the type of a character.
- Symbol::Type GetSymbolType(const UnicodeText::const_iterator& it) const;
-
- private:
- struct RegexAnnotator {
- std::unique_ptr<UniLib::RegexPattern> pattern;
- Nonterm nonterm;
- };
-
- // Uncompress and build the defined regex annotators.
- std::vector<RegexAnnotator> BuildRegexAnnotator(const UniLib& unilib,
- const RulesSet* rules) const;
-
- const UniLib& unilib_;
- const RulesSet* rules_;
- std::vector<RegexAnnotator> regex_annotators_;
-};
-
-} // namespace libtextclassifier3::grammar
-
-#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_LEXER_H_
diff --git a/native/utils/grammar/match.cc b/native/utils/grammar/match.cc
deleted file mode 100644
index ecf9874..0000000
--- a/native/utils/grammar/match.cc
+++ /dev/null
@@ -1,77 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/grammar/match.h"
-
-#include <algorithm>
-#include <stack>
-
-namespace libtextclassifier3::grammar {
-
-void Traverse(const Match* root,
- const std::function<bool(const Match*)>& node_fn) {
- std::stack<const Match*> open;
- open.push(root);
-
- while (!open.empty()) {
- const Match* node = open.top();
- open.pop();
- if (!node_fn(node) || node->IsLeaf()) {
- continue;
- }
- open.push(node->rhs2);
- if (node->rhs1 != nullptr) {
- open.push(node->rhs1);
- }
- }
-}
-
-const Match* SelectFirst(const Match* root,
- const std::function<bool(const Match*)>& pred_fn) {
- std::stack<const Match*> open;
- open.push(root);
-
- while (!open.empty()) {
- const Match* node = open.top();
- open.pop();
- if (pred_fn(node)) {
- return node;
- }
- if (node->IsLeaf()) {
- continue;
- }
- open.push(node->rhs2);
- if (node->rhs1 != nullptr) {
- open.push(node->rhs1);
- }
- }
-
- return nullptr;
-}
-
-std::vector<const Match*> SelectAll(
- const Match* root, const std::function<bool(const Match*)>& pred_fn) {
- std::vector<const Match*> result;
- Traverse(root, [&result, pred_fn](const Match* node) {
- if (pred_fn(node)) {
- result.push_back(node);
- }
- return true;
- });
- return result;
-}
-
-} // namespace libtextclassifier3::grammar
diff --git a/native/utils/grammar/match.h b/native/utils/grammar/match.h
deleted file mode 100644
index 97edac9..0000000
--- a/native/utils/grammar/match.h
+++ /dev/null
@@ -1,172 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_MATCH_H_
-#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_MATCH_H_
-
-#include <functional>
-#include <vector>
-
-#include "annotator/types.h"
-#include "utils/grammar/types.h"
-#include "utils/strings/stringpiece.h"
-
-namespace libtextclassifier3::grammar {
-
-// Represents a single match that was found for a particular nonterminal.
-// Instances should be created by calling Matcher::AllocateMatch().
-// This uses an arena to allocate matches (and subclasses thereof).
-struct Match {
- static constexpr int16 kUnknownType = 0;
- static constexpr int16 kTokenType = -1;
- static constexpr int16 kDigitsType = -2;
- static constexpr int16 kBreakType = -3;
- static constexpr int16 kAssertionMatch = -4;
- static constexpr int16 kMappingMatch = -5;
- static constexpr int16 kExclusionMatch = -6;
- static constexpr int16 kAnnotationMatch = -7;
-
- void Init(const Nonterm arg_lhs, const CodepointSpan arg_codepoint_span,
- const int arg_match_offset, const int arg_type = kUnknownType) {
- lhs = arg_lhs;
- codepoint_span = arg_codepoint_span;
- match_offset = arg_match_offset;
- type = arg_type;
- rhs1 = nullptr;
- rhs2 = nullptr;
- }
-
- void Init(const Match& other) { *this = other; }
-
- // For binary rule matches: rhs1 != NULL and rhs2 != NULL
- // unary rule matches: rhs1 == NULL and rhs2 != NULL
- // terminal rule matches: rhs1 != NULL and rhs2 == NULL
- // custom leaves: rhs1 == NULL and rhs2 == NULL
- bool IsInteriorNode() const { return rhs2 != nullptr; }
- bool IsLeaf() const { return !rhs2; }
-
- bool IsBinaryRule() const { return rhs1 && rhs2; }
- bool IsUnaryRule() const { return !rhs1 && rhs2; }
- bool IsTerminalRule() const { return rhs1 && !rhs2; }
- bool HasLeadingWhitespace() const {
- return codepoint_span.first != match_offset;
- }
-
- const Match* unary_rule_rhs() const { return rhs2; }
-
- // Used in singly-linked queue of matches for processing.
- Match* next = nullptr;
-
- // Nonterminal we found a match for.
- Nonterm lhs = kUnassignedNonterm;
-
- // Type of the match.
- int16 type = kUnknownType;
-
- // The span in codepoints.
- CodepointSpan codepoint_span;
-
- // The begin codepoint offset used during matching.
- // This is usually including any prefix whitespace.
- int match_offset;
-
- union {
- // The first sub match for binary rules.
- const Match* rhs1 = nullptr;
-
- // The terminal, for terminal rules.
- const char* terminal;
- };
- // First or second sub-match for interior nodes.
- const Match* rhs2 = nullptr;
-};
-
-// Match type to keep track of associated values.
-struct MappingMatch : public Match {
- // The associated id or value.
- int64 id;
-};
-
-// Match type to keep track of assertions.
-struct AssertionMatch : public Match {
- // If true, the assertion is negative and will be valid if the input doesn't
- // match.
- bool negative;
-};
-
-// Match type to define exclusions.
-struct ExclusionMatch : public Match {
- // The nonterminal that denotes matches to exclude from a successful match.
- // So the match is only valid if there is no match of `exclusion_nonterm`
- // spanning the same text range.
- Nonterm exclusion_nonterm;
-};
-
-// Match to represent an annotator annotated span in the grammar.
-struct AnnotationMatch : public Match {
- const ClassificationResult* annotation;
-};
-
-// Utility functions for parse tree traversal.
-
-// Does a preorder traversal, calling `node_fn` on each node.
-// `node_fn` is expected to return whether to continue expanding a node.
-void Traverse(const Match* root,
- const std::function<bool(const Match*)>& node_fn);
-
-// Does a preorder traversal, calling `pred_fn` and returns the first node
-// on which `pred_fn` returns true.
-const Match* SelectFirst(const Match* root,
- const std::function<bool(const Match*)>& pred_fn);
-
-// Does a preorder traversal, selecting all nodes where `pred_fn` returns true.
-std::vector<const Match*> SelectAll(
- const Match* root, const std::function<bool(const Match*)>& pred_fn);
-
-// Selects all terminals from a parse tree.
-inline std::vector<const Match*> SelectTerminals(const Match* root) {
- return SelectAll(root, &Match::IsTerminalRule);
-}
-
-// Selects all leaves from a parse tree.
-inline std::vector<const Match*> SelectLeaves(const Match* root) {
- return SelectAll(root, &Match::IsLeaf);
-}
-
-// Retrieves the first child node of a given type.
-template <typename T>
-const T* SelectFirstOfType(const Match* root, const int16 type) {
- return static_cast<const T*>(SelectFirst(
- root, [type](const Match* node) { return node->type == type; }));
-}
-
-// Retrieves all nodes of a given type.
-template <typename T>
-const std::vector<const T*> SelectAllOfType(const Match* root,
- const int16 type) {
- std::vector<const T*> result;
- Traverse(root, [&result, type](const Match* node) {
- if (node->type == type) {
- result.push_back(static_cast<const T*>(node));
- }
- return true;
- });
- return result;
-}
-
-} // namespace libtextclassifier3::grammar
-
-#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_MATCH_H_
diff --git a/native/utils/grammar/matcher.h b/native/utils/grammar/matcher.h
deleted file mode 100644
index 47bac43..0000000
--- a/native/utils/grammar/matcher.h
+++ /dev/null
@@ -1,246 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-// A token matcher based on context-free grammars.
-//
-// A lexer passes token to the matcher: literal terminal strings and token
-// types. It passes tokens to the matcher by calling AddTerminal() and
-// AddMatch() for literal terminals and token types, respectively.
-// The lexer passes each token along with the [begin, end) position range
-// in which it occurs. So for an input string "Groundhog February 2, 2007", the
-// lexer would tell the matcher that:
-//
-// "Groundhog" occurs at [0, 9)
-// <space> occurs at [9, 10)
-// "February" occurs at [10, 18)
-// <space> occurs at [18, 19)
-// <string_of_digits> occurs at [19, 20)
-// "," occurs at [20, 21)
-// <space> occurs at [21, 22)
-// <string_of_digits> occurs at [22, 26)
-//
-// The lexer passes tokens to the matcher by calling AddTerminal() and
-// AddMatch() for literal terminals and token types, respectively.
-//
-// Although it is unnecessary for this example grammar, a lexer can
-// output multiple tokens for the same input range. So our lexer could
-// additionally output:
-// "2" occurs at [19, 20) // a second token for [19, 20)
-// "2007" occurs at [22, 26)
-// <syllable> occurs at [0, 6) // overlaps with (Groundhog [0, 9))
-// <syllable> occurs at [6, 9)
-// The only constraint on the lexer's output is that it has to pass tokens
-// to the matcher in left-to-right order, strictly speaking, their "end"
-// positions must be nondecreasing. (This constraint allows a more
-// efficient matching algorithm.) The "begin" positions can be in any
-// order.
-//
-// There are two kinds of supported callbacks:
-// (1) OUTPUT: Callbacks are the only output mechanism a matcher has. For each
-// "top-level" rule in your grammar, like the rule for <date> above -- something
-// you're trying to find instances of -- you use a callback which the matcher
-// will invoke every time it finds an instance of <date>.
-// (2) FILTERS:
-// Callbacks allow you to put extra conditions on when a grammar rule
-// applies. In the example grammar, the rule
-//
-// <day> ::= <string_of_digits> // must be between 1 and 31
-//
-// should only apply for *some* <string_of_digits> tokens, not others.
-// By using a filter callback on this rule, you can tell the matcher that
-// an instance of the rule's RHS is only *sometimes* considered an
-// instance of its LHS. The filter callback will get invoked whenever
-// the matcher finds an instance of <string_of_digits>. The callback can
-// look at the digits and decide whether they represent a number between
-// 1 and 31. If so, the callback calls Matcher::AddMatch() to tell the
-// matcher there's a <day> there. If not, the callback simply exits
-// without calling AddMatch().
-//
-// Technically, a FILTER callback can make any number of calls to
-// AddMatch() or even AddTerminal(). But the expected usage is to just
-// make zero or one call to AddMatch(). OUTPUT callbacks are not expected
-// to call either of these -- output callbacks are invoked merely as a
-// side-effect, not in order to decide whether a rule applies or not.
-//
-// In the above example, you would probably use three callbacks. Filter
-// callbacks on the rules for <day> and <year> would check the numeric
-// value of the <string_of_digits>. An output callback on the rule for
-// <date> would simply increment the counter of dates found on the page.
-//
-// Note that callbacks are attached to rules, not to nonterminals. You
-// could have two alternative rules for <date> and use a different
-// callback for each one.
-
-#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_MATCHER_H_
-#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_MATCHER_H_
-
-#include <array>
-#include <functional>
-#include <vector>
-
-#include "annotator/types.h"
-#include "utils/base/arena.h"
-#include "utils/grammar/callback-delegate.h"
-#include "utils/grammar/match.h"
-#include "utils/grammar/rules_generated.h"
-#include "utils/strings/stringpiece.h"
-#include "utils/utf8/unilib.h"
-
-namespace libtextclassifier3::grammar {
-
-class Matcher {
- public:
- explicit Matcher(const UniLib* unilib, const RulesSet* rules,
- const std::vector<const RulesSet_::Rules*> rules_shards,
- CallbackDelegate* delegate)
- : state_(STATE_DEFAULT),
- unilib_(*unilib),
- arena_(kBlocksize),
- rules_(rules),
- rules_shards_(rules_shards),
- delegate_(delegate) {
- TC3_CHECK(rules_ != nullptr);
- Reset();
- }
- explicit Matcher(const UniLib* unilib, const RulesSet* rules,
- CallbackDelegate* delegate)
- : Matcher(unilib, rules, {}, delegate) {
- rules_shards_.reserve(rules->rules()->size());
- rules_shards_.insert(rules_shards_.end(), rules->rules()->begin(),
- rules->rules()->end());
- }
-
- // Resets the matcher.
- void Reset();
-
- // Finish the matching.
- void Finish();
-
- // Tells the matcher that the given terminal was found occupying position
- // range [begin, end) in the input.
- // The matcher may invoke callback functions before returning, if this
- // terminal triggers any new matches for rules in the grammar.
- // Calls to AddTerminal() and AddMatch() must be in left-to-right order,
- // that is, the sequence of `end` values must be non-decreasing.
- void AddTerminal(const CodepointSpan codepoint_span, const int match_offset,
- StringPiece terminal);
- void AddTerminal(const CodepointIndex begin, const CodepointIndex end,
- StringPiece terminal) {
- AddTerminal(CodepointSpan{begin, end}, begin, terminal);
- }
-
- // Adds a nonterminal match to the chart.
- // This can be invoked by the lexer if the lexer needs to add nonterminals to
- // the chart.
- void AddMatch(Match* match);
-
- // Allocates memory from an area for a new match.
- // The `size` parameter is there to allow subclassing of the match object
- // with additional fields.
- Match* AllocateMatch(const size_t size) {
- return reinterpret_cast<Match*>(arena_.Alloc(size));
- }
-
- template <typename T>
- T* AllocateMatch() {
- return reinterpret_cast<T*>(arena_.Alloc(sizeof(T)));
- }
-
- template <typename T, typename... Args>
- T* AllocateAndInitMatch(Args... args) {
- T* match = AllocateMatch<T>();
- match->Init(args...);
- return match;
- }
-
- // Returns the current number of bytes allocated for all match objects.
- size_t ArenaSize() const { return arena_.status().bytes_allocated(); }
-
- private:
- static constexpr int kBlocksize = 16 << 10;
-
- // The state of the matcher.
- enum State {
- // The matcher is in the default state.
- STATE_DEFAULT = 0,
-
- // The matcher is currently processing queued match items.
- STATE_PROCESSING = 1,
- };
- State state_;
-
- // Process matches from lhs set.
- void ExecuteLhsSet(const CodepointSpan codepoint_span, const int match_offset,
- const int whitespace_gap,
- const std::function<void(Match*)>& initializer,
- const RulesSet_::LhsSet* lhs_set,
- CallbackDelegate* delegate);
-
- // Queues a newly created match item.
- void QueueForProcessing(Match* item);
-
- // Queues a match item for later post checking of the exclusion condition.
- // For exclusions we need to check that the `item->excluded_nonterminal`
- // doesn't match the same span. As we cannot know which matches have already
- // been added, we queue the item for later post checking - once all matches
- // up to `item->codepoint_span.second` have been added.
- void QueueForPostCheck(ExclusionMatch* item);
-
- // Adds pending items to the chart, possibly generating new matches as a
- // result.
- void ProcessPendingSet();
-
- // Returns whether the chart contains a match for a given nonterminal.
- bool ContainsMatch(const Nonterm nonterm, const CodepointSpan& span) const;
-
- // Checks all pending exclusion matches that their exclusion condition is
- // fulfilled.
- void ProcessPendingExclusionMatches();
-
- UniLib unilib_;
-
- // Memory arena for match allocation.
- UnsafeArena arena_;
-
- // The end position of the most recent match or terminal, for sanity
- // checking.
- int last_end_;
-
- // Rules.
- const RulesSet* rules_;
-
- // The set of items pending to be added to the chart as a singly-linked list.
- Match* pending_items_;
-
- // The set of items pending to be post-checked as a singly-linked list.
- ExclusionMatch* pending_exclusion_items_;
-
- // The chart data structure: a hashtable containing all matches, indexed by
- // their end positions.
- static constexpr int kChartHashTableNumBuckets = 1 << 8;
- static constexpr int kChartHashTableBitmask = kChartHashTableNumBuckets - 1;
- std::array<Match*, kChartHashTableNumBuckets> chart_;
-
- // The active rule shards.
- std::vector<const RulesSet_::Rules*> rules_shards_;
-
- // The callback handler.
- CallbackDelegate* delegate_;
-};
-
-} // namespace libtextclassifier3::grammar
-
-#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_MATCHER_H_
diff --git a/native/utils/grammar/parsing/chart.h b/native/utils/grammar/parsing/chart.h
new file mode 100644
index 0000000..4ec05d7
--- /dev/null
+++ b/native/utils/grammar/parsing/chart.h
@@ -0,0 +1,108 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_PARSING_CHART_H_
+#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_PARSING_CHART_H_
+
+#include <array>
+
+#include "annotator/types.h"
+#include "utils/grammar/parsing/derivation.h"
+#include "utils/grammar/parsing/parse-tree.h"
+
+namespace libtextclassifier3::grammar {
+
+// Chart is a hashtable container for use with a CYK style parser.
+// The hashtable contains all matches, indexed by their end positions.
+template <int NumBuckets = 1 << 8>
+class Chart {
+ public:
+ explicit Chart() { std::fill(chart_.begin(), chart_.end(), nullptr); }
+
+ // Iterator that allows iterating through recorded matches that end at a given
+ // match offset.
+ class Iterator {
+ public:
+ explicit Iterator(const int match_offset, const ParseTree* value)
+ : match_offset_(match_offset), value_(value) {}
+
+ bool Done() const {
+ return value_ == nullptr ||
+ (value_->codepoint_span.second < match_offset_);
+ }
+ const ParseTree* Item() const { return value_; }
+ void Next() {
+ TC3_DCHECK(!Done());
+ value_ = value_->next;
+ }
+
+ private:
+ const int match_offset_;
+ const ParseTree* value_;
+ };
+
+ // Returns whether the chart contains a match for a given nonterminal.
+ bool HasMatch(const Nonterm nonterm, const CodepointSpan& span) const;
+
+ // Adds a match to the chart.
+ void Add(ParseTree* item) {
+ item->next = chart_[item->codepoint_span.second & kChartHashTableBitmask];
+ chart_[item->codepoint_span.second & kChartHashTableBitmask] = item;
+ }
+
+ // Records a derivation of a root rule.
+ void AddDerivation(const Derivation& derivation) {
+ root_derivations_.push_back(derivation);
+ }
+
+ // Returns an iterator through all matches ending at `match_offset`.
+ Iterator MatchesEndingAt(const int match_offset) const {
+ const ParseTree* value = chart_[match_offset & kChartHashTableBitmask];
+ // The chain of items is in decreasing `end` order.
+ // Find the ones that have prev->end == item->begin.
+ while (value != nullptr && (value->codepoint_span.second > match_offset)) {
+ value = value->next;
+ }
+ return Iterator(match_offset, value);
+ }
+
+ const std::vector<Derivation> derivations() const {
+ return root_derivations_;
+ }
+
+ private:
+ static constexpr int kChartHashTableBitmask = NumBuckets - 1;
+ std::array<ParseTree*, NumBuckets> chart_;
+ std::vector<Derivation> root_derivations_;
+};
+
+template <int NumBuckets>
+bool Chart<NumBuckets>::HasMatch(const Nonterm nonterm,
+ const CodepointSpan& span) const {
+ // Lookup by end.
+ for (Chart<NumBuckets>::Iterator it = MatchesEndingAt(span.second);
+ !it.Done(); it.Next()) {
+ if (it.Item()->lhs == nonterm &&
+ it.Item()->codepoint_span.first == span.first) {
+ return true;
+ }
+ }
+ return false;
+}
+
+} // namespace libtextclassifier3::grammar
+
+#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_PARSING_CHART_H_
diff --git a/native/utils/grammar/parsing/chart_test.cc b/native/utils/grammar/parsing/chart_test.cc
new file mode 100644
index 0000000..e4ec72f
--- /dev/null
+++ b/native/utils/grammar/parsing/chart_test.cc
@@ -0,0 +1,106 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/grammar/parsing/chart.h"
+
+#include "annotator/types.h"
+#include "utils/base/arena.h"
+#include "utils/grammar/parsing/parse-tree.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3::grammar {
+namespace {
+
+using ::testing::IsEmpty;
+
+class ChartTest : public testing::Test {
+ protected:
+ explicit ChartTest() : arena_(/*block_size=*/16 << 10) {}
+ UnsafeArena arena_;
+};
+
+TEST_F(ChartTest, IsEmptyByDefault) {
+ Chart<> chart;
+
+ EXPECT_THAT(chart.derivations(), IsEmpty());
+ EXPECT_TRUE(chart.MatchesEndingAt(0).Done());
+}
+
+TEST_F(ChartTest, IteratesThroughCell) {
+ Chart<> chart;
+ ParseTree* m0 = arena_.AllocAndInit<ParseTree>(/*lhs=*/0, CodepointSpan{0, 1},
+ /*match_offset=*/0,
+ ParseTree::Type::kDefault);
+ chart.Add(m0);
+ ParseTree* m1 = arena_.AllocAndInit<ParseTree>(/*lhs=*/1, CodepointSpan{0, 2},
+ /*match_offset=*/0,
+ ParseTree::Type::kDefault);
+ chart.Add(m1);
+ ParseTree* m2 = arena_.AllocAndInit<ParseTree>(/*lhs=*/2, CodepointSpan{0, 2},
+ /*match_offset=*/0,
+ ParseTree::Type::kDefault);
+ chart.Add(m2);
+
+ // Position 0 should be empty.
+ EXPECT_TRUE(chart.MatchesEndingAt(0).Done());
+
+ // Position 1 should contain m0.
+ {
+ Chart<>::Iterator it = chart.MatchesEndingAt(1);
+ ASSERT_FALSE(it.Done());
+ EXPECT_EQ(it.Item(), m0);
+ it.Next();
+ EXPECT_TRUE(it.Done());
+ }
+
+ // Position 2 should contain m1 and m2.
+ {
+ Chart<>::Iterator it = chart.MatchesEndingAt(2);
+ ASSERT_FALSE(it.Done());
+ EXPECT_EQ(it.Item(), m2);
+ it.Next();
+ ASSERT_FALSE(it.Done());
+ EXPECT_EQ(it.Item(), m1);
+ it.Next();
+ EXPECT_TRUE(it.Done());
+ }
+}
+
+TEST_F(ChartTest, ChecksExistingMatches) {
+ Chart<> chart;
+ ParseTree* m0 = arena_.AllocAndInit<ParseTree>(/*lhs=*/0, CodepointSpan{0, 1},
+ /*match_offset=*/0,
+ ParseTree::Type::kDefault);
+ chart.Add(m0);
+ ParseTree* m1 = arena_.AllocAndInit<ParseTree>(/*lhs=*/1, CodepointSpan{0, 2},
+ /*match_offset=*/0,
+ ParseTree::Type::kDefault);
+ chart.Add(m1);
+ ParseTree* m2 = arena_.AllocAndInit<ParseTree>(/*lhs=*/2, CodepointSpan{0, 2},
+ /*match_offset=*/0,
+ ParseTree::Type::kDefault);
+ chart.Add(m2);
+
+ EXPECT_TRUE(chart.HasMatch(0, CodepointSpan{0, 1}));
+ EXPECT_FALSE(chart.HasMatch(0, CodepointSpan{0, 2}));
+ EXPECT_TRUE(chart.HasMatch(1, CodepointSpan{0, 2}));
+ EXPECT_TRUE(chart.HasMatch(2, CodepointSpan{0, 2}));
+ EXPECT_FALSE(chart.HasMatch(0, CodepointSpan{0, 2}));
+}
+
+} // namespace
+} // namespace libtextclassifier3::grammar
diff --git a/native/utils/grammar/parsing/derivation.cc b/native/utils/grammar/parsing/derivation.cc
new file mode 100644
index 0000000..4298be5
--- /dev/null
+++ b/native/utils/grammar/parsing/derivation.cc
@@ -0,0 +1,54 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/grammar/parsing/derivation.h"
+
+#include <algorithm>
+#include <vector>
+
+namespace libtextclassifier3::grammar {
+
+bool Derivation::IsValid() const {
+ bool result = true;
+ Traverse(parse_tree, [&result](const ParseTree* node) {
+ if (node->type != ParseTree::Type::kAssertion) {
+ // Only validation if all checks so far passed.
+ return result;
+ }
+ // Positive assertions are by definition fulfilled,
+ // fail if the assertion is negative.
+ if (static_cast<const AssertionNode*>(node)->negative) {
+ result = false;
+ }
+ return result;
+ });
+ return result;
+}
+
+std::vector<Derivation> ValidDeduplicatedDerivations(
+ const std::vector<Derivation>& derivations) {
+ std::vector<Derivation> result;
+ for (const Derivation& derivation :
+ DeduplicateDerivations<Derivation>(derivations)) {
+ // Check that asserts are fulfilled.
+ if (derivation.IsValid()) {
+ result.push_back(derivation);
+ }
+ }
+ return result;
+}
+
+} // namespace libtextclassifier3::grammar
diff --git a/native/utils/grammar/parsing/derivation.h b/native/utils/grammar/parsing/derivation.h
new file mode 100644
index 0000000..2196495
--- /dev/null
+++ b/native/utils/grammar/parsing/derivation.h
@@ -0,0 +1,102 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_PARSING_DERIVATION_H_
+#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_PARSING_DERIVATION_H_
+
+#include <vector>
+
+#include "utils/grammar/parsing/parse-tree.h"
+
+namespace libtextclassifier3::grammar {
+
+// A parse tree for a root rule.
+struct Derivation {
+ const ParseTree* parse_tree;
+ int64 rule_id;
+
+ // Checks that all assertions are fulfilled.
+ bool IsValid() const;
+ int64 GetRuleId() const { return rule_id; }
+ const ParseTree* GetParseTree() const { return parse_tree; }
+};
+
+// Deduplicates rule derivations by containing overlap.
+// The grammar system can output multiple candidates for optional parts.
+// For example if a rule has an optional suffix, we
+// will get two rule derivations when the suffix is present: one with and one
+// without the suffix. We therefore deduplicate by containing overlap, viz. from
+// two candidates we keep the longer one if it completely contains the shorter.
+// This factory function works with any type T that extends Derivation.
+template <typename T, typename std::enable_if<std::is_base_of<
+ Derivation, T>::value>::type* = nullptr>
+// std::vector<T> DeduplicateDerivations(const std::vector<T>& derivations);
+std::vector<T> DeduplicateDerivations(const std::vector<T>& derivations) {
+ std::vector<T> sorted_candidates = derivations;
+
+ std::stable_sort(sorted_candidates.begin(), sorted_candidates.end(),
+ [](const T& a, const T& b) {
+ // Sort by id.
+ if (a.GetRuleId() != b.GetRuleId()) {
+ return a.GetRuleId() < b.GetRuleId();
+ }
+
+ // Sort by increasing start.
+ if (a.GetParseTree()->codepoint_span.first !=
+ b.GetParseTree()->codepoint_span.first) {
+ return a.GetParseTree()->codepoint_span.first <
+ b.GetParseTree()->codepoint_span.first;
+ }
+
+ // Sort by decreasing end.
+ return a.GetParseTree()->codepoint_span.second >
+ b.GetParseTree()->codepoint_span.second;
+ });
+
+ // Deduplicate by overlap.
+ std::vector<T> result;
+ for (int i = 0; i < sorted_candidates.size(); i++) {
+ const T& candidate = sorted_candidates[i];
+ bool eliminated = false;
+
+ // Due to the sorting above, the candidate can only be completely
+ // intersected by a match before it in the sorted order.
+ for (int j = i - 1; j >= 0; j--) {
+ if (sorted_candidates[j].rule_id != candidate.rule_id) {
+ break;
+ }
+ if (sorted_candidates[j].parse_tree->codepoint_span.first <=
+ candidate.parse_tree->codepoint_span.first &&
+ sorted_candidates[j].parse_tree->codepoint_span.second >=
+ candidate.parse_tree->codepoint_span.second) {
+ eliminated = true;
+ break;
+ }
+ }
+ if (!eliminated) {
+ result.push_back(candidate);
+ }
+ }
+ return result;
+}
+
+// Deduplicates and validates rule derivations.
+std::vector<Derivation> ValidDeduplicatedDerivations(
+ const std::vector<Derivation>& derivations);
+
+} // namespace libtextclassifier3::grammar
+
+#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_PARSING_DERIVATION_H_
diff --git a/native/utils/grammar/parsing/lexer.cc b/native/utils/grammar/parsing/lexer.cc
new file mode 100644
index 0000000..79e92e1
--- /dev/null
+++ b/native/utils/grammar/parsing/lexer.cc
@@ -0,0 +1,66 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/grammar/parsing/lexer.h"
+
+namespace libtextclassifier3::grammar {
+
+Symbol::Type Lexer::GetSymbolType(const UnicodeText::const_iterator& it) const {
+ if (unilib_.IsPunctuation(*it)) {
+ return Symbol::Type::TYPE_PUNCTUATION;
+ } else if (unilib_.IsDigit(*it)) {
+ return Symbol::Type::TYPE_DIGITS;
+ } else {
+ return Symbol::Type::TYPE_TERM;
+ }
+}
+
+void Lexer::AppendTokenSymbols(const StringPiece value, int match_offset,
+ const CodepointSpan codepoint_span,
+ std::vector<Symbol>* symbols) const {
+ // Possibly split token.
+ UnicodeText token_unicode = UTF8ToUnicodeText(value.data(), value.size(),
+ /*do_copy=*/false);
+ int next_match_offset = match_offset;
+ auto token_end = token_unicode.end();
+ auto it = token_unicode.begin();
+ Symbol::Type type = GetSymbolType(it);
+ CodepointIndex sub_token_start = codepoint_span.first;
+ while (it != token_end) {
+ auto next = std::next(it);
+ int num_codepoints = 1;
+ Symbol::Type next_type;
+ while (next != token_end) {
+ next_type = GetSymbolType(next);
+ if (type == Symbol::Type::TYPE_PUNCTUATION || next_type != type) {
+ break;
+ }
+ ++next;
+ ++num_codepoints;
+ }
+ symbols->emplace_back(
+ type, CodepointSpan{sub_token_start, sub_token_start + num_codepoints},
+ /*match_offset=*/next_match_offset,
+ /*lexeme=*/
+ StringPiece(it.utf8_data(), next.utf8_data() - it.utf8_data()));
+ next_match_offset = sub_token_start + num_codepoints;
+ it = next;
+ type = next_type;
+ sub_token_start = next_match_offset;
+ }
+}
+
+} // namespace libtextclassifier3::grammar
diff --git a/native/utils/grammar/parsing/lexer.h b/native/utils/grammar/parsing/lexer.h
new file mode 100644
index 0000000..f902fbd
--- /dev/null
+++ b/native/utils/grammar/parsing/lexer.h
@@ -0,0 +1,120 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// A lexer that (splits) and classifies tokens.
+//
+// Any whitespace gets absorbed into the token that follows them in the text.
+// For example, if the text contains:
+//
+// ...hello there world...
+// | | |
+// offset=16 39 52
+//
+// then the output will be:
+//
+// "hello" [?, 16)
+// "there" [16, 44) <-- note "16" NOT "39"
+// "world" [44, ?) <-- note "44" NOT "52"
+//
+// This makes it appear to the Matcher as if the tokens are adjacent.
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_PARSING_LEXER_H_
+#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_PARSING_LEXER_H_
+
+#include <vector>
+
+#include "annotator/types.h"
+#include "utils/grammar/parsing/parse-tree.h"
+#include "utils/grammar/types.h"
+#include "utils/strings/stringpiece.h"
+#include "utils/utf8/unicodetext.h"
+#include "utils/utf8/unilib.h"
+
+namespace libtextclassifier3::grammar {
+
+// A lexical symbol with an identified meaning that represents raw tokens,
+// token categories or predefined text matches.
+// It is the unit fed to the grammar matcher.
+struct Symbol {
+ // The type of the lexical symbol.
+ enum class Type {
+ // A raw token.
+ TYPE_TERM,
+
+ // A symbol representing a string of digits.
+ TYPE_DIGITS,
+
+ // Punctuation characters.
+ TYPE_PUNCTUATION,
+
+ // A predefined parse tree.
+ TYPE_PARSE_TREE
+ };
+
+ explicit Symbol() = default;
+
+ // Constructs a symbol of a given type with an anchor in the text.
+ Symbol(const Type type, const CodepointSpan codepoint_span,
+ const int match_offset, StringPiece lexeme)
+ : type(type),
+ codepoint_span(codepoint_span),
+ match_offset(match_offset),
+ lexeme(lexeme) {}
+
+ // Constructs a symbol from a pre-defined parse tree.
+ explicit Symbol(ParseTree* parse_tree)
+ : type(Type::TYPE_PARSE_TREE),
+ codepoint_span(parse_tree->codepoint_span),
+ match_offset(parse_tree->match_offset),
+ parse_tree(parse_tree) {}
+
+ // The type of the symbol.
+ Type type;
+
+ // The span in the text as codepoint offsets.
+ CodepointSpan codepoint_span;
+
+ // The match start offset (including preceding whitespace) as codepoint
+ // offset.
+ int match_offset;
+
+ // The symbol text value.
+ StringPiece lexeme;
+
+ // The predefined parse tree.
+ ParseTree* parse_tree;
+};
+
+class Lexer {
+ public:
+ explicit Lexer(const UniLib* unilib) : unilib_(*unilib) {}
+
+ // Processes a single token.
+ // Splits a token into classified symbols.
+ void AppendTokenSymbols(const StringPiece value, int match_offset,
+ const CodepointSpan codepoint_span,
+ std::vector<Symbol>* symbols) const;
+
+ private:
+ // Gets the type of a character.
+ Symbol::Type GetSymbolType(const UnicodeText::const_iterator& it) const;
+
+ const UniLib& unilib_;
+};
+
+} // namespace libtextclassifier3::grammar
+
+#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_PARSING_LEXER_H_
diff --git a/native/utils/grammar/parsing/lexer_test.cc b/native/utils/grammar/parsing/lexer_test.cc
new file mode 100644
index 0000000..dad3b8e
--- /dev/null
+++ b/native/utils/grammar/parsing/lexer_test.cc
@@ -0,0 +1,170 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Unit tests for the lexer.
+
+#include "utils/grammar/parsing/lexer.h"
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "annotator/types.h"
+#include "utils/jvm-test-utils.h"
+#include "utils/tokenizer.h"
+#include "utils/utf8/unicodetext.h"
+#include "utils/utf8/unilib.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3::grammar {
+
+std::ostream& operator<<(std::ostream& os, const Symbol& symbol) {
+ return os << "Symbol(type=" << static_cast<int>(symbol.type) << ", span=("
+ << symbol.codepoint_span.first << ", "
+ << symbol.codepoint_span.second
+ << "), lexeme=" << symbol.lexeme.ToString() << ")";
+}
+
+namespace {
+
+using ::testing::DescribeMatcher;
+using ::testing::ElementsAre;
+using ::testing::ExplainMatchResult;
+
+// Superclass of all tests here.
+class LexerTest : public testing::Test {
+ protected:
+ explicit LexerTest()
+ : unilib_(libtextclassifier3::CreateUniLibForTesting()),
+ tokenizer_(TokenizationType_ICU, unilib_.get(),
+ /*codepoint_ranges=*/{},
+ /*internal_tokenizer_codepoint_ranges=*/{},
+ /*split_on_script_change=*/false,
+ /*icu_preserve_whitespace_tokens=*/false),
+ lexer_(unilib_.get()) {}
+
+ std::vector<Symbol> SymbolsForTokens(const std::vector<Token>& tokens) const {
+ std::vector<Symbol> symbols;
+ for (const Token& token : tokens) {
+ lexer_.AppendTokenSymbols(token.value, token.start,
+ CodepointSpan{token.start, token.end},
+ &symbols);
+ }
+ return symbols;
+ }
+
+ std::unique_ptr<UniLib> unilib_;
+ Tokenizer tokenizer_;
+ Lexer lexer_;
+};
+
+MATCHER_P4(IsSymbol, type, begin, end, terminal,
+ "is symbol with type that " +
+ DescribeMatcher<Symbol::Type>(type, negation) + ", begin that " +
+ DescribeMatcher<int>(begin, negation) + ", end that " +
+ DescribeMatcher<int>(end, negation) + ", value that " +
+ DescribeMatcher<std::string>(terminal, negation)) {
+ return ExplainMatchResult(type, arg.type, result_listener) &&
+ ExplainMatchResult(CodepointSpan(begin, end), arg.codepoint_span,
+ result_listener) &&
+ ExplainMatchResult(terminal, arg.lexeme.ToString(), result_listener);
+}
+
+TEST_F(LexerTest, HandlesSimpleWords) {
+ std::vector<Token> tokens = tokenizer_.Tokenize("This is a word");
+ EXPECT_THAT(SymbolsForTokens(tokens),
+ ElementsAre(IsSymbol(Symbol::Type::TYPE_TERM, 0, 4, "This"),
+ IsSymbol(Symbol::Type::TYPE_TERM, 5, 7, "is"),
+ IsSymbol(Symbol::Type::TYPE_TERM, 8, 9, "a"),
+ IsSymbol(Symbol::Type::TYPE_TERM, 10, 14, "word")));
+}
+
+TEST_F(LexerTest, SplitsConcatedLettersAndDigit) {
+ std::vector<Token> tokens = tokenizer_.Tokenize("1234This a4321cde");
+ EXPECT_THAT(SymbolsForTokens(tokens),
+ ElementsAre(IsSymbol(Symbol::Type::TYPE_DIGITS, 0, 4, "1234"),
+ IsSymbol(Symbol::Type::TYPE_TERM, 4, 8, "This"),
+ IsSymbol(Symbol::Type::TYPE_TERM, 9, 10, "a"),
+ IsSymbol(Symbol::Type::TYPE_DIGITS, 10, 14, "4321"),
+ IsSymbol(Symbol::Type::TYPE_TERM, 14, 17, "cde")));
+}
+
+TEST_F(LexerTest, SplitsPunctuation) {
+ std::vector<Token> tokens = tokenizer_.Tokenize("10/18/2014");
+ EXPECT_THAT(SymbolsForTokens(tokens),
+ ElementsAre(IsSymbol(Symbol::Type::TYPE_DIGITS, 0, 2, "10"),
+ IsSymbol(Symbol::Type::TYPE_PUNCTUATION, 2, 3, "/"),
+ IsSymbol(Symbol::Type::TYPE_DIGITS, 3, 5, "18"),
+ IsSymbol(Symbol::Type::TYPE_PUNCTUATION, 5, 6, "/"),
+ IsSymbol(Symbol::Type::TYPE_DIGITS, 6, 10, "2014")));
+}
+
+TEST_F(LexerTest, SplitsUTF8Punctuation) {
+ std::vector<Token> tokens = tokenizer_.Tokenize("电话:0871—6857(曹");
+ EXPECT_THAT(
+ SymbolsForTokens(tokens),
+ ElementsAre(IsSymbol(Symbol::Type::TYPE_TERM, 0, 2, "电话"),
+ IsSymbol(Symbol::Type::TYPE_PUNCTUATION, 2, 3, ":"),
+ IsSymbol(Symbol::Type::TYPE_DIGITS, 3, 7, "0871"),
+ IsSymbol(Symbol::Type::TYPE_PUNCTUATION, 7, 8, "—"),
+ IsSymbol(Symbol::Type::TYPE_DIGITS, 8, 12, "6857"),
+ IsSymbol(Symbol::Type::TYPE_PUNCTUATION, 12, 13, "("),
+ IsSymbol(Symbol::Type::TYPE_TERM, 13, 14, "曹")));
+}
+
+TEST_F(LexerTest, HandlesMixedPunctuation) {
+ std::vector<Token> tokens = tokenizer_.Tokenize("电话 :0871—6857(曹");
+ EXPECT_THAT(
+ SymbolsForTokens(tokens),
+ ElementsAre(IsSymbol(Symbol::Type::TYPE_TERM, 0, 2, "电话"),
+ IsSymbol(Symbol::Type::TYPE_PUNCTUATION, 3, 4, ":"),
+ IsSymbol(Symbol::Type::TYPE_DIGITS, 4, 8, "0871"),
+ IsSymbol(Symbol::Type::TYPE_PUNCTUATION, 8, 9, "—"),
+ IsSymbol(Symbol::Type::TYPE_DIGITS, 9, 13, "6857"),
+ IsSymbol(Symbol::Type::TYPE_PUNCTUATION, 13, 14, "("),
+ IsSymbol(Symbol::Type::TYPE_TERM, 14, 15, "曹")));
+}
+
+TEST_F(LexerTest, HandlesTokensWithDigits) {
+ std::vector<Token> tokens =
+ tokenizer_.Tokenize("The.qUIck\n brown2345fox88 \xE2\x80\x94 the");
+ EXPECT_THAT(SymbolsForTokens(tokens),
+ ElementsAre(IsSymbol(Symbol::Type::TYPE_TERM, 0, 3, "The"),
+ IsSymbol(Symbol::Type::TYPE_PUNCTUATION, 3, 4, "."),
+ IsSymbol(Symbol::Type::TYPE_TERM, 4, 9, "qUIck"),
+ IsSymbol(Symbol::Type::TYPE_TERM, 11, 16, "brown"),
+ IsSymbol(Symbol::Type::TYPE_DIGITS, 16, 20, "2345"),
+ IsSymbol(Symbol::Type::TYPE_TERM, 20, 23, "fox"),
+ IsSymbol(Symbol::Type::TYPE_DIGITS, 23, 25, "88"),
+ IsSymbol(Symbol::Type::TYPE_PUNCTUATION, 26, 27, "—"),
+ IsSymbol(Symbol::Type::TYPE_TERM, 28, 31, "the")));
+}
+
+TEST_F(LexerTest, SplitsPlusSigns) {
+ std::vector<Token> tokens = tokenizer_.Tokenize("The+2345++the +");
+ EXPECT_THAT(SymbolsForTokens(tokens),
+ ElementsAre(IsSymbol(Symbol::Type::TYPE_TERM, 0, 3, "The"),
+ IsSymbol(Symbol::Type::TYPE_TERM, 3, 4, "+"),
+ IsSymbol(Symbol::Type::TYPE_DIGITS, 4, 8, "2345"),
+ IsSymbol(Symbol::Type::TYPE_TERM, 8, 9, "+"),
+ IsSymbol(Symbol::Type::TYPE_TERM, 9, 10, "+"),
+ IsSymbol(Symbol::Type::TYPE_TERM, 10, 13, "the"),
+ IsSymbol(Symbol::Type::TYPE_TERM, 14, 15, "+")));
+}
+
+} // namespace
+} // namespace libtextclassifier3::grammar
diff --git a/native/utils/grammar/matcher.cc b/native/utils/grammar/parsing/matcher.cc
similarity index 67%
rename from native/utils/grammar/matcher.cc
rename to native/utils/grammar/parsing/matcher.cc
index a8ebba5..fa0ea0a 100644
--- a/native/utils/grammar/matcher.cc
+++ b/native/utils/grammar/parsing/matcher.cc
@@ -14,7 +14,7 @@
* limitations under the License.
*/
-#include "utils/grammar/matcher.h"
+#include "utils/grammar/parsing/matcher.h"
#include <iostream>
#include <limits>
@@ -58,10 +58,13 @@
// Queue next character.
if (buffer_pos >= buffer_size) {
buffer_pos = 0;
- // Lower-case the next character.
+
+ // Lower-case the next character. The character and its lower-cased
+ // counterpart may be represented with a different number of bytes in
+ // utf8.
buffer_size =
ValidRuneToChar(unilib.ToLower(ValidCharToRune(data)), buffer);
- data += buffer_size;
+ data += GetNumBytesForUTF8Char(data);
}
TC3_DCHECK_LT(buffer_pos, buffer_size);
return buffer[buffer_pos++];
@@ -130,7 +133,7 @@
}
++match_length;
- // By the loop variant and due to the fact that the strings are sorted,
+ // By the loop invariant and due to the fact that the strings are sorted,
// a matching string will be at `left` now.
if (!input_iterator.HasNext()) {
const int string_offset = LittleEndian::ToHost32(offsets[left]);
@@ -217,7 +220,7 @@
}
inline void GetLhs(const RulesSet* rules_set, const int lhs_entry,
- Nonterm* nonterminal, CallbackId* callback, uint64* param,
+ Nonterm* nonterminal, CallbackId* callback, int64* param,
int8* max_whitespace_gap) {
if (lhs_entry > 0) {
// Direct encoding of the nonterminal.
@@ -236,27 +239,18 @@
} // namespace
-void Matcher::Reset() {
- state_ = STATE_DEFAULT;
- arena_.Reset();
- pending_items_ = nullptr;
- pending_exclusion_items_ = nullptr;
- std::fill(chart_.begin(), chart_.end(), nullptr);
- last_end_ = std::numeric_limits<int>().lowest();
-}
-
void Matcher::Finish() {
// Check any pending items.
ProcessPendingExclusionMatches();
}
-void Matcher::QueueForProcessing(Match* item) {
+void Matcher::QueueForProcessing(ParseTree* item) {
// Push element to the front.
item->next = pending_items_;
pending_items_ = item;
}
-void Matcher::QueueForPostCheck(ExclusionMatch* item) {
+void Matcher::QueueForPostCheck(ExclusionNode* item) {
// Push element to the front.
item->next = pending_exclusion_items_;
pending_exclusion_items_ = item;
@@ -282,11 +276,11 @@
ExecuteLhsSet(
codepoint_span, match_offset,
/*whitespace_gap=*/(codepoint_span.first - match_offset),
- [terminal](Match* match) {
- match->terminal = terminal.data();
- match->rhs2 = nullptr;
+ [terminal](ParseTree* parse_tree) {
+ parse_tree->terminal = terminal.data();
+ parse_tree->rhs2 = nullptr;
},
- lhs_set, delegate_);
+ lhs_set);
}
// Try case-insensitive matches.
@@ -298,42 +292,41 @@
ExecuteLhsSet(
codepoint_span, match_offset,
/*whitespace_gap=*/(codepoint_span.first - match_offset),
- [terminal](Match* match) {
- match->terminal = terminal.data();
- match->rhs2 = nullptr;
+ [terminal](ParseTree* parse_tree) {
+ parse_tree->terminal = terminal.data();
+ parse_tree->rhs2 = nullptr;
},
- lhs_set, delegate_);
+ lhs_set);
}
}
ProcessPendingSet();
}
-void Matcher::AddMatch(Match* match) {
- TC3_CHECK_GE(match->codepoint_span.second, last_end_);
+void Matcher::AddParseTree(ParseTree* parse_tree) {
+ TC3_CHECK_GE(parse_tree->codepoint_span.second, last_end_);
// Finish any pending post-checks.
- if (match->codepoint_span.second > last_end_) {
+ if (parse_tree->codepoint_span.second > last_end_) {
ProcessPendingExclusionMatches();
}
- last_end_ = match->codepoint_span.second;
- QueueForProcessing(match);
+ last_end_ = parse_tree->codepoint_span.second;
+ QueueForProcessing(parse_tree);
ProcessPendingSet();
}
-void Matcher::ExecuteLhsSet(const CodepointSpan codepoint_span,
- const int match_offset_bytes,
- const int whitespace_gap,
- const std::function<void(Match*)>& initializer,
- const RulesSet_::LhsSet* lhs_set,
- CallbackDelegate* delegate) {
+void Matcher::ExecuteLhsSet(
+ const CodepointSpan codepoint_span, const int match_offset_bytes,
+ const int whitespace_gap,
+ const std::function<void(ParseTree*)>& initializer_fn,
+ const RulesSet_::LhsSet* lhs_set) {
TC3_CHECK(lhs_set);
- Match* match = nullptr;
+ ParseTree* parse_tree = nullptr;
Nonterm prev_lhs = kUnassignedNonterm;
for (const int32 lhs_entry : *lhs_set->lhs()) {
Nonterm lhs;
CallbackId callback_id;
- uint64 callback_param;
+ int64 callback_param;
int8 max_whitespace_gap;
GetLhs(rules_, lhs_entry, &lhs, &callback_id, &callback_param,
&max_whitespace_gap);
@@ -343,91 +336,70 @@
continue;
}
- // Handle default callbacks.
+ // Handle callbacks.
switch (static_cast<DefaultCallback>(callback_id)) {
- case DefaultCallback::kSetType: {
- Match* typed_match = AllocateAndInitMatch<Match>(lhs, codepoint_span,
- match_offset_bytes);
- initializer(typed_match);
- typed_match->type = callback_param;
- QueueForProcessing(typed_match);
- continue;
- }
case DefaultCallback::kAssertion: {
- AssertionMatch* assertion_match = AllocateAndInitMatch<AssertionMatch>(
- lhs, codepoint_span, match_offset_bytes);
- initializer(assertion_match);
- assertion_match->type = Match::kAssertionMatch;
- assertion_match->negative = (callback_param != 0);
- QueueForProcessing(assertion_match);
+ AssertionNode* assertion_node = arena_->AllocAndInit<AssertionNode>(
+ lhs, codepoint_span, match_offset_bytes,
+ /*negative=*/(callback_param != 0));
+ initializer_fn(assertion_node);
+ QueueForProcessing(assertion_node);
continue;
}
case DefaultCallback::kMapping: {
- MappingMatch* mapping_match = AllocateAndInitMatch<MappingMatch>(
- lhs, codepoint_span, match_offset_bytes);
- initializer(mapping_match);
- mapping_match->type = Match::kMappingMatch;
- mapping_match->id = callback_param;
- QueueForProcessing(mapping_match);
+ MappingNode* mapping_node = arena_->AllocAndInit<MappingNode>(
+ lhs, codepoint_span, match_offset_bytes, /*id=*/callback_param);
+ initializer_fn(mapping_node);
+ QueueForProcessing(mapping_node);
continue;
}
case DefaultCallback::kExclusion: {
// We can only check the exclusion once all matches up to this position
// have been processed. Schedule and post check later.
- ExclusionMatch* exclusion_match = AllocateAndInitMatch<ExclusionMatch>(
- lhs, codepoint_span, match_offset_bytes);
- initializer(exclusion_match);
- exclusion_match->exclusion_nonterm = callback_param;
- QueueForPostCheck(exclusion_match);
+ ExclusionNode* exclusion_node = arena_->AllocAndInit<ExclusionNode>(
+ lhs, codepoint_span, match_offset_bytes,
+ /*exclusion_nonterm=*/callback_param);
+ initializer_fn(exclusion_node);
+ QueueForPostCheck(exclusion_node);
+ continue;
+ }
+ case DefaultCallback::kSemanticExpression: {
+ SemanticExpressionNode* expression_node =
+ arena_->AllocAndInit<SemanticExpressionNode>(
+ lhs, codepoint_span, match_offset_bytes,
+ /*expression=*/
+ rules_->semantic_expression()->Get(callback_param));
+ initializer_fn(expression_node);
+ QueueForProcessing(expression_node);
continue;
}
default:
break;
}
- if (callback_id != kNoCallback && rules_->callback() != nullptr) {
- const RulesSet_::CallbackEntry* callback_info =
- rules_->callback()->LookupByKey(callback_id);
- if (callback_info && callback_info->value().is_filter()) {
- // Filter callback.
- Match candidate;
- candidate.Init(lhs, codepoint_span, match_offset_bytes);
- initializer(&candidate);
- delegate->MatchFound(&candidate, callback_id, callback_param, this);
- continue;
- }
- }
-
if (prev_lhs != lhs) {
prev_lhs = lhs;
- match =
- AllocateAndInitMatch<Match>(lhs, codepoint_span, match_offset_bytes);
- initializer(match);
- QueueForProcessing(match);
+ parse_tree = arena_->AllocAndInit<ParseTree>(
+ lhs, codepoint_span, match_offset_bytes, ParseTree::Type::kDefault);
+ initializer_fn(parse_tree);
+ QueueForProcessing(parse_tree);
}
- if (callback_id != kNoCallback) {
- // This is an output callback.
- delegate->MatchFound(match, callback_id, callback_param, this);
+ if (static_cast<DefaultCallback>(callback_id) ==
+ DefaultCallback::kRootRule) {
+ chart_.AddDerivation(Derivation{parse_tree, /*rule_id=*/callback_param});
}
}
}
void Matcher::ProcessPendingSet() {
- // Avoid recursion caused by:
- // ProcessPendingSet --> callback --> AddMatch --> ProcessPendingSet --> ...
- if (state_ == STATE_PROCESSING) {
- return;
- }
- state_ = STATE_PROCESSING;
while (pending_items_) {
// Process.
- Match* item = pending_items_;
+ ParseTree* item = pending_items_;
pending_items_ = pending_items_->next;
// Add it to the chart.
- item->next = chart_[item->codepoint_span.second & kChartHashTableBitmask];
- chart_[item->codepoint_span.second & kChartHashTableBitmask] = item;
+ chart_.Add(item);
// Check unary rules that trigger.
for (const RulesSet_::Rules* shard : rules_shards_) {
@@ -437,26 +409,19 @@
item->codepoint_span, item->match_offset,
/*whitespace_gap=*/
(item->codepoint_span.first - item->match_offset),
- [item](Match* match) {
- match->rhs1 = nullptr;
- match->rhs2 = item;
+ [item](ParseTree* parse_tree) {
+ parse_tree->rhs1 = nullptr;
+ parse_tree->rhs2 = item;
},
- lhs_set, delegate_);
+ lhs_set);
}
}
// Check binary rules that trigger.
// Lookup by begin.
- Match* prev = chart_[item->match_offset & kChartHashTableBitmask];
- // The chain of items is in decreasing `end` order.
- // Find the ones that have prev->end == item->begin.
- while (prev != nullptr &&
- (prev->codepoint_span.second > item->match_offset)) {
- prev = prev->next;
- }
- for (;
- prev != nullptr && (prev->codepoint_span.second == item->match_offset);
- prev = prev->next) {
+ for (Chart<>::Iterator it = chart_.MatchesEndingAt(item->match_offset);
+ !it.Done(); it.Next()) {
+ const ParseTree* prev = it.Item();
for (const RulesSet_::Rules* shard : rules_shards_) {
if (const RulesSet_::LhsSet* lhs_set =
FindBinaryRulesMatches(rules_, shard, {prev->lhs, item->lhs})) {
@@ -468,45 +433,27 @@
(item->codepoint_span.first -
item->match_offset), // Whitespace gap is the gap
// between the two parts.
- [prev, item](Match* match) {
- match->rhs1 = prev;
- match->rhs2 = item;
+ [prev, item](ParseTree* parse_tree) {
+ parse_tree->rhs1 = prev;
+ parse_tree->rhs2 = item;
},
- lhs_set, delegate_);
+ lhs_set);
}
}
}
}
- state_ = STATE_DEFAULT;
}
void Matcher::ProcessPendingExclusionMatches() {
while (pending_exclusion_items_) {
- ExclusionMatch* item = pending_exclusion_items_;
- pending_exclusion_items_ = static_cast<ExclusionMatch*>(item->next);
+ ExclusionNode* item = pending_exclusion_items_;
+ pending_exclusion_items_ = static_cast<ExclusionNode*>(item->next);
// Check that the exclusion condition is fulfilled.
- if (!ContainsMatch(item->exclusion_nonterm, item->codepoint_span)) {
- AddMatch(item);
+ if (!chart_.HasMatch(item->exclusion_nonterm, item->codepoint_span)) {
+ AddParseTree(item);
}
}
}
-bool Matcher::ContainsMatch(const Nonterm nonterm,
- const CodepointSpan& span) const {
- // Lookup by end.
- Match* match = chart_[span.second & kChartHashTableBitmask];
- // The chain of items is in decreasing `end` order.
- while (match != nullptr && match->codepoint_span.second > span.second) {
- match = match->next;
- }
- while (match != nullptr && match->codepoint_span.second == span.second) {
- if (match->lhs == nonterm && match->codepoint_span.first == span.first) {
- return true;
- }
- match = match->next;
- }
- return false;
-}
-
} // namespace libtextclassifier3::grammar
diff --git a/native/utils/grammar/parsing/matcher.h b/native/utils/grammar/parsing/matcher.h
new file mode 100644
index 0000000..f12a6a5
--- /dev/null
+++ b/native/utils/grammar/parsing/matcher.h
@@ -0,0 +1,151 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// A token based context-free grammar matcher.
+//
+// A parser passes token to the matcher: literal terminal strings and token
+// types.
+// The parser passes each token along with the [begin, end) position range
+// in which it occurs. So for an input string "Groundhog February 2, 2007", the
+// parser would tell the matcher that:
+//
+// "Groundhog" occurs at [0, 9)
+// "February" occurs at [9, 18)
+// <digits> occurs at [18, 20)
+// "," occurs at [20, 21)
+// <digits> occurs at [21, 26)
+//
+// Multiple overlapping symbols can be passed.
+// The only constraint on symbol order is that they have to be passed in
+// left-to-right order, strictly speaking, their "end" positions must be
+// nondecreasing. This constraint allows a more efficient matching algorithm.
+// The "begin" positions can be in any order.
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_PARSING_MATCHER_H_
+#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_PARSING_MATCHER_H_
+
+#include <array>
+#include <functional>
+#include <vector>
+
+#include "annotator/types.h"
+#include "utils/base/arena.h"
+#include "utils/grammar/parsing/chart.h"
+#include "utils/grammar/parsing/derivation.h"
+#include "utils/grammar/parsing/parse-tree.h"
+#include "utils/grammar/rules_generated.h"
+#include "utils/strings/stringpiece.h"
+#include "utils/utf8/unilib.h"
+
+namespace libtextclassifier3::grammar {
+
+class Matcher {
+ public:
+ explicit Matcher(const UniLib* unilib, const RulesSet* rules,
+ const std::vector<const RulesSet_::Rules*> rules_shards,
+ UnsafeArena* arena)
+ : unilib_(*unilib),
+ arena_(arena),
+ last_end_(std::numeric_limits<int>().lowest()),
+ rules_(rules),
+ rules_shards_(rules_shards),
+ pending_items_(nullptr),
+ pending_exclusion_items_(nullptr) {
+ TC3_CHECK_NE(rules, nullptr);
+ }
+
+ explicit Matcher(const UniLib* unilib, const RulesSet* rules,
+ UnsafeArena* arena)
+ : Matcher(unilib, rules, {}, arena) {
+ rules_shards_.reserve(rules->rules()->size());
+ rules_shards_.insert(rules_shards_.end(), rules->rules()->begin(),
+ rules->rules()->end());
+ }
+
+ // Finish the matching.
+ void Finish();
+
+ // Tells the matcher that the given terminal was found occupying position
+ // range [begin, end) in the input.
+ // The matcher may invoke callback functions before returning, if this
+ // terminal triggers any new matches for rules in the grammar.
+ // Calls to AddTerminal() and AddParseTree() must be in left-to-right order,
+ // that is, the sequence of `end` values must be non-decreasing.
+ void AddTerminal(const CodepointSpan codepoint_span, const int match_offset,
+ StringPiece terminal);
+ void AddTerminal(const CodepointIndex begin, const CodepointIndex end,
+ StringPiece terminal) {
+ AddTerminal(CodepointSpan{begin, end}, begin, terminal);
+ }
+
+ // Adds predefined parse tree.
+ void AddParseTree(ParseTree* parse_tree);
+
+ const Chart<> chart() const { return chart_; }
+
+ private:
+ // Process matches from lhs set.
+ void ExecuteLhsSet(const CodepointSpan codepoint_span, const int match_offset,
+ const int whitespace_gap,
+ const std::function<void(ParseTree*)>& initializer_fn,
+ const RulesSet_::LhsSet* lhs_set);
+
+ // Queues a newly created match item.
+ void QueueForProcessing(ParseTree* item);
+
+ // Queues a match item for later post checking of the exclusion condition.
+ // For exclusions we need to check that the `item->excluded_nonterminal`
+ // doesn't match the same span. As we cannot know which matches have already
+ // been added, we queue the item for later post checking - once all matches
+ // up to `item->codepoint_span.second` have been added.
+ void QueueForPostCheck(ExclusionNode* item);
+
+ // Adds pending items to the chart, possibly generating new matches as a
+ // result.
+ void ProcessPendingSet();
+
+ // Checks all pending exclusion matches that their exclusion condition is
+ // fulfilled.
+ void ProcessPendingExclusionMatches();
+
+ UniLib unilib_;
+
+ // Memory arena for match allocation.
+ UnsafeArena* arena_;
+
+ // The end position of the most recent match or terminal, for sanity
+ // checking.
+ int last_end_;
+
+ // Rules.
+ const RulesSet* rules_;
+ // The active rule shards.
+ std::vector<const RulesSet_::Rules*> rules_shards_;
+
+ // The set of items pending to be added to the chart as a singly-linked list.
+ ParseTree* pending_items_;
+
+ // The set of items pending to be post-checked as a singly-linked list.
+ ExclusionNode* pending_exclusion_items_;
+
+ // The chart data structure: a hashtable containing all matches, indexed by
+ // their end positions.
+ Chart<> chart_;
+};
+
+} // namespace libtextclassifier3::grammar
+
+#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_PARSING_MATCHER_H_
diff --git a/native/utils/grammar/parsing/matcher_test.cc b/native/utils/grammar/parsing/matcher_test.cc
new file mode 100644
index 0000000..7c9a14d
--- /dev/null
+++ b/native/utils/grammar/parsing/matcher_test.cc
@@ -0,0 +1,441 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/grammar/parsing/matcher.h"
+
+#include <string>
+#include <vector>
+
+#include "utils/base/arena.h"
+#include "utils/grammar/rules_generated.h"
+#include "utils/grammar/types.h"
+#include "utils/grammar/utils/rules.h"
+#include "utils/strings/append.h"
+#include "utils/utf8/unilib.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3::grammar {
+namespace {
+
+using ::testing::DescribeMatcher;
+using ::testing::ElementsAre;
+using ::testing::ExplainMatchResult;
+using ::testing::IsEmpty;
+
+struct TestMatchResult {
+ CodepointSpan codepoint_span;
+ std::string terminal;
+ std::string nonterminal;
+ int rule_id;
+
+ friend std::ostream& operator<<(std::ostream& os,
+ const TestMatchResult& match) {
+ return os << "Result(rule_id=" << match.rule_id
+ << ", begin=" << match.codepoint_span.first
+ << ", end=" << match.codepoint_span.second
+ << ", terminal=" << match.terminal
+ << ", nonterminal=" << match.nonterminal << ")";
+ }
+};
+
+MATCHER_P3(IsTerminal, begin, end, terminal,
+ "is terminal with begin that " +
+ DescribeMatcher<int>(begin, negation) + ", end that " +
+ DescribeMatcher<int>(end, negation) + ", value that " +
+ DescribeMatcher<std::string>(terminal, negation)) {
+ return ExplainMatchResult(CodepointSpan(begin, end), arg.codepoint_span,
+ result_listener) &&
+ ExplainMatchResult(terminal, arg.terminal, result_listener);
+}
+
+MATCHER_P3(IsNonterminal, begin, end, name,
+ "is nonterminal with begin that " +
+ DescribeMatcher<int>(begin, negation) + ", end that " +
+ DescribeMatcher<int>(end, negation) + ", name that " +
+ DescribeMatcher<std::string>(name, negation)) {
+ return ExplainMatchResult(CodepointSpan(begin, end), arg.codepoint_span,
+ result_listener) &&
+ ExplainMatchResult(name, arg.nonterminal, result_listener);
+}
+
+MATCHER_P4(IsDerivation, begin, end, name, rule_id,
+ "is derivation of rule that " +
+ DescribeMatcher<int>(rule_id, negation) + ", begin that " +
+ DescribeMatcher<int>(begin, negation) + ", end that " +
+ DescribeMatcher<int>(end, negation) + ", name that " +
+ DescribeMatcher<std::string>(name, negation)) {
+ return ExplainMatchResult(IsNonterminal(begin, end, name), arg,
+ result_listener) &&
+ ExplainMatchResult(rule_id, arg.rule_id, result_listener);
+}
+
+// Superclass of all tests.
+class MatcherTest : public testing::Test {
+ protected:
+ MatcherTest()
+ : INIT_UNILIB_FOR_TESTING(unilib_), arena_(/*block_size=*/16 << 10) {}
+
+ std::string GetNonterminalName(
+ const RulesSet_::DebugInformation* debug_information,
+ const Nonterm nonterminal) const {
+ if (const RulesSet_::DebugInformation_::NonterminalNamesEntry* entry =
+ debug_information->nonterminal_names()->LookupByKey(nonterminal)) {
+ return entry->value()->str();
+ }
+ // Unnamed Nonterm.
+ return "()";
+ }
+
+ std::vector<TestMatchResult> GetMatchResults(
+ const Chart<>& chart,
+ const RulesSet_::DebugInformation* debug_information) {
+ std::vector<TestMatchResult> result;
+ for (const Derivation& derivation : chart.derivations()) {
+ result.emplace_back();
+ result.back().rule_id = derivation.rule_id;
+ result.back().codepoint_span = derivation.parse_tree->codepoint_span;
+ result.back().nonterminal =
+ GetNonterminalName(debug_information, derivation.parse_tree->lhs);
+ if (derivation.parse_tree->IsTerminalRule()) {
+ result.back().terminal = derivation.parse_tree->terminal;
+ }
+ }
+ return result;
+ }
+
+ UniLib unilib_;
+ UnsafeArena arena_;
+};
+
+TEST_F(MatcherTest, HandlesBasicOperations) {
+ // Create an example grammar.
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ Rules rules(locale_shard_map);
+ rules.Add("<test>", {"the", "quick", "brown", "fox"},
+ static_cast<CallbackId>(DefaultCallback::kRootRule));
+ rules.Add("<action>", {"<test>"},
+ static_cast<CallbackId>(DefaultCallback::kRootRule));
+ const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer(
+ /*include_debug_information=*/true);
+ const RulesSet* rules_set =
+ flatbuffers::GetRoot<RulesSet>(rules_buffer.data());
+ Matcher matcher(&unilib_, rules_set, &arena_);
+
+ matcher.AddTerminal(0, 1, "the");
+ matcher.AddTerminal(1, 2, "quick");
+ matcher.AddTerminal(2, 3, "brown");
+ matcher.AddTerminal(3, 4, "fox");
+
+ EXPECT_THAT(GetMatchResults(matcher.chart(), rules_set->debug_information()),
+ ElementsAre(IsNonterminal(0, 4, "<test>"),
+ IsNonterminal(0, 4, "<action>")));
+}
+
+std::string CreateTestGrammar() {
+ // Create an example grammar.
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ Rules rules(locale_shard_map);
+
+ // Callbacks on terminal rules.
+ rules.Add("<output_5>", {"quick"},
+ static_cast<CallbackId>(DefaultCallback::kRootRule), 6);
+ rules.Add("<output_0>", {"the"},
+ static_cast<CallbackId>(DefaultCallback::kRootRule), 1);
+
+ // Callbacks on non-terminal rules.
+ rules.Add("<output_1>", {"the", "quick", "brown", "fox"},
+ static_cast<CallbackId>(DefaultCallback::kRootRule), 2);
+ rules.Add("<output_2>", {"the", "quick"},
+ static_cast<CallbackId>(DefaultCallback::kRootRule), 3);
+ rules.Add("<output_3>", {"brown", "fox"},
+ static_cast<CallbackId>(DefaultCallback::kRootRule), 4);
+
+ // Now a complex thing: "the* brown fox".
+ rules.Add("<thestarbrownfox>", {"brown", "fox"},
+ static_cast<CallbackId>(DefaultCallback::kRootRule), 5);
+ rules.Add("<thestarbrownfox>", {"the", "<thestarbrownfox>"},
+ static_cast<CallbackId>(DefaultCallback::kRootRule), 5);
+
+ return rules.Finalize().SerializeAsFlatbuffer(
+ /*include_debug_information=*/true);
+}
+
+Nonterm FindNontermForName(const RulesSet* rules,
+ const std::string& nonterminal_name) {
+ for (const RulesSet_::DebugInformation_::NonterminalNamesEntry* entry :
+ *rules->debug_information()->nonterminal_names()) {
+ if (entry->value()->str() == nonterminal_name) {
+ return entry->key();
+ }
+ }
+ return kUnassignedNonterm;
+}
+
+TEST_F(MatcherTest, HandlesDerivationsOfRules) {
+ const std::string rules_buffer = CreateTestGrammar();
+ const RulesSet* rules_set =
+ flatbuffers::GetRoot<RulesSet>(rules_buffer.data());
+ Matcher matcher(&unilib_, rules_set, &arena_);
+
+ matcher.AddTerminal(0, 1, "the");
+ matcher.AddTerminal(1, 2, "quick");
+ matcher.AddTerminal(2, 3, "brown");
+ matcher.AddTerminal(3, 4, "fox");
+ matcher.AddTerminal(3, 5, "fox");
+ matcher.AddTerminal(4, 6, "fox"); // Not adjacent to "brown".
+
+ EXPECT_THAT(GetMatchResults(matcher.chart(), rules_set->debug_information()),
+ ElementsAre(
+ // the
+ IsDerivation(0, 1, "<output_0>", 1),
+
+ // quick
+ IsDerivation(1, 2, "<output_5>", 6),
+ IsDerivation(0, 2, "<output_2>", 3),
+
+ // brown
+
+ // fox
+ IsDerivation(0, 4, "<output_1>", 2),
+ IsDerivation(2, 4, "<output_3>", 4),
+ IsDerivation(2, 4, "<thestarbrownfox>", 5),
+
+ // fox
+ IsDerivation(0, 5, "<output_1>", 2),
+ IsDerivation(2, 5, "<output_3>", 4),
+ IsDerivation(2, 5, "<thestarbrownfox>", 5)));
+}
+
+TEST_F(MatcherTest, HandlesRecursiveRules) {
+ const std::string rules_buffer = CreateTestGrammar();
+ const RulesSet* rules_set =
+ flatbuffers::GetRoot<RulesSet>(rules_buffer.data());
+ Matcher matcher(&unilib_, rules_set, &arena_);
+
+ matcher.AddTerminal(0, 1, "the");
+ matcher.AddTerminal(1, 2, "the");
+ matcher.AddTerminal(2, 4, "the");
+ matcher.AddTerminal(3, 4, "the");
+ matcher.AddTerminal(4, 5, "brown");
+ matcher.AddTerminal(5, 6, "fox"); // Generates 5 of <thestarbrownfox>
+
+ EXPECT_THAT(GetMatchResults(matcher.chart(), rules_set->debug_information()),
+ ElementsAre(IsTerminal(0, 1, "the"), IsTerminal(1, 2, "the"),
+ IsTerminal(2, 4, "the"), IsTerminal(3, 4, "the"),
+ IsNonterminal(4, 6, "<output_3>"),
+ IsNonterminal(4, 6, "<thestarbrownfox>"),
+ IsNonterminal(3, 6, "<thestarbrownfox>"),
+ IsNonterminal(2, 6, "<thestarbrownfox>"),
+ IsNonterminal(1, 6, "<thestarbrownfox>"),
+ IsNonterminal(0, 6, "<thestarbrownfox>")));
+}
+
+TEST_F(MatcherTest, HandlesManualAddParseTreeCalls) {
+ const std::string rules_buffer = CreateTestGrammar();
+ const RulesSet* rules_set =
+ flatbuffers::GetRoot<RulesSet>(rules_buffer.data());
+ Matcher matcher(&unilib_, rules_set, &arena_);
+
+ // Test having the lexer call AddParseTree() instead of AddTerminal()
+ matcher.AddTerminal(-4, 37, "the");
+ matcher.AddParseTree(arena_.AllocAndInit<ParseTree>(
+ FindNontermForName(rules_set, "<thestarbrownfox>"), CodepointSpan{37, 42},
+ /*match_offset=*/37, ParseTree::Type::kDefault));
+
+ EXPECT_THAT(GetMatchResults(matcher.chart(), rules_set->debug_information()),
+ ElementsAre(IsTerminal(-4, 37, "the"),
+ IsNonterminal(-4, 42, "<thestarbrownfox>")));
+}
+
+TEST_F(MatcherTest, HandlesOptionalRuleElements) {
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ Rules rules(locale_shard_map);
+ rules.Add("<output_0>", {"a?", "b?", "c?", "d?", "e"},
+ static_cast<CallbackId>(DefaultCallback::kRootRule));
+ rules.Add("<output_1>", {"a", "b?", "c", "d?", "e"},
+ static_cast<CallbackId>(DefaultCallback::kRootRule));
+ rules.Add("<output_2>", {"a", "b?", "c", "d", "e?"},
+ static_cast<CallbackId>(DefaultCallback::kRootRule));
+
+ const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer(
+ /*include_debug_information=*/true);
+
+ const RulesSet* rules_set =
+ flatbuffers::GetRoot<RulesSet>(rules_buffer.data());
+ Matcher matcher(&unilib_, rules_set, &arena_);
+
+ // Run the matcher on "a b c d e".
+ matcher.AddTerminal(0, 1, "a");
+ matcher.AddTerminal(1, 2, "b");
+ matcher.AddTerminal(2, 3, "c");
+ matcher.AddTerminal(3, 4, "d");
+ matcher.AddTerminal(4, 5, "e");
+
+ EXPECT_THAT(
+ GetMatchResults(matcher.chart(), rules_set->debug_information()),
+ ElementsAre(
+ IsNonterminal(0, 4, "<output_2>"), IsTerminal(4, 5, "e"),
+ IsNonterminal(0, 5, "<output_0>"), IsNonterminal(0, 5, "<output_1>"),
+ IsNonterminal(0, 5, "<output_2>"), IsNonterminal(1, 5, "<output_0>"),
+ IsNonterminal(2, 5, "<output_0>"),
+ IsNonterminal(3, 5, "<output_0>")));
+}
+
+TEST_F(MatcherTest, HandlesWhitespaceGapLimits) {
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ Rules rules(locale_shard_map);
+ rules.Add("<iata>", {"lx"});
+ rules.Add("<iata>", {"aa"});
+ // Require no whitespace between code and flight number.
+ rules.Add("<flight_number>", {"<iata>", "<4_digits>"},
+ /*callback=*/static_cast<CallbackId>(DefaultCallback::kRootRule), 0,
+ /*max_whitespace_gap=*/0);
+ const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer(
+ /*include_debug_information=*/true);
+ const RulesSet* rules_set =
+ flatbuffers::GetRoot<RulesSet>(rules_buffer.data());
+
+ // Check that the grammar triggers on LX1138.
+ {
+ Matcher matcher(&unilib_, rules_set, &arena_);
+ matcher.AddTerminal(0, 2, "LX");
+ matcher.AddParseTree(arena_.AllocAndInit<ParseTree>(
+ rules_set->nonterminals()->n_digits_nt()->Get(4 - 1),
+ CodepointSpan{2, 6}, /*match_offset=*/2, ParseTree::Type::kDefault));
+ EXPECT_THAT(
+ GetMatchResults(matcher.chart(), rules_set->debug_information()),
+ ElementsAre(IsNonterminal(0, 6, "<flight_number>")));
+ }
+
+ // Check that the grammar doesn't trigger on LX 1138.
+ {
+ Matcher matcher(&unilib_, rules_set, &arena_);
+ matcher.AddTerminal(6, 8, "LX");
+ matcher.AddParseTree(arena_.AllocAndInit<ParseTree>(
+ rules_set->nonterminals()->n_digits_nt()->Get(4 - 1),
+ CodepointSpan{9, 13}, /*match_offset=*/8, ParseTree::Type::kDefault));
+ EXPECT_THAT(
+ GetMatchResults(matcher.chart(), rules_set->debug_information()),
+ IsEmpty());
+ }
+}
+
+TEST_F(MatcherTest, HandlesCaseSensitiveTerminals) {
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ Rules rules(locale_shard_map);
+ rules.Add("<iata>", {"LX"}, /*callback=*/kNoCallback, 0,
+ /*max_whitespace_gap*/ -1, /*case_sensitive=*/true);
+ rules.Add("<iata>", {"AA"}, /*callback=*/kNoCallback, 0,
+ /*max_whitespace_gap*/ -1, /*case_sensitive=*/true);
+ rules.Add("<iata>", {"dl"}, /*callback=*/kNoCallback, 0,
+ /*max_whitespace_gap*/ -1, /*case_sensitive=*/false);
+ // Require no whitespace between code and flight number.
+ rules.Add("<flight_number>", {"<iata>", "<4_digits>"},
+ /*callback=*/static_cast<CallbackId>(DefaultCallback::kRootRule), 0,
+ /*max_whitespace_gap=*/0);
+ const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer(
+ /*include_debug_information=*/true);
+ const RulesSet* rules_set =
+ flatbuffers::GetRoot<RulesSet>(rules_buffer.data());
+
+ // Check that the grammar triggers on LX1138.
+ {
+ Matcher matcher(&unilib_, rules_set, &arena_);
+ matcher.AddTerminal(0, 2, "LX");
+ matcher.AddParseTree(arena_.AllocAndInit<ParseTree>(
+ rules_set->nonterminals()->n_digits_nt()->Get(4 - 1),
+ CodepointSpan{2, 6}, /*match_offset=*/2, ParseTree::Type::kDefault));
+ EXPECT_THAT(
+ GetMatchResults(matcher.chart(), rules_set->debug_information()),
+ ElementsAre(IsNonterminal(0, 6, "<flight_number>")));
+ }
+
+ // Check that the grammar doesn't trigger on lx1138.
+ {
+ Matcher matcher(&unilib_, rules_set, &arena_);
+ matcher.AddTerminal(6, 8, "lx");
+ matcher.AddParseTree(arena_.AllocAndInit<ParseTree>(
+ rules_set->nonterminals()->n_digits_nt()->Get(4 - 1),
+ CodepointSpan{8, 12}, /*match_offset=*/8, ParseTree::Type::kDefault));
+ EXPECT_THAT(matcher.chart().derivations(), IsEmpty());
+ }
+
+ // Check that the grammar does trigger on dl1138.
+ {
+ Matcher matcher(&unilib_, rules_set, &arena_);
+ matcher.AddTerminal(12, 14, "dl");
+ matcher.AddParseTree(arena_.AllocAndInit<ParseTree>(
+ rules_set->nonterminals()->n_digits_nt()->Get(4 - 1),
+ CodepointSpan{14, 18}, /*match_offset=*/14, ParseTree::Type::kDefault));
+ EXPECT_THAT(
+ GetMatchResults(matcher.chart(), rules_set->debug_information()),
+ ElementsAre(IsNonterminal(12, 18, "<flight_number>")));
+ }
+}
+
+TEST_F(MatcherTest, HandlesExclusions) {
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ Rules rules(locale_shard_map);
+
+ rules.Add("<all_zeros>", {"0000"});
+ rules.AddWithExclusion("<flight_code>", {"<4_digits>"},
+ /*excluded_nonterminal=*/"<all_zeros>");
+ rules.Add("<iata>", {"lx"});
+ rules.Add("<iata>", {"aa"});
+ rules.Add("<iata>", {"dl"});
+ // Require no whitespace between code and flight number.
+ rules.Add("<flight_number>", {"<iata>", "<flight_code>"},
+ static_cast<CallbackId>(DefaultCallback::kRootRule));
+ const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer(
+ /*include_debug_information=*/true);
+ const RulesSet* rules_set =
+ flatbuffers::GetRoot<RulesSet>(rules_buffer.data());
+
+ // Check that the grammar triggers on LX1138.
+ {
+ Matcher matcher(&unilib_, rules_set, &arena_);
+ matcher.AddTerminal(0, 2, "LX");
+ matcher.AddParseTree(arena_.AllocAndInit<ParseTree>(
+ rules_set->nonterminals()->n_digits_nt()->Get(4 - 1),
+ CodepointSpan{2, 6}, /*match_offset=*/2, ParseTree::Type::kDefault));
+ matcher.Finish();
+ EXPECT_THAT(
+ GetMatchResults(matcher.chart(), rules_set->debug_information()),
+ ElementsAre(IsNonterminal(0, 6, "<flight_number>")));
+ }
+
+ // Check that the grammar doesn't trigger on LX0000.
+ {
+ Matcher matcher(&unilib_, rules_set, &arena_);
+ matcher.AddTerminal(6, 8, "LX");
+ matcher.AddTerminal(8, 12, "0000");
+ matcher.AddParseTree(arena_.AllocAndInit<ParseTree>(
+ rules_set->nonterminals()->n_digits_nt()->Get(4 - 1),
+ CodepointSpan{8, 12}, /*match_offset=*/8, ParseTree::Type::kDefault));
+ matcher.Finish();
+ EXPECT_THAT(matcher.chart().derivations(), IsEmpty());
+ }
+}
+
+} // namespace
+} // namespace libtextclassifier3::grammar
diff --git a/native/utils/grammar/parsing/parse-tree.cc b/native/utils/grammar/parsing/parse-tree.cc
new file mode 100644
index 0000000..8a53173
--- /dev/null
+++ b/native/utils/grammar/parsing/parse-tree.cc
@@ -0,0 +1,55 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/grammar/parsing/parse-tree.h"
+
+#include <algorithm>
+#include <stack>
+
+namespace libtextclassifier3::grammar {
+
+void Traverse(const ParseTree* root,
+ const std::function<bool(const ParseTree*)>& node_fn) {
+ std::stack<const ParseTree*> open;
+ open.push(root);
+
+ while (!open.empty()) {
+ const ParseTree* node = open.top();
+ open.pop();
+ if (!node_fn(node) || node->IsLeaf()) {
+ continue;
+ }
+ open.push(node->rhs2);
+ if (node->rhs1 != nullptr) {
+ open.push(node->rhs1);
+ }
+ }
+}
+
+std::vector<const ParseTree*> SelectAll(
+ const ParseTree* root,
+ const std::function<bool(const ParseTree*)>& pred_fn) {
+ std::vector<const ParseTree*> result;
+ Traverse(root, [&result, pred_fn](const ParseTree* node) {
+ if (pred_fn(node)) {
+ result.push_back(node);
+ }
+ return true;
+ });
+ return result;
+}
+
+} // namespace libtextclassifier3::grammar
diff --git a/native/utils/grammar/parsing/parse-tree.h b/native/utils/grammar/parsing/parse-tree.h
new file mode 100644
index 0000000..d3075d8
--- /dev/null
+++ b/native/utils/grammar/parsing/parse-tree.h
@@ -0,0 +1,195 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_PARSING_PARSE_TREE_H_
+#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_PARSING_PARSE_TREE_H_
+
+#include <functional>
+#include <vector>
+
+#include "annotator/types.h"
+#include "utils/grammar/semantics/expression_generated.h"
+#include "utils/grammar/types.h"
+#include "utils/strings/stringpiece.h"
+
+namespace libtextclassifier3::grammar {
+
+// Represents a parse tree for a match that was found for a nonterminal.
+struct ParseTree {
+ enum class Type : int8 {
+ // Default, untyped match.
+ kDefault = 0,
+
+ // An assertion match (see: AssertionNode).
+ kAssertion = 1,
+
+ // A value mapping match (see: MappingNode).
+ kMapping = 2,
+
+ // An exclusion match (see: ExclusionNode).
+ kExclusion = 3,
+
+ // A match for an annotation (see: AnnotationNode).
+ kAnnotation = 4,
+
+ // A match for a semantic annotation (see: SemanticExpressionNode).
+ kExpression = 5,
+ };
+
+ explicit ParseTree() = default;
+ explicit ParseTree(const Nonterm lhs, const CodepointSpan& codepoint_span,
+ const int match_offset, const Type type)
+ : lhs(lhs),
+ type(type),
+ codepoint_span(codepoint_span),
+ match_offset(match_offset) {}
+
+ // For binary rule matches: rhs1 != NULL and rhs2 != NULL
+ // unary rule matches: rhs1 == NULL and rhs2 != NULL
+ // terminal rule matches: rhs1 != NULL and rhs2 == NULL
+ // custom leaves: rhs1 == NULL and rhs2 == NULL
+ bool IsInteriorNode() const { return rhs2 != nullptr; }
+ bool IsLeaf() const { return !rhs2; }
+
+ bool IsBinaryRule() const { return rhs1 && rhs2; }
+ bool IsUnaryRule() const { return !rhs1 && rhs2; }
+ bool IsTerminalRule() const { return rhs1 && !rhs2; }
+ bool HasLeadingWhitespace() const {
+ return codepoint_span.first != match_offset;
+ }
+
+ const ParseTree* unary_rule_rhs() const { return rhs2; }
+
+ // Used in singly-linked queue of matches for processing.
+ ParseTree* next = nullptr;
+
+ // Nonterminal we found a match for.
+ Nonterm lhs = kUnassignedNonterm;
+
+ // Type of the match.
+ Type type = Type::kDefault;
+
+ // The span in codepoints.
+ CodepointSpan codepoint_span;
+
+ // The begin codepoint offset used during matching.
+ // This is usually including any prefix whitespace.
+ int match_offset;
+
+ union {
+ // The first sub match for binary rules.
+ const ParseTree* rhs1 = nullptr;
+
+ // The terminal, for terminal rules.
+ const char* terminal;
+ };
+ // First or second sub-match for interior nodes.
+ const ParseTree* rhs2 = nullptr;
+};
+
+// Node type to keep track of associated values.
+struct MappingNode : public ParseTree {
+ explicit MappingNode(const Nonterm arg_lhs,
+ const CodepointSpan arg_codepoint_span,
+ const int arg_match_offset, const int64 arg_value)
+ : ParseTree(arg_lhs, arg_codepoint_span, arg_match_offset,
+ Type::kMapping),
+ id(arg_value) {}
+ // The associated id or value.
+ int64 id;
+};
+
+// Node type to keep track of assertions.
+struct AssertionNode : public ParseTree {
+ explicit AssertionNode(const Nonterm arg_lhs,
+ const CodepointSpan arg_codepoint_span,
+ const int arg_match_offset, const bool arg_negative)
+ : ParseTree(arg_lhs, arg_codepoint_span, arg_match_offset,
+ Type::kAssertion),
+ negative(arg_negative) {}
+ // If true, the assertion is negative and will be valid if the input doesn't
+ // match.
+ bool negative;
+};
+
+// Node type to define exclusions.
+struct ExclusionNode : public ParseTree {
+ explicit ExclusionNode(const Nonterm arg_lhs,
+ const CodepointSpan arg_codepoint_span,
+ const int arg_match_offset,
+ const Nonterm arg_exclusion_nonterm)
+ : ParseTree(arg_lhs, arg_codepoint_span, arg_match_offset,
+ Type::kExclusion),
+ exclusion_nonterm(arg_exclusion_nonterm) {}
+ // The nonterminal that denotes matches to exclude from a successful match.
+ // So the match is only valid if there is no match of `exclusion_nonterm`
+ // spanning the same text range.
+ Nonterm exclusion_nonterm;
+};
+
+// Match to represent an annotator annotated span in the grammar.
+struct AnnotationNode : public ParseTree {
+ explicit AnnotationNode(const Nonterm arg_lhs,
+ const CodepointSpan arg_codepoint_span,
+ const int arg_match_offset,
+ const ClassificationResult* arg_annotation)
+ : ParseTree(arg_lhs, arg_codepoint_span, arg_match_offset,
+ Type::kAnnotation),
+ annotation(arg_annotation) {}
+ const ClassificationResult* annotation;
+};
+
+// Node type to represent an associated semantic expression.
+struct SemanticExpressionNode : public ParseTree {
+ explicit SemanticExpressionNode(const Nonterm arg_lhs,
+ const CodepointSpan arg_codepoint_span,
+ const int arg_match_offset,
+ const SemanticExpression* arg_expression)
+ : ParseTree(arg_lhs, arg_codepoint_span, arg_match_offset,
+ Type::kExpression),
+ expression(arg_expression) {}
+ const SemanticExpression* expression;
+};
+
+// Utility functions for parse tree traversal.
+
+// Does a preorder traversal, calling `node_fn` on each node.
+// `node_fn` is expected to return whether to continue expanding a node.
+void Traverse(const ParseTree* root,
+ const std::function<bool(const ParseTree*)>& node_fn);
+
+// Does a preorder traversal, selecting all nodes where `pred_fn` returns true.
+std::vector<const ParseTree*> SelectAll(
+ const ParseTree* root,
+ const std::function<bool(const ParseTree*)>& pred_fn);
+
+// Retrieves all nodes of a given type.
+template <typename T>
+const std::vector<const T*> SelectAllOfType(const ParseTree* root,
+ const ParseTree::Type type) {
+ std::vector<const T*> result;
+ Traverse(root, [&result, type](const ParseTree* node) {
+ if (node->type == type) {
+ result.push_back(static_cast<const T*>(node));
+ }
+ return true;
+ });
+ return result;
+}
+
+} // namespace libtextclassifier3::grammar
+
+#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_PARSING_PARSE_TREE_H_
diff --git a/native/utils/grammar/parsing/parser.cc b/native/utils/grammar/parsing/parser.cc
new file mode 100644
index 0000000..4e39a98
--- /dev/null
+++ b/native/utils/grammar/parsing/parser.cc
@@ -0,0 +1,278 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/grammar/parsing/parser.h"
+
+#include <unordered_map>
+
+#include "utils/grammar/parsing/parse-tree.h"
+#include "utils/grammar/rules-utils.h"
+#include "utils/grammar/types.h"
+#include "utils/zlib/zlib.h"
+#include "utils/zlib/zlib_regex.h"
+
+namespace libtextclassifier3::grammar {
+namespace {
+
+inline bool CheckMemoryUsage(const UnsafeArena* arena) {
+ // The maximum memory usage for matching.
+ constexpr int kMaxMemoryUsage = 1 << 20;
+ return arena->status().bytes_allocated() <= kMaxMemoryUsage;
+}
+
+// Maps a codepoint to include the token padding if it aligns with a token
+// start. Whitespace is ignored when symbols are fed to the matcher. Preceding
+// whitespace is merged to the match start so that tokens and non-terminals
+// appear next to each other without whitespace. For text or regex annotations,
+// we therefore merge the whitespace padding to the start if the annotation
+// starts at a token.
+int MapCodepointToTokenPaddingIfPresent(
+ const std::unordered_map<CodepointIndex, CodepointIndex>& token_alignment,
+ const int start) {
+ const auto it = token_alignment.find(start);
+ if (it != token_alignment.end()) {
+ return it->second;
+ }
+ return start;
+}
+
+} // namespace
+
+Parser::Parser(const UniLib* unilib, const RulesSet* rules)
+ : unilib_(*unilib),
+ rules_(rules),
+ lexer_(unilib),
+ nonterminals_(rules_->nonterminals()),
+ rules_locales_(ParseRulesLocales(rules_)),
+ regex_annotators_(BuildRegexAnnotators()) {}
+
+// Uncompresses and build the defined regex annotators.
+std::vector<Parser::RegexAnnotator> Parser::BuildRegexAnnotators() const {
+ std::vector<RegexAnnotator> result;
+ if (rules_->regex_annotator() != nullptr) {
+ std::unique_ptr<ZlibDecompressor> decompressor =
+ ZlibDecompressor::Instance();
+ result.reserve(rules_->regex_annotator()->size());
+ for (const RulesSet_::RegexAnnotator* regex_annotator :
+ *rules_->regex_annotator()) {
+ result.push_back(
+ {UncompressMakeRegexPattern(unilib_, regex_annotator->pattern(),
+ regex_annotator->compressed_pattern(),
+ rules_->lazy_regex_compilation(),
+ decompressor.get()),
+ regex_annotator->nonterminal()});
+ }
+ }
+ return result;
+}
+
+std::vector<Symbol> Parser::SortedSymbolsForInput(const TextContext& input,
+ UnsafeArena* arena) const {
+ // Whitespace is ignored when symbols are fed to the matcher.
+ // For regex matches and existing text annotations we therefore have to merge
+ // preceding whitespace to the match start so that tokens and non-terminals
+ // appear as next to each other without whitespace. We keep track of real
+ // token starts and precending whitespace in `token_match_start`, so that we
+ // can extend a match's start to include the preceding whitespace.
+ std::unordered_map<CodepointIndex, CodepointIndex> token_match_start;
+ for (int i = input.context_span.first + 1; i < input.context_span.second;
+ i++) {
+ const CodepointIndex token_start = input.tokens[i].start;
+ const CodepointIndex prev_token_end = input.tokens[i - 1].end;
+ if (token_start != prev_token_end) {
+ token_match_start[token_start] = prev_token_end;
+ }
+ }
+
+ std::vector<Symbol> symbols;
+ CodepointIndex match_offset = input.tokens[input.context_span.first].start;
+
+ // Add start symbol.
+ if (input.context_span.first == 0 &&
+ nonterminals_->start_nt() != kUnassignedNonterm) {
+ match_offset = 0;
+ symbols.emplace_back(arena->AllocAndInit<ParseTree>(
+ nonterminals_->start_nt(), CodepointSpan{0, 0},
+ /*match_offset=*/0, ParseTree::Type::kDefault));
+ }
+
+ if (nonterminals_->wordbreak_nt() != kUnassignedNonterm) {
+ symbols.emplace_back(arena->AllocAndInit<ParseTree>(
+ nonterminals_->wordbreak_nt(),
+ CodepointSpan{match_offset, match_offset},
+ /*match_offset=*/match_offset, ParseTree::Type::kDefault));
+ }
+
+ // Add symbols from tokens.
+ for (int i = input.context_span.first; i < input.context_span.second; i++) {
+ const Token& token = input.tokens[i];
+ lexer_.AppendTokenSymbols(token.value, /*match_offset=*/match_offset,
+ CodepointSpan{token.start, token.end}, &symbols);
+ match_offset = token.end;
+
+ // Add word break symbol.
+ if (nonterminals_->wordbreak_nt() != kUnassignedNonterm) {
+ symbols.emplace_back(arena->AllocAndInit<ParseTree>(
+ nonterminals_->wordbreak_nt(),
+ CodepointSpan{match_offset, match_offset},
+ /*match_offset=*/match_offset, ParseTree::Type::kDefault));
+ }
+ }
+
+ // Add end symbol if used by the grammar.
+ if (input.context_span.second == input.tokens.size() &&
+ nonterminals_->end_nt() != kUnassignedNonterm) {
+ symbols.emplace_back(arena->AllocAndInit<ParseTree>(
+ nonterminals_->end_nt(), CodepointSpan{match_offset, match_offset},
+ /*match_offset=*/match_offset, ParseTree::Type::kDefault));
+ }
+
+ // Add symbols from the regex annotators.
+ const CodepointIndex context_start =
+ input.tokens[input.context_span.first].start;
+ const CodepointIndex context_end =
+ input.tokens[input.context_span.second - 1].end;
+ for (const RegexAnnotator& regex_annotator : regex_annotators_) {
+ std::unique_ptr<UniLib::RegexMatcher> regex_matcher =
+ regex_annotator.pattern->Matcher(UnicodeText::Substring(
+ input.text, context_start, context_end, /*do_copy=*/false));
+ int status = UniLib::RegexMatcher::kNoError;
+ while (regex_matcher->Find(&status) &&
+ status == UniLib::RegexMatcher::kNoError) {
+ const CodepointSpan span{regex_matcher->Start(0, &status) + context_start,
+ regex_matcher->End(0, &status) + context_start};
+ symbols.emplace_back(arena->AllocAndInit<ParseTree>(
+ regex_annotator.nonterm, span, /*match_offset=*/
+ MapCodepointToTokenPaddingIfPresent(token_match_start, span.first),
+ ParseTree::Type::kDefault));
+ }
+ }
+
+ // Add symbols based on annotations.
+ if (auto annotation_nonterminals = nonterminals_->annotation_nt()) {
+ for (const AnnotatedSpan& annotated_span : input.annotations) {
+ const ClassificationResult& classification =
+ annotated_span.classification.front();
+ if (auto entry = annotation_nonterminals->LookupByKey(
+ classification.collection.c_str())) {
+ symbols.emplace_back(arena->AllocAndInit<AnnotationNode>(
+ entry->value(), annotated_span.span, /*match_offset=*/
+ MapCodepointToTokenPaddingIfPresent(token_match_start,
+ annotated_span.span.first),
+ &classification));
+ }
+ }
+ }
+
+ std::sort(symbols.begin(), symbols.end(),
+ [](const Symbol& a, const Symbol& b) {
+ // Sort by increasing (end, start) position to guarantee the
+ // matcher requirement that the tokens are fed in non-decreasing
+ // end position order.
+ return std::tie(a.codepoint_span.second, a.codepoint_span.first) <
+ std::tie(b.codepoint_span.second, b.codepoint_span.first);
+ });
+
+ return symbols;
+}
+
+void Parser::EmitSymbol(const Symbol& symbol, UnsafeArena* arena,
+ Matcher* matcher) const {
+ if (!CheckMemoryUsage(arena)) {
+ return;
+ }
+ switch (symbol.type) {
+ case Symbol::Type::TYPE_PARSE_TREE: {
+ // Just emit the parse tree.
+ matcher->AddParseTree(symbol.parse_tree);
+ return;
+ }
+ case Symbol::Type::TYPE_DIGITS: {
+ // Emit <digits> if used by the rules.
+ if (nonterminals_->digits_nt() != kUnassignedNonterm) {
+ matcher->AddParseTree(arena->AllocAndInit<ParseTree>(
+ nonterminals_->digits_nt(), symbol.codepoint_span,
+ symbol.match_offset, ParseTree::Type::kDefault));
+ }
+
+ // Emit <n_digits> if used by the rules.
+ if (nonterminals_->n_digits_nt() != nullptr) {
+ const int num_digits =
+ symbol.codepoint_span.second - symbol.codepoint_span.first;
+ if (num_digits <= nonterminals_->n_digits_nt()->size()) {
+ const Nonterm n_digits_nt =
+ nonterminals_->n_digits_nt()->Get(num_digits - 1);
+ if (n_digits_nt != kUnassignedNonterm) {
+ matcher->AddParseTree(arena->AllocAndInit<ParseTree>(
+ nonterminals_->n_digits_nt()->Get(num_digits - 1),
+ symbol.codepoint_span, symbol.match_offset,
+ ParseTree::Type::kDefault));
+ }
+ }
+ }
+ break;
+ }
+ case Symbol::Type::TYPE_TERM: {
+ // Emit <uppercase_token> if used by the rules.
+ if (nonterminals_->uppercase_token_nt() != 0 &&
+ unilib_.IsUpperText(
+ UTF8ToUnicodeText(symbol.lexeme, /*do_copy=*/false))) {
+ matcher->AddParseTree(arena->AllocAndInit<ParseTree>(
+ nonterminals_->uppercase_token_nt(), symbol.codepoint_span,
+ symbol.match_offset, ParseTree::Type::kDefault));
+ }
+ break;
+ }
+ default:
+ break;
+ }
+
+ // Emit the token as terminal.
+ matcher->AddTerminal(symbol.codepoint_span, symbol.match_offset,
+ symbol.lexeme);
+
+ // Emit <token> if used by rules.
+ matcher->AddParseTree(arena->AllocAndInit<ParseTree>(
+ nonterminals_->token_nt(), symbol.codepoint_span, symbol.match_offset,
+ ParseTree::Type::kDefault));
+}
+
+// Parses an input text and returns the root rule derivations.
+std::vector<Derivation> Parser::Parse(const TextContext& input,
+ UnsafeArena* arena) const {
+ // Check the tokens, input can be non-empty (whitespace) but have no tokens.
+ if (input.tokens.empty()) {
+ return {};
+ }
+
+ // Select locale matching rules.
+ std::vector<const RulesSet_::Rules*> locale_rules =
+ SelectLocaleMatchingShards(rules_, rules_locales_, input.locales);
+
+ if (locale_rules.empty()) {
+ // Nothing to do.
+ return {};
+ }
+
+ Matcher matcher(&unilib_, rules_, locale_rules, arena);
+ for (const Symbol& symbol : SortedSymbolsForInput(input, arena)) {
+ EmitSymbol(symbol, arena, &matcher);
+ }
+ matcher.Finish();
+ return matcher.chart().derivations();
+}
+
+} // namespace libtextclassifier3::grammar
diff --git a/native/utils/grammar/parsing/parser.h b/native/utils/grammar/parsing/parser.h
new file mode 100644
index 0000000..0b320a0
--- /dev/null
+++ b/native/utils/grammar/parsing/parser.h
@@ -0,0 +1,82 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_PARSING_PARSER_H_
+#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_PARSING_PARSER_H_
+
+#include <vector>
+
+#include "annotator/types.h"
+#include "utils/base/arena.h"
+#include "utils/grammar/parsing/derivation.h"
+#include "utils/grammar/parsing/lexer.h"
+#include "utils/grammar/parsing/matcher.h"
+#include "utils/grammar/rules_generated.h"
+#include "utils/grammar/text-context.h"
+#include "utils/i18n/locale.h"
+#include "utils/utf8/unilib.h"
+
+namespace libtextclassifier3::grammar {
+
+// Syntactic parsing pass.
+// The parser validates and deduplicates candidates produced by the grammar
+// matcher. It augments the parse trees with derivation information for semantic
+// evaluation.
+class Parser {
+ public:
+ explicit Parser(const UniLib* unilib, const RulesSet* rules);
+
+ // Parses an input text and returns the root rule derivations.
+ std::vector<Derivation> Parse(const TextContext& input,
+ UnsafeArena* arena) const;
+
+ private:
+ struct RegexAnnotator {
+ std::unique_ptr<UniLib::RegexPattern> pattern;
+ Nonterm nonterm;
+ };
+
+ // Uncompresses and build the defined regex annotators.
+ std::vector<RegexAnnotator> BuildRegexAnnotators() const;
+
+ // Produces symbols for a text input to feed to a matcher.
+ // These are symbols for each tokens from the lexer, existing text annotations
+ // and regex annotations.
+ // The symbols are sorted with increasing end-positions to satisfy the matcher
+ // requirements.
+ std::vector<Symbol> SortedSymbolsForInput(const TextContext& input,
+ UnsafeArena* arena) const;
+
+ // Emits a symbol to the matcher.
+ void EmitSymbol(const Symbol& symbol, UnsafeArena* arena,
+ Matcher* matcher) const;
+
+ const UniLib& unilib_;
+ const RulesSet* rules_;
+ const Lexer lexer_;
+
+ // Pre-defined nonterminals.
+ const RulesSet_::Nonterminals* nonterminals_;
+
+ // Pre-parsed locales of the rules.
+ const std::vector<std::vector<Locale>> rules_locales_;
+
+ std::vector<RegexAnnotator> regex_annotators_;
+};
+
+} // namespace libtextclassifier3::grammar
+
+#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_PARSING_PARSER_H_
diff --git a/native/utils/grammar/parsing/parser_test.cc b/native/utils/grammar/parsing/parser_test.cc
new file mode 100644
index 0000000..183be0e
--- /dev/null
+++ b/native/utils/grammar/parsing/parser_test.cc
@@ -0,0 +1,320 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/grammar/parsing/parser.h"
+
+#include <string>
+#include <vector>
+
+#include "utils/grammar/parsing/derivation.h"
+#include "utils/grammar/rules_generated.h"
+#include "utils/grammar/testing/utils.h"
+#include "utils/grammar/types.h"
+#include "utils/grammar/utils/ir.h"
+#include "utils/grammar/utils/rules.h"
+#include "utils/i18n/locale.h"
+#include "utils/tokenizer.h"
+#include "utils/utf8/unicodetext.h"
+#include "utils/utf8/unilib.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3::grammar {
+namespace {
+
+using ::testing::ElementsAre;
+using ::testing::IsEmpty;
+
+class ParserTest : public GrammarTest {};
+
+TEST_F(ParserTest, ParsesSimpleRules) {
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ Rules rules(locale_shard_map);
+ rules.Add("<day>", {"<2_digits>"});
+ rules.Add("<month>", {"<2_digits>"});
+ rules.Add("<year>", {"<4_digits>"});
+ constexpr int kDate = 0;
+ rules.Add("<date>", {"<year>", "/", "<month>", "/", "<day>"},
+ static_cast<CallbackId>(DefaultCallback::kRootRule), kDate);
+ const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer();
+ Parser parser(unilib_.get(),
+ flatbuffers::GetRoot<RulesSet>(rules_buffer.data()));
+
+ EXPECT_THAT(ValidDeduplicatedDerivations(parser.Parse(
+ TextContextForText("Event: 2020/05/08"), &arena_)),
+ ElementsAre(IsDerivation(kDate, 7, 17)));
+}
+
+TEST_F(ParserTest, HandlesEmptyInput) {
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ Rules rules(locale_shard_map);
+ constexpr int kTest = 0;
+ rules.Add("<test>", {"test"},
+ static_cast<CallbackId>(DefaultCallback::kRootRule), kTest);
+ const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer();
+ Parser parser(unilib_.get(),
+ flatbuffers::GetRoot<RulesSet>(rules_buffer.data()));
+
+ EXPECT_THAT(ValidDeduplicatedDerivations(
+ parser.Parse(TextContextForText("Event: test"), &arena_)),
+ ElementsAre(IsDerivation(kTest, 7, 11)));
+
+ // Check that we bail out in case of empty input.
+ EXPECT_THAT(ValidDeduplicatedDerivations(
+ parser.Parse(TextContextForText(""), &arena_)),
+ IsEmpty());
+ EXPECT_THAT(ValidDeduplicatedDerivations(
+ parser.Parse(TextContextForText(" "), &arena_)),
+ IsEmpty());
+}
+
+TEST_F(ParserTest, HandlesUppercaseTokens) {
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ Rules rules(locale_shard_map);
+ constexpr int kScriptedReply = 0;
+ rules.Add("<test>", {"please?", "reply", "<uppercase_token>"},
+ static_cast<CallbackId>(DefaultCallback::kRootRule),
+ kScriptedReply);
+ const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer();
+ Parser parser(unilib_.get(),
+ flatbuffers::GetRoot<RulesSet>(rules_buffer.data()));
+
+ EXPECT_THAT(ValidDeduplicatedDerivations(parser.Parse(
+ TextContextForText("Reply STOP to cancel."), &arena_)),
+ ElementsAre(IsDerivation(kScriptedReply, 0, 10)));
+
+ EXPECT_THAT(ValidDeduplicatedDerivations(parser.Parse(
+ TextContextForText("Reply stop to cancel."), &arena_)),
+ IsEmpty());
+}
+
+TEST_F(ParserTest, HandlesAnchors) {
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ Rules rules(locale_shard_map);
+ constexpr int kScriptedReply = 0;
+ rules.Add("<test>", {"<^>", "reply", "<uppercase_token>", "<$>"},
+ static_cast<CallbackId>(DefaultCallback::kRootRule),
+ kScriptedReply);
+ const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer();
+ Parser parser(unilib_.get(),
+ flatbuffers::GetRoot<RulesSet>(rules_buffer.data()));
+
+ EXPECT_THAT(ValidDeduplicatedDerivations(
+ parser.Parse(TextContextForText("Reply STOP"), &arena_)),
+ ElementsAre(IsDerivation(kScriptedReply, 0, 10)));
+
+ EXPECT_THAT(ValidDeduplicatedDerivations(parser.Parse(
+ TextContextForText("Please reply STOP to cancel."), &arena_)),
+ IsEmpty());
+}
+
+TEST_F(ParserTest, HandlesWordBreaks) {
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ Rules rules(locale_shard_map);
+ rules.Add("<carrier>", {"lx"});
+ rules.Add("<carrier>", {"aa"});
+ constexpr int kFlight = 0;
+ rules.Add("<flight>", {"<carrier>", "<digits>", "<\b>"},
+ static_cast<CallbackId>(DefaultCallback::kRootRule), kFlight);
+ const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer();
+ Parser parser(unilib_.get(),
+ flatbuffers::GetRoot<RulesSet>(rules_buffer.data()));
+
+ // Make sure the grammar recognizes "LX 38".
+ EXPECT_THAT(
+ ValidDeduplicatedDerivations(parser.Parse(
+ TextContextForText("My flight is: LX 38. Arriving later"), &arena_)),
+ ElementsAre(IsDerivation(kFlight, 14, 19)));
+
+ // Make sure the grammar doesn't trigger on "LX 38.00".
+ EXPECT_THAT(ValidDeduplicatedDerivations(
+ parser.Parse(TextContextForText("LX 38.00"), &arena_)),
+ IsEmpty());
+}
+
+TEST_F(ParserTest, HandlesAnnotations) {
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ Rules rules(locale_shard_map);
+ constexpr int kCallPhone = 0;
+ rules.Add("<flight>", {"dial", "<phone>"},
+ static_cast<CallbackId>(DefaultCallback::kRootRule), kCallPhone);
+ rules.BindAnnotation("<phone>", "phone");
+ const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer();
+ Parser parser(unilib_.get(),
+ flatbuffers::GetRoot<RulesSet>(rules_buffer.data()));
+
+ TextContext context = TextContextForText("Please dial 911");
+
+ // Sanity check that we don't trigger if we don't feed the correct
+ // annotations.
+ EXPECT_THAT(ValidDeduplicatedDerivations(parser.Parse(context, &arena_)),
+ IsEmpty());
+
+ // Create a phone annotion.
+ AnnotatedSpan phone_span;
+ phone_span.span = CodepointSpan{12, 15};
+ phone_span.classification.emplace_back("phone", 1.0);
+ context.annotations.push_back(phone_span);
+ EXPECT_THAT(ValidDeduplicatedDerivations(parser.Parse(context, &arena_)),
+ ElementsAre(IsDerivation(kCallPhone, 7, 15)));
+}
+
+TEST_F(ParserTest, HandlesRegexAnnotators) {
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ Rules rules(locale_shard_map);
+ rules.AddRegex("<code>",
+ "(\"([A-Za-z]+)\"|\\b\"?(?:[A-Z]+[0-9]*|[0-9])\"?\\b)");
+ constexpr int kScriptedReply = 0;
+ rules.Add("<test>", {"please?", "reply", "<code>"},
+ static_cast<CallbackId>(DefaultCallback::kRootRule),
+ kScriptedReply);
+ const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer();
+ Parser parser(unilib_.get(),
+ flatbuffers::GetRoot<RulesSet>(rules_buffer.data()));
+
+ EXPECT_THAT(ValidDeduplicatedDerivations(parser.Parse(
+ TextContextForText("Reply STOP to cancel."), &arena_)),
+ ElementsAre(IsDerivation(kScriptedReply, 0, 10)));
+
+ EXPECT_THAT(ValidDeduplicatedDerivations(parser.Parse(
+ TextContextForText("Reply Stop to cancel."), &arena_)),
+ IsEmpty());
+}
+
+TEST_F(ParserTest, HandlesExclusions) {
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ Rules rules(locale_shard_map);
+ rules.Add("<excluded>", {"be", "safe"});
+ rules.AddWithExclusion("<tokens_but_not_excluded>", {"<token>", "<token>"},
+ /*excluded_nonterminal=*/"<excluded>");
+ constexpr int kSetReminder = 0;
+ rules.Add("<set_reminder>",
+ {"do", "not", "forget", "to", "<tokens_but_not_excluded>"},
+ static_cast<CallbackId>(DefaultCallback::kRootRule), kSetReminder);
+ const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer();
+ Parser parser(unilib_.get(),
+ flatbuffers::GetRoot<RulesSet>(rules_buffer.data()));
+
+ EXPECT_THAT(ValidDeduplicatedDerivations(parser.Parse(
+ TextContextForText("do not forget to be there"), &arena_)),
+ ElementsAre(IsDerivation(kSetReminder, 0, 25)));
+
+ EXPECT_THAT(ValidDeduplicatedDerivations(parser.Parse(
+ TextContextForText("do not forget to be safe"), &arena_)),
+ IsEmpty());
+}
+
+TEST_F(ParserTest, HandlesFillers) {
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ Rules rules(locale_shard_map);
+ constexpr int kSetReminder = 0;
+ rules.Add("<set_reminder>", {"do", "not", "forget", "to", "<filler>"},
+ static_cast<CallbackId>(DefaultCallback::kRootRule), kSetReminder);
+ const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer();
+ Parser parser(unilib_.get(),
+ flatbuffers::GetRoot<RulesSet>(rules_buffer.data()));
+
+ EXPECT_THAT(ValidDeduplicatedDerivations(parser.Parse(
+ TextContextForText("do not forget to be there"), &arena_)),
+ ElementsAre(IsDerivation(kSetReminder, 0, 25)));
+}
+
+TEST_F(ParserTest, HandlesAssertions) {
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ Rules rules(locale_shard_map);
+ rules.Add("<carrier>", {"lx"});
+ rules.Add("<carrier>", {"aa"});
+ rules.Add("<flight_code>", {"<2_digits>"});
+ rules.Add("<flight_code>", {"<3_digits>"});
+ rules.Add("<flight_code>", {"<4_digits>"});
+ // Flight: carrier + flight code and check right context.
+ constexpr int kFlight = 0;
+ rules.Add("<track_flight>",
+ {"<carrier>", "<flight_code>", "<context_assertion>?"},
+ static_cast<CallbackId>(DefaultCallback::kRootRule), kFlight);
+ // Exclude matches like: LX 38.00 etc.
+ rules.AddAssertion("<context_assertion>", {".?", "<digits>"},
+ /*negative=*/true);
+ const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer();
+ Parser parser(unilib_.get(),
+ flatbuffers::GetRoot<RulesSet>(rules_buffer.data()));
+
+ EXPECT_THAT(
+ ValidDeduplicatedDerivations(
+ parser.Parse(TextContextForText("LX38 aa 44 LX 38.38"), &arena_)),
+ ElementsAre(IsDerivation(kFlight, 0, 4), IsDerivation(kFlight, 5, 10)));
+}
+
+TEST_F(ParserTest, HandlesWhitespaceGapLimit) {
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ Rules rules(locale_shard_map);
+ rules.Add("<carrier>", {"lx"});
+ rules.Add("<carrier>", {"aa"});
+ rules.Add("<flight_code>", {"<2_digits>"});
+ rules.Add("<flight_code>", {"<3_digits>"});
+ rules.Add("<flight_code>", {"<4_digits>"});
+ // Flight: carrier + flight code and check right context.
+ constexpr int kFlight = 0;
+ rules.Add("<track_flight>", {"<carrier>", "<flight_code>"},
+ static_cast<CallbackId>(DefaultCallback::kRootRule), kFlight,
+ /*max_whitespace_gap=*/0);
+ const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer();
+ Parser parser(unilib_.get(),
+ flatbuffers::GetRoot<RulesSet>(rules_buffer.data()));
+
+ EXPECT_THAT(ValidDeduplicatedDerivations(parser.Parse(
+ TextContextForText("LX38 aa 44 LX 38"), &arena_)),
+ ElementsAre(IsDerivation(kFlight, 0, 4)));
+}
+
+TEST_F(ParserTest, HandlesCaseSensitiveMatching) {
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ Rules rules(locale_shard_map);
+ rules.Add("<carrier>", {"Lx"}, /*callback=*/kNoCallback, /*callback_param=*/0,
+ /*max_whitespace_gap=*/-1, /*case_sensitive=*/true);
+ rules.Add("<carrier>", {"AA"}, /*callback=*/kNoCallback, /*callback_param=*/0,
+ /*max_whitespace_gap=*/-1, /*case_sensitive=*/true);
+ rules.Add("<flight_code>", {"<2_digits>"});
+ rules.Add("<flight_code>", {"<3_digits>"});
+ rules.Add("<flight_code>", {"<4_digits>"});
+ // Flight: carrier + flight code and check right context.
+ constexpr int kFlight = 0;
+ rules.Add("<track_flight>", {"<carrier>", "<flight_code>"},
+ static_cast<CallbackId>(DefaultCallback::kRootRule), kFlight);
+ const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer();
+ Parser parser(unilib_.get(),
+ flatbuffers::GetRoot<RulesSet>(rules_buffer.data()));
+
+ EXPECT_THAT(
+ ValidDeduplicatedDerivations(
+ parser.Parse(TextContextForText("Lx38 AA 44 LX 38"), &arena_)),
+ ElementsAre(IsDerivation(kFlight, 0, 4), IsDerivation(kFlight, 5, 10)));
+}
+
+} // namespace
+} // namespace libtextclassifier3::grammar
diff --git a/native/utils/grammar/rules-utils.cc b/native/utils/grammar/rules-utils.cc
index 56c928a..5e8c189 100644
--- a/native/utils/grammar/rules-utils.cc
+++ b/native/utils/grammar/rules-utils.cc
@@ -54,70 +54,4 @@
return shards;
}
-std::vector<Derivation> DeduplicateDerivations(
- const std::vector<Derivation>& derivations) {
- std::vector<Derivation> sorted_candidates = derivations;
- std::stable_sort(
- sorted_candidates.begin(), sorted_candidates.end(),
- [](const Derivation& a, const Derivation& b) {
- // Sort by id.
- if (a.rule_id != b.rule_id) {
- return a.rule_id < b.rule_id;
- }
-
- // Sort by increasing start.
- if (a.match->codepoint_span.first != b.match->codepoint_span.first) {
- return a.match->codepoint_span.first < b.match->codepoint_span.first;
- }
-
- // Sort by decreasing end.
- return a.match->codepoint_span.second > b.match->codepoint_span.second;
- });
-
- // Deduplicate by overlap.
- std::vector<Derivation> result;
- for (int i = 0; i < sorted_candidates.size(); i++) {
- const Derivation& candidate = sorted_candidates[i];
- bool eliminated = false;
-
- // Due to the sorting above, the candidate can only be completely
- // intersected by a match before it in the sorted order.
- for (int j = i - 1; j >= 0; j--) {
- if (sorted_candidates[j].rule_id != candidate.rule_id) {
- break;
- }
- if (sorted_candidates[j].match->codepoint_span.first <=
- candidate.match->codepoint_span.first &&
- sorted_candidates[j].match->codepoint_span.second >=
- candidate.match->codepoint_span.second) {
- eliminated = true;
- break;
- }
- }
-
- if (!eliminated) {
- result.push_back(candidate);
- }
- }
- return result;
-}
-
-bool VerifyAssertions(const Match* match) {
- bool result = true;
- grammar::Traverse(match, [&result](const Match* node) {
- if (node->type != Match::kAssertionMatch) {
- // Only validation if all checks so far passed.
- return result;
- }
-
- // Positive assertions are by definition fulfilled,
- // fail if the assertion is negative.
- if (static_cast<const AssertionMatch*>(node)->negative) {
- result = false;
- }
- return result;
- });
- return result;
-}
-
} // namespace libtextclassifier3::grammar
diff --git a/native/utils/grammar/rules-utils.h b/native/utils/grammar/rules-utils.h
index e6ac541..64e8245 100644
--- a/native/utils/grammar/rules-utils.h
+++ b/native/utils/grammar/rules-utils.h
@@ -19,10 +19,8 @@
#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_RULES_UTILS_H_
#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_RULES_UTILS_H_
-#include <unordered_map>
#include <vector>
-#include "utils/grammar/match.h"
#include "utils/grammar/rules_generated.h"
#include "utils/i18n/locale.h"
@@ -37,22 +35,6 @@
const std::vector<std::vector<Locale>>& shard_locales,
const std::vector<Locale>& locales);
-// Deduplicates rule derivations by containing overlap.
-// The grammar system can output multiple candidates for optional parts.
-// For example if a rule has an optional suffix, we
-// will get two rule derivations when the suffix is present: one with and one
-// without the suffix. We therefore deduplicate by containing overlap, viz. from
-// two candidates we keep the longer one if it completely contains the shorter.
-struct Derivation {
- const Match* match;
- int64 rule_id;
-};
-std::vector<Derivation> DeduplicateDerivations(
- const std::vector<Derivation>& derivations);
-
-// Checks that all assertions of a match tree are fulfilled.
-bool VerifyAssertions(const Match* match);
-
} // namespace libtextclassifier3::grammar
#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_RULES_UTILS_H_
diff --git a/native/utils/grammar/rules-utils_test.cc b/native/utils/grammar/rules-utils_test.cc
deleted file mode 100644
index 6391be1..0000000
--- a/native/utils/grammar/rules-utils_test.cc
+++ /dev/null
@@ -1,85 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/grammar/rules-utils.h"
-
-#include <vector>
-
-#include "utils/grammar/match.h"
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-
-namespace libtextclassifier3::grammar {
-namespace {
-
-using testing::ElementsAre;
-using testing::Value;
-
-// Create test match object.
-Match CreateMatch(const CodepointIndex begin, const CodepointIndex end) {
- Match match;
- match.Init(0, CodepointSpan{begin, end},
- /*arg_match_offset=*/begin);
- return match;
-}
-
-MATCHER_P(IsDerivation, candidate, "") {
- return Value(arg.rule_id, candidate.rule_id) &&
- Value(arg.match, candidate.match);
-}
-
-TEST(UtilsTest, DeduplicatesMatches) {
- // Overlapping matches from the same rule.
- Match matches[] = {CreateMatch(0, 1), CreateMatch(1, 2), CreateMatch(0, 2)};
- const std::vector<Derivation> candidates = {{&matches[0], /*rule_id=*/0},
- {&matches[1], /*rule_id=*/0},
- {&matches[2], /*rule_id=*/0}};
-
- // Keep longest.
- EXPECT_THAT(DeduplicateDerivations(candidates),
- ElementsAre(IsDerivation(candidates[2])));
-}
-
-TEST(UtilsTest, DeduplicatesMatchesPerRule) {
- // Overlapping matches from different rules.
- Match matches[] = {CreateMatch(0, 1), CreateMatch(1, 2), CreateMatch(0, 2)};
- const std::vector<Derivation> candidates = {{&matches[0], /*rule_id=*/0},
- {&matches[1], /*rule_id=*/0},
- {&matches[2], /*rule_id=*/0},
- {&matches[0], /*rule_id=*/1}};
-
- // Keep longest for rule 0, but also keep match from rule 1.
- EXPECT_THAT(
- DeduplicateDerivations(candidates),
- ElementsAre(IsDerivation(candidates[2]), IsDerivation(candidates[3])));
-}
-
-TEST(UtilsTest, KeepNonoverlapping) {
- // Non-overlapping matches.
- Match matches[] = {CreateMatch(0, 1), CreateMatch(1, 2), CreateMatch(2, 3)};
- const std::vector<Derivation> candidates = {{&matches[0], /*rule_id=*/0},
- {&matches[1], /*rule_id=*/0},
- {&matches[2], /*rule_id=*/0}};
-
- // Keep all matches.
- EXPECT_THAT(
- DeduplicateDerivations(candidates),
- ElementsAre(IsDerivation(candidates[0]), IsDerivation(candidates[1]),
- IsDerivation(candidates[2])));
-}
-
-} // namespace
-} // namespace libtextclassifier3::grammar
diff --git a/native/utils/grammar/rules.fbs b/native/utils/grammar/rules.fbs
old mode 100755
new mode 100644
index 8052c11..bc0136c
--- a/native/utils/grammar/rules.fbs
+++ b/native/utils/grammar/rules.fbs
@@ -14,6 +14,7 @@
// limitations under the License.
//
+include "utils/grammar/semantics/expression.fbs";
include "utils/i18n/language-tag.fbs";
include "utils/zlib/buffer.fbs";
@@ -147,19 +148,6 @@
annotation_nt:[Nonterminals_.AnnotationNtEntry];
}
-// Callback information.
-namespace libtextclassifier3.grammar.RulesSet_;
-struct Callback {
- // Whether the callback is a filter.
- is_filter:bool;
-}
-
-namespace libtextclassifier3.grammar.RulesSet_;
-struct CallbackEntry {
- key:uint (key);
- value:Callback;
-}
-
namespace libtextclassifier3.grammar.RulesSet_.DebugInformation_;
table NonterminalNamesEntry {
key:int (key);
@@ -205,11 +193,17 @@
terminals:string (shared);
nonterminals:RulesSet_.Nonterminals;
- callback:[RulesSet_.CallbackEntry];
+ reserved_6:int16 (deprecated);
debug_information:RulesSet_.DebugInformation;
regex_annotator:[RulesSet_.RegexAnnotator];
// If true, will compile the regexes only on first use.
lazy_regex_compilation:bool;
+
+ // The semantic expressions associated with rule matches.
+ semantic_expression:[SemanticExpression];
+
+ // The schema defining the semantic results.
+ semantic_values_schema:[ubyte];
}
diff --git a/native/utils/grammar/semantics/composer.cc b/native/utils/grammar/semantics/composer.cc
new file mode 100644
index 0000000..2d69049
--- /dev/null
+++ b/native/utils/grammar/semantics/composer.cc
@@ -0,0 +1,132 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/grammar/semantics/composer.h"
+
+#include "utils/base/status_macros.h"
+#include "utils/grammar/semantics/evaluators/arithmetic-eval.h"
+#include "utils/grammar/semantics/evaluators/compose-eval.h"
+#include "utils/grammar/semantics/evaluators/const-eval.h"
+#include "utils/grammar/semantics/evaluators/constituent-eval.h"
+#include "utils/grammar/semantics/evaluators/merge-values-eval.h"
+#include "utils/grammar/semantics/evaluators/parse-number-eval.h"
+#include "utils/grammar/semantics/evaluators/span-eval.h"
+
+namespace libtextclassifier3::grammar {
+namespace {
+
+// Gathers all constituents of a rule and index them.
+// The constituents are numbered in the rule construction. But consituents could
+// be in optional parts of the rule and might not be present in a match.
+// This finds all constituents that are present in a match and allows to
+// retrieve them by their index.
+std::unordered_map<int, const ParseTree*> GatherConstituents(
+ const ParseTree* root) {
+ std::unordered_map<int, const ParseTree*> constituents;
+ Traverse(root, [root, &constituents](const ParseTree* node) {
+ switch (node->type) {
+ case ParseTree::Type::kMapping:
+ TC3_CHECK(node->IsUnaryRule());
+ constituents[static_cast<const MappingNode*>(node)->id] =
+ node->unary_rule_rhs();
+ return false;
+ case ParseTree::Type::kDefault:
+ // Continue traversal.
+ return true;
+ default:
+ // Don't continue the traversal if we are not at the root node.
+ // This could e.g. be an assertion node.
+ return (node == root);
+ }
+ });
+ return constituents;
+}
+
+} // namespace
+
+SemanticComposer::SemanticComposer(
+ const reflection::Schema* semantic_values_schema) {
+ evaluators_.emplace(SemanticExpression_::Expression_ArithmeticExpression,
+ std::make_unique<ArithmeticExpressionEvaluator>(this));
+ evaluators_.emplace(SemanticExpression_::Expression_ConstituentExpression,
+ std::make_unique<ConstituentEvaluator>());
+ evaluators_.emplace(SemanticExpression_::Expression_ParseNumberExpression,
+ std::make_unique<ParseNumberEvaluator>(this));
+ evaluators_.emplace(SemanticExpression_::Expression_SpanAsStringExpression,
+ std::make_unique<SpanAsStringEvaluator>());
+ if (semantic_values_schema != nullptr) {
+ // Register semantic functions.
+ evaluators_.emplace(
+ SemanticExpression_::Expression_ComposeExpression,
+ std::make_unique<ComposeEvaluator>(this, semantic_values_schema));
+ evaluators_.emplace(
+ SemanticExpression_::Expression_ConstValueExpression,
+ std::make_unique<ConstEvaluator>(semantic_values_schema));
+ evaluators_.emplace(
+ SemanticExpression_::Expression_MergeValueExpression,
+ std::make_unique<MergeValuesEvaluator>(this, semantic_values_schema));
+ }
+}
+
+StatusOr<const SemanticValue*> SemanticComposer::Eval(
+ const TextContext& text_context, const Derivation& derivation,
+ UnsafeArena* arena) const {
+ if (!derivation.parse_tree->IsUnaryRule() ||
+ derivation.parse_tree->unary_rule_rhs()->type !=
+ ParseTree::Type::kExpression) {
+ return nullptr;
+ }
+ return Eval(text_context,
+ static_cast<const SemanticExpressionNode*>(
+ derivation.parse_tree->unary_rule_rhs()),
+ arena);
+}
+
+StatusOr<const SemanticValue*> SemanticComposer::Eval(
+ const TextContext& text_context, const SemanticExpressionNode* derivation,
+ UnsafeArena* arena) const {
+ // Evaluate constituents.
+ EvalContext context{&text_context, derivation};
+ for (const auto& [constituent_index, constituent] :
+ GatherConstituents(derivation)) {
+ if (constituent->type == ParseTree::Type::kExpression) {
+ TC3_ASSIGN_OR_RETURN(
+ context.rule_constituents[constituent_index],
+ Eval(text_context,
+ static_cast<const SemanticExpressionNode*>(constituent), arena));
+ } else {
+ // Just use the text of the constituent if no semantic expression was
+ // defined.
+ context.rule_constituents[constituent_index] = SemanticValue::Create(
+ text_context.Span(constituent->codepoint_span), arena);
+ }
+ }
+ return Apply(context, derivation->expression, arena);
+}
+
+StatusOr<const SemanticValue*> SemanticComposer::Apply(
+ const EvalContext& context, const SemanticExpression* expression,
+ UnsafeArena* arena) const {
+ const auto handler_it = evaluators_.find(expression->expression_type());
+ if (handler_it == evaluators_.end()) {
+ return Status(StatusCode::INVALID_ARGUMENT,
+ std::string("Unhandled expression type: ") +
+ EnumNameExpression(expression->expression_type()));
+ }
+ return handler_it->second->Apply(context, expression, arena);
+}
+
+} // namespace libtextclassifier3::grammar
diff --git a/native/utils/grammar/semantics/composer.h b/native/utils/grammar/semantics/composer.h
new file mode 100644
index 0000000..135f7d6
--- /dev/null
+++ b/native/utils/grammar/semantics/composer.h
@@ -0,0 +1,74 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_COMPOSER_H_
+#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_COMPOSER_H_
+
+#include <unordered_map>
+#include <vector>
+
+#include "utils/base/arena.h"
+#include "utils/base/status.h"
+#include "utils/base/statusor.h"
+#include "utils/flatbuffers/flatbuffers.h"
+#include "utils/grammar/parsing/derivation.h"
+#include "utils/grammar/parsing/parse-tree.h"
+#include "utils/grammar/semantics/eval-context.h"
+#include "utils/grammar/semantics/evaluator.h"
+#include "utils/grammar/semantics/expression_generated.h"
+#include "utils/grammar/semantics/value.h"
+#include "utils/grammar/text-context.h"
+
+namespace libtextclassifier3::grammar {
+
+// Semantic value composer.
+// It evaluates a semantic expression of a syntactic parse tree as a semantic
+// value.
+// It evaluates the constituents of a rule match and applies them to semantic
+// expression, calling out to semantic functions that implement the basic
+// building blocks.
+class SemanticComposer : public SemanticExpressionEvaluator {
+ public:
+ // Expects a flatbuffer schema that describes the possible result values of
+ // an evaluation.
+ explicit SemanticComposer(const reflection::Schema* semantic_values_schema);
+
+ // Evaluates a semantic expression that is associated with the root of a parse
+ // tree.
+ StatusOr<const SemanticValue*> Eval(const TextContext& text_context,
+ const Derivation& derivation,
+ UnsafeArena* arena) const;
+
+ // Applies a semantic expression to a list of constituents and
+ // produces an output semantic value.
+ StatusOr<const SemanticValue*> Apply(const EvalContext& context,
+ const SemanticExpression* expression,
+ UnsafeArena* arena) const override;
+
+ private:
+ // Evaluates a semantic expression against a parse tree.
+ StatusOr<const SemanticValue*> Eval(const TextContext& text_context,
+ const SemanticExpressionNode* derivation,
+ UnsafeArena* arena) const;
+
+ std::unordered_map<SemanticExpression_::Expression,
+ std::unique_ptr<SemanticExpressionEvaluator>>
+ evaluators_;
+};
+
+} // namespace libtextclassifier3::grammar
+
+#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_COMPOSER_H_
diff --git a/native/utils/grammar/semantics/composer_test.cc b/native/utils/grammar/semantics/composer_test.cc
new file mode 100644
index 0000000..e768e18
--- /dev/null
+++ b/native/utils/grammar/semantics/composer_test.cc
@@ -0,0 +1,177 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/grammar/semantics/composer.h"
+
+#include "utils/flatbuffers/flatbuffers.h"
+#include "utils/flatbuffers/reflection.h"
+#include "utils/grammar/parsing/derivation.h"
+#include "utils/grammar/parsing/parser.h"
+#include "utils/grammar/rules_generated.h"
+#include "utils/grammar/semantics/expression_generated.h"
+#include "utils/grammar/testing/utils.h"
+#include "utils/grammar/testing/value_generated.h"
+#include "utils/grammar/types.h"
+#include "utils/grammar/utils/rules.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3::grammar {
+namespace {
+
+using ::testing::ElementsAre;
+
+class SemanticComposerTest : public GrammarTest {};
+
+TEST_F(SemanticComposerTest, EvaluatesSimpleMapping) {
+ RulesSetT model;
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ Rules rules(locale_shard_map);
+ const int test_value_type =
+ TypeIdForName(semantic_values_schema_.get(),
+ "libtextclassifier3.grammar.TestValue")
+ .value();
+ {
+ rules.Add("<month>", {"january"},
+ static_cast<CallbackId>(DefaultCallback::kSemanticExpression),
+ /*callback_param=*/model.semantic_expression.size());
+ TestValueT value;
+ value.value = 1;
+ const std::string serialized_value = PackFlatbuffer<TestValue>(&value);
+ ConstValueExpressionT const_value;
+ const_value.base_type = reflection::BaseType::Obj;
+ const_value.type = test_value_type;
+ const_value.value.assign(serialized_value.begin(), serialized_value.end());
+ model.semantic_expression.emplace_back(new SemanticExpressionT);
+ model.semantic_expression.back()->expression.Set(const_value);
+ }
+ {
+ rules.Add("<month>", {"february"},
+ static_cast<CallbackId>(DefaultCallback::kSemanticExpression),
+ /*callback_param=*/model.semantic_expression.size());
+ TestValueT value;
+ value.value = 2;
+ const std::string serialized_value = PackFlatbuffer<TestValue>(&value);
+ ConstValueExpressionT const_value;
+ const_value.base_type = reflection::BaseType::Obj;
+ const_value.type = test_value_type;
+ const_value.value.assign(serialized_value.begin(), serialized_value.end());
+ model.semantic_expression.emplace_back(new SemanticExpressionT);
+ model.semantic_expression.back()->expression.Set(const_value);
+ }
+ const int kMonth = 0;
+ rules.Add("<month_rule>", {"<month>"},
+ static_cast<CallbackId>(DefaultCallback::kRootRule), kMonth);
+ rules.Finalize().Serialize(/*include_debug_information=*/false, &model);
+ const std::string model_buffer = PackFlatbuffer<RulesSet>(&model);
+ Parser parser(unilib_.get(),
+ flatbuffers::GetRoot<RulesSet>(model_buffer.data()));
+ SemanticComposer composer(semantic_values_schema_.get());
+
+ {
+ const TextContext text = TextContextForText("Month: January");
+ const std::vector<Derivation> derivations = parser.Parse(text, &arena_);
+ EXPECT_THAT(derivations, ElementsAre(IsDerivation(kMonth, 7, 14)));
+
+ StatusOr<const SemanticValue*> maybe_value =
+ composer.Eval(text, derivations.front(), &arena_);
+ EXPECT_TRUE(maybe_value.ok());
+
+ const TestValue* value = maybe_value.ValueOrDie()->Table<TestValue>();
+ EXPECT_EQ(value->value(), 1);
+ }
+
+ {
+ const TextContext text = TextContextForText("Month: February");
+ const std::vector<Derivation> derivations = parser.Parse(text, &arena_);
+ EXPECT_THAT(derivations, ElementsAre(IsDerivation(kMonth, 7, 15)));
+
+ StatusOr<const SemanticValue*> maybe_value =
+ composer.Eval(text, derivations.front(), &arena_);
+ EXPECT_TRUE(maybe_value.ok());
+
+ const TestValue* value = maybe_value.ValueOrDie()->Table<TestValue>();
+ EXPECT_EQ(value->value(), 2);
+ }
+}
+
+TEST_F(SemanticComposerTest, RecursivelyEvaluatesConstituents) {
+ RulesSetT model;
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ Rules rules(locale_shard_map);
+ const int test_value_type =
+ TypeIdForName(semantic_values_schema_.get(),
+ "libtextclassifier3.grammar.TestValue")
+ .value();
+ constexpr int kDateRule = 0;
+ {
+ rules.Add("<month>", {"january"},
+ static_cast<CallbackId>(DefaultCallback::kSemanticExpression),
+ /*callback_param=*/model.semantic_expression.size());
+ TestValueT value;
+ value.value = 42;
+ const std::string serialized_value = PackFlatbuffer<TestValue>(&value);
+ ConstValueExpressionT const_value;
+ const_value.type = test_value_type;
+ const_value.base_type = reflection::BaseType::Obj;
+ const_value.value.assign(serialized_value.begin(), serialized_value.end());
+ model.semantic_expression.emplace_back(new SemanticExpressionT);
+ model.semantic_expression.back()->expression.Set(const_value);
+ }
+ {
+ // Define constituents of the rule.
+ // TODO(smillius): Add support in the rules builder to directly specify
+ // constituent ids in the rule, e.g. `<date> ::= <month>@0? <4_digits>`.
+ rules.Add("<date_@0>", {"<month>"},
+ static_cast<CallbackId>(DefaultCallback::kMapping),
+ /*callback_param=*/1);
+ rules.Add("<date>", {"<date_@0>?", "<4_digits>"},
+ static_cast<CallbackId>(DefaultCallback::kSemanticExpression),
+ /*callback_param=*/model.semantic_expression.size());
+ ConstituentExpressionT constituent;
+ constituent.id = 1;
+ model.semantic_expression.emplace_back(new SemanticExpressionT);
+ model.semantic_expression.back()->expression.Set(constituent);
+ rules.Add("<date_rule>", {"<date>"},
+ static_cast<CallbackId>(DefaultCallback::kRootRule),
+ /*callback_param=*/kDateRule);
+ }
+
+ rules.Finalize().Serialize(/*include_debug_information=*/false, &model);
+ const std::string model_buffer = PackFlatbuffer<RulesSet>(&model);
+ Parser parser(unilib_.get(),
+ flatbuffers::GetRoot<RulesSet>(model_buffer.data()));
+ SemanticComposer composer(semantic_values_schema_.get());
+
+ {
+ const TextContext text = TextContextForText("Event: January 2020");
+ const std::vector<Derivation> derivations =
+ ValidDeduplicatedDerivations(parser.Parse(text, &arena_));
+ EXPECT_THAT(derivations, ElementsAre(IsDerivation(kDateRule, 7, 19)));
+
+ StatusOr<const SemanticValue*> maybe_value =
+ composer.Eval(text, derivations.front(), &arena_);
+ EXPECT_TRUE(maybe_value.ok());
+
+ const TestValue* value = maybe_value.ValueOrDie()->Table<TestValue>();
+ EXPECT_EQ(value->value(), 42);
+ }
+}
+
+} // namespace
+} // namespace libtextclassifier3::grammar
diff --git a/native/utils/grammar/semantics/eval-context.h b/native/utils/grammar/semantics/eval-context.h
new file mode 100644
index 0000000..aab878a
--- /dev/null
+++ b/native/utils/grammar/semantics/eval-context.h
@@ -0,0 +1,45 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVAL_CONTEXT_H_
+#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVAL_CONTEXT_H_
+
+#include <unordered_map>
+
+#include "utils/grammar/parsing/parse-tree.h"
+#include "utils/grammar/semantics/value.h"
+#include "utils/grammar/text-context.h"
+
+namespace libtextclassifier3::grammar {
+
+// Context for the evaluation of the semantic expression of a rule parse tree.
+// This contains data about the evaluated constituents (named parts) of a rule
+// and it's match.
+struct EvalContext {
+ // The input text.
+ const TextContext* text_context = nullptr;
+
+ // The syntactic parse tree that is begin evaluated.
+ const ParseTree* parse_tree = nullptr;
+
+ // A map of an id of a rule constituent (named part of a rule match) to it's
+ // evaluated semantic value.
+ std::unordered_map<int, const SemanticValue*> rule_constituents;
+};
+
+} // namespace libtextclassifier3::grammar
+
+#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVAL_CONTEXT_H_
diff --git a/native/utils/grammar/semantics/evaluator.h b/native/utils/grammar/semantics/evaluator.h
new file mode 100644
index 0000000..7b6bf90
--- /dev/null
+++ b/native/utils/grammar/semantics/evaluator.h
@@ -0,0 +1,42 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVALUATOR_H_
+#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVALUATOR_H_
+
+#include "utils/base/arena.h"
+#include "utils/base/statusor.h"
+#include "utils/grammar/semantics/eval-context.h"
+#include "utils/grammar/semantics/expression_generated.h"
+#include "utils/grammar/semantics/value.h"
+
+namespace libtextclassifier3::grammar {
+
+// Interface for a semantic function that evaluates an expression and returns
+// a semantic value.
+class SemanticExpressionEvaluator {
+ public:
+ virtual ~SemanticExpressionEvaluator() = default;
+
+ // Applies `expression` to the `context` to produce a semantic value.
+ virtual StatusOr<const SemanticValue*> Apply(
+ const EvalContext& context, const SemanticExpression* expression,
+ UnsafeArena* arena) const = 0;
+};
+
+} // namespace libtextclassifier3::grammar
+
+#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVALUATOR_H_
diff --git a/native/utils/grammar/semantics/evaluators/arithmetic-eval.cc b/native/utils/grammar/semantics/evaluators/arithmetic-eval.cc
new file mode 100644
index 0000000..76b72c6
--- /dev/null
+++ b/native/utils/grammar/semantics/evaluators/arithmetic-eval.cc
@@ -0,0 +1,134 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/grammar/semantics/evaluators/arithmetic-eval.h"
+
+#include <limits>
+
+namespace libtextclassifier3::grammar {
+namespace {
+
+template <typename T>
+StatusOr<const SemanticValue*> Reduce(
+ const SemanticExpressionEvaluator* composer, const EvalContext& context,
+ const ArithmeticExpression* expression, UnsafeArena* arena) {
+ T result;
+ switch (expression->op()) {
+ case ArithmeticExpression_::Operator_OP_ADD: {
+ result = 0;
+ break;
+ }
+ case ArithmeticExpression_::Operator_OP_MUL: {
+ result = 1;
+ break;
+ }
+ case ArithmeticExpression_::Operator_OP_MIN: {
+ result = std::numeric_limits<T>::max();
+ break;
+ }
+ case ArithmeticExpression_::Operator_OP_MAX: {
+ result = std::numeric_limits<T>::min();
+ break;
+ }
+ default: {
+ return Status(StatusCode::INVALID_ARGUMENT,
+ "Unexpected op: " +
+ std::string(ArithmeticExpression_::EnumNameOperator(
+ expression->op())));
+ }
+ }
+ if (expression->values() != nullptr) {
+ for (const SemanticExpression* semantic_expression :
+ *expression->values()) {
+ TC3_ASSIGN_OR_RETURN(
+ const SemanticValue* value,
+ composer->Apply(context, semantic_expression, arena));
+ if (value == nullptr) {
+ continue;
+ }
+ if (!value->Has<T>()) {
+ return Status(
+ StatusCode::INVALID_ARGUMENT,
+ "Argument didn't evaluate as expected type: " +
+ std::string(reflection::EnumNameBaseType(value->base_type())));
+ }
+ const T scalar_value = value->Value<T>();
+ switch (expression->op()) {
+ case ArithmeticExpression_::Operator_OP_ADD: {
+ result += scalar_value;
+ break;
+ }
+ case ArithmeticExpression_::Operator_OP_MUL: {
+ result *= scalar_value;
+ break;
+ }
+ case ArithmeticExpression_::Operator_OP_MIN: {
+ result = std::min(result, scalar_value);
+ break;
+ }
+ case ArithmeticExpression_::Operator_OP_MAX: {
+ result = std::max(result, scalar_value);
+ break;
+ }
+ default: {
+ break;
+ }
+ }
+ }
+ }
+ return SemanticValue::Create(result, arena);
+}
+
+} // namespace
+
+StatusOr<const SemanticValue*> ArithmeticExpressionEvaluator::Apply(
+ const EvalContext& context, const SemanticExpression* expression,
+ UnsafeArena* arena) const {
+ TC3_DCHECK_EQ(expression->expression_type(),
+ SemanticExpression_::Expression_ArithmeticExpression);
+ const ArithmeticExpression* arithmetic_expression =
+ expression->expression_as_ArithmeticExpression();
+ switch (arithmetic_expression->base_type()) {
+ case reflection::BaseType::Byte:
+ return Reduce<int8>(composer_, context, arithmetic_expression, arena);
+ case reflection::BaseType::UByte:
+ return Reduce<uint8>(composer_, context, arithmetic_expression, arena);
+ case reflection::BaseType::Short:
+ return Reduce<int16>(composer_, context, arithmetic_expression, arena);
+ case reflection::BaseType::UShort:
+ return Reduce<uint16>(composer_, context, arithmetic_expression, arena);
+ case reflection::BaseType::Int:
+ return Reduce<int32>(composer_, context, arithmetic_expression, arena);
+ case reflection::BaseType::UInt:
+ return Reduce<uint32>(composer_, context, arithmetic_expression, arena);
+ case reflection::BaseType::Long:
+ return Reduce<int64>(composer_, context, arithmetic_expression, arena);
+ case reflection::BaseType::ULong:
+ return Reduce<uint64>(composer_, context, arithmetic_expression, arena);
+ case reflection::BaseType::Float:
+ return Reduce<float>(composer_, context, arithmetic_expression, arena);
+ case reflection::BaseType::Double:
+ return Reduce<double>(composer_, context, arithmetic_expression, arena);
+ default:
+ return Status(StatusCode::INVALID_ARGUMENT,
+ "Unsupported for ArithmeticExpression: " +
+ std::string(reflection::EnumNameBaseType(
+ static_cast<reflection::BaseType>(
+ arithmetic_expression->base_type()))));
+ }
+}
+
+} // namespace libtextclassifier3::grammar
diff --git a/native/utils/grammar/semantics/evaluators/arithmetic-eval.h b/native/utils/grammar/semantics/evaluators/arithmetic-eval.h
new file mode 100644
index 0000000..38efc57
--- /dev/null
+++ b/native/utils/grammar/semantics/evaluators/arithmetic-eval.h
@@ -0,0 +1,47 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVALUATORS_ARITHMETIC_EVAL_H_
+#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVALUATORS_ARITHMETIC_EVAL_H_
+
+#include "utils/base/arena.h"
+#include "utils/grammar/semantics/eval-context.h"
+#include "utils/grammar/semantics/evaluator.h"
+#include "utils/grammar/semantics/expression_generated.h"
+#include "utils/grammar/semantics/value.h"
+
+namespace libtextclassifier3::grammar {
+
+// Evaluates an arithmetic expression.
+// Expects zero or more arguments and produces either sum, product, minimum or
+// maximum of its arguments. If no arguments are specified, each operator
+// returns its identity value.
+class ArithmeticExpressionEvaluator : public SemanticExpressionEvaluator {
+ public:
+ explicit ArithmeticExpressionEvaluator(
+ const SemanticExpressionEvaluator* composer)
+ : composer_(composer) {}
+
+ StatusOr<const SemanticValue*> Apply(const EvalContext& context,
+ const SemanticExpression* expression,
+ UnsafeArena* arena) const override;
+
+ private:
+ const SemanticExpressionEvaluator* composer_;
+};
+
+} // namespace libtextclassifier3::grammar
+#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVALUATORS_ARITHMETIC_EVAL_H_
diff --git a/native/utils/grammar/semantics/evaluators/arithmetic-eval_test.cc b/native/utils/grammar/semantics/evaluators/arithmetic-eval_test.cc
new file mode 100644
index 0000000..5385fc1
--- /dev/null
+++ b/native/utils/grammar/semantics/evaluators/arithmetic-eval_test.cc
@@ -0,0 +1,83 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/grammar/semantics/evaluators/arithmetic-eval.h"
+
+#include <vector>
+
+#include "utils/base/statusor.h"
+#include "utils/flatbuffers/flatbuffers.h"
+#include "utils/flatbuffers/reflection.h"
+#include "utils/flatbuffers/test-utils.h"
+#include "utils/grammar/semantics/evaluator.h"
+#include "utils/grammar/semantics/evaluators/const-eval.h"
+#include "utils/grammar/semantics/expression_generated.h"
+#include "utils/grammar/testing/utils.h"
+#include "utils/grammar/testing/value_generated.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "flatbuffers/flatbuffers.h"
+
+namespace libtextclassifier3::grammar {
+namespace {
+
+template <typename T>
+class ArithmeticExpressionEvaluatorTest : public GrammarTest {
+ protected:
+ T Eval(const ArithmeticExpression_::Operator op) {
+ ArithmeticExpressionT arithmetic_expression;
+ arithmetic_expression.base_type = flatbuffers_base_type<T>::value;
+ arithmetic_expression.op = op;
+ arithmetic_expression.values.push_back(
+ CreatePrimitiveConstExpression<T>(1));
+ arithmetic_expression.values.push_back(
+ CreatePrimitiveConstExpression<T>(2));
+ arithmetic_expression.values.push_back(
+ CreatePrimitiveConstExpression<T>(3));
+ arithmetic_expression.values.push_back(
+ CreatePrimitiveConstExpression<T>(4));
+ OwnedFlatbuffer<SemanticExpression> expression =
+ CreateExpression(std::move(arithmetic_expression));
+
+ // Setup evaluators.
+ ConstEvaluator const_eval(semantic_values_schema_.get());
+ ArithmeticExpressionEvaluator arithmetic_eval(&const_eval);
+
+ // Run evaluator.
+ StatusOr<const SemanticValue*> result =
+ arithmetic_eval.Apply(/*context=*/{}, expression.get(), &arena_);
+
+ // Check result.
+ EXPECT_TRUE(result.ok());
+ const SemanticValue* result_value = result.ValueOrDie();
+ EXPECT_TRUE(result_value != nullptr);
+ return result_value->Value<T>();
+ }
+};
+
+using NumberTypes = ::testing::Types<int8, uint8, int16, uint16, int32, uint32,
+ int64, uint64, double, float>;
+TYPED_TEST_SUITE(ArithmeticExpressionEvaluatorTest, NumberTypes);
+
+TYPED_TEST(ArithmeticExpressionEvaluatorTest, ParsesNumber) {
+ EXPECT_EQ(this->Eval(ArithmeticExpression_::Operator_OP_ADD), 10);
+ EXPECT_EQ(this->Eval(ArithmeticExpression_::Operator_OP_MUL), 24);
+ EXPECT_EQ(this->Eval(ArithmeticExpression_::Operator_OP_MIN), 1);
+ EXPECT_EQ(this->Eval(ArithmeticExpression_::Operator_OP_MAX), 4);
+}
+
+} // namespace
+} // namespace libtextclassifier3::grammar
diff --git a/native/utils/grammar/semantics/evaluators/compose-eval.cc b/native/utils/grammar/semantics/evaluators/compose-eval.cc
new file mode 100644
index 0000000..09bbf5c
--- /dev/null
+++ b/native/utils/grammar/semantics/evaluators/compose-eval.cc
@@ -0,0 +1,183 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/grammar/semantics/evaluators/compose-eval.h"
+
+#include "utils/base/status_macros.h"
+#include "utils/strings/stringpiece.h"
+
+namespace libtextclassifier3::grammar {
+namespace {
+
+// Tries setting a singular field.
+template <typename T>
+Status TrySetField(const reflection::Field* field, const SemanticValue* value,
+ MutableFlatbuffer* result) {
+ if (!result->Set<T>(field, value->Value<T>())) {
+ return Status(StatusCode::INVALID_ARGUMENT, "Could not set field.");
+ }
+ return Status::OK;
+}
+
+template <>
+Status TrySetField<flatbuffers::Table>(const reflection::Field* field,
+ const SemanticValue* value,
+ MutableFlatbuffer* result) {
+ if (!result->Mutable(field)->MergeFrom(value->Table())) {
+ return Status(StatusCode::INVALID_ARGUMENT,
+ "Could not set sub-field in result.");
+ }
+ return Status::OK;
+}
+
+// Tries adding a value to a repeated field.
+template <typename T>
+Status TryAddField(const reflection::Field* field, const SemanticValue* value,
+ MutableFlatbuffer* result) {
+ if (!result->Repeated(field)->Add(value->Value<T>())) {
+ return Status(StatusCode::INVALID_ARGUMENT, "Could not add field.");
+ }
+ return Status::OK;
+}
+
+template <>
+Status TryAddField<flatbuffers::Table>(const reflection::Field* field,
+ const SemanticValue* value,
+ MutableFlatbuffer* result) {
+ if (!result->Repeated(field)->Add()->MergeFrom(value->Table())) {
+ return Status(StatusCode::INVALID_ARGUMENT,
+ "Could not add message to repeated field.");
+ }
+ return Status::OK;
+}
+
+// Tries adding or setting a value for a field.
+template <typename T>
+Status TrySetOrAddValue(const FlatbufferFieldPath* field_path,
+ const SemanticValue* value, MutableFlatbuffer* result) {
+ MutableFlatbuffer* parent;
+ const reflection::Field* field;
+ if (!result->GetFieldWithParent(field_path, &parent, &field)) {
+ return Status(StatusCode::INVALID_ARGUMENT, "Could not get field.");
+ }
+ if (field->type()->base_type() == reflection::Vector) {
+ return TryAddField<T>(field, value, parent);
+ } else {
+ return TrySetField<T>(field, value, parent);
+ }
+}
+
+} // namespace
+
+StatusOr<const SemanticValue*> ComposeEvaluator::Apply(
+ const EvalContext& context, const SemanticExpression* expression,
+ UnsafeArena* arena) const {
+ const ComposeExpression* compose_expression =
+ expression->expression_as_ComposeExpression();
+ std::unique_ptr<MutableFlatbuffer> result =
+ semantic_value_builder_.NewTable(compose_expression->type());
+
+ if (result == nullptr) {
+ return Status(StatusCode::INVALID_ARGUMENT, "Invalid result type.");
+ }
+
+ // Evaluate and set fields.
+ if (compose_expression->fields() != nullptr) {
+ for (const ComposeExpression_::Field* field :
+ *compose_expression->fields()) {
+ // Evaluate argument.
+ TC3_ASSIGN_OR_RETURN(const SemanticValue* value,
+ composer_->Apply(context, field->value(), arena));
+ if (value == nullptr) {
+ continue;
+ }
+
+ switch (value->base_type()) {
+ case reflection::BaseType::Bool: {
+ TC3_RETURN_IF_ERROR(
+ TrySetOrAddValue<bool>(field->path(), value, result.get()));
+ break;
+ }
+ case reflection::BaseType::Byte: {
+ TC3_RETURN_IF_ERROR(
+ TrySetOrAddValue<int8>(field->path(), value, result.get()));
+ break;
+ }
+ case reflection::BaseType::UByte: {
+ TC3_RETURN_IF_ERROR(
+ TrySetOrAddValue<uint8>(field->path(), value, result.get()));
+ break;
+ }
+ case reflection::BaseType::Short: {
+ TC3_RETURN_IF_ERROR(
+ TrySetOrAddValue<int16>(field->path(), value, result.get()));
+ break;
+ }
+ case reflection::BaseType::UShort: {
+ TC3_RETURN_IF_ERROR(
+ TrySetOrAddValue<uint16>(field->path(), value, result.get()));
+ break;
+ }
+ case reflection::BaseType::Int: {
+ TC3_RETURN_IF_ERROR(
+ TrySetOrAddValue<int32>(field->path(), value, result.get()));
+ break;
+ }
+ case reflection::BaseType::UInt: {
+ TC3_RETURN_IF_ERROR(
+ TrySetOrAddValue<uint32>(field->path(), value, result.get()));
+ break;
+ }
+ case reflection::BaseType::Long: {
+ TC3_RETURN_IF_ERROR(
+ TrySetOrAddValue<int64>(field->path(), value, result.get()));
+ break;
+ }
+ case reflection::BaseType::ULong: {
+ TC3_RETURN_IF_ERROR(
+ TrySetOrAddValue<uint64>(field->path(), value, result.get()));
+ break;
+ }
+ case reflection::BaseType::Float: {
+ TC3_RETURN_IF_ERROR(
+ TrySetOrAddValue<float>(field->path(), value, result.get()));
+ break;
+ }
+ case reflection::BaseType::Double: {
+ TC3_RETURN_IF_ERROR(
+ TrySetOrAddValue<double>(field->path(), value, result.get()));
+ break;
+ }
+ case reflection::BaseType::String: {
+ TC3_RETURN_IF_ERROR(TrySetOrAddValue<StringPiece>(
+ field->path(), value, result.get()));
+ break;
+ }
+ case reflection::BaseType::Obj: {
+ TC3_RETURN_IF_ERROR(TrySetOrAddValue<flatbuffers::Table>(
+ field->path(), value, result.get()));
+ break;
+ }
+ default:
+ return Status(StatusCode::INVALID_ARGUMENT, "Unhandled type.");
+ }
+ }
+ }
+
+ return SemanticValue::Create<const MutableFlatbuffer*>(result.get(), arena);
+}
+
+} // namespace libtextclassifier3::grammar
diff --git a/native/utils/grammar/semantics/evaluators/compose-eval.h b/native/utils/grammar/semantics/evaluators/compose-eval.h
new file mode 100644
index 0000000..ba3b6f9
--- /dev/null
+++ b/native/utils/grammar/semantics/evaluators/compose-eval.h
@@ -0,0 +1,47 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVALUATORS_COMPOSE_EVAL_H_
+#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVALUATORS_COMPOSE_EVAL_H_
+
+#include "utils/base/arena.h"
+#include "utils/flatbuffers/mutable.h"
+#include "utils/grammar/semantics/eval-context.h"
+#include "utils/grammar/semantics/evaluator.h"
+#include "utils/grammar/semantics/expression_generated.h"
+#include "utils/grammar/semantics/value.h"
+
+namespace libtextclassifier3::grammar {
+
+// Combines arguments to a result type.
+class ComposeEvaluator : public SemanticExpressionEvaluator {
+ public:
+ explicit ComposeEvaluator(const SemanticExpressionEvaluator* composer,
+ const reflection::Schema* semantic_values_schema)
+ : composer_(composer), semantic_value_builder_(semantic_values_schema) {}
+
+ StatusOr<const SemanticValue*> Apply(const EvalContext& context,
+ const SemanticExpression* expression,
+ UnsafeArena* arena) const override;
+
+ private:
+ const SemanticExpressionEvaluator* composer_;
+ const MutableFlatbufferBuilder semantic_value_builder_;
+};
+
+} // namespace libtextclassifier3::grammar
+
+#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVALUATORS_COMPOSE_EVAL_H_
diff --git a/native/utils/grammar/semantics/evaluators/compose-eval_test.cc b/native/utils/grammar/semantics/evaluators/compose-eval_test.cc
new file mode 100644
index 0000000..f26042a
--- /dev/null
+++ b/native/utils/grammar/semantics/evaluators/compose-eval_test.cc
@@ -0,0 +1,289 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/grammar/semantics/evaluators/compose-eval.h"
+
+#include <vector>
+
+#include "utils/base/statusor.h"
+#include "utils/flatbuffers/flatbuffers.h"
+#include "utils/flatbuffers/reflection.h"
+#include "utils/flatbuffers/test-utils.h"
+#include "utils/grammar/semantics/evaluator.h"
+#include "utils/grammar/semantics/evaluators/const-eval.h"
+#include "utils/grammar/semantics/expression_generated.h"
+#include "utils/grammar/testing/utils.h"
+#include "utils/grammar/testing/value_generated.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "flatbuffers/flatbuffers.h"
+
+namespace libtextclassifier3::grammar {
+namespace {
+
+class ComposeEvaluatorTest : public GrammarTest {
+ protected:
+ explicit ComposeEvaluatorTest()
+ : const_eval_(semantic_values_schema_.get()) {}
+
+ // Evaluator that just returns a constant value.
+ ConstEvaluator const_eval_;
+};
+
+TEST_F(ComposeEvaluatorTest, SetsSingleField) {
+ TestDateT date;
+ date.day = 1;
+ date.month = 2;
+ date.year = 2020;
+ ComposeExpressionT compose_expression;
+ compose_expression.type =
+ TypeIdForName(semantic_values_schema_.get(),
+ "libtextclassifier3.grammar.TestValue")
+ .value();
+ compose_expression.fields.emplace_back(new ComposeExpression_::FieldT);
+ compose_expression.fields.back()->path = CreateUnpackedFieldPath({"date"});
+ compose_expression.fields.back()->value = CreateConstDateExpression(date);
+ OwnedFlatbuffer<SemanticExpression> expression =
+ CreateExpression(std::move(compose_expression));
+
+ ComposeEvaluator compose_eval(&const_eval_, semantic_values_schema_.get());
+
+ StatusOr<const SemanticValue*> result =
+ compose_eval.Apply(/*context=*/{}, expression.get(), &arena_);
+
+ EXPECT_TRUE(result.ok());
+ const SemanticValue* result_value = result.ValueOrDie();
+ ASSERT_NE(result_value, nullptr);
+ EXPECT_EQ(result_value->type()->name()->str(),
+ "libtextclassifier3.grammar.TestValue");
+ const TestValue* result_test_value = result_value->Table<TestValue>();
+ EXPECT_EQ(result_test_value->date()->day(), 1);
+ EXPECT_EQ(result_test_value->date()->month(), 2);
+ EXPECT_EQ(result_test_value->date()->year(), 2020);
+}
+
+TEST_F(ComposeEvaluatorTest, SetsStringField) {
+ ComposeExpressionT compose_expression;
+ compose_expression.type =
+ TypeIdForName(semantic_values_schema_.get(),
+ "libtextclassifier3.grammar.TestValue")
+ .value();
+ compose_expression.fields.emplace_back(new ComposeExpression_::FieldT);
+ compose_expression.fields.back()->path =
+ CreateUnpackedFieldPath({"test_string"});
+ compose_expression.fields.back()->value =
+ CreatePrimitiveConstExpression<StringPiece>("this is a test");
+ OwnedFlatbuffer<SemanticExpression> expression =
+ CreateExpression(std::move(compose_expression));
+
+ ComposeEvaluator compose_eval(&const_eval_, semantic_values_schema_.get());
+
+ StatusOr<const SemanticValue*> result =
+ compose_eval.Apply(/*context=*/{}, expression.get(), &arena_);
+
+ EXPECT_TRUE(result.ok());
+ const SemanticValue* result_value = result.ValueOrDie();
+ ASSERT_NE(result_value, nullptr);
+ EXPECT_EQ(result_value->type()->name()->str(),
+ "libtextclassifier3.grammar.TestValue");
+ const TestValue* result_test_value = result_value->Table<TestValue>();
+ EXPECT_EQ(result_test_value->test_string()->str(), "this is a test");
+}
+
+TEST_F(ComposeEvaluatorTest, SetsPrimitiveField) {
+ ComposeExpressionT compose_expression;
+ compose_expression.type = TypeIdForName(semantic_values_schema_.get(),
+ "libtextclassifier3.grammar.TestDate")
+ .value();
+ compose_expression.fields.emplace_back(new ComposeExpression_::FieldT);
+ compose_expression.fields.back()->path = CreateUnpackedFieldPath({"day"});
+ compose_expression.fields.back()->value =
+ CreatePrimitiveConstExpression<int>(1);
+
+ OwnedFlatbuffer<SemanticExpression> expression =
+ CreateExpression(std::move(compose_expression));
+
+ ComposeEvaluator compose_eval(&const_eval_, semantic_values_schema_.get());
+
+ StatusOr<const SemanticValue*> result =
+ compose_eval.Apply(/*context=*/{}, expression.get(), &arena_);
+
+ EXPECT_TRUE(result.ok());
+ const SemanticValue* result_value = result.ValueOrDie();
+ ASSERT_NE(result_value, nullptr);
+ EXPECT_EQ(result_value->type()->name()->str(),
+ "libtextclassifier3.grammar.TestDate");
+ const TestDate* result_date = result_value->Table<TestDate>();
+ EXPECT_EQ(result_date->day(), 1);
+}
+
+TEST_F(ComposeEvaluatorTest, MergesMultipleField) {
+ TestDateT day;
+ day.day = 1;
+
+ TestDateT month;
+ month.month = 2;
+
+ TestDateT year;
+ year.year = 2020;
+
+ ComposeExpressionT compose_expression;
+ compose_expression.type =
+ TypeIdForName(semantic_values_schema_.get(),
+ "libtextclassifier3.grammar.TestValue")
+ .value();
+ for (const TestDateT& component : std::vector<TestDateT>{day, month, year}) {
+ compose_expression.fields.emplace_back(new ComposeExpression_::FieldT);
+ compose_expression.fields.back()->path = CreateUnpackedFieldPath({"date"});
+ compose_expression.fields.back()->value =
+ CreateConstDateExpression(component);
+ }
+ OwnedFlatbuffer<SemanticExpression> expression =
+ CreateExpression(std::move(compose_expression));
+
+ ComposeEvaluator compose_eval(&const_eval_, semantic_values_schema_.get());
+
+ StatusOr<const SemanticValue*> result =
+ compose_eval.Apply(/*context=*/{}, expression.get(), &arena_);
+
+ EXPECT_TRUE(result.ok());
+ const SemanticValue* result_value = result.ValueOrDie();
+ ASSERT_NE(result_value, nullptr);
+ EXPECT_EQ(result_value->type()->name()->str(),
+ "libtextclassifier3.grammar.TestValue");
+ const TestValue* result_test_value = result_value->Table<TestValue>();
+ EXPECT_EQ(result_test_value->date()->day(), 1);
+ EXPECT_EQ(result_test_value->date()->month(), 2);
+ EXPECT_EQ(result_test_value->date()->year(), 2020);
+}
+
+TEST_F(ComposeEvaluatorTest, SucceedsEvenWhenEmpty) {
+ ComposeExpressionT compose_expression;
+ compose_expression.type =
+ TypeIdForName(semantic_values_schema_.get(),
+ "libtextclassifier3.grammar.TestValue")
+ .value();
+ compose_expression.fields.emplace_back(new ComposeExpression_::FieldT);
+ compose_expression.fields.back()->path = CreateUnpackedFieldPath({"date"});
+ compose_expression.fields.back()->value.reset(new SemanticExpressionT);
+ OwnedFlatbuffer<SemanticExpression> expression =
+ CreateExpression(std::move(compose_expression));
+
+ // Just return null value.
+ struct : public SemanticExpressionEvaluator {
+ StatusOr<const SemanticValue*> Apply(const EvalContext&,
+ const SemanticExpression*,
+ UnsafeArena*) const override {
+ return nullptr;
+ }
+ } null_eval;
+
+ ComposeEvaluator compose_eval(&null_eval, semantic_values_schema_.get());
+
+ StatusOr<const SemanticValue*> result =
+ compose_eval.Apply(/*context=*/{}, expression.get(), &arena_);
+
+ EXPECT_TRUE(result.ok());
+}
+
+TEST_F(ComposeEvaluatorTest, AddsRepeatedPrimitiveField) {
+ ComposeExpressionT compose_expression;
+ compose_expression.type =
+ TypeIdForName(semantic_values_schema_.get(),
+ "libtextclassifier3.grammar.TestValue")
+ .value();
+ compose_expression.fields.emplace_back(new ComposeExpression_::FieldT);
+ compose_expression.fields.back()->path =
+ CreateUnpackedFieldPath({"repeated_enum"});
+ compose_expression.fields.back()->value =
+ CreatePrimitiveConstExpression<int>(TestEnum_ENUM_1);
+ compose_expression.fields.emplace_back(new ComposeExpression_::FieldT);
+ compose_expression.fields.back()->path =
+ CreateUnpackedFieldPath({"repeated_enum"});
+ compose_expression.fields.back()->value =
+ CreatePrimitiveConstExpression<int>(TestEnum_ENUM_2);
+ OwnedFlatbuffer<SemanticExpression> expression =
+ CreateExpression(std::move(compose_expression));
+
+ ComposeEvaluator compose_eval(&const_eval_, semantic_values_schema_.get());
+
+ StatusOr<const SemanticValue*> result =
+ compose_eval.Apply(/*context=*/{}, expression.get(), &arena_);
+
+ EXPECT_TRUE(result.ok());
+ const SemanticValue* result_value = result.ValueOrDie();
+ ASSERT_NE(result_value, nullptr);
+ EXPECT_EQ(result_value->type()->name()->str(),
+ "libtextclassifier3.grammar.TestValue");
+ const TestValue* result_test_value = result_value->Table<TestValue>();
+ EXPECT_EQ(result_test_value->repeated_enum()->size(), 2);
+ EXPECT_EQ(result_test_value->repeated_enum()->Get(0), TestEnum_ENUM_1);
+ EXPECT_EQ(result_test_value->repeated_enum()->Get(1), TestEnum_ENUM_2);
+}
+
+TEST_F(ComposeEvaluatorTest, AddsRepeatedSubmessage) {
+ ComposeExpressionT compose_expression;
+ compose_expression.type =
+ TypeIdForName(semantic_values_schema_.get(),
+ "libtextclassifier3.grammar.TestValue")
+ .value();
+ {
+ compose_expression.fields.emplace_back(new ComposeExpression_::FieldT);
+ compose_expression.fields.back()->path =
+ CreateUnpackedFieldPath({"repeated_date"});
+ TestDateT date;
+ date.day = 1;
+ date.month = 2;
+ date.year = 2020;
+ compose_expression.fields.back()->value = CreateConstDateExpression(date);
+ }
+
+ {
+ compose_expression.fields.emplace_back(new ComposeExpression_::FieldT);
+ compose_expression.fields.back()->path =
+ CreateUnpackedFieldPath({"repeated_date"});
+ TestDateT date;
+ date.day = 3;
+ date.month = 4;
+ date.year = 2021;
+ compose_expression.fields.back()->value = CreateConstDateExpression(date);
+ }
+
+ OwnedFlatbuffer<SemanticExpression> expression =
+ CreateExpression(std::move(compose_expression));
+
+ ComposeEvaluator compose_eval(&const_eval_, semantic_values_schema_.get());
+
+ StatusOr<const SemanticValue*> result =
+ compose_eval.Apply(/*context=*/{}, expression.get(), &arena_);
+
+ EXPECT_TRUE(result.ok());
+ const SemanticValue* result_value = result.ValueOrDie();
+ ASSERT_NE(result_value, nullptr);
+ EXPECT_EQ(result_value->type()->name()->str(),
+ "libtextclassifier3.grammar.TestValue");
+ const TestValue* result_test_value = result_value->Table<TestValue>();
+ EXPECT_EQ(result_test_value->repeated_date()->size(), 2);
+ EXPECT_EQ(result_test_value->repeated_date()->Get(0)->day(), 1);
+ EXPECT_EQ(result_test_value->repeated_date()->Get(0)->month(), 2);
+ EXPECT_EQ(result_test_value->repeated_date()->Get(0)->year(), 2020);
+ EXPECT_EQ(result_test_value->repeated_date()->Get(1)->day(), 3);
+ EXPECT_EQ(result_test_value->repeated_date()->Get(1)->month(), 4);
+ EXPECT_EQ(result_test_value->repeated_date()->Get(1)->year(), 2021);
+}
+
+} // namespace
+} // namespace libtextclassifier3::grammar
diff --git a/native/utils/grammar/semantics/evaluators/const-eval.h b/native/utils/grammar/semantics/evaluators/const-eval.h
new file mode 100644
index 0000000..67a4c54
--- /dev/null
+++ b/native/utils/grammar/semantics/evaluators/const-eval.h
@@ -0,0 +1,68 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVALUATORS_CONST_EVAL_H_
+#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVALUATORS_CONST_EVAL_H_
+
+#include "utils/base/arena.h"
+#include "utils/grammar/semantics/eval-context.h"
+#include "utils/grammar/semantics/evaluator.h"
+#include "utils/grammar/semantics/expression_generated.h"
+#include "utils/grammar/semantics/value.h"
+
+namespace libtextclassifier3::grammar {
+
+// Returns a constant value of a given type.
+class ConstEvaluator : public SemanticExpressionEvaluator {
+ public:
+ explicit ConstEvaluator(const reflection::Schema* semantic_values_schema)
+ : semantic_values_schema_(semantic_values_schema) {}
+
+ StatusOr<const SemanticValue*> Apply(const EvalContext&,
+ const SemanticExpression* expression,
+ UnsafeArena* arena) const override {
+ TC3_DCHECK_EQ(expression->expression_type(),
+ SemanticExpression_::Expression_ConstValueExpression);
+ const ConstValueExpression* const_value_expression =
+ expression->expression_as_ConstValueExpression();
+ const reflection::BaseType base_type =
+ static_cast<reflection::BaseType>(const_value_expression->base_type());
+ const StringPiece data = StringPiece(
+ reinterpret_cast<const char*>(const_value_expression->value()->data()),
+ const_value_expression->value()->size());
+
+ if (base_type == reflection::BaseType::Obj) {
+ // Resolve the object type.
+ const int type_id = const_value_expression->type();
+ if (type_id < 0 ||
+ type_id >= semantic_values_schema_->objects()->size()) {
+ return Status(StatusCode::INVALID_ARGUMENT, "Invalid type.");
+ }
+ return SemanticValue::Create(semantic_values_schema_->objects()->Get(
+ const_value_expression->type()),
+ data, arena);
+ } else {
+ return SemanticValue::Create(base_type, data, arena);
+ }
+ }
+
+ private:
+ const reflection::Schema* semantic_values_schema_;
+};
+
+} // namespace libtextclassifier3::grammar
+
+#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVALUATORS_CONST_EVAL_H_
diff --git a/native/utils/grammar/semantics/evaluators/const-eval_test.cc b/native/utils/grammar/semantics/evaluators/const-eval_test.cc
new file mode 100644
index 0000000..02eea5d
--- /dev/null
+++ b/native/utils/grammar/semantics/evaluators/const-eval_test.cc
@@ -0,0 +1,105 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/grammar/semantics/evaluators/const-eval.h"
+
+#include "utils/flatbuffers/flatbuffers.h"
+#include "utils/grammar/semantics/expression_generated.h"
+#include "utils/grammar/testing/utils.h"
+#include "utils/grammar/testing/value_generated.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "flatbuffers/flatbuffers.h"
+
+namespace libtextclassifier3::grammar {
+namespace {
+
+class ConstEvaluatorTest : public GrammarTest {
+ protected:
+ explicit ConstEvaluatorTest() : const_eval_(semantic_values_schema_.get()) {}
+
+ const ConstEvaluator const_eval_;
+};
+
+TEST_F(ConstEvaluatorTest, CreatesConstantSemanticValues) {
+ TestValueT value;
+ value.a_float_value = 64.42;
+ value.test_string = "test string";
+ OwnedFlatbuffer<SemanticExpression> expression =
+ CreateAndPackConstExpression(value);
+
+ StatusOr<const SemanticValue*> result =
+ const_eval_.Apply(/*context=*/{}, expression.get(), &arena_);
+
+ EXPECT_TRUE(result.ok());
+ const SemanticValue* result_value = result.ValueOrDie();
+ ASSERT_NE(result_value, nullptr);
+ EXPECT_EQ(result_value->type()->name()->str(),
+ "libtextclassifier3.grammar.TestValue");
+ const TestValue* result_test_value = result_value->Table<TestValue>();
+ EXPECT_EQ(result_test_value->test_string()->str(), "test string");
+ EXPECT_FLOAT_EQ(result_test_value->a_float_value(), 64.42);
+}
+
+template <typename T>
+class PrimitiveValueTest : public ConstEvaluatorTest {
+ protected:
+ T Eval(const T value) {
+ OwnedFlatbuffer<SemanticExpression> expression =
+ CreateAndPackPrimitiveConstExpression<T>(value);
+ StatusOr<const SemanticValue*> result =
+ const_eval_.Apply(/*context=*/{}, expression.get(), &arena_);
+ EXPECT_TRUE(result.ok());
+ const SemanticValue* result_value = result.ValueOrDie();
+ EXPECT_NE(result_value, nullptr);
+ return result_value->Value<T>();
+ }
+};
+
+using PrimitiveTypes = ::testing::Types<int8, uint8, int16, uint16, int32,
+ uint32, int64, uint64, double, float>;
+TYPED_TEST_SUITE(PrimitiveValueTest, PrimitiveTypes);
+
+TYPED_TEST(PrimitiveValueTest, CreatesConstantPrimitiveValues) {
+ EXPECT_EQ(this->Eval(42), 42);
+}
+
+TEST_F(ConstEvaluatorTest, CreatesStringValues) {
+ OwnedFlatbuffer<SemanticExpression> expression =
+ CreateAndPackPrimitiveConstExpression<StringPiece>("this is a test.");
+ StatusOr<const SemanticValue*> result =
+ const_eval_.Apply(/*context=*/{}, expression.get(), &arena_);
+
+ EXPECT_TRUE(result.ok());
+ const SemanticValue* result_value = result.ValueOrDie();
+ ASSERT_NE(result_value, nullptr);
+ EXPECT_EQ(result_value->Value<StringPiece>().ToString(), "this is a test.");
+}
+
+TEST_F(ConstEvaluatorTest, CreatesBoolValues) {
+ OwnedFlatbuffer<SemanticExpression> expression =
+ CreateAndPackPrimitiveConstExpression<bool>(true);
+ StatusOr<const SemanticValue*> result =
+ const_eval_.Apply(/*context=*/{}, expression.get(), &arena_);
+
+ EXPECT_TRUE(result.ok());
+ const SemanticValue* result_value = result.ValueOrDie();
+ ASSERT_NE(result_value, nullptr);
+ EXPECT_TRUE(result_value->Value<bool>());
+}
+
+} // namespace
+} // namespace libtextclassifier3::grammar
diff --git a/native/utils/grammar/semantics/evaluators/constituent-eval.h b/native/utils/grammar/semantics/evaluators/constituent-eval.h
new file mode 100644
index 0000000..4b877fe
--- /dev/null
+++ b/native/utils/grammar/semantics/evaluators/constituent-eval.h
@@ -0,0 +1,51 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVALUATORS_CONSTITUENT_EVAL_H_
+#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVALUATORS_CONSTITUENT_EVAL_H_
+
+#include "utils/base/arena.h"
+#include "utils/grammar/semantics/eval-context.h"
+#include "utils/grammar/semantics/evaluator.h"
+#include "utils/grammar/semantics/expression_generated.h"
+#include "utils/grammar/semantics/value.h"
+
+namespace libtextclassifier3::grammar {
+
+// Returns the semantic value of an evaluated constituent.
+class ConstituentEvaluator : public SemanticExpressionEvaluator {
+ public:
+ StatusOr<const SemanticValue*> Apply(const EvalContext& context,
+ const SemanticExpression* expression,
+ UnsafeArena*) const override {
+ TC3_DCHECK_EQ(expression->expression_type(),
+ SemanticExpression_::Expression_ConstituentExpression);
+ const ConstituentExpression* constituent_expression =
+ expression->expression_as_ConstituentExpression();
+ const auto constituent_it =
+ context.rule_constituents.find(constituent_expression->id());
+ if (constituent_it != context.rule_constituents.end()) {
+ return constituent_it->second;
+ }
+ // The constituent was not present in the rule parse tree, return a
+ // null value for it.
+ return nullptr;
+ }
+};
+
+} // namespace libtextclassifier3::grammar
+
+#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVALUATORS_CONSTITUENT_EVAL_H_
diff --git a/native/utils/grammar/semantics/evaluators/constituent-eval_test.cc b/native/utils/grammar/semantics/evaluators/constituent-eval_test.cc
new file mode 100644
index 0000000..c40d1cc
--- /dev/null
+++ b/native/utils/grammar/semantics/evaluators/constituent-eval_test.cc
@@ -0,0 +1,79 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/grammar/semantics/evaluators/constituent-eval.h"
+
+#include "utils/flatbuffers/flatbuffers.h"
+#include "utils/grammar/semantics/eval-context.h"
+#include "utils/grammar/semantics/expression_generated.h"
+#include "utils/grammar/testing/utils.h"
+#include "utils/grammar/testing/value_generated.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "flatbuffers/flatbuffers.h"
+
+namespace libtextclassifier3::grammar {
+namespace {
+
+class ConstituentEvaluatorTest : public GrammarTest {
+ protected:
+ explicit ConstituentEvaluatorTest() {}
+
+ OwnedFlatbuffer<SemanticExpression> CreateConstituentExpression(
+ const int id) {
+ ConstituentExpressionT constituent_expression;
+ constituent_expression.id = id;
+ return CreateExpression(constituent_expression);
+ }
+
+ const ConstituentEvaluator constituent_eval_;
+};
+
+TEST_F(ConstituentEvaluatorTest, HandlesNotDefinedConstituents) {
+ OwnedFlatbuffer<SemanticExpression> expression =
+ CreateConstituentExpression(/*id=*/42);
+
+ StatusOr<const SemanticValue*> result = constituent_eval_.Apply(
+ /*context=*/{}, expression.get(), /*arena=*/nullptr);
+
+ EXPECT_TRUE(result.ok());
+ EXPECT_EQ(result.ValueOrDie(), nullptr);
+}
+
+TEST_F(ConstituentEvaluatorTest, ForwardsConstituentSemanticValues) {
+ // Create example values for constituents.
+ EvalContext context;
+ TestValueT value_0;
+ value_0.test_string = "constituent 0 value";
+ context.rule_constituents[0] = CreateSemanticValue(value_0);
+
+ TestValueT value_42;
+ value_42.test_string = "constituent 42 value";
+ context.rule_constituents[42] = CreateSemanticValue(value_42);
+
+ OwnedFlatbuffer<SemanticExpression> expression =
+ CreateConstituentExpression(/*id=*/42);
+
+ StatusOr<const SemanticValue*> result =
+ constituent_eval_.Apply(context, expression.get(), /*arena=*/nullptr);
+
+ EXPECT_TRUE(result.ok());
+ const TestValue* result_value = result.ValueOrDie()->Table<TestValue>();
+ EXPECT_EQ(result_value->test_string()->str(), "constituent 42 value");
+}
+
+} // namespace
+} // namespace libtextclassifier3::grammar
diff --git a/native/utils/grammar/semantics/evaluators/merge-values-eval.cc b/native/utils/grammar/semantics/evaluators/merge-values-eval.cc
new file mode 100644
index 0000000..d9bf544
--- /dev/null
+++ b/native/utils/grammar/semantics/evaluators/merge-values-eval.cc
@@ -0,0 +1,49 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/grammar/semantics/evaluators/merge-values-eval.h"
+
+namespace libtextclassifier3::grammar {
+
+StatusOr<const SemanticValue*> MergeValuesEvaluator::Apply(
+ const EvalContext& context, const SemanticExpression* expression,
+ UnsafeArena* arena) const {
+ const MergeValueExpression* merge_value_expression =
+ expression->expression_as_MergeValueExpression();
+ std::unique_ptr<MutableFlatbuffer> result =
+ semantic_value_builder_.NewTable(merge_value_expression->type());
+
+ if (result == nullptr) {
+ return Status(StatusCode::INVALID_ARGUMENT, "Invalid result type.");
+ }
+
+ for (const SemanticExpression* semantic_expression :
+ *merge_value_expression->values()) {
+ TC3_ASSIGN_OR_RETURN(const SemanticValue* value,
+ composer_->Apply(context, semantic_expression, arena));
+ if (value == nullptr) {
+ continue;
+ }
+ if ((value->type() != result->type()) ||
+ !result->MergeFrom(value->Table())) {
+ return Status(StatusCode::INVALID_ARGUMENT,
+ "Could not merge the results.");
+ }
+ }
+ return SemanticValue::Create<const MutableFlatbuffer*>(result.get(), arena);
+}
+
+} // namespace libtextclassifier3::grammar
diff --git a/native/utils/grammar/semantics/evaluators/merge-values-eval.h b/native/utils/grammar/semantics/evaluators/merge-values-eval.h
new file mode 100644
index 0000000..8fe49e3
--- /dev/null
+++ b/native/utils/grammar/semantics/evaluators/merge-values-eval.h
@@ -0,0 +1,50 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVALUATORS_MERGE_VALUES_EVAL_H_
+#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVALUATORS_MERGE_VALUES_EVAL_H_
+
+#include "utils/base/arena.h"
+#include "utils/base/status_macros.h"
+#include "utils/flatbuffers/mutable.h"
+#include "utils/grammar/semantics/eval-context.h"
+#include "utils/grammar/semantics/evaluator.h"
+#include "utils/grammar/semantics/expression_generated.h"
+#include "utils/grammar/semantics/value.h"
+
+namespace libtextclassifier3::grammar {
+
+// Evaluate the “merge” semantic function expression.
+// Conceptually, the way this merge evaluator works is that each of the
+// arguments (semantic value) is merged into a return type semantic value.
+class MergeValuesEvaluator : public SemanticExpressionEvaluator {
+ public:
+ explicit MergeValuesEvaluator(
+ const SemanticExpressionEvaluator* composer,
+ const reflection::Schema* semantic_values_schema)
+ : composer_(composer), semantic_value_builder_(semantic_values_schema) {}
+
+ StatusOr<const SemanticValue*> Apply(const EvalContext& context,
+ const SemanticExpression* expression,
+ UnsafeArena* arena) const override;
+
+ private:
+ const SemanticExpressionEvaluator* composer_;
+ const MutableFlatbufferBuilder semantic_value_builder_;
+};
+} // namespace libtextclassifier3::grammar
+
+#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVALUATORS_MERGE_VALUES_EVAL_H_
diff --git a/native/utils/grammar/semantics/evaluators/merge-values-eval_test.cc b/native/utils/grammar/semantics/evaluators/merge-values-eval_test.cc
new file mode 100644
index 0000000..8d3d70f
--- /dev/null
+++ b/native/utils/grammar/semantics/evaluators/merge-values-eval_test.cc
@@ -0,0 +1,77 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/grammar/semantics/evaluators/merge-values-eval.h"
+
+#include <vector>
+
+#include "utils/base/statusor.h"
+#include "utils/flatbuffers/flatbuffers.h"
+#include "utils/flatbuffers/reflection.h"
+#include "utils/flatbuffers/test-utils.h"
+#include "utils/grammar/semantics/evaluator.h"
+#include "utils/grammar/semantics/evaluators/const-eval.h"
+#include "utils/grammar/semantics/expression_generated.h"
+#include "utils/grammar/testing/utils.h"
+#include "utils/grammar/testing/value_generated.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "flatbuffers/flatbuffers.h"
+
+namespace libtextclassifier3::grammar {
+namespace {
+
+class MergeValuesEvaluatorTest : public GrammarTest {
+ protected:
+ explicit MergeValuesEvaluatorTest()
+ : const_eval_(semantic_values_schema_.get()) {}
+
+ // Evaluator that just returns a constant value.
+ ConstEvaluator const_eval_;
+};
+
+TEST_F(MergeValuesEvaluatorTest, MergeSemanticValues) {
+ // Setup the data
+ TestDateT date_value_day;
+ date_value_day.day = 23;
+ TestDateT date_value_month;
+ date_value_month.month = 9;
+ TestDateT date_value_year;
+ date_value_year.year = 2019;
+
+ OwnedFlatbuffer<SemanticExpression> expression =
+ CreateAndPackMergeValuesExpression(
+ {date_value_day, date_value_month, date_value_year});
+
+ MergeValuesEvaluator merge_values_eval(&const_eval_,
+ semantic_values_schema_.get());
+
+ StatusOr<const SemanticValue*> result =
+ merge_values_eval.Apply(/*context=*/{}, expression.get(), &arena_);
+
+ EXPECT_TRUE(result.ok());
+ const SemanticValue* result_value = result.ValueOrDie();
+ ASSERT_NE(result_value, nullptr);
+ EXPECT_EQ(result_value->type()->name()->str(),
+ "libtextclassifier3.grammar.TestDate");
+ const TestDate* result_test_date = result_value->Table<TestDate>();
+ EXPECT_EQ(result_test_date->day(), 23);
+ EXPECT_EQ(result_test_date->month(), 9);
+ EXPECT_EQ(result_test_date->year(), 2019);
+}
+
+} // namespace
+} // namespace libtextclassifier3::grammar
diff --git a/native/utils/grammar/semantics/evaluators/parse-number-eval.h b/native/utils/grammar/semantics/evaluators/parse-number-eval.h
new file mode 100644
index 0000000..9171c65
--- /dev/null
+++ b/native/utils/grammar/semantics/evaluators/parse-number-eval.h
@@ -0,0 +1,110 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVALUATORS_PARSE_NUMBER_EVAL_H_
+#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVALUATORS_PARSE_NUMBER_EVAL_H_
+
+#include <string>
+
+#include "utils/base/arena.h"
+#include "utils/base/statusor.h"
+#include "utils/grammar/semantics/eval-context.h"
+#include "utils/grammar/semantics/evaluator.h"
+#include "utils/grammar/semantics/expression_generated.h"
+#include "utils/grammar/semantics/value.h"
+#include "utils/strings/numbers.h"
+
+namespace libtextclassifier3::grammar {
+
+// Parses a string as a number.
+class ParseNumberEvaluator : public SemanticExpressionEvaluator {
+ public:
+ explicit ParseNumberEvaluator(const SemanticExpressionEvaluator* composer)
+ : composer_(composer) {}
+
+ StatusOr<const SemanticValue*> Apply(const EvalContext& context,
+ const SemanticExpression* expression,
+ UnsafeArena* arena) const override {
+ TC3_DCHECK_EQ(expression->expression_type(),
+ SemanticExpression_::Expression_ParseNumberExpression);
+ const ParseNumberExpression* parse_number_expression =
+ expression->expression_as_ParseNumberExpression();
+
+ // Evaluate argument.
+ TC3_ASSIGN_OR_RETURN(
+ const SemanticValue* value,
+ composer_->Apply(context, parse_number_expression->value(), arena));
+ if (value == nullptr) {
+ return nullptr;
+ }
+ if (!value->Has<StringPiece>()) {
+ return Status(StatusCode::INVALID_ARGUMENT,
+ "Argument didn't evaluate as a string value.");
+ }
+ const std::string data = value->Value<std::string>();
+
+ // Parse the string data as a number.
+ const reflection::BaseType type =
+ static_cast<reflection::BaseType>(parse_number_expression->base_type());
+ if (flatbuffers::IsLong(type)) {
+ TC3_ASSIGN_OR_RETURN(const int64 value, TryParse<int64>(data));
+ return SemanticValue::Create(type, value, arena);
+ } else if (flatbuffers::IsInteger(type)) {
+ TC3_ASSIGN_OR_RETURN(const int32 value, TryParse<int32>(data));
+ return SemanticValue::Create(type, value, arena);
+ } else if (flatbuffers::IsFloat(type)) {
+ TC3_ASSIGN_OR_RETURN(const double value, TryParse<double>(data));
+ return SemanticValue::Create(type, value, arena);
+ } else {
+ return Status(StatusCode::INVALID_ARGUMENT,
+ "Unsupported type: " + std::to_string(type));
+ }
+ }
+
+ private:
+ template <typename T>
+ bool Parse(const std::string& data, T* value) const;
+
+ template <>
+ bool Parse(const std::string& data, int32* value) const {
+ return ParseInt32(data.data(), value);
+ }
+
+ template <>
+ bool Parse(const std::string& data, int64* value) const {
+ return ParseInt64(data.data(), value);
+ }
+
+ template <>
+ bool Parse(const std::string& data, double* value) const {
+ return ParseDouble(data.data(), value);
+ }
+
+ template <typename T>
+ StatusOr<T> TryParse(const std::string& data) const {
+ T result;
+ if (!Parse<T>(data, &result)) {
+ return Status(StatusCode::INVALID_ARGUMENT, "Could not parse value.");
+ }
+ return result;
+ }
+
+ const SemanticExpressionEvaluator* composer_;
+};
+
+} // namespace libtextclassifier3::grammar
+
+#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVALUATORS_PARSE_NUMBER_EVAL_H_
diff --git a/native/utils/grammar/semantics/evaluators/parse-number-eval_test.cc b/native/utils/grammar/semantics/evaluators/parse-number-eval_test.cc
new file mode 100644
index 0000000..e9f21d9
--- /dev/null
+++ b/native/utils/grammar/semantics/evaluators/parse-number-eval_test.cc
@@ -0,0 +1,86 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/grammar/semantics/evaluators/parse-number-eval.h"
+
+#include <vector>
+
+#include "utils/base/statusor.h"
+#include "utils/flatbuffers/flatbuffers.h"
+#include "utils/flatbuffers/reflection.h"
+#include "utils/flatbuffers/test-utils.h"
+#include "utils/grammar/semantics/evaluator.h"
+#include "utils/grammar/semantics/evaluators/const-eval.h"
+#include "utils/grammar/semantics/expression_generated.h"
+#include "utils/grammar/testing/utils.h"
+#include "utils/grammar/testing/value_generated.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "flatbuffers/flatbuffers.h"
+
+namespace libtextclassifier3::grammar {
+namespace {
+
+template <typename T>
+class ParseNumberEvaluatorTest : public GrammarTest {
+ protected:
+ T Eval(const StringPiece value) {
+ ParseNumberExpressionT parse_number_expression;
+ parse_number_expression.base_type = flatbuffers_base_type<T>::value;
+ parse_number_expression.value =
+ CreatePrimitiveConstExpression<StringPiece>(value);
+ OwnedFlatbuffer<SemanticExpression> expression =
+ CreateExpression(std::move(parse_number_expression));
+
+ ConstEvaluator const_eval(semantic_values_schema_.get());
+ ParseNumberEvaluator parse_number_eval(&const_eval);
+
+ StatusOr<const SemanticValue*> result =
+ parse_number_eval.Apply(/*context=*/{}, expression.get(), &arena_);
+
+ EXPECT_TRUE(result.ok());
+ const SemanticValue* result_value = result.ValueOrDie();
+ EXPECT_NE(result_value, nullptr);
+ return result_value->Value<T>();
+ }
+};
+
+using NumberTypes = ::testing::Types<int8, uint8, int16, uint16, int32, uint32,
+ int64, uint64, double, float>;
+TYPED_TEST_SUITE(ParseNumberEvaluatorTest, NumberTypes);
+
+TYPED_TEST(ParseNumberEvaluatorTest, ParsesNumber) {
+ EXPECT_EQ(this->Eval("42"), 42);
+}
+
+TEST_F(GrammarTest, FailsOnInvalidArgument) {
+ ParseNumberExpressionT parse_number_expression;
+ parse_number_expression.base_type = flatbuffers_base_type<int32>::value;
+ parse_number_expression.value = CreatePrimitiveConstExpression<int32>(42);
+ OwnedFlatbuffer<SemanticExpression> expression =
+ CreateExpression(std::move(parse_number_expression));
+
+ ConstEvaluator const_eval(semantic_values_schema_.get());
+ ParseNumberEvaluator parse_number_eval(&const_eval);
+
+ StatusOr<const SemanticValue*> result =
+ parse_number_eval.Apply(/*context=*/{}, expression.get(), &arena_);
+
+ EXPECT_FALSE(result.ok());
+}
+
+} // namespace
+} // namespace libtextclassifier3::grammar
diff --git a/native/utils/grammar/semantics/evaluators/span-eval.h b/native/utils/grammar/semantics/evaluators/span-eval.h
new file mode 100644
index 0000000..f8a5d5b
--- /dev/null
+++ b/native/utils/grammar/semantics/evaluators/span-eval.h
@@ -0,0 +1,45 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVALUATORS_SPAN_EVAL_H_
+#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVALUATORS_SPAN_EVAL_H_
+
+#include "annotator/types.h"
+#include "utils/base/arena.h"
+#include "utils/base/statusor.h"
+#include "utils/grammar/semantics/eval-context.h"
+#include "utils/grammar/semantics/evaluator.h"
+#include "utils/grammar/semantics/expression_generated.h"
+#include "utils/grammar/semantics/value.h"
+
+namespace libtextclassifier3::grammar {
+
+// Returns a value lifted from a parse tree.
+class SpanAsStringEvaluator : public SemanticExpressionEvaluator {
+ public:
+ StatusOr<const SemanticValue*> Apply(const EvalContext& context,
+ const SemanticExpression* expression,
+ UnsafeArena* arena) const override {
+ TC3_DCHECK_EQ(expression->expression_type(),
+ SemanticExpression_::Expression_SpanAsStringExpression);
+ return SemanticValue::Create(
+ context.text_context->Span(context.parse_tree->codepoint_span), arena);
+ }
+};
+
+} // namespace libtextclassifier3::grammar
+
+#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVALUATORS_SPAN_EVAL_H_
diff --git a/native/utils/grammar/semantics/evaluators/span-eval_test.cc b/native/utils/grammar/semantics/evaluators/span-eval_test.cc
new file mode 100644
index 0000000..daba860
--- /dev/null
+++ b/native/utils/grammar/semantics/evaluators/span-eval_test.cc
@@ -0,0 +1,49 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/grammar/semantics/evaluators/span-eval.h"
+
+#include "utils/flatbuffers/flatbuffers.h"
+#include "utils/grammar/semantics/expression_generated.h"
+#include "utils/grammar/testing/utils.h"
+#include "utils/grammar/testing/value_generated.h"
+#include "utils/grammar/types.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "flatbuffers/flatbuffers.h"
+
+namespace libtextclassifier3::grammar {
+namespace {
+
+class SpanTextEvaluatorTest : public GrammarTest {};
+
+TEST_F(SpanTextEvaluatorTest, CreatesSpanTextValues) {
+ OwnedFlatbuffer<SemanticExpression> expression =
+ CreateExpression(SpanAsStringExpressionT());
+ SpanAsStringEvaluator span_eval;
+ TextContext text = TextContextForText("This a test.");
+ ParseTree derivation(/*lhs=*/kUnassignedNonterm, CodepointSpan{5, 11},
+ /*match_offset=*/0, /*type=*/ParseTree::Type::kDefault);
+
+ StatusOr<const SemanticValue*> result = span_eval.Apply(
+ /*context=*/{&text, &derivation}, expression.get(), &arena_);
+
+ ASSERT_TRUE(result.ok());
+ EXPECT_EQ(result.ValueOrDie()->Value<StringPiece>().ToString(), "a test");
+}
+
+} // namespace
+} // namespace libtextclassifier3::grammar
diff --git a/native/utils/grammar/semantics/expression.fbs b/native/utils/grammar/semantics/expression.fbs
new file mode 100644
index 0000000..5397407
--- /dev/null
+++ b/native/utils/grammar/semantics/expression.fbs
@@ -0,0 +1,119 @@
+//
+// Copyright (C) 2018 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+include "utils/flatbuffers/flatbuffers.fbs";
+
+namespace libtextclassifier3.grammar.SemanticExpression_;
+union Expression {
+ ConstValueExpression,
+ ConstituentExpression,
+ ComposeExpression,
+ SpanAsStringExpression,
+ ParseNumberExpression,
+ MergeValueExpression,
+ ArithmeticExpression,
+}
+
+// A semantic expression.
+namespace libtextclassifier3.grammar;
+table SemanticExpression {
+ expression:SemanticExpression_.Expression;
+}
+
+// A constant flatbuffer value.
+namespace libtextclassifier3.grammar;
+table ConstValueExpression {
+ // The base type of the value.
+ base_type:int;
+
+ // The id of the type of the value.
+ // The id is used for lookup in the semantic values type metadata.
+ type:int;
+
+ // The serialized value.
+ value:[ubyte];
+}
+
+// The value of a rule constituent.
+namespace libtextclassifier3.grammar;
+table ConstituentExpression {
+ // The id of the constituent.
+ id:ushort;
+}
+
+// The fields to set.
+namespace libtextclassifier3.grammar.ComposeExpression_;
+table Field {
+ // The field to set.
+ path:libtextclassifier3.FlatbufferFieldPath;
+
+ // The value.
+ value:SemanticExpression;
+}
+
+// A combination: Compose a result from arguments.
+// https://mitpress.mit.edu/sites/default/files/sicp/full-text/book/book-Z-H-4.html#%_toc_%_sec_1.1.1
+namespace libtextclassifier3.grammar;
+table ComposeExpression {
+ // The id of the type of the result.
+ type:int;
+
+ fields:[ComposeExpression_.Field];
+}
+
+// Lifts a span as a value.
+namespace libtextclassifier3.grammar;
+table SpanAsStringExpression {
+}
+
+// Parses a string as a number.
+namespace libtextclassifier3.grammar;
+table ParseNumberExpression {
+ // The base type of the value.
+ base_type:int;
+
+ value:SemanticExpression;
+}
+
+// Merge the semantic expressions.
+namespace libtextclassifier3.grammar;
+table MergeValueExpression {
+ // The id of the type of the result.
+ type:int;
+
+ values:[SemanticExpression];
+}
+
+// The operator of the arithmetic expression.
+namespace libtextclassifier3.grammar.ArithmeticExpression_;
+enum Operator : int {
+ NO_OP = 0,
+ OP_ADD = 1,
+ OP_MUL = 2,
+ OP_MAX = 3,
+ OP_MIN = 4,
+}
+
+// Simple arithmetic expression.
+namespace libtextclassifier3.grammar;
+table ArithmeticExpression {
+ // The base type of the operation.
+ base_type:int;
+
+ op:ArithmeticExpression_.Operator;
+ values:[SemanticExpression];
+}
+
diff --git a/native/utils/grammar/semantics/value.h b/native/utils/grammar/semantics/value.h
new file mode 100644
index 0000000..abf5eaf
--- /dev/null
+++ b/native/utils/grammar/semantics/value.h
@@ -0,0 +1,218 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_VALUE_H_
+#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_VALUE_H_
+
+#include "utils/base/arena.h"
+#include "utils/base/logging.h"
+#include "utils/flatbuffers/mutable.h"
+#include "utils/flatbuffers/reflection.h"
+#include "utils/strings/stringpiece.h"
+#include "utils/utf8/unicodetext.h"
+#include "flatbuffers/base.h"
+#include "flatbuffers/reflection.h"
+
+namespace libtextclassifier3::grammar {
+
+// A semantic value as a typed, arena-allocated flatbuffer.
+// This denotes the possible results of the evaluation of a semantic expression.
+class SemanticValue {
+ public:
+ // Creates an arena allocated semantic value.
+ template <typename T>
+ static const SemanticValue* Create(const T value, UnsafeArena* arena) {
+ static_assert(!std::is_pointer<T>() && std::is_scalar<T>());
+ if (char* buffer = reinterpret_cast<char*>(
+ arena->AllocAligned(sizeof(T), alignof(T)))) {
+ flatbuffers::WriteScalar<T>(buffer, value);
+ return arena->AllocAndInit<SemanticValue>(
+ libtextclassifier3::flatbuffers_base_type<T>::value,
+ StringPiece(buffer, sizeof(T)));
+ }
+ return nullptr;
+ }
+
+ template <>
+ const SemanticValue* Create(const StringPiece value, UnsafeArena* arena) {
+ return arena->AllocAndInit<SemanticValue>(reflection::BaseType::String,
+ value);
+ }
+
+ template <>
+ const SemanticValue* Create(const UnicodeText value, UnsafeArena* arena) {
+ return arena->AllocAndInit<SemanticValue>(
+ reflection::BaseType::String,
+ StringPiece(value.data(), value.size_bytes()));
+ }
+
+ template <>
+ const SemanticValue* Create(const MutableFlatbuffer* value,
+ UnsafeArena* arena) {
+ const std::string buffer = value->Serialize();
+ return Create(
+ value->type(),
+ StringPiece(arena->Memdup(buffer.data(), buffer.size()), buffer.size()),
+ arena);
+ }
+
+ static const SemanticValue* Create(const reflection::Object* type,
+ const StringPiece data,
+ UnsafeArena* arena) {
+ return arena->AllocAndInit<SemanticValue>(type, data);
+ }
+
+ static const SemanticValue* Create(const reflection::BaseType base_type,
+ const StringPiece data,
+ UnsafeArena* arena) {
+ return arena->AllocAndInit<SemanticValue>(base_type, data);
+ }
+
+ template <typename T>
+ static const SemanticValue* Create(const reflection::BaseType base_type,
+ const T value, UnsafeArena* arena) {
+ switch (base_type) {
+ case reflection::BaseType::Bool:
+ return Create(
+ static_cast<
+ flatbuffers_cpp_type<reflection::BaseType::Bool>::value>(value),
+ arena);
+ case reflection::BaseType::Byte:
+ return Create(
+ static_cast<
+ flatbuffers_cpp_type<reflection::BaseType::Byte>::value>(value),
+ arena);
+ case reflection::BaseType::UByte:
+ return Create(
+ static_cast<
+ flatbuffers_cpp_type<reflection::BaseType::UByte>::value>(
+ value),
+ arena);
+ case reflection::BaseType::Short:
+ return Create(
+ static_cast<
+ flatbuffers_cpp_type<reflection::BaseType::Short>::value>(
+ value),
+ arena);
+ case reflection::BaseType::UShort:
+ return Create(
+ static_cast<
+ flatbuffers_cpp_type<reflection::BaseType::UShort>::value>(
+ value),
+ arena);
+ case reflection::BaseType::Int:
+ return Create(
+ static_cast<flatbuffers_cpp_type<reflection::BaseType::Int>::value>(
+ value),
+ arena);
+ case reflection::BaseType::UInt:
+ return Create(
+ static_cast<
+ flatbuffers_cpp_type<reflection::BaseType::UInt>::value>(value),
+ arena);
+ case reflection::BaseType::Long:
+ return Create(
+ static_cast<
+ flatbuffers_cpp_type<reflection::BaseType::Long>::value>(value),
+ arena);
+ case reflection::BaseType::ULong:
+ return Create(
+ static_cast<
+ flatbuffers_cpp_type<reflection::BaseType::ULong>::value>(
+ value),
+ arena);
+ case reflection::BaseType::Float:
+ return Create(
+ static_cast<
+ flatbuffers_cpp_type<reflection::BaseType::Float>::value>(
+ value),
+ arena);
+ case reflection::BaseType::Double:
+ return Create(
+ static_cast<
+ flatbuffers_cpp_type<reflection::BaseType::Double>::value>(
+ value),
+ arena);
+ default: {
+ TC3_LOG(ERROR) << "Unhandled type: " << base_type;
+ return nullptr;
+ }
+ }
+ }
+
+ explicit SemanticValue(const reflection::BaseType base_type,
+ const StringPiece data)
+ : base_type_(base_type), type_(nullptr), data_(data) {}
+ explicit SemanticValue(const reflection::Object* type, const StringPiece data)
+ : base_type_(reflection::BaseType::Obj), type_(type), data_(data) {}
+
+ template <typename T>
+ bool Has() const {
+ return base_type_ == libtextclassifier3::flatbuffers_base_type<T>::value;
+ }
+
+ template <>
+ bool Has<flatbuffers::Table>() const {
+ return base_type_ == reflection::BaseType::Obj;
+ }
+
+ template <typename T = flatbuffers::Table>
+ const T* Table() const {
+ TC3_CHECK(Has<flatbuffers::Table>());
+ return flatbuffers::GetRoot<T>(
+ reinterpret_cast<const unsigned char*>(data_.data()));
+ }
+
+ template <typename T>
+ const T Value() const {
+ TC3_CHECK(Has<T>());
+ return flatbuffers::ReadScalar<T>(data_.data());
+ }
+
+ template <>
+ const StringPiece Value<StringPiece>() const {
+ TC3_CHECK(Has<StringPiece>());
+ return data_;
+ }
+
+ template <>
+ const std::string Value<std::string>() const {
+ TC3_CHECK(Has<StringPiece>());
+ return data_.ToString();
+ }
+
+ template <>
+ const UnicodeText Value<UnicodeText>() const {
+ TC3_CHECK(Has<StringPiece>());
+ return UTF8ToUnicodeText(data_, /*do_copy=*/false);
+ }
+
+ const reflection::BaseType base_type() const { return base_type_; }
+ const reflection::Object* type() const { return type_; }
+
+ private:
+ // The base type.
+ const reflection::BaseType base_type_;
+
+ // The object type of the value.
+ const reflection::Object* type_;
+
+ StringPiece data_;
+};
+
+} // namespace libtextclassifier3::grammar
+
+#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_VALUE_H_
diff --git a/native/utils/grammar/testing/utils.h b/native/utils/grammar/testing/utils.h
new file mode 100644
index 0000000..709b94a
--- /dev/null
+++ b/native/utils/grammar/testing/utils.h
@@ -0,0 +1,239 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_TESTING_UTILS_H_
+#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_TESTING_UTILS_H_
+
+#include <memory>
+#include <vector>
+
+#include "utils/base/arena.h"
+#include "utils/flatbuffers/reflection.h"
+#include "utils/grammar/parsing/derivation.h"
+#include "utils/grammar/parsing/parse-tree.h"
+#include "utils/grammar/semantics/value.h"
+#include "utils/grammar/testing/value_generated.h"
+#include "utils/grammar/text-context.h"
+#include "utils/i18n/locale.h"
+#include "utils/jvm-test-utils.h"
+#include "utils/test-data-test-utils.h"
+#include "utils/tokenizer.h"
+#include "utils/utf8/unilib.h"
+#include "gmock/gmock.h"
+#include "flatbuffers/base.h"
+#include "flatbuffers/flatbuffers.h"
+
+namespace libtextclassifier3::grammar {
+
+inline std::ostream& operator<<(std::ostream& os, const ParseTree* parse_tree) {
+ return os << "ParseTree(lhs=" << parse_tree->lhs
+ << ", begin=" << parse_tree->codepoint_span.first
+ << ", end=" << parse_tree->codepoint_span.second << ")";
+}
+
+inline std::ostream& operator<<(std::ostream& os,
+ const Derivation& derivation) {
+ return os << "Derivation(rule_id=" << derivation.rule_id << ", "
+ << "parse_tree=" << derivation.parse_tree << ")";
+}
+
+MATCHER_P3(IsDerivation, rule_id, begin, end,
+ "is derivation of rule that " +
+ ::testing::DescribeMatcher<int>(rule_id, negation) +
+ ", begin that " +
+ ::testing::DescribeMatcher<int>(begin, negation) +
+ ", end that " + ::testing::DescribeMatcher<int>(end, negation)) {
+ return ::testing::ExplainMatchResult(CodepointSpan(begin, end),
+ arg.parse_tree->codepoint_span,
+ result_listener) &&
+ ::testing::ExplainMatchResult(rule_id, arg.rule_id, result_listener);
+}
+
+// A test fixture with common auxiliary test methods.
+class GrammarTest : public testing::Test {
+ protected:
+ explicit GrammarTest()
+ : unilib_(CreateUniLibForTesting()),
+ arena_(/*block_size=*/16 << 10),
+ semantic_values_schema_(
+ GetTestFileContent("utils/grammar/testing/value.bfbs")),
+ tokenizer_(libtextclassifier3::TokenizationType_ICU, unilib_.get(),
+ /*codepoint_ranges=*/{},
+ /*internal_tokenizer_codepoint_ranges=*/{},
+ /*split_on_script_change=*/false,
+ /*icu_preserve_whitespace_tokens=*/false) {}
+
+ TextContext TextContextForText(const std::string& text) {
+ TextContext context;
+ context.text = UTF8ToUnicodeText(text);
+ context.tokens = tokenizer_.Tokenize(context.text);
+ context.codepoints = context.text.Codepoints();
+ context.codepoints.push_back(context.text.end());
+ context.locales = {Locale::FromBCP47("en")};
+ context.context_span.first = 0;
+ context.context_span.second = context.tokens.size();
+ return context;
+ }
+
+ // Creates a semantic expression union.
+ template <typename T>
+ SemanticExpressionT AsSemanticExpressionUnion(T&& expression) {
+ SemanticExpressionT semantic_expression;
+ semantic_expression.expression.Set(std::forward<T>(expression));
+ return semantic_expression;
+ }
+
+ template <typename T>
+ OwnedFlatbuffer<SemanticExpression> CreateExpression(T&& expression) {
+ return Pack<SemanticExpression>(
+ AsSemanticExpressionUnion(std::forward<T>(expression)));
+ }
+
+ OwnedFlatbuffer<SemanticExpression> CreateEmptyExpression() {
+ return Pack<SemanticExpression>(SemanticExpressionT());
+ }
+
+ // Packs a flatbuffer.
+ template <typename T>
+ OwnedFlatbuffer<T> Pack(const typename T::NativeTableType&& value) {
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(T::Pack(builder, &value));
+ return OwnedFlatbuffer<T>(builder.Release());
+ }
+
+ // Creates a test semantic value.
+ const SemanticValue* CreateSemanticValue(const TestValueT& value) {
+ const std::string value_buffer = PackFlatbuffer<TestValue>(&value);
+ return arena_.AllocAndInit<SemanticValue>(
+ semantic_values_schema_->objects()->Get(
+ TypeIdForName(semantic_values_schema_.get(),
+ "libtextclassifier3.grammar.TestValue")
+ .value()),
+ StringPiece(arena_.Memdup(value_buffer.data(), value_buffer.size()),
+ value_buffer.size()));
+ }
+
+ // Creates a primitive semantic value.
+ template <typename T>
+ const SemanticValue* CreatePrimitiveSemanticValue(const T value) {
+ return arena_.AllocAndInit<SemanticValue>(value);
+ }
+
+ std::unique_ptr<SemanticExpressionT> CreateConstExpression(
+ const TestValueT& value) {
+ ConstValueExpressionT const_value;
+ const_value.base_type = reflection::BaseType::Obj;
+ const_value.type = TypeIdForName(semantic_values_schema_.get(),
+ "libtextclassifier3.grammar.TestValue")
+ .value();
+ const std::string value_buffer = PackFlatbuffer<TestValue>(&value);
+ const_value.value.assign(value_buffer.begin(), value_buffer.end());
+ auto semantic_expression = std::make_unique<SemanticExpressionT>();
+ semantic_expression->expression.Set(const_value);
+ return semantic_expression;
+ }
+
+ OwnedFlatbuffer<SemanticExpression> CreateAndPackConstExpression(
+ const TestValueT& value) {
+ ConstValueExpressionT const_value;
+ const_value.base_type = reflection::BaseType::Obj;
+ const_value.type = TypeIdForName(semantic_values_schema_.get(),
+ "libtextclassifier3.grammar.TestValue")
+ .value();
+ const std::string value_buffer = PackFlatbuffer<TestValue>(&value);
+ const_value.value.assign(value_buffer.begin(), value_buffer.end());
+ return CreateExpression(const_value);
+ }
+
+ std::unique_ptr<SemanticExpressionT> CreateConstDateExpression(
+ const TestDateT& value) {
+ ConstValueExpressionT const_value;
+ const_value.base_type = reflection::BaseType::Obj;
+ const_value.type = TypeIdForName(semantic_values_schema_.get(),
+ "libtextclassifier3.grammar.TestDate")
+ .value();
+ const std::string value_buffer = PackFlatbuffer<TestDate>(&value);
+ const_value.value.assign(value_buffer.begin(), value_buffer.end());
+ auto semantic_expression = std::make_unique<SemanticExpressionT>();
+ semantic_expression->expression.Set(const_value);
+ return semantic_expression;
+ }
+
+ OwnedFlatbuffer<SemanticExpression> CreateAndPackMergeValuesExpression(
+ const std::vector<TestDateT>& values) {
+ MergeValueExpressionT merge_expression;
+ merge_expression.type = TypeIdForName(semantic_values_schema_.get(),
+ "libtextclassifier3.grammar.TestDate")
+ .value();
+ for (const TestDateT& test_date : values) {
+ merge_expression.values.emplace_back(new SemanticExpressionT);
+ merge_expression.values.back() = CreateConstDateExpression(test_date);
+ }
+ return CreateExpression(std::move(merge_expression));
+ }
+
+ template <typename T>
+ std::unique_ptr<SemanticExpressionT> CreatePrimitiveConstExpression(
+ const T value) {
+ ConstValueExpressionT const_value;
+ const_value.base_type = flatbuffers_base_type<T>::value;
+ const_value.value.resize(sizeof(T));
+ flatbuffers::WriteScalar(const_value.value.data(), value);
+ auto semantic_expression = std::make_unique<SemanticExpressionT>();
+ semantic_expression->expression.Set(const_value);
+ return semantic_expression;
+ }
+
+ template <typename T>
+ OwnedFlatbuffer<SemanticExpression> CreateAndPackPrimitiveConstExpression(
+ const T value) {
+ ConstValueExpressionT const_value;
+ const_value.base_type = flatbuffers_base_type<T>::value;
+ const_value.value.resize(sizeof(T));
+ flatbuffers::WriteScalar(const_value.value.data(), value);
+ return CreateExpression(const_value);
+ }
+
+ template <>
+ OwnedFlatbuffer<SemanticExpression> CreateAndPackPrimitiveConstExpression(
+ const StringPiece value) {
+ ConstValueExpressionT const_value;
+ const_value.base_type = reflection::BaseType::String;
+ const_value.value.assign(value.data(), value.data() + value.size());
+ return CreateExpression(const_value);
+ }
+
+ template <>
+ std::unique_ptr<SemanticExpressionT> CreatePrimitiveConstExpression(
+ const StringPiece value) {
+ ConstValueExpressionT const_value;
+ const_value.base_type = reflection::BaseType::String;
+ const_value.value.assign(value.data(), value.data() + value.size());
+ auto semantic_expression = std::make_unique<SemanticExpressionT>();
+ semantic_expression->expression.Set(const_value);
+ return semantic_expression;
+ }
+
+ const std::unique_ptr<UniLib> unilib_;
+ UnsafeArena arena_;
+ const OwnedFlatbuffer<reflection::Schema, std::string>
+ semantic_values_schema_;
+ const Tokenizer tokenizer_;
+};
+
+} // namespace libtextclassifier3::grammar
+
+#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_TESTING_UTILS_H_
diff --git a/native/utils/grammar/testing/value.bfbs b/native/utils/grammar/testing/value.bfbs
new file mode 100644
index 0000000..6dd8538
--- /dev/null
+++ b/native/utils/grammar/testing/value.bfbs
Binary files differ
diff --git a/native/utils/grammar/testing/value.fbs b/native/utils/grammar/testing/value.fbs
new file mode 100644
index 0000000..0429491
--- /dev/null
+++ b/native/utils/grammar/testing/value.fbs
@@ -0,0 +1,44 @@
+//
+// Copyright (C) 2018 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+// Test enum
+namespace libtextclassifier3.grammar;
+enum TestEnum : int {
+ UNSPECIFIED = 0,
+ ENUM_1 = 1,
+ ENUM_2 = 2,
+}
+
+// A test semantic value result.
+namespace libtextclassifier3.grammar;
+table TestValue {
+ value:int;
+ a_float_value:double;
+ test_string:string (shared);
+ date:TestDate;
+ enum_value:TestEnum;
+ repeated_enum:[TestEnum];
+ repeated_date:[TestDate];
+}
+
+// A test date value result.
+namespace libtextclassifier3.grammar;
+table TestDate {
+ day:int;
+ month:int;
+ year:int;
+}
+
diff --git a/native/utils/grammar/text-context.h b/native/utils/grammar/text-context.h
new file mode 100644
index 0000000..53e5f8b
--- /dev/null
+++ b/native/utils/grammar/text-context.h
@@ -0,0 +1,57 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_TEXT_CONTEXT_H_
+#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_TEXT_CONTEXT_H_
+
+#include <vector>
+
+#include "annotator/types.h"
+#include "utils/i18n/locale.h"
+#include "utils/utf8/unicodetext.h"
+
+namespace libtextclassifier3::grammar {
+
+// Input to the parser.
+struct TextContext {
+ // Returns a view on a span of the text.
+ const UnicodeText Span(const CodepointSpan& span) const {
+ return text.Substring(codepoints[span.first], codepoints[span.second],
+ /*do_copy=*/false);
+ }
+
+ // The input text.
+ UnicodeText text;
+
+ // Pre-enumerated codepoints for fast substring extraction.
+ std::vector<UnicodeText::const_iterator> codepoints;
+
+ // The tokenized input text.
+ std::vector<Token> tokens;
+
+ // Locales of the input text.
+ std::vector<Locale> locales;
+
+ // Text annotations.
+ std::vector<AnnotatedSpan> annotations;
+
+ // The span of tokens to consider.
+ TokenSpan context_span;
+};
+
+}; // namespace libtextclassifier3::grammar
+
+#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_TEXT_CONTEXT_H_
diff --git a/native/utils/grammar/types.h b/native/utils/grammar/types.h
index a79532b..64a618d 100644
--- a/native/utils/grammar/types.h
+++ b/native/utils/grammar/types.h
@@ -38,6 +38,7 @@
kMapping = -3,
kExclusion = -4,
kRootRule = 1,
+ kSemanticExpression = 2,
};
// Special CallbackId indicating that there's no callback associated with a
diff --git a/native/utils/grammar/utils/ir.cc b/native/utils/grammar/utils/ir.cc
index ce074b8..9477dd0 100644
--- a/native/utils/grammar/utils/ir.cc
+++ b/native/utils/grammar/utils/ir.cc
@@ -16,6 +16,7 @@
#include "utils/grammar/utils/ir.h"
+#include "utils/i18n/locale.h"
#include "utils/strings/append.h"
#include "utils/strings/stringpiece.h"
#include "utils/zlib/zlib.h"
@@ -70,7 +71,7 @@
}
}
- return false;
+ return true;
}
Ir::LhsSet SortedLhsSet(const Ir::LhsSet& lhs_set) {
@@ -192,15 +193,6 @@
continue;
}
- // If either callback is a filter, we can't share as we must always run
- // both filters.
- if ((lhs.callback.id != kNoCallback &&
- filters_.find(lhs.callback.id) != filters_.end()) ||
- (candidate->callback.id != kNoCallback &&
- filters_.find(candidate->callback.id) != filters_.end())) {
- continue;
- }
-
// If the nonterminal is already defined, it must match for sharing.
if (lhs.nonterminal != kUnassignedNonterm &&
lhs.nonterminal != candidate->nonterminal) {
@@ -406,13 +398,6 @@
void Ir::Serialize(const bool include_debug_information,
RulesSetT* output) const {
- // Set callback information.
- for (const CallbackId filter_callback_id : filters_) {
- output->callback.push_back(RulesSet_::CallbackEntry(
- filter_callback_id, RulesSet_::Callback(/*is_filter=*/true)));
- }
- SortStructsForBinarySearchLookup(&output->callback);
-
// Add information about predefined nonterminal classes.
output->nonterminals.reset(new RulesSet_::NonterminalsT);
output->nonterminals->start_nt = GetNonterminalForName(kStartNonterm);
@@ -461,16 +446,29 @@
}
// Serialize the unary and binary rules.
- for (const RulesShard& shard : shards_) {
+ for (int i = 0; i < shards_.size(); i++) {
output->rules.emplace_back(std::make_unique<RulesSet_::RulesT>());
RulesSet_::RulesT* rules = output->rules.back().get();
+ for (const Locale& shard_locale : locale_shard_map_.GetLocales(i)) {
+ if (shard_locale.IsValid()) {
+ // Check if the language is set to all i.e. '*' which is a special, to
+ // make it consistent with device side parser here instead of filling
+ // the all locale leave the language tag list empty
+ rules->locale.emplace_back(
+ std::make_unique<libtextclassifier3::LanguageTagT>());
+ libtextclassifier3::LanguageTagT* language_tag =
+ rules->locale.back().get();
+ language_tag->language = shard_locale.Language();
+ language_tag->region = shard_locale.Region();
+ language_tag->script = shard_locale.Script();
+ }
+ }
+
// Serialize the unary rules.
- SerializeUnaryRulesShard(shard.unary_rules, output, rules);
-
+ SerializeUnaryRulesShard(shards_[i].unary_rules, output, rules);
// Serialize the binary rules.
- SerializeBinaryRulesShard(shard.binary_rules, output, rules);
+ SerializeBinaryRulesShard(shards_[i].binary_rules, output, rules);
}
-
// Serialize the terminal rules.
// We keep the rules separate by shard but merge the actual terminals into
// one shared string pool to most effectively exploit reuse.
diff --git a/native/utils/grammar/utils/ir.h b/native/utils/grammar/utils/ir.h
index b05b87f..f056d7a 100644
--- a/native/utils/grammar/utils/ir.h
+++ b/native/utils/grammar/utils/ir.h
@@ -25,6 +25,7 @@
#include "utils/base/integral_types.h"
#include "utils/grammar/rules_generated.h"
#include "utils/grammar/types.h"
+#include "utils/grammar/utils/locale-shard-map.h"
namespace libtextclassifier3::grammar {
@@ -96,21 +97,27 @@
std::unordered_map<TwoNonterms, LhsSet, BinaryRuleHasher> binary_rules;
};
- explicit Ir(const std::unordered_set<CallbackId>& filters = {},
- const int num_shards = 1)
- : num_nonterminals_(0), filters_(filters), shards_(num_shards) {}
+ explicit Ir(const LocaleShardMap& locale_shard_map)
+ : num_nonterminals_(0),
+ locale_shard_map_(locale_shard_map),
+ shards_(locale_shard_map_.GetNumberOfShards()) {}
// Adds a new non-terminal.
Nonterm AddNonterminal(const std::string& name = "") {
const Nonterm nonterminal = ++num_nonterminals_;
if (!name.empty()) {
// Record debug information.
- nonterminal_names_[nonterminal] = name;
- nonterminal_ids_[name] = nonterminal;
+ SetNonterminal(name, nonterminal);
}
return nonterminal;
}
+ // Sets the name of a nonterminal.
+ void SetNonterminal(const std::string& name, const Nonterm nonterminal) {
+ nonterminal_names_[nonterminal] = name;
+ nonterminal_ids_[name] = nonterminal;
+ }
+
// Defines a nonterminal if not yet defined.
Nonterm DefineNonterminal(Nonterm nonterminal) {
return (nonterminal != kUnassignedNonterm) ? nonterminal : AddNonterminal();
@@ -183,6 +190,12 @@
bool include_debug_information = false) const;
const std::vector<RulesShard>& shards() const { return shards_; }
+ const std::vector<std::pair<std::string, Nonterm>>& regex_rules() const {
+ return regex_rules_;
+ }
+ const std::vector<std::pair<std::string, Nonterm>>& annotations() const {
+ return annotations_;
+ }
private:
template <typename R, typename H>
@@ -214,9 +227,8 @@
Nonterm num_nonterminals_;
std::unordered_set<Nonterm> nonshareable_;
- // The set of callbacks that should be treated as filters.
- std::unordered_set<CallbackId> filters_;
-
+ // Locale information for Rules
+ const LocaleShardMap& locale_shard_map_;
// The sharded rules.
std::vector<RulesShard> shards_;
diff --git a/native/utils/grammar/utils/ir_test.cc b/native/utils/grammar/utils/ir_test.cc
index d2438dd..7a386df 100644
--- a/native/utils/grammar/utils/ir_test.cc
+++ b/native/utils/grammar/utils/ir_test.cc
@@ -24,13 +24,16 @@
namespace libtextclassifier3::grammar {
namespace {
+using ::testing::ElementsAre;
using ::testing::Eq;
using ::testing::IsEmpty;
using ::testing::Ne;
using ::testing::SizeIs;
TEST(IrTest, HandlesSharingWithTerminalRules) {
- Ir ir;
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ Ir ir(locale_shard_map);
// <t1> ::= the
const Nonterm t1 = ir.Add(kUnassignedNonterm, "the");
@@ -71,7 +74,9 @@
}
TEST(IrTest, HandlesSharingWithNonterminalRules) {
- Ir ir;
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ Ir ir(locale_shard_map);
// Setup a few terminal rules.
const std::vector<Nonterm> rhs = {
@@ -96,52 +101,31 @@
// Test sharing in the presence of callbacks.
constexpr CallbackId kOutput1 = 1;
constexpr CallbackId kOutput2 = 2;
- constexpr CallbackId kFilter1 = 3;
- constexpr CallbackId kFilter2 = 4;
- Ir ir(/*filters=*/{kFilter1, kFilter2});
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ Ir ir(locale_shard_map);
const Nonterm x1 = ir.Add(kUnassignedNonterm, "hello");
const Nonterm x2 =
ir.Add(Ir::Lhs{kUnassignedNonterm, {kOutput1, 0}}, "hello");
const Nonterm x3 =
- ir.Add(Ir::Lhs{kUnassignedNonterm, {kFilter1, 0}}, "hello");
- const Nonterm x4 =
ir.Add(Ir::Lhs{kUnassignedNonterm, {kOutput2, 0}}, "hello");
- const Nonterm x5 =
- ir.Add(Ir::Lhs{kUnassignedNonterm, {kFilter2, 0}}, "hello");
// Duplicate entry.
- const Nonterm x6 =
+ const Nonterm x4 =
ir.Add(Ir::Lhs{kUnassignedNonterm, {kOutput2, 0}}, "hello");
EXPECT_THAT(x2, Eq(x1));
- EXPECT_THAT(x3, Ne(x1));
+ EXPECT_THAT(x3, Eq(x1));
EXPECT_THAT(x4, Eq(x1));
- EXPECT_THAT(x5, Ne(x1));
- EXPECT_THAT(x5, Ne(x3));
- EXPECT_THAT(x6, Ne(x3));
-}
-
-TEST(IrTest, HandlesSharingWithCallbacksWithDifferentParameters) {
- // Test sharing in the presence of callbacks.
- constexpr CallbackId kOutput = 1;
- constexpr CallbackId kFilter = 2;
- Ir ir(/*filters=*/{kFilter});
-
- const Nonterm x1 = ir.Add(Ir::Lhs{kUnassignedNonterm, {kOutput, 0}}, "world");
- const Nonterm x2 = ir.Add(Ir::Lhs{kUnassignedNonterm, {kOutput, 1}}, "world");
- const Nonterm x3 = ir.Add(Ir::Lhs{kUnassignedNonterm, {kFilter, 0}}, "world");
- const Nonterm x4 = ir.Add(Ir::Lhs{kUnassignedNonterm, {kFilter, 1}}, "world");
-
- EXPECT_THAT(x2, Eq(x1));
- EXPECT_THAT(x3, Ne(x1));
- EXPECT_THAT(x4, Ne(x1));
- EXPECT_THAT(x4, Ne(x3));
}
TEST(IrTest, SerializesRulesToFlatbufferFormat) {
constexpr CallbackId kOutput = 1;
- Ir ir;
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ Ir ir(locale_shard_map);
+
const Nonterm verb = ir.AddUnshareableNonterminal();
ir.Add(verb, "buy");
ir.Add(Ir::Lhs{verb, {kOutput}}, "bring");
@@ -180,7 +164,9 @@
}
TEST(IrTest, HandlesRulesSharding) {
- Ir ir(/*filters=*/{}, /*num_shards=*/2);
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({"", "de"});
+ Ir ir(locale_shard_map);
const Nonterm verb = ir.AddUnshareableNonterminal();
const Nonterm set_reminder = ir.AddUnshareableNonterminal();
@@ -234,5 +220,23 @@
EXPECT_THAT(rules.rules[1]->binary_rules, SizeIs(3));
}
+TEST(IrTest, DeduplicatesLhsSets) {
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ Ir ir(locale_shard_map);
+
+ const Nonterm test = ir.AddUnshareableNonterminal();
+ ir.Add(test, "test");
+
+ // Add a second rule for the same nonterminal.
+ ir.Add(test, "again");
+
+ RulesSetT rules;
+ ir.Serialize(/*include_debug_information=*/false, &rules);
+
+ EXPECT_THAT(rules.lhs_set, SizeIs(1));
+ EXPECT_THAT(rules.lhs_set.front()->lhs, ElementsAre(test));
+}
+
} // namespace
} // namespace libtextclassifier3::grammar
diff --git a/native/utils/grammar/utils/locale-shard-map.cc b/native/utils/grammar/utils/locale-shard-map.cc
new file mode 100644
index 0000000..4f7dc5e
--- /dev/null
+++ b/native/utils/grammar/utils/locale-shard-map.cc
@@ -0,0 +1,86 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/grammar/utils/locale-shard-map.h"
+
+#include <algorithm>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "utils/i18n/locale-list.h"
+#include "utils/i18n/locale.h"
+#include "utils/strings/append.h"
+
+namespace libtextclassifier3::grammar {
+namespace {
+
+std::vector<Locale> LocaleTagsToLocaleList(const std::string& locale_tags) {
+ std::vector<Locale> locale_list;
+ for (const Locale& locale : LocaleList::ParseFrom(locale_tags).GetLocales()) {
+ if (locale.IsValid()) {
+ locale_list.emplace_back(locale);
+ }
+ }
+ std::sort(locale_list.begin(), locale_list.end(),
+ [](const Locale& a, const Locale& b) { return a < b; });
+ return locale_list;
+}
+
+} // namespace
+
+LocaleShardMap LocaleShardMap::CreateLocaleShardMap(
+ const std::vector<std::string>& locale_tags) {
+ LocaleShardMap locale_shard_map;
+ for (const std::string& locale_tag : locale_tags) {
+ locale_shard_map.AddLocalTags(locale_tag);
+ }
+ return locale_shard_map;
+}
+
+std::vector<Locale> LocaleShardMap::GetLocales(const int shard) const {
+ auto locale_it = shard_to_locale_data_.find(shard);
+ if (locale_it != shard_to_locale_data_.end()) {
+ return locale_it->second;
+ }
+ return std::vector<Locale>();
+}
+
+int LocaleShardMap::GetNumberOfShards() const {
+ return shard_to_locale_data_.size();
+}
+
+int LocaleShardMap::GetShard(const std::vector<Locale> locales) const {
+ for (const auto& [shard, locale_list] : shard_to_locale_data_) {
+ if (std::equal(locales.begin(), locales.end(), locale_list.begin())) {
+ return shard;
+ }
+ }
+ return 0;
+}
+
+int LocaleShardMap::GetShard(const std::string& locale_tags) const {
+ std::vector<Locale> locale_list = LocaleTagsToLocaleList(locale_tags);
+ return GetShard(locale_list);
+}
+
+void LocaleShardMap::AddLocalTags(const std::string& locale_tags) {
+ std::vector<Locale> locale_list = LocaleTagsToLocaleList(locale_tags);
+ int shard_id = shard_to_locale_data_.size();
+ shard_to_locale_data_.insert({shard_id, locale_list});
+}
+
+} // namespace libtextclassifier3::grammar
diff --git a/native/utils/grammar/utils/locale-shard-map.h b/native/utils/grammar/utils/locale-shard-map.h
new file mode 100644
index 0000000..5e0f5cb
--- /dev/null
+++ b/native/utils/grammar/utils/locale-shard-map.h
@@ -0,0 +1,55 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_UTILS_LOCALE_SHARD_MAP_H_
+#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_UTILS_LOCALE_SHARD_MAP_H_
+
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "utils/grammar/types.h"
+#include "utils/i18n/locale-list.h"
+#include "utils/i18n/locale.h"
+#include "utils/optional.h"
+
+namespace libtextclassifier3::grammar {
+
+// Grammar rules are associated with Locale which serve as a filter during rule
+// application. The class holds shard’s information for Locale which is used
+// when the Aqua rules are compiled into internal rules.proto flatbuffer.
+class LocaleShardMap {
+ public:
+ static LocaleShardMap CreateLocaleShardMap(
+ const std::vector<std::string>& locale_tags);
+
+ std::vector<Locale> GetLocales(const int shard) const;
+
+ int GetShard(const std::vector<Locale> locales) const;
+ int GetShard(const std::string& locale_tags) const;
+
+ int GetNumberOfShards() const;
+
+ private:
+ explicit LocaleShardMap() {}
+ void AddLocalTags(const std::string& locale_tag);
+
+ std::unordered_map<int, std::vector<Locale>> shard_to_locale_data_;
+};
+
+} // namespace libtextclassifier3::grammar
+
+#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_UTILS_LOCALE_SHARD_MAP_H_
diff --git a/native/utils/grammar/utils/locale-shard-map_test.cc b/native/utils/grammar/utils/locale-shard-map_test.cc
new file mode 100644
index 0000000..14c9081
--- /dev/null
+++ b/native/utils/grammar/utils/locale-shard-map_test.cc
@@ -0,0 +1,76 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/grammar/utils/locale-shard-map.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3::grammar {
+namespace {
+
+using ::testing::SizeIs;
+
+TEST(LocaleShardMapTest, HandlesSimpleShard) {
+ LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap(
+ {"ar-EG", "bn-BD", "cs-CZ", "da-DK", "de-DE", "en-US", "es-ES", "fi-FI",
+ "fr-FR", "gu-IN", "id-ID", "it-IT", "ja-JP", "kn-IN", "ko-KR", "ml-IN",
+ "mr-IN", "nl-NL", "no-NO", "pl-PL", "pt-BR", "ru-RU", "sv-SE", "ta-IN",
+ "te-IN", "th-TH", "tr-TR", "uk-UA", "ur-PK", "vi-VN", "zh-TW"});
+
+ EXPECT_EQ(locale_shard_map.GetNumberOfShards(), 31);
+ for (int i = 0; i < 31; i++) {
+ EXPECT_THAT(locale_shard_map.GetLocales(i), SizeIs(1));
+ }
+ EXPECT_EQ(locale_shard_map.GetLocales(0)[0], Locale::FromBCP47("ar-EG"));
+ EXPECT_EQ(locale_shard_map.GetLocales(8)[0], Locale::FromBCP47("fr-FR"));
+ EXPECT_EQ(locale_shard_map.GetLocales(16)[0], Locale::FromBCP47("mr-IN"));
+ EXPECT_EQ(locale_shard_map.GetLocales(24)[0], Locale::FromBCP47("te-IN"));
+ EXPECT_EQ(locale_shard_map.GetLocales(30)[0], Locale::FromBCP47("zh-TW"));
+}
+
+TEST(LocaleTagShardTest, HandlesWildCard) {
+ LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap({"*"});
+ EXPECT_EQ(locale_shard_map.GetNumberOfShards(), 1);
+ EXPECT_THAT(locale_shard_map.GetLocales(0), SizeIs(1));
+}
+
+TEST(LocaleTagShardTest, HandlesMultipleLocalePerShard) {
+ LocaleShardMap locale_shard_map =
+ LocaleShardMap::CreateLocaleShardMap({"ar-EG,bn-BD,cs-CZ", "en-*"});
+ EXPECT_EQ(locale_shard_map.GetNumberOfShards(), 2);
+ EXPECT_EQ(locale_shard_map.GetLocales(0)[0], Locale::FromBCP47("ar-EG"));
+ EXPECT_EQ(locale_shard_map.GetLocales(0)[1], Locale::FromBCP47("bn-BD"));
+ EXPECT_EQ(locale_shard_map.GetLocales(0)[2], Locale::FromBCP47("cs-CZ"));
+ EXPECT_EQ(locale_shard_map.GetLocales(1)[0], Locale::FromBCP47("en"));
+
+ EXPECT_EQ(locale_shard_map.GetShard("ar-EG,bn-BD,cs-CZ"), 0);
+ EXPECT_EQ(locale_shard_map.GetShard("bn-BD,cs-CZ,ar-EG"), 0);
+ EXPECT_EQ(locale_shard_map.GetShard("bn-BD,ar-EG,cs-CZ"), 0);
+ EXPECT_EQ(locale_shard_map.GetShard("ar-EG,cs-CZ,bn-BD"), 0);
+}
+
+TEST(LocaleTagShardTest, HandlesEmptyLocaleTag) {
+ LocaleShardMap locale_shard_map =
+ LocaleShardMap::CreateLocaleShardMap({"", "en-US"});
+ EXPECT_EQ(locale_shard_map.GetNumberOfShards(), 2);
+ EXPECT_THAT(locale_shard_map.GetLocales(0), SizeIs(0));
+ EXPECT_THAT(locale_shard_map.GetLocales(1), SizeIs(1));
+ EXPECT_EQ(locale_shard_map.GetLocales(1)[0], Locale::FromBCP47("en-US"));
+}
+
+} // namespace
+} // namespace libtextclassifier3::grammar
diff --git a/native/utils/grammar/utils/rules.cc b/native/utils/grammar/utils/rules.cc
index d6e4b76..1b545a6 100644
--- a/native/utils/grammar/utils/rules.cc
+++ b/native/utils/grammar/utils/rules.cc
@@ -161,10 +161,16 @@
void Rules::AddAlias(const std::string& nonterminal_name,
const std::string& alias) {
+#ifndef TC3_USE_CXX14
TC3_CHECK_EQ(nonterminal_alias_.insert_or_assign(alias, nonterminal_name)
.first->second,
nonterminal_name)
<< "Cannot redefine alias: " << alias;
+#else
+ nonterminal_alias_[alias] = nonterminal_name;
+ TC3_CHECK_EQ(nonterminal_alias_[alias], nonterminal_name)
+ << "Cannot redefine alias: " << alias;
+#endif
}
// Defines a nonterminal for an externally provided annotation.
@@ -258,7 +264,7 @@
}
std::vector<Rules::RhsElement> Rules::ResolveFillers(
- const std::vector<RhsElement>& rhs) {
+ const std::vector<RhsElement>& rhs, int shard) {
std::vector<RhsElement> result;
for (int i = 0; i < rhs.size();) {
if (i == rhs.size() - 1 || IsNonterminalOfName(rhs[i], kFiller) ||
@@ -278,15 +284,27 @@
/*is_optional=*/false);
if (rhs[i + 1].is_optional) {
// <a_with_tokens> ::= <a>
- Add(with_tokens_nonterminal, {rhs[i]});
+ Add(with_tokens_nonterminal, {rhs[i]},
+ /*callback=*/kNoCallback,
+ /*callback_param=*/0,
+ /*max_whitespace_gap=*/-1,
+ /*case_sensitive=*/false, shard);
} else {
// <a_with_tokens> ::= <a> <token>
- Add(with_tokens_nonterminal, {rhs[i], token});
+ Add(with_tokens_nonterminal, {rhs[i], token},
+ /*callback=*/kNoCallback,
+ /*callback_param=*/0,
+ /*max_whitespace_gap=*/-1,
+ /*case_sensitive=*/false, shard);
}
// <a_with_tokens> ::= <a_with_tokens> <token>
const RhsElement with_tokens(with_tokens_nonterminal,
/*is_optional=*/false);
- Add(with_tokens_nonterminal, {with_tokens, token});
+ Add(with_tokens_nonterminal, {with_tokens, token},
+ /*callback=*/kNoCallback,
+ /*callback_param=*/0,
+ /*max_whitespace_gap=*/-1,
+ /*case_sensitive=*/false, shard);
result.push_back(with_tokens);
i += 2;
}
@@ -294,8 +312,8 @@
}
std::vector<Rules::RhsElement> Rules::OptimizeRhs(
- const std::vector<RhsElement>& rhs) {
- return ResolveFillers(ResolveAnchors(rhs));
+ const std::vector<RhsElement>& rhs, int shard) {
+ return ResolveFillers(ResolveAnchors(rhs), shard);
}
void Rules::Add(const int lhs, const std::vector<RhsElement>& rhs,
@@ -379,6 +397,14 @@
/*callback_param=*/value, max_whitespace_gap, case_sensitive, shard);
}
+void Rules::AddValueMapping(const int lhs, const std::vector<RhsElement>& rhs,
+ int64 value, const int8 max_whitespace_gap,
+ const bool case_sensitive, const int shard) {
+ Add(lhs, rhs,
+ /*callback=*/static_cast<CallbackId>(DefaultCallback::kMapping),
+ /*callback_param=*/value, max_whitespace_gap, case_sensitive, shard);
+}
+
void Rules::AddRegex(const std::string& lhs, const std::string& regex_pattern) {
AddRegex(AddNonterminal(lhs), regex_pattern);
}
@@ -388,8 +414,19 @@
regex_rules_.push_back(regex_pattern);
}
+bool Rules::UsesFillers() const {
+ for (const Rule& rule : rules_) {
+ for (const RhsElement& rhs_element : rule.rhs) {
+ if (IsNonterminalOfName(rhs_element, kFiller)) {
+ return true;
+ }
+ }
+ }
+ return false;
+}
+
Ir Rules::Finalize(const std::set<std::string>& predefined_nonterminals) const {
- Ir rules(filters_, num_shards_);
+ Ir rules(locale_shard_map_);
std::unordered_map<int, Nonterm> nonterminal_ids;
// Pending rules to process.
@@ -405,22 +442,21 @@
}
// Assign (unmergeable) Nonterm values to any nonterminals that have
- // multiple rules or that have a filter callback on some rule.
+ // multiple rules.
for (int i = 0; i < nonterminals_.size(); i++) {
const NontermInfo& nonterminal = nonterminals_[i];
+
+ // Skip predefined nonterminals, they have already been assigned.
+ if (rules.GetNonterminalForName(nonterminal.name) != kUnassignedNonterm) {
+ continue;
+ }
+
bool unmergeable =
(nonterminal.from_annotation || nonterminal.rules.size() > 1 ||
!nonterminal.regex_rules.empty());
for (const int rule_index : nonterminal.rules) {
- const Rule& rule = rules_[rule_index];
-
// Schedule rule.
scheduled_rules.insert({i, rule_index});
-
- if (rule.callback != kNoCallback &&
- filters_.find(rule.callback) != filters_.end()) {
- unmergeable = true;
- }
}
if (unmergeable) {
@@ -441,6 +477,22 @@
rules.AddAnnotation(nonterminal_ids[nonterminal], annotation);
}
+ // Check whether fillers are still referenced (if they couldn't get optimized
+ // away).
+ if (UsesFillers()) {
+ TC3_LOG(WARNING) << "Rules use fillers that couldn't be optimized, grammar "
+ "matching performance might be impacted.";
+
+ // Add a definition for the filler:
+ // <filler> = <token>
+ // <filler> = <token> <filler>
+ const Nonterm filler = rules.GetNonterminalForName(kFiller);
+ const Nonterm token =
+ rules.DefineNonterminal(rules.GetNonterminalForName(kTokenNonterm));
+ rules.Add(filler, token);
+ rules.Add(filler, std::vector<Nonterm>{token, filler});
+ }
+
// Now, keep adding eligible rules (rules whose rhs is completely assigned)
// until we can't make any more progress.
// Note: The following code is quadratic in the worst case.
diff --git a/native/utils/grammar/utils/rules.h b/native/utils/grammar/utils/rules.h
index 5a2cbc2..c8b2a70 100644
--- a/native/utils/grammar/utils/rules.h
+++ b/native/utils/grammar/utils/rules.h
@@ -34,19 +34,15 @@
// All rules for a grammar will be collected in a rules object.
//
// Rules r;
-// CallbackId date_output_callback = 1;
-// CallbackId day_filter_callback = 2; r.DefineFilter(day_filter_callback);
-// CallbackId year_filter_callback = 3; r.DefineFilter(year_filter_callback);
-// r.Add("<date>", {"<monthname>", "<day>", <year>"},
-// date_output_callback);
+// r.Add("<date>", {"<monthname>", "<day>", <year>"});
// r.Add("<monthname>", {"January"});
// ...
// r.Add("<monthname>", {"December"});
-// r.Add("<day>", {"<string_of_digits>"}, day_filter_callback);
-// r.Add("<year>", {"<string_of_digits>"}, year_filter_callback);
+// r.Add("<day>", {"<string_of_digits>"});
+// r.Add("<year>", {"<string_of_digits>"});
//
-// The Add() method adds a rule with a given lhs, rhs, and (optionally)
-// callback. The rhs is just a list of terminals and nonterminals. Anything
+// The Add() method adds a rule with a given lhs, rhs/
+// The rhs is just a list of terminals and nonterminals. Anything
// surrounded in angle brackets is considered a nonterminal. A "?" can follow
// any element of the RHS, like this:
//
@@ -55,26 +51,35 @@
// This indicates that the <day> and "," parts of the rhs are optional.
// (This is just notational shorthand for adding a bunch of rules.)
//
-// Once you're done adding rules and callbacks to the Rules object,
-// call r.Finalize() on it. This lowers the rule set into an internal
-// representation.
+// Once you're done adding rules, r.Finalize() lowers the rule set into an
+// internal representation.
class Rules {
public:
- explicit Rules(const int num_shards = 1) : num_shards_(num_shards) {}
+ explicit Rules(const LocaleShardMap& locale_shard_map)
+ : locale_shard_map_(locale_shard_map) {}
// Represents one item in a right-hand side, a single terminal or nonterminal.
struct RhsElement {
RhsElement() {}
explicit RhsElement(const std::string& terminal, const bool is_optional)
- : is_terminal(true), terminal(terminal), is_optional(is_optional) {}
- explicit RhsElement(const int nonterminal, const bool is_optional)
+ : is_terminal(true),
+ terminal(terminal),
+ is_optional(is_optional),
+ is_constituent(false) {}
+ explicit RhsElement(const int nonterminal, const bool is_optional,
+ const bool is_constituent = true)
: is_terminal(false),
nonterminal(nonterminal),
- is_optional(is_optional) {}
+ is_optional(is_optional),
+ is_constituent(is_constituent) {}
bool is_terminal;
std::string terminal;
int nonterminal;
bool is_optional;
+
+ // Whether the element is a constituent of a rule - these are the explicit
+ // nonterminals, but not terminals or implicitly added anchors.
+ bool is_constituent;
};
// Represents the right-hand side, and possibly callback, of one rule.
@@ -139,6 +144,9 @@
const std::vector<std::string>& rhs, int64 value,
int8 max_whitespace_gap = -1,
bool case_sensitive = false, int shard = 0);
+ void AddValueMapping(int lhs, const std::vector<RhsElement>& rhs, int64 value,
+ int8 max_whitespace_gap = -1,
+ bool case_sensitive = false, int shard = 0);
// Adds a regex rule.
void AddRegex(const std::string& lhs, const std::string& regex_pattern);
@@ -161,9 +169,6 @@
// nonterminal.
void AddAlias(const std::string& nonterminal_name, const std::string& alias);
- // Defines a new filter id.
- void DefineFilter(const CallbackId filter_id) { filters_.insert(filter_id); }
-
// Lowers the rule set into the intermediate representation.
// Treats nonterminals given by the argument `predefined_nonterminals` as
// defined externally. This allows to define rules that are dependent on
@@ -171,6 +176,9 @@
// fed to the matcher by the lexer.
Ir Finalize(const std::set<std::string>& predefined_nonterminals = {}) const;
+ const std::vector<NontermInfo>& nonterminals() const { return nonterminals_; }
+ const std::vector<Rule>& rules() const { return rules_; }
+
private:
void ExpandOptionals(
int lhs, const std::vector<RhsElement>& rhs, CallbackId callback,
@@ -180,7 +188,8 @@
std::vector<bool>* omit_these);
// Applies optimizations to the right hand side of a rule.
- std::vector<RhsElement> OptimizeRhs(const std::vector<RhsElement>& rhs);
+ std::vector<RhsElement> OptimizeRhs(const std::vector<RhsElement>& rhs,
+ int shard = 0);
// Removes start and end anchors in case they are followed (respectively
// preceded) by unbounded filler.
@@ -198,13 +207,17 @@
// `<a_with_tokens> ::= <a>`
// `<a_with_tokens> ::= <a_with_tokens> <token>`
// In this each occurrence of `<a>` can start a sequence of tokens.
- std::vector<RhsElement> ResolveFillers(const std::vector<RhsElement>& rhs);
+ std::vector<RhsElement> ResolveFillers(const std::vector<RhsElement>& rhs,
+ int shard = 0);
// Checks whether an element denotes a specific nonterminal.
bool IsNonterminalOfName(const RhsElement& element,
const std::string& nonterminal) const;
- const int num_shards_;
+ // Checks whether the fillers are used in any active rule.
+ bool UsesFillers() const;
+
+ const LocaleShardMap& locale_shard_map_;
// Non-terminal to id map.
std::unordered_map<std::string, int> nonterminal_names_;
@@ -215,9 +228,6 @@
// Rules.
std::vector<Rule> rules_;
std::vector<std::string> regex_rules_;
-
- // Ids of callbacks that should be treated as filters.
- std::unordered_set<CallbackId> filters_;
};
} // namespace libtextclassifier3::grammar
diff --git a/native/utils/grammar/utils/rules_test.cc b/native/utils/grammar/utils/rules_test.cc
index 6761118..c71f2b4 100644
--- a/native/utils/grammar/utils/rules_test.cc
+++ b/native/utils/grammar/utils/rules_test.cc
@@ -28,7 +28,9 @@
using ::testing::SizeIs;
TEST(SerializeRulesTest, HandlesSimpleRuleSet) {
- Rules rules;
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ Rules rules(locale_shard_map);
rules.Add("<verb>", {"buy"});
rules.Add("<verb>", {"bring"});
@@ -49,16 +51,16 @@
}
TEST(SerializeRulesTest, HandlesRulesSetWithCallbacks) {
- Rules rules;
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ Rules rules(locale_shard_map);
const CallbackId output = 1;
- const CallbackId filter = 2;
- rules.DefineFilter(filter);
rules.Add("<verb>", {"buy"});
- rules.Add("<verb>", {"bring"}, output, 0);
- rules.Add("<verb>", {"remind"}, output, 0);
+ rules.Add("<verb>", {"bring"});
+ rules.Add("<verb>", {"remind"});
rules.Add("<reminder>", {"remind", "me", "to", "<verb>"});
- rules.Add("<action>", {"<reminder>"}, filter, 0);
+ rules.Add("<action>", {"<reminder>"}, output, 0);
const Ir ir = rules.Finalize();
RulesSetT frozen_rules;
@@ -68,16 +70,16 @@
EXPECT_EQ(frozen_rules.terminals,
std::string("bring\0buy\0me\0remind\0to\0", 23));
- // We have two identical output calls and one filter call in the rule set
- // definition above.
- EXPECT_THAT(frozen_rules.lhs, SizeIs(2));
+ EXPECT_THAT(frozen_rules.lhs, SizeIs(1));
EXPECT_THAT(frozen_rules.rules.front()->binary_rules, SizeIs(3));
EXPECT_THAT(frozen_rules.rules.front()->unary_rules, SizeIs(1));
}
TEST(SerializeRulesTest, HandlesRulesWithWhitespaceGapLimits) {
- Rules rules;
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ Rules rules(locale_shard_map);
rules.Add("<iata>", {"lx"});
rules.Add("<iata>", {"aa"});
rules.Add("<flight>", {"<iata>", "<4_digits>"}, kNoCallback, 0,
@@ -93,7 +95,9 @@
}
TEST(SerializeRulesTest, HandlesCaseSensitiveTerminals) {
- Rules rules;
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ Rules rules(locale_shard_map);
rules.Add("<iata>", {"LX"}, kNoCallback, 0, /*max_whitespace_gap=*/-1,
/*case_sensitive=*/true);
rules.Add("<iata>", {"AA"}, kNoCallback, 0, /*max_whitespace_gap=*/-1,
@@ -113,7 +117,9 @@
}
TEST(SerializeRulesTest, HandlesMultipleShards) {
- Rules rules(/*num_shards=*/2);
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({"", "de"});
+ Rules rules(locale_shard_map);
rules.Add("<iata>", {"LX"}, kNoCallback, 0, /*max_whitespace_gap=*/-1,
/*case_sensitive=*/true, /*shard=*/0);
rules.Add("<iata>", {"aa"}, kNoCallback, 0, /*max_whitespace_gap=*/-1,
@@ -128,7 +134,10 @@
}
TEST(SerializeRulesTest, HandlesRegexRules) {
- Rules rules;
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ Rules rules(locale_shard_map);
+ // Rules rules;
rules.AddRegex("<code>", "[A-Z]+");
rules.AddRegex("<numbers>", "\\d+");
RulesSetT frozen_rules;
@@ -138,7 +147,9 @@
}
TEST(SerializeRulesTest, HandlesAlias) {
- Rules rules;
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ Rules rules(locale_shard_map);
rules.Add("<iata>", {"lx"});
rules.Add("<iata>", {"aa"});
rules.Add("<flight>", {"<iata>", "<4_digits>"});
@@ -159,7 +170,9 @@
}
TEST(SerializeRulesTest, ResolvesAnchorsAndFillers) {
- Rules rules;
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ Rules rules(locale_shard_map);
rules.Add("<code>",
{"<^>", "<filler>", "this", "is", "a", "test", "<filler>", "<$>"});
const Ir ir = rules.Finalize();
@@ -180,8 +193,33 @@
EXPECT_THAT(frozen_rules.lhs, IsEmpty());
}
+TEST(SerializeRulesTest, HandlesFillers) {
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ Rules rules(locale_shard_map);
+ rules.Add("<test>", {"<filler>?", "a", "test"});
+ const Ir ir = rules.Finalize();
+ RulesSetT frozen_rules;
+ ir.Serialize(/*include_debug_information=*/false, &frozen_rules);
+
+ EXPECT_THAT(frozen_rules.rules, SizeIs(1));
+ EXPECT_EQ(frozen_rules.terminals, std::string("a\0test\0", 7));
+
+ // Expect removal of anchors and fillers in this case.
+ // The rule above is equivalent to: <code> ::= this is a test, binarized into
+ // <tmp_0> ::= <filler> a
+ // <test> ::= <tmp_0> test
+ // <test> ::= a test
+ // <filler> ::= <token> <filler>
+ EXPECT_THAT(frozen_rules.rules.front()->binary_rules, SizeIs(4));
+ // <filler> ::= <token>
+ EXPECT_THAT(frozen_rules.rules.front()->unary_rules, SizeIs(1));
+}
+
TEST(SerializeRulesTest, HandlesAnnotations) {
- Rules rules;
+ grammar::LocaleShardMap locale_shard_map =
+ grammar::LocaleShardMap::CreateLocaleShardMap({""});
+ Rules rules(locale_shard_map);
rules.AddAnnotation("phone");
rules.AddAnnotation("url");
rules.AddAnnotation("tracking_number");
diff --git a/native/utils/hash/cityhash.cc b/native/utils/hash/cityhash.cc
new file mode 100644
index 0000000..e2a8596
--- /dev/null
+++ b/native/utils/hash/cityhash.cc
@@ -0,0 +1,188 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/hash/cityhash.h"
+
+#include <cstdint>
+
+#include "absl/base/internal/endian.h"
+#include "absl/numeric/int128.h"
+
+namespace tc3farmhash {
+namespace {
+// Some primes between 2^63 and 2^64 for various uses.
+static const uint64_t k0 = 0xa5b85c5e198ed849ULL;
+static const uint64_t k1 = 0x8d58ac26afe12e47ULL;
+static const uint64_t k2 = 0xc47b6e9e3a970ed3ULL;
+static const uint64_t k3 = 0xc70f6907e782aa0bULL;
+
+// Hash 128 input bits down to 64 bits of output.
+// This is intended to be a reasonably good hash function.
+// It may change from time to time.
+inline uint64_t Hash128to64(const absl::uint128 x) {
+ // Murmur-inspired hashing.
+ const uint64_t kMul = 0xc6a4a7935bd1e995ULL;
+ uint64_t a = (absl::Uint128Low64(x) ^ absl::Uint128High64(x)) * kMul;
+ a ^= (a >> 47);
+ uint64_t b = (absl::Uint128High64(x) ^ a) * kMul;
+ b ^= (b >> 47);
+ b *= kMul;
+ return b;
+}
+
+uint64_t HashLen16(uint64_t u, uint64_t v) {
+ return Hash128to64(absl::MakeUint128(u, v));
+}
+
+static uint64_t Rotate(uint64_t val, size_t shift) {
+ assert(shift <= 63);
+ return (val >> shift) | (val << (-shift & 63));
+}
+
+static uint64_t ShiftMix(uint64_t val) { return val ^ (val >> 47); }
+
+uint64_t HashLen0to16(const char *s, size_t len) {
+ assert(len <= 16);
+ if (len > 8) {
+ uint64_t a = absl::little_endian::Load64(s);
+ uint64_t b = absl::little_endian::Load64(s + len - 8);
+ return HashLen16(a, Rotate(b + len, len)) ^ b;
+ }
+ if (len >= 4) {
+ uint64_t a = absl::little_endian::Load32(s);
+ return HashLen16(len + (a << 3), absl::little_endian::Load32(s + len - 4));
+ }
+ if (len > 0) {
+ uint8_t a = s[0];
+ uint8_t b = s[len >> 1];
+ uint8_t c = s[len - 1];
+ uint32_t y = static_cast<uint32_t>(a) + (static_cast<uint32_t>(b) << 8);
+ uint32_t z = len + (static_cast<uint32_t>(c) << 2);
+ return ShiftMix(y * k2 ^ z * k3) * k2;
+ }
+ return k2;
+}
+
+// Return a 16-byte hash for 48 bytes. Quick and dirty.
+// Callers do best to use "random-looking" values for a and b.
+// (For more, see the code review discussion of CL 18799087.)
+std::pair<uint64_t, uint64_t> WeakHashLen32WithSeeds(uint64_t w, uint64_t x,
+ uint64_t y, uint64_t z,
+ uint64_t a, uint64_t b) {
+ a += w;
+ b = Rotate(b + a + z, 51);
+ uint64_t c = a;
+ a += x;
+ a += y;
+ b += Rotate(a, 23);
+ return std::make_pair(a + z, b + c);
+}
+
+// Return a 16-byte hash for s[0] ... s[31], a, and b. Quick and dirty.
+std::pair<uint64_t, uint64_t> WeakHashLen32WithSeeds(const char *s, uint64_t a,
+ uint64_t b) {
+ return WeakHashLen32WithSeeds(absl::little_endian::Load64(s),
+ absl::little_endian::Load64(s + 8),
+ absl::little_endian::Load64(s + 16),
+ absl::little_endian::Load64(s + 24), a, b);
+}
+
+} // namespace
+
+// This probably works well for 16-byte strings as well, but it may be overkill
+// in that case.
+static uint64_t HashLen17to32(const char *s, size_t len) {
+ assert(len >= 17);
+ assert(len <= 32);
+ uint64_t a = absl::little_endian::Load64(s) * k1;
+ uint64_t b = absl::little_endian::Load64(s + 8);
+ uint64_t c = absl::little_endian::Load64(s + len - 8) * k2;
+ uint64_t d = absl::little_endian::Load64(s + len - 16) * k0;
+ return HashLen16(Rotate(a - b, 43) + Rotate(c, 30) + d,
+ a + Rotate(b ^ k3, 20) - c + len);
+}
+
+// Return an 8-byte hash for 33 to 64 bytes.
+static uint64_t HashLen33to64(const char *s, size_t len) {
+ uint64_t z = absl::little_endian::Load64(s + 24);
+ uint64_t a = absl::little_endian::Load64(s) +
+ (len + absl::little_endian::Load64(s + len - 16)) * k0;
+ uint64_t b = Rotate(a + z, 52);
+ uint64_t c = Rotate(a, 37);
+ a += absl::little_endian::Load64(s + 8);
+ c += Rotate(a, 7);
+ a += absl::little_endian::Load64(s + 16);
+ uint64_t vf = a + z;
+ uint64_t vs = b + Rotate(a, 31) + c;
+ a = absl::little_endian::Load64(s + 16) +
+ absl::little_endian::Load64(s + len - 32);
+ z += absl::little_endian::Load64(s + len - 8);
+ b = Rotate(a + z, 52);
+ c = Rotate(a, 37);
+ a += absl::little_endian::Load64(s + len - 24);
+ c += Rotate(a, 7);
+ a += absl::little_endian::Load64(s + len - 16);
+ uint64_t wf = a + z;
+ uint64_t ws = b + Rotate(a, 31) + c;
+ uint64_t r = ShiftMix((vf + ws) * k2 + (wf + vs) * k0);
+ return ShiftMix(r * k0 + vs) * k2;
+}
+
+uint64_t CityHash64(const char *s, size_t len) {
+ if (len <= 32) {
+ if (len <= 16) {
+ return HashLen0to16(s, len);
+ } else {
+ return HashLen17to32(s, len);
+ }
+ } else if (len <= 64) {
+ return HashLen33to64(s, len);
+ }
+
+ // For strings over 64 bytes we hash the end first, and then as we
+ // loop we keep 56 bytes of state: v, w, x, y, and z.
+ uint64_t x = absl::little_endian::Load64(s + len - 40);
+ uint64_t y = absl::little_endian::Load64(s + len - 16) +
+ absl::little_endian::Load64(s + len - 56);
+ uint64_t z = HashLen16(absl::little_endian::Load64(s + len - 48) + len,
+ absl::little_endian::Load64(s + len - 24));
+ std::pair<uint64_t, uint64_t> v =
+ WeakHashLen32WithSeeds(s + len - 64, len, z);
+ std::pair<uint64_t, uint64_t> w =
+ WeakHashLen32WithSeeds(s + len - 32, y + k1, x);
+ x = x * k1 + absl::little_endian::Load64(s);
+
+ // Decrease len to the nearest multiple of 64, and operate on 64-byte chunks.
+ len = (len - 1) & ~static_cast<size_t>(63);
+ assert(len > 0);
+ assert(len == len / 64 * 64);
+ do {
+ x = Rotate(x + y + v.first + absl::little_endian::Load64(s + 8), 37) * k1;
+ y = Rotate(y + v.second + absl::little_endian::Load64(s + 48), 42) * k1;
+ x ^= w.second;
+ y += v.first + absl::little_endian::Load64(s + 40);
+ z = Rotate(z + w.first, 33) * k1;
+ v = WeakHashLen32WithSeeds(s, v.second * k1, x + w.first);
+ w = WeakHashLen32WithSeeds(s + 32, z + w.second,
+ y + absl::little_endian::Load64(s + 16));
+ std::swap(z, x);
+ s += 64;
+ len -= 64;
+ } while (len != 0);
+ return HashLen16(HashLen16(v.first, w.first) + ShiftMix(y) * k1 + z,
+ HashLen16(v.second, w.second) + x);
+}
+} // namespace tc3farmhash
diff --git a/native/utils/hash/cityhash.h b/native/utils/hash/cityhash.h
new file mode 100644
index 0000000..9ede3d6
--- /dev/null
+++ b/native/utils/hash/cityhash.h
@@ -0,0 +1,27 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_HASH_CITYHASH_H_
+#define LIBTEXTCLASSIFIER_UTILS_HASH_CITYHASH_H_
+
+#include <cstddef>
+#include <cstdint>
+
+namespace tc3farmhash {
+uint64_t CityHash64(const char *s, size_t len);
+} // namespace tc3farmhash
+
+#endif // LIBTEXTCLASSIFIER_UTILS_HASH_CITYHASH_H_
diff --git a/native/utils/i18n/language-tag.fbs b/native/utils/i18n/language-tag.fbs
old mode 100755
new mode 100644
diff --git a/native/utils/i18n/locale-list.cc b/native/utils/i18n/locale-list.cc
new file mode 100644
index 0000000..a0be5ac
--- /dev/null
+++ b/native/utils/i18n/locale-list.cc
@@ -0,0 +1,44 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/i18n/locale-list.h"
+
+#include <string>
+
+namespace libtextclassifier3 {
+
+LocaleList LocaleList::ParseFrom(const std::string& locale_tags) {
+ std::vector<StringPiece> split_locales = strings::Split(locale_tags, ',');
+ std::string reference_locale;
+ if (!split_locales.empty()) {
+ // Assigns the first parsed locale to reference_locale.
+ reference_locale = split_locales[0].ToString();
+ } else {
+ reference_locale = "";
+ }
+ std::vector<Locale> locales;
+ for (const StringPiece& locale_str : split_locales) {
+ const Locale locale = Locale::FromBCP47(locale_str.ToString());
+ if (!locale.IsValid()) {
+ TC3_LOG(WARNING) << "Failed to parse the detected_text_language_tag: "
+ << locale_str.ToString();
+ }
+ locales.push_back(locale);
+ }
+ return LocaleList(locales, split_locales, reference_locale);
+}
+
+} // namespace libtextclassifier3
diff --git a/native/utils/i18n/locale-list.h b/native/utils/i18n/locale-list.h
new file mode 100644
index 0000000..cf2e06d
--- /dev/null
+++ b/native/utils/i18n/locale-list.h
@@ -0,0 +1,55 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_I18N_LOCALE_LIST_H_
+#define LIBTEXTCLASSIFIER_UTILS_I18N_LOCALE_LIST_H_
+
+#include <string>
+
+#include "utils/i18n/locale.h"
+#include "utils/strings/split.h"
+
+namespace libtextclassifier3 {
+
+// Parses and hold data about locales (combined by delimiter ',').
+class LocaleList {
+ public:
+ // Constructs the
+ // - Collection of locale tag from local_tags
+ // - Collection of Locale objects from a valid BCP47 tag. (If the tag is
+ // invalid, an object is created but return false for IsInvalid() call.
+ // - Assigns the first parsed locale to reference_locale.
+ static LocaleList ParseFrom(const std::string& locale_tags);
+
+ std::vector<Locale> GetLocales() const { return locales_; }
+ std::vector<StringPiece> GetLocaleTags() const { return split_locales_; }
+ std::string GetReferenceLocale() const { return reference_locale_; }
+
+ private:
+ LocaleList(const std::vector<Locale>& locales,
+ const std::vector<StringPiece>& split_locales,
+ const StringPiece& reference_locale)
+ : locales_(locales),
+ split_locales_(split_locales),
+ reference_locale_(reference_locale.ToString()) {}
+
+ const std::vector<Locale> locales_;
+ const std::vector<StringPiece> split_locales_;
+ const std::string reference_locale_;
+};
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_I18N_LOCALE_LIST_H_
diff --git a/native/utils/i18n/locale-list_test.cc b/native/utils/i18n/locale-list_test.cc
new file mode 100644
index 0000000..d7cfd17
--- /dev/null
+++ b/native/utils/i18n/locale-list_test.cc
@@ -0,0 +1,51 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/i18n/locale-list.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+using ::testing::SizeIs;
+
+namespace libtextclassifier3 {
+namespace {
+
+TEST(LocaleTest, ParsedLocalesSanityCheck) {
+ LocaleList locale_list = LocaleList::ParseFrom("en-US,zh-CN,ar,en");
+ EXPECT_THAT(locale_list.GetLocales(), SizeIs(4));
+ EXPECT_THAT(locale_list.GetLocaleTags(), SizeIs(4));
+ EXPECT_EQ(locale_list.GetReferenceLocale(), "en-US");
+}
+
+TEST(LocaleTest, ParsedLocalesEmpty) {
+ LocaleList locale_list = LocaleList::ParseFrom("");
+ EXPECT_THAT(locale_list.GetLocales(), SizeIs(0));
+ EXPECT_THAT(locale_list.GetLocaleTags(), SizeIs(0));
+ EXPECT_EQ(locale_list.GetReferenceLocale(), "");
+}
+
+TEST(LocaleTest, ParsedLocalesIvalid) {
+ LocaleList locale_list = LocaleList::ParseFrom("en,invalid");
+ EXPECT_THAT(locale_list.GetLocales(), SizeIs(2));
+ EXPECT_THAT(locale_list.GetLocaleTags(), SizeIs(2));
+ EXPECT_EQ(locale_list.GetReferenceLocale(), "en");
+ EXPECT_TRUE(locale_list.GetLocales()[0].IsValid());
+ EXPECT_FALSE(locale_list.GetLocales()[1].IsValid());
+}
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/native/utils/i18n/locale.cc b/native/utils/i18n/locale.cc
index d5a1109..3719079 100644
--- a/native/utils/i18n/locale.cc
+++ b/native/utils/i18n/locale.cc
@@ -16,6 +16,8 @@
#include "utils/i18n/locale.h"
+#include <string>
+
#include "utils/strings/split.h"
namespace libtextclassifier3 {
@@ -196,6 +198,20 @@
return false;
}
+bool Locale::operator==(const Locale& locale) const {
+ return language_ == locale.language_ && region_ == locale.region_ &&
+ script_ == locale.script_;
+}
+
+bool Locale::operator<(const Locale& locale) const {
+ return std::tie(language_, region_, script_) <
+ std::tie(locale.language_, locale.region_, locale.script_);
+}
+
+bool Locale::operator!=(const Locale& locale) const {
+ return !(*this == locale);
+}
+
logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
const Locale& locale) {
return stream << "Locale(language=" << locale.Language()
diff --git a/native/utils/i18n/locale.h b/native/utils/i18n/locale.h
index 308846d..036bacd 100644
--- a/native/utils/i18n/locale.h
+++ b/native/utils/i18n/locale.h
@@ -60,6 +60,10 @@
const std::vector<Locale>& supported_locales,
bool default_value);
+ bool operator==(const Locale& locale) const;
+ bool operator!=(const Locale& locale) const;
+ bool operator<(const Locale& locale) const;
+
private:
Locale(const std::string& language, const std::string& script,
const std::string& region)
diff --git a/native/utils/intents/intent-config.fbs b/native/utils/intents/intent-config.fbs
old mode 100755
new mode 100644
diff --git a/native/utils/intents/intent-generator-test-lib.cc b/native/utils/intents/intent-generator-test-lib.cc
new file mode 100644
index 0000000..4207a3e
--- /dev/null
+++ b/native/utils/intents/intent-generator-test-lib.cc
@@ -0,0 +1,662 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <jni.h>
+
+#include <memory>
+#include <vector>
+
+#include "utils/flatbuffers/mutable.h"
+#include "utils/intents/intent-generator.h"
+#include "utils/intents/remote-action-template.h"
+#include "utils/java/jni-helper.h"
+#include "utils/jvm-test-utils.h"
+#include "utils/resources_generated.h"
+#include "utils/testing/logging_event_listener.h"
+#include "utils/variant.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "flatbuffers/reflection.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+using ::testing::ElementsAre;
+using ::testing::IsEmpty;
+using ::testing::SizeIs;
+
+flatbuffers::DetachedBuffer BuildTestIntentFactoryModel(
+ const std::string& entity_type, const std::string& generator_code) {
+ // Test intent generation options.
+ IntentFactoryModelT options;
+ options.generator.emplace_back(new IntentFactoryModel_::IntentGeneratorT());
+ options.generator.back()->type = entity_type;
+ options.generator.back()->lua_template_generator = std::vector<unsigned char>(
+ generator_code.data(), generator_code.data() + generator_code.size());
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(IntentFactoryModel::Pack(builder, &options));
+ return builder.Release();
+}
+
+flatbuffers::DetachedBuffer BuildTestResources() {
+ // Custom string resources.
+ ResourcePoolT test_resources;
+ test_resources.locale.emplace_back(new LanguageTagT);
+ test_resources.locale.back()->language = "en";
+ test_resources.locale.emplace_back(new LanguageTagT);
+ test_resources.locale.back()->language = "de";
+
+ // Add `add_calendar_event`
+ test_resources.resource_entry.emplace_back(new ResourceEntryT);
+ test_resources.resource_entry.back()->name = "add_calendar_event";
+
+ // en
+ test_resources.resource_entry.back()->resource.emplace_back(new ResourceT);
+ test_resources.resource_entry.back()->resource.back()->content = "Schedule";
+ test_resources.resource_entry.back()->resource.back()->locale.push_back(0);
+
+ // Add `add_calendar_event_desc`
+ test_resources.resource_entry.emplace_back(new ResourceEntryT);
+ test_resources.resource_entry.back()->name = "add_calendar_event_desc";
+
+ // en
+ test_resources.resource_entry.back()->resource.emplace_back(new ResourceT);
+ test_resources.resource_entry.back()->resource.back()->content =
+ "Schedule event for selected time";
+ test_resources.resource_entry.back()->resource.back()->locale.push_back(0);
+
+ // Add `map`.
+ test_resources.resource_entry.emplace_back(new ResourceEntryT);
+ test_resources.resource_entry.back()->name = "map";
+
+ // en
+ test_resources.resource_entry.back()->resource.emplace_back(new ResourceT);
+ test_resources.resource_entry.back()->resource.back()->content = "Map";
+ test_resources.resource_entry.back()->resource.back()->locale.push_back(0);
+
+ // de
+ test_resources.resource_entry.back()->resource.emplace_back(new ResourceT);
+ test_resources.resource_entry.back()->resource.back()->content = "Karte";
+ test_resources.resource_entry.back()->resource.back()->locale.push_back(1);
+
+ // Add `map_desc`.
+ test_resources.resource_entry.emplace_back(new ResourceEntryT);
+ test_resources.resource_entry.back()->name = "map_desc";
+
+ // en
+ test_resources.resource_entry.back()->resource.emplace_back(new ResourceT);
+ test_resources.resource_entry.back()->resource.back()->content =
+ "Locate selected address";
+ test_resources.resource_entry.back()->resource.back()->locale.push_back(0);
+
+ // de
+ test_resources.resource_entry.back()->resource.emplace_back(new ResourceT);
+ test_resources.resource_entry.back()->resource.back()->content =
+ "Ausgewählte Adresse finden";
+ test_resources.resource_entry.back()->resource.back()->locale.push_back(1);
+
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(ResourcePool::Pack(builder, &test_resources));
+ return builder.Release();
+}
+
+// Common methods for intent generator tests.
+class IntentGeneratorTest : public testing::Test {
+ protected:
+ explicit IntentGeneratorTest()
+ : jni_cache_(JniCache::Create(GetJenv())),
+ resource_buffer_(BuildTestResources()),
+ resources_(
+ flatbuffers::GetRoot<ResourcePool>(resource_buffer_.data())) {}
+
+ const std::shared_ptr<JniCache> jni_cache_;
+ const flatbuffers::DetachedBuffer resource_buffer_;
+ const ResourcePool* resources_;
+};
+
+TEST_F(IntentGeneratorTest, HandlesDefaultClassification) {
+ flatbuffers::DetachedBuffer intent_factory_model =
+ BuildTestIntentFactoryModel("unused", "");
+ std::unique_ptr<IntentGenerator> generator = IntentGenerator::Create(
+ /*options=*/flatbuffers::GetRoot<IntentFactoryModel>(
+ intent_factory_model.data()),
+ /*resources=*/resources_,
+ /*jni_cache=*/jni_cache_);
+ ClassificationResult classification;
+ std::vector<RemoteActionTemplate> intents;
+ EXPECT_TRUE(generator->GenerateIntents(
+ /*device_locales=*/nullptr, classification, /*reference_time_ms_utc=*/0,
+ /*text=*/"", /*selection_indices=*/{kInvalidIndex, kInvalidIndex},
+ /*context=*/nullptr,
+ /*annotations_entity_data_schema=*/nullptr, &intents));
+ EXPECT_THAT(intents, IsEmpty());
+}
+
+TEST_F(IntentGeneratorTest, FailsGracefully) {
+ flatbuffers::DetachedBuffer intent_factory_model =
+ BuildTestIntentFactoryModel("test", R"lua(
+return {
+ {
+ -- Should fail, as no app GetAndroidContext() is provided.
+ data = external.android.package_name,
+ }
+})lua");
+ std::unique_ptr<IntentGenerator> generator = IntentGenerator::Create(
+ flatbuffers::GetRoot<IntentFactoryModel>(intent_factory_model.data()),
+ /*resources=*/resources_, jni_cache_);
+ ClassificationResult classification = {"test", 1.0};
+ std::vector<RemoteActionTemplate> intents;
+ EXPECT_FALSE(generator->GenerateIntents(
+ JniHelper::NewStringUTF(GetJenv(), "en-US").ValueOrDie().get(),
+ classification,
+ /*reference_time_ms_utc=*/0, "test", {0, 4}, /*context=*/nullptr,
+ /*annotations_entity_data_schema=*/nullptr, &intents));
+ EXPECT_THAT(intents, IsEmpty());
+}
+
+TEST_F(IntentGeneratorTest, HandlesEntityIntentGeneration) {
+ flatbuffers::DetachedBuffer intent_factory_model =
+ BuildTestIntentFactoryModel("address", R"lua(
+return {
+ {
+ title_without_entity = external.android.R.map,
+ title_with_entity = external.entity.text,
+ description = external.android.R.map_desc,
+ action = "android.intent.action.VIEW",
+ data = "geo:0,0?q=" ..
+ external.android.urlencode(external.entity.text),
+ }
+})lua");
+ std::unique_ptr<IntentGenerator> generator = IntentGenerator::Create(
+ flatbuffers::GetRoot<IntentFactoryModel>(intent_factory_model.data()),
+ /*resources=*/resources_, jni_cache_);
+ ClassificationResult classification = {"address", 1.0};
+ std::vector<RemoteActionTemplate> intents;
+ EXPECT_TRUE(generator->GenerateIntents(
+ JniHelper::NewStringUTF(GetJenv(), "en-US").ValueOrDie().get(),
+ classification,
+ /*reference_time_ms_utc=*/0, "333 E Wonderview Ave", {0, 20},
+ GetAndroidContext(),
+ /*annotations_entity_data_schema=*/nullptr, &intents));
+ EXPECT_THAT(intents, SizeIs(1));
+ EXPECT_EQ(intents[0].title_without_entity.value(), "Map");
+ EXPECT_EQ(intents[0].title_with_entity.value(), "333 E Wonderview Ave");
+ EXPECT_EQ(intents[0].description.value(), "Locate selected address");
+ EXPECT_EQ(intents[0].action.value(), "android.intent.action.VIEW");
+ EXPECT_EQ(intents[0].data.value(), "geo:0,0?q=333%20E%20Wonderview%20Ave");
+}
+
+TEST_F(IntentGeneratorTest, HandlesCallbacks) {
+ flatbuffers::DetachedBuffer intent_factory_model =
+ BuildTestIntentFactoryModel("test", R"lua(
+local test = external.entity["text"]
+return {
+ {
+ data = "encoded=" .. external.android.urlencode(test),
+ category = { "test_category" },
+ extra = {
+ { name = "package", string_value = external.android.package_name},
+ { name = "scheme",
+ string_value = external.android.url_schema("https://google.com")},
+ { name = "host",
+ string_value = external.android.url_host("https://google.com/search")},
+ { name = "permission",
+ bool_value = external.android.user_restrictions["no_sms"] },
+ { name = "language",
+ string_value = external.android.device_locales[1].language },
+ { name = "description",
+ string_value = external.format("$1 $0", "hello", "world") },
+ },
+ request_code = external.hash(test)
+ }
+})lua");
+ std::unique_ptr<IntentGenerator> generator = IntentGenerator::Create(
+ flatbuffers::GetRoot<IntentFactoryModel>(intent_factory_model.data()),
+ /*resources=*/resources_, jni_cache_);
+ ClassificationResult classification = {"test", 1.0};
+ std::vector<RemoteActionTemplate> intents;
+ EXPECT_TRUE(generator->GenerateIntents(
+ JniHelper::NewStringUTF(GetJenv(), "en-US").ValueOrDie().get(),
+ classification,
+ /*reference_time_ms_utc=*/0, "this is a test", {0, 14},
+ GetAndroidContext(),
+ /*annotations_entity_data_schema=*/nullptr, &intents));
+ EXPECT_THAT(intents, SizeIs(1));
+ EXPECT_EQ(intents[0].data.value(), "encoded=this%20is%20a%20test");
+ EXPECT_THAT(intents[0].category, ElementsAre("test_category"));
+ EXPECT_THAT(intents[0].extra, SizeIs(6));
+ EXPECT_EQ(intents[0].extra["package"].ConstRefValue<std::string>(),
+ "com.google.android.textclassifier.tests"
+ );
+ EXPECT_EQ(intents[0].extra["scheme"].ConstRefValue<std::string>(), "https");
+ EXPECT_EQ(intents[0].extra["host"].ConstRefValue<std::string>(),
+ "google.com");
+ EXPECT_FALSE(intents[0].extra["permission"].Value<bool>());
+ EXPECT_EQ(intents[0].extra["language"].ConstRefValue<std::string>(), "en");
+ EXPECT_TRUE(intents[0].request_code.has_value());
+ EXPECT_EQ(intents[0].extra["description"].ConstRefValue<std::string>(),
+ "world hello");
+}
+
+TEST_F(IntentGeneratorTest, HandlesActionIntentGeneration) {
+ flatbuffers::DetachedBuffer intent_factory_model =
+ BuildTestIntentFactoryModel("view_map", R"lua(
+return {
+ {
+ title_without_entity = external.android.R.map,
+ description = external.android.R.map_desc,
+ description_with_app_name = external.android.R.map,
+ action = "android.intent.action.VIEW",
+ data = "geo:0,0?q=" ..
+ external.android.urlencode(external.entity.annotation["location"].text),
+ }
+})lua");
+ std::unique_ptr<IntentGenerator> generator = IntentGenerator::Create(
+ flatbuffers::GetRoot<IntentFactoryModel>(intent_factory_model.data()),
+ /*resources=*/resources_, jni_cache_);
+ Conversation conversation = {{{/*user_id=*/1, "hello there"}}};
+ ActionSuggestionAnnotation annotation;
+ annotation.entity = {"address", 1.0};
+ annotation.span = {/*message_index=*/0,
+ /*span=*/{6, 11},
+ /*text=*/"there"};
+ annotation.name = "location";
+ ActionSuggestion suggestion = {/*response_text=""*/ "",
+ /*type=*/"view_map",
+ /*score=*/1.0,
+ /*priority_score=*/0.0,
+ /*annotations=*/
+ {annotation}};
+ std::vector<RemoteActionTemplate> intents;
+ EXPECT_TRUE(generator->GenerateIntents(
+ JniHelper::NewStringUTF(GetJenv(), "en-US").ValueOrDie().get(),
+ suggestion, conversation, GetAndroidContext(),
+ /*annotations_entity_data_schema=*/nullptr,
+ /*actions_entity_data_schema=*/nullptr, &intents));
+ EXPECT_THAT(intents, SizeIs(1));
+ EXPECT_EQ(intents[0].title_without_entity.value(), "Map");
+ EXPECT_EQ(intents[0].description.value(), "Locate selected address");
+ EXPECT_EQ(intents[0].description_with_app_name.value(), "Map");
+ EXPECT_EQ(intents[0].action.value(), "android.intent.action.VIEW");
+ EXPECT_EQ(intents[0].data.value(), "geo:0,0?q=there");
+}
+
+TEST_F(IntentGeneratorTest, HandlesTimezoneAndReferenceTime) {
+ flatbuffers::DetachedBuffer intent_factory_model =
+ BuildTestIntentFactoryModel("test", R"lua(
+local conversation = external.conversation
+return {
+ {
+ extra = {
+ { name = "timezone", string_value = conversation[#conversation].timezone },
+ { name = "num_messages", int_value = #conversation },
+ { name = "reference_time", long_value = conversation[#conversation].time_ms_utc }
+ },
+ }
+})lua");
+ std::unique_ptr<IntentGenerator> generator = IntentGenerator::Create(
+ flatbuffers::GetRoot<IntentFactoryModel>(intent_factory_model.data()),
+ /*resources=*/resources_, jni_cache_);
+ Conversation conversation = {
+ {{/*user_id=*/0, "hello there", /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Testing/Test"},
+ {/*user_id=*/1, "general retesti", /*reference_time_ms_utc=*/1000,
+ /*reference_timezone=*/"Europe/Zurich"}}};
+ ActionSuggestion suggestion = {/*response_text=""*/ "",
+ /*type=*/"test",
+ /*score=*/1.0,
+ /*priority_score=*/0.0,
+ /*annotations=*/
+ {}};
+ std::vector<RemoteActionTemplate> intents;
+ EXPECT_TRUE(generator->GenerateIntents(
+ JniHelper::NewStringUTF(GetJenv(), "en-US").ValueOrDie().get(),
+ suggestion, conversation, GetAndroidContext(),
+ /*annotations_entity_data_schema=*/nullptr,
+ /*actions_entity_data_schema=*/nullptr, &intents));
+ EXPECT_THAT(intents, SizeIs(1));
+ EXPECT_EQ(intents[0].extra["timezone"].ConstRefValue<std::string>(),
+ "Europe/Zurich");
+ EXPECT_EQ(intents[0].extra["num_messages"].Value<int>(), 2);
+ EXPECT_EQ(intents[0].extra["reference_time"].Value<int64>(), 1000);
+}
+
+TEST_F(IntentGeneratorTest, HandlesActionIntentGenerationMultipleAnnotations) {
+ flatbuffers::DetachedBuffer intent_factory_model =
+ BuildTestIntentFactoryModel("create_event", R"lua(
+return {
+ {
+ title_without_entity = external.android.R.add_calendar_event,
+ description = external.android.R.add_calendar_event_desc,
+ extra = {
+ {name = "time", string_value =
+ external.entity.annotation["time"].text},
+ {name = "location",
+ string_value = external.entity.annotation["location"].text},
+ }
+ }
+})lua");
+ std::unique_ptr<IntentGenerator> generator = IntentGenerator::Create(
+ flatbuffers::GetRoot<IntentFactoryModel>(intent_factory_model.data()),
+ /*resources=*/resources_, jni_cache_);
+ Conversation conversation = {{{/*user_id=*/1, "hello there at 1pm"}}};
+ ActionSuggestionAnnotation location_annotation, time_annotation;
+ location_annotation.entity = {"address", 1.0};
+ location_annotation.span = {/*message_index=*/0,
+ /*span=*/{6, 11},
+ /*text=*/"there"};
+ location_annotation.name = "location";
+ time_annotation.entity = {"datetime", 1.0};
+ time_annotation.span = {/*message_index=*/0,
+ /*span=*/{15, 18},
+ /*text=*/"1pm"};
+ time_annotation.name = "time";
+ ActionSuggestion suggestion = {/*response_text=""*/ "",
+ /*type=*/"create_event",
+ /*score=*/1.0,
+ /*priority_score=*/0.0,
+ /*annotations=*/
+ {location_annotation, time_annotation}};
+ std::vector<RemoteActionTemplate> intents;
+ EXPECT_TRUE(generator->GenerateIntents(
+ JniHelper::NewStringUTF(GetJenv(), "en-US").ValueOrDie().get(),
+ suggestion, conversation, GetAndroidContext(),
+ /*annotations_entity_data_schema=*/nullptr,
+ /*actions_entity_data_schema=*/nullptr, &intents));
+ EXPECT_THAT(intents, SizeIs(1));
+ EXPECT_EQ(intents[0].title_without_entity.value(), "Schedule");
+ EXPECT_THAT(intents[0].extra, SizeIs(2));
+ EXPECT_EQ(intents[0].extra["location"].ConstRefValue<std::string>(), "there");
+ EXPECT_EQ(intents[0].extra["time"].ConstRefValue<std::string>(), "1pm");
+}
+
+TEST_F(IntentGeneratorTest,
+ HandlesActionIntentGenerationMultipleAnnotationsWithIndices) {
+ flatbuffers::DetachedBuffer intent_factory_model =
+ BuildTestIntentFactoryModel("time_range", R"lua(
+return {
+ {
+ title_without_entity = "test",
+ description = "test",
+ extra = {
+ {name = "from", string_value = external.entity.annotation[1].text},
+ {name = "to", string_value = external.entity.annotation[2].text},
+ }
+ }
+})lua");
+ std::unique_ptr<IntentGenerator> generator = IntentGenerator::Create(
+ flatbuffers::GetRoot<IntentFactoryModel>(intent_factory_model.data()),
+ /*resources=*/resources_, jni_cache_);
+ Conversation conversation = {{{/*user_id=*/1, "from 1pm to 2pm"}}};
+ ActionSuggestionAnnotation from_annotation, to_annotation;
+ from_annotation.entity = {"datetime", 1.0};
+ from_annotation.span = {/*message_index=*/0,
+ /*span=*/{5, 8},
+ /*text=*/"1pm"};
+ to_annotation.entity = {"datetime", 1.0};
+ to_annotation.span = {/*message_index=*/0,
+ /*span=*/{12, 15},
+ /*text=*/"2pm"};
+ ActionSuggestion suggestion = {/*response_text=""*/ "",
+ /*type=*/"time_range",
+ /*score=*/1.0,
+ /*priority_score=*/0.0,
+ /*annotations=*/
+ {from_annotation, to_annotation}};
+ std::vector<RemoteActionTemplate> intents;
+ EXPECT_TRUE(generator->GenerateIntents(
+ JniHelper::NewStringUTF(GetJenv(), "en-US").ValueOrDie().get(),
+ suggestion, conversation, GetAndroidContext(),
+ /*annotations_entity_data_schema=*/nullptr,
+ /*actions_entity_data_schema=*/nullptr, &intents));
+ EXPECT_THAT(intents, SizeIs(1));
+ EXPECT_THAT(intents[0].extra, SizeIs(2));
+ EXPECT_EQ(intents[0].extra["from"].ConstRefValue<std::string>(), "1pm");
+ EXPECT_EQ(intents[0].extra["to"].ConstRefValue<std::string>(), "2pm");
+}
+
+TEST_F(IntentGeneratorTest, HandlesResources) {
+ flatbuffers::DetachedBuffer intent_factory_model =
+ BuildTestIntentFactoryModel("address", R"lua(
+return {
+ {
+ title_without_entity = external.android.R.map,
+ description = external.android.R.map_desc,
+ action = "android.intent.action.VIEW",
+ data = "geo:0,0?q=" ..
+ external.android.urlencode(external.entity.text),
+ }
+})lua");
+ std::unique_ptr<IntentGenerator> generator = IntentGenerator::Create(
+ flatbuffers::GetRoot<IntentFactoryModel>(intent_factory_model.data()),
+ resources_, jni_cache_);
+ ClassificationResult classification = {"address", 1.0};
+ std::vector<RemoteActionTemplate> intents;
+ EXPECT_TRUE(generator->GenerateIntents(
+ JniHelper::NewStringUTF(GetJenv(), "de-DE").ValueOrDie().get(),
+ classification,
+ /*reference_time_ms_utc=*/0, "333 E Wonderview Ave", {0, 20},
+ GetAndroidContext(),
+ /*annotations_entity_data_schema=*/nullptr, &intents));
+ EXPECT_THAT(intents, SizeIs(1));
+ EXPECT_EQ(intents[0].title_without_entity.value(), "Karte");
+ EXPECT_EQ(intents[0].description.value(), "Ausgewählte Adresse finden");
+ EXPECT_EQ(intents[0].action.value(), "android.intent.action.VIEW");
+ EXPECT_EQ(intents[0].data.value(), "geo:0,0?q=333%20E%20Wonderview%20Ave");
+}
+
+TEST_F(IntentGeneratorTest, HandlesIteration) {
+ flatbuffers::DetachedBuffer intent_factory_model =
+ BuildTestIntentFactoryModel("iteration_test", R"lua(
+local extra = {{ name = "length", int_value = #external.entity.annotation }}
+for annotation_id, annotation in pairs(external.entity.annotation) do
+ table.insert(extra,
+ { name = annotation.name,
+ string_value = annotation.text })
+end
+return {{ extra = extra }})lua");
+ std::unique_ptr<IntentGenerator> generator = IntentGenerator::Create(
+ flatbuffers::GetRoot<IntentFactoryModel>(intent_factory_model.data()),
+ /*resources=*/resources_, jni_cache_);
+ Conversation conversation = {{{/*user_id=*/1, "hello there"}}};
+ ActionSuggestionAnnotation location_annotation;
+ location_annotation.entity = {"address", 1.0};
+ location_annotation.span = {/*message_index=*/0,
+ /*span=*/{6, 11},
+ /*text=*/"there"};
+ location_annotation.name = "location";
+ ActionSuggestionAnnotation greeting_annotation;
+ greeting_annotation.entity = {"greeting", 1.0};
+ greeting_annotation.span = {/*message_index=*/0,
+ /*span=*/{0, 5},
+ /*text=*/"hello"};
+ greeting_annotation.name = "greeting";
+ ActionSuggestion suggestion = {/*response_text=""*/ "",
+ /*type=*/"iteration_test",
+ /*score=*/1.0,
+ /*priority_score=*/0.0,
+ /*annotations=*/
+ {location_annotation, greeting_annotation}};
+ std::vector<RemoteActionTemplate> intents;
+ EXPECT_TRUE(generator->GenerateIntents(
+ JniHelper::NewStringUTF(GetJenv(), "en-US").ValueOrDie().get(),
+ suggestion, conversation, GetAndroidContext(),
+ /*annotations_entity_data_schema=*/nullptr,
+ /*actions_entity_data_schema=*/nullptr, &intents));
+ EXPECT_THAT(intents, SizeIs(1));
+ EXPECT_EQ(intents[0].extra["length"].Value<int>(), 2);
+ EXPECT_EQ(intents[0].extra["location"].ConstRefValue<std::string>(), "there");
+ EXPECT_EQ(intents[0].extra["greeting"].ConstRefValue<std::string>(), "hello");
+}
+
+TEST_F(IntentGeneratorTest, HandlesEntityDataLookups) {
+ flatbuffers::DetachedBuffer intent_factory_model =
+ BuildTestIntentFactoryModel("fake", R"lua(
+local person = external.entity.person
+return {
+ {
+ title_without_entity = "Add to contacts",
+ extra = {
+ {name = "name", string_value = string.lower(person.name)},
+ {name = "encoded_phone", string_value = external.android.urlencode(person.phone)},
+ {name = "age", int_value = person.age_years},
+ }
+ }
+})lua");
+
+ // Create fake entity data schema meta data.
+ // Cannot use object oriented API here as that is not available for the
+ // reflection schema.
+ flatbuffers::FlatBufferBuilder schema_builder;
+ std::vector<flatbuffers::Offset<reflection::Field>> person_fields = {
+ reflection::CreateField(
+ schema_builder,
+ /*name=*/schema_builder.CreateString("name"),
+ /*type=*/
+ reflection::CreateType(schema_builder,
+ /*base_type=*/reflection::String),
+ /*id=*/0,
+ /*offset=*/4),
+ reflection::CreateField(
+ schema_builder,
+ /*name=*/schema_builder.CreateString("phone"),
+ /*type=*/
+ reflection::CreateType(schema_builder,
+ /*base_type=*/reflection::String),
+ /*id=*/1,
+ /*offset=*/6),
+ reflection::CreateField(
+ schema_builder,
+ /*name=*/schema_builder.CreateString("age_years"),
+ /*type=*/
+ reflection::CreateType(schema_builder,
+ /*base_type=*/reflection::Int),
+ /*id=*/2,
+ /*offset=*/8),
+ };
+ std::vector<flatbuffers::Offset<reflection::Field>> entity_data_fields = {
+ reflection::CreateField(
+ schema_builder,
+ /*name=*/schema_builder.CreateString("person"),
+ /*type=*/
+ reflection::CreateType(schema_builder,
+ /*base_type=*/reflection::Obj,
+ /*element=*/reflection::None,
+ /*index=*/1),
+ /*id=*/0,
+ /*offset=*/4)};
+ std::vector<flatbuffers::Offset<reflection::Enum>> enums;
+ std::vector<flatbuffers::Offset<reflection::Object>> objects = {
+ reflection::CreateObject(
+ schema_builder,
+ /*name=*/schema_builder.CreateString("EntityData"),
+ /*fields=*/
+ schema_builder.CreateVectorOfSortedTables(&entity_data_fields)),
+ reflection::CreateObject(
+ schema_builder,
+ /*name=*/schema_builder.CreateString("person"),
+ /*fields=*/
+ schema_builder.CreateVectorOfSortedTables(&person_fields))};
+ schema_builder.Finish(reflection::CreateSchema(
+ schema_builder, schema_builder.CreateVectorOfSortedTables(&objects),
+ schema_builder.CreateVectorOfSortedTables(&enums),
+ /*(unused) file_ident=*/0,
+ /*(unused) file_ext=*/0,
+ /*root_table*/ objects[0]));
+ const reflection::Schema* entity_data_schema =
+ flatbuffers::GetRoot<reflection::Schema>(
+ schema_builder.GetBufferPointer());
+
+ std::unique_ptr<IntentGenerator> generator = IntentGenerator::Create(
+ flatbuffers::GetRoot<IntentFactoryModel>(intent_factory_model.data()),
+ /*resources=*/resources_, jni_cache_);
+
+ ClassificationResult classification = {"fake", 1.0};
+
+ // Build test entity data.
+ MutableFlatbufferBuilder entity_data_builder(entity_data_schema);
+ std::unique_ptr<MutableFlatbuffer> entity_data_buffer =
+ entity_data_builder.NewRoot();
+ MutableFlatbuffer* person = entity_data_buffer->Mutable("person");
+ person->Set("name", "Kenobi");
+ person->Set("phone", "1 800 HIGHGROUND");
+ person->Set("age_years", 38);
+ classification.serialized_entity_data = entity_data_buffer->Serialize();
+
+ std::vector<RemoteActionTemplate> intents;
+ EXPECT_TRUE(generator->GenerateIntents(
+ JniHelper::NewStringUTF(GetJenv(), "en-US").ValueOrDie().get(),
+ classification,
+ /*reference_time_ms_utc=*/0, "highground", {0, 10}, GetAndroidContext(),
+ /*annotations_entity_data_schema=*/entity_data_schema, &intents));
+ EXPECT_THAT(intents, SizeIs(1));
+ EXPECT_THAT(intents[0].extra, SizeIs(3));
+ EXPECT_EQ(intents[0].extra["name"].ConstRefValue<std::string>(), "kenobi");
+ EXPECT_EQ(intents[0].extra["encoded_phone"].ConstRefValue<std::string>(),
+ "1%20800%20HIGHGROUND");
+ EXPECT_EQ(intents[0].extra["age"].Value<int>(), 38);
+}
+
+TEST_F(IntentGeneratorTest, ReadExtras) {
+ flatbuffers::DetachedBuffer intent_factory_model =
+ BuildTestIntentFactoryModel("test", R"lua(
+ return {
+ {
+ extra = {
+ { name = "languages", string_array_value = {"en", "zh"}},
+ { name = "scores", float_array_value = {0.6, 0.4}},
+ { name = "ints", int_array_value = {7, 2, 1}},
+ { name = "bundle",
+ named_variant_array_value =
+ {
+ { name = "inner_string", string_value = "a" },
+ { name = "inner_int", int_value = 42 }
+ }
+ }
+ }
+ }}
+ )lua");
+ std::unique_ptr<IntentGenerator> generator = IntentGenerator::Create(
+ flatbuffers::GetRoot<IntentFactoryModel>(intent_factory_model.data()),
+ /*resources=*/resources_, jni_cache_);
+ const ClassificationResult classification = {"test", 1.0};
+ std::vector<RemoteActionTemplate> intents;
+
+ EXPECT_TRUE(generator->GenerateIntents(
+ JniHelper::NewStringUTF(GetJenv(), "en-US").ValueOrDie().get(),
+ classification,
+ /*reference_time_ms_utc=*/0, "test", {0, 4}, GetAndroidContext(),
+ /*annotations_entity_data_schema=*/nullptr, &intents));
+
+ EXPECT_THAT(intents, SizeIs(1));
+ RemoteActionTemplate intent = intents[0];
+ EXPECT_THAT(intent.extra, SizeIs(4));
+ EXPECT_THAT(
+ intent.extra["languages"].ConstRefValue<std::vector<std::string>>(),
+ ElementsAre("en", "zh"));
+ EXPECT_THAT(intent.extra["scores"].ConstRefValue<std::vector<float>>(),
+ ElementsAre(0.6, 0.4));
+ EXPECT_THAT(intent.extra["ints"].ConstRefValue<std::vector<int>>(),
+ ElementsAre(7, 2, 1));
+ const std::map<std::string, Variant>& map =
+ intent.extra["bundle"].ConstRefValue<std::map<std::string, Variant>>();
+ EXPECT_THAT(map, SizeIs(2));
+ EXPECT_EQ(map.at("inner_string").ConstRefValue<std::string>(), "a");
+ EXPECT_EQ(map.at("inner_int").Value<int>(), 42);
+}
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/native/utils/intents/intent-generator.cc b/native/utils/intents/intent-generator.cc
index 4cb3e40..7edef41 100644
--- a/native/utils/intents/intent-generator.cc
+++ b/native/utils/intents/intent-generator.cc
@@ -18,21 +18,11 @@
#include <vector>
-#include "actions/types.h"
-#include "annotator/types.h"
#include "utils/base/logging.h"
-#include "utils/base/statusor.h"
-#include "utils/hash/farmhash.h"
-#include "utils/java/jni-base.h"
+#include "utils/intents/jni-lua.h"
#include "utils/java/jni-helper.h"
-#include "utils/java/string_utils.h"
-#include "utils/lua-utils.h"
-#include "utils/strings/stringpiece.h"
-#include "utils/strings/substitute.h"
#include "utils/utf8/unicodetext.h"
-#include "utils/variant.h"
#include "utils/zlib/zlib.h"
-#include "flatbuffers/reflection_generated.h"
#ifdef __cplusplus
extern "C" {
@@ -47,696 +37,6 @@
namespace {
static constexpr const char* kReferenceTimeUsecKey = "reference_time_ms_utc";
-static constexpr const char* kHashKey = "hash";
-static constexpr const char* kUrlSchemaKey = "url_schema";
-static constexpr const char* kUrlHostKey = "url_host";
-static constexpr const char* kUrlEncodeKey = "urlencode";
-static constexpr const char* kPackageNameKey = "package_name";
-static constexpr const char* kDeviceLocaleKey = "device_locales";
-static constexpr const char* kFormatKey = "format";
-
-// An Android specific Lua environment with JNI backed callbacks.
-class JniLuaEnvironment : public LuaEnvironment {
- public:
- JniLuaEnvironment(const Resources& resources, const JniCache* jni_cache,
- const jobject context,
- const std::vector<Locale>& device_locales);
- // Environment setup.
- bool Initialize();
-
- // Runs an intent generator snippet.
- bool RunIntentGenerator(const std::string& generator_snippet,
- std::vector<RemoteActionTemplate>* remote_actions);
-
- protected:
- virtual void SetupExternalHook();
-
- int HandleExternalCallback();
- int HandleAndroidCallback();
- int HandleUserRestrictionsCallback();
- int HandleUrlEncode();
- int HandleUrlSchema();
- int HandleHash();
- int HandleFormat();
- int HandleAndroidStringResources();
- int HandleUrlHost();
-
- // Checks and retrieves string resources from the model.
- bool LookupModelStringResource() const;
-
- // Reads and create a RemoteAction result from Lua.
- RemoteActionTemplate ReadRemoteActionTemplateResult() const;
-
- // Reads the extras from the Lua result.
- std::map<std::string, Variant> ReadExtras() const;
-
- // Retrieves user manager if not previously done.
- bool RetrieveUserManager();
-
- // Retrieves system resources if not previously done.
- bool RetrieveSystemResources();
-
- // Parse the url string by using Uri.parse from Java.
- StatusOr<ScopedLocalRef<jobject>> ParseUri(StringPiece url) const;
-
- // Read remote action templates from lua generator.
- int ReadRemoteActionTemplates(std::vector<RemoteActionTemplate>* result);
-
- const Resources& resources_;
- JNIEnv* jenv_;
- const JniCache* jni_cache_;
- const jobject context_;
- std::vector<Locale> device_locales_;
-
- ScopedGlobalRef<jobject> usermanager_;
- // Whether we previously attempted to retrieve the UserManager before.
- bool usermanager_retrieved_;
-
- ScopedGlobalRef<jobject> system_resources_;
- // Whether we previously attempted to retrieve the system resources.
- bool system_resources_resources_retrieved_;
-
- // Cached JNI references for Java strings `string` and `android`.
- ScopedGlobalRef<jstring> string_;
- ScopedGlobalRef<jstring> android_;
-};
-
-JniLuaEnvironment::JniLuaEnvironment(const Resources& resources,
- const JniCache* jni_cache,
- const jobject context,
- const std::vector<Locale>& device_locales)
- : resources_(resources),
- jenv_(jni_cache ? jni_cache->GetEnv() : nullptr),
- jni_cache_(jni_cache),
- context_(context),
- device_locales_(device_locales),
- usermanager_(/*object=*/nullptr,
- /*jvm=*/(jni_cache ? jni_cache->jvm : nullptr)),
- usermanager_retrieved_(false),
- system_resources_(/*object=*/nullptr,
- /*jvm=*/(jni_cache ? jni_cache->jvm : nullptr)),
- system_resources_resources_retrieved_(false),
- string_(/*object=*/nullptr,
- /*jvm=*/(jni_cache ? jni_cache->jvm : nullptr)),
- android_(/*object=*/nullptr,
- /*jvm=*/(jni_cache ? jni_cache->jvm : nullptr)) {}
-
-bool JniLuaEnvironment::Initialize() {
- TC3_ASSIGN_OR_RETURN_FALSE(ScopedLocalRef<jstring> string_value,
- JniHelper::NewStringUTF(jenv_, "string"));
- string_ = MakeGlobalRef(string_value.get(), jenv_, jni_cache_->jvm);
- TC3_ASSIGN_OR_RETURN_FALSE(ScopedLocalRef<jstring> android_value,
- JniHelper::NewStringUTF(jenv_, "android"));
- android_ = MakeGlobalRef(android_value.get(), jenv_, jni_cache_->jvm);
- if (string_ == nullptr || android_ == nullptr) {
- TC3_LOG(ERROR) << "Could not allocate constant strings references.";
- return false;
- }
- return (RunProtected([this] {
- LoadDefaultLibraries();
- SetupExternalHook();
- lua_setglobal(state_, "external");
- return LUA_OK;
- }) == LUA_OK);
-}
-
-void JniLuaEnvironment::SetupExternalHook() {
- // This exposes an `external` object with the following fields:
- // * entity: the bundle with all information about a classification.
- // * android: callbacks into specific android provided methods.
- // * android.user_restrictions: callbacks to check user permissions.
- // * android.R: callbacks to retrieve string resources.
- PushLazyObject(&JniLuaEnvironment::HandleExternalCallback);
-
- // android
- PushLazyObject(&JniLuaEnvironment::HandleAndroidCallback);
- {
- // android.user_restrictions
- PushLazyObject(&JniLuaEnvironment::HandleUserRestrictionsCallback);
- lua_setfield(state_, /*idx=*/-2, "user_restrictions");
-
- // android.R
- // Callback to access android string resources.
- PushLazyObject(&JniLuaEnvironment::HandleAndroidStringResources);
- lua_setfield(state_, /*idx=*/-2, "R");
- }
- lua_setfield(state_, /*idx=*/-2, "android");
-}
-
-int JniLuaEnvironment::HandleExternalCallback() {
- const StringPiece key = ReadString(kIndexStackTop);
- if (key.Equals(kHashKey)) {
- PushFunction(&JniLuaEnvironment::HandleHash);
- return 1;
- } else if (key.Equals(kFormatKey)) {
- PushFunction(&JniLuaEnvironment::HandleFormat);
- return 1;
- } else {
- TC3_LOG(ERROR) << "Undefined external access " << key;
- lua_error(state_);
- return 0;
- }
-}
-
-int JniLuaEnvironment::HandleAndroidCallback() {
- const StringPiece key = ReadString(kIndexStackTop);
- if (key.Equals(kDeviceLocaleKey)) {
- // Provide the locale as table with the individual fields set.
- lua_newtable(state_);
- for (int i = 0; i < device_locales_.size(); i++) {
- // Adjust index to 1-based indexing for Lua.
- lua_pushinteger(state_, i + 1);
- lua_newtable(state_);
- PushString(device_locales_[i].Language());
- lua_setfield(state_, -2, "language");
- PushString(device_locales_[i].Region());
- lua_setfield(state_, -2, "region");
- PushString(device_locales_[i].Script());
- lua_setfield(state_, -2, "script");
- lua_settable(state_, /*idx=*/-3);
- }
- return 1;
- } else if (key.Equals(kPackageNameKey)) {
- if (context_ == nullptr) {
- TC3_LOG(ERROR) << "Context invalid.";
- lua_error(state_);
- return 0;
- }
-
- StatusOr<ScopedLocalRef<jstring>> status_or_package_name_str =
- JniHelper::CallObjectMethod<jstring>(
- jenv_, context_, jni_cache_->context_get_package_name);
-
- if (!status_or_package_name_str.ok()) {
- TC3_LOG(ERROR) << "Error calling Context.getPackageName";
- lua_error(state_);
- return 0;
- }
- StatusOr<std::string> status_or_package_name_std_str =
- ToStlString(jenv_, status_or_package_name_str.ValueOrDie().get());
- if (!status_or_package_name_std_str.ok()) {
- lua_error(state_);
- return 0;
- }
- PushString(status_or_package_name_std_str.ValueOrDie());
- return 1;
- } else if (key.Equals(kUrlEncodeKey)) {
- PushFunction(&JniLuaEnvironment::HandleUrlEncode);
- return 1;
- } else if (key.Equals(kUrlHostKey)) {
- PushFunction(&JniLuaEnvironment::HandleUrlHost);
- return 1;
- } else if (key.Equals(kUrlSchemaKey)) {
- PushFunction(&JniLuaEnvironment::HandleUrlSchema);
- return 1;
- } else {
- TC3_LOG(ERROR) << "Undefined android reference " << key;
- lua_error(state_);
- return 0;
- }
-}
-
-int JniLuaEnvironment::HandleUserRestrictionsCallback() {
- if (jni_cache_->usermanager_class == nullptr ||
- jni_cache_->usermanager_get_user_restrictions == nullptr) {
- // UserManager is only available for API level >= 17 and
- // getUserRestrictions only for API level >= 18, so we just return false
- // normally here.
- lua_pushboolean(state_, false);
- return 1;
- }
-
- // Get user manager if not previously retrieved.
- if (!RetrieveUserManager()) {
- TC3_LOG(ERROR) << "Error retrieving user manager.";
- lua_error(state_);
- return 0;
- }
-
- StatusOr<ScopedLocalRef<jobject>> status_or_bundle =
- JniHelper::CallObjectMethod(
- jenv_, usermanager_.get(),
- jni_cache_->usermanager_get_user_restrictions);
- if (!status_or_bundle.ok() || status_or_bundle.ValueOrDie() == nullptr) {
- TC3_LOG(ERROR) << "Error calling getUserRestrictions";
- lua_error(state_);
- return 0;
- }
-
- const StringPiece key_str = ReadString(kIndexStackTop);
- if (key_str.empty()) {
- TC3_LOG(ERROR) << "Expected string, got null.";
- lua_error(state_);
- return 0;
- }
-
- const StatusOr<ScopedLocalRef<jstring>> status_or_key =
- jni_cache_->ConvertToJavaString(key_str);
- if (!status_or_key.ok()) {
- lua_error(state_);
- return 0;
- }
- const StatusOr<bool> status_or_permission = JniHelper::CallBooleanMethod(
- jenv_, status_or_bundle.ValueOrDie().get(),
- jni_cache_->bundle_get_boolean, status_or_key.ValueOrDie().get());
- if (!status_or_permission.ok()) {
- TC3_LOG(ERROR) << "Error getting bundle value";
- lua_pushboolean(state_, false);
- } else {
- lua_pushboolean(state_, status_or_permission.ValueOrDie());
- }
- return 1;
-}
-
-int JniLuaEnvironment::HandleUrlEncode() {
- const StringPiece input = ReadString(/*index=*/1);
- if (input.empty()) {
- TC3_LOG(ERROR) << "Expected string, got null.";
- lua_error(state_);
- return 0;
- }
-
- // Call Java URL encoder.
- const StatusOr<ScopedLocalRef<jstring>> status_or_input_str =
- jni_cache_->ConvertToJavaString(input);
- if (!status_or_input_str.ok()) {
- lua_error(state_);
- return 0;
- }
- StatusOr<ScopedLocalRef<jstring>> status_or_encoded_str =
- JniHelper::CallStaticObjectMethod<jstring>(
- jenv_, jni_cache_->urlencoder_class.get(),
- jni_cache_->urlencoder_encode, status_or_input_str.ValueOrDie().get(),
- jni_cache_->string_utf8.get());
-
- if (!status_or_encoded_str.ok()) {
- TC3_LOG(ERROR) << "Error calling UrlEncoder.encode";
- lua_error(state_);
- return 0;
- }
- const StatusOr<std::string> status_or_encoded_std_str =
- ToStlString(jenv_, status_or_encoded_str.ValueOrDie().get());
- if (!status_or_encoded_std_str.ok()) {
- lua_error(state_);
- return 0;
- }
- PushString(status_or_encoded_std_str.ValueOrDie());
- return 1;
-}
-
-StatusOr<ScopedLocalRef<jobject>> JniLuaEnvironment::ParseUri(
- StringPiece url) const {
- if (url.empty()) {
- return {Status::UNKNOWN};
- }
-
- // Call to Java URI parser.
- TC3_ASSIGN_OR_RETURN(
- const StatusOr<ScopedLocalRef<jstring>> status_or_url_str,
- jni_cache_->ConvertToJavaString(url));
-
- // Try to parse uri and get scheme.
- TC3_ASSIGN_OR_RETURN(
- ScopedLocalRef<jobject> uri,
- JniHelper::CallStaticObjectMethod(jenv_, jni_cache_->uri_class.get(),
- jni_cache_->uri_parse,
- status_or_url_str.ValueOrDie().get()));
- if (uri == nullptr) {
- TC3_LOG(ERROR) << "Error calling Uri.parse";
- return {Status::UNKNOWN};
- }
- return uri;
-}
-
-int JniLuaEnvironment::HandleUrlSchema() {
- StringPiece url = ReadString(/*index=*/1);
-
- const StatusOr<ScopedLocalRef<jobject>> status_or_parsed_uri = ParseUri(url);
- if (!status_or_parsed_uri.ok()) {
- lua_error(state_);
- return 0;
- }
-
- const StatusOr<ScopedLocalRef<jstring>> status_or_scheme_str =
- JniHelper::CallObjectMethod<jstring>(
- jenv_, status_or_parsed_uri.ValueOrDie().get(),
- jni_cache_->uri_get_scheme);
- if (!status_or_scheme_str.ok()) {
- TC3_LOG(ERROR) << "Error calling Uri.getScheme";
- lua_error(state_);
- return 0;
- }
- if (status_or_scheme_str.ValueOrDie() == nullptr) {
- lua_pushnil(state_);
- } else {
- const StatusOr<std::string> status_or_scheme_std_str =
- ToStlString(jenv_, status_or_scheme_str.ValueOrDie().get());
- if (!status_or_scheme_std_str.ok()) {
- lua_error(state_);
- return 0;
- }
- PushString(status_or_scheme_std_str.ValueOrDie());
- }
- return 1;
-}
-
-int JniLuaEnvironment::HandleUrlHost() {
- const StringPiece url = ReadString(kIndexStackTop);
-
- const StatusOr<ScopedLocalRef<jobject>> status_or_parsed_uri = ParseUri(url);
- if (!status_or_parsed_uri.ok()) {
- lua_error(state_);
- return 0;
- }
-
- const StatusOr<ScopedLocalRef<jstring>> status_or_host_str =
- JniHelper::CallObjectMethod<jstring>(
- jenv_, status_or_parsed_uri.ValueOrDie().get(),
- jni_cache_->uri_get_host);
- if (!status_or_host_str.ok()) {
- TC3_LOG(ERROR) << "Error calling Uri.getHost";
- lua_error(state_);
- return 0;
- }
-
- if (status_or_host_str.ValueOrDie() == nullptr) {
- lua_pushnil(state_);
- } else {
- const StatusOr<std::string> status_or_host_std_str =
- ToStlString(jenv_, status_or_host_str.ValueOrDie().get());
- if (!status_or_host_std_str.ok()) {
- lua_error(state_);
- return 0;
- }
- PushString(status_or_host_std_str.ValueOrDie());
- }
- return 1;
-}
-
-int JniLuaEnvironment::HandleHash() {
- const StringPiece input = ReadString(kIndexStackTop);
- lua_pushinteger(state_, tc3farmhash::Hash32(input.data(), input.length()));
- return 1;
-}
-
-int JniLuaEnvironment::HandleFormat() {
- const int num_args = lua_gettop(state_);
- std::vector<StringPiece> args(num_args - 1);
- for (int i = 0; i < num_args - 1; i++) {
- args[i] = ReadString(/*index=*/i + 2);
- }
- PushString(strings::Substitute(ReadString(/*index=*/1), args));
- return 1;
-}
-
-bool JniLuaEnvironment::LookupModelStringResource() const {
- // Handle only lookup by name.
- if (lua_type(state_, kIndexStackTop) != LUA_TSTRING) {
- return false;
- }
-
- const StringPiece resource_name = ReadString(kIndexStackTop);
- std::string resource_content;
- if (!resources_.GetResourceContent(device_locales_, resource_name,
- &resource_content)) {
- // Resource cannot be provided by the model.
- return false;
- }
-
- PushString(resource_content);
- return true;
-}
-
-int JniLuaEnvironment::HandleAndroidStringResources() {
- // Check whether the requested resource can be served from the model data.
- if (LookupModelStringResource()) {
- return 1;
- }
-
- // Get system resources if not previously retrieved.
- if (!RetrieveSystemResources()) {
- TC3_LOG(ERROR) << "Error retrieving system resources.";
- lua_error(state_);
- return 0;
- }
-
- int resource_id;
- switch (lua_type(state_, kIndexStackTop)) {
- case LUA_TNUMBER:
- resource_id = Read<int>(/*index=*/kIndexStackTop);
- break;
- case LUA_TSTRING: {
- const StringPiece resource_name_str = ReadString(kIndexStackTop);
- if (resource_name_str.empty()) {
- TC3_LOG(ERROR) << "No resource name provided.";
- lua_error(state_);
- return 0;
- }
- const StatusOr<ScopedLocalRef<jstring>> status_or_resource_name =
- jni_cache_->ConvertToJavaString(resource_name_str);
- if (!status_or_resource_name.ok()) {
- TC3_LOG(ERROR) << "Invalid resource name.";
- lua_error(state_);
- return 0;
- }
- StatusOr<int> status_or_resource_id = JniHelper::CallIntMethod(
- jenv_, system_resources_.get(), jni_cache_->resources_get_identifier,
- status_or_resource_name.ValueOrDie().get(), string_.get(),
- android_.get());
- if (!status_or_resource_id.ok()) {
- TC3_LOG(ERROR) << "Error calling getIdentifier.";
- lua_error(state_);
- return 0;
- }
- resource_id = status_or_resource_id.ValueOrDie();
- break;
- }
- default:
- TC3_LOG(ERROR) << "Unexpected type for resource lookup.";
- lua_error(state_);
- return 0;
- }
- if (resource_id == 0) {
- TC3_LOG(ERROR) << "Resource not found.";
- lua_pushnil(state_);
- return 1;
- }
- StatusOr<ScopedLocalRef<jstring>> status_or_resource_str =
- JniHelper::CallObjectMethod<jstring>(jenv_, system_resources_.get(),
- jni_cache_->resources_get_string,
- resource_id);
- if (!status_or_resource_str.ok()) {
- TC3_LOG(ERROR) << "Error calling getString.";
- lua_error(state_);
- return 0;
- }
-
- if (status_or_resource_str.ValueOrDie() == nullptr) {
- lua_pushnil(state_);
- } else {
- StatusOr<std::string> status_or_resource_std_str =
- ToStlString(jenv_, status_or_resource_str.ValueOrDie().get());
- if (!status_or_resource_std_str.ok()) {
- lua_error(state_);
- return 0;
- }
- PushString(status_or_resource_std_str.ValueOrDie());
- }
- return 1;
-}
-
-bool JniLuaEnvironment::RetrieveSystemResources() {
- if (system_resources_resources_retrieved_) {
- return (system_resources_ != nullptr);
- }
- system_resources_resources_retrieved_ = true;
- TC3_ASSIGN_OR_RETURN_FALSE(ScopedLocalRef<jobject> system_resources_ref,
- JniHelper::CallStaticObjectMethod(
- jenv_, jni_cache_->resources_class.get(),
- jni_cache_->resources_get_system));
- system_resources_ =
- MakeGlobalRef(system_resources_ref.get(), jenv_, jni_cache_->jvm);
- return (system_resources_ != nullptr);
-}
-
-bool JniLuaEnvironment::RetrieveUserManager() {
- if (context_ == nullptr) {
- return false;
- }
- if (usermanager_retrieved_) {
- return (usermanager_ != nullptr);
- }
- usermanager_retrieved_ = true;
- TC3_ASSIGN_OR_RETURN_FALSE(const ScopedLocalRef<jstring> service,
- JniHelper::NewStringUTF(jenv_, "user"));
- TC3_ASSIGN_OR_RETURN_FALSE(
- const ScopedLocalRef<jobject> usermanager_ref,
- JniHelper::CallObjectMethod(jenv_, context_,
- jni_cache_->context_get_system_service,
- service.get()));
-
- usermanager_ = MakeGlobalRef(usermanager_ref.get(), jenv_, jni_cache_->jvm);
- return (usermanager_ != nullptr);
-}
-
-RemoteActionTemplate JniLuaEnvironment::ReadRemoteActionTemplateResult() const {
- RemoteActionTemplate result;
- // Read intent template.
- lua_pushnil(state_);
- while (Next(/*index=*/-2)) {
- const StringPiece key = ReadString(/*index=*/-2);
- if (key.Equals("title_without_entity")) {
- result.title_without_entity = Read<std::string>(/*index=*/kIndexStackTop);
- } else if (key.Equals("title_with_entity")) {
- result.title_with_entity = Read<std::string>(/*index=*/kIndexStackTop);
- } else if (key.Equals("description")) {
- result.description = Read<std::string>(/*index=*/kIndexStackTop);
- } else if (key.Equals("description_with_app_name")) {
- result.description_with_app_name =
- Read<std::string>(/*index=*/kIndexStackTop);
- } else if (key.Equals("action")) {
- result.action = Read<std::string>(/*index=*/kIndexStackTop);
- } else if (key.Equals("data")) {
- result.data = Read<std::string>(/*index=*/kIndexStackTop);
- } else if (key.Equals("type")) {
- result.type = Read<std::string>(/*index=*/kIndexStackTop);
- } else if (key.Equals("flags")) {
- result.flags = Read<int>(/*index=*/kIndexStackTop);
- } else if (key.Equals("package_name")) {
- result.package_name = Read<std::string>(/*index=*/kIndexStackTop);
- } else if (key.Equals("request_code")) {
- result.request_code = Read<int>(/*index=*/kIndexStackTop);
- } else if (key.Equals("category")) {
- result.category = ReadVector<std::string>(/*index=*/kIndexStackTop);
- } else if (key.Equals("extra")) {
- result.extra = ReadExtras();
- } else {
- TC3_LOG(INFO) << "Unknown entry: " << key;
- }
- lua_pop(state_, 1);
- }
- lua_pop(state_, 1);
- return result;
-}
-
-std::map<std::string, Variant> JniLuaEnvironment::ReadExtras() const {
- if (lua_type(state_, kIndexStackTop) != LUA_TTABLE) {
- TC3_LOG(ERROR) << "Expected extras table, got: "
- << lua_type(state_, kIndexStackTop);
- lua_pop(state_, 1);
- return {};
- }
- std::map<std::string, Variant> extras;
- lua_pushnil(state_);
- while (Next(/*index=*/-2)) {
- // Each entry is a table specifying name and value.
- // The value is specified via a type specific field as Lua doesn't allow
- // to easily distinguish between different number types.
- if (lua_type(state_, kIndexStackTop) != LUA_TTABLE) {
- TC3_LOG(ERROR) << "Expected a table for an extra, got: "
- << lua_type(state_, kIndexStackTop);
- lua_pop(state_, 1);
- return {};
- }
- std::string name;
- Variant value;
-
- lua_pushnil(state_);
- while (Next(/*index=*/-2)) {
- const StringPiece key = ReadString(/*index=*/-2);
- if (key.Equals("name")) {
- name = Read<std::string>(/*index=*/kIndexStackTop);
- } else if (key.Equals("int_value")) {
- value = Variant(Read<int>(/*index=*/kIndexStackTop));
- } else if (key.Equals("long_value")) {
- value = Variant(Read<int64>(/*index=*/kIndexStackTop));
- } else if (key.Equals("float_value")) {
- value = Variant(Read<float>(/*index=*/kIndexStackTop));
- } else if (key.Equals("bool_value")) {
- value = Variant(Read<bool>(/*index=*/kIndexStackTop));
- } else if (key.Equals("string_value")) {
- value = Variant(Read<std::string>(/*index=*/kIndexStackTop));
- } else if (key.Equals("string_array_value")) {
- value = Variant(ReadVector<std::string>(/*index=*/kIndexStackTop));
- } else if (key.Equals("float_array_value")) {
- value = Variant(ReadVector<float>(/*index=*/kIndexStackTop));
- } else if (key.Equals("int_array_value")) {
- value = Variant(ReadVector<int>(/*index=*/kIndexStackTop));
- } else if (key.Equals("named_variant_array_value")) {
- value = Variant(ReadExtras());
- } else {
- TC3_LOG(INFO) << "Unknown extra field: " << key;
- }
- lua_pop(state_, 1);
- }
- if (!name.empty()) {
- extras[name] = value;
- } else {
- TC3_LOG(ERROR) << "Unnamed extra entry. Skipping.";
- }
- lua_pop(state_, 1);
- }
- return extras;
-}
-
-int JniLuaEnvironment::ReadRemoteActionTemplates(
- std::vector<RemoteActionTemplate>* result) {
- // Read result.
- if (lua_type(state_, kIndexStackTop) != LUA_TTABLE) {
- TC3_LOG(ERROR) << "Unexpected result for snippet: "
- << lua_type(state_, kIndexStackTop);
- lua_error(state_);
- return LUA_ERRRUN;
- }
-
- // Read remote action templates array.
- lua_pushnil(state_);
- while (Next(/*index=*/-2)) {
- if (lua_type(state_, kIndexStackTop) != LUA_TTABLE) {
- TC3_LOG(ERROR) << "Expected intent table, got: "
- << lua_type(state_, kIndexStackTop);
- lua_pop(state_, 1);
- continue;
- }
- result->push_back(ReadRemoteActionTemplateResult());
- }
- lua_pop(state_, /*n=*/1);
- return LUA_OK;
-}
-
-bool JniLuaEnvironment::RunIntentGenerator(
- const std::string& generator_snippet,
- std::vector<RemoteActionTemplate>* remote_actions) {
- int status;
- status = luaL_loadbuffer(state_, generator_snippet.data(),
- generator_snippet.size(),
- /*name=*/nullptr);
- if (status != LUA_OK) {
- TC3_LOG(ERROR) << "Couldn't load generator snippet: " << status;
- return false;
- }
- status = lua_pcall(state_, /*nargs=*/0, /*nresults=*/1, /*errfunc=*/0);
- if (status != LUA_OK) {
- TC3_LOG(ERROR) << "Couldn't run generator snippet: " << status;
- return false;
- }
- if (RunProtected(
- [this, remote_actions] {
- return ReadRemoteActionTemplates(remote_actions);
- },
- /*num_args=*/1) != LUA_OK) {
- TC3_LOG(ERROR) << "Could not read results.";
- return false;
- }
- // Check that we correctly cleaned-up the state.
- const int stack_size = lua_gettop(state_);
- if (stack_size > 0) {
- TC3_LOG(ERROR) << "Unexpected stack size.";
- lua_settop(state_, 0);
- return false;
- }
- return true;
-}
// Lua environment for classfication result intent generation.
class AnnotatorJniEnvironment : public JniLuaEnvironment {
@@ -855,15 +155,15 @@
TC3_LOG(ERROR) << "No locales provided.";
return {};
}
- ScopedStringChars locales_str =
- GetScopedStringChars(jni_cache_->GetEnv(), device_locales);
- if (locales_str == nullptr) {
- TC3_LOG(ERROR) << "Cannot retrieve provided locales.";
+ StatusOr<std::string> status_or_locales_str =
+ JStringToUtf8String(jni_cache_->GetEnv(), device_locales);
+ if (!status_or_locales_str.ok()) {
+ TC3_LOG(ERROR)
+ << "JStringToUtf8String failed, cannot retrieve provided locales.";
return {};
}
std::vector<Locale> locales;
- if (!ParseLocales(reinterpret_cast<const char*>(locales_str.get()),
- &locales)) {
+ if (!ParseLocales(status_or_locales_str.ValueOrDie(), &locales)) {
TC3_LOG(ERROR) << "Cannot parse locales.";
return {};
}
diff --git a/native/utils/intents/intent-generator.h b/native/utils/intents/intent-generator.h
index 2a45191..c5cbb1d 100644
--- a/native/utils/intents/intent-generator.h
+++ b/native/utils/intents/intent-generator.h
@@ -1,3 +1,19 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
#ifndef LIBTEXTCLASSIFIER_UTILS_INTENTS_INTENT_GENERATOR_H_
#define LIBTEXTCLASSIFIER_UTILS_INTENTS_INTENT_GENERATOR_H_
@@ -18,6 +34,7 @@
#include "utils/resources.h"
#include "utils/resources_generated.h"
#include "utils/strings/stringpiece.h"
+#include "flatbuffers/reflection_generated.h"
namespace libtextclassifier3 {
diff --git a/native/utils/intents/jni-lua.cc b/native/utils/intents/jni-lua.cc
new file mode 100644
index 0000000..71a466e
--- /dev/null
+++ b/native/utils/intents/jni-lua.cc
@@ -0,0 +1,669 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/intents/jni-lua.h"
+
+#include "utils/hash/farmhash.h"
+#include "utils/java/jni-helper.h"
+#include "utils/strings/substitute.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+#include "lauxlib.h"
+#include "lua.h"
+#ifdef __cplusplus
+}
+#endif
+
+namespace libtextclassifier3 {
+namespace {
+
+static constexpr const char* kHashKey = "hash";
+static constexpr const char* kUrlSchemaKey = "url_schema";
+static constexpr const char* kUrlHostKey = "url_host";
+static constexpr const char* kUrlEncodeKey = "urlencode";
+static constexpr const char* kPackageNameKey = "package_name";
+static constexpr const char* kDeviceLocaleKey = "device_locales";
+static constexpr const char* kFormatKey = "format";
+
+} // namespace
+
+JniLuaEnvironment::JniLuaEnvironment(const Resources& resources,
+ const JniCache* jni_cache,
+ const jobject context,
+ const std::vector<Locale>& device_locales)
+ : LuaEnvironment(),
+ resources_(resources),
+ jenv_(jni_cache ? jni_cache->GetEnv() : nullptr),
+ jni_cache_(jni_cache),
+ context_(context),
+ device_locales_(device_locales),
+ usermanager_(/*object=*/nullptr,
+ /*jvm=*/(jni_cache ? jni_cache->jvm : nullptr)),
+ usermanager_retrieved_(false),
+ system_resources_(/*object=*/nullptr,
+ /*jvm=*/(jni_cache ? jni_cache->jvm : nullptr)),
+ system_resources_resources_retrieved_(false),
+ string_(/*object=*/nullptr,
+ /*jvm=*/(jni_cache ? jni_cache->jvm : nullptr)),
+ android_(/*object=*/nullptr,
+ /*jvm=*/(jni_cache ? jni_cache->jvm : nullptr)) {}
+
+bool JniLuaEnvironment::PreallocateConstantJniStrings() {
+ TC3_ASSIGN_OR_RETURN_FALSE(ScopedLocalRef<jstring> string_value,
+ JniHelper::NewStringUTF(jenv_, "string"));
+ string_ = MakeGlobalRef(string_value.get(), jenv_, jni_cache_->jvm);
+ TC3_ASSIGN_OR_RETURN_FALSE(ScopedLocalRef<jstring> android_value,
+ JniHelper::NewStringUTF(jenv_, "android"));
+ android_ = MakeGlobalRef(android_value.get(), jenv_, jni_cache_->jvm);
+ if (string_ == nullptr || android_ == nullptr) {
+ TC3_LOG(ERROR) << "Could not allocate constant strings references.";
+ return false;
+ }
+ return true;
+}
+
+bool JniLuaEnvironment::Initialize() {
+ if (!PreallocateConstantJniStrings()) {
+ return false;
+ }
+ return (RunProtected([this] {
+ LoadDefaultLibraries();
+ SetupExternalHook();
+ lua_setglobal(state_, "external");
+ return LUA_OK;
+ }) == LUA_OK);
+}
+
+void JniLuaEnvironment::SetupExternalHook() {
+ // This exposes an `external` object with the following fields:
+ // * entity: the bundle with all information about a classification.
+ // * android: callbacks into specific android provided methods.
+ // * android.user_restrictions: callbacks to check user permissions.
+ // * android.R: callbacks to retrieve string resources.
+ PushLazyObject(&JniLuaEnvironment::HandleExternalCallback);
+
+ // android
+ PushLazyObject(&JniLuaEnvironment::HandleAndroidCallback);
+ {
+ // android.user_restrictions
+ PushLazyObject(&JniLuaEnvironment::HandleUserRestrictionsCallback);
+ lua_setfield(state_, /*idx=*/-2, "user_restrictions");
+
+ // android.R
+ // Callback to access android string resources.
+ PushLazyObject(&JniLuaEnvironment::HandleAndroidStringResources);
+ lua_setfield(state_, /*idx=*/-2, "R");
+ }
+ lua_setfield(state_, /*idx=*/-2, "android");
+}
+
+int JniLuaEnvironment::HandleExternalCallback() {
+ const StringPiece key = ReadString(kIndexStackTop);
+ if (key.Equals(kHashKey)) {
+ PushFunction(&JniLuaEnvironment::HandleHash);
+ return 1;
+ } else if (key.Equals(kFormatKey)) {
+ PushFunction(&JniLuaEnvironment::HandleFormat);
+ return 1;
+ } else {
+ TC3_LOG(ERROR) << "Undefined external access " << key;
+ lua_error(state_);
+ return 0;
+ }
+}
+
+int JniLuaEnvironment::HandleAndroidCallback() {
+ const StringPiece key = ReadString(kIndexStackTop);
+ if (key.Equals(kDeviceLocaleKey)) {
+ // Provide the locale as table with the individual fields set.
+ lua_newtable(state_);
+ for (int i = 0; i < device_locales_.size(); i++) {
+ // Adjust index to 1-based indexing for Lua.
+ lua_pushinteger(state_, i + 1);
+ lua_newtable(state_);
+ PushString(device_locales_[i].Language());
+ lua_setfield(state_, -2, "language");
+ PushString(device_locales_[i].Region());
+ lua_setfield(state_, -2, "region");
+ PushString(device_locales_[i].Script());
+ lua_setfield(state_, -2, "script");
+ lua_settable(state_, /*idx=*/-3);
+ }
+ return 1;
+ } else if (key.Equals(kPackageNameKey)) {
+ if (context_ == nullptr) {
+ TC3_LOG(ERROR) << "Context invalid.";
+ lua_error(state_);
+ return 0;
+ }
+
+ StatusOr<ScopedLocalRef<jstring>> status_or_package_name_str =
+ JniHelper::CallObjectMethod<jstring>(
+ jenv_, context_, jni_cache_->context_get_package_name);
+
+ if (!status_or_package_name_str.ok()) {
+ TC3_LOG(ERROR) << "Error calling Context.getPackageName";
+ lua_error(state_);
+ return 0;
+ }
+ StatusOr<std::string> status_or_package_name_std_str = JStringToUtf8String(
+ jenv_, status_or_package_name_str.ValueOrDie().get());
+ if (!status_or_package_name_std_str.ok()) {
+ lua_error(state_);
+ return 0;
+ }
+ PushString(status_or_package_name_std_str.ValueOrDie());
+ return 1;
+ } else if (key.Equals(kUrlEncodeKey)) {
+ PushFunction(&JniLuaEnvironment::HandleUrlEncode);
+ return 1;
+ } else if (key.Equals(kUrlHostKey)) {
+ PushFunction(&JniLuaEnvironment::HandleUrlHost);
+ return 1;
+ } else if (key.Equals(kUrlSchemaKey)) {
+ PushFunction(&JniLuaEnvironment::HandleUrlSchema);
+ return 1;
+ } else {
+ TC3_LOG(ERROR) << "Undefined android reference " << key;
+ lua_error(state_);
+ return 0;
+ }
+}
+
+int JniLuaEnvironment::HandleUserRestrictionsCallback() {
+ if (jni_cache_->usermanager_class == nullptr ||
+ jni_cache_->usermanager_get_user_restrictions == nullptr) {
+ // UserManager is only available for API level >= 17 and
+ // getUserRestrictions only for API level >= 18, so we just return false
+ // normally here.
+ lua_pushboolean(state_, false);
+ return 1;
+ }
+
+ // Get user manager if not previously retrieved.
+ if (!RetrieveUserManager()) {
+ TC3_LOG(ERROR) << "Error retrieving user manager.";
+ lua_error(state_);
+ return 0;
+ }
+
+ StatusOr<ScopedLocalRef<jobject>> status_or_bundle =
+ JniHelper::CallObjectMethod(
+ jenv_, usermanager_.get(),
+ jni_cache_->usermanager_get_user_restrictions);
+ if (!status_or_bundle.ok() || status_or_bundle.ValueOrDie() == nullptr) {
+ TC3_LOG(ERROR) << "Error calling getUserRestrictions";
+ lua_error(state_);
+ return 0;
+ }
+
+ const StringPiece key_str = ReadString(kIndexStackTop);
+ if (key_str.empty()) {
+ TC3_LOG(ERROR) << "Expected string, got null.";
+ lua_error(state_);
+ return 0;
+ }
+
+ const StatusOr<ScopedLocalRef<jstring>> status_or_key =
+ jni_cache_->ConvertToJavaString(key_str);
+ if (!status_or_key.ok()) {
+ lua_error(state_);
+ return 0;
+ }
+ const StatusOr<bool> status_or_permission = JniHelper::CallBooleanMethod(
+ jenv_, status_or_bundle.ValueOrDie().get(),
+ jni_cache_->bundle_get_boolean, status_or_key.ValueOrDie().get());
+ if (!status_or_permission.ok()) {
+ TC3_LOG(ERROR) << "Error getting bundle value";
+ lua_pushboolean(state_, false);
+ } else {
+ lua_pushboolean(state_, status_or_permission.ValueOrDie());
+ }
+ return 1;
+}
+
+int JniLuaEnvironment::HandleUrlEncode() {
+ const StringPiece input = ReadString(/*index=*/1);
+ if (input.empty()) {
+ TC3_LOG(ERROR) << "Expected string, got null.";
+ lua_error(state_);
+ return 0;
+ }
+
+ // Call Java Uri encode.
+ const StatusOr<ScopedLocalRef<jstring>> status_or_input_str =
+ jni_cache_->ConvertToJavaString(input);
+ if (!status_or_input_str.ok()) {
+ lua_error(state_);
+ return 0;
+ }
+ StatusOr<ScopedLocalRef<jstring>> status_or_encoded_str =
+ JniHelper::CallStaticObjectMethod<jstring>(
+ jenv_, jni_cache_->uri_class.get(), jni_cache_->uri_encode,
+ status_or_input_str.ValueOrDie().get());
+
+ if (!status_or_encoded_str.ok()) {
+ TC3_LOG(ERROR) << "Error calling Uri.encode";
+ lua_error(state_);
+ return 0;
+ }
+ const StatusOr<std::string> status_or_encoded_std_str =
+ JStringToUtf8String(jenv_, status_or_encoded_str.ValueOrDie().get());
+ if (!status_or_encoded_std_str.ok()) {
+ lua_error(state_);
+ return 0;
+ }
+ PushString(status_or_encoded_std_str.ValueOrDie());
+ return 1;
+}
+
+StatusOr<ScopedLocalRef<jobject>> JniLuaEnvironment::ParseUri(
+ StringPiece url) const {
+ if (url.empty()) {
+ return {Status::UNKNOWN};
+ }
+
+ // Call to Java URI parser.
+ TC3_ASSIGN_OR_RETURN(
+ const StatusOr<ScopedLocalRef<jstring>> status_or_url_str,
+ jni_cache_->ConvertToJavaString(url));
+
+ // Try to parse uri and get scheme.
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jobject> uri,
+ JniHelper::CallStaticObjectMethod(jenv_, jni_cache_->uri_class.get(),
+ jni_cache_->uri_parse,
+ status_or_url_str.ValueOrDie().get()));
+ if (uri == nullptr) {
+ TC3_LOG(ERROR) << "Error calling Uri.parse";
+ return {Status::UNKNOWN};
+ }
+ return uri;
+}
+
+int JniLuaEnvironment::HandleUrlSchema() {
+ StringPiece url = ReadString(/*index=*/1);
+
+ const StatusOr<ScopedLocalRef<jobject>> status_or_parsed_uri = ParseUri(url);
+ if (!status_or_parsed_uri.ok()) {
+ lua_error(state_);
+ return 0;
+ }
+
+ const StatusOr<ScopedLocalRef<jstring>> status_or_scheme_str =
+ JniHelper::CallObjectMethod<jstring>(
+ jenv_, status_or_parsed_uri.ValueOrDie().get(),
+ jni_cache_->uri_get_scheme);
+ if (!status_or_scheme_str.ok()) {
+ TC3_LOG(ERROR) << "Error calling Uri.getScheme";
+ lua_error(state_);
+ return 0;
+ }
+ if (status_or_scheme_str.ValueOrDie() == nullptr) {
+ lua_pushnil(state_);
+ } else {
+ const StatusOr<std::string> status_or_scheme_std_str =
+ JStringToUtf8String(jenv_, status_or_scheme_str.ValueOrDie().get());
+ if (!status_or_scheme_std_str.ok()) {
+ lua_error(state_);
+ return 0;
+ }
+ PushString(status_or_scheme_std_str.ValueOrDie());
+ }
+ return 1;
+}
+
+int JniLuaEnvironment::HandleUrlHost() {
+ const StringPiece url = ReadString(kIndexStackTop);
+
+ const StatusOr<ScopedLocalRef<jobject>> status_or_parsed_uri = ParseUri(url);
+ if (!status_or_parsed_uri.ok()) {
+ lua_error(state_);
+ return 0;
+ }
+
+ const StatusOr<ScopedLocalRef<jstring>> status_or_host_str =
+ JniHelper::CallObjectMethod<jstring>(
+ jenv_, status_or_parsed_uri.ValueOrDie().get(),
+ jni_cache_->uri_get_host);
+ if (!status_or_host_str.ok()) {
+ TC3_LOG(ERROR) << "Error calling Uri.getHost";
+ lua_error(state_);
+ return 0;
+ }
+
+ if (status_or_host_str.ValueOrDie() == nullptr) {
+ lua_pushnil(state_);
+ } else {
+ const StatusOr<std::string> status_or_host_std_str =
+ JStringToUtf8String(jenv_, status_or_host_str.ValueOrDie().get());
+ if (!status_or_host_std_str.ok()) {
+ lua_error(state_);
+ return 0;
+ }
+ PushString(status_or_host_std_str.ValueOrDie());
+ }
+ return 1;
+}
+
+int JniLuaEnvironment::HandleHash() {
+ const StringPiece input = ReadString(kIndexStackTop);
+ lua_pushinteger(state_, tc3farmhash::Hash32(input.data(), input.length()));
+ return 1;
+}
+
+int JniLuaEnvironment::HandleFormat() {
+ const int num_args = lua_gettop(state_);
+ std::vector<StringPiece> args(num_args - 1);
+ for (int i = 0; i < num_args - 1; i++) {
+ args[i] = ReadString(/*index=*/i + 2);
+ }
+ PushString(strings::Substitute(ReadString(/*index=*/1), args));
+ return 1;
+}
+
+bool JniLuaEnvironment::LookupModelStringResource() const {
+ // Handle only lookup by name.
+ if (lua_type(state_, kIndexStackTop) != LUA_TSTRING) {
+ return false;
+ }
+
+ const StringPiece resource_name = ReadString(kIndexStackTop);
+ std::string resource_content;
+ if (!resources_.GetResourceContent(device_locales_, resource_name,
+ &resource_content)) {
+ // Resource cannot be provided by the model.
+ return false;
+ }
+
+ PushString(resource_content);
+ return true;
+}
+
+int JniLuaEnvironment::HandleAndroidStringResources() {
+ // Check whether the requested resource can be served from the model data.
+ if (LookupModelStringResource()) {
+ return 1;
+ }
+
+ // Get system resources if not previously retrieved.
+ if (!RetrieveSystemResources()) {
+ TC3_LOG(ERROR) << "Error retrieving system resources.";
+ lua_error(state_);
+ return 0;
+ }
+
+ int resource_id;
+ switch (lua_type(state_, kIndexStackTop)) {
+ case LUA_TNUMBER:
+ resource_id = Read<int>(/*index=*/kIndexStackTop);
+ break;
+ case LUA_TSTRING: {
+ const StringPiece resource_name_str = ReadString(kIndexStackTop);
+ if (resource_name_str.empty()) {
+ TC3_LOG(ERROR) << "No resource name provided.";
+ lua_error(state_);
+ return 0;
+ }
+ const StatusOr<ScopedLocalRef<jstring>> status_or_resource_name =
+ jni_cache_->ConvertToJavaString(resource_name_str);
+ if (!status_or_resource_name.ok()) {
+ TC3_LOG(ERROR) << "Invalid resource name.";
+ lua_error(state_);
+ return 0;
+ }
+ StatusOr<int> status_or_resource_id = JniHelper::CallIntMethod(
+ jenv_, system_resources_.get(), jni_cache_->resources_get_identifier,
+ status_or_resource_name.ValueOrDie().get(), string_.get(),
+ android_.get());
+ if (!status_or_resource_id.ok()) {
+ TC3_LOG(ERROR) << "Error calling getIdentifier.";
+ lua_error(state_);
+ return 0;
+ }
+ resource_id = status_or_resource_id.ValueOrDie();
+ break;
+ }
+ default:
+ TC3_LOG(ERROR) << "Unexpected type for resource lookup.";
+ lua_error(state_);
+ return 0;
+ }
+ if (resource_id == 0) {
+ TC3_LOG(ERROR) << "Resource not found.";
+ lua_pushnil(state_);
+ return 1;
+ }
+ StatusOr<ScopedLocalRef<jstring>> status_or_resource_str =
+ JniHelper::CallObjectMethod<jstring>(jenv_, system_resources_.get(),
+ jni_cache_->resources_get_string,
+ resource_id);
+ if (!status_or_resource_str.ok()) {
+ TC3_LOG(ERROR) << "Error calling getString.";
+ lua_error(state_);
+ return 0;
+ }
+
+ if (status_or_resource_str.ValueOrDie() == nullptr) {
+ lua_pushnil(state_);
+ } else {
+ StatusOr<std::string> status_or_resource_std_str =
+ JStringToUtf8String(jenv_, status_or_resource_str.ValueOrDie().get());
+ if (!status_or_resource_std_str.ok()) {
+ lua_error(state_);
+ return 0;
+ }
+ PushString(status_or_resource_std_str.ValueOrDie());
+ }
+ return 1;
+}
+
+bool JniLuaEnvironment::RetrieveSystemResources() {
+ if (system_resources_resources_retrieved_) {
+ return (system_resources_ != nullptr);
+ }
+ system_resources_resources_retrieved_ = true;
+ TC3_ASSIGN_OR_RETURN_FALSE(ScopedLocalRef<jobject> system_resources_ref,
+ JniHelper::CallStaticObjectMethod(
+ jenv_, jni_cache_->resources_class.get(),
+ jni_cache_->resources_get_system));
+ system_resources_ =
+ MakeGlobalRef(system_resources_ref.get(), jenv_, jni_cache_->jvm);
+ return (system_resources_ != nullptr);
+}
+
+bool JniLuaEnvironment::RetrieveUserManager() {
+ if (context_ == nullptr) {
+ return false;
+ }
+ if (usermanager_retrieved_) {
+ return (usermanager_ != nullptr);
+ }
+ usermanager_retrieved_ = true;
+ TC3_ASSIGN_OR_RETURN_FALSE(const ScopedLocalRef<jstring> service,
+ JniHelper::NewStringUTF(jenv_, "user"));
+ TC3_ASSIGN_OR_RETURN_FALSE(
+ const ScopedLocalRef<jobject> usermanager_ref,
+ JniHelper::CallObjectMethod(jenv_, context_,
+ jni_cache_->context_get_system_service,
+ service.get()));
+
+ usermanager_ = MakeGlobalRef(usermanager_ref.get(), jenv_, jni_cache_->jvm);
+ return (usermanager_ != nullptr);
+}
+
+RemoteActionTemplate JniLuaEnvironment::ReadRemoteActionTemplateResult() const {
+ RemoteActionTemplate result;
+ // Read intent template.
+ lua_pushnil(state_);
+ while (Next(/*index=*/-2)) {
+ const StringPiece key = ReadString(/*index=*/-2);
+ if (key.Equals("title_without_entity")) {
+ result.title_without_entity = Read<std::string>(/*index=*/kIndexStackTop);
+ } else if (key.Equals("title_with_entity")) {
+ result.title_with_entity = Read<std::string>(/*index=*/kIndexStackTop);
+ } else if (key.Equals("description")) {
+ result.description = Read<std::string>(/*index=*/kIndexStackTop);
+ } else if (key.Equals("description_with_app_name")) {
+ result.description_with_app_name =
+ Read<std::string>(/*index=*/kIndexStackTop);
+ } else if (key.Equals("action")) {
+ result.action = Read<std::string>(/*index=*/kIndexStackTop);
+ } else if (key.Equals("data")) {
+ result.data = Read<std::string>(/*index=*/kIndexStackTop);
+ } else if (key.Equals("type")) {
+ result.type = Read<std::string>(/*index=*/kIndexStackTop);
+ } else if (key.Equals("flags")) {
+ result.flags = Read<int>(/*index=*/kIndexStackTop);
+ } else if (key.Equals("package_name")) {
+ result.package_name = Read<std::string>(/*index=*/kIndexStackTop);
+ } else if (key.Equals("request_code")) {
+ result.request_code = Read<int>(/*index=*/kIndexStackTop);
+ } else if (key.Equals("category")) {
+ result.category = ReadVector<std::string>(/*index=*/kIndexStackTop);
+ } else if (key.Equals("extra")) {
+ result.extra = ReadExtras();
+ } else {
+ TC3_LOG(INFO) << "Unknown entry: " << key;
+ }
+ lua_pop(state_, 1);
+ }
+ lua_pop(state_, 1);
+ return result;
+}
+
+std::map<std::string, Variant> JniLuaEnvironment::ReadExtras() const {
+ if (lua_type(state_, kIndexStackTop) != LUA_TTABLE) {
+ TC3_LOG(ERROR) << "Expected extras table, got: "
+ << lua_type(state_, kIndexStackTop);
+ lua_pop(state_, 1);
+ return {};
+ }
+ std::map<std::string, Variant> extras;
+ lua_pushnil(state_);
+ while (Next(/*index=*/-2)) {
+ // Each entry is a table specifying name and value.
+ // The value is specified via a type specific field as Lua doesn't allow
+ // to easily distinguish between different number types.
+ if (lua_type(state_, kIndexStackTop) != LUA_TTABLE) {
+ TC3_LOG(ERROR) << "Expected a table for an extra, got: "
+ << lua_type(state_, kIndexStackTop);
+ lua_pop(state_, 1);
+ return {};
+ }
+ std::string name;
+ Variant value;
+
+ lua_pushnil(state_);
+ while (Next(/*index=*/-2)) {
+ const StringPiece key = ReadString(/*index=*/-2);
+ if (key.Equals("name")) {
+ name = Read<std::string>(/*index=*/kIndexStackTop);
+ } else if (key.Equals("int_value")) {
+ value = Variant(Read<int>(/*index=*/kIndexStackTop));
+ } else if (key.Equals("long_value")) {
+ value = Variant(Read<int64>(/*index=*/kIndexStackTop));
+ } else if (key.Equals("float_value")) {
+ value = Variant(Read<float>(/*index=*/kIndexStackTop));
+ } else if (key.Equals("bool_value")) {
+ value = Variant(Read<bool>(/*index=*/kIndexStackTop));
+ } else if (key.Equals("string_value")) {
+ value = Variant(Read<std::string>(/*index=*/kIndexStackTop));
+ } else if (key.Equals("string_array_value")) {
+ value = Variant(ReadVector<std::string>(/*index=*/kIndexStackTop));
+ } else if (key.Equals("float_array_value")) {
+ value = Variant(ReadVector<float>(/*index=*/kIndexStackTop));
+ } else if (key.Equals("int_array_value")) {
+ value = Variant(ReadVector<int>(/*index=*/kIndexStackTop));
+ } else if (key.Equals("named_variant_array_value")) {
+ value = Variant(ReadExtras());
+ } else {
+ TC3_LOG(INFO) << "Unknown extra field: " << key;
+ }
+ lua_pop(state_, 1);
+ }
+ if (!name.empty()) {
+ extras[name] = value;
+ } else {
+ TC3_LOG(ERROR) << "Unnamed extra entry. Skipping.";
+ }
+ lua_pop(state_, 1);
+ }
+ return extras;
+}
+
+int JniLuaEnvironment::ReadRemoteActionTemplates(
+ std::vector<RemoteActionTemplate>* result) {
+ // Read result.
+ if (lua_type(state_, kIndexStackTop) != LUA_TTABLE) {
+ TC3_LOG(ERROR) << "Unexpected result for snippet: "
+ << lua_type(state_, kIndexStackTop);
+ lua_error(state_);
+ return LUA_ERRRUN;
+ }
+
+ // Read remote action templates array.
+ lua_pushnil(state_);
+ while (Next(/*index=*/-2)) {
+ if (lua_type(state_, kIndexStackTop) != LUA_TTABLE) {
+ TC3_LOG(ERROR) << "Expected intent table, got: "
+ << lua_type(state_, kIndexStackTop);
+ lua_pop(state_, 1);
+ continue;
+ }
+ result->push_back(ReadRemoteActionTemplateResult());
+ }
+ lua_pop(state_, /*n=*/1);
+ return LUA_OK;
+}
+
+bool JniLuaEnvironment::RunIntentGenerator(
+ const std::string& generator_snippet,
+ std::vector<RemoteActionTemplate>* remote_actions) {
+ int status;
+ status = luaL_loadbuffer(state_, generator_snippet.data(),
+ generator_snippet.size(),
+ /*name=*/nullptr);
+ if (status != LUA_OK) {
+ TC3_LOG(ERROR) << "Couldn't load generator snippet: " << status;
+ return false;
+ }
+ status = lua_pcall(state_, /*nargs=*/0, /*nresults=*/1, /*errfunc=*/0);
+ if (status != LUA_OK) {
+ TC3_LOG(ERROR) << "Couldn't run generator snippet: " << status;
+ return false;
+ }
+ if (RunProtected(
+ [this, remote_actions] {
+ return ReadRemoteActionTemplates(remote_actions);
+ },
+ /*num_args=*/1) != LUA_OK) {
+ TC3_LOG(ERROR) << "Could not read results.";
+ return false;
+ }
+ // Check that we correctly cleaned-up the state.
+ const int stack_size = lua_gettop(state_);
+ if (stack_size > 0) {
+ TC3_LOG(ERROR) << "Unexpected stack size.";
+ lua_settop(state_, 0);
+ return false;
+ }
+ return true;
+}
+
+} // namespace libtextclassifier3
diff --git a/native/utils/intents/jni-lua.h b/native/utils/intents/jni-lua.h
new file mode 100644
index 0000000..ab7bc96
--- /dev/null
+++ b/native/utils/intents/jni-lua.h
@@ -0,0 +1,104 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_INTENTS_JNI_LUA_H_
+#define LIBTEXTCLASSIFIER_UTILS_INTENTS_JNI_LUA_H_
+
+#include <map>
+#include <vector>
+
+#include "utils/base/statusor.h"
+#include "utils/i18n/locale.h"
+#include "utils/intents/remote-action-template.h"
+#include "utils/java/jni-base.h"
+#include "utils/java/jni-cache.h"
+#include "utils/lua-utils.h"
+#include "utils/resources.h"
+#include "utils/strings/stringpiece.h"
+#include "utils/variant.h"
+
+namespace libtextclassifier3 {
+
+// An Android specific Lua environment with JNI backed callbacks.
+class JniLuaEnvironment : public LuaEnvironment {
+ public:
+ JniLuaEnvironment(const Resources& resources, const JniCache* jni_cache,
+ const jobject context,
+ const std::vector<Locale>& device_locales);
+ // Environment setup.
+ bool Initialize();
+
+ // Runs an intent generator snippet.
+ bool RunIntentGenerator(const std::string& generator_snippet,
+ std::vector<RemoteActionTemplate>* remote_actions);
+
+ protected:
+ virtual void SetupExternalHook();
+ bool PreallocateConstantJniStrings();
+
+ int HandleExternalCallback();
+ int HandleAndroidCallback();
+ int HandleUserRestrictionsCallback();
+ int HandleUrlEncode();
+ int HandleUrlSchema();
+ int HandleHash();
+ int HandleFormat();
+ int HandleAndroidStringResources();
+ int HandleUrlHost();
+
+ // Checks and retrieves string resources from the model.
+ bool LookupModelStringResource() const;
+
+ // Reads and create a RemoteAction result from Lua.
+ RemoteActionTemplate ReadRemoteActionTemplateResult() const;
+
+ // Reads the extras from the Lua result.
+ std::map<std::string, Variant> ReadExtras() const;
+
+ // Retrieves user manager if not previously done.
+ bool RetrieveUserManager();
+
+ // Retrieves system resources if not previously done.
+ bool RetrieveSystemResources();
+
+ // Parse the url string by using Uri.parse from Java.
+ StatusOr<ScopedLocalRef<jobject>> ParseUri(StringPiece url) const;
+
+ // Read remote action templates from lua generator.
+ int ReadRemoteActionTemplates(std::vector<RemoteActionTemplate>* result);
+
+ const Resources& resources_;
+ JNIEnv* jenv_;
+ const JniCache* jni_cache_;
+ const jobject context_;
+ std::vector<Locale> device_locales_;
+
+ ScopedGlobalRef<jobject> usermanager_;
+ // Whether we previously attempted to retrieve the UserManager before.
+ bool usermanager_retrieved_;
+
+ ScopedGlobalRef<jobject> system_resources_;
+ // Whether we previously attempted to retrieve the system resources.
+ bool system_resources_resources_retrieved_;
+
+ // Cached JNI references for Java strings `string` and `android`.
+ ScopedGlobalRef<jstring> string_;
+ ScopedGlobalRef<jstring> android_;
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_INTENTS_JNI_LUA_H_
diff --git a/native/utils/intents/jni.cc b/native/utils/intents/jni.cc
index 051d078..c95f03b 100644
--- a/native/utils/intents/jni.cc
+++ b/native/utils/intents/jni.cc
@@ -18,6 +18,7 @@
#include <memory>
+#include "utils/base/status_macros.h"
#include "utils/base/statusor.h"
#include "utils/java/jni-base.h"
#include "utils/java/jni-helper.h"
@@ -27,19 +28,19 @@
// The macros below are intended to reduce the boilerplate and avoid
// easily introduced copy/paste errors.
#define TC3_CHECK_JNI_PTR(PTR) TC3_CHECK((PTR) != nullptr)
-#define TC3_GET_CLASS(FIELD, NAME) \
- { \
- StatusOr<ScopedLocalRef<jclass>> status_or_clazz = \
- JniHelper::FindClass(env, NAME); \
- handler->FIELD = MakeGlobalRef(status_or_clazz.ValueOrDie().release(), \
- env, jni_cache->jvm); \
- TC3_CHECK_JNI_PTR(handler->FIELD) << "Error finding class: " << NAME; \
+#define TC3_GET_CLASS(FIELD, NAME) \
+ { \
+ TC3_ASSIGN_OR_RETURN(ScopedLocalRef<jclass> clazz, \
+ JniHelper::FindClass(env, NAME)); \
+ handler->FIELD = MakeGlobalRef(clazz.release(), env, jni_cache->jvm); \
+ TC3_CHECK_JNI_PTR(handler->FIELD) << "Error finding class: " << NAME; \
}
-#define TC3_GET_METHOD(CLASS, FIELD, NAME, SIGNATURE) \
- handler->FIELD = env->GetMethodID(handler->CLASS.get(), NAME, SIGNATURE); \
- TC3_CHECK(handler->FIELD) << "Error finding method: " << NAME;
+#define TC3_GET_METHOD(CLASS, FIELD, NAME, SIGNATURE) \
+ TC3_ASSIGN_OR_RETURN( \
+ handler->FIELD, \
+ JniHelper::GetMethodID(env, handler->CLASS.get(), NAME, SIGNATURE));
-std::unique_ptr<RemoteActionTemplatesHandler>
+StatusOr<std::unique_ptr<RemoteActionTemplatesHandler>>
RemoteActionTemplatesHandler::Create(
const std::shared_ptr<JniCache>& jni_cache) {
JNIEnv* env = jni_cache->GetEnv();
@@ -127,8 +128,8 @@
for (int k = 0; k < values.size(); k++) {
TC3_ASSIGN_OR_RETURN(ScopedLocalRef<jstring> value_str,
jni_cache_->ConvertToJavaString(values[k]));
- jni_cache_->GetEnv()->SetObjectArrayElement(result.get(), k,
- value_str.get());
+ TC3_RETURN_IF_ERROR(JniHelper::SetObjectArrayElement(
+ jni_cache_->GetEnv(), result.get(), k, value_str.get()));
}
return result;
}
@@ -144,9 +145,9 @@
ScopedLocalRef<jfloatArray> result,
JniHelper::NewFloatArray(jni_cache_->GetEnv(), values.size()));
- jni_cache_->GetEnv()->SetFloatArrayRegion(result.get(), /*start=*/0,
- /*len=*/values.size(),
- &(values[0]));
+ TC3_RETURN_IF_ERROR(JniHelper::SetFloatArrayRegion(
+ jni_cache_->GetEnv(), result.get(), /*start=*/0,
+ /*len=*/values.size(), &(values[0])));
return result;
}
@@ -160,8 +161,9 @@
ScopedLocalRef<jintArray> result,
JniHelper::NewIntArray(jni_cache_->GetEnv(), values.size()));
- jni_cache_->GetEnv()->SetIntArrayRegion(result.get(), /*start=*/0,
- /*len=*/values.size(), &(values[0]));
+ TC3_RETURN_IF_ERROR(JniHelper::SetIntArrayRegion(
+ jni_cache_->GetEnv(), result.get(), /*start=*/0,
+ /*len=*/values.size(), &(values[0])));
return result;
}
@@ -275,8 +277,8 @@
TC3_ASSIGN_OR_RETURN(
StatusOr<ScopedLocalRef<jobject>> named_extra,
AsNamedVariant(key_value_pair.first, key_value_pair.second));
- env->SetObjectArrayElement(result.get(), element_index,
- named_extra.ValueOrDie().get());
+ TC3_RETURN_IF_ERROR(JniHelper::SetObjectArrayElement(
+ env, result.get(), element_index, named_extra.ValueOrDie().get()));
element_index++;
}
return result;
@@ -335,7 +337,8 @@
type.ValueOrDie().get(), flags.ValueOrDie().get(),
category.ValueOrDie().get(), package.ValueOrDie().get(),
extra.ValueOrDie().get(), request_code.ValueOrDie().get()));
- env->SetObjectArrayElement(results.get(), i, result.get());
+ TC3_RETURN_IF_ERROR(
+ JniHelper::SetObjectArrayElement(env, results.get(), i, result.get()));
}
return results;
}
@@ -344,8 +347,8 @@
RemoteActionTemplatesHandler::EntityDataAsNamedVariantArray(
const reflection::Schema* entity_data_schema,
const std::string& serialized_entity_data) const {
- ReflectiveFlatbufferBuilder entity_data_builder(entity_data_schema);
- std::unique_ptr<ReflectiveFlatbuffer> buffer = entity_data_builder.NewRoot();
+ MutableFlatbufferBuilder entity_data_builder(entity_data_schema);
+ std::unique_ptr<MutableFlatbuffer> buffer = entity_data_builder.NewRoot();
buffer->MergeFromSerializedFlatbuffer(serialized_entity_data);
std::map<std::string, Variant> entity_data_map = buffer->AsFlatMap();
return AsNamedVariantArray(entity_data_map);
diff --git a/native/utils/intents/jni.h b/native/utils/intents/jni.h
index ada2631..895c63d 100644
--- a/native/utils/intents/jni.h
+++ b/native/utils/intents/jni.h
@@ -25,7 +25,8 @@
#include <vector>
#include "utils/base/statusor.h"
-#include "utils/flatbuffers.h"
+#include "utils/flatbuffers/flatbuffers.h"
+#include "utils/flatbuffers/mutable.h"
#include "utils/intents/remote-action-template.h"
#include "utils/java/jni-base.h"
#include "utils/java/jni-cache.h"
@@ -51,7 +52,7 @@
// A helper class to create RemoteActionTemplate object from model results.
class RemoteActionTemplatesHandler {
public:
- static std::unique_ptr<RemoteActionTemplatesHandler> Create(
+ static StatusOr<std::unique_ptr<RemoteActionTemplatesHandler>> Create(
const std::shared_ptr<JniCache>& jni_cache);
StatusOr<ScopedLocalRef<jstring>> AsUTF8String(
diff --git a/native/utils/java/jni-base.cc b/native/utils/java/jni-base.cc
index e0829b7..39ade45 100644
--- a/native/utils/java/jni-base.cc
+++ b/native/utils/java/jni-base.cc
@@ -17,7 +17,6 @@
#include "utils/java/jni-base.h"
#include "utils/base/status.h"
-#include "utils/java/string_utils.h"
namespace libtextclassifier3 {
@@ -25,22 +24,16 @@
return env->EnsureLocalCapacity(capacity) == JNI_OK;
}
-bool JniExceptionCheckAndClear(JNIEnv* env) {
+bool JniExceptionCheckAndClear(JNIEnv* env, bool print_exception_on_error) {
TC3_CHECK(env != nullptr);
const bool result = env->ExceptionCheck();
if (result) {
- env->ExceptionDescribe();
+ if (print_exception_on_error) {
+ env->ExceptionDescribe();
+ }
env->ExceptionClear();
}
return result;
}
-StatusOr<std::string> ToStlString(JNIEnv* env, const jstring& str) {
- std::string result;
- if (!JStringToUtf8String(env, str, &result)) {
- return {Status::UNKNOWN};
- }
- return result;
-}
-
} // namespace libtextclassifier3
diff --git a/native/utils/java/jni-base.h b/native/utils/java/jni-base.h
index c7b04e6..211000a 100644
--- a/native/utils/java/jni-base.h
+++ b/native/utils/java/jni-base.h
@@ -65,9 +65,8 @@
bool EnsureLocalCapacity(JNIEnv* env, int capacity);
// Returns true if there was an exception. Also it clears the exception.
-bool JniExceptionCheckAndClear(JNIEnv* env);
-
-StatusOr<std::string> ToStlString(JNIEnv* env, const jstring& str);
+bool JniExceptionCheckAndClear(JNIEnv* env,
+ bool print_exception_on_error = true);
// A deleter to be used with std::unique_ptr to delete JNI global references.
class GlobalRefDeleter {
diff --git a/native/utils/java/jni-cache.cc b/native/utils/java/jni-cache.cc
index 0be769d..824141a 100644
--- a/native/utils/java/jni-cache.cc
+++ b/native/utils/java/jni-cache.cc
@@ -17,6 +17,7 @@
#include "utils/java/jni-cache.h"
#include "utils/base/logging.h"
+#include "utils/base/status_macros.h"
#include "utils/java/jni-base.h"
#include "utils/java/jni-helper.h"
@@ -33,8 +34,7 @@
breakiterator_class(nullptr, jvm),
integer_class(nullptr, jvm),
calendar_class(nullptr, jvm),
- timezone_class(nullptr, jvm),
- urlencoder_class(nullptr, jvm)
+ timezone_class(nullptr, jvm)
#ifdef __ANDROID__
,
context_class(nullptr, jvm),
@@ -72,59 +72,61 @@
} \
}
-#define TC3_GET_METHOD(CLASS, FIELD, NAME, SIGNATURE) \
- result->CLASS##_##FIELD = \
- env->GetMethodID(result->CLASS##_class.get(), NAME, SIGNATURE); \
- TC3_CHECK_JNI_RESULT(result->CLASS##_##FIELD) \
- << "Error finding method: " << NAME;
+#define TC3_GET_METHOD(CLASS, FIELD, NAME, SIGNATURE) \
+ TC3_ASSIGN_OR_RETURN_NULL( \
+ result->CLASS##_##FIELD, \
+ JniHelper::GetMethodID(env, result->CLASS##_class.get(), NAME, \
+ SIGNATURE));
-#define TC3_GET_OPTIONAL_METHOD(CLASS, FIELD, NAME, SIGNATURE) \
- if (result->CLASS##_class != nullptr) { \
- result->CLASS##_##FIELD = \
- env->GetMethodID(result->CLASS##_class.get(), NAME, SIGNATURE); \
- env->ExceptionClear(); \
+#define TC3_GET_OPTIONAL_METHOD(CLASS, FIELD, NAME, SIGNATURE) \
+ TC3_GET_OPTIONAL_METHOD_INTERNAL(CLASS, FIELD, NAME, SIGNATURE, GetMethodID)
+
+#define TC3_GET_OPTIONAL_STATIC_METHOD(CLASS, FIELD, NAME, SIGNATURE) \
+ TC3_GET_OPTIONAL_METHOD_INTERNAL(CLASS, FIELD, NAME, SIGNATURE, \
+ GetStaticMethodID)
+
+#define TC3_GET_OPTIONAL_METHOD_INTERNAL(CLASS, FIELD, NAME, SIGNATURE, \
+ METHOD_NAME) \
+ if (result->CLASS##_class != nullptr) { \
+ if (StatusOr<jmethodID> status_or_method_id = JniHelper::METHOD_NAME( \
+ env, result->CLASS##_class.get(), NAME, SIGNATURE); \
+ status_or_method_id.ok()) { \
+ result->CLASS##_##FIELD = status_or_method_id.ValueOrDie(); \
+ } \
}
-#define TC3_GET_OPTIONAL_STATIC_METHOD(CLASS, FIELD, NAME, SIGNATURE) \
- if (result->CLASS##_class != nullptr) { \
- result->CLASS##_##FIELD = \
- env->GetStaticMethodID(result->CLASS##_class.get(), NAME, SIGNATURE); \
- env->ExceptionClear(); \
+#define TC3_GET_STATIC_METHOD(CLASS, FIELD, NAME, SIGNATURE) \
+ TC3_ASSIGN_OR_RETURN_NULL( \
+ result->CLASS##_##FIELD, \
+ JniHelper::GetStaticMethodID(env, result->CLASS##_class.get(), NAME, \
+ SIGNATURE));
+
+#define TC3_GET_STATIC_OBJECT_FIELD_OR_RETURN_NULL(CLASS, FIELD, NAME, \
+ SIGNATURE) \
+ { \
+ TC3_ASSIGN_OR_RETURN_NULL( \
+ const jfieldID CLASS##_##FIELD##_field, \
+ JniHelper::GetStaticFieldID(env, result->CLASS##_class.get(), NAME, \
+ SIGNATURE)); \
+ TC3_ASSIGN_OR_RETURN_NULL( \
+ ScopedLocalRef<jobject> static_object, \
+ JniHelper::GetStaticObjectField(env, result->CLASS##_class.get(), \
+ CLASS##_##FIELD##_field)); \
+ result->CLASS##_##FIELD = MakeGlobalRef(static_object.get(), env, jvm); \
+ if (result->CLASS##_##FIELD == nullptr) { \
+ TC3_LOG(ERROR) << "Error finding field: " << NAME; \
+ return nullptr; \
+ } \
}
-#define TC3_GET_STATIC_METHOD(CLASS, FIELD, NAME, SIGNATURE) \
- result->CLASS##_##FIELD = \
- env->GetStaticMethodID(result->CLASS##_class.get(), NAME, SIGNATURE); \
- TC3_CHECK_JNI_RESULT(result->CLASS##_##FIELD) \
- << "Error finding method: " << NAME;
-
-#define TC3_GET_STATIC_OBJECT_FIELD_OR_RETURN_NULL(CLASS, FIELD, NAME, \
- SIGNATURE) \
- { \
- const jfieldID CLASS##_##FIELD##_field = \
- env->GetStaticFieldID(result->CLASS##_class.get(), NAME, SIGNATURE); \
- TC3_CHECK_JNI_RESULT(CLASS##_##FIELD##_field) \
- << "Error finding field id: " << NAME; \
- TC3_ASSIGN_OR_RETURN_NULL( \
- ScopedLocalRef<jobject> static_object, \
- JniHelper::GetStaticObjectField(env, result->CLASS##_class.get(), \
- CLASS##_##FIELD##_field)); \
- result->CLASS##_##FIELD = MakeGlobalRef(static_object.get(), env, jvm); \
- if (result->CLASS##_##FIELD == nullptr) { \
- TC3_LOG(ERROR) << "Error finding field: " << NAME; \
- return nullptr; \
- } \
- }
-
-#define TC3_GET_STATIC_INT_FIELD(CLASS, FIELD, NAME) \
- const jfieldID CLASS##_##FIELD##_field = \
- env->GetStaticFieldID(result->CLASS##_class.get(), NAME, "I"); \
- TC3_CHECK_JNI_RESULT(CLASS##_##FIELD##_field) \
- << "Error finding field id: " << NAME; \
- result->CLASS##_##FIELD = env->GetStaticIntField( \
- result->CLASS##_class.get(), CLASS##_##FIELD##_field); \
- TC3_CHECK_JNI_RESULT(result->CLASS##_##FIELD) \
- << "Error finding field: " << NAME;
+#define TC3_GET_STATIC_INT_FIELD(CLASS, FIELD, NAME) \
+ TC3_ASSIGN_OR_RETURN_NULL(const jfieldID CLASS##_##FIELD##_field, \
+ JniHelper::GetStaticFieldID( \
+ env, result->CLASS##_class.get(), NAME, "I")); \
+ TC3_ASSIGN_OR_RETURN_NULL( \
+ result->CLASS##_##FIELD, \
+ JniHelper::GetStaticIntField(env, result->CLASS##_class.get(), \
+ CLASS##_##FIELD##_field));
std::unique_ptr<JniCache> JniCache::Create(JNIEnv* env) {
if (env == nullptr) {
@@ -219,12 +221,6 @@
TC3_GET_STATIC_METHOD(timezone, get_timezone, "getTimeZone",
"(Ljava/lang/String;)Ljava/util/TimeZone;");
- // URLEncoder.
- TC3_GET_CLASS_OR_RETURN_NULL(urlencoder, "java/net/URLEncoder");
- TC3_GET_STATIC_METHOD(
- urlencoder, encode, "encode",
- "(Ljava/lang/String;Ljava/lang/String;)Ljava/lang/String;");
-
#ifdef __ANDROID__
// Context.
TC3_GET_CLASS_OR_RETURN_NULL(context, "android/content/Context");
@@ -239,6 +235,8 @@
"(Ljava/lang/String;)Landroid/net/Uri;");
TC3_GET_METHOD(uri, get_scheme, "getScheme", "()Ljava/lang/String;");
TC3_GET_METHOD(uri, get_host, "getHost", "()Ljava/lang/String;");
+ TC3_GET_STATIC_METHOD(uri, encode, "encode",
+ "(Ljava/lang/String;)Ljava/lang/String;");
// UserManager.
TC3_GET_OPTIONAL_CLASS(usermanager, "android/os/UserManager");
@@ -290,8 +288,9 @@
TC3_ASSIGN_OR_RETURN(ScopedLocalRef<jbyteArray> text_java_utf8,
JniHelper::NewByteArray(jenv, utf8_text_size_bytes));
- jenv->SetByteArrayRegion(text_java_utf8.get(), 0, utf8_text_size_bytes,
- reinterpret_cast<const jbyte*>(utf8_text));
+ TC3_RETURN_IF_ERROR(JniHelper::SetByteArrayRegion(
+ jenv, text_java_utf8.get(), 0, utf8_text_size_bytes,
+ reinterpret_cast<const jbyte*>(utf8_text)));
// Create the string with a UTF-8 charset.
TC3_ASSIGN_OR_RETURN(ScopedLocalRef<jstring> result,
diff --git a/native/utils/java/jni-cache.h b/native/utils/java/jni-cache.h
index ab48419..8754f4c 100644
--- a/native/utils/java/jni-cache.h
+++ b/native/utils/java/jni-cache.h
@@ -107,10 +107,6 @@
ScopedGlobalRef<jclass> timezone_class;
jmethodID timezone_get_timezone = nullptr;
- // java.net.URLEncoder
- ScopedGlobalRef<jclass> urlencoder_class;
- jmethodID urlencoder_encode = nullptr;
-
// android.content.Context
ScopedGlobalRef<jclass> context_class;
jmethodID context_get_package_name = nullptr;
@@ -121,6 +117,7 @@
jmethodID uri_parse = nullptr;
jmethodID uri_get_scheme = nullptr;
jmethodID uri_get_host = nullptr;
+ jmethodID uri_encode = nullptr;
// android.os.UserManager
ScopedGlobalRef<jclass> usermanager_class;
diff --git a/native/utils/java/jni-helper.cc b/native/utils/java/jni-helper.cc
index d1677e4..c7d8012 100644
--- a/native/utils/java/jni-helper.cc
+++ b/native/utils/java/jni-helper.cc
@@ -16,6 +16,8 @@
#include "utils/java/jni-helper.h"
+#include "utils/base/status_macros.h"
+
namespace libtextclassifier3 {
StatusOr<ScopedLocalRef<jclass>> JniHelper::FindClass(JNIEnv* env,
@@ -27,10 +29,46 @@
return result;
}
+StatusOr<ScopedLocalRef<jclass>> JniHelper::GetObjectClass(JNIEnv* env,
+ jobject object) {
+ TC3_ENSURE_LOCAL_CAPACITY_OR_RETURN;
+ ScopedLocalRef<jclass> result(env->GetObjectClass(object), env);
+ TC3_NO_EXCEPTION_OR_RETURN;
+ TC3_NOT_NULL_OR_RETURN;
+ return result;
+}
+
StatusOr<jmethodID> JniHelper::GetMethodID(JNIEnv* env, jclass clazz,
const char* method_name,
- const char* return_type) {
- jmethodID result = env->GetMethodID(clazz, method_name, return_type);
+ const char* signature) {
+ jmethodID result = env->GetMethodID(clazz, method_name, signature);
+ TC3_NO_EXCEPTION_OR_RETURN;
+ TC3_NOT_NULL_OR_RETURN;
+ return result;
+}
+
+StatusOr<jmethodID> JniHelper::GetStaticMethodID(JNIEnv* env, jclass clazz,
+ const char* method_name,
+ const char* signature) {
+ jmethodID result = env->GetStaticMethodID(clazz, method_name, signature);
+ TC3_NO_EXCEPTION_OR_RETURN;
+ TC3_NOT_NULL_OR_RETURN;
+ return result;
+}
+
+StatusOr<jfieldID> JniHelper::GetFieldID(JNIEnv* env, jclass clazz,
+ const char* field_name,
+ const char* signature) {
+ jfieldID result = env->GetFieldID(clazz, field_name, signature);
+ TC3_NO_EXCEPTION_OR_RETURN;
+ TC3_NOT_NULL_OR_RETURN;
+ return result;
+}
+
+StatusOr<jfieldID> JniHelper::GetStaticFieldID(JNIEnv* env, jclass clazz,
+ const char* field_name,
+ const char* signature) {
+ jfieldID result = env->GetStaticFieldID(clazz, field_name, signature);
TC3_NO_EXCEPTION_OR_RETURN;
TC3_NOT_NULL_OR_RETURN;
return result;
@@ -46,6 +84,14 @@
return result;
}
+StatusOr<jint> JniHelper::GetStaticIntField(JNIEnv* env, jclass class_name,
+ jfieldID field_id) {
+ TC3_ENSURE_LOCAL_CAPACITY_OR_RETURN;
+ jint result = env->GetStaticIntField(class_name, field_id);
+ TC3_NO_EXCEPTION_OR_RETURN;
+ return result;
+}
+
StatusOr<ScopedLocalRef<jbyteArray>> JniHelper::NewByteArray(JNIEnv* env,
jsize length) {
TC3_ENSURE_LOCAL_CAPACITY_OR_RETURN;
@@ -147,6 +193,46 @@
return Status::OK;
}
+StatusOr<jsize> JniHelper::GetArrayLength(JNIEnv* env, jarray array) {
+ TC3_ENSURE_LOCAL_CAPACITY_OR_RETURN;
+ jsize result = env->GetArrayLength(array);
+ TC3_NO_EXCEPTION_OR_RETURN;
+ return result;
+}
+
+Status JniHelper::GetByteArrayRegion(JNIEnv* env, jbyteArray array, jsize start,
+ jsize len, jbyte* buf) {
+ TC3_ENSURE_LOCAL_CAPACITY_OR_RETURN;
+ env->GetByteArrayRegion(array, start, len, buf);
+ TC3_NO_EXCEPTION_OR_RETURN;
+ return Status::OK;
+}
+
+Status JniHelper::SetByteArrayRegion(JNIEnv* env, jbyteArray array, jsize start,
+ jsize len, const jbyte* buf) {
+ TC3_ENSURE_LOCAL_CAPACITY_OR_RETURN;
+ env->SetByteArrayRegion(array, start, len, buf);
+ TC3_NO_EXCEPTION_OR_RETURN;
+ return Status::OK;
+}
+
+Status JniHelper::SetIntArrayRegion(JNIEnv* env, jintArray array, jsize start,
+ jsize len, const jint* buf) {
+ TC3_ENSURE_LOCAL_CAPACITY_OR_RETURN;
+ env->SetIntArrayRegion(array, start, len, buf);
+ TC3_NO_EXCEPTION_OR_RETURN;
+ return Status::OK;
+}
+
+Status JniHelper::SetFloatArrayRegion(JNIEnv* env, jfloatArray array,
+ jsize start, jsize len,
+ const jfloat* buf) {
+ TC3_ENSURE_LOCAL_CAPACITY_OR_RETURN;
+ env->SetFloatArrayRegion(array, start, len, buf);
+ TC3_NO_EXCEPTION_OR_RETURN;
+ return Status::OK;
+}
+
StatusOr<ScopedLocalRef<jobjectArray>> JniHelper::NewObjectArray(
JNIEnv* env, jsize length, jclass element_class, jobject initial_element) {
TC3_ENSURE_LOCAL_CAPACITY_OR_RETURN;
@@ -157,14 +243,6 @@
return result;
}
-StatusOr<jsize> JniHelper::GetArrayLength(JNIEnv* env,
- jarray jinput_fragments) {
- TC3_ENSURE_LOCAL_CAPACITY_OR_RETURN;
- jsize result = env->GetArrayLength(jinput_fragments);
- TC3_NO_EXCEPTION_OR_RETURN;
- return result;
-}
-
StatusOr<ScopedLocalRef<jstring>> JniHelper::NewStringUTF(JNIEnv* env,
const char* bytes) {
TC3_ENSURE_LOCAL_CAPACITY_OR_RETURN;
@@ -174,4 +252,37 @@
return result;
}
+StatusOr<std::string> JByteArrayToString(JNIEnv* env, jbyteArray array) {
+ std::string result;
+ TC3_ASSIGN_OR_RETURN(const int array_length,
+ JniHelper::GetArrayLength(env, array));
+ result.resize(array_length);
+ TC3_RETURN_IF_ERROR(JniHelper::GetByteArrayRegion(
+ env, array, 0, array_length,
+ reinterpret_cast<jbyte*>(const_cast<char*>(result.data()))));
+ return result;
+}
+
+StatusOr<std::string> JStringToUtf8String(JNIEnv* env, jstring jstr) {
+ if (jstr == nullptr) {
+ return "";
+ }
+
+ TC3_ASSIGN_OR_RETURN(ScopedLocalRef<jclass> string_class,
+ JniHelper::FindClass(env, "java/lang/String"));
+ TC3_ASSIGN_OR_RETURN(
+ jmethodID get_bytes_id,
+ JniHelper::GetMethodID(env, string_class.get(), "getBytes",
+ "(Ljava/lang/String;)[B"));
+
+ TC3_ASSIGN_OR_RETURN(ScopedLocalRef<jstring> encoding,
+ JniHelper::NewStringUTF(env, "UTF-8"));
+
+ TC3_ASSIGN_OR_RETURN(ScopedLocalRef<jbyteArray> array,
+ JniHelper::CallObjectMethod<jbyteArray>(
+ env, jstr, get_bytes_id, encoding.get()));
+
+ return JByteArrayToString(env, array.get());
+}
+
} // namespace libtextclassifier3
diff --git a/native/utils/java/jni-helper.h b/native/utils/java/jni-helper.h
index 55d4696..5ac60ef 100644
--- a/native/utils/java/jni-helper.h
+++ b/native/utils/java/jni-helper.h
@@ -74,16 +74,31 @@
static StatusOr<ScopedLocalRef<jclass>> FindClass(JNIEnv* env,
const char* class_name);
+ static StatusOr<ScopedLocalRef<jclass>> GetObjectClass(JNIEnv* env,
+ jobject object);
+
template <typename T = jobject>
static StatusOr<ScopedLocalRef<T>> GetObjectArrayElement(JNIEnv* env,
jobjectArray array,
jsize index);
static StatusOr<jmethodID> GetMethodID(JNIEnv* env, jclass clazz,
const char* method_name,
- const char* return_type);
+ const char* signature);
+ static StatusOr<jmethodID> GetStaticMethodID(JNIEnv* env, jclass clazz,
+ const char* method_name,
+ const char* signature);
+
+ static StatusOr<jfieldID> GetFieldID(JNIEnv* env, jclass clazz,
+ const char* field_name,
+ const char* signature);
+ static StatusOr<jfieldID> GetStaticFieldID(JNIEnv* env, jclass clazz,
+ const char* field_name,
+ const char* signature);
static StatusOr<ScopedLocalRef<jobject>> GetStaticObjectField(
JNIEnv* env, jclass class_name, jfieldID field_id);
+ static StatusOr<jint> GetStaticIntField(JNIEnv* env, jclass class_name,
+ jfieldID field_id);
// New* methods.
TC3_DEFINE_VARIADIC_SCOPED_LOCAL_REF_ENV_METHOD(NewObject, jobject, jclass,
@@ -100,11 +115,23 @@
static StatusOr<ScopedLocalRef<jfloatArray>> NewFloatArray(JNIEnv* env,
jsize length);
- static StatusOr<jsize> GetArrayLength(JNIEnv* env, jarray jinput_fragments);
+ static StatusOr<jsize> GetArrayLength(JNIEnv* env, jarray array);
static Status SetObjectArrayElement(JNIEnv* env, jobjectArray array,
jsize index, jobject val);
+ static Status GetByteArrayRegion(JNIEnv* env, jbyteArray array, jsize start,
+ jsize len, jbyte* buf);
+
+ static Status SetByteArrayRegion(JNIEnv* env, jbyteArray array, jsize start,
+ jsize len, const jbyte* buf);
+
+ static Status SetIntArrayRegion(JNIEnv* env, jintArray array, jsize start,
+ jsize len, const jint* buf);
+
+ static Status SetFloatArrayRegion(JNIEnv* env, jfloatArray array, jsize start,
+ jsize len, const jfloat* buf);
+
// Call* methods.
TC3_DEFINE_VARIADIC_SCOPED_LOCAL_REF_ENV_METHOD(CallObjectMethod, jobject,
jobject, TC3_JNI_NO_CHECK);
@@ -125,8 +152,10 @@
jmethodID method_id, ...);
template <class T>
- static StatusOr<T> CallStaticIntMethod(JNIEnv* env, jclass clazz,
- jmethodID method_id, ...);
+ static StatusOr<T> CallStaticIntMethod(JNIEnv* env,
+ bool print_exception_on_error,
+ jclass clazz, jmethodID method_id,
+ ...);
};
template <typename T>
@@ -142,17 +171,28 @@
}
template <class T>
-StatusOr<T> JniHelper::CallStaticIntMethod(JNIEnv* env, jclass clazz,
- jmethodID method_id, ...) {
+StatusOr<T> JniHelper::CallStaticIntMethod(JNIEnv* env,
+ bool print_exception_on_error,
+ jclass clazz, jmethodID method_id,
+ ...) {
va_list args;
va_start(args, method_id);
jint result = env->CallStaticIntMethodV(clazz, method_id, args);
va_end(args);
- TC3_NO_EXCEPTION_OR_RETURN;
+ if (JniExceptionCheckAndClear(env, print_exception_on_error)) {
+ return {Status::UNKNOWN};
+ }
+
return result;
}
+// Converts Java byte[] object to std::string.
+StatusOr<std::string> JByteArrayToString(JNIEnv* env, jbyteArray array);
+
+// Converts Java String object to UTF8-encoded std::string.
+StatusOr<std::string> JStringToUtf8String(JNIEnv* env, jstring jstr);
+
} // namespace libtextclassifier3
#endif // LIBTEXTCLASSIFIER_UTILS_JAVA_JNI_HELPER_H_
diff --git a/native/utils/java/string_utils.cc b/native/utils/java/string_utils.cc
deleted file mode 100644
index ca518a0..0000000
--- a/native/utils/java/string_utils.cc
+++ /dev/null
@@ -1,75 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/java/string_utils.h"
-
-#include "utils/base/logging.h"
-
-namespace libtextclassifier3 {
-
-bool JByteArrayToString(JNIEnv* env, const jbyteArray& array,
- std::string* result) {
- jbyte* const array_bytes = env->GetByteArrayElements(array, JNI_FALSE);
- if (array_bytes == nullptr) {
- return false;
- }
-
- const int array_length = env->GetArrayLength(array);
- *result = std::string(reinterpret_cast<char*>(array_bytes), array_length);
-
- env->ReleaseByteArrayElements(array, array_bytes, JNI_ABORT);
-
- return true;
-}
-
-bool JStringToUtf8String(JNIEnv* env, const jstring& jstr,
- std::string* result) {
- if (jstr == nullptr) {
- *result = std::string();
- return true;
- }
-
- jclass string_class = env->FindClass("java/lang/String");
- if (!string_class) {
- TC3_LOG(ERROR) << "Can't find String class";
- return false;
- }
-
- jmethodID get_bytes_id =
- env->GetMethodID(string_class, "getBytes", "(Ljava/lang/String;)[B");
-
- jstring encoding = env->NewStringUTF("UTF-8");
-
- jbyteArray array = reinterpret_cast<jbyteArray>(
- env->CallObjectMethod(jstr, get_bytes_id, encoding));
-
- JByteArrayToString(env, array, result);
-
- // Release the array.
- env->DeleteLocalRef(array);
- env->DeleteLocalRef(string_class);
- env->DeleteLocalRef(encoding);
-
- return true;
-}
-
-ScopedStringChars GetScopedStringChars(JNIEnv* env, jstring string,
- jboolean* is_copy) {
- return ScopedStringChars(env->GetStringUTFChars(string, is_copy),
- StringCharsReleaser(env, string));
-}
-
-} // namespace libtextclassifier3
diff --git a/native/utils/java/string_utils.h b/native/utils/java/string_utils.h
deleted file mode 100644
index 172a938..0000000
--- a/native/utils/java/string_utils.h
+++ /dev/null
@@ -1,76 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_UTILS_JAVA_STRING_UTILS_H_
-#define LIBTEXTCLASSIFIER_UTILS_JAVA_STRING_UTILS_H_
-
-#include <jni.h>
-#include <memory>
-#include <string>
-
-#include "utils/base/logging.h"
-
-namespace libtextclassifier3 {
-
-bool JByteArrayToString(JNIEnv* env, const jbyteArray& array,
- std::string* result);
-bool JStringToUtf8String(JNIEnv* env, const jstring& jstr, std::string* result);
-
-// A deleter to be used with std::unique_ptr to release Java string chars.
-class StringCharsReleaser {
- public:
- StringCharsReleaser() : env_(nullptr) {}
-
- StringCharsReleaser(JNIEnv* env, jstring jstr) : env_(env), jstr_(jstr) {}
-
- StringCharsReleaser(const StringCharsReleaser& orig) = default;
-
- // Copy assignment to allow move semantics in StringCharsReleaser.
- StringCharsReleaser& operator=(const StringCharsReleaser& rhs) {
- // As the releaser and its state are thread-local, it's enough to only
- // ensure the envs are consistent but do nothing.
- TC3_CHECK_EQ(env_, rhs.env_);
- return *this;
- }
-
- // The delete operator.
- void operator()(const char* chars) const {
- if (env_ != nullptr) {
- env_->ReleaseStringUTFChars(jstr_, chars);
- }
- }
-
- private:
- // The env_ stashed to use for deletion. Thread-local, don't share!
- JNIEnv* const env_;
-
- // The referenced jstring.
- jstring jstr_;
-};
-
-// A smart pointer that releases string chars when it goes out of scope.
-// of scope.
-// Note that this class is not thread-safe since it caches JNIEnv in
-// the deleter. Do not use the same jobject across different threads.
-using ScopedStringChars = std::unique_ptr<const char, StringCharsReleaser>;
-
-// Returns a scoped pointer to the array of Unicode characters of a string.
-ScopedStringChars GetScopedStringChars(JNIEnv* env, jstring string,
- jboolean* is_copy = nullptr);
-
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_UTILS_JAVA_STRING_UTILS_H_
diff --git a/native/utils/jvm-test-utils.h b/native/utils/jvm-test-utils.h
new file mode 100644
index 0000000..4600248
--- /dev/null
+++ b/native/utils/jvm-test-utils.h
@@ -0,0 +1,56 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_JVM_TEST_UTILS_H_
+#define LIBTEXTCLASSIFIER_UTILS_JVM_TEST_UTILS_H_
+
+#include "utils/calendar/calendar.h"
+#include "utils/utf8/unilib.h"
+
+#if defined(__ANDROID__)
+#include <jni.h>
+
+// Use these with jvm_test_launcher which provides these global variables.
+extern JNIEnv* g_jenv;
+extern jobject g_context;
+#endif
+
+namespace libtextclassifier3 {
+inline std::unique_ptr<UniLib> CreateUniLibForTesting() {
+#if defined TC3_UNILIB_JAVAICU
+ return std::make_unique<UniLib>(JniCache::Create(g_jenv));
+#else
+ return std::make_unique<UniLib>();
+#endif
+}
+
+inline std::unique_ptr<CalendarLib> CreateCalendarLibForTesting() {
+#if defined TC3_CALENDAR_JAVAICU
+ return std::make_unique<CalendarLib>(JniCache::Create(g_jenv));
+#else
+ return std::make_unique<CalendarLib>();
+#endif
+}
+
+#if defined(__ANDROID__)
+inline JNIEnv* GetJenv() { return g_jenv; }
+
+inline jobject GetAndroidContext() { return g_context; }
+#endif
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_JVM_TEST_UTILS_H_
diff --git a/native/utils/lua-utils.cc b/native/utils/lua-utils.cc
index fe3d12d..9117c54 100644
--- a/native/utils/lua-utils.cc
+++ b/native/utils/lua-utils.cc
@@ -16,11 +16,6 @@
#include "utils/lua-utils.h"
-// lua_dump takes an extra argument "strip" in 5.3, but not in 5.2.
-#ifndef TC3_AOSP
-#define lua_dump(L, w, d, s) lua_dump((L), (w), (d))
-#endif
-
namespace libtextclassifier3 {
namespace {
static constexpr luaL_Reg defaultlibs[] = {{"_G", luaopen_base},
@@ -86,6 +81,8 @@
const reflection::BaseType field_type = field->type()->base_type();
switch (field_type) {
case reflection::Bool:
+ Push(table->GetField<bool>(field->offset(), field->default_integer()));
+ break;
case reflection::UByte:
Push(table->GetField<uint8>(field->offset(), field->default_integer()));
break;
@@ -98,12 +95,6 @@
case reflection::UInt:
Push(table->GetField<uint32>(field->offset(), field->default_integer()));
break;
- case reflection::Short:
- Push(table->GetField<int16>(field->offset(), field->default_integer()));
- break;
- case reflection::UShort:
- Push(table->GetField<uint16>(field->offset(), field->default_integer()));
- break;
case reflection::Long:
Push(table->GetField<int64>(field->offset(), field->default_integer()));
break;
@@ -144,6 +135,9 @@
}
switch (field->type()->element()) {
case reflection::Bool:
+ PushRepeatedField(table->GetPointer<const flatbuffers::Vector<bool>*>(
+ field->offset()));
+ break;
case reflection::UByte:
PushRepeatedField(
table->GetPointer<const flatbuffers::Vector<uint8>*>(
@@ -163,16 +157,6 @@
table->GetPointer<const flatbuffers::Vector<uint32>*>(
field->offset()));
break;
- case reflection::Short:
- PushRepeatedField(
- table->GetPointer<const flatbuffers::Vector<int16>*>(
- field->offset()));
- break;
- case reflection::UShort:
- PushRepeatedField(
- table->GetPointer<const flatbuffers::Vector<uint16>*>(
- field->offset()));
- break;
case reflection::Long:
PushRepeatedField(
table->GetPointer<const flatbuffers::Vector<int64>*>(
@@ -221,7 +205,7 @@
}
int LuaEnvironment::ReadFlatbuffer(const int index,
- ReflectiveFlatbuffer* buffer) const {
+ MutableFlatbuffer* buffer) const {
if (buffer == nullptr) {
TC3_LOG(ERROR) << "Called ReadFlatbuffer with null buffer: " << index;
lua_error(state_);
@@ -322,8 +306,8 @@
buffer->Repeated(field));
break;
case reflection::Obj:
- ReadRepeatedField<ReflectiveFlatbuffer>(/*index=*/kIndexStackTop,
- buffer->Repeated(field));
+ ReadRepeatedField<MutableFlatbuffer>(/*index=*/kIndexStackTop,
+ buffer->Repeated(field));
break;
default:
TC3_LOG(ERROR) << "Unsupported repeated field type: "
@@ -542,7 +526,7 @@
classification.serialized_entity_data =
Read<std::string>(/*index=*/kIndexStackTop);
} else if (key.Equals(kEntityKey)) {
- auto buffer = ReflectiveFlatbufferBuilder(entity_data_schema).NewRoot();
+ auto buffer = MutableFlatbufferBuilder(entity_data_schema).NewRoot();
ReadFlatbuffer(/*index=*/kIndexStackTop, buffer.get());
classification.serialized_entity_data = buffer->Serialize();
} else {
@@ -610,7 +594,7 @@
ReadAnnotations(actions_entity_data_schema, &action.annotations);
} else if (key.Equals(kEntityKey)) {
auto buffer =
- ReflectiveFlatbufferBuilder(actions_entity_data_schema).NewRoot();
+ MutableFlatbufferBuilder(actions_entity_data_schema).NewRoot();
ReadFlatbuffer(/*index=*/kIndexStackTop, buffer.get());
action.serialized_entity_data = buffer->Serialize();
} else {
diff --git a/native/utils/lua-utils.h b/native/utils/lua-utils.h
index b01471a..a76c790 100644
--- a/native/utils/lua-utils.h
+++ b/native/utils/lua-utils.h
@@ -21,7 +21,7 @@
#include "actions/types.h"
#include "annotator/types.h"
-#include "utils/flatbuffers.h"
+#include "utils/flatbuffers/mutable.h"
#include "utils/strings/stringpiece.h"
#include "utils/variant.h"
#include "flatbuffers/reflection_generated.h"
@@ -65,7 +65,7 @@
class LuaEnvironment {
public:
virtual ~LuaEnvironment();
- LuaEnvironment();
+ explicit LuaEnvironment();
// Compile a lua snippet into binary bytecode.
// NOTE: The compiled bytecode might not be compatible across Lua versions
@@ -137,42 +137,42 @@
template <>
int64 Read<int64>(const int index) const {
- return static_cast<int64>(lua_tonumber(state_, /*idx=*/index));
+ return static_cast<int64>(lua_tointeger(state_, /*idx=*/index));
}
template <>
uint64 Read<uint64>(const int index) const {
- return static_cast<uint64>(lua_tonumber(state_, /*idx=*/index));
+ return static_cast<uint64>(lua_tointeger(state_, /*idx=*/index));
}
template <>
int32 Read<int32>(const int index) const {
- return static_cast<int32>(lua_tonumber(state_, /*idx=*/index));
+ return static_cast<int32>(lua_tointeger(state_, /*idx=*/index));
}
template <>
uint32 Read<uint32>(const int index) const {
- return static_cast<uint32>(lua_tonumber(state_, /*idx=*/index));
+ return static_cast<uint32>(lua_tointeger(state_, /*idx=*/index));
}
template <>
int16 Read<int16>(const int index) const {
- return static_cast<int16>(lua_tonumber(state_, /*idx=*/index));
+ return static_cast<int16>(lua_tointeger(state_, /*idx=*/index));
}
template <>
uint16 Read<uint16>(const int index) const {
- return static_cast<uint16>(lua_tonumber(state_, /*idx=*/index));
+ return static_cast<uint16>(lua_tointeger(state_, /*idx=*/index));
}
template <>
int8 Read<int8>(const int index) const {
- return static_cast<int8>(lua_tonumber(state_, /*idx=*/index));
+ return static_cast<int8>(lua_tointeger(state_, /*idx=*/index));
}
template <>
uint8 Read<uint8>(const int index) const {
- return static_cast<uint8>(lua_tonumber(state_, /*idx=*/index));
+ return static_cast<uint8>(lua_tointeger(state_, /*idx=*/index));
}
template <>
@@ -213,7 +213,7 @@
}
// Reads a flatbuffer from the stack.
- int ReadFlatbuffer(int index, ReflectiveFlatbuffer* buffer) const;
+ int ReadFlatbuffer(int index, MutableFlatbuffer* buffer) const;
// Pushes an iterator.
template <typename ItemCallback, typename KeyCallback>
@@ -507,14 +507,14 @@
// Reads a repeated field from lua.
template <typename T>
void ReadRepeatedField(const int index, RepeatedField* result) const {
- for (const auto& element : ReadVector<T>(index)) {
+ for (const T& element : ReadVector<T>(index)) {
result->Add(element);
}
}
template <>
- void ReadRepeatedField<ReflectiveFlatbuffer>(const int index,
- RepeatedField* result) const {
+ void ReadRepeatedField<MutableFlatbuffer>(const int index,
+ RepeatedField* result) const {
lua_pushnil(state_);
while (Next(index - 1)) {
ReadFlatbuffer(index, result->Add());
diff --git a/native/utils/lua-utils_test.cc b/native/utils/lua-utils_test.cc
index 8c9f8de..44190b8 100644
--- a/native/utils/lua-utils_test.cc
+++ b/native/utils/lua-utils_test.cc
@@ -16,96 +16,31 @@
#include "utils/lua-utils.h"
+#include <memory>
#include <string>
-#include "utils/flatbuffers.h"
+#include "utils/flatbuffers/flatbuffers.h"
+#include "utils/flatbuffers/mutable.h"
+#include "utils/lua_utils_tests_generated.h"
+#include "utils/strings/stringpiece.h"
+#include "utils/test-data-test-utils.h"
+#include "utils/testing/test_data_generator.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
namespace libtextclassifier3 {
namespace {
+using testing::DoubleEq;
using testing::ElementsAre;
using testing::Eq;
using testing::FloatEq;
-std::string TestFlatbufferSchema() {
- // Creates a test schema for flatbuffer passing tests.
- // Cannot use the object oriented API here as that is not available for the
- // reflection schema.
- flatbuffers::FlatBufferBuilder schema_builder;
- std::vector<flatbuffers::Offset<reflection::Field>> fields = {
- reflection::CreateField(
- schema_builder,
- /*name=*/schema_builder.CreateString("float_field"),
- /*type=*/
- reflection::CreateType(schema_builder,
- /*base_type=*/reflection::Float),
- /*id=*/0,
- /*offset=*/4),
- reflection::CreateField(
- schema_builder,
- /*name=*/schema_builder.CreateString("nested_field"),
- /*type=*/
- reflection::CreateType(schema_builder,
- /*base_type=*/reflection::Obj,
- /*element=*/reflection::None,
- /*index=*/0 /* self */),
- /*id=*/1,
- /*offset=*/6),
- reflection::CreateField(
- schema_builder,
- /*name=*/schema_builder.CreateString("repeated_nested_field"),
- /*type=*/
- reflection::CreateType(schema_builder,
- /*base_type=*/reflection::Vector,
- /*element=*/reflection::Obj,
- /*index=*/0 /* self */),
- /*id=*/2,
- /*offset=*/8),
- reflection::CreateField(
- schema_builder,
- /*name=*/schema_builder.CreateString("repeated_string_field"),
- /*type=*/
- reflection::CreateType(schema_builder,
- /*base_type=*/reflection::Vector,
- /*element=*/reflection::String),
- /*id=*/3,
- /*offset=*/10),
- reflection::CreateField(
- schema_builder,
- /*name=*/schema_builder.CreateString("string_field"),
- /*type=*/
- reflection::CreateType(schema_builder,
- /*base_type=*/reflection::String),
- /*id=*/4,
- /*offset=*/12)};
-
- std::vector<flatbuffers::Offset<reflection::Enum>> enums;
- std::vector<flatbuffers::Offset<reflection::Object>> objects = {
- reflection::CreateObject(
- schema_builder,
- /*name=*/schema_builder.CreateString("TestData"),
- /*fields=*/
- schema_builder.CreateVectorOfSortedTables(&fields))};
- schema_builder.Finish(reflection::CreateSchema(
- schema_builder, schema_builder.CreateVectorOfSortedTables(&objects),
- schema_builder.CreateVectorOfSortedTables(&enums),
- /*(unused) file_ident=*/0,
- /*(unused) file_ext=*/0,
- /*root_table*/ objects[0]));
- return std::string(
- reinterpret_cast<const char*>(schema_builder.GetBufferPointer()),
- schema_builder.GetSize());
-}
-
class LuaUtilsTest : public testing::Test, protected LuaEnvironment {
protected:
LuaUtilsTest()
- : serialized_flatbuffer_schema_(TestFlatbufferSchema()),
- schema_(flatbuffers::GetRoot<reflection::Schema>(
- serialized_flatbuffer_schema_.data())),
- flatbuffer_builder_(schema_) {
+ : schema_(GetTestFileContent("utils/lua_utils_tests.bfbs")),
+ flatbuffer_builder_(schema_.get()) {
EXPECT_THAT(RunProtected([this] {
LoadDefaultLibraries();
return LUA_OK;
@@ -122,163 +57,283 @@
Eq(LUA_OK));
}
- const std::string serialized_flatbuffer_schema_;
- const reflection::Schema* schema_;
- ReflectiveFlatbufferBuilder flatbuffer_builder_;
+ OwnedFlatbuffer<reflection::Schema, std::string> schema_;
+ MutableFlatbufferBuilder flatbuffer_builder_;
+ TestDataGenerator test_data_generator_;
};
-TEST_F(LuaUtilsTest, HandlesVectors) {
- {
- PushVector(std::vector<int64>{1, 2, 3, 4, 5});
- EXPECT_THAT(ReadVector<int64>(), ElementsAre(1, 2, 3, 4, 5));
- }
- {
- PushVector(std::vector<std::string>{"hello", "there"});
- EXPECT_THAT(ReadVector<std::string>(), ElementsAre("hello", "there"));
- }
- {
- PushVector(std::vector<bool>{true, true, false});
- EXPECT_THAT(ReadVector<bool>(), ElementsAre(true, true, false));
- }
+template <typename T>
+class TypedLuaUtilsTest : public LuaUtilsTest {};
+
+using testing::Types;
+using LuaTypes =
+ ::testing::Types<int64, uint64, int32, uint32, int16, uint16, int8, uint8,
+ float, double, bool, std::string>;
+TYPED_TEST_SUITE(TypedLuaUtilsTest, LuaTypes);
+
+TYPED_TEST(TypedLuaUtilsTest, HandlesVectors) {
+ std::vector<TypeParam> elements(5);
+ std::generate_n(elements.begin(), 5, [&]() {
+ return this->test_data_generator_.template generate<TypeParam>();
+ });
+
+ this->PushVector(elements);
+
+ EXPECT_THAT(this->template ReadVector<TypeParam>(),
+ testing::ContainerEq(elements));
}
-TEST_F(LuaUtilsTest, HandlesVectorIterators) {
- {
- const std::vector<int64> elements = {1, 2, 3, 4, 5};
- PushVectorIterator(&elements);
- EXPECT_THAT(ReadVector<int64>(), ElementsAre(1, 2, 3, 4, 5));
- }
- {
- const std::vector<std::string> elements = {"hello", "there"};
- PushVectorIterator(&elements);
- EXPECT_THAT(ReadVector<std::string>(), ElementsAre("hello", "there"));
- }
- {
- const std::vector<bool> elements = {true, true, false};
- PushVectorIterator(&elements);
- EXPECT_THAT(ReadVector<bool>(), ElementsAre(true, true, false));
- }
+TYPED_TEST(TypedLuaUtilsTest, HandlesVectorIterators) {
+ std::vector<TypeParam> elements(5);
+ std::generate_n(elements.begin(), 5, [&]() {
+ return this->test_data_generator_.template generate<TypeParam>();
+ });
+
+ this->PushVectorIterator(&elements);
+
+ EXPECT_THAT(this->template ReadVector<TypeParam>(),
+ testing::ContainerEq(elements));
}
-TEST_F(LuaUtilsTest, ReadsFlatbufferResults) {
- // Setup.
+TEST_F(LuaUtilsTest, IndexCallback) {
+ test::TestDataT input_data;
+ input_data.repeated_byte_field = {1, 2};
+ input_data.repeated_ubyte_field = {1, 2};
+ input_data.repeated_int_field = {1, 2};
+ input_data.repeated_uint_field = {1, 2};
+ input_data.repeated_long_field = {1, 2};
+ input_data.repeated_ulong_field = {1, 2};
+ input_data.repeated_bool_field = {true, false};
+ input_data.repeated_float_field = {1, 2};
+ input_data.repeated_double_field = {1, 2};
+ input_data.repeated_string_field = {"1", "2"};
+
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(test::TestData::Pack(builder, &input_data));
+ const flatbuffers::DetachedBuffer input_buffer = builder.Release();
+ PushFlatbuffer(schema_.get(),
+ flatbuffers::GetRoot<flatbuffers::Table>(input_buffer.data()));
+ lua_setglobal(state_, "arg");
+ // A Lua script that reads the vectors and return the first value of them.
+ // This should trigger the __index callback.
RunScript(R"lua(
return {
- float_field = 42.1,
- string_field = "hello there",
+ byte_field = arg.repeated_byte_field[1],
+ ubyte_field = arg.repeated_ubyte_field[1],
+ int_field = arg.repeated_int_field[1],
+ uint_field = arg.repeated_uint_field[1],
+ long_field = arg.repeated_long_field[1],
+ ulong_field = arg.repeated_ulong_field[1],
+ bool_field = arg.repeated_bool_field[1],
+ float_field = arg.repeated_float_field[1],
+ double_field = arg.repeated_double_field[1],
+ string_field = arg.repeated_string_field[1],
+ }
+ )lua");
- -- Nested field.
+ // Read the flatbuffer.
+ std::unique_ptr<MutableFlatbuffer> buffer = flatbuffer_builder_.NewRoot();
+ ReadFlatbuffer(/*index=*/-1, buffer.get());
+ const std::string serialized_buffer = buffer->Serialize();
+ std::unique_ptr<test::TestDataT> test_data =
+ LoadAndVerifyMutableFlatbuffer<test::TestData>(buffer->Serialize());
+
+ EXPECT_THAT(test_data->byte_field, 1);
+ EXPECT_THAT(test_data->ubyte_field, 1);
+ EXPECT_THAT(test_data->int_field, 1);
+ EXPECT_THAT(test_data->uint_field, 1);
+ EXPECT_THAT(test_data->long_field, 1);
+ EXPECT_THAT(test_data->ulong_field, 1);
+ EXPECT_THAT(test_data->bool_field, true);
+ EXPECT_THAT(test_data->float_field, FloatEq(1));
+ EXPECT_THAT(test_data->double_field, DoubleEq(1));
+ EXPECT_THAT(test_data->string_field, "1");
+}
+
+TEST_F(LuaUtilsTest, PairCallback) {
+ test::TestDataT input_data;
+ input_data.repeated_byte_field = {1, 2};
+ input_data.repeated_ubyte_field = {1, 2};
+ input_data.repeated_int_field = {1, 2};
+ input_data.repeated_uint_field = {1, 2};
+ input_data.repeated_long_field = {1, 2};
+ input_data.repeated_ulong_field = {1, 2};
+ input_data.repeated_bool_field = {true, false};
+ input_data.repeated_float_field = {1, 2};
+ input_data.repeated_double_field = {1, 2};
+ input_data.repeated_string_field = {"1", "2"};
+
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(test::TestData::Pack(builder, &input_data));
+ const flatbuffers::DetachedBuffer input_buffer = builder.Release();
+ PushFlatbuffer(schema_.get(),
+ flatbuffers::GetRoot<flatbuffers::Table>(input_buffer.data()));
+ lua_setglobal(state_, "arg");
+
+ // Iterate the pushed repeated fields by using the pair API and check
+ // if the value is correct. This should trigger the __pair callback.
+ RunScript(R"lua(
+ function equal(table1, table2)
+ for key, value in pairs(table1) do
+ if value ~= table2[key] then
+ return false
+ end
+ end
+ return true
+ end
+
+ local valid = equal(arg.repeated_byte_field, {[1]=1,[2]=2})
+ valid = valid and equal(arg.repeated_ubyte_field, {[1]=1,[2]=2})
+ valid = valid and equal(arg.repeated_int_field, {[1]=1,[2]=2})
+ valid = valid and equal(arg.repeated_uint_field, {[1]=1,[2]=2})
+ valid = valid and equal(arg.repeated_long_field, {[1]=1,[2]=2})
+ valid = valid and equal(arg.repeated_ulong_field, {[1]=1,[2]=2})
+ valid = valid and equal(arg.repeated_bool_field, {[1]=true,[2]=false})
+ valid = valid and equal(arg.repeated_float_field, {[1]=1,[2]=2})
+ valid = valid and equal(arg.repeated_double_field, {[1]=1,[2]=2})
+ valid = valid and equal(arg.repeated_string_field, {[1]="1",[2]="2"})
+
+ return {
+ bool_field = valid
+ }
+ )lua");
+
+ // Read the flatbuffer.
+ std::unique_ptr<MutableFlatbuffer> buffer = flatbuffer_builder_.NewRoot();
+ ReadFlatbuffer(/*index=*/-1, buffer.get());
+ const std::string serialized_buffer = buffer->Serialize();
+ std::unique_ptr<test::TestDataT> test_data =
+ LoadAndVerifyMutableFlatbuffer<test::TestData>(buffer->Serialize());
+
+ EXPECT_THAT(test_data->bool_field, true);
+}
+
+TEST_F(LuaUtilsTest, PushAndReadsFlatbufferRoundTrip) {
+ // Setup.
+ test::TestDataT input_data;
+ input_data.byte_field = 1;
+ input_data.ubyte_field = 2;
+ input_data.int_field = 10;
+ input_data.uint_field = 11;
+ input_data.long_field = 20;
+ input_data.ulong_field = 21;
+ input_data.bool_field = true;
+ input_data.float_field = 42.1;
+ input_data.double_field = 12.4;
+ input_data.string_field = "hello there";
+ // Nested field.
+ input_data.nested_field = std::make_unique<test::TestDataT>();
+ input_data.nested_field->float_field = 64;
+ input_data.nested_field->string_field = "hello nested";
+ // Repeated fields.
+ input_data.repeated_byte_field = {1, 2, 1};
+ input_data.repeated_byte_field = {1, 2, 1};
+ input_data.repeated_ubyte_field = {2, 4, 2};
+ input_data.repeated_int_field = {1, 2, 3};
+ input_data.repeated_uint_field = {2, 4, 6};
+ input_data.repeated_long_field = {4, 5, 6};
+ input_data.repeated_ulong_field = {8, 10, 12};
+ input_data.repeated_bool_field = {true, false, true};
+ input_data.repeated_float_field = {1.23, 2.34, 3.45};
+ input_data.repeated_double_field = {1.11, 2.22, 3.33};
+ input_data.repeated_string_field = {"a", "bold", "one"};
+ // Repeated nested fields.
+ input_data.repeated_nested_field.push_back(
+ std::make_unique<test::TestDataT>());
+ input_data.repeated_nested_field.back()->string_field = "a";
+ input_data.repeated_nested_field.push_back(
+ std::make_unique<test::TestDataT>());
+ input_data.repeated_nested_field.back()->string_field = "b";
+ input_data.repeated_nested_field.push_back(
+ std::make_unique<test::TestDataT>());
+ input_data.repeated_nested_field.back()->repeated_string_field = {"nested",
+ "nested2"};
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(test::TestData::Pack(builder, &input_data));
+ const flatbuffers::DetachedBuffer input_buffer = builder.Release();
+ PushFlatbuffer(schema_.get(),
+ flatbuffers::GetRoot<flatbuffers::Table>(input_buffer.data()));
+ lua_setglobal(state_, "arg");
+
+ RunScript(R"lua(
+ return {
+ byte_field = arg.byte_field,
+ ubyte_field = arg.ubyte_field,
+ int_field = arg.int_field,
+ uint_field = arg.uint_field,
+ long_field = arg.long_field,
+ ulong_field = arg.ulong_field,
+ bool_field = arg.bool_field,
+ float_field = arg.float_field,
+ double_field = arg.double_field,
+ string_field = arg.string_field,
nested_field = {
- float_field = 64,
- string_field = "hello nested",
+ float_field = arg.nested_field.float_field,
+ string_field = arg.nested_field.string_field,
},
-
- -- Repeated fields.
- repeated_string_field = { "a", "bold", "one" },
+ repeated_byte_field = arg.repeated_byte_field,
+ repeated_ubyte_field = arg.repeated_ubyte_field,
+ repeated_int_field = arg.repeated_int_field,
+ repeated_uint_field = arg.repeated_uint_field,
+ repeated_long_field = arg.repeated_long_field,
+ repeated_ulong_field = arg.repeated_ulong_field,
+ repeated_bool_field = arg.repeated_bool_field,
+ repeated_float_field = arg.repeated_float_field,
+ repeated_double_field = arg.repeated_double_field,
+ repeated_string_field = arg.repeated_string_field,
repeated_nested_field = {
- { string_field = "a" },
- { string_field = "b" },
- { repeated_string_field = { "nested", "nested2" } },
+ { string_field = arg.repeated_nested_field[1].string_field },
+ { string_field = arg.repeated_nested_field[2].string_field },
+ { repeated_string_field = arg.repeated_nested_field[3].repeated_string_field },
},
}
)lua");
// Read the flatbuffer.
- std::unique_ptr<ReflectiveFlatbuffer> buffer = flatbuffer_builder_.NewRoot();
+ std::unique_ptr<MutableFlatbuffer> buffer = flatbuffer_builder_.NewRoot();
ReadFlatbuffer(/*index=*/-1, buffer.get());
const std::string serialized_buffer = buffer->Serialize();
+ std::unique_ptr<test::TestDataT> test_data =
+ LoadAndVerifyMutableFlatbuffer<test::TestData>(buffer->Serialize());
- // Check fields. As we do not have flatbuffer compiled generated code for the
- // ad hoc generated test schema, we have to read by manually using field
- // offsets.
- const flatbuffers::Table* flatbuffer_data =
- flatbuffers::GetRoot<flatbuffers::Table>(serialized_buffer.data());
- EXPECT_THAT(flatbuffer_data->GetField<float>(/*field=*/4, /*defaultval=*/0),
- FloatEq(42.1));
- EXPECT_THAT(
- flatbuffer_data->GetPointer<const flatbuffers::String*>(/*field=*/12)
- ->str(),
- "hello there");
-
- // Read the nested field.
- const flatbuffers::Table* nested_field =
- flatbuffer_data->GetPointer<const flatbuffers::Table*>(/*field=*/6);
- EXPECT_THAT(nested_field->GetField<float>(/*field=*/4, /*defaultval=*/0),
- FloatEq(64));
- EXPECT_THAT(
- nested_field->GetPointer<const flatbuffers::String*>(/*field=*/12)->str(),
- "hello nested");
-
- // Read the repeated string field.
- auto repeated_strings = flatbuffer_data->GetPointer<
- flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>*>(
- /*field=*/10);
- EXPECT_THAT(repeated_strings->size(), Eq(3));
- EXPECT_THAT(repeated_strings->GetAsString(0)->str(), Eq("a"));
- EXPECT_THAT(repeated_strings->GetAsString(1)->str(), Eq("bold"));
- EXPECT_THAT(repeated_strings->GetAsString(2)->str(), Eq("one"));
-
- // Read the repeated nested field.
- auto repeated_nested_fields = flatbuffer_data->GetPointer<
- flatbuffers::Vector<flatbuffers::Offset<flatbuffers::Table>>*>(
- /*field=*/8);
- EXPECT_THAT(repeated_nested_fields->size(), Eq(3));
- EXPECT_THAT(repeated_nested_fields->Get(0)
- ->GetPointer<const flatbuffers::String*>(/*field=*/12)
- ->str(),
- "a");
- EXPECT_THAT(repeated_nested_fields->Get(1)
- ->GetPointer<const flatbuffers::String*>(/*field=*/12)
- ->str(),
- "b");
-}
-
-TEST_F(LuaUtilsTest, HandlesSimpleFlatbufferFields) {
- // Create test flatbuffer.
- std::unique_ptr<ReflectiveFlatbuffer> buffer = flatbuffer_builder_.NewRoot();
- buffer->Set("float_field", 42.f);
- const std::string serialized_buffer = buffer->Serialize();
- PushFlatbuffer(schema_, flatbuffers::GetRoot<flatbuffers::Table>(
- serialized_buffer.data()));
- lua_setglobal(state_, "arg");
-
- // Setup.
- RunScript(R"lua(
- return arg.float_field
- )lua");
-
- EXPECT_THAT(Read<float>(), FloatEq(42));
-}
-
-TEST_F(LuaUtilsTest, HandlesRepeatedFlatbufferFields) {
- // Create test flatbuffer.
- std::unique_ptr<ReflectiveFlatbuffer> buffer = flatbuffer_builder_.NewRoot();
- RepeatedField* repeated_field = buffer->Repeated("repeated_string_field");
- repeated_field->Add("this");
- repeated_field->Add("is");
- repeated_field->Add("a");
- repeated_field->Add("test");
- const std::string serialized_buffer = buffer->Serialize();
- PushFlatbuffer(schema_, flatbuffers::GetRoot<flatbuffers::Table>(
- serialized_buffer.data()));
- lua_setglobal(state_, "arg");
-
- // Return flatbuffer repeated field as vector.
- RunScript(R"lua(
- return arg.repeated_string_field
- )lua");
-
- EXPECT_THAT(ReadVector<std::string>(),
- ElementsAre("this", "is", "a", "test"));
+ EXPECT_THAT(test_data->byte_field, 1);
+ EXPECT_THAT(test_data->ubyte_field, 2);
+ EXPECT_THAT(test_data->int_field, 10);
+ EXPECT_THAT(test_data->uint_field, 11);
+ EXPECT_THAT(test_data->long_field, 20);
+ EXPECT_THAT(test_data->ulong_field, 21);
+ EXPECT_THAT(test_data->bool_field, true);
+ EXPECT_THAT(test_data->float_field, FloatEq(42.1));
+ EXPECT_THAT(test_data->double_field, DoubleEq(12.4));
+ EXPECT_THAT(test_data->string_field, "hello there");
+ EXPECT_THAT(test_data->repeated_byte_field, ElementsAre(1, 2, 1));
+ EXPECT_THAT(test_data->repeated_ubyte_field, ElementsAre(2, 4, 2));
+ EXPECT_THAT(test_data->repeated_int_field, ElementsAre(1, 2, 3));
+ EXPECT_THAT(test_data->repeated_uint_field, ElementsAre(2, 4, 6));
+ EXPECT_THAT(test_data->repeated_long_field, ElementsAre(4, 5, 6));
+ EXPECT_THAT(test_data->repeated_ulong_field, ElementsAre(8, 10, 12));
+ EXPECT_THAT(test_data->repeated_bool_field, ElementsAre(true, false, true));
+ EXPECT_THAT(test_data->repeated_float_field, ElementsAre(1.23, 2.34, 3.45));
+ EXPECT_THAT(test_data->repeated_double_field, ElementsAre(1.11, 2.22, 3.33));
+ EXPECT_THAT(test_data->repeated_string_field,
+ ElementsAre("a", "bold", "one"));
+ // Nested fields.
+ EXPECT_THAT(test_data->nested_field->float_field, FloatEq(64));
+ EXPECT_THAT(test_data->nested_field->string_field, "hello nested");
+ // Repeated nested fields.
+ EXPECT_THAT(test_data->repeated_nested_field[0]->string_field, "a");
+ EXPECT_THAT(test_data->repeated_nested_field[1]->string_field, "b");
+ EXPECT_THAT(test_data->repeated_nested_field[2]->repeated_string_field,
+ ElementsAre("nested", "nested2"));
}
TEST_F(LuaUtilsTest, HandlesRepeatedNestedFlatbufferFields) {
// Create test flatbuffer.
- std::unique_ptr<ReflectiveFlatbuffer> buffer = flatbuffer_builder_.NewRoot();
+ std::unique_ptr<MutableFlatbuffer> buffer = flatbuffer_builder_.NewRoot();
RepeatedField* repeated_field = buffer->Repeated("repeated_nested_field");
repeated_field->Add()->Set("string_field", "hello");
repeated_field->Add()->Set("string_field", "my");
- ReflectiveFlatbuffer* nested = repeated_field->Add();
+ MutableFlatbuffer* nested = repeated_field->Add();
nested->Set("string_field", "old");
RepeatedField* nested_repeated = nested->Repeated("repeated_string_field");
nested_repeated->Add("friend");
@@ -286,8 +341,8 @@
nested_repeated->Add("are");
repeated_field->Add()->Set("string_field", "you?");
const std::string serialized_buffer = buffer->Serialize();
- PushFlatbuffer(schema_, flatbuffers::GetRoot<flatbuffers::Table>(
- serialized_buffer.data()));
+ PushFlatbuffer(schema_.get(), flatbuffers::GetRoot<flatbuffers::Table>(
+ serialized_buffer.data()));
lua_setglobal(state_, "arg");
RunScript(R"lua(
@@ -308,18 +363,18 @@
TEST_F(LuaUtilsTest, CorrectlyReadsTwoFlatbuffersSimultaneously) {
// The first flatbuffer.
- std::unique_ptr<ReflectiveFlatbuffer> buffer = flatbuffer_builder_.NewRoot();
+ std::unique_ptr<MutableFlatbuffer> buffer = flatbuffer_builder_.NewRoot();
buffer->Set("string_field", "first");
const std::string serialized_buffer = buffer->Serialize();
- PushFlatbuffer(schema_, flatbuffers::GetRoot<flatbuffers::Table>(
- serialized_buffer.data()));
+ PushFlatbuffer(schema_.get(), flatbuffers::GetRoot<flatbuffers::Table>(
+ serialized_buffer.data()));
lua_setglobal(state_, "arg");
// The second flatbuffer.
- std::unique_ptr<ReflectiveFlatbuffer> buffer2 = flatbuffer_builder_.NewRoot();
+ std::unique_ptr<MutableFlatbuffer> buffer2 = flatbuffer_builder_.NewRoot();
buffer2->Set("string_field", "second");
const std::string serialized_buffer2 = buffer2->Serialize();
- PushFlatbuffer(schema_, flatbuffers::GetRoot<flatbuffers::Table>(
- serialized_buffer2.data()));
+ PushFlatbuffer(schema_.get(), flatbuffers::GetRoot<flatbuffers::Table>(
+ serialized_buffer2.data()));
lua_setglobal(state_, "arg2");
RunScript(R"lua(
diff --git a/native/utils/lua_utils_tests.bfbs b/native/utils/lua_utils_tests.bfbs
new file mode 100644
index 0000000..acb731b
--- /dev/null
+++ b/native/utils/lua_utils_tests.bfbs
Binary files differ
diff --git a/native/utils/lua_utils_tests.fbs b/native/utils/lua_utils_tests.fbs
new file mode 100644
index 0000000..6d8ad38
--- /dev/null
+++ b/native/utils/lua_utils_tests.fbs
@@ -0,0 +1,45 @@
+//
+// Copyright (C) 2018 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+namespace libtextclassifier3.test;
+
+table TestData {
+ byte_field: byte;
+ ubyte_field: ubyte;
+ int_field: int;
+ uint_field: uint;
+ long_field: int64;
+ ulong_field: uint64;
+ bool_field: bool;
+ float_field: float;
+ double_field: double;
+ string_field: string;
+ nested_field: TestData;
+
+ repeated_byte_field: [byte];
+ repeated_ubyte_field: [ubyte];
+ repeated_int_field: [int];
+ repeated_uint_field: [uint];
+ repeated_long_field: [int64];
+ repeated_ulong_field: [uint64];
+ repeated_bool_field: [bool];
+ repeated_float_field: [float];
+ repeated_double_field: [double];
+ repeated_string_field: [string];
+ repeated_nested_field: [TestData];
+}
+
+root_type libtextclassifier3.test.TestData;
diff --git a/native/utils/normalization.fbs b/native/utils/normalization.fbs
old mode 100755
new mode 100644
diff --git a/native/utils/optional.h b/native/utils/optional.h
index 15d2619..572350d 100644
--- a/native/utils/optional.h
+++ b/native/utils/optional.h
@@ -62,7 +62,7 @@
return value_;
}
- T const& value_or(T&& default_value) {
+ T const& value_or(T&& default_value) const& {
return (init_ ? value_ : default_value);
}
diff --git a/native/utils/regex-match_test.cc b/native/utils/regex-match_test.cc
index c45fb29..c7a7740 100644
--- a/native/utils/regex-match_test.cc
+++ b/native/utils/regex-match_test.cc
@@ -18,6 +18,7 @@
#include <memory>
+#include "utils/jvm-test-utils.h"
#include "utils/utf8/unicodetext.h"
#include "utils/utf8/unilib.h"
#include "gmock/gmock.h"
@@ -28,11 +29,10 @@
class RegexMatchTest : public testing::Test {
protected:
- RegexMatchTest() : INIT_UNILIB_FOR_TESTING(unilib_) {}
- UniLib unilib_;
+ RegexMatchTest() : unilib_(libtextclassifier3::CreateUniLibForTesting()) {}
+ std::unique_ptr<UniLib> unilib_;
};
-#ifdef TC3_UNILIB_ICU
#ifndef TC3_DISABLE_LUA
TEST_F(RegexMatchTest, HandlesSimpleVerification) {
EXPECT_TRUE(VerifyMatch(/*context=*/"", /*matcher=*/nullptr, "return true;"));
@@ -65,7 +65,7 @@
return luhn(match[1].text);
)";
const std::unique_ptr<UniLib::RegexPattern> regex_pattern =
- unilib_.CreateRegexPattern(pattern);
+ unilib_->CreateRegexPattern(pattern);
ASSERT_TRUE(regex_pattern != nullptr);
const std::unique_ptr<UniLib::RegexMatcher> matcher =
regex_pattern->Matcher(message);
@@ -83,7 +83,7 @@
UTF8ToUnicodeText("never gonna (?:give (you) up|let (you) down)",
/*do_copy=*/true);
const std::unique_ptr<UniLib::RegexPattern> regex_pattern =
- unilib_.CreateRegexPattern(pattern);
+ unilib_->CreateRegexPattern(pattern);
ASSERT_TRUE(regex_pattern != nullptr);
UnicodeText message =
UTF8ToUnicodeText("never gonna give you up - never gonna let you down");
@@ -108,7 +108,6 @@
EXPECT_THAT(GetCapturingGroupText(matcher.get(), 2).value(),
testing::Eq("you"));
}
-#endif
} // namespace
} // namespace libtextclassifier3
diff --git a/native/utils/resources.cc b/native/utils/resources.cc
index 2ae2def..24b3a6f 100644
--- a/native/utils/resources.cc
+++ b/native/utils/resources.cc
@@ -18,7 +18,6 @@
#include "utils/base/logging.h"
#include "utils/zlib/buffer_generated.h"
-#include "utils/zlib/zlib.h"
namespace libtextclassifier3 {
namespace {
@@ -128,121 +127,8 @@
if (resource->content() != nullptr) {
*result = resource->content()->str();
return true;
- } else if (resource->compressed_content() != nullptr) {
- std::unique_ptr<ZlibDecompressor> decompressor = ZlibDecompressor::Instance(
- resources_->compression_dictionary()->data(),
- resources_->compression_dictionary()->size());
- if (decompressor != nullptr &&
- decompressor->MaybeDecompress(resource->compressed_content(), result)) {
- return true;
- }
}
return false;
}
-bool CompressResources(ResourcePoolT* resources,
- const bool build_compression_dictionary,
- const int dictionary_sample_every) {
- std::vector<unsigned char> dictionary;
- if (build_compression_dictionary) {
- {
- // Build up a compression dictionary.
- std::unique_ptr<ZlibCompressor> compressor = ZlibCompressor::Instance();
- int i = 0;
- for (auto& entry : resources->resource_entry) {
- for (auto& resource : entry->resource) {
- if (resource->content.empty()) {
- continue;
- }
- i++;
-
- // Use a sample of the entries to build up a custom compression
- // dictionary. Using all entries will generally not give a benefit
- // for small data sizes, so we subsample here.
- if (i % dictionary_sample_every != 0) {
- continue;
- }
- CompressedBufferT compressed_content;
- compressor->Compress(resource->content, &compressed_content);
- }
- }
- compressor->GetDictionary(&dictionary);
- resources->compression_dictionary.assign(
- dictionary.data(), dictionary.data() + dictionary.size());
- }
- }
-
- for (auto& entry : resources->resource_entry) {
- for (auto& resource : entry->resource) {
- if (resource->content.empty()) {
- continue;
- }
- // Try compressing the data.
- std::unique_ptr<ZlibCompressor> compressor =
- build_compression_dictionary
- ? ZlibCompressor::Instance(dictionary.data(), dictionary.size())
- : ZlibCompressor::Instance();
- if (!compressor) {
- TC3_LOG(ERROR) << "Cannot create zlib compressor.";
- return false;
- }
-
- CompressedBufferT compressed_content;
- compressor->Compress(resource->content, &compressed_content);
-
- // Only keep compressed version if smaller.
- if (compressed_content.uncompressed_size >
- compressed_content.buffer.size()) {
- resource->content.clear();
- resource->compressed_content.reset(new CompressedBufferT);
- *resource->compressed_content = compressed_content;
- }
- }
- }
- return true;
-}
-
-std::string CompressSerializedResources(const std::string& resources,
- const int dictionary_sample_every) {
- std::unique_ptr<ResourcePoolT> unpacked_resources(
- flatbuffers::GetRoot<ResourcePool>(resources.data())->UnPack());
- TC3_CHECK(unpacked_resources != nullptr);
- TC3_CHECK(
- CompressResources(unpacked_resources.get(), dictionary_sample_every));
- flatbuffers::FlatBufferBuilder builder;
- builder.Finish(ResourcePool::Pack(builder, unpacked_resources.get()));
- return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
- builder.GetSize());
-}
-
-bool DecompressResources(ResourcePoolT* resources,
- const bool build_compression_dictionary) {
- std::vector<unsigned char> dictionary;
-
- for (auto& entry : resources->resource_entry) {
- for (auto& resource : entry->resource) {
- if (resource->compressed_content == nullptr) {
- continue;
- }
-
- std::unique_ptr<ZlibDecompressor> zlib_decompressor =
- build_compression_dictionary
- ? ZlibDecompressor::Instance(dictionary.data(), dictionary.size())
- : ZlibDecompressor::Instance();
- if (!zlib_decompressor) {
- TC3_LOG(ERROR) << "Cannot initialize decompressor.";
- return false;
- }
-
- if (!zlib_decompressor->MaybeDecompress(
- resource->compressed_content.get(), &resource->content)) {
- TC3_LOG(ERROR) << "Cannot decompress resource.";
- return false;
- }
- resource->compressed_content.reset(nullptr);
- }
- }
- return true;
-}
-
} // namespace libtextclassifier3
diff --git a/native/utils/resources.fbs b/native/utils/resources.fbs
old mode 100755
new mode 100644
index aae57cf..b4d9b83
--- a/native/utils/resources.fbs
+++ b/native/utils/resources.fbs
@@ -21,7 +21,6 @@
table Resource {
locale:[int];
content:string (shared);
- compressed_content:CompressedBuffer;
}
namespace libtextclassifier3;
@@ -34,6 +33,5 @@
table ResourcePool {
locale:[LanguageTag];
resource_entry:[ResourceEntry];
- compression_dictionary:[ubyte];
}
diff --git a/native/utils/resources.h b/native/utils/resources.h
index 96f9683..ca601fe 100644
--- a/native/utils/resources.h
+++ b/native/utils/resources.h
@@ -63,18 +63,6 @@
const ResourcePool* resources_;
};
-// Compresses resources in place.
-bool CompressResources(ResourcePoolT* resources,
- const bool build_compression_dictionary = false,
- const int dictionary_sample_every = 1);
-std::string CompressSerializedResources(
- const std::string& resources,
- const bool build_compression_dictionary = false,
- const int dictionary_sample_every = 1);
-
-bool DecompressResources(ResourcePoolT* resources,
- const bool build_compression_dictionary = false);
-
} // namespace libtextclassifier3
#endif // LIBTEXTCLASSIFIER_UTILS_RESOURCES_H_
diff --git a/native/utils/resources_test.cc b/native/utils/resources_test.cc
index c385f39..6e3d0a1 100644
--- a/native/utils/resources_test.cc
+++ b/native/utils/resources_test.cc
@@ -15,6 +15,7 @@
*/
#include "utils/resources.h"
+
#include "utils/i18n/locale.h"
#include "utils/resources_generated.h"
#include "gmock/gmock.h"
@@ -23,8 +24,7 @@
namespace libtextclassifier3 {
namespace {
-class ResourcesTest
- : public testing::TestWithParam<testing::tuple<bool, bool>> {
+class ResourcesTest : public testing::Test {
protected:
ResourcesTest() {}
@@ -57,7 +57,7 @@
test_resources.locale.back()->language = "zh";
test_resources.locale.emplace_back(new LanguageTagT);
test_resources.locale.back()->language = "fr";
- test_resources.locale.back()->language = "fr-CA";
+ test_resources.locale.back()->region = "CA";
if (add_default_language) {
test_resources.locale.emplace_back(new LanguageTagT); // default
}
@@ -115,12 +115,6 @@
test_resources.resource_entry.back()->resource.back()->content = "龍";
test_resources.resource_entry.back()->resource.back()->locale.push_back(7);
- if (compress()) {
- EXPECT_TRUE(CompressResources(
- &test_resources,
- /*build_compression_dictionary=*/build_dictionary()));
- }
-
flatbuffers::FlatBufferBuilder builder;
builder.Finish(ResourcePool::Pack(builder, &test_resources));
@@ -128,16 +122,9 @@
reinterpret_cast<const char*>(builder.GetBufferPointer()),
builder.GetSize());
}
-
- bool compress() const { return testing::get<0>(GetParam()); }
-
- bool build_dictionary() const { return testing::get<1>(GetParam()); }
};
-INSTANTIATE_TEST_SUITE_P(Compression, ResourcesTest,
- testing::Combine(testing::Bool(), testing::Bool()));
-
-TEST_P(ResourcesTest, CorrectlyHandlesExactMatch) {
+TEST_F(ResourcesTest, CorrectlyHandlesExactMatch) {
std::string test_resources = BuildTestResources();
Resources resources(
flatbuffers::GetRoot<ResourcePool>(test_resources.data()));
@@ -162,7 +149,7 @@
EXPECT_EQ("localiser", content);
}
-TEST_P(ResourcesTest, CorrectlyHandlesTie) {
+TEST_F(ResourcesTest, CorrectlyHandlesTie) {
std::string test_resources = BuildTestResources();
Resources resources(
flatbuffers::GetRoot<ResourcePool>(test_resources.data()));
@@ -173,7 +160,7 @@
EXPECT_EQ("localize", content);
}
-TEST_P(ResourcesTest, RequiresLanguageMatch) {
+TEST_F(ResourcesTest, RequiresLanguageMatch) {
{
std::string test_resources =
BuildTestResources(/*add_default_language=*/false);
@@ -196,7 +183,7 @@
}
}
-TEST_P(ResourcesTest, HandlesFallback) {
+TEST_F(ResourcesTest, HandlesFallback) {
std::string test_resources = BuildTestResources();
Resources resources(
flatbuffers::GetRoot<ResourcePool>(test_resources.data()));
@@ -217,7 +204,7 @@
EXPECT_EQ("localize", content);
}
-TEST_P(ResourcesTest, HandlesFallbackMultipleLocales) {
+TEST_F(ResourcesTest, HandlesFallbackMultipleLocales) {
std::string test_resources = BuildTestResources();
Resources resources(
flatbuffers::GetRoot<ResourcePool>(test_resources.data()));
@@ -252,7 +239,7 @@
EXPECT_EQ("localize", content);
}
-TEST_P(ResourcesTest, PreferGenericCallback) {
+TEST_F(ResourcesTest, PreferGenericCallback) {
std::string test_resources = BuildTestResources();
Resources resources(
flatbuffers::GetRoot<ResourcePool>(test_resources.data()));
@@ -271,7 +258,7 @@
EXPECT_EQ("龍", content); // Falls back to zh, not zh-Hans-CN.
}
-TEST_P(ResourcesTest, PreferGenericWhenGeneric) {
+TEST_F(ResourcesTest, PreferGenericWhenGeneric) {
std::string test_resources = BuildTestResources();
Resources resources(
flatbuffers::GetRoot<ResourcePool>(test_resources.data()));
diff --git a/native/utils/sentencepiece/normalizer.cc b/native/utils/sentencepiece/normalizer.cc
index 4cee507..d2b0c06 100644
--- a/native/utils/sentencepiece/normalizer.cc
+++ b/native/utils/sentencepiece/normalizer.cc
@@ -124,8 +124,8 @@
}
const bool no_match = match.match_length <= 0;
if (no_match) {
- const int char_length = ValidUTF8CharLength(input.data(), input.size());
- if (char_length <= 0) {
+ int char_length;
+ if (!IsValidChar(input.data(), input.size(), &char_length)) {
// Found a malformed utf8.
// The rune is set to be 0xFFFD (REPLACEMENT CHARACTER),
// which is a valid Unicode of three bytes in utf8,
diff --git a/native/utils/sentencepiece/normalizer_test.cc b/native/utils/sentencepiece/normalizer_test.cc
new file mode 100644
index 0000000..57debe3
--- /dev/null
+++ b/native/utils/sentencepiece/normalizer_test.cc
@@ -0,0 +1,199 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/sentencepiece/normalizer.h"
+
+#include <fstream>
+#include <string>
+
+#include "utils/container/double-array-trie.h"
+#include "utils/sentencepiece/test_utils.h"
+#include "utils/strings/stringpiece.h"
+#include "utils/test-data-test-utils.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+std::string GetTestConfigPath() {
+ return GetTestDataPath("utils/sentencepiece/test_data/nmt_nfkc_charsmap.bin");
+}
+
+TEST(NormalizerTest, NormalizesAsReferenceNormalizer) {
+ std::ifstream test_config_stream(GetTestConfigPath());
+ std::string config((std::istreambuf_iterator<char>(test_config_stream)),
+ (std::istreambuf_iterator<char>()));
+ SentencePieceNormalizer normalizer =
+ NormalizerFromSpec(config, /*add_dummy_prefix=*/true,
+ /*remove_extra_whitespaces=*/true,
+ /*escape_whitespaces=*/true);
+ {
+ std::string normalized;
+ EXPECT_TRUE(normalizer.Normalize("hello there", &normalized));
+ EXPECT_EQ(normalized, "▁hello▁there");
+ }
+
+ // Redundant whitespace.
+ {
+ std::string normalized;
+ EXPECT_TRUE(normalizer.Normalize("when is the world cup?", &normalized));
+ EXPECT_EQ(normalized, "▁when▁is▁the▁world▁cup?");
+ }
+
+ // Different whitespace.
+ {
+ std::string normalized;
+ EXPECT_TRUE(normalizer.Normalize("general\tkenobi", &normalized));
+ EXPECT_EQ(normalized, "▁general▁kenobi");
+ }
+
+ // NFKC char to multi-char normalization.
+ {
+ std::string normalized;
+ EXPECT_TRUE(normalizer.Normalize("㍿", &normalized));
+ EXPECT_EQ(normalized, "▁株式会社");
+ }
+
+ // Half width katakana, character composition happens.
+ {
+ std::string normalized;
+ EXPECT_TRUE(normalizer.Normalize(" グーグル ", &normalized));
+ EXPECT_EQ(normalized, "▁グーグル");
+ }
+
+ // NFKC char to char normalization.
+ {
+ std::string normalized;
+ EXPECT_TRUE(normalizer.Normalize("①②③", &normalized));
+ EXPECT_EQ(normalized, "▁123");
+ }
+}
+
+TEST(NormalizerTest, NoDummyPrefix) {
+ std::ifstream test_config_stream(GetTestConfigPath());
+ std::string config((std::istreambuf_iterator<char>(test_config_stream)),
+ (std::istreambuf_iterator<char>()));
+ SentencePieceNormalizer normalizer =
+ NormalizerFromSpec(config, /*add_dummy_prefix=*/false,
+ /*remove_extra_whitespaces=*/true,
+ /*escape_whitespaces=*/true);
+
+ // NFKC char to char normalization.
+ {
+ std::string normalized;
+ EXPECT_TRUE(normalizer.Normalize("hello there", &normalized));
+ EXPECT_EQ(normalized, "hello▁there");
+ }
+
+ // Redundant whitespace.
+ {
+ std::string normalized;
+ EXPECT_TRUE(normalizer.Normalize("when is the world cup?", &normalized));
+ EXPECT_EQ(normalized, "when▁is▁the▁world▁cup?");
+ }
+
+ // Different whitespace.
+ {
+ std::string normalized;
+ EXPECT_TRUE(normalizer.Normalize("general\tkenobi", &normalized));
+ EXPECT_EQ(normalized, "general▁kenobi");
+ }
+
+ // NFKC char to multi-char normalization.
+ {
+ std::string normalized;
+ EXPECT_TRUE(normalizer.Normalize("㍿", &normalized));
+ EXPECT_EQ(normalized, "株式会社");
+ }
+
+ // Half width katakana, character composition happens.
+ {
+ std::string normalized;
+ EXPECT_TRUE(normalizer.Normalize(" グーグル ", &normalized));
+ EXPECT_EQ(normalized, "グーグル");
+ }
+
+ // NFKC char to char normalization.
+ {
+ std::string normalized;
+ EXPECT_TRUE(normalizer.Normalize("①②③", &normalized));
+ EXPECT_EQ(normalized, "123");
+ }
+}
+
+TEST(NormalizerTest, NoRemoveExtraWhitespace) {
+ std::ifstream test_config_stream(GetTestConfigPath());
+ std::string config((std::istreambuf_iterator<char>(test_config_stream)),
+ (std::istreambuf_iterator<char>()));
+ SentencePieceNormalizer normalizer =
+ NormalizerFromSpec(config, /*add_dummy_prefix=*/false,
+ /*remove_extra_whitespaces=*/false,
+ /*escape_whitespaces=*/true);
+
+ {
+ std::string normalized;
+ EXPECT_TRUE(normalizer.Normalize("hello there", &normalized));
+ EXPECT_EQ(normalized, "hello▁there");
+ }
+
+ // Redundant whitespace.
+ {
+ std::string normalized;
+ EXPECT_TRUE(normalizer.Normalize("when is the world cup?", &normalized));
+ EXPECT_EQ(normalized, "when▁is▁▁the▁▁world▁cup?");
+ }
+
+ // Different whitespace.
+ {
+ std::string normalized;
+ EXPECT_TRUE(normalizer.Normalize("general\tkenobi", &normalized));
+ EXPECT_EQ(normalized, "general▁kenobi");
+ }
+}
+
+TEST(NormalizerTest, NoEscapeWhitespaces) {
+ std::ifstream test_config_stream(GetTestConfigPath());
+ std::string config((std::istreambuf_iterator<char>(test_config_stream)),
+ (std::istreambuf_iterator<char>()));
+ SentencePieceNormalizer normalizer =
+ NormalizerFromSpec(config, /*add_dummy_prefix=*/false,
+ /*remove_extra_whitespaces=*/false,
+ /*escape_whitespaces=*/false);
+
+ {
+ std::string normalized;
+ EXPECT_TRUE(normalizer.Normalize("hello there", &normalized));
+ EXPECT_EQ(normalized, "hello there");
+ }
+
+ // Redundant whitespace.
+ {
+ std::string normalized;
+ EXPECT_TRUE(normalizer.Normalize("when is the world cup?", &normalized));
+ EXPECT_EQ(normalized, "when is the world cup?");
+ }
+
+ // Different whitespace.
+ {
+ std::string normalized;
+ EXPECT_TRUE(normalizer.Normalize("general\tkenobi", &normalized));
+ EXPECT_EQ(normalized, "general kenobi");
+ }
+}
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/native/utils/sentencepiece/test_data/nmt_nfkc_charsmap.bin b/native/utils/sentencepiece/test_data/nmt_nfkc_charsmap.bin
new file mode 100644
index 0000000..74da62d
--- /dev/null
+++ b/native/utils/sentencepiece/test_data/nmt_nfkc_charsmap.bin
Binary files differ
diff --git a/native/utils/sentencepiece/test_utils.cc b/native/utils/sentencepiece/test_utils.cc
new file mode 100644
index 0000000..f277a14
--- /dev/null
+++ b/native/utils/sentencepiece/test_utils.cc
@@ -0,0 +1,42 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/sentencepiece/test_utils.h"
+
+#include <memory>
+
+#include "utils/base/integral_types.h"
+#include "utils/container/double-array-trie.h"
+#include "utils/strings/stringpiece.h"
+
+namespace libtextclassifier3 {
+
+SentencePieceNormalizer NormalizerFromSpec(StringPiece spec,
+ bool add_dummy_prefix,
+ bool remove_extra_whitespaces,
+ bool escape_whitespaces) {
+ const uint32 trie_blob_size = reinterpret_cast<const uint32*>(spec.data())[0];
+ spec.RemovePrefix(sizeof(trie_blob_size));
+ const TrieNode* trie_blob = reinterpret_cast<const TrieNode*>(spec.data());
+ spec.RemovePrefix(trie_blob_size);
+ const int num_nodes = trie_blob_size / sizeof(TrieNode);
+ return SentencePieceNormalizer(
+ DoubleArrayTrie(trie_blob, num_nodes),
+ /*charsmap_normalized=*/StringPiece(spec.data(), spec.size()),
+ add_dummy_prefix, remove_extra_whitespaces, escape_whitespaces);
+}
+
+} // namespace libtextclassifier3
diff --git a/native/utils/sentencepiece/test_utils.h b/native/utils/sentencepiece/test_utils.h
new file mode 100644
index 0000000..0c833da
--- /dev/null
+++ b/native/utils/sentencepiece/test_utils.h
@@ -0,0 +1,35 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_TEST_UTILS_H_
+#define LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_TEST_UTILS_H_
+
+#include <string>
+#include <vector>
+
+#include "utils/sentencepiece/normalizer.h"
+#include "utils/strings/stringpiece.h"
+
+namespace libtextclassifier3 {
+
+SentencePieceNormalizer NormalizerFromSpec(StringPiece spec,
+ bool add_dummy_prefix,
+ bool remove_extra_whitespaces,
+ bool escape_whitespaces);
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_TEST_UTILS_H_
diff --git a/native/utils/strings/utf8.cc b/native/utils/strings/utf8.cc
index 932e2a5..b3ed0af 100644
--- a/native/utils/strings/utf8.cc
+++ b/native/utils/strings/utf8.cc
@@ -19,10 +19,11 @@
#include "utils/base/logging.h"
namespace libtextclassifier3 {
+
bool IsValidUTF8(const char *src, int size) {
+ int char_length;
for (int i = 0; i < size;) {
- const int char_length = ValidUTF8CharLength(src + i, size - i);
- if (char_length <= 0) {
+ if (!IsValidChar(src + i, size - i, &char_length)) {
return false;
}
i += char_length;
@@ -30,27 +31,6 @@
return true;
}
-int ValidUTF8CharLength(const char *src, int size) {
- // Unexpected trail byte.
- if (IsTrailByte(src[0])) {
- return -1;
- }
-
- const int num_codepoint_bytes = GetNumBytesForUTF8Char(&src[0]);
- if (num_codepoint_bytes <= 0 || num_codepoint_bytes > size) {
- return -1;
- }
-
- // Check that remaining bytes in the codepoint are trailing bytes.
- for (int k = 1; k < num_codepoint_bytes; k++) {
- if (!IsTrailByte(src[k])) {
- return -1;
- }
- }
-
- return num_codepoint_bytes;
-}
-
int SafeTruncateLength(const char *str, int truncate_at) {
// Always want to truncate at the start of a character, so if
// it's in a middle, back up toward the start
@@ -88,6 +68,47 @@
((byte3 & 0x3F) << 6) | (byte4 & 0x3F);
}
+bool IsValidChar(const char *str, int size, int *num_bytes) {
+ // Unexpected trail byte.
+ if (IsTrailByte(str[0])) {
+ return false;
+ }
+
+ *num_bytes = GetNumBytesForUTF8Char(str);
+ if (*num_bytes <= 0 || *num_bytes > size) {
+ return false;
+ }
+
+ // Check that remaining bytes in the codepoint are trailing bytes.
+ for (int k = 1; k < *num_bytes; k++) {
+ if (!IsTrailByte(str[k])) {
+ return false;
+ }
+ }
+
+ // Exclude overlong encodings.
+ // Check that the codepoint is encoded with the minimum number of required
+ // bytes. An ascii value could be encoded in 4, 3 or 2 bytes but requires
+ // only 1. There is a unique valid encoding for each code point.
+ // This ensures that string comparisons and searches are well-defined.
+ // See: https://en.wikipedia.org/wiki/UTF-8
+ const char32 codepoint = ValidCharToRune(str);
+ switch (*num_bytes) {
+ case 1:
+ return true;
+ case 2:
+ // Everything below 128 can be encoded in one byte.
+ return (codepoint >= (1 << 7 /* num. payload bits in one byte */));
+ case 3:
+ return (codepoint >= (1 << 11 /* num. payload bits in two utf8 bytes */));
+ case 4:
+ return (codepoint >=
+ (1 << 16 /* num. payload bits in three utf8 bytes */)) &&
+ (codepoint < 0x10FFFF /* maximum rune value */);
+ }
+ return false;
+}
+
int ValidRuneToChar(const char32 rune, char *dest) {
// Convert to unsigned for range check.
uint32 c;
diff --git a/native/utils/strings/utf8.h b/native/utils/strings/utf8.h
index e871731..370cf23 100644
--- a/native/utils/strings/utf8.h
+++ b/native/utils/strings/utf8.h
@@ -41,10 +41,6 @@
// Returns true iff src points to a well-formed UTF-8 string.
bool IsValidUTF8(const char *src, int size);
-// Returns byte length of the first valid codepoint in the string, otherwise -1
-// if pointing to an ill-formed UTF-8 character.
-int ValidUTF8CharLength(const char *src, int size);
-
// Helper to ensure that strings are not truncated in the middle of
// multi-byte UTF-8 characters.
// Given a string, and a position at which to truncate, returns the
@@ -55,6 +51,10 @@
// Gets a unicode codepoint from a valid utf8 encoding.
char32 ValidCharToRune(const char *str);
+// Checks whether a utf8 encoding is a valid codepoint and returns the number of
+// bytes of the codepoint.
+bool IsValidChar(const char *str, int size, int *num_bytes);
+
// Converts a valid codepoint to utf8.
// Returns the length of the encoding.
int ValidRuneToChar(const char32 rune, char *dest);
diff --git a/native/utils/strings/utf8_test.cc b/native/utils/strings/utf8_test.cc
index 28d971b..5b4b748 100644
--- a/native/utils/strings/utf8_test.cc
+++ b/native/utils/strings/utf8_test.cc
@@ -34,25 +34,18 @@
EXPECT_TRUE(IsValidUTF8("\u304A\u00B0\u106B", 8));
EXPECT_TRUE(IsValidUTF8("this is a test😋😋😋", 26));
EXPECT_TRUE(IsValidUTF8("\xf0\x9f\x98\x8b", 4));
+ // Example with first byte payload of zero.
+ EXPECT_TRUE(IsValidUTF8("\xf0\x90\x80\x80", 4));
// Too short (string is too short).
EXPECT_FALSE(IsValidUTF8("\xf0\x9f", 2));
// Too long (too many trailing bytes).
EXPECT_FALSE(IsValidUTF8("\xf0\x9f\x98\x8b\x8b", 5));
// Too short (too few trailing bytes).
EXPECT_FALSE(IsValidUTF8("\xf0\x9f\x98\x61\x61", 5));
-}
-
-TEST(Utf8Test, ValidUTF8CharLength) {
- EXPECT_EQ(ValidUTF8CharLength("1234😋hello", 13), 1);
- EXPECT_EQ(ValidUTF8CharLength("\u304A\u00B0\u106B", 8), 3);
- EXPECT_EQ(ValidUTF8CharLength("this is a test😋😋😋", 26), 1);
- EXPECT_EQ(ValidUTF8CharLength("\xf0\x9f\x98\x8b", 4), 4);
- // Too short (string is too short).
- EXPECT_EQ(ValidUTF8CharLength("\xf0\x9f", 2), -1);
- // Too long (too many trailing bytes). First character is valid.
- EXPECT_EQ(ValidUTF8CharLength("\xf0\x9f\x98\x8b\x8b", 5), 4);
- // Too short (too few trailing bytes).
- EXPECT_EQ(ValidUTF8CharLength("\xf0\x9f\x98\x61\x61", 5), -1);
+ // Invalid continuation byte (can be encoded in less bytes).
+ EXPECT_FALSE(IsValidUTF8("\xc0\x81", 2));
+ // Invalid continuation byte (can be encoded in less bytes).
+ EXPECT_FALSE(IsValidUTF8("\xf0\x8a\x85\x8f", 4));
}
TEST(Utf8Test, CorrectlyTruncatesStrings) {
diff --git a/native/utils/test-data-test-utils.h b/native/utils/test-data-test-utils.h
new file mode 100644
index 0000000..61f6d97
--- /dev/null
+++ b/native/utils/test-data-test-utils.h
@@ -0,0 +1,40 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Utilities for accessing test data.
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_TEST_DATA_TEST_UTILS_H_
+#define LIBTEXTCLASSIFIER_UTILS_TEST_DATA_TEST_UTILS_H_
+#include <fstream>
+
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+
+// Get the file path to the test data.
+inline std::string GetTestDataPath(const std::string& relative_path) {
+ return "/data/local/tmp/" + relative_path;
+}
+
+inline std::string GetTestFileContent(const std::string& relative_path) {
+ const std::string full_path = GetTestDataPath(relative_path);
+ std::ifstream file_stream(full_path);
+ return std::string(std::istreambuf_iterator<char>(file_stream), {});
+}
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_TEST_DATA_TEST_UTILS_H_
diff --git a/native/utils/test-utils.cc b/native/utils/test-utils.cc
deleted file mode 100644
index 8996a4a..0000000
--- a/native/utils/test-utils.cc
+++ /dev/null
@@ -1,68 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/test-utils.h"
-
-#include <iterator>
-
-#include "utils/codepoint-range.h"
-#include "utils/strings/utf8.h"
-#include "utils/utf8/unicodetext.h"
-
-namespace libtextclassifier3 {
-
-using libtextclassifier3::Token;
-
-std::vector<Token> TokenizeOnSpace(const std::string& text) {
- return TokenizeOnDelimiters(text, {' '});
-}
-
-std::vector<Token> TokenizeOnDelimiters(
- const std::string& text, const std::unordered_set<char32>& delimiters) {
- const UnicodeText unicode_text = UTF8ToUnicodeText(text, /*do_copy=*/false);
-
- std::vector<Token> result;
-
- int token_start_codepoint = 0;
- auto token_start_it = unicode_text.begin();
- int codepoint_idx = 0;
-
- UnicodeText::const_iterator it;
- for (it = unicode_text.begin(); it < unicode_text.end(); it++) {
- if (delimiters.find(*it) != delimiters.end()) {
- // Only add a token when the string is non-empty.
- if (token_start_it != it) {
- result.push_back(Token{UnicodeText::UTF8Substring(token_start_it, it),
- token_start_codepoint, codepoint_idx});
- }
-
- token_start_codepoint = codepoint_idx + 1;
- token_start_it = it;
- token_start_it++;
- }
-
- codepoint_idx++;
- }
- // Only add a token when the string is non-empty.
- if (token_start_it != it) {
- result.push_back(Token{UnicodeText::UTF8Substring(token_start_it, it),
- token_start_codepoint, codepoint_idx});
- }
-
- return result;
-}
-
-} // namespace libtextclassifier3
diff --git a/native/utils/test-utils.h b/native/utils/test-utils.h
deleted file mode 100644
index 0e75190..0000000
--- a/native/utils/test-utils.h
+++ /dev/null
@@ -1,38 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-// Utilities for tests.
-
-#ifndef LIBTEXTCLASSIFIER_UTILS_TEST_UTILS_H_
-#define LIBTEXTCLASSIFIER_UTILS_TEST_UTILS_H_
-
-#include <string>
-
-#include "annotator/types.h"
-
-namespace libtextclassifier3 {
-
-// Returns a list of Tokens for a given input string, by tokenizing on space.
-std::vector<Token> TokenizeOnSpace(const std::string& text);
-
-// Returns a list of Tokens for a given input string, by tokenizing on the
-// given set of delimiter codepoints.
-std::vector<Token> TokenizeOnDelimiters(
- const std::string& text, const std::unordered_set<char32>& delimiters);
-
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_UTILS_TEST_UTILS_H_
diff --git a/native/utils/test-utils_test.cc b/native/utils/test-utils_test.cc
deleted file mode 100644
index bdaa285..0000000
--- a/native/utils/test-utils_test.cc
+++ /dev/null
@@ -1,100 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/test-utils.h"
-
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-
-namespace libtextclassifier3 {
-namespace {
-
-TEST(TestUtilTest, TokenizeOnSpace) {
- std::vector<Token> tokens =
- TokenizeOnSpace("Where is Jörg Borg located? Maybe in Zürich ...");
-
- EXPECT_EQ(tokens.size(), 9);
-
- EXPECT_EQ(tokens[0].value, "Where");
- EXPECT_EQ(tokens[0].start, 0);
- EXPECT_EQ(tokens[0].end, 5);
-
- EXPECT_EQ(tokens[1].value, "is");
- EXPECT_EQ(tokens[1].start, 6);
- EXPECT_EQ(tokens[1].end, 8);
-
- EXPECT_EQ(tokens[2].value, "Jörg");
- EXPECT_EQ(tokens[2].start, 9);
- EXPECT_EQ(tokens[2].end, 13);
-
- EXPECT_EQ(tokens[3].value, "Borg");
- EXPECT_EQ(tokens[3].start, 14);
- EXPECT_EQ(tokens[3].end, 18);
-
- EXPECT_EQ(tokens[4].value, "located?");
- EXPECT_EQ(tokens[4].start, 19);
- EXPECT_EQ(tokens[4].end, 27);
-
- EXPECT_EQ(tokens[5].value, "Maybe");
- EXPECT_EQ(tokens[5].start, 28);
- EXPECT_EQ(tokens[5].end, 33);
-
- EXPECT_EQ(tokens[6].value, "in");
- EXPECT_EQ(tokens[6].start, 34);
- EXPECT_EQ(tokens[6].end, 36);
-
- EXPECT_EQ(tokens[7].value, "Zürich");
- EXPECT_EQ(tokens[7].start, 37);
- EXPECT_EQ(tokens[7].end, 43);
-
- EXPECT_EQ(tokens[8].value, "...");
- EXPECT_EQ(tokens[8].start, 44);
- EXPECT_EQ(tokens[8].end, 47);
-}
-
-TEST(TestUtilTest, TokenizeOnDelimiters) {
- std::vector<Token> tokens = TokenizeOnDelimiters(
- "This might be čomplíčateď?!: Oder?", {' ', '?', '!'});
-
- EXPECT_EQ(tokens.size(), 6);
-
- EXPECT_EQ(tokens[0].value, "This");
- EXPECT_EQ(tokens[0].start, 0);
- EXPECT_EQ(tokens[0].end, 4);
-
- EXPECT_EQ(tokens[1].value, "might");
- EXPECT_EQ(tokens[1].start, 7);
- EXPECT_EQ(tokens[1].end, 12);
-
- EXPECT_EQ(tokens[2].value, "be");
- EXPECT_EQ(tokens[2].start, 13);
- EXPECT_EQ(tokens[2].end, 15);
-
- EXPECT_EQ(tokens[3].value, "čomplíčateď");
- EXPECT_EQ(tokens[3].start, 16);
- EXPECT_EQ(tokens[3].end, 27);
-
- EXPECT_EQ(tokens[4].value, ":");
- EXPECT_EQ(tokens[4].start, 29);
- EXPECT_EQ(tokens[4].end, 30);
-
- EXPECT_EQ(tokens[5].value, "Oder");
- EXPECT_EQ(tokens[5].start, 31);
- EXPECT_EQ(tokens[5].end, 35);
-}
-
-} // namespace
-} // namespace libtextclassifier3
diff --git a/native/utils/testing/annotator.cc b/native/utils/testing/annotator.cc
new file mode 100644
index 0000000..47c3fb7
--- /dev/null
+++ b/native/utils/testing/annotator.cc
@@ -0,0 +1,206 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/testing/annotator.h"
+
+#include "utils/flatbuffers/mutable.h"
+#include "flatbuffers/reflection.h"
+
+namespace libtextclassifier3 {
+
+std::string FirstResult(const std::vector<ClassificationResult>& results) {
+ if (results.empty()) {
+ return "<INVALID RESULTS>";
+ }
+ return results[0].collection;
+}
+
+std::string ReadFile(const std::string& file_name) {
+ std::ifstream file_stream(file_name);
+ return std::string(std::istreambuf_iterator<char>(file_stream), {});
+}
+
+std::unique_ptr<RegexModel_::PatternT> MakePattern(
+ const std::string& collection_name, const std::string& pattern,
+ const bool enabled_for_classification, const bool enabled_for_selection,
+ const bool enabled_for_annotation, const float score,
+ const float priority_score) {
+ std::unique_ptr<RegexModel_::PatternT> result(new RegexModel_::PatternT);
+ result->collection_name = collection_name;
+ result->pattern = pattern;
+ // We cannot directly operate with |= on the flag, so use an int here.
+ int enabled_modes = ModeFlag_NONE;
+ if (enabled_for_annotation) enabled_modes |= ModeFlag_ANNOTATION;
+ if (enabled_for_classification) enabled_modes |= ModeFlag_CLASSIFICATION;
+ if (enabled_for_selection) enabled_modes |= ModeFlag_SELECTION;
+ result->enabled_modes = static_cast<ModeFlag>(enabled_modes);
+ result->target_classification_score = score;
+ result->priority_score = priority_score;
+ return result;
+}
+
+// Shortcut function that doesn't need to specify the priority score.
+std::unique_ptr<RegexModel_::PatternT> MakePattern(
+ const std::string& collection_name, const std::string& pattern,
+ const bool enabled_for_classification, const bool enabled_for_selection,
+ const bool enabled_for_annotation, const float score) {
+ return MakePattern(collection_name, pattern, enabled_for_classification,
+ enabled_for_selection, enabled_for_annotation,
+ /*score=*/score,
+ /*priority_score=*/score);
+}
+
+void AddTestRegexModel(ModelT* unpacked_model) {
+ // Add test regex models.
+ unpacked_model->regex_model->patterns.push_back(MakePattern(
+ "person_with_age", "(Barack) (?:(Obama) )?is (\\d+) years old",
+ /*enabled_for_classification=*/true,
+ /*enabled_for_selection=*/true, /*enabled_for_annotation=*/true, 1.0));
+
+ // Use meta data to generate custom serialized entity data.
+ MutableFlatbufferBuilder entity_data_builder(
+ flatbuffers::GetRoot<reflection::Schema>(
+ unpacked_model->entity_data_schema.data()));
+ RegexModel_::PatternT* pattern =
+ unpacked_model->regex_model->patterns.back().get();
+
+ {
+ std::unique_ptr<MutableFlatbuffer> entity_data =
+ entity_data_builder.NewRoot();
+ entity_data->Set("is_alive", true);
+ pattern->serialized_entity_data = entity_data->Serialize();
+ }
+ pattern->capturing_group.emplace_back(new CapturingGroupT);
+ pattern->capturing_group.emplace_back(new CapturingGroupT);
+ pattern->capturing_group.emplace_back(new CapturingGroupT);
+ pattern->capturing_group.emplace_back(new CapturingGroupT);
+ // Group 0 is the full match, capturing groups starting at 1.
+ pattern->capturing_group[1]->entity_field_path.reset(
+ new FlatbufferFieldPathT);
+ pattern->capturing_group[1]->entity_field_path->field.emplace_back(
+ new FlatbufferFieldT);
+ pattern->capturing_group[1]->entity_field_path->field.back()->field_name =
+ "first_name";
+ pattern->capturing_group[2]->entity_field_path.reset(
+ new FlatbufferFieldPathT);
+ pattern->capturing_group[2]->entity_field_path->field.emplace_back(
+ new FlatbufferFieldT);
+ pattern->capturing_group[2]->entity_field_path->field.back()->field_name =
+ "last_name";
+ // Set `former_us_president` field if we match Obama.
+ {
+ std::unique_ptr<MutableFlatbuffer> entity_data =
+ entity_data_builder.NewRoot();
+ entity_data->Set("former_us_president", true);
+ pattern->capturing_group[2]->serialized_entity_data =
+ entity_data->Serialize();
+ }
+ pattern->capturing_group[3]->entity_field_path.reset(
+ new FlatbufferFieldPathT);
+ pattern->capturing_group[3]->entity_field_path->field.emplace_back(
+ new FlatbufferFieldT);
+ pattern->capturing_group[3]->entity_field_path->field.back()->field_name =
+ "age";
+}
+
+std::string CreateEmptyModel(
+ const std::function<void(ModelT* model)> model_update_fn) {
+ ModelT model;
+ model_update_fn(&model);
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, &model));
+ return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize());
+}
+
+// Create fake entity data schema meta data.
+void AddTestEntitySchemaData(ModelT* unpacked_model) {
+ // Cannot use object oriented API here as that is not available for the
+ // reflection schema.
+ flatbuffers::FlatBufferBuilder schema_builder;
+ std::vector<flatbuffers::Offset<reflection::Field>> fields = {
+ reflection::CreateField(
+ schema_builder,
+ /*name=*/schema_builder.CreateString("first_name"),
+ /*type=*/
+ reflection::CreateType(schema_builder,
+ /*base_type=*/reflection::String),
+ /*id=*/0,
+ /*offset=*/4),
+ reflection::CreateField(
+ schema_builder,
+ /*name=*/schema_builder.CreateString("is_alive"),
+ /*type=*/
+ reflection::CreateType(schema_builder,
+ /*base_type=*/reflection::Bool),
+ /*id=*/1,
+ /*offset=*/6),
+ reflection::CreateField(
+ schema_builder,
+ /*name=*/schema_builder.CreateString("last_name"),
+ /*type=*/
+ reflection::CreateType(schema_builder,
+ /*base_type=*/reflection::String),
+ /*id=*/2,
+ /*offset=*/8),
+ reflection::CreateField(
+ schema_builder,
+ /*name=*/schema_builder.CreateString("age"),
+ /*type=*/
+ reflection::CreateType(schema_builder,
+ /*base_type=*/reflection::Int),
+ /*id=*/3,
+ /*offset=*/10),
+ reflection::CreateField(
+ schema_builder,
+ /*name=*/schema_builder.CreateString("former_us_president"),
+ /*type=*/
+ reflection::CreateType(schema_builder,
+ /*base_type=*/reflection::Bool),
+ /*id=*/4,
+ /*offset=*/12)};
+ std::vector<flatbuffers::Offset<reflection::Enum>> enums;
+ std::vector<flatbuffers::Offset<reflection::Object>> objects = {
+ reflection::CreateObject(
+ schema_builder,
+ /*name=*/schema_builder.CreateString("EntityData"),
+ /*fields=*/
+ schema_builder.CreateVectorOfSortedTables(&fields))};
+ schema_builder.Finish(reflection::CreateSchema(
+ schema_builder, schema_builder.CreateVectorOfSortedTables(&objects),
+ schema_builder.CreateVectorOfSortedTables(&enums),
+ /*(unused) file_ident=*/0,
+ /*(unused) file_ext=*/0,
+ /*root_table*/ objects[0]));
+
+ unpacked_model->entity_data_schema.assign(
+ schema_builder.GetBufferPointer(),
+ schema_builder.GetBufferPointer() + schema_builder.GetSize());
+}
+
+AnnotatedSpan MakeAnnotatedSpan(CodepointSpan span,
+ const std::string& collection,
+ const float score,
+ AnnotatedSpan::Source source) {
+ AnnotatedSpan result;
+ result.span = span;
+ result.classification.push_back({collection, score});
+ result.source = source;
+ return result;
+}
+
+} // namespace libtextclassifier3
diff --git a/native/utils/testing/annotator.h b/native/utils/testing/annotator.h
new file mode 100644
index 0000000..794e55f
--- /dev/null
+++ b/native/utils/testing/annotator.h
@@ -0,0 +1,83 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Helper utilities for testing Annotator.
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_TESTING_ANNOTATOR_H_
+#define LIBTEXTCLASSIFIER_UTILS_TESTING_ANNOTATOR_H_
+
+#include <fstream>
+#include <memory>
+#include <string>
+
+#include "annotator/model_generated.h"
+#include "annotator/types.h"
+#include "flatbuffers/flatbuffers.h"
+
+namespace libtextclassifier3 {
+
+// Loads FlatBuffer model, unpacks it and passes it to the visitor_fn so that it
+// can modify it. Afterwards the modified unpacked model is serialized back to a
+// flatbuffer.
+template <typename Fn>
+std::string ModifyAnnotatorModel(const std::string& model_flatbuffer,
+ Fn visitor_fn) {
+ std::unique_ptr<ModelT> unpacked_model =
+ UnPackModel(model_flatbuffer.c_str());
+
+ visitor_fn(unpacked_model.get());
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ return std::string(reinterpret_cast<char*>(builder.GetBufferPointer()),
+ builder.GetSize());
+}
+
+std::string FirstResult(const std::vector<ClassificationResult>& results);
+
+std::string ReadFile(const std::string& file_name);
+
+std::unique_ptr<RegexModel_::PatternT> MakePattern(
+ const std::string& collection_name, const std::string& pattern,
+ const bool enabled_for_classification, const bool enabled_for_selection,
+ const bool enabled_for_annotation, const float score,
+ const float priority_score);
+
+// Shortcut function that doesn't need to specify the priority score.
+std::unique_ptr<RegexModel_::PatternT> MakePattern(
+ const std::string& collection_name, const std::string& pattern,
+ const bool enabled_for_classification, const bool enabled_for_selection,
+ const bool enabled_for_annotation, const float score);
+
+void AddTestRegexModel(ModelT* unpacked_model);
+
+// Creates empty serialized Annotator model. Takes optional function that can
+// modify the unpacked ModelT before the serialization.
+std::string CreateEmptyModel(
+ const std::function<void(ModelT* model)> model_update_fn =
+ [](ModelT* model) {});
+
+// Create fake entity data schema meta data.
+void AddTestEntitySchemaData(ModelT* unpacked_model);
+
+AnnotatedSpan MakeAnnotatedSpan(
+ CodepointSpan span, const std::string& collection, const float score,
+ AnnotatedSpan::Source source = AnnotatedSpan::Source::OTHER);
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_TESTING_ANNOTATOR_H_
diff --git a/native/utils/testing/logging_event_listener.cc b/native/utils/testing/logging_event_listener.cc
new file mode 100644
index 0000000..bfc6b95
--- /dev/null
+++ b/native/utils/testing/logging_event_listener.cc
@@ -0,0 +1,119 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/testing/logging_event_listener.h"
+
+#include "utils/base/logging.h"
+
+namespace libtextclassifier3 {
+
+void LoggingEventListener::OnTestProgramStart(
+ const testing::UnitTest& /* unit_test */) {}
+
+void LoggingEventListener::OnTestIterationStart(
+ const testing::UnitTest& unit_test, int iteration) {
+ TC3_LOG(INFO) << "[==========] Running " << unit_test.test_to_run_count()
+ << " test(s) from " << unit_test.test_case_to_run_count()
+ << " test case(s)";
+}
+
+void LoggingEventListener::OnEnvironmentsSetUpStart(
+ const testing::UnitTest& unit_test) {
+ TC3_LOG(INFO) << "[----------] Global test environment set-up.";
+}
+
+void LoggingEventListener::OnEnvironmentsSetUpEnd(
+ const testing::UnitTest& /* unit_test */) {}
+
+void LoggingEventListener::OnTestCaseStart(const testing::TestCase& test_case) {
+ std::string param_text;
+ if (test_case.type_param()) {
+ param_text.append(", where TypeParam = ").append(test_case.type_param());
+ }
+ TC3_LOG(INFO) << "[----------] " << test_case.test_to_run_count()
+ << " test(s) from " << test_case.name() << param_text;
+}
+
+void LoggingEventListener::OnTestStart(const testing::TestInfo& test_info) {
+ TC3_LOG(INFO) << "[ RUN ] " << test_info.test_case_name() << "."
+ << test_info.name();
+}
+
+void LoggingEventListener::OnTestPartResult(
+ const testing::TestPartResult& test_part_result) {
+ if (test_part_result.type() != testing::TestPartResult::kSuccess) {
+ TC3_LOG(ERROR) << test_part_result.file_name() << ":"
+ << test_part_result.line_number() << ": Failure "
+ << test_part_result.message();
+ }
+}
+
+void LoggingEventListener::OnTestEnd(const testing::TestInfo& test_info) {
+ if (test_info.result()->Passed()) {
+ TC3_LOG(INFO) << "[ OK ] " << test_info.test_case_name() << "."
+ << test_info.name();
+ } else {
+ TC3_LOG(ERROR) << "[ FAILED ] " << test_info.test_case_name() << "."
+ << test_info.name();
+ }
+}
+
+void LoggingEventListener::OnTestCaseEnd(const testing::TestCase& test_case) {
+ TC3_LOG(INFO) << "[----------] " << test_case.test_to_run_count()
+ << " test(s) from " << test_case.name() << " ("
+ << test_case.elapsed_time() << " ms total)";
+}
+
+void LoggingEventListener::OnEnvironmentsTearDownStart(
+ const testing::UnitTest& unit_test) {
+ TC3_LOG(INFO) << "[----------] Global test environment tear-down.";
+}
+
+void LoggingEventListener::OnEnvironmentsTearDownEnd(
+ const testing::UnitTest& /* unit_test */) {}
+
+void LoggingEventListener::OnTestIterationEnd(
+ const testing::UnitTest& unit_test, int iteration) {
+ TC3_LOG(INFO) << "[==========] " << unit_test.test_to_run_count()
+ << " test(s) from " << unit_test.test_case_to_run_count()
+ << " test case(s) ran. (" << unit_test.elapsed_time()
+ << " ms total)";
+ TC3_LOG(INFO) << "[ PASSED ] " << unit_test.successful_test_count()
+ << " test(s)";
+ if (!unit_test.Passed()) {
+ TC3_LOG(ERROR) << "[ FAILED ] " << unit_test.failed_test_count()
+ << " test(s), listed below:";
+ for (int i = 0; i < unit_test.total_test_case_count(); ++i) {
+ const testing::TestCase& test_case = *unit_test.GetTestCase(i);
+ if (!test_case.should_run() || (test_case.failed_test_count() == 0)) {
+ continue;
+ }
+ for (int j = 0; j < test_case.total_test_count(); ++j) {
+ const testing::TestInfo& test_info = *test_case.GetTestInfo(j);
+ if (!test_info.should_run() || test_info.result()->Passed()) {
+ continue;
+ }
+ TC3_LOG(ERROR) << "[ FAILED ] " << test_case.name() << "."
+ << test_info.name();
+ }
+ }
+ }
+}
+
+void LoggingEventListener::OnTestProgramEnd(
+ const testing::UnitTest& /* unit_test */) {}
+
+} // namespace libtextclassifier3
diff --git a/native/utils/testing/logging_event_listener.h b/native/utils/testing/logging_event_listener.h
new file mode 100644
index 0000000..2663a9c
--- /dev/null
+++ b/native/utils/testing/logging_event_listener.h
@@ -0,0 +1,62 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_TESTING_LOGGING_EVENT_LISTENER_H_
+#define LIBTEXTCLASSIFIER_UTILS_TESTING_LOGGING_EVENT_LISTENER_H_
+
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+
+// TestEventListener that writes test results to the log so that they will be
+// visible in the logcat output in Sponge.
+// The formatting of the output is patterend after the output produced by the
+// standard PrettyUnitTestResultPrinter.
+class LoggingEventListener : public ::testing::TestEventListener {
+ public:
+ void OnTestProgramStart(const testing::UnitTest& unit_test) override;
+
+ void OnTestIterationStart(const testing::UnitTest& unit_test,
+ int iteration) override;
+
+ void OnEnvironmentsSetUpStart(const testing::UnitTest& unit_test) override;
+
+ void OnEnvironmentsSetUpEnd(const testing::UnitTest& unit_test) override;
+
+ void OnTestCaseStart(const testing::TestCase& test_case) override;
+
+ void OnTestStart(const testing::TestInfo& test_info) override;
+
+ void OnTestPartResult(
+ const testing::TestPartResult& test_part_result) override;
+
+ void OnTestEnd(const testing::TestInfo& test_info) override;
+
+ void OnTestCaseEnd(const testing::TestCase& test_case) override;
+
+ void OnEnvironmentsTearDownStart(const testing::UnitTest& unit_test) override;
+
+ void OnEnvironmentsTearDownEnd(const testing::UnitTest& unit_test) override;
+
+ void OnTestIterationEnd(const testing::UnitTest& unit_test,
+ int iteration) override;
+
+ void OnTestProgramEnd(const testing::UnitTest& unit_test) override;
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_TESTING_LOGGING_EVENT_LISTENER_H_
diff --git a/native/utils/testing/test_data_generator.h b/native/utils/testing/test_data_generator.h
new file mode 100644
index 0000000..30c7aed
--- /dev/null
+++ b/native/utils/testing/test_data_generator.h
@@ -0,0 +1,56 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_TESTING_TEST_DATA_GENERATOR_H_
+#define LIBTEXTCLASSIFIER_UTILS_TESTING_TEST_DATA_GENERATOR_H_
+
+#include <algorithm>
+#include <iostream>
+#include <random>
+
+#include "utils/strings/stringpiece.h"
+
+// Generates test data randomly.
+class TestDataGenerator {
+ public:
+ explicit TestDataGenerator() : random_engine_(0) {}
+
+ template <typename T,
+ typename std::enable_if_t<std::is_integral<T>::value>* = nullptr>
+ T generate() {
+ std::uniform_int_distribution<T> dist;
+ return dist(random_engine_);
+ }
+
+ template <typename T, typename std::enable_if_t<
+ std::is_floating_point<T>::value>* = nullptr>
+ T generate() {
+ std::uniform_real_distribution<T> dist;
+ return dist(random_engine_);
+ }
+
+ template <typename T, typename std::enable_if_t<
+ std::is_same<std::string, T>::value>* = nullptr>
+ T generate() {
+ std::uniform_int_distribution<> dist(1, 10);
+ return std::string(dist(random_engine_), '*');
+ }
+
+ private:
+ std::default_random_engine random_engine_;
+};
+
+#endif // LIBTEXTCLASSIFIER_UTILS_TESTING_TEST_DATA_GENERATOR_H_
diff --git a/native/utils/tflite-model-executor.cc b/native/utils/tflite-model-executor.cc
index 55faea5..36db3e9 100644
--- a/native/utils/tflite-model-executor.cc
+++ b/native/utils/tflite-model-executor.cc
@@ -18,6 +18,7 @@
#include "utils/base/logging.h"
#include "tensorflow/lite/kernels/register.h"
+#include "tensorflow/lite/schema/schema_generated.h"
// Forward declaration of custom TensorFlow Lite ops for registration.
namespace tflite {
@@ -33,15 +34,21 @@
TfLiteRegistration* Register_MUL();
TfLiteRegistration* Register_RESHAPE();
TfLiteRegistration* Register_REDUCE_MAX();
+TfLiteRegistration* Register_REDUCE_MIN();
TfLiteRegistration* Register_REDUCE_ANY();
TfLiteRegistration* Register_SOFTMAX();
TfLiteRegistration* Register_GATHER();
+TfLiteRegistration* Register_GATHER_ND();
+TfLiteRegistration* Register_IF();
+TfLiteRegistration* Register_ROUND();
+TfLiteRegistration* Register_ZEROS_LIKE();
TfLiteRegistration* Register_TRANSPOSE();
TfLiteRegistration* Register_SUB();
TfLiteRegistration* Register_DIV();
TfLiteRegistration* Register_STRIDED_SLICE();
TfLiteRegistration* Register_EXP();
TfLiteRegistration* Register_TOPK_V2();
+TfLiteRegistration* Register_SLICE();
TfLiteRegistration* Register_SPLIT();
TfLiteRegistration* Register_CAST();
TfLiteRegistration* Register_MAXIMUM();
@@ -49,6 +56,7 @@
TfLiteRegistration* Register_NEG();
TfLiteRegistration* Register_SLICE();
TfLiteRegistration* Register_LOG();
+TfLiteRegistration* Register_LOGISTIC();
TfLiteRegistration* Register_SUM();
TfLiteRegistration* Register_PACK();
TfLiteRegistration* Register_DEQUANTIZE();
@@ -59,14 +67,37 @@
TfLiteRegistration* Register_RSQRT();
TfLiteRegistration* Register_LOG_SOFTMAX();
TfLiteRegistration* Register_WHERE();
+TfLiteRegistration* Register_ONE_HOT();
+TfLiteRegistration* Register_POW();
+TfLiteRegistration* Register_TANH();
+TfLiteRegistration* Register_UNIQUE();
+TfLiteRegistration* Register_REDUCE_PROD();
+TfLiteRegistration* Register_SHAPE();
+TfLiteRegistration* Register_NOT_EQUAL();
+TfLiteRegistration* Register_CUMSUM();
+TfLiteRegistration* Register_EXPAND_DIMS();
+TfLiteRegistration* Register_FILL();
+TfLiteRegistration* Register_PADV2();
} // namespace builtin
} // namespace ops
} // namespace tflite
#ifdef TC3_WITH_ACTIONS_OPS
+#include "utils/tflite/blacklist.h"
#include "utils/tflite/dist_diversification.h"
+#include "utils/tflite/string_projection.h"
#include "utils/tflite/text_encoder.h"
#include "utils/tflite/token_encoder.h"
+namespace tflite {
+namespace ops {
+namespace custom {
+TfLiteRegistration* Register_SENTENCEPIECE_TOKENIZER();
+TfLiteRegistration* Register_RAGGED_TENSOR_TO_TENSOR();
+TfLiteRegistration* Register_RAGGED_RANGE();
+TfLiteRegistration* Register_RANDOM_UNIFORM();
+} // namespace custom
+} // namespace ops
+} // namespace tflite
void RegisterSelectedOps(tflite::MutableOpResolver* resolver) {
resolver->AddBuiltin(tflite::BuiltinOperator_ADD,
@@ -80,14 +111,14 @@
resolver->AddBuiltin(tflite::BuiltinOperator_CONV_2D,
tflite::ops::builtin::Register_CONV_2D(),
/*min_version=*/1,
- /*max_version=*/3);
+ /*max_version=*/5);
resolver->AddBuiltin(::tflite::BuiltinOperator_EQUAL,
::tflite::ops::builtin::Register_EQUAL());
resolver->AddBuiltin(tflite::BuiltinOperator_FULLY_CONNECTED,
tflite::ops::builtin::Register_FULLY_CONNECTED(),
/*min_version=*/1,
- /*max_version=*/4);
+ /*max_version=*/9);
resolver->AddBuiltin(::tflite::BuiltinOperator_GREATER_EQUAL,
::tflite::ops::builtin::Register_GREATER_EQUAL());
resolver->AddBuiltin(tflite::BuiltinOperator_L2_NORMALIZATION,
@@ -100,6 +131,8 @@
tflite::ops::builtin::Register_RESHAPE());
resolver->AddBuiltin(::tflite::BuiltinOperator_REDUCE_MAX,
::tflite::ops::builtin::Register_REDUCE_MAX());
+ resolver->AddBuiltin(::tflite::BuiltinOperator_REDUCE_MIN,
+ ::tflite::ops::builtin::Register_REDUCE_MIN());
resolver->AddBuiltin(::tflite::BuiltinOperator_REDUCE_ANY,
::tflite::ops::builtin::Register_REDUCE_ANY());
resolver->AddBuiltin(tflite::BuiltinOperator_SOFTMAX,
@@ -110,6 +143,15 @@
tflite::ops::builtin::Register_GATHER(),
/*min_version=*/1,
/*max_version=*/2);
+ resolver->AddBuiltin(::tflite::BuiltinOperator_GATHER_ND,
+ ::tflite::ops::builtin::Register_GATHER_ND(),
+ /*version=*/2);
+ resolver->AddBuiltin(::tflite::BuiltinOperator_IF,
+ ::tflite::ops::builtin::Register_IF()),
+ resolver->AddBuiltin(::tflite::BuiltinOperator_ROUND,
+ ::tflite::ops::builtin::Register_ROUND());
+ resolver->AddBuiltin(::tflite::BuiltinOperator_ZEROS_LIKE,
+ ::tflite::ops::builtin::Register_ZEROS_LIKE());
resolver->AddBuiltin(tflite::BuiltinOperator_TRANSPOSE,
tflite::ops::builtin::Register_TRANSPOSE(),
/*min_version=*/1,
@@ -130,6 +172,10 @@
tflite::ops::builtin::Register_TOPK_V2(),
/*min_version=*/1,
/*max_version=*/2);
+ resolver->AddBuiltin(tflite::BuiltinOperator_SLICE,
+ tflite::ops::builtin::Register_SLICE(),
+ /*min_version=*/1,
+ /*max_version=*/3);
resolver->AddBuiltin(tflite::BuiltinOperator_SPLIT,
tflite::ops::builtin::Register_SPLIT(),
/*min_version=*/1,
@@ -152,6 +198,8 @@
/*max_version=*/2);
resolver->AddBuiltin(tflite::BuiltinOperator_LOG,
tflite::ops::builtin::Register_LOG());
+ resolver->AddBuiltin(tflite::BuiltinOperator_LOGISTIC,
+ tflite::ops::builtin::Register_LOGISTIC());
resolver->AddBuiltin(tflite::BuiltinOperator_SUM,
tflite::ops::builtin::Register_SUM());
resolver->AddBuiltin(tflite::BuiltinOperator_PACK,
@@ -176,6 +224,34 @@
tflite::ops::builtin::Register_LOG_SOFTMAX());
resolver->AddBuiltin(::tflite::BuiltinOperator_WHERE,
::tflite::ops::builtin::Register_WHERE());
+ resolver->AddBuiltin(tflite::BuiltinOperator_ONE_HOT,
+ tflite::ops::builtin::Register_ONE_HOT(),
+ /*min_version=*/1,
+ /*max_version=*/1);
+ resolver->AddBuiltin(tflite::BuiltinOperator_POW,
+ tflite::ops::builtin::Register_POW(),
+ /*min_version=*/1,
+ /*max_version=*/1);
+ resolver->AddBuiltin(tflite::BuiltinOperator_TANH,
+ tflite::ops::builtin::Register_TANH(),
+ /*min_version=*/1,
+ /*max_version=*/1);
+ resolver->AddBuiltin(::tflite::BuiltinOperator_UNIQUE,
+ ::tflite::ops::builtin::Register_UNIQUE());
+ resolver->AddBuiltin(::tflite::BuiltinOperator_REDUCE_PROD,
+ ::tflite::ops::builtin::Register_REDUCE_PROD());
+ resolver->AddBuiltin(::tflite::BuiltinOperator_SHAPE,
+ ::tflite::ops::builtin::Register_SHAPE());
+ resolver->AddBuiltin(::tflite::BuiltinOperator_NOT_EQUAL,
+ ::tflite::ops::builtin::Register_NOT_EQUAL());
+ resolver->AddBuiltin(::tflite::BuiltinOperator_CUMSUM,
+ ::tflite::ops::builtin::Register_CUMSUM());
+ resolver->AddBuiltin(::tflite::BuiltinOperator_EXPAND_DIMS,
+ ::tflite::ops::builtin::Register_EXPAND_DIMS());
+ resolver->AddBuiltin(::tflite::BuiltinOperator_FILL,
+ ::tflite::ops::builtin::Register_FILL());
+ resolver->AddBuiltin(::tflite::BuiltinOperator_PADV2,
+ ::tflite::ops::builtin::Register_PADV2());
}
#else
void RegisterSelectedOps(tflite::MutableOpResolver* resolver) {
@@ -186,7 +262,12 @@
namespace libtextclassifier3 {
-inline std::unique_ptr<tflite::OpResolver> BuildOpResolver() {
+std::unique_ptr<tflite::OpResolver> BuildOpResolver() {
+ return BuildOpResolver([](tflite::MutableOpResolver* mutable_resolver) {});
+}
+
+std::unique_ptr<tflite::OpResolver> BuildOpResolver(
+ const std::function<void(tflite::MutableOpResolver*)>& customize_fn) {
#ifdef TC3_USE_SELECTIVE_REGISTRATION
std::unique_ptr<tflite::MutableOpResolver> resolver(
new tflite::MutableOpResolver);
@@ -202,7 +283,24 @@
tflite::ops::custom::Register_TEXT_ENCODER());
resolver->AddCustom("TokenEncoder",
tflite::ops::custom::Register_TOKEN_ENCODER());
+ resolver->AddCustom(
+ "TFSentencepieceTokenizeOp",
+ ::tflite::ops::custom::Register_SENTENCEPIECE_TOKENIZER());
+ resolver->AddCustom("RaggedRange",
+ ::tflite::ops::custom::Register_RAGGED_RANGE());
+ resolver->AddCustom(
+ "RaggedTensorToTensor",
+ ::tflite::ops::custom::Register_RAGGED_TENSOR_TO_TENSOR());
+ resolver->AddCustom(
+ "STRING_PROJECTION",
+ ::tflite::ops::custom::libtextclassifier3::Register_STRING_PROJECTION());
+ resolver->AddCustom(
+ "BLACKLIST",
+ ::tflite::ops::custom::libtextclassifier3::Register_BLACKLIST());
+ resolver->AddCustom("RandomUniform",
+ ::tflite::ops::custom::Register_RANDOM_UNIFORM());
#endif // TC3_WITH_ACTIONS_OPS
+ customize_fn(resolver.get());
return std::unique_ptr<tflite::OpResolver>(std::move(resolver));
}
diff --git a/native/utils/tflite-model-executor.h b/native/utils/tflite-model-executor.h
index a4432ff..faa1295 100644
--- a/native/utils/tflite-model-executor.h
+++ b/native/utils/tflite-model-executor.h
@@ -28,12 +28,21 @@
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/model.h"
+#include "tensorflow/lite/mutable_op_resolver.h"
#include "tensorflow/lite/op_resolver.h"
#include "tensorflow/lite/string_util.h"
namespace libtextclassifier3 {
+// Creates a TF.Lite Op resolver in default configuration, with ops for
+// Annotator and Actions models.
std::unique_ptr<tflite::OpResolver> BuildOpResolver();
+
+// Like above, but allows passage of a function that can register additional
+// ops.
+std::unique_ptr<tflite::OpResolver> BuildOpResolver(
+ const std::function<void(tflite::MutableOpResolver*)>& customize_fn);
+
std::unique_ptr<const tflite::FlatBufferModel> TfLiteModelFromModelSpec(
const tflite::Model*);
std::unique_ptr<const tflite::FlatBufferModel> TfLiteModelFromBuffer(
diff --git a/native/utils/tflite/blacklist.cc b/native/utils/tflite/blacklist.cc
new file mode 100644
index 0000000..b41fba1
--- /dev/null
+++ b/native/utils/tflite/blacklist.cc
@@ -0,0 +1,94 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/tflite/blacklist.h"
+
+#include "utils/tflite/blacklist_base.h"
+#include "utils/tflite/skipgram_finder.h"
+#include "flatbuffers/flexbuffers.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+
+namespace libtextclassifier3 {
+namespace blacklist {
+
+// Generates prediction vectors for input strings using a skipgram blacklist.
+// This uses the framework in `blacklist_base.h`, with the implementation detail
+// that the input is a string tensor of messages and the terms are skipgrams.
+class BlacklistOp : public BlacklistOpBase {
+ public:
+ explicit BlacklistOp(const flexbuffers::Map& custom_options)
+ : BlacklistOpBase(custom_options),
+ skipgram_finder_(custom_options["max_skip_size"].AsInt32()),
+ input_(nullptr) {
+ auto blacklist = custom_options["blacklist"].AsTypedVector();
+ auto blacklist_category =
+ custom_options["blacklist_category"].AsTypedVector();
+ for (int i = 0; i < blacklist.size(); i++) {
+ int category = blacklist_category[i].AsInt32();
+ flexbuffers::String s = blacklist[i].AsString();
+ skipgram_finder_.AddSkipgram(std::string(s.c_str(), s.length()),
+ category);
+ }
+ }
+
+ TfLiteStatus InitializeInput(TfLiteContext* context,
+ TfLiteNode* node) override {
+ input_ = &context->tensors[node->inputs->data[kInputMessage]];
+ return kTfLiteOk;
+ }
+
+ absl::flat_hash_set<int> GetCategories(int i) const override {
+ StringRef input = GetString(input_, i);
+ return skipgram_finder_.FindSkipgrams(std::string(input.str, input.len));
+ }
+
+ void FinalizeInput() override { input_ = nullptr; }
+
+ TfLiteIntArray* GetInputShape(TfLiteContext* context,
+ TfLiteNode* node) override {
+ return context->tensors[node->inputs->data[kInputMessage]].dims;
+ }
+
+ private:
+ ::libtextclassifier3::SkipgramFinder skipgram_finder_;
+ TfLiteTensor* input_;
+
+ static constexpr int kInputMessage = 0;
+};
+
+void* BlacklistOpInit(TfLiteContext* context, const char* buffer,
+ size_t length) {
+ const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
+ return new BlacklistOp(flexbuffers::GetRoot(buffer_t, length).AsMap());
+}
+
+} // namespace blacklist
+
+TfLiteRegistration* Register_BLACKLIST() {
+ static TfLiteRegistration r = {libtextclassifier3::blacklist::BlacklistOpInit,
+ libtextclassifier3::blacklist::Free,
+ libtextclassifier3::blacklist::Resize,
+ libtextclassifier3::blacklist::Eval};
+ return &r;
+}
+
+} // namespace libtextclassifier3
+} // namespace custom
+} // namespace ops
+} // namespace tflite
diff --git a/native/utils/tflite/blacklist.h b/native/utils/tflite/blacklist.h
new file mode 100644
index 0000000..0fcf5c4
--- /dev/null
+++ b/native/utils/tflite/blacklist.h
@@ -0,0 +1,34 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_TFLITE_BLACKLIST_H_
+#define LIBTEXTCLASSIFIER_UTILS_TFLITE_BLACKLIST_H_
+
+#include "tensorflow/lite/context.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+namespace libtextclassifier3 {
+
+TfLiteRegistration* Register_BLACKLIST();
+
+} // namespace libtextclassifier3
+} // namespace custom
+} // namespace ops
+} // namespace tflite
+
+#endif // LIBTEXTCLASSIFIER_UTILS_TFLITE_BLACKLIST_H_
diff --git a/native/utils/tflite/blacklist_base.cc b/native/utils/tflite/blacklist_base.cc
new file mode 100644
index 0000000..214283b
--- /dev/null
+++ b/native/utils/tflite/blacklist_base.cc
@@ -0,0 +1,109 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/tflite/blacklist_base.h"
+
+#include <cstdint>
+
+#include "absl/container/flat_hash_set.h"
+#include "tensorflow/lite/context.h"
+#include "tensorflow_models/seq_flow_lite/tflite_ops/quantization_util.h"
+namespace tflite {
+namespace ops {
+namespace custom {
+namespace libtextclassifier3 {
+namespace blacklist {
+
+static const int kOutputCategories = 0;
+
+void Free(TfLiteContext* context, void* buffer) {
+ delete reinterpret_cast<BlacklistOpBase*>(buffer);
+}
+
+TfLiteStatus Resize(TfLiteContext* context, TfLiteNode* node) {
+ auto* op = reinterpret_cast<BlacklistOpBase*>(node->user_data);
+
+ TfLiteIntArray* input_dims = op->GetInputShape(context, node);
+ TfLiteIntArray* output_dims = TfLiteIntArrayCreate(input_dims->size + 1);
+ for (int i = 0; i < input_dims->size; i++) {
+ output_dims->data[i] = input_dims->data[i];
+ }
+ output_dims->data[input_dims->size] = op->categories();
+ return context->ResizeTensor(
+ context, &context->tensors[node->outputs->data[kOutputCategories]],
+ output_dims);
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ auto* op = reinterpret_cast<BlacklistOpBase*>(node->user_data);
+
+ TfLiteTensor* output_categories =
+ &context->tensors[node->outputs->data[kOutputCategories]];
+
+ TfLiteIntArray* input_dims = op->GetInputShape(context, node);
+ int input_size = 1;
+ for (int i = 0; i < input_dims->size; i++) {
+ input_size *= input_dims->data[i];
+ }
+ const int n_categories = op->categories();
+
+ TF_LITE_ENSURE_STATUS(op->InitializeInput(context, node));
+ if (output_categories->type == kTfLiteFloat32) {
+ for (int i = 0; i < input_size; i++) {
+ absl::flat_hash_set<int> categories = op->GetCategories(i);
+ if (categories.empty()) {
+ for (int j = 0; j < n_categories; j++) {
+ output_categories->data.f[i * n_categories + j] =
+ (j < op->negative_categories()) ? 1.0 : 0.0;
+ }
+ } else {
+ for (int j = 0; j < n_categories; j++) {
+ output_categories->data.f[i * n_categories + j] =
+ (categories.find(j) != categories.end()) ? 1.0 : 0.0;
+ }
+ }
+ }
+ } else if (output_categories->type == kTfLiteUInt8) {
+ const uint8_t one =
+ ::seq_flow_lite::PodQuantize(1.0, output_categories->params.zero_point,
+ 1.0 / output_categories->params.scale);
+ const uint8_t zero =
+ ::seq_flow_lite::PodQuantize(0.0, output_categories->params.zero_point,
+ 1.0 / output_categories->params.scale);
+ for (int i = 0; i < input_size; i++) {
+ absl::flat_hash_set<int> categories = op->GetCategories(i);
+ if (categories.empty()) {
+ for (int j = 0; j < n_categories; j++) {
+ output_categories->data.uint8[i * n_categories + j] =
+ (j < op->negative_categories()) ? one : zero;
+ }
+ } else {
+ for (int j = 0; j < n_categories; j++) {
+ output_categories->data.uint8[i * n_categories + j] =
+ (categories.find(j) != categories.end()) ? one : zero;
+ }
+ }
+ }
+ }
+ op->FinalizeInput();
+ return kTfLiteOk;
+}
+
+} // namespace blacklist
+} // namespace libtextclassifier3
+} // namespace custom
+} // namespace ops
+} // namespace tflite
diff --git a/native/utils/tflite/blacklist_base.h b/native/utils/tflite/blacklist_base.h
new file mode 100644
index 0000000..3da1ed7
--- /dev/null
+++ b/native/utils/tflite/blacklist_base.h
@@ -0,0 +1,91 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_TFLITE_BLACKLIST_BASE_H_
+#define LIBTEXTCLASSIFIER_UTILS_TFLITE_BLACKLIST_BASE_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "flatbuffers/flexbuffers.h"
+#include "tensorflow/lite/context.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+namespace libtextclassifier3 {
+namespace blacklist {
+
+/*
+ * A framework for writing ops that generates prediction vectors using a
+ * blacklist.
+ *
+ * Input is defined by the specific implementation.
+ *
+ * Attributes:
+ * blacklist: string[n]
+ * Terms in the blacklist.
+ * blacklist_category: int[n]
+ * Category for each term in the blacklist. Each category must be in
+ * [0, categories).
+ * categories: int[]
+ * Total number of categories.
+ * negative_categories: int[]
+ * Total number of negative categories.
+ *
+ * Output:
+ * tensor[0]: Category indicators for each message, float[..., categories]
+ *
+ */
+
+class BlacklistOpBase {
+ public:
+ explicit BlacklistOpBase(const flexbuffers::Map& custom_options)
+ : categories_(custom_options["categories"].AsInt32()),
+ negative_categories_(custom_options["negative_categories"].AsInt32()) {}
+
+ virtual ~BlacklistOpBase() {}
+
+ int categories() const { return categories_; }
+ int negative_categories() const { return negative_categories_; }
+
+ virtual TfLiteStatus InitializeInput(TfLiteContext* context,
+ TfLiteNode* node) = 0;
+ virtual absl::flat_hash_set<int> GetCategories(int i) const = 0;
+ virtual void FinalizeInput() = 0;
+
+ // Returns the input shape. TfLiteIntArray is owned by the object.
+ virtual TfLiteIntArray* GetInputShape(TfLiteContext* context,
+ TfLiteNode* node) = 0;
+
+ private:
+ int categories_;
+ int negative_categories_;
+};
+
+// Individual ops should define an Init() function that returns a
+// BlacklistOpBase.
+
+void Free(TfLiteContext* context, void* buffer);
+
+TfLiteStatus Resize(TfLiteContext* context, TfLiteNode* node);
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node);
+} // namespace blacklist
+} // namespace libtextclassifier3
+} // namespace custom
+} // namespace ops
+} // namespace tflite
+
+#endif // LIBTEXTCLASSIFIER_UTILS_TFLITE_BLACKLIST_BASE_H_
diff --git a/native/utils/tflite/encoder_common_test.cc b/native/utils/tflite/encoder_common_test.cc
new file mode 100644
index 0000000..247689f
--- /dev/null
+++ b/native/utils/tflite/encoder_common_test.cc
@@ -0,0 +1,34 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/tflite/encoder_common.h"
+
+#include "gtest/gtest.h"
+#include "tensorflow/lite/model.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+TEST(EncoderUtilsTest, CreateIntArray) {
+ TfLiteIntArray* a = CreateIntArray({1, 2, 3});
+ EXPECT_EQ(a->data[0], 1);
+ EXPECT_EQ(a->data[1], 2);
+ EXPECT_EQ(a->data[2], 3);
+ TfLiteIntArrayFree(a);
+}
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/native/utils/tflite/skipgram_finder.cc b/native/utils/tflite/skipgram_finder.cc
new file mode 100644
index 0000000..c69193e
--- /dev/null
+++ b/native/utils/tflite/skipgram_finder.cc
@@ -0,0 +1,203 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/tflite/skipgram_finder.h"
+
+#include <cctype>
+#include <deque>
+#include <string>
+#include <vector>
+
+#include "utils/strings/utf8.h"
+#include "utils/utf8/unilib-common.h"
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/strings/match.h"
+#include "absl/strings/str_split.h"
+#include "absl/strings/string_view.h"
+#include "tensorflow/lite/string_util.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+using ::tflite::StringRef;
+
+void PreprocessToken(std::string& token) {
+ size_t in = 0;
+ size_t out = 0;
+ while (in < token.size()) {
+ const char* in_data = token.data() + in;
+ const int n = GetNumBytesForUTF8Char(in_data);
+ if (n < 0 || n > token.size() - in) {
+ // Invalid Utf8 sequence.
+ break;
+ }
+ in += n;
+ const char32 r = ValidCharToRune(in_data);
+ if (IsPunctuation(r)) {
+ continue;
+ }
+ const char32 rl = ToLower(r);
+ char output_buffer[4];
+ int encoded_length = ValidRuneToChar(rl, output_buffer);
+ if (encoded_length > n) {
+ // This is a hack, but there are exactly two unicode characters whose
+ // lowercase versions have longer UTF-8 encodings (0x23a to 0x2c65,
+ // 0x23e to 0x2c66). So, to avoid sizing issues, they're not lowercased.
+ encoded_length = ValidRuneToChar(r, output_buffer);
+ }
+ memcpy(token.data() + out, output_buffer, encoded_length);
+ out += encoded_length;
+ }
+
+ size_t remaining = token.size() - in;
+ if (remaining > 0) {
+ memmove(token.data() + out, token.data() + in, remaining);
+ out += remaining;
+ }
+ token.resize(out);
+}
+
+} // namespace
+
+void SkipgramFinder::AddSkipgram(const std::string& skipgram, int category) {
+ std::vector<std::string> tokens = absl::StrSplit(skipgram, ' ');
+
+ // Store the skipgram in a trie-like structure that uses tokens as the
+ // edge labels, instead of characters. Each node represents a skipgram made
+ // from the tokens used to reach the node, and stores the categories the
+ // skipgram is associated with.
+ TrieNode* cur = &skipgram_trie_;
+ for (auto& token : tokens) {
+ if (absl::EndsWith(token, ".*")) {
+ token.resize(token.size() - 2);
+ PreprocessToken(token);
+ auto iter = cur->prefix_to_node.find(token);
+ if (iter != cur->prefix_to_node.end()) {
+ cur = &iter->second;
+ } else {
+ cur = &cur->prefix_to_node
+ .emplace(std::piecewise_construct,
+ std::forward_as_tuple(token), std::make_tuple<>())
+ .first->second;
+ }
+ continue;
+ }
+
+ PreprocessToken(token);
+ auto iter = cur->token_to_node.find(token);
+ if (iter != cur->token_to_node.end()) {
+ cur = &iter->second;
+ } else {
+ cur = &cur->token_to_node
+ .emplace(std::piecewise_construct,
+ std::forward_as_tuple(token), std::make_tuple<>())
+ .first->second;
+ }
+ }
+ cur->categories.insert(category);
+}
+
+absl::flat_hash_set<int> SkipgramFinder::FindSkipgrams(
+ const std::string& input) const {
+ std::vector<std::string> tokens = absl::StrSplit(input, ' ');
+ std::vector<absl::string_view> sv_tokens;
+ sv_tokens.reserve(tokens.size());
+ for (auto& token : tokens) {
+ PreprocessToken(token);
+ sv_tokens.emplace_back(token.data(), token.size());
+ }
+ return FindSkipgrams(sv_tokens);
+}
+
+absl::flat_hash_set<int> SkipgramFinder::FindSkipgrams(
+ const std::vector<StringRef>& tokens) const {
+ std::vector<absl::string_view> sv_tokens;
+ sv_tokens.reserve(tokens.size());
+ for (auto& token : tokens) {
+ sv_tokens.emplace_back(token.str, token.len);
+ }
+ return FindSkipgrams(sv_tokens);
+}
+
+absl::flat_hash_set<int> SkipgramFinder::FindSkipgrams(
+ const std::vector<absl::string_view>& tokens) const {
+ absl::flat_hash_set<int> categories;
+
+ // Tracks skipgram prefixes and the index of their last token.
+ std::deque<std::pair<int, const TrieNode*>> indices_and_skipgrams;
+
+ for (int token_i = 0; token_i < tokens.size(); token_i++) {
+ const absl::string_view& token = tokens[token_i];
+
+ std::vector<absl::string_view> token_prefixes;
+ {
+ const char* s = token.data();
+ int n = token.size();
+ while (n > 0) {
+ const int rlen = GetNumBytesForUTF8Char(s);
+ if (rlen < 0 || rlen > n) {
+ // Invalid UTF8.
+ break;
+ }
+ n -= rlen;
+ s += rlen;
+ token_prefixes.emplace_back(token.data(), token.size() - n);
+ }
+ }
+
+ // Drop any skipgrams prefixes which would skip more than `max_skip_size_`
+ // tokens between the end of the prefix and the current token.
+ while (!indices_and_skipgrams.empty()) {
+ if (indices_and_skipgrams.front().first + max_skip_size_ + 1 < token_i) {
+ indices_and_skipgrams.pop_front();
+ } else {
+ break;
+ }
+ }
+
+ // Check if we can form a valid skipgram prefix (or skipgram) by adding
+ // the current token to any of the existing skipgram prefixes, or
+ // if the current token is a valid skipgram prefix (or skipgram).
+ size_t size = indices_and_skipgrams.size();
+ for (size_t skipgram_i = 0; skipgram_i <= size; skipgram_i++) {
+ const auto& node = skipgram_i < size
+ ? *indices_and_skipgrams[skipgram_i].second
+ : skipgram_trie_;
+
+ auto iter = node.token_to_node.find(token);
+ if (iter != node.token_to_node.end()) {
+ categories.insert(iter->second.categories.begin(),
+ iter->second.categories.end());
+ indices_and_skipgrams.push_back(std::make_pair(token_i, &iter->second));
+ }
+
+ for (const auto& token_prefix : token_prefixes) {
+ auto iter = node.prefix_to_node.find(token_prefix);
+ if (iter != node.prefix_to_node.end()) {
+ categories.insert(iter->second.categories.begin(),
+ iter->second.categories.end());
+ indices_and_skipgrams.push_back(
+ std::make_pair(token_i, &iter->second));
+ }
+ }
+ }
+ }
+
+ return categories;
+}
+
+} // namespace libtextclassifier3
diff --git a/native/utils/tflite/skipgram_finder.h b/native/utils/tflite/skipgram_finder.h
new file mode 100644
index 0000000..e7e8547
--- /dev/null
+++ b/native/utils/tflite/skipgram_finder.h
@@ -0,0 +1,70 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_TFLITE_SKIPGRAM_FINDER_H_
+#define LIBTEXTCLASSIFIER_UTILS_TFLITE_SKIPGRAM_FINDER_H_
+
+#include <string>
+#include <vector>
+
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/strings/string_view.h"
+#include "tensorflow/lite/string_util.h"
+
+namespace libtextclassifier3 {
+
+// SkipgramFinder finds skipgrams in strings.
+//
+// To use: First, add skipgrams using AddSkipgram() - each skipgram is
+// associated with some category. Then, call FindSkipgrams() on a string,
+// which will return the set of categories of the skipgrams in the string.
+//
+// Both the skipgrams and the input strings will be tokenzied by splitting
+// on spaces. Additionally, the tokens will be lowercased and have any
+// trailing punctuation removed.
+class SkipgramFinder {
+ public:
+ explicit SkipgramFinder(int max_skip_size) : max_skip_size_(max_skip_size) {}
+
+ // Adds a skipgram that SkipgramFinder should look for in input strings.
+ // Tokens may use the regex '.*' as a suffix.
+ void AddSkipgram(const std::string& skipgram, int category);
+
+ // Find all of the skipgrams in `input`, and return their categories.
+ absl::flat_hash_set<int> FindSkipgrams(const std::string& input) const;
+
+ // Find all of the skipgrams in `tokens`, and return their categories.
+ absl::flat_hash_set<int> FindSkipgrams(
+ const std::vector<absl::string_view>& tokens) const;
+ absl::flat_hash_set<int> FindSkipgrams(
+ const std::vector<::tflite::StringRef>& tokens) const;
+
+ private:
+ struct TrieNode {
+ absl::flat_hash_set<int> categories;
+ // Maps tokens to the next node in the trie.
+ absl::flat_hash_map<std::string, TrieNode> token_to_node;
+ // Maps token prefixes (<prefix>.*) to the next node in the trie.
+ absl::flat_hash_map<std::string, TrieNode> prefix_to_node;
+ };
+
+ TrieNode skipgram_trie_;
+ int max_skip_size_;
+};
+
+} // namespace libtextclassifier3
+#endif // LIBTEXTCLASSIFIER_UTILS_TFLITE_SKIPGRAM_FINDER_H_
diff --git a/native/utils/tflite/string_projection.cc b/native/utils/tflite/string_projection.cc
new file mode 100644
index 0000000..9f8d36e
--- /dev/null
+++ b/native/utils/tflite/string_projection.cc
@@ -0,0 +1,579 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/tflite/string_projection.h"
+
+#include <string>
+#include <unordered_map>
+
+#include "utils/strings/utf8.h"
+#include "utils/tflite/string_projection_base.h"
+#include "utils/utf8/unilib-common.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/strings/match.h"
+#include "flatbuffers/flexbuffers.h"
+#include "tensorflow/lite/context.h"
+#include "tensorflow/lite/string_util.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+
+namespace libtextclassifier3 {
+namespace string_projection {
+namespace {
+
+const char kStartToken[] = "<S>";
+const char kEndToken[] = "<E>";
+const char kEmptyToken[] = "<S> <E>";
+constexpr size_t kEntireString = SIZE_MAX;
+constexpr size_t kAllTokens = SIZE_MAX;
+constexpr int kInvalid = -1;
+
+constexpr char kApostrophe = '\'';
+constexpr char kSpace = ' ';
+constexpr char kComma = ',';
+constexpr char kDot = '.';
+
+// Returns true if the given text contains a number.
+bool IsDigitString(const std::string& text) {
+ for (size_t i = 0; i < text.length();) {
+ const int bytes_read =
+ ::libtextclassifier3::GetNumBytesForUTF8Char(text.data());
+ if (bytes_read <= 0 || bytes_read > text.length() - i) {
+ break;
+ }
+ const char32_t rune = ::libtextclassifier3::ValidCharToRune(text.data());
+ if (::libtextclassifier3::IsDigit(rune)) return true;
+ i += bytes_read;
+ }
+ return false;
+}
+
+// Gets the string containing |num_chars| characters from |start| position.
+std::string GetCharToken(const std::vector<std::string>& char_tokens, int start,
+ int num_chars) {
+ std::string char_token = "";
+ if (start + num_chars <= char_tokens.size()) {
+ for (int i = 0; i < num_chars; ++i) {
+ char_token.append(char_tokens[start + i]);
+ }
+ }
+ return char_token;
+}
+
+// Counts how many times |pattern| appeared from |start| position.
+int GetNumPattern(const std::vector<std::string>& char_tokens, size_t start,
+ size_t num_chars, const std::string& pattern) {
+ int count = 0;
+ for (int i = start; i < char_tokens.size(); i += num_chars) {
+ std::string cur_pattern = GetCharToken(char_tokens, i, num_chars);
+ if (pattern == cur_pattern) {
+ ++count;
+ } else {
+ break;
+ }
+ }
+ return count;
+}
+
+inline size_t FindNextSpace(const char* input_ptr, size_t from, size_t length) {
+ size_t space_index;
+ for (space_index = from; space_index < length; space_index++) {
+ if (input_ptr[space_index] == kSpace) {
+ break;
+ }
+ }
+ return space_index == length ? kInvalid : space_index;
+}
+
+template <typename T>
+void SplitByCharInternal(std::vector<T>* tokens, const char* input_ptr,
+ size_t len, size_t max_tokens) {
+ for (size_t i = 0; i < len;) {
+ auto bytes_read =
+ ::libtextclassifier3::GetNumBytesForUTF8Char(input_ptr + i);
+ if (bytes_read <= 0 || bytes_read > len - i) break;
+ tokens->emplace_back(input_ptr + i, bytes_read);
+ if (max_tokens != kInvalid && tokens->size() == max_tokens) {
+ break;
+ }
+ i += bytes_read;
+ }
+}
+
+std::vector<std::string> SplitByChar(const char* input_ptr, size_t len,
+ size_t max_tokens) {
+ std::vector<std::string> tokens;
+ SplitByCharInternal(&tokens, input_ptr, len, max_tokens);
+ return tokens;
+}
+
+std::string ContractToken(const char* input_ptr, size_t len, size_t num_chars) {
+ // This function contracts patterns whose length is |num_chars| and appeared
+ // more than twice. So if the input is shorter than 3 * |num_chars|, do not
+ // apply any contraction.
+ if (len < 3 * num_chars) {
+ return input_ptr;
+ }
+ std::vector<std::string> char_tokens = SplitByChar(input_ptr, len, len);
+
+ std::string token;
+ token.reserve(len);
+ for (int i = 0; i < char_tokens.size();) {
+ std::string cur_pattern = GetCharToken(char_tokens, i, num_chars);
+
+ // Count how many times this pattern appeared.
+ int num_cur_patterns = 0;
+ if (!absl::StrContains(cur_pattern, " ") && !IsDigitString(cur_pattern)) {
+ num_cur_patterns =
+ GetNumPattern(char_tokens, i + num_chars, num_chars, cur_pattern);
+ }
+
+ if (num_cur_patterns >= 2) {
+ // If this pattern is repeated, store it only twice.
+ token.append(cur_pattern);
+ token.append(cur_pattern);
+ i += (num_cur_patterns + 1) * num_chars;
+ } else {
+ token.append(char_tokens[i]);
+ ++i;
+ }
+ }
+
+ return token;
+}
+
+template <typename T>
+void SplitBySpaceInternal(std::vector<T>* tokens, const char* input_ptr,
+ size_t len, size_t max_input, size_t max_tokens) {
+ size_t last_index =
+ max_input == kEntireString ? len : (len < max_input ? len : max_input);
+ size_t start = 0;
+ // skip leading spaces
+ while (start < last_index && input_ptr[start] == kSpace) {
+ start++;
+ }
+ auto end = FindNextSpace(input_ptr, start, last_index);
+ while (end != kInvalid &&
+ (max_tokens == kAllTokens || tokens->size() < max_tokens - 1)) {
+ auto length = end - start;
+ if (length > 0) {
+ tokens->emplace_back(input_ptr + start, length);
+ }
+
+ start = end + 1;
+ end = FindNextSpace(input_ptr, start, last_index);
+ }
+ auto length = end == kInvalid ? (last_index - start) : (end - start);
+ if (length > 0) {
+ tokens->emplace_back(input_ptr + start, length);
+ }
+}
+
+std::vector<std::string> SplitBySpace(const char* input_ptr, size_t len,
+ size_t max_input, size_t max_tokens) {
+ std::vector<std::string> tokens;
+ SplitBySpaceInternal(&tokens, input_ptr, len, max_input, max_tokens);
+ return tokens;
+}
+
+bool prepend_separator(char separator) { return separator == kApostrophe; }
+
+bool is_numeric(char c) { return c >= '0' && c <= '9'; }
+
+class ProjectionNormalizer {
+ public:
+ explicit ProjectionNormalizer(const std::string& separators,
+ bool normalize_repetition = false) {
+ InitializeSeparators(separators);
+ normalize_repetition_ = normalize_repetition;
+ }
+
+ // Normalizes the repeated characters (except numbers) which consecutively
+ // appeared more than twice in a word.
+ std::string Normalize(const std::string& input, size_t max_input = 300) {
+ return Normalize(input.data(), input.length(), max_input);
+ }
+ std::string Normalize(const char* input_ptr, size_t len,
+ size_t max_input = 300) {
+ std::string normalized(input_ptr, std::min(len, max_input));
+
+ if (normalize_repetition_) {
+ // Remove repeated 1 char (e.g. soooo => soo)
+ normalized = ContractToken(normalized.data(), normalized.length(), 1);
+
+ // Remove repeated 2 chars from the beginning (e.g. hahaha =>
+ // haha, xhahaha => xhaha, xyhahaha => xyhaha).
+ normalized = ContractToken(normalized.data(), normalized.length(), 2);
+
+ // Remove repeated 3 chars from the beginning
+ // (e.g. wowwowwow => wowwow, abcdbcdbcd => abcdbcd).
+ normalized = ContractToken(normalized.data(), normalized.length(), 3);
+ }
+
+ if (!separators_.empty()) {
+ // Add space around separators_.
+ normalized = NormalizeInternal(normalized.data(), normalized.length());
+ }
+ return normalized;
+ }
+
+ private:
+ // Parses and extracts supported separators.
+ void InitializeSeparators(const std::string& separators) {
+ for (int i = 0; i < separators.length(); ++i) {
+ if (separators[i] != ' ') {
+ separators_.insert(separators[i]);
+ }
+ }
+ }
+
+ // Removes repeated chars.
+ std::string NormalizeInternal(const char* input_ptr, size_t len) {
+ std::string normalized;
+ normalized.reserve(len * 2);
+ for (int i = 0; i < len; ++i) {
+ char c = input_ptr[i];
+ bool matched_separator = separators_.find(c) != separators_.end();
+ if (matched_separator) {
+ if (i > 0 && input_ptr[i - 1] != ' ' && normalized.back() != ' ') {
+ normalized.append(" ");
+ }
+ }
+ normalized.append(1, c);
+ if (matched_separator) {
+ if (i + 1 < len && input_ptr[i + 1] != ' ' && c != '\'') {
+ normalized.append(" ");
+ }
+ }
+ }
+ return normalized;
+ }
+
+ absl::flat_hash_set<char> separators_;
+ bool normalize_repetition_;
+};
+
+class ProjectionTokenizer {
+ public:
+ explicit ProjectionTokenizer(const std::string& separators) {
+ InitializeSeparators(separators);
+ }
+
+ // Tokenizes the input by separators_. Limit to max_tokens, when it is not -1.
+ std::vector<std::string> Tokenize(const std::string& input, size_t max_input,
+ size_t max_tokens) const {
+ return Tokenize(input.c_str(), input.size(), max_input, max_tokens);
+ }
+
+ std::vector<std::string> Tokenize(const char* input_ptr, size_t len,
+ size_t max_input, size_t max_tokens) const {
+ // If separators_ is not given, tokenize the input with a space.
+ if (separators_.empty()) {
+ return SplitBySpace(input_ptr, len, max_input, max_tokens);
+ }
+
+ std::vector<std::string> tokens;
+ size_t last_index =
+ max_input == kEntireString ? len : (len < max_input ? len : max_input);
+ size_t start = 0;
+ // Skip leading spaces.
+ while (start < last_index && input_ptr[start] == kSpace) {
+ start++;
+ }
+ auto end = FindNextSeparator(input_ptr, start, last_index);
+
+ while (end != kInvalid &&
+ (max_tokens == kAllTokens || tokens.size() < max_tokens - 1)) {
+ auto length = end - start;
+ if (length > 0) tokens.emplace_back(input_ptr + start, length);
+
+ // Add the separator (except space and apostrophe) as a token
+ char separator = input_ptr[end];
+ if (separator != kSpace && separator != kApostrophe) {
+ tokens.emplace_back(input_ptr + end, 1);
+ }
+
+ start = end + (prepend_separator(separator) ? 0 : 1);
+ end = FindNextSeparator(input_ptr, end + 1, last_index);
+ }
+ auto length = end == kInvalid ? (last_index - start) : (end - start);
+ if (length > 0) tokens.emplace_back(input_ptr + start, length);
+ return tokens;
+ }
+
+ private:
+ // Parses and extracts supported separators.
+ void InitializeSeparators(const std::string& separators) {
+ for (int i = 0; i < separators.length(); ++i) {
+ separators_.insert(separators[i]);
+ }
+ }
+
+ // Starting from input_ptr[from], search for the next occurrence of
+ // separators_. Don't search beyond input_ptr[length](non-inclusive). Return
+ // -1 if not found.
+ size_t FindNextSeparator(const char* input_ptr, size_t from,
+ size_t length) const {
+ auto index = from;
+ while (index < length) {
+ char c = input_ptr[index];
+ // Do not break a number (e.g. "10,000", "0.23").
+ if (c == kComma || c == kDot) {
+ if (index + 1 < length && is_numeric(input_ptr[index + 1])) {
+ c = input_ptr[++index];
+ }
+ }
+ if (separators_.find(c) != separators_.end()) {
+ break;
+ }
+ ++index;
+ }
+ return index == length ? kInvalid : index;
+ }
+
+ absl::flat_hash_set<char> separators_;
+};
+
+inline void StripTrailingAsciiPunctuation(std::string* str) {
+ auto it = std::find_if_not(str->rbegin(), str->rend(), ::ispunct);
+ str->erase(str->rend() - it);
+}
+
+std::string PreProcessString(const char* str, int len,
+ const bool remove_punctuation) {
+ std::string output_str(str, len);
+ std::transform(output_str.begin(), output_str.end(), output_str.begin(),
+ ::tolower);
+
+ // Remove trailing punctuation.
+ if (remove_punctuation) {
+ StripTrailingAsciiPunctuation(&output_str);
+ }
+
+ if (output_str.empty()) {
+ output_str.assign(str, len);
+ }
+ return output_str;
+}
+
+bool ShouldIncludeCurrentNgram(const SkipGramParams& params, int size) {
+ if (size <= 0) {
+ return false;
+ }
+ if (params.include_all_ngrams) {
+ return size <= params.ngram_size;
+ } else {
+ return size == params.ngram_size;
+ }
+}
+
+bool ShouldStepInRecursion(const std::vector<int>& stack, int stack_idx,
+ int num_words, const SkipGramParams& params) {
+ // If current stack size and next word enumeration are within valid range.
+ if (stack_idx < params.ngram_size && stack[stack_idx] + 1 < num_words) {
+ // If this stack is empty, step in for first word enumeration.
+ if (stack_idx == 0) {
+ return true;
+ }
+ // If next word enumeration are within the range of max_skip_size.
+ // NOTE: equivalent to
+ // next_word_idx = stack[stack_idx] + 1
+ // next_word_idx - stack[stack_idx-1] <= max_skip_size + 1
+ if (stack[stack_idx] - stack[stack_idx - 1] <= params.max_skip_size) {
+ return true;
+ }
+ }
+ return false;
+}
+
+std::string JoinTokensBySpace(const std::vector<int>& stack, int stack_idx,
+ const std::vector<std::string>& tokens) {
+ int len = 0;
+ for (int i = 0; i < stack_idx; i++) {
+ len += tokens[stack[i]].size();
+ }
+ len += stack_idx - 1;
+
+ std::string res;
+ res.reserve(len);
+ res.append(tokens[stack[0]]);
+ for (int i = 1; i < stack_idx; i++) {
+ res.append(" ");
+ res.append(tokens[stack[i]]);
+ }
+
+ return res;
+}
+
+std::unordered_map<std::string, int> ExtractSkipGramsImpl(
+ const std::vector<std::string>& tokens, const SkipGramParams& params) {
+ // Ignore positional tokens.
+ static auto* blacklist = new std::unordered_set<std::string>({
+ kStartToken,
+ kEndToken,
+ kEmptyToken,
+ });
+
+ std::unordered_map<std::string, int> res;
+
+ // Stack stores the index of word used to generate ngram.
+ // The size of stack is the size of ngram.
+ std::vector<int> stack(params.ngram_size + 1, 0);
+ // Stack index that indicates which depth the recursion is operating at.
+ int stack_idx = 1;
+ int num_words = tokens.size();
+
+ while (stack_idx >= 0) {
+ if (ShouldStepInRecursion(stack, stack_idx, num_words, params)) {
+ // When current depth can fill with a new word
+ // and the new word is within the max range to skip,
+ // fill this word to stack, recurse into next depth.
+ stack[stack_idx]++;
+ stack_idx++;
+ stack[stack_idx] = stack[stack_idx - 1];
+ } else {
+ if (ShouldIncludeCurrentNgram(params, stack_idx)) {
+ // Add n-gram to tensor buffer when the stack has filled with enough
+ // words to generate the ngram.
+ std::string ngram = JoinTokensBySpace(stack, stack_idx, tokens);
+ if (blacklist->find(ngram) == blacklist->end()) {
+ res[ngram] = stack_idx;
+ }
+ }
+ // When current depth cannot fill with a valid new word,
+ // and not in last depth to generate ngram,
+ // step back to previous depth to iterate to next possible word.
+ stack_idx--;
+ }
+ }
+
+ return res;
+}
+
+std::unordered_map<std::string, int> ExtractSkipGrams(
+ const std::string& input, ProjectionTokenizer* tokenizer,
+ ProjectionNormalizer* normalizer, const SkipGramParams& params) {
+ // Normalize the input.
+ const std::string& normalized =
+ normalizer == nullptr
+ ? input
+ : normalizer->Normalize(input, params.max_input_chars);
+
+ // Split sentence to words.
+ std::vector<std::string> tokens;
+ if (params.char_level) {
+ tokens = SplitByChar(normalized.data(), normalized.size(),
+ params.max_input_chars);
+ } else {
+ tokens = tokenizer->Tokenize(normalized.data(), normalized.size(),
+ params.max_input_chars, kAllTokens);
+ }
+
+ // Process tokens
+ for (int i = 0; i < tokens.size(); ++i) {
+ if (params.preprocess) {
+ tokens[i] = PreProcessString(tokens[i].data(), tokens[i].size(),
+ params.remove_punctuation);
+ }
+ }
+
+ tokens.insert(tokens.begin(), kStartToken);
+ tokens.insert(tokens.end(), kEndToken);
+
+ return ExtractSkipGramsImpl(tokens, params);
+}
+} // namespace
+// Generates LSH projections for input strings. This uses the framework in
+// `string_projection_base.h`, with the implementation details that the input is
+// a string tensor of messages and the op will perform tokenization.
+//
+// Input:
+// tensor[0]: Input message, string[...]
+//
+// Additional attributes:
+// max_input_chars: int[]
+// maximum number of input characters to use from each message.
+// token_separators: string[]
+// the list of separators used to tokenize the input.
+// normalize_repetition: bool[]
+// if true, remove repeated characters in tokens ('loool' -> 'lol').
+
+static const int kInputMessage = 0;
+
+class StringProjectionOp : public StringProjectionOpBase {
+ public:
+ explicit StringProjectionOp(const flexbuffers::Map& custom_options)
+ : StringProjectionOpBase(custom_options),
+ projection_normalizer_(
+ custom_options["token_separators"].AsString().str(),
+ custom_options["normalize_repetition"].AsBool()),
+ projection_tokenizer_(" ") {
+ if (custom_options["max_input_chars"].IsInt()) {
+ skip_gram_params().max_input_chars =
+ custom_options["max_input_chars"].AsInt32();
+ }
+ }
+
+ TfLiteStatus InitializeInput(TfLiteContext* context,
+ TfLiteNode* node) override {
+ input_ = &context->tensors[node->inputs->data[kInputMessage]];
+ return kTfLiteOk;
+ }
+
+ std::unordered_map<std::string, int> ExtractSkipGrams(int i) override {
+ StringRef input = GetString(input_, i);
+ return ::tflite::ops::custom::libtextclassifier3::string_projection::
+ ExtractSkipGrams({input.str, static_cast<size_t>(input.len)},
+ &projection_tokenizer_, &projection_normalizer_,
+ skip_gram_params());
+ }
+
+ void FinalizeInput() override { input_ = nullptr; }
+
+ TfLiteIntArray* GetInputShape(TfLiteContext* context,
+ TfLiteNode* node) override {
+ return context->tensors[node->inputs->data[kInputMessage]].dims;
+ }
+
+ private:
+ ProjectionNormalizer projection_normalizer_;
+ ProjectionTokenizer projection_tokenizer_;
+
+ TfLiteTensor* input_;
+};
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
+ return new StringProjectionOp(flexbuffers::GetRoot(buffer_t, length).AsMap());
+}
+
+} // namespace string_projection
+
+// This op converts a list of strings to integers via LSH projections.
+TfLiteRegistration* Register_STRING_PROJECTION() {
+ static TfLiteRegistration r = {libtextclassifier3::string_projection::Init,
+ libtextclassifier3::string_projection::Free,
+ libtextclassifier3::string_projection::Resize,
+ libtextclassifier3::string_projection::Eval};
+ return &r;
+}
+
+} // namespace libtextclassifier3
+} // namespace custom
+} // namespace ops
+} // namespace tflite
diff --git a/native/utils/tflite/string_projection.h b/native/utils/tflite/string_projection.h
new file mode 100644
index 0000000..ba86a21
--- /dev/null
+++ b/native/utils/tflite/string_projection.h
@@ -0,0 +1,34 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_TFLITE_STRING_PROJECTION_H_
+#define LIBTEXTCLASSIFIER_UTILS_TFLITE_STRING_PROJECTION_H_
+
+#include "tensorflow/lite/context.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+namespace libtextclassifier3 {
+
+TfLiteRegistration* Register_STRING_PROJECTION();
+
+} // namespace libtextclassifier3
+} // namespace custom
+} // namespace ops
+} // namespace tflite
+
+#endif // LIBTEXTCLASSIFIER_UTILS_TFLITE_STRING_PROJECTION_H_
diff --git a/native/utils/tflite/string_projection_base.cc b/native/utils/tflite/string_projection_base.cc
new file mode 100644
index 0000000..d185f52
--- /dev/null
+++ b/native/utils/tflite/string_projection_base.cc
@@ -0,0 +1,255 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/tflite/string_projection_base.h"
+
+#include <cmath>
+#include <cstdint>
+#include <cstring>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "utils/hash/cityhash.h"
+#include "utils/hash/farmhash.h"
+#include "flatbuffers/flexbuffers.h"
+#include "tensorflow/lite/context.h"
+#include "tensorflow_models/seq_flow_lite/tflite_ops/quantization_util.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+namespace libtextclassifier3 {
+namespace string_projection {
+
+namespace {
+const int32_t kMaxInputChars = 300;
+
+const int kOutputLabel = 0;
+const char kFastHash[] = "[DEV] FastHash";
+const char kAXB[] = "[DEV] AXB";
+
+const int kSeedSize = sizeof(float);
+const int kInputItemBytes = sizeof(int32_t);
+const int kKeyBytes = sizeof(float) + sizeof(int32_t);
+
+} // namespace
+
+StringProjectionOpBase::StringProjectionOpBase(
+ const flexbuffers::Map& custom_options)
+ : hash_function_(custom_options["hash_function"].AsTypedVector()),
+ num_hash_(custom_options["num_hash"].AsInt32()),
+ num_bits_(custom_options["num_bits"].AsInt32()),
+ binary_projection_(custom_options["binary_projection"].AsBool()),
+ hash_method_(custom_options["hash_method"].ToString()),
+ axb_scale_(custom_options["axb_scale"].AsFloat()) {
+ skip_gram_params_ = {
+ .ngram_size = custom_options["ngram_size"].AsInt32(),
+ .max_skip_size = custom_options["max_skip_size"].AsInt32(),
+ .include_all_ngrams = custom_options["include_all_ngrams"].AsBool(),
+ .preprocess = custom_options["preprocess"].AsBool(),
+ .char_level = custom_options["char_level"].AsBool(),
+ .remove_punctuation = custom_options["remove_punctuation"].AsBool(),
+ .max_input_chars = kMaxInputChars,
+ };
+}
+
+void StringProjectionOpBase::GetFeatureWeights(
+ const std::unordered_map<std::string, int>& feature_counts,
+ std::vector<std::vector<int64_t>>* batch_ids,
+ std::vector<std::vector<float>>* batch_weights) {
+ std::vector<int64_t> ids;
+ std::vector<float> weights;
+ for (const auto& iter : feature_counts) {
+ if (hash_method_ == kFastHash || hash_method_ == kAXB) {
+ int32_t feature_id =
+ tc3farmhash::CityHash64(iter.first.c_str(), iter.first.size());
+ ids.push_back(feature_id);
+ weights.push_back(iter.second);
+ } else {
+ int64_t feature_id =
+ tc3farmhash::Fingerprint64(iter.first.c_str(), iter.first.size());
+ ids.push_back(feature_id);
+ weights.push_back(iter.second);
+ }
+ }
+
+ batch_ids->push_back(ids);
+ batch_weights->push_back(weights);
+}
+
+void StringProjectionOpBase::DenseLshProjection(
+ const int batch_size, const std::vector<std::vector<int64_t>>& batch_ids,
+ const std::vector<std::vector<float>>& batch_weights,
+ TfLiteTensor* output) {
+ auto key = std::unique_ptr<char[]>(
+ new char[kKeyBytes]); // NOLINT: modernize-make-unique
+
+ if (output->type == kTfLiteFloat32) {
+ for (int batch = 0; batch < batch_size; ++batch) {
+ const std::vector<int64_t>& input = batch_ids[batch];
+ const std::vector<float>& weight = batch_weights[batch];
+
+ for (int i = 0; i < num_hash_; i++) {
+ for (int j = 0; j < num_bits_; j++) {
+ int hash_bit = i * num_bits_ + j;
+ float seed = hash_function_[hash_bit].AsFloat();
+ float bit = running_sign_bit(input, weight, seed, key.get());
+ output->data.f[batch * num_hash_ * num_bits_ + hash_bit] = bit;
+ }
+ }
+ }
+ } else if (output->type == kTfLiteUInt8) {
+ const float inverse_scale = 1.0 / output->params.scale;
+ for (int batch = 0; batch < batch_size; ++batch) {
+ const std::vector<int64_t>& input = batch_ids[batch];
+ const std::vector<float>& weight = batch_weights[batch];
+
+ for (int i = 0; i < num_hash_; i++) {
+ for (int j = 0; j < num_bits_; j++) {
+ int hash_bit = i * num_bits_ + j;
+ float seed = hash_function_[hash_bit].AsFloat();
+ float bit = running_sign_bit(input, weight, seed, key.get());
+ output->data.uint8[batch * num_hash_ * num_bits_ + hash_bit] =
+ seq_flow_lite::PodQuantize(bit, output->params.zero_point,
+ inverse_scale);
+ }
+ }
+ }
+ }
+}
+
+namespace {
+
+int32_t hash32(int32_t value, uint32_t seed) {
+ uint32_t hash = value;
+ hash = (hash ^ 61) ^ (hash >> 16);
+ hash = hash + (hash << 3);
+ hash = hash ^ (hash >> 4);
+ hash = hash * seed;
+ hash = hash ^ (hash >> 15);
+ return static_cast<int32_t>(hash);
+}
+
+double axb(int32_t value, float seed, float scale) {
+ // Convert seed to a larger scale of range, multiplier is 1e5 to avoid
+ // precision difference on different hardware.
+ int64_t hash_signature =
+ static_cast<int64_t>(scale) * static_cast<int64_t>(seed * 1e5) * value;
+ hash_signature %= 0x100000000;
+ hash_signature = fabs(hash_signature);
+ if (hash_signature >= 0x80000000) {
+ hash_signature -= 0x100000000;
+ }
+ return hash_signature;
+}
+
+} // namespace
+
+// Compute sign bit of dot product of hash(seed, input) and weight.
+float StringProjectionOpBase::running_sign_bit(
+ const std::vector<int64_t>& input, const std::vector<float>& weight,
+ float seed, char* key) {
+ double score = 0.0;
+ memcpy(key, &seed, kSeedSize);
+ int cnt = 0;
+ for (int i = 0; i < input.size(); ++i) {
+ if (weight[i] == 0.0) continue;
+ cnt++;
+ const int32_t curr_input = input[i];
+ memcpy(key + kSeedSize, &curr_input, kInputItemBytes);
+
+ // Create running hash id and value for current dimension.
+ if (hash_method_ == kFastHash) {
+ int32_t hash_signature =
+ hash32(input[i], *reinterpret_cast<uint32_t*>(&seed));
+ score += static_cast<double>(weight[i]) * hash_signature;
+ } else if (hash_method_ == kAXB) {
+ score += weight[i] * axb(input[i], seed, axb_scale_);
+ } else {
+ int64_t hash_signature = tc3farmhash::Fingerprint64(key, kKeyBytes);
+ double running_value = static_cast<double>(hash_signature);
+ score += weight[i] * running_value;
+ }
+ }
+
+ const double inverse_normalizer = 0.00000000046566129;
+ if (!binary_projection_) {
+ if (hash_method_ == kAXB) {
+ return tanh(score / cnt * inverse_normalizer);
+ } else {
+ return tanh(score * inverse_normalizer);
+ }
+ }
+
+ return (score > 0) ? 1 : 0;
+}
+
+void Free(TfLiteContext* context, void* buffer) {
+ delete reinterpret_cast<StringProjectionOpBase*>(buffer);
+}
+
+TfLiteStatus Resize(TfLiteContext* context, TfLiteNode* node) {
+ auto* op = reinterpret_cast<StringProjectionOpBase*>(node->user_data);
+
+ // The shape of the output should be the shape of the input + a new inner
+ // dimension equal to the number of features.
+ TfLiteIntArray* input_shape = op->GetInputShape(context, node);
+ TfLiteIntArray* output_shape = TfLiteIntArrayCreate(input_shape->size + 1);
+ for (int i = 0; i < input_shape->size; ++i) {
+ output_shape->data[i] = input_shape->data[i];
+ }
+ output_shape->data[input_shape->size] = op->num_hash() * op->num_bits();
+ context->ResizeTensor(context,
+ &context->tensors[node->outputs->data[kOutputLabel]],
+ output_shape);
+ return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ auto* op = reinterpret_cast<StringProjectionOpBase*>(node->user_data);
+
+ TfLiteTensor* label = &context->tensors[node->outputs->data[kOutputLabel]];
+
+ TfLiteIntArray* input_shape = op->GetInputShape(context, node);
+ int input_size = 1;
+ for (int i = 0; i < input_shape->size; ++i) {
+ input_size *= input_shape->data[i];
+ }
+
+ TF_LITE_ENSURE_STATUS(op->InitializeInput(context, node));
+
+ std::vector<std::vector<int64_t>> batch_ids;
+ std::vector<std::vector<float>> batch_weights;
+ for (int i = 0; i < input_size; ++i) {
+ std::unordered_map<std::string, int> feature_counts =
+ op->ExtractSkipGrams(i);
+ op->GetFeatureWeights(feature_counts, &batch_ids, &batch_weights);
+ }
+
+ op->DenseLshProjection(input_size, batch_ids, batch_weights, label);
+
+ op->FinalizeInput();
+
+ return kTfLiteOk;
+}
+
+} // namespace string_projection
+} // namespace libtextclassifier3
+} // namespace custom
+} // namespace ops
+} // namespace tflite
diff --git a/native/utils/tflite/string_projection_base.h b/native/utils/tflite/string_projection_base.h
new file mode 100644
index 0000000..61b1708
--- /dev/null
+++ b/native/utils/tflite/string_projection_base.h
@@ -0,0 +1,156 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_TFLITE_STRING_PROJECTION_BASE_H_
+#define LIBTEXTCLASSIFIER_UTILS_TFLITE_STRING_PROJECTION_BASE_H_
+
+/**
+ * String projection op used in Self-Governing Neural Network (SGNN)
+ * and other ProjectionNet models for text prediction.
+ * The code is copied/adapted from
+ * learning/expander/pod/deep_pod/tflite_handlers/
+ */
+
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "flatbuffers/flexbuffers.h"
+#include "tensorflow/lite/context.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+namespace libtextclassifier3 {
+namespace string_projection {
+
+struct SkipGramParams {
+ // Num of tokens in ngram.
+ int ngram_size;
+
+ // Max num of tokens to skip in skip gram.
+ int max_skip_size;
+
+ // True when include all k-grams where k <= ngram_size.
+ bool include_all_ngrams;
+
+ // True when include preprocess.
+ bool preprocess;
+
+ // True when tokens are chars, false when tokens are whitespace separated.
+ bool char_level;
+
+ // True when punctuations are removed.
+ bool remove_punctuation;
+
+ // Max num of chars to process in input.
+ int max_input_chars;
+};
+
+/**
+ * A framework for writing TFLite ops that convert strings to integers via LSH
+ * projections. Input is defined by the specific implementation.
+ * NOTE: Only supports dense projection.
+ *
+ * Attributes:
+ * num_hash: int[]
+ * number of hash functions
+ * num_bits: int[]
+ * number of bits in each hash function
+ * hash_function: float[num_hash * num_bits]
+ * hash_functions used to generate projections
+ * ngram_size: int[]
+ * maximum number of tokens in skipgrams
+ * max_skip_size: int[]
+ * maximum number of tokens to skip between tokens in skipgrams.
+ * include_all_ngrams: bool[]
+ * if false, only use skipgrams with ngram_size tokens
+ * preprocess: bool[]
+ * if true, normalize input strings (lower case, remove punctuation)
+ * hash_method: string[]
+ * hashing function to use
+ * char_level: bool[]
+ * if true, treat each character as a token
+ * binary_projection: bool[]
+ * if true, output features are 0 or 1
+ * remove_punctuation: bool[]
+ * if true, remove punctuation during normalization/preprocessing
+ *
+ * Output:
+ * tensor[0]: computed projections. float32[..., num_func * num_bits]
+ */
+
+class StringProjectionOpBase {
+ public:
+ explicit StringProjectionOpBase(const flexbuffers::Map& custom_options);
+
+ virtual ~StringProjectionOpBase() {}
+
+ void GetFeatureWeights(
+ const std::unordered_map<std::string, int>& feature_counts,
+ std::vector<std::vector<int64_t>>* batch_ids,
+ std::vector<std::vector<float>>* batch_weights);
+
+ void DenseLshProjection(const int batch_size,
+ const std::vector<std::vector<int64_t>>& batch_ids,
+ const std::vector<std::vector<float>>& batch_weights,
+ TfLiteTensor* output);
+
+ inline int num_hash() { return num_hash_; }
+ inline int num_bits() { return num_bits_; }
+ virtual TfLiteStatus InitializeInput(TfLiteContext* context,
+ TfLiteNode* node) = 0;
+ virtual std::unordered_map<std::string, int> ExtractSkipGrams(int i) = 0;
+ virtual void FinalizeInput() = 0;
+
+ // Returns the input shape. TfLiteIntArray is owned by the object.
+ virtual TfLiteIntArray* GetInputShape(TfLiteContext* context,
+ TfLiteNode* node) = 0;
+
+ protected:
+ SkipGramParams& skip_gram_params() { return skip_gram_params_; }
+
+ private:
+ ::flexbuffers::TypedVector hash_function_;
+ int num_hash_;
+ int num_bits_;
+ bool binary_projection_;
+ std::string hash_method_;
+ float axb_scale_;
+ SkipGramParams skip_gram_params_;
+
+ // Compute sign bit of dot product of hash(seed, input) and weight.
+ float running_sign_bit(const std::vector<int64_t>& input,
+ const std::vector<float>& weight, float seed,
+ char* key);
+};
+
+// Individual ops should define an Init() function that returns a
+// BlacklistOpBase.
+
+void Free(TfLiteContext* context, void* buffer);
+
+TfLiteStatus Resize(TfLiteContext* context, TfLiteNode* node);
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node);
+
+} // namespace string_projection
+} // namespace libtextclassifier3
+} // namespace custom
+} // namespace ops
+} // namespace tflite
+
+#endif // LIBTEXTCLASSIFIER_UTILS_TFLITE_STRING_PROJECTION_BASE_H_
diff --git a/native/utils/tflite/text_encoder.cc b/native/utils/tflite/text_encoder.cc
index 78bb51a..80b5854 100644
--- a/native/utils/tflite/text_encoder.cc
+++ b/native/utils/tflite/text_encoder.cc
@@ -100,14 +100,14 @@
config->add_dummy_prefix(), config->remove_extra_whitespaces(),
config->escape_whitespaces()));
- const int num_pieces = config->pieces_scores()->Length();
+ const int num_pieces = config->pieces_scores()->size();
switch (config->matcher_type()) {
case SentencePieceMatcherType_MAPPED_TRIE: {
const TrieNode* pieces_trie_nodes =
reinterpret_cast<const TrieNode*>(config->pieces()->Data());
const int pieces_trie_nodes_length =
- config->pieces()->Length() / sizeof(TrieNode);
+ config->pieces()->size() / sizeof(TrieNode);
encoder_op->matcher.reset(
new DoubleArrayTrie(pieces_trie_nodes, pieces_trie_nodes_length));
break;
@@ -115,7 +115,7 @@
case SentencePieceMatcherType_SORTED_STRING_TABLE: {
encoder_op->matcher.reset(new SortedStringsTable(
num_pieces, config->pieces_offsets()->data(),
- StringPiece(config->pieces()->data(), config->pieces()->Length())));
+ StringPiece(config->pieces()->data(), config->pieces()->size())));
break;
}
default: {
diff --git a/native/utils/tokenizer-utils.cc b/native/utils/tokenizer-utils.cc
new file mode 100644
index 0000000..c812acf
--- /dev/null
+++ b/native/utils/tokenizer-utils.cc
@@ -0,0 +1,56 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/tokenizer-utils.h"
+
+#include <iterator>
+
+#include "utils/codepoint-range.h"
+#include "utils/strings/utf8.h"
+#include "utils/utf8/unicodetext.h"
+#include "utils/utf8/unilib-common.h"
+#include "absl/container/flat_hash_set.h"
+
+namespace libtextclassifier3 {
+
+using libtextclassifier3::Token;
+
+std::vector<Token> TokenizeOnSpace(const std::string& text) {
+ return TokenizeOnDelimiters(text, {' '});
+}
+
+std::vector<Token> TokenizeOnDelimiters(
+ const std::string& text, const absl::flat_hash_set<char32>& delimiters,
+ bool create_tokens_for_non_space_delimiters) {
+ return TokenizeWithFilter(text, [&](char32 codepoint) {
+ bool to_split = delimiters.find(codepoint) != delimiters.end();
+ bool to_keep =
+ (create_tokens_for_non_space_delimiters) ? codepoint != ' ' : false;
+ return FilterResult{to_split, to_keep};
+ });
+}
+
+std::vector<Token> TokenizeOnWhiteSpacePunctuationAndChineseLetter(
+ const absl::string_view text) {
+ return TokenizeWithFilter(text, [](char32 codepoint) {
+ bool is_whitespace = IsWhitespace(codepoint);
+ bool to_split =
+ is_whitespace || IsPunctuation(codepoint) || IsChineseLetter(codepoint);
+ bool to_keep = !is_whitespace;
+ return FilterResult{to_split, to_keep};
+ });
+}
+} // namespace libtextclassifier3
diff --git a/native/utils/tokenizer-utils.h b/native/utils/tokenizer-utils.h
new file mode 100644
index 0000000..7d850d9
--- /dev/null
+++ b/native/utils/tokenizer-utils.h
@@ -0,0 +1,105 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Utilities for tests.
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_TOKENIZER_UTILS_H_
+#define LIBTEXTCLASSIFIER_UTILS_TOKENIZER_UTILS_H_
+
+#include <string>
+
+#include "annotator/types.h"
+#include "utils/codepoint-range.h"
+#include "utils/strings/utf8.h"
+#include "utils/utf8/unicodetext.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/strings/string_view.h"
+
+namespace libtextclassifier3 {
+
+struct FilterResult {
+ // Whether split on this codepoint.
+ bool to_split;
+ // If the codepoint is used to split the text, whether to output it as a
+ // token.
+ bool to_keep;
+};
+
+// Returns a list of Tokens for a given input string, by tokenizing on space.
+std::vector<Token> TokenizeOnSpace(const std::string& text);
+
+// Returns a list of Tokens for a given input string, by tokenizing on the
+// given set of delimiter codepoints.
+// If create_tokens_for_non_space_delimiters is true, create tokens for
+// delimiters which are not white spaces. For example "This, is" -> {"This",
+// ",", "is"}.
+std::vector<Token> TokenizeOnDelimiters(
+ const std::string& text, const absl::flat_hash_set<char32>& delimiters,
+ bool create_tokens_for_non_space_delimiters = false);
+
+// This replicates how the original bert_tokenizer from the tflite-support
+// library pretokenize text by using regex_split with these default regexes.
+// It splits the text on spaces, punctuations and chinese characters and
+// output all the tokens except spaces.
+// So far, the only difference between this and the original implementation
+// we are aware of is that the original regexes has 8 ranges of chinese
+// unicodes. We have all these 8 ranges plus two extra ranges.
+std::vector<Token> TokenizeOnWhiteSpacePunctuationAndChineseLetter(
+ const absl::string_view text);
+
+// Returns a list of Tokens for a given input string, by tokenizing on the
+// given filter function. Caller can control which codepoint to split and
+// whether a delimiter should be output as a token.
+template <typename FilterFn>
+std::vector<Token> TokenizeWithFilter(const absl::string_view input,
+ FilterFn filter) {
+ const UnicodeText input_unicode = UTF8ToUnicodeText(input, /*do_copy=*/false);
+ std::vector<Token> tokens;
+ UnicodeText::const_iterator start_it = input_unicode.begin();
+ int token_start_codepoint = 0;
+ int codepoint_idx = 0;
+
+ for (auto it = input_unicode.begin(); it != input_unicode.end(); ++it) {
+ const char32 code_point = *it;
+ FilterResult filter_result = filter(code_point);
+ if (filter_result.to_split) {
+ const std::string token_text = UnicodeText::UTF8Substring(start_it, it);
+ if (!token_text.empty()) {
+ tokens.push_back(
+ Token{token_text, token_start_codepoint, codepoint_idx});
+ }
+ if (filter_result.to_keep) {
+ const std::string delimiter =
+ UnicodeText::UTF8Substring(it, std::next(it));
+ tokens.push_back(Token{delimiter, codepoint_idx, codepoint_idx + 1});
+ }
+ start_it = std::next(it);
+ token_start_codepoint = codepoint_idx + 1;
+ }
+ codepoint_idx++;
+ }
+ // Flush the last token if any.
+ if (start_it != input_unicode.end()) {
+ const std::string token_text =
+ UnicodeText::UTF8Substring(start_it, input_unicode.end());
+ tokens.push_back(Token{token_text, token_start_codepoint, codepoint_idx});
+ }
+ return tokens;
+}
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_TOKENIZER_UTILS_H_
diff --git a/native/utils/tokenizer-utils_test.cc b/native/utils/tokenizer-utils_test.cc
new file mode 100644
index 0000000..d4a1bc0
--- /dev/null
+++ b/native/utils/tokenizer-utils_test.cc
@@ -0,0 +1,201 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/tokenizer-utils.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+TEST(TokenizerUtilTest, TokenizeOnSpace) {
+ std::vector<Token> tokens =
+ TokenizeOnSpace("Where is Jörg Borg located? Maybe in Zürich ...");
+
+ EXPECT_EQ(tokens.size(), 9);
+
+ EXPECT_EQ(tokens[0].value, "Where");
+ EXPECT_EQ(tokens[0].start, 0);
+ EXPECT_EQ(tokens[0].end, 5);
+
+ EXPECT_EQ(tokens[1].value, "is");
+ EXPECT_EQ(tokens[1].start, 6);
+ EXPECT_EQ(tokens[1].end, 8);
+
+ EXPECT_EQ(tokens[2].value, "Jörg");
+ EXPECT_EQ(tokens[2].start, 9);
+ EXPECT_EQ(tokens[2].end, 13);
+
+ EXPECT_EQ(tokens[3].value, "Borg");
+ EXPECT_EQ(tokens[3].start, 14);
+ EXPECT_EQ(tokens[3].end, 18);
+
+ EXPECT_EQ(tokens[4].value, "located?");
+ EXPECT_EQ(tokens[4].start, 19);
+ EXPECT_EQ(tokens[4].end, 27);
+
+ EXPECT_EQ(tokens[5].value, "Maybe");
+ EXPECT_EQ(tokens[5].start, 28);
+ EXPECT_EQ(tokens[5].end, 33);
+
+ EXPECT_EQ(tokens[6].value, "in");
+ EXPECT_EQ(tokens[6].start, 34);
+ EXPECT_EQ(tokens[6].end, 36);
+
+ EXPECT_EQ(tokens[7].value, "Zürich");
+ EXPECT_EQ(tokens[7].start, 37);
+ EXPECT_EQ(tokens[7].end, 43);
+
+ EXPECT_EQ(tokens[8].value, "...");
+ EXPECT_EQ(tokens[8].start, 44);
+ EXPECT_EQ(tokens[8].end, 47);
+}
+
+TEST(TokenizerUtilTest, TokenizeOnDelimiters) {
+ std::vector<Token> tokens = TokenizeOnDelimiters(
+ "This might be čomplíčateď?!: Oder?", {' ', '?', '!'});
+
+ EXPECT_EQ(tokens.size(), 6);
+
+ EXPECT_EQ(tokens[0].value, "This");
+ EXPECT_EQ(tokens[0].start, 0);
+ EXPECT_EQ(tokens[0].end, 4);
+
+ EXPECT_EQ(tokens[1].value, "might");
+ EXPECT_EQ(tokens[1].start, 7);
+ EXPECT_EQ(tokens[1].end, 12);
+
+ EXPECT_EQ(tokens[2].value, "be");
+ EXPECT_EQ(tokens[2].start, 13);
+ EXPECT_EQ(tokens[2].end, 15);
+
+ EXPECT_EQ(tokens[3].value, "čomplíčateď");
+ EXPECT_EQ(tokens[3].start, 16);
+ EXPECT_EQ(tokens[3].end, 27);
+
+ EXPECT_EQ(tokens[4].value, ":");
+ EXPECT_EQ(tokens[4].start, 29);
+ EXPECT_EQ(tokens[4].end, 30);
+
+ EXPECT_EQ(tokens[5].value, "Oder");
+ EXPECT_EQ(tokens[5].start, 31);
+ EXPECT_EQ(tokens[5].end, 35);
+}
+
+TEST(TokenizerUtilTest, TokenizeOnDelimitersKeepNoSpace) {
+ std::vector<Token> tokens = TokenizeOnDelimiters(
+ "This might be čomplíčateď?!: Oder?", {' ', '?', '!'},
+ /* create_tokens_for_non_space_delimiters =*/true);
+
+ EXPECT_EQ(tokens.size(), 9);
+
+ EXPECT_EQ(tokens[0].value, "This");
+ EXPECT_EQ(tokens[0].start, 0);
+ EXPECT_EQ(tokens[0].end, 4);
+
+ EXPECT_EQ(tokens[1].value, "might");
+ EXPECT_EQ(tokens[1].start, 7);
+ EXPECT_EQ(tokens[1].end, 12);
+
+ EXPECT_EQ(tokens[2].value, "be");
+ EXPECT_EQ(tokens[2].start, 13);
+ EXPECT_EQ(tokens[2].end, 15);
+
+ EXPECT_EQ(tokens[3].value, "čomplíčateď");
+ EXPECT_EQ(tokens[3].start, 16);
+ EXPECT_EQ(tokens[3].end, 27);
+
+ EXPECT_EQ(tokens[4].value, "?");
+ EXPECT_EQ(tokens[4].start, 27);
+ EXPECT_EQ(tokens[4].end, 28);
+
+ EXPECT_EQ(tokens[5].value, "!");
+ EXPECT_EQ(tokens[5].start, 28);
+ EXPECT_EQ(tokens[5].end, 29);
+
+ EXPECT_EQ(tokens[6].value, ":");
+ EXPECT_EQ(tokens[6].start, 29);
+ EXPECT_EQ(tokens[6].end, 30);
+
+ EXPECT_EQ(tokens[7].value, "Oder");
+ EXPECT_EQ(tokens[7].start, 31);
+ EXPECT_EQ(tokens[7].end, 35);
+
+ EXPECT_EQ(tokens[8].value, "?");
+ EXPECT_EQ(tokens[8].start, 35);
+ EXPECT_EQ(tokens[8].end, 36);
+}
+
+TEST(TokenizerUtilTest, SimpleEnglishWithPunctuation) {
+ absl::string_view input = "I am fine, thanks!";
+
+ std::vector<Token> tokens =
+ TokenizeOnWhiteSpacePunctuationAndChineseLetter(input);
+
+ EXPECT_THAT(tokens, testing::ElementsAreArray(
+ {Token{"I", 0, 1}, Token{"am", 2, 4},
+ Token{"fine", 5, 9}, Token{",", 9, 10},
+ Token{"thanks", 11, 17}, Token{"!", 17, 18}}));
+}
+
+TEST(TokenizerUtilTest, InputDoesNotEndWithDelimiter) {
+ absl::string_view input = "Good! Cool";
+
+ std::vector<Token> tokens =
+ TokenizeOnWhiteSpacePunctuationAndChineseLetter(input);
+
+ EXPECT_THAT(tokens,
+ testing::ElementsAreArray({Token{"Good", 0, 4}, Token{"!", 4, 5},
+ Token{"Cool", 6, 10}}));
+}
+
+TEST(TokenizerUtilTest, OnlySpace) {
+ absl::string_view input = " \t";
+
+ std::vector<Token> tokens =
+ TokenizeOnWhiteSpacePunctuationAndChineseLetter(input);
+
+ ASSERT_TRUE(tokens.empty());
+}
+
+TEST(TokenizerUtilTest, Punctuation) {
+ absl::string_view input = "!-/:-@[-`{-~";
+
+ std::vector<Token> tokens =
+ TokenizeOnWhiteSpacePunctuationAndChineseLetter(input);
+
+ EXPECT_THAT(tokens,
+ testing::ElementsAreArray(
+ {Token{"!", 0, 1}, Token{"-", 1, 2}, Token{"/", 2, 3},
+ Token{":", 3, 4}, Token{"-", 4, 5}, Token{"@", 5, 6},
+ Token{"[", 6, 7}, Token{"-", 7, 8}, Token{"`", 8, 9},
+ Token{"{", 9, 10}, Token{"-", 10, 11}, Token{"~", 11, 12}}));
+}
+
+TEST(TokenizerUtilTest, ChineseCharacters) {
+ absl::string_view input = "你好嗎三個字";
+
+ std::vector<Token> tokens =
+ TokenizeOnWhiteSpacePunctuationAndChineseLetter(input);
+
+ EXPECT_THAT(tokens,
+ testing::ElementsAreArray(
+ {Token{"你", 0, 1}, Token{"好", 1, 2}, Token{"嗎", 2, 3},
+ Token{"三", 3, 4}, Token{"個", 4, 5}, Token{"字", 5, 6}}));
+}
+} // namespace
+} // namespace libtextclassifier3
diff --git a/native/utils/tokenizer.cc b/native/utils/tokenizer.cc
index bd47592..20f72c4 100644
--- a/native/utils/tokenizer.cc
+++ b/native/utils/tokenizer.cc
@@ -50,6 +50,10 @@
SortCodepointRanges(internal_tokenizer_codepoint_ranges,
&internal_tokenizer_codepoint_ranges_);
+ if (type_ == TokenizationType_MIXED && split_on_script_change) {
+ TC3_LOG(ERROR) << "The option `split_on_script_change` is unavailable for "
+ "the selected tokenizer type (mixed).";
+ }
}
const TokenizationCodepointRangeT* Tokenizer::FindTokenizationRange(
@@ -234,18 +238,20 @@
if (!break_iterator) {
return false;
}
- int last_break_index = 0;
- int break_index = 0;
+ const int context_unicode_size = context_unicode.size_codepoints();
int last_unicode_index = 0;
int unicode_index = 0;
auto token_begin_it = context_unicode.begin();
- while ((break_index = break_iterator->Next()) !=
+ while ((unicode_index = break_iterator->Next()) !=
UniLib::BreakIterator::kDone) {
- const int token_length = break_index - last_break_index;
- unicode_index = last_unicode_index + token_length;
+ const int token_length = unicode_index - last_unicode_index;
+ if (token_length + last_unicode_index > context_unicode_size) {
+ return false;
+ }
auto token_end_it = token_begin_it;
std::advance(token_end_it, token_length);
+ TC3_CHECK(token_end_it <= context_unicode.end());
// Determine if the whole token is whitespace.
bool is_whitespace = true;
@@ -264,7 +270,6 @@
/*is_padding=*/false, is_whitespace));
}
- last_break_index = break_index;
last_unicode_index = unicode_index;
token_begin_it = token_end_it;
}
diff --git a/native/utils/tokenizer.fbs b/native/utils/tokenizer.fbs
old mode 100755
new mode 100644
diff --git a/native/utils/utf8/unicodetext.cc b/native/utils/utf8/unicodetext.cc
index 7b56ce2..a8bc9fb 100644
--- a/native/utils/utf8/unicodetext.cc
+++ b/native/utils/utf8/unicodetext.cc
@@ -22,6 +22,7 @@
#include "utils/base/logging.h"
#include "utils/strings/utf8.h"
+#include "absl/strings/string_view.h"
namespace libtextclassifier3 {
@@ -202,6 +203,22 @@
return IsValidUTF8(repr_.data_, repr_.size_);
}
+std::vector<UnicodeText::const_iterator> UnicodeText::Codepoints() const {
+ std::vector<UnicodeText::const_iterator> codepoints;
+ for (auto it = begin(); it != end(); it++) {
+ codepoints.push_back(it);
+ }
+ return codepoints;
+}
+
+std::vector<char32> UnicodeText::CodepointsChar32() const {
+ std::vector<char32> codepoints;
+ for (auto it = begin(); it != end(); it++) {
+ codepoints.push_back(*it);
+ }
+ return codepoints;
+}
+
bool UnicodeText::operator==(const UnicodeText& other) const {
if (repr_.size_ != other.repr_.size_) {
return false;
@@ -320,4 +337,8 @@
return UTF8ToUnicodeText(str.data(), str.size(), do_copy);
}
+UnicodeText UTF8ToUnicodeText(absl::string_view str, bool do_copy) {
+ return UTF8ToUnicodeText(str.data(), str.size(), do_copy);
+}
+
} // namespace libtextclassifier3
diff --git a/native/utils/utf8/unicodetext.h b/native/utils/utf8/unicodetext.h
index 9810480..1eb41bc 100644
--- a/native/utils/utf8/unicodetext.h
+++ b/native/utils/utf8/unicodetext.h
@@ -20,10 +20,12 @@
#include <iterator>
#include <string>
#include <utility>
+#include <vector>
#include "utils/base/integral_types.h"
#include "utils/base/logging.h"
#include "utils/strings/stringpiece.h"
+#include "absl/strings/string_view.h"
namespace libtextclassifier3 {
@@ -174,6 +176,12 @@
UnicodeText& push_back(char32 ch);
void clear();
+ // Returns an iterator for each codepoint.
+ std::vector<const_iterator> Codepoints() const;
+
+ // Returns the list of codepoints of the UnicodeText.
+ std::vector<char32> CodepointsChar32() const;
+
std::string ToUTF8String() const;
std::string UTF8Substring(int begin_codepoint, int end_codepoint) const;
static std::string UTF8Substring(const const_iterator& it_begin,
@@ -230,6 +238,7 @@
UnicodeText UTF8ToUnicodeText(const char* utf8_buf, bool do_copy = true);
UnicodeText UTF8ToUnicodeText(const std::string& str, bool do_copy = true);
UnicodeText UTF8ToUnicodeText(StringPiece str, bool do_copy = true);
+UnicodeText UTF8ToUnicodeText(absl::string_view str, bool do_copy = true);
inline logging::LoggingStringStream& operator<<(
logging::LoggingStringStream& stream, const UnicodeText& message) {
diff --git a/native/utils/utf8/unilib-common.cc b/native/utils/utf8/unilib-common.cc
index de52086..70b8fec 100644
--- a/native/utils/utf8/unilib-common.cc
+++ b/native/utils/utf8/unilib-common.cc
@@ -61,6 +61,12 @@
0x1F501, 0x1F502, 0x1F503, 0x1F504, 0x1F5D8, 0x1F5DE};
constexpr int kNumWhitespaces = ARRAYSIZE(kWhitespaces);
+// https://en.wikipedia.org/wiki/Bidirectional_text
+constexpr char32 kBidirectional[] = {0x061C, 0x200E, 0x200F, 0x202A,
+ 0x202B, 0x202C, 0x202D, 0x202E,
+ 0x2066, 0x2067, 0x2068, 0x2069};
+constexpr int kNumBidirectional = ARRAYSIZE(kBidirectional);
+
// grep -E "Nd" UnicodeData.txt | sed -re "s/([0-9A-Z]+);.*/0x\1, /"
// As the name suggests, these ranges are always 10 codepoints long, so we just
// store the end of the range.
@@ -389,6 +395,22 @@
constexpr char32 kDots[] = {0x002e, 0xfe52, 0xff0e};
constexpr int kNumDots = ARRAYSIZE(kDots);
+// Source: https://unicode-search.net/unicode-namesearch.pl?term=Apostrophe
+constexpr char32 kApostrophe[] = {0x0027, 0x02BC, 0x02EE, 0x055A,
+ 0x07F4, 0x07F5, 0xFF07};
+constexpr int kNumApostrophe = ARRAYSIZE(kApostrophe);
+
+// Source: https://unicode-search.net/unicode-namesearch.pl?term=Quotation
+constexpr char32 kQuotation[] = {
+ 0x0022, 0x00AB, 0x00BB, 0x2018, 0x2019, 0x201A, 0x201B, 0x201C,
+ 0x201D, 0x201E, 0x201F, 0x2039, 0x203A, 0x275B, 0x275C, 0x275D,
+ 0x275E, 0x276E, 0x276F, 0x2E42, 0x301D, 0x301E, 0x301F, 0xFF02};
+constexpr int kNumQuotation = ARRAYSIZE(kQuotation);
+
+// Source: https://unicode-search.net/unicode-namesearch.pl?term=ampersand
+constexpr char32 kAmpersand[] = {0x0026, 0xFE60, 0xFF06, 0x1F674, 0x1F675};
+constexpr int kNumAmpersand = ARRAYSIZE(kAmpersand);
+
#undef ARRAYSIZE
static_assert(kNumOpeningBrackets == kNumClosingBrackets,
@@ -502,6 +524,10 @@
return GetMatchIndex(kWhitespaces, kNumWhitespaces, codepoint) >= 0;
}
+bool IsBidirectional(char32 codepoint) {
+ return GetMatchIndex(kBidirectional, kNumBidirectional, codepoint) >= 0;
+}
+
bool IsDigit(char32 codepoint) {
return GetOverlappingRangeIndex(kDecimalDigitRangesEnd,
kNumDecimalDigitRangesEnd,
@@ -566,6 +592,18 @@
return GetMatchIndex(kDots, kNumDots, codepoint) >= 0;
}
+bool IsApostrophe(char32 codepoint) {
+ return GetMatchIndex(kApostrophe, kNumApostrophe, codepoint) >= 0;
+}
+
+bool IsQuotation(char32 codepoint) {
+ return GetMatchIndex(kQuotation, kNumQuotation, codepoint) >= 0;
+}
+
+bool IsAmpersand(char32 codepoint) {
+ return GetMatchIndex(kAmpersand, kNumAmpersand, codepoint) >= 0;
+}
+
bool IsLatinLetter(char32 codepoint) {
return (GetOverlappingRangeIndex(
kLatinLettersRangesStart, kLatinLettersRangesEnd,
diff --git a/native/utils/utf8/unilib-common.h b/native/utils/utf8/unilib-common.h
index 4f03de7..b192034 100644
--- a/native/utils/utf8/unilib-common.h
+++ b/native/utils/utf8/unilib-common.h
@@ -25,6 +25,7 @@
bool IsOpeningBracket(char32 codepoint);
bool IsClosingBracket(char32 codepoint);
bool IsWhitespace(char32 codepoint);
+bool IsBidirectional(char32 codepoint);
bool IsDigit(char32 codepoint);
bool IsLower(char32 codepoint);
bool IsUpper(char32 codepoint);
@@ -34,6 +35,9 @@
bool IsMinus(char32 codepoint);
bool IsNumberSign(char32 codepoint);
bool IsDot(char32 codepoint);
+bool IsApostrophe(char32 codepoint);
+bool IsQuotation(char32 codepoint);
+bool IsAmpersand(char32 codepoint);
bool IsLatinLetter(char32 codepoint);
bool IsArabicLetter(char32 codepoint);
@@ -49,6 +53,23 @@
char32 ToUpper(char32 codepoint);
char32 GetPairedBracket(char32 codepoint);
+// Checks if the text format is not likely to be a number. Used to avoid most of
+// the java exceptions thrown when fail to parse.
+template <class T>
+bool PassesIntPreChesks(const UnicodeText& text, const T result) {
+ if (text.empty() ||
+ (std::is_same<T, int32>::value && text.size_codepoints() > 10) ||
+ (std::is_same<T, int64>::value && text.size_codepoints() > 19)) {
+ return false;
+ }
+ for (auto it = text.begin(); it != text.end(); ++it) {
+ if (!IsDigit(*it)) {
+ return false;
+ }
+ }
+ return true;
+}
+
} // namespace libtextclassifier3
#endif // LIBTEXTCLASSIFIER_UTILS_UTF8_UNILIB_COMMON_H_
diff --git a/native/utils/utf8/unilib-javaicu.cc b/native/utils/utf8/unilib-javaicu.cc
index de6b5ed..befe639 100644
--- a/native/utils/utf8/unilib-javaicu.cc
+++ b/native/utils/utf8/unilib-javaicu.cc
@@ -25,9 +25,8 @@
#include "utils/base/logging.h"
#include "utils/base/statusor.h"
#include "utils/java/jni-base.h"
-#include "utils/java/string_utils.h"
+#include "utils/java/jni-helper.h"
#include "utils/utf8/unicodetext.h"
-#include "utils/utf8/unilib-common.h"
namespace libtextclassifier3 {
@@ -82,6 +81,20 @@
// Implementations that call out to JVM. Behold the beauty.
// -----------------------------------------------------------------------------
+StatusOr<int32> UniLibBase::Length(const UnicodeText& text) const {
+ TC3_ASSIGN_OR_RETURN(ScopedLocalRef<jstring> text_java,
+ jni_cache_->ConvertToJavaString(text));
+
+ JNIEnv* jenv = jni_cache_->GetEnv();
+ TC3_ASSIGN_OR_RETURN(int utf16_length,
+ JniHelper::CallIntMethod(jenv, text_java.get(),
+ jni_cache_->string_length));
+
+ return JniHelper::CallIntMethod(jenv, text_java.get(),
+ jni_cache_->string_code_point_count, 0,
+ utf16_length);
+}
+
bool UniLibBase::ParseInt32(const UnicodeText& text, int32* result) const {
return ParseInt(text, result);
}
@@ -95,29 +108,23 @@
return false;
}
- JNIEnv* env = jni_cache_->GetEnv();
auto it_dot = text.begin();
for (; it_dot != text.end() && !IsDot(*it_dot); it_dot++) {
}
- int64 integer_part;
+ int32 integer_part;
if (!ParseInt(UnicodeText::Substring(text.begin(), it_dot, /*do_copy=*/false),
&integer_part)) {
return false;
}
- int64 fractional_part = 0;
+ int32 fractional_part = 0;
if (it_dot != text.end()) {
- std::string fractional_part_str =
- UnicodeText::UTF8Substring(++it_dot, text.end());
- TC3_ASSIGN_OR_RETURN_FALSE(
- const ScopedLocalRef<jstring> fractional_text_java,
- jni_cache_->ConvertToJavaString(fractional_part_str));
- TC3_ASSIGN_OR_RETURN_FALSE(
- fractional_part,
- JniHelper::CallStaticIntMethod<int64>(
- env, jni_cache_->integer_class.get(), jni_cache_->integer_parse_int,
- fractional_text_java.get()));
+ if (!ParseInt(
+ UnicodeText::Substring(++it_dot, text.end(), /*do_copy=*/false),
+ &fractional_part)) {
+ return false;
+ }
}
double factional_part_double = fractional_part;
@@ -420,14 +427,14 @@
return UTF8ToUnicodeText("", /*do_copy=*/false);
}
- std::string result;
- if (!JStringToUtf8String(jenv, status_or_java_result.ValueOrDie().get(),
- &result)) {
+ StatusOr<std::string> status_or_result =
+ JStringToUtf8String(jenv, status_or_java_result.ValueOrDie().get());
+ if (!status_or_result.ok()) {
*status = kError;
return UTF8ToUnicodeText("", /*do_copy=*/false);
}
*status = kNoError;
- return UTF8ToUnicodeText(result, /*do_copy=*/true);
+ return UTF8ToUnicodeText(status_or_result.ValueOrDie(), /*do_copy=*/true);
} else {
*status = kError;
return UTF8ToUnicodeText("", /*do_copy=*/false);
@@ -455,14 +462,14 @@
return UTF8ToUnicodeText("", /*do_copy=*/false);
}
- std::string result;
- if (!JStringToUtf8String(jenv, status_or_java_result.ValueOrDie().get(),
- &result)) {
+ StatusOr<std::string> status_or_result =
+ JStringToUtf8String(jenv, status_or_java_result.ValueOrDie().get());
+ if (!status_or_result.ok()) {
*status = kError;
return UTF8ToUnicodeText("", /*do_copy=*/false);
}
*status = kNoError;
- return UTF8ToUnicodeText(result, /*do_copy=*/true);
+ return UTF8ToUnicodeText(status_or_result.ValueOrDie(), /*do_copy=*/true);
} else {
*status = kError;
return UTF8ToUnicodeText("", /*do_copy=*/false);
diff --git a/native/utils/utf8/unilib-javaicu.h b/native/utils/utf8/unilib-javaicu.h
index d208730..8b04789 100644
--- a/native/utils/utf8/unilib-javaicu.h
+++ b/native/utils/utf8/unilib-javaicu.h
@@ -31,8 +31,8 @@
#include "utils/java/jni-base.h"
#include "utils/java/jni-cache.h"
#include "utils/java/jni-helper.h"
-#include "utils/java/string_utils.h"
#include "utils/utf8/unicodetext.h"
+#include "utils/utf8/unilib-common.h"
namespace libtextclassifier3 {
@@ -57,6 +57,8 @@
char32 ToUpper(char32 codepoint) const;
char32 GetPairedBracket(char32 codepoint) const;
+ StatusOr<int32> Length(const UnicodeText& text) const;
+
// Forward declaration for friend.
class RegexPattern;
@@ -115,9 +117,13 @@
// Returns the matched text (the 0th capturing group).
std::string Text() const {
- ScopedStringChars text_str =
- GetScopedStringChars(jni_cache_->GetEnv(), text_.get());
- return text_str.get();
+ StatusOr<std::string> status_or_result =
+ JStringToUtf8String(jni_cache_->GetEnv(), text_.get());
+ if (!status_or_result.ok()) {
+ TC3_LOG(ERROR) << "JStringToUtf8String failed.";
+ return "";
+ }
+ return status_or_result.ValueOrDie();
}
private:
@@ -194,13 +200,21 @@
return false;
}
+ // Avoid throwing exceptions when the text is unlikely to be a number.
+ int32 result32 = 0;
+ if (!PassesIntPreChesks(text, result32)) {
+ return false;
+ }
+
JNIEnv* env = jni_cache_->GetEnv();
TC3_ASSIGN_OR_RETURN_FALSE(const ScopedLocalRef<jstring> text_java,
jni_cache_->ConvertToJavaString(text));
TC3_ASSIGN_OR_RETURN_FALSE(
- *result, JniHelper::CallStaticIntMethod<T>(
- env, jni_cache_->integer_class.get(),
- jni_cache_->integer_parse_int, text_java.get()));
+ *result,
+ JniHelper::CallStaticIntMethod<T>(
+ env,
+ /*print_exception_on_error=*/false, jni_cache_->integer_class.get(),
+ jni_cache_->integer_parse_int, text_java.get()));
return true;
}
diff --git a/native/utils/utf8/unilib.h b/native/utils/utf8/unilib.h
index d0e6164..ffda7d9 100644
--- a/native/utils/utf8/unilib.h
+++ b/native/utils/utf8/unilib.h
@@ -30,9 +30,6 @@
#elif defined TC3_UNILIB_APPLE
#include "utils/utf8/unilib-apple.h"
#define INIT_UNILIB_FOR_TESTING(VAR) VAR()
-#elif defined TC3_UNILIB_DUMMY
-#include "utils/utf8/unilib-dummy.h"
-#define INIT_UNILIB_FOR_TESTING(VAR) VAR()
#else
#error No TC3_UNILIB implementation specified.
#endif
@@ -108,6 +105,18 @@
return libtextclassifier3::IsDot(codepoint);
}
+ bool IsApostrophe(char32 codepoint) const {
+ return libtextclassifier3::IsApostrophe(codepoint);
+ }
+
+ bool IsQuotation(char32 codepoint) const {
+ return libtextclassifier3::IsQuotation(codepoint);
+ }
+
+ bool IsAmpersand(char32 codepoint) const {
+ return libtextclassifier3::IsAmpersand(codepoint);
+ }
+
bool IsLatinLetter(char32 codepoint) const {
return libtextclassifier3::IsLatinLetter(codepoint);
}
@@ -143,6 +152,31 @@
bool IsLetter(char32 codepoint) const {
return libtextclassifier3::IsLetter(codepoint);
}
+
+ bool IsValidUtf8(const UnicodeText& text) const {
+ // Basic check of structural validity of UTF8.
+ if (!text.is_valid()) {
+ return false;
+ }
+ // In addition to that, we declare that a valid UTF8 is when the number of
+ // codepoints in the string as measured by ICU is the same as the number of
+ // codepoints as measured by UnicodeText. Because if we don't do this check,
+ // the indices might differ, and cause trouble, because the assumption
+ // throughout the code is that ICU indices and UnicodeText indices are the
+ // same.
+ // NOTE: This is not perfect, as this doesn't check the alignment of the
+ // codepoints, but for the practical purposes should be enough.
+ const StatusOr<int32> icu_length = Length(text);
+ if (!icu_length.ok()) {
+ return false;
+ }
+
+ if (icu_length.ValueOrDie() != text.size_codepoints()) {
+ return false;
+ }
+
+ return true;
+ }
};
} // namespace libtextclassifier3
diff --git a/native/utils/utf8/unilib_test-include.cc b/native/utils/utf8/unilib_test-include.cc
new file mode 100644
index 0000000..ed0f184
--- /dev/null
+++ b/native/utils/utf8/unilib_test-include.cc
@@ -0,0 +1,548 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/utf8/unilib_test-include.h"
+
+#include "utils/base/logging.h"
+#include "gmock/gmock.h"
+
+namespace libtextclassifier3 {
+namespace test_internal {
+
+using ::testing::ElementsAre;
+
+TEST_F(UniLibTest, CharacterClassesAscii) {
+ EXPECT_TRUE(unilib_->IsOpeningBracket('('));
+ EXPECT_TRUE(unilib_->IsClosingBracket(')'));
+ EXPECT_FALSE(unilib_->IsWhitespace(')'));
+ EXPECT_TRUE(unilib_->IsWhitespace(' '));
+ EXPECT_FALSE(unilib_->IsDigit(')'));
+ EXPECT_TRUE(unilib_->IsDigit('0'));
+ EXPECT_TRUE(unilib_->IsDigit('9'));
+ EXPECT_FALSE(unilib_->IsUpper(')'));
+ EXPECT_TRUE(unilib_->IsUpper('A'));
+ EXPECT_TRUE(unilib_->IsUpper('Z'));
+ EXPECT_FALSE(unilib_->IsLower(')'));
+ EXPECT_TRUE(unilib_->IsLower('a'));
+ EXPECT_TRUE(unilib_->IsLower('z'));
+ EXPECT_TRUE(unilib_->IsPunctuation('!'));
+ EXPECT_TRUE(unilib_->IsPunctuation('?'));
+ EXPECT_TRUE(unilib_->IsPunctuation('#'));
+ EXPECT_TRUE(unilib_->IsPunctuation('('));
+ EXPECT_FALSE(unilib_->IsPunctuation('0'));
+ EXPECT_FALSE(unilib_->IsPunctuation('$'));
+ EXPECT_TRUE(unilib_->IsPercentage('%'));
+ EXPECT_TRUE(unilib_->IsPercentage(u'%'));
+ EXPECT_TRUE(unilib_->IsSlash('/'));
+ EXPECT_TRUE(unilib_->IsSlash(u'/'));
+ EXPECT_TRUE(unilib_->IsMinus('-'));
+ EXPECT_TRUE(unilib_->IsMinus(u'-'));
+ EXPECT_TRUE(unilib_->IsNumberSign('#'));
+ EXPECT_TRUE(unilib_->IsNumberSign(u'#'));
+ EXPECT_TRUE(unilib_->IsDot('.'));
+ EXPECT_TRUE(unilib_->IsDot(u'.'));
+ EXPECT_TRUE(unilib_->IsApostrophe('\''));
+ EXPECT_TRUE(unilib_->IsApostrophe(u'ߴ'));
+ EXPECT_TRUE(unilib_->IsQuotation(u'"'));
+ EXPECT_TRUE(unilib_->IsQuotation(u'”'));
+ EXPECT_TRUE(unilib_->IsAmpersand(u'&'));
+ EXPECT_TRUE(unilib_->IsAmpersand(u'﹠'));
+ EXPECT_TRUE(unilib_->IsAmpersand(u'&'));
+
+ EXPECT_TRUE(unilib_->IsLatinLetter('A'));
+ EXPECT_TRUE(unilib_->IsArabicLetter(u'ب')); // ARABIC LETTER BEH
+ EXPECT_TRUE(
+ unilib_->IsCyrillicLetter(u'ᲀ')); // CYRILLIC SMALL LETTER ROUNDED VE
+ EXPECT_TRUE(unilib_->IsChineseLetter(u'豈')); // CJK COMPATIBILITY IDEOGRAPH
+ EXPECT_TRUE(unilib_->IsJapaneseLetter(u'ぁ')); // HIRAGANA LETTER SMALL A
+ EXPECT_TRUE(unilib_->IsKoreanLetter(u'ㄱ')); // HANGUL LETTER KIYEOK
+ EXPECT_TRUE(unilib_->IsThaiLetter(u'ก')); // THAI CHARACTER KO KAI
+ EXPECT_TRUE(unilib_->IsCJTletter(u'ก')); // THAI CHARACTER KO KAI
+ EXPECT_FALSE(unilib_->IsCJTletter('A'));
+
+ EXPECT_TRUE(unilib_->IsLetter('A'));
+ EXPECT_TRUE(unilib_->IsLetter(u'A'));
+ EXPECT_TRUE(unilib_->IsLetter(u'ト')); // KATAKANA LETTER TO
+ EXPECT_TRUE(unilib_->IsLetter(u'ト')); // HALFWIDTH KATAKANA LETTER TO
+ EXPECT_TRUE(unilib_->IsLetter(u'豈')); // CJK COMPATIBILITY IDEOGRAPH
+
+ EXPECT_EQ(unilib_->ToLower('A'), 'a');
+ EXPECT_EQ(unilib_->ToLower('Z'), 'z');
+ EXPECT_EQ(unilib_->ToLower(')'), ')');
+ EXPECT_EQ(unilib_->ToLowerText(UTF8ToUnicodeText("Never gonna give you up."))
+ .ToUTF8String(),
+ "never gonna give you up.");
+ EXPECT_EQ(unilib_->ToUpper('a'), 'A');
+ EXPECT_EQ(unilib_->ToUpper('z'), 'Z');
+ EXPECT_EQ(unilib_->ToUpper(')'), ')');
+ EXPECT_EQ(unilib_->ToUpperText(UTF8ToUnicodeText("Never gonna let you down."))
+ .ToUTF8String(),
+ "NEVER GONNA LET YOU DOWN.");
+ EXPECT_EQ(unilib_->GetPairedBracket(')'), '(');
+ EXPECT_EQ(unilib_->GetPairedBracket('}'), '{');
+}
+
+TEST_F(UniLibTest, CharacterClassesUnicode) {
+ EXPECT_TRUE(unilib_->IsOpeningBracket(0x0F3C)); // TIBET ANG KHANG GYON
+ EXPECT_TRUE(unilib_->IsClosingBracket(0x0F3D)); // TIBET ANG KHANG GYAS
+ EXPECT_FALSE(unilib_->IsWhitespace(0x23F0)); // ALARM CLOCK
+ EXPECT_TRUE(unilib_->IsWhitespace(0x2003)); // EM SPACE
+ EXPECT_FALSE(unilib_->IsDigit(0xA619)); // VAI SYMBOL JONG
+ EXPECT_TRUE(unilib_->IsDigit(0xA620)); // VAI DIGIT ZERO
+ EXPECT_TRUE(unilib_->IsDigit(0xA629)); // VAI DIGIT NINE
+ EXPECT_FALSE(unilib_->IsDigit(0xA62A)); // VAI SYLLABLE NDOLE MA
+ EXPECT_FALSE(unilib_->IsUpper(0x0211)); // SMALL R WITH DOUBLE GRAVE
+ EXPECT_TRUE(unilib_->IsUpper(0x0212)); // CAPITAL R WITH DOUBLE GRAVE
+ EXPECT_TRUE(unilib_->IsUpper(0x0391)); // GREEK CAPITAL ALPHA
+ EXPECT_TRUE(unilib_->IsUpper(0x03AB)); // GREEK CAPITAL UPSILON W DIAL
+ EXPECT_FALSE(unilib_->IsUpper(0x03AC)); // GREEK SMALL ALPHA WITH TONOS
+ EXPECT_TRUE(unilib_->IsLower(0x03AC)); // GREEK SMALL ALPHA WITH TONOS
+ EXPECT_TRUE(unilib_->IsLower(0x03B1)); // GREEK SMALL ALPHA
+ EXPECT_TRUE(unilib_->IsLower(0x03CB)); // GREEK SMALL UPSILON
+ EXPECT_TRUE(unilib_->IsLower(0x0211)); // SMALL R WITH DOUBLE GRAVE
+ EXPECT_TRUE(unilib_->IsLower(0x03C0)); // GREEK SMALL PI
+ EXPECT_TRUE(unilib_->IsLower(0x007A)); // SMALL Z
+ EXPECT_FALSE(unilib_->IsLower(0x005A)); // CAPITAL Z
+ EXPECT_FALSE(unilib_->IsLower(0x0212)); // CAPITAL R WITH DOUBLE GRAVE
+ EXPECT_FALSE(unilib_->IsLower(0x0391)); // GREEK CAPITAL ALPHA
+ EXPECT_TRUE(unilib_->IsPunctuation(0x055E)); // ARMENIAN QUESTION MARK
+ EXPECT_TRUE(unilib_->IsPunctuation(0x066C)); // ARABIC THOUSANDS SEPARATOR
+ EXPECT_TRUE(unilib_->IsPunctuation(0x07F7)); // NKO SYMBOL GBAKURUNEN
+ EXPECT_TRUE(unilib_->IsPunctuation(0x10AF2)); // DOUBLE DOT WITHIN DOT
+ EXPECT_FALSE(unilib_->IsPunctuation(0x00A3)); // POUND SIGN
+ EXPECT_FALSE(unilib_->IsPunctuation(0xA838)); // NORTH INDIC RUPEE MARK
+ EXPECT_TRUE(unilib_->IsPercentage(0x0025)); // PERCENT SIGN
+ EXPECT_TRUE(unilib_->IsPercentage(0xFF05)); // FULLWIDTH PERCENT SIGN
+ EXPECT_TRUE(unilib_->IsSlash(0x002F)); // SOLIDUS
+ EXPECT_TRUE(unilib_->IsSlash(0xFF0F)); // FULLWIDTH SOLIDUS
+ EXPECT_TRUE(unilib_->IsMinus(0x002D)); // HYPHEN-MINUS
+ EXPECT_TRUE(unilib_->IsMinus(0xFF0D)); // FULLWIDTH HYPHEN-MINUS
+ EXPECT_TRUE(unilib_->IsNumberSign(0x0023)); // NUMBER SIGN
+ EXPECT_TRUE(unilib_->IsNumberSign(0xFF03)); // FULLWIDTH NUMBER SIGN
+ EXPECT_TRUE(unilib_->IsDot(0x002E)); // FULL STOP
+ EXPECT_TRUE(unilib_->IsDot(0xFF0E)); // FULLWIDTH FULL STOP
+
+ EXPECT_TRUE(unilib_->IsLatinLetter(0x0041)); // LATIN CAPITAL LETTER A
+ EXPECT_TRUE(unilib_->IsArabicLetter(0x0628)); // ARABIC LETTER BEH
+ EXPECT_TRUE(
+ unilib_->IsCyrillicLetter(0x1C80)); // CYRILLIC SMALL LETTER ROUNDED VE
+ EXPECT_TRUE(unilib_->IsChineseLetter(0xF900)); // CJK COMPATIBILITY IDEOGRAPH
+ EXPECT_TRUE(unilib_->IsJapaneseLetter(0x3041)); // HIRAGANA LETTER SMALL A
+ EXPECT_TRUE(unilib_->IsKoreanLetter(0x3131)); // HANGUL LETTER KIYEOK
+ EXPECT_TRUE(unilib_->IsThaiLetter(0x0E01)); // THAI CHARACTER KO KAI
+ EXPECT_TRUE(unilib_->IsCJTletter(0x0E01)); // THAI CHARACTER KO KAI
+ EXPECT_FALSE(unilib_->IsCJTletter(0x0041)); // LATIN CAPITAL LETTER A
+
+ EXPECT_TRUE(unilib_->IsLetter(0x0041)); // LATIN CAPITAL LETTER A
+ EXPECT_TRUE(unilib_->IsLetter(0xFF21)); // FULLWIDTH LATIN CAPITAL LETTER A
+ EXPECT_TRUE(unilib_->IsLetter(0x30C8)); // KATAKANA LETTER TO
+ EXPECT_TRUE(unilib_->IsLetter(0xFF84)); // HALFWIDTH KATAKANA LETTER TO
+ EXPECT_TRUE(unilib_->IsLetter(0xF900)); // CJK COMPATIBILITY IDEOGRAPH
+
+ EXPECT_EQ(unilib_->ToLower(0x0391), 0x03B1); // GREEK ALPHA
+ EXPECT_EQ(unilib_->ToLower(0x03AB), 0x03CB); // GREEK UPSILON WITH DIALYTIKA
+ EXPECT_EQ(unilib_->ToLower(0x03C0), 0x03C0); // GREEK SMALL PI
+ EXPECT_EQ(unilib_->ToLower(0x03A3), 0x03C3); // GREEK CAPITAL LETTER SIGMA
+ EXPECT_EQ(
+ unilib_->ToLowerText(UTF8ToUnicodeText("Κανένας άνθρωπος δεν ξέρει"))
+ .ToUTF8String(),
+ "κανένας άνθρωπος δεν ξέρει");
+ EXPECT_TRUE(unilib_->IsLowerText(UTF8ToUnicodeText("ξέρει")));
+ EXPECT_EQ(unilib_->ToUpper(0x03B1), 0x0391); // GREEK ALPHA
+ EXPECT_EQ(unilib_->ToUpper(0x03CB), 0x03AB); // GREEK UPSILON WITH DIALYTIKA
+ EXPECT_EQ(unilib_->ToUpper(0x0391), 0x0391); // GREEK CAPITAL ALPHA
+ EXPECT_EQ(unilib_->ToUpper(0x03C3), 0x03A3); // GREEK CAPITAL LETTER SIGMA
+ EXPECT_EQ(unilib_->ToUpper(0x03C2), 0x03A3); // GREEK CAPITAL LETTER SIGMA
+ EXPECT_EQ(
+ unilib_->ToUpperText(UTF8ToUnicodeText("Κανένας άνθρωπος δεν ξέρει"))
+ .ToUTF8String(),
+ "ΚΑΝΈΝΑΣ ΆΝΘΡΩΠΟΣ ΔΕΝ ΞΈΡΕΙ");
+ EXPECT_TRUE(unilib_->IsUpperText(UTF8ToUnicodeText("ΚΑΝΈΝΑΣ")));
+ EXPECT_EQ(unilib_->GetPairedBracket(0x0F3C), 0x0F3D);
+ EXPECT_EQ(unilib_->GetPairedBracket(0x0F3D), 0x0F3C);
+}
+
+TEST_F(UniLibTest, RegexInterface) {
+ const UnicodeText regex_pattern =
+ UTF8ToUnicodeText("[0-9]+", /*do_copy=*/true);
+ std::unique_ptr<UniLib::RegexPattern> pattern =
+ unilib_->CreateRegexPattern(regex_pattern);
+ const UnicodeText input = UTF8ToUnicodeText("hello 0123", /*do_copy=*/false);
+ int status;
+ std::unique_ptr<UniLib::RegexMatcher> matcher = pattern->Matcher(input);
+ TC3_LOG(INFO) << matcher->Matches(&status);
+ TC3_LOG(INFO) << matcher->Find(&status);
+ TC3_LOG(INFO) << matcher->Start(0, &status);
+ TC3_LOG(INFO) << matcher->End(0, &status);
+ TC3_LOG(INFO) << matcher->Group(0, &status).size_codepoints();
+}
+
+TEST_F(UniLibTest, Regex) {
+ // The smiley face is a 4-byte UTF8 codepoint 0x1F60B, and it's important to
+ // test the regex functionality with it to verify we are handling the indices
+ // correctly.
+ const UnicodeText regex_pattern =
+ UTF8ToUnicodeText("[0-9]+😋", /*do_copy=*/false);
+ std::unique_ptr<UniLib::RegexPattern> pattern =
+ unilib_->CreateRegexPattern(regex_pattern);
+ int status;
+ std::unique_ptr<UniLib::RegexMatcher> matcher;
+
+ matcher = pattern->Matcher(UTF8ToUnicodeText("0123😋", /*do_copy=*/false));
+ EXPECT_TRUE(matcher->Matches(&status));
+ EXPECT_TRUE(matcher->ApproximatelyMatches(&status));
+ EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
+ EXPECT_TRUE(matcher->Matches(&status)); // Check that the state is reset.
+ EXPECT_TRUE(matcher->ApproximatelyMatches(&status));
+ EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
+
+ matcher = pattern->Matcher(
+ UTF8ToUnicodeText("hello😋😋 0123😋 world", /*do_copy=*/false));
+ EXPECT_FALSE(matcher->Matches(&status));
+ EXPECT_FALSE(matcher->ApproximatelyMatches(&status));
+ EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
+
+ matcher = pattern->Matcher(
+ UTF8ToUnicodeText("hello😋😋 0123😋 world", /*do_copy=*/false));
+ EXPECT_TRUE(matcher->Find(&status));
+ EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
+ EXPECT_EQ(matcher->Start(0, &status), 8);
+ EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
+ EXPECT_EQ(matcher->End(0, &status), 13);
+ EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
+ EXPECT_EQ(matcher->Group(0, &status).ToUTF8String(), "0123😋");
+ EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
+}
+
+TEST_F(UniLibTest, RegexLazy) {
+ std::unique_ptr<UniLib::RegexPattern> pattern =
+ unilib_->CreateLazyRegexPattern(
+ UTF8ToUnicodeText("[a-z][0-9]", /*do_copy=*/false));
+ int status;
+ std::unique_ptr<UniLib::RegexMatcher> matcher;
+
+ matcher = pattern->Matcher(UTF8ToUnicodeText("a3", /*do_copy=*/false));
+ EXPECT_TRUE(matcher->Matches(&status));
+ EXPECT_TRUE(matcher->ApproximatelyMatches(&status));
+ EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
+ EXPECT_TRUE(matcher->Matches(&status)); // Check that the state is reset.
+ EXPECT_TRUE(matcher->ApproximatelyMatches(&status));
+ EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
+
+ matcher = pattern->Matcher(UTF8ToUnicodeText("3a", /*do_copy=*/false));
+ EXPECT_FALSE(matcher->Matches(&status));
+ EXPECT_FALSE(matcher->ApproximatelyMatches(&status));
+ EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
+}
+
+TEST_F(UniLibTest, RegexGroups) {
+ // The smiley face is a 4-byte UTF8 codepoint 0x1F60B, and it's important to
+ // test the regex functionality with it to verify we are handling the indices
+ // correctly.
+ const UnicodeText regex_pattern =
+ UTF8ToUnicodeText("([0-9])([0-9]+)😋", /*do_copy=*/false);
+ std::unique_ptr<UniLib::RegexPattern> pattern =
+ unilib_->CreateRegexPattern(regex_pattern);
+ int status;
+ std::unique_ptr<UniLib::RegexMatcher> matcher;
+
+ matcher = pattern->Matcher(
+ UTF8ToUnicodeText("hello😋😋 0123😋 world", /*do_copy=*/false));
+ EXPECT_TRUE(matcher->Find(&status));
+ EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
+ EXPECT_EQ(matcher->Start(0, &status), 8);
+ EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
+ EXPECT_EQ(matcher->Start(1, &status), 8);
+ EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
+ EXPECT_EQ(matcher->Start(2, &status), 9);
+ EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
+ EXPECT_EQ(matcher->End(0, &status), 13);
+ EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
+ EXPECT_EQ(matcher->End(1, &status), 9);
+ EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
+ EXPECT_EQ(matcher->End(2, &status), 12);
+ EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
+ EXPECT_EQ(matcher->Group(0, &status).ToUTF8String(), "0123😋");
+ EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
+ EXPECT_EQ(matcher->Group(1, &status).ToUTF8String(), "0");
+ EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
+ EXPECT_EQ(matcher->Group(2, &status).ToUTF8String(), "123");
+ EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
+}
+
+TEST_F(UniLibTest, RegexGroupsNotAllGroupsInvolved) {
+ const UnicodeText regex_pattern =
+ UTF8ToUnicodeText("([0-9])([a-z])?", /*do_copy=*/false);
+ std::unique_ptr<UniLib::RegexPattern> pattern =
+ unilib_->CreateRegexPattern(regex_pattern);
+ int status;
+ std::unique_ptr<UniLib::RegexMatcher> matcher;
+
+ matcher = pattern->Matcher(UTF8ToUnicodeText("7", /*do_copy=*/false));
+ EXPECT_TRUE(matcher->Find(&status));
+ EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
+ EXPECT_EQ(matcher->Group(0, &status).ToUTF8String(), "7");
+ EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
+ EXPECT_EQ(matcher->Group(1, &status).ToUTF8String(), "7");
+ EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
+ EXPECT_EQ(matcher->Group(2, &status).ToUTF8String(), "");
+ EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
+}
+
+TEST_F(UniLibTest, RegexGroupsEmptyResult) {
+ const UnicodeText regex_pattern =
+ UTF8ToUnicodeText("(.*)", /*do_copy=*/false);
+ std::unique_ptr<UniLib::RegexPattern> pattern =
+ unilib_->CreateRegexPattern(regex_pattern);
+ int status;
+ std::unique_ptr<UniLib::RegexMatcher> matcher;
+
+ matcher = pattern->Matcher(UTF8ToUnicodeText("", /*do_copy=*/false));
+ EXPECT_TRUE(matcher->Find(&status));
+ EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
+ EXPECT_EQ(matcher->Group(0, &status).ToUTF8String(), "");
+ EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
+ EXPECT_EQ(matcher->Group(1, &status).ToUTF8String(), "");
+ EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
+}
+
+TEST_F(UniLibTest, BreakIterator) {
+ const UnicodeText text = UTF8ToUnicodeText("some text", /*do_copy=*/false);
+ std::unique_ptr<UniLib::BreakIterator> iterator =
+ unilib_->CreateBreakIterator(text);
+ std::vector<int> break_indices;
+ int break_index = 0;
+ while ((break_index = iterator->Next()) != UniLib::BreakIterator::kDone) {
+ break_indices.push_back(break_index);
+ }
+ EXPECT_THAT(break_indices, ElementsAre(4, 5, 9));
+}
+
+TEST_F(UniLibTest, BreakIterator4ByteUTF8) {
+ const UnicodeText text = UTF8ToUnicodeText("😀😂😋", /*do_copy=*/false);
+ std::unique_ptr<UniLib::BreakIterator> iterator =
+ unilib_->CreateBreakIterator(text);
+ std::vector<int> break_indices;
+ int break_index = 0;
+ while ((break_index = iterator->Next()) != UniLib::BreakIterator::kDone) {
+ break_indices.push_back(break_index);
+ }
+ EXPECT_THAT(break_indices, ElementsAre(1, 2, 3));
+}
+
+TEST_F(UniLibTest, Integer32Parse) {
+ int result;
+ EXPECT_TRUE(unilib_->ParseInt32(UTF8ToUnicodeText("123", /*do_copy=*/false),
+ &result));
+ EXPECT_EQ(result, 123);
+}
+
+TEST_F(UniLibTest, Integer32ParseFloatNumber) {
+ int result;
+ EXPECT_FALSE(unilib_->ParseInt32(UTF8ToUnicodeText("12.3", /*do_copy=*/false),
+ &result));
+}
+
+TEST_F(UniLibTest, Integer32ParseLongNumber) {
+ int32 result;
+ EXPECT_TRUE(unilib_->ParseInt32(
+ UTF8ToUnicodeText("1000000000", /*do_copy=*/false), &result));
+ EXPECT_EQ(result, 1000000000);
+}
+
+TEST_F(UniLibTest, Integer32ParseOverflowNumber) {
+ int32 result;
+ EXPECT_FALSE(unilib_->ParseInt32(
+ UTF8ToUnicodeText("9123456789", /*do_copy=*/false), &result));
+}
+
+TEST_F(UniLibTest, Integer32ParseEmptyString) {
+ int result;
+ EXPECT_FALSE(
+ unilib_->ParseInt32(UTF8ToUnicodeText("", /*do_copy=*/false), &result));
+}
+
+TEST_F(UniLibTest, Integer32ParseFullWidth) {
+ int result;
+ // The input string here is full width
+ EXPECT_TRUE(unilib_->ParseInt32(
+ UTF8ToUnicodeText("123", /*do_copy=*/false), &result));
+ EXPECT_EQ(result, 123);
+}
+
+TEST_F(UniLibTest, Integer32ParseNotNumber) {
+ int result;
+ // The input string here is full width
+ EXPECT_FALSE(unilib_->ParseInt32(
+ UTF8ToUnicodeText("1a3", /*do_copy=*/false), &result));
+ // Strings starting with "nan" are not numbers.
+ EXPECT_FALSE(unilib_->ParseInt32(UTF8ToUnicodeText("Nancy",
+ /*do_copy=*/false),
+ &result));
+ // Strings starting with "inf" are not numbers
+ EXPECT_FALSE(unilib_->ParseInt32(
+ UTF8ToUnicodeText("Information", /*do_copy=*/false), &result));
+}
+
+TEST_F(UniLibTest, Integer64Parse) {
+ int64 result;
+ EXPECT_TRUE(unilib_->ParseInt64(UTF8ToUnicodeText("123", /*do_copy=*/false),
+ &result));
+ EXPECT_EQ(result, 123);
+}
+
+TEST_F(UniLibTest, Integer64ParseFloatNumber) {
+ int64 result;
+ EXPECT_FALSE(unilib_->ParseInt64(UTF8ToUnicodeText("12.3", /*do_copy=*/false),
+ &result));
+}
+
+TEST_F(UniLibTest, Integer64ParseLongNumber) {
+ int64 result;
+ // The limitation comes from the javaicu implementation: parseDouble does not
+ // have ICU support and parseInt limit the size of the number.
+ EXPECT_TRUE(unilib_->ParseInt64(
+ UTF8ToUnicodeText("1000000000", /*do_copy=*/false), &result));
+ EXPECT_EQ(result, 1000000000);
+}
+
+TEST_F(UniLibTest, Integer64ParseOverflowNumber) {
+ int64 result;
+ EXPECT_FALSE(unilib_->ParseInt64(
+ UTF8ToUnicodeText("92233720368547758099", /*do_copy=*/false), &result));
+}
+
+TEST_F(UniLibTest, Integer64ParseOverflowNegativeNumber) {
+ int64 result;
+ EXPECT_FALSE(unilib_->ParseInt64(
+ UTF8ToUnicodeText("-92233720368547758099", /*do_copy=*/false), &result));
+}
+
+TEST_F(UniLibTest, Integer64ParseEmptyString) {
+ int64 result;
+ EXPECT_FALSE(
+ unilib_->ParseInt64(UTF8ToUnicodeText("", /*do_copy=*/false), &result));
+}
+
+TEST_F(UniLibTest, Integer64ParseFullWidth) {
+ int64 result;
+ // The input string here is full width
+ EXPECT_TRUE(unilib_->ParseInt64(
+ UTF8ToUnicodeText("123", /*do_copy=*/false), &result));
+ EXPECT_EQ(result, 123);
+}
+
+TEST_F(UniLibTest, Integer64ParseNotNumber) {
+ int64 result;
+ // The input string here is full width
+ EXPECT_FALSE(unilib_->ParseInt64(
+ UTF8ToUnicodeText("1a4", /*do_copy=*/false), &result));
+ // Strings starting with "nan" are not numbers.
+ EXPECT_FALSE(unilib_->ParseInt64(UTF8ToUnicodeText("Nancy",
+ /*do_copy=*/false),
+ &result));
+ // Strings starting with "inf" are not numbers
+ EXPECT_FALSE(unilib_->ParseInt64(
+ UTF8ToUnicodeText("Information", /*do_copy=*/false), &result));
+}
+
+TEST_F(UniLibTest, DoubleParse) {
+ double result;
+ EXPECT_TRUE(unilib_->ParseDouble(UTF8ToUnicodeText("1.23", /*do_copy=*/false),
+ &result));
+ EXPECT_EQ(result, 1.23);
+}
+
+TEST_F(UniLibTest, DoubleParseLongNumber) {
+ double result;
+ // The limitation comes from the javaicu implementation: parseDouble does not
+ // have ICU support and parseInt limit the size of the number.
+ EXPECT_TRUE(unilib_->ParseDouble(
+ UTF8ToUnicodeText("999999999.999999999", /*do_copy=*/false), &result));
+ EXPECT_EQ(result, 999999999.999999999);
+}
+
+TEST_F(UniLibTest, DoubleParseWithoutFractionalPart) {
+ double result;
+ EXPECT_TRUE(unilib_->ParseDouble(UTF8ToUnicodeText("123", /*do_copy=*/false),
+ &result));
+ EXPECT_EQ(result, 123);
+}
+
+TEST_F(UniLibTest, DoubleParseEmptyString) {
+ double result;
+ EXPECT_FALSE(
+ unilib_->ParseDouble(UTF8ToUnicodeText("", /*do_copy=*/false), &result));
+}
+
+TEST_F(UniLibTest, DoubleParsePrecedingDot) {
+ double result;
+ EXPECT_FALSE(unilib_->ParseDouble(
+ UTF8ToUnicodeText(".123", /*do_copy=*/false), &result));
+}
+
+TEST_F(UniLibTest, DoubleParseLeadingDot) {
+ double result;
+ EXPECT_FALSE(unilib_->ParseDouble(
+ UTF8ToUnicodeText("123.", /*do_copy=*/false), &result));
+}
+
+TEST_F(UniLibTest, DoubleParseMultipleDots) {
+ double result;
+ EXPECT_FALSE(unilib_->ParseDouble(
+ UTF8ToUnicodeText("1.2.3", /*do_copy=*/false), &result));
+}
+
+TEST_F(UniLibTest, DoubleParseFullWidth) {
+ double result;
+ // The input string here is full width
+ EXPECT_TRUE(unilib_->ParseDouble(
+ UTF8ToUnicodeText("1.23", /*do_copy=*/false), &result));
+ EXPECT_EQ(result, 1.23);
+}
+
+TEST_F(UniLibTest, DoubleParseNotNumber) {
+ double result;
+ // The input string here is full width
+ EXPECT_FALSE(unilib_->ParseDouble(
+ UTF8ToUnicodeText("1a5", /*do_copy=*/false), &result));
+ // Strings starting with "nan" are not numbers.
+ EXPECT_FALSE(unilib_->ParseDouble(
+ UTF8ToUnicodeText("Nancy", /*do_copy=*/false), &result));
+ // Strings starting with "inf" are not numbers
+ EXPECT_FALSE(unilib_->ParseDouble(
+ UTF8ToUnicodeText("Information", /*do_copy=*/false), &result));
+}
+
+TEST_F(UniLibTest, Length) {
+ EXPECT_EQ(unilib_->Length(UTF8ToUnicodeText("hello", /*do_copy=*/false))
+ .ValueOrDie(),
+ 5);
+ EXPECT_EQ(unilib_->Length(UTF8ToUnicodeText("ěščřž", /*do_copy=*/false))
+ .ValueOrDie(),
+ 5);
+ // Test Invalid UTF8.
+ // This testing condition needs to be != 1, as Apple character counting seems
+ // to return 0 when the input is invalid UTF8, while ICU will treat the
+ // invalid codepoint as 3 separate bytes.
+ EXPECT_NE(
+ unilib_->Length(UTF8ToUnicodeText("\xed\xa0\x80", /*do_copy=*/false))
+ .ValueOrDie(),
+ 1);
+}
+
+} // namespace test_internal
+} // namespace libtextclassifier3
diff --git a/native/utils/utf8/unilib_test-include.h b/native/utils/utf8/unilib_test-include.h
new file mode 100644
index 0000000..8ae8a0f
--- /dev/null
+++ b/native/utils/utf8/unilib_test-include.h
@@ -0,0 +1,36 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_UTF8_UNILIB_TEST_INCLUDE_H_
+#define LIBTEXTCLASSIFIER_UTILS_UTF8_UNILIB_TEST_INCLUDE_H_
+
+#include "utils/jvm-test-utils.h"
+#include "utils/utf8/unilib.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+namespace test_internal {
+
+class UniLibTest : public ::testing::Test {
+ protected:
+ UniLibTest() : unilib_(CreateUniLibForTesting()) {}
+ std::unique_ptr<UniLib> unilib_;
+};
+
+} // namespace test_internal
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_UTF8_UNILIB_TEST_INCLUDE_H_
diff --git a/native/utils/variant.cc b/native/utils/variant.cc
index 0513440..ed39826 100644
--- a/native/utils/variant.cc
+++ b/native/utils/variant.cc
@@ -30,6 +30,9 @@
case Variant::TYPE_INT_VALUE:
return std::to_string(Value<int>());
break;
+ case Variant::TYPE_UINT_VALUE:
+ return std::to_string(Value<unsigned int>());
+ break;
case Variant::TYPE_INT64_VALUE:
return std::to_string(Value<int64>());
break;
diff --git a/native/utils/wordpiece_tokenizer.cc b/native/utils/wordpiece_tokenizer.cc
new file mode 100644
index 0000000..f4fcafc
--- /dev/null
+++ b/native/utils/wordpiece_tokenizer.cc
@@ -0,0 +1,247 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/wordpiece_tokenizer.h"
+
+#include "utils/utf8/unicodetext.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
+#include "absl/strings/string_view.h"
+
+namespace libtextclassifier3 {
+
+namespace {
+
+LookupStatus Lookup(int byte_start, int byte_end, const absl::string_view token,
+ const std::string& suffix_indicator,
+ const WordpieceVocab* vocab_map, bool* in_vocab) {
+ int byte_len = byte_end - byte_start;
+ absl::string_view substr(token.data() + byte_start, byte_len);
+ std::string lookup_value;
+ if (byte_start > 0) {
+ lookup_value = absl::StrCat(suffix_indicator, substr);
+ } else {
+ // absl::CopyToString
+ lookup_value.assign(substr.begin(), substr.end());
+ }
+ return vocab_map->Contains(lookup_value, in_vocab);
+}
+
+// Sets byte_end to the longest byte sequence which:
+// 1) is a proper UTF8 sequence
+// 2) is in the vocab OR if split_unknown_characters is true, is a single
+// UTF8 character.
+// If no match is found, found_match is set to false.
+LookupStatus LongestMatchStartingAt(
+ int byte_start, const absl::string_view token,
+ const std::string& suffix_indicator, const int max_chars_per_subtoken,
+ bool split_unknown_characters, const WordpieceVocab* vocab_map,
+ int* byte_end, bool* found_match, bool* match_is_unknown_character) {
+ *match_is_unknown_character = false;
+ *found_match = false;
+ const UnicodeText unicode_token =
+ UTF8ToUnicodeText(token.substr(byte_start), /*do_copy=*/false);
+ std::vector<int32_t> byte_ends;
+ int32_t codepoint_offset = byte_start;
+ for (auto it = unicode_token.begin(); it != unicode_token.end(); ++it) {
+ codepoint_offset += it.utf8_length();
+ byte_ends.push_back(codepoint_offset);
+ if (max_chars_per_subtoken > 0 &&
+ byte_ends.size() == max_chars_per_subtoken) {
+ // If the max bytes of a subtoken is known, do not search beyond that
+ // length.
+ break;
+ }
+ }
+ int n = byte_ends.size();
+ for (int i = n - 1; i >= 0; i--) {
+ bool in_vocab;
+ auto status = Lookup(byte_start, byte_ends[i], token, suffix_indicator,
+ vocab_map, &in_vocab);
+ if (!status.success) return status;
+ if (in_vocab) {
+ *byte_end = byte_ends[i];
+ *found_match = true;
+ return LookupStatus::OK();
+ }
+ if (i == 0 && split_unknown_characters) {
+ *byte_end = byte_ends[0];
+ *found_match = true;
+ *match_is_unknown_character = true;
+ return LookupStatus::OK();
+ }
+ }
+ return LookupStatus::OK();
+}
+
+// Sets the outputs 'begin_offset', 'end_offset' and 'num_word_pieces' when no
+// token is found.
+LookupStatus NoTokenFound(const absl::string_view token, bool use_unknown_token,
+ const std::string& unknown_token,
+ std::vector<std::string>* subwords,
+ std::vector<int>* begin_offset,
+ std::vector<int>* end_offset, int* num_word_pieces) {
+ begin_offset->push_back(0);
+ if (use_unknown_token) {
+ subwords->push_back(unknown_token);
+ end_offset->push_back(token.length());
+ } else {
+ subwords->emplace_back(token.data(), token.length());
+ end_offset->push_back(token.length());
+ }
+ ++(*num_word_pieces);
+
+ return LookupStatus::OK();
+}
+
+// When a subword is found, this helper function will add the outputs to
+// 'subwords', 'begin_offset' and 'end_offset'.
+void AddWord(const absl::string_view token, int byte_start, int byte_end,
+ const std::string& suffix_indicator,
+ std::vector<std::string>* subwords, std::vector<int>* begin_offset,
+ std::vector<int>* end_offset) {
+ begin_offset->push_back(byte_start);
+ int len = byte_end - byte_start;
+
+ if (byte_start > 0) {
+ // Prepend suffix_indicator if the token is within a word.
+ subwords->push_back(::absl::StrCat(
+ suffix_indicator, absl::string_view(token.data() + byte_start, len)));
+ } else {
+ subwords->emplace_back(token.data(), len);
+ }
+ end_offset->push_back(byte_end);
+}
+
+// Adds a single unknown character subword, found when split_unknown_characters
+// is true.
+void AddUnknownCharacter(const absl::string_view token, int byte_start,
+ int byte_end, const std::string& suffix_indicator,
+ bool use_unknown_token,
+ const std::string& unknown_token,
+ std::vector<std::string>* subwords,
+ std::vector<int>* begin_offset,
+ std::vector<int>* end_offset) {
+ begin_offset->push_back(byte_start);
+ end_offset->push_back(byte_end);
+ int len = byte_end - byte_start;
+ if (use_unknown_token) {
+ if (byte_start > 0) {
+ // Prepend suffix_indicator if the character is within a word.
+ subwords->push_back(::absl::StrCat(suffix_indicator, unknown_token));
+ } else {
+ subwords->push_back(unknown_token);
+ }
+ } else {
+ if (byte_start > 0) {
+ // Prepend suffix_indicator if the character is within a word.
+ subwords->push_back(::absl::StrCat(
+ suffix_indicator, absl::string_view(token.data() + byte_start, len)));
+ } else {
+ subwords->emplace_back(token.data(), len);
+ }
+ }
+}
+
+LookupStatus TokenizeL2RGreedy(
+ const absl::string_view token, const int max_bytes_per_token,
+ const int max_chars_per_subtoken, const std::string& suffix_indicator,
+ bool use_unknown_token, const std::string& unknown_token,
+ bool split_unknown_characters, const WordpieceVocab* vocab_map,
+ std::vector<std::string>* subwords, std::vector<int>* begin_offset,
+ std::vector<int>* end_offset, int* num_word_pieces) {
+ std::vector<std::string> candidate_subwords;
+ std::vector<int> candidate_begin_offsets;
+ std::vector<int> candidate_end_offsets;
+ const int token_len = token.length();
+ for (int byte_start = 0; byte_start < token_len;) {
+ int byte_end;
+ bool found_subword;
+ bool match_is_unknown_character;
+ auto status = LongestMatchStartingAt(
+ byte_start, token, suffix_indicator, max_chars_per_subtoken,
+ split_unknown_characters, vocab_map, &byte_end, &found_subword,
+ &match_is_unknown_character);
+ if (!status.success) return status;
+ if (found_subword) {
+ if (match_is_unknown_character) {
+ AddUnknownCharacter(token, byte_start, byte_end, suffix_indicator,
+ use_unknown_token, unknown_token,
+ &candidate_subwords, &candidate_begin_offsets,
+ &candidate_end_offsets);
+ } else {
+ AddWord(token, byte_start, byte_end, suffix_indicator,
+ &candidate_subwords, &candidate_begin_offsets,
+ &candidate_end_offsets);
+ }
+ byte_start = byte_end;
+ } else {
+ return NoTokenFound(token, use_unknown_token, unknown_token, subwords,
+ begin_offset, end_offset, num_word_pieces);
+ }
+ }
+
+ subwords->insert(subwords->end(), candidate_subwords.begin(),
+ candidate_subwords.end());
+ begin_offset->insert(begin_offset->end(), candidate_begin_offsets.begin(),
+ candidate_begin_offsets.end());
+ end_offset->insert(end_offset->end(), candidate_end_offsets.begin(),
+ candidate_end_offsets.end());
+ *num_word_pieces += candidate_subwords.size();
+ return LookupStatus::OK();
+}
+
+} // namespace
+
+LookupStatus WordpieceTokenize(
+ const absl::string_view token, const int max_bytes_per_token,
+ const int max_chars_per_subtoken, const std::string& suffix_indicator,
+ bool use_unknown_token, const std::string& unknown_token,
+ bool split_unknown_characters, const WordpieceVocab* vocab_map,
+ std::vector<std::string>* subwords, std::vector<int>* begin_offset,
+ std::vector<int>* end_offset, int* num_word_pieces) {
+ int token_len = token.size();
+ if (token_len > max_bytes_per_token) {
+ begin_offset->push_back(0);
+ *num_word_pieces = 1;
+ if (use_unknown_token) {
+ end_offset->push_back(unknown_token.size());
+ subwords->emplace_back(unknown_token);
+ } else {
+ subwords->emplace_back(token);
+ end_offset->push_back(token.size());
+ }
+ return LookupStatus::OK();
+ }
+ return TokenizeL2RGreedy(token, max_bytes_per_token, max_chars_per_subtoken,
+ suffix_indicator, use_unknown_token, unknown_token,
+ split_unknown_characters, vocab_map, subwords,
+ begin_offset, end_offset, num_word_pieces);
+}
+
+LookupStatus WordpieceTokenize(
+ const absl::string_view token, const int max_bytes_per_token,
+ const std::string& suffix_indicator, bool use_unknown_token,
+ const std::string& unknown_token, const WordpieceVocab* vocab_map,
+ std::vector<std::string>* subwords, std::vector<int>* begin_offset,
+ std::vector<int>* end_offset, int* num_word_pieces) {
+ return WordpieceTokenize(token, max_bytes_per_token,
+ /* max_chars_per_subtoken= */ 0, suffix_indicator,
+ use_unknown_token, unknown_token,
+ /* split_unknown_characters= */ false, vocab_map,
+ subwords, begin_offset, end_offset, num_word_pieces);
+}
+} // namespace libtextclassifier3
diff --git a/native/utils/wordpiece_tokenizer.h b/native/utils/wordpiece_tokenizer.h
new file mode 100644
index 0000000..a6eb8e0
--- /dev/null
+++ b/native/utils/wordpiece_tokenizer.h
@@ -0,0 +1,63 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_WORDPIECE_TOKENIZER_H_
+#define LIBTEXTCLASSIFIER_UTILS_WORDPIECE_TOKENIZER_H_
+
+#include <string>
+#include <vector>
+
+#include "absl/strings/string_view.h"
+
+namespace libtextclassifier3 {
+
+struct LookupStatus {
+ LookupStatus() : error_msg(""), success(true) {}
+ explicit LookupStatus(const std::string& msg)
+ : error_msg(msg), success(false) {}
+ std::string error_msg;
+ bool success;
+
+ static LookupStatus OK() { return LookupStatus(); }
+};
+
+class WordpieceVocab {
+ public:
+ virtual ~WordpieceVocab() {}
+ virtual LookupStatus Contains(const absl::string_view key,
+ bool* value) const = 0;
+};
+
+LookupStatus WordpieceTokenize(
+ const absl::string_view token, const int max_bytes_per_token,
+ const int max_chars_per_subtoken, const std::string& suffix_indicator,
+ bool use_unknown_token, const std::string& unknown_token,
+ bool split_unknown_characters, const WordpieceVocab* vocab_map,
+ std::vector<std::string>* subwords, std::vector<int>* begin_offset,
+ std::vector<int>* end_offset, int* num_word_pieces);
+
+// As above but with `max_bytes_per_subtoken` unknown,
+// and split_unknown_characters=false. (For backwards compatibility.)
+LookupStatus WordpieceTokenize(
+ const absl::string_view token, const int max_bytes_per_token,
+ const std::string& suffix_indicator, bool use_unknown_token,
+ const std::string& unknown_token, const WordpieceVocab* vocab_map,
+ std::vector<std::string>* subwords, std::vector<int>* begin_offset,
+ std::vector<int>* end_offset, int* num_word_pieces);
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_WORDPIECE_TOKENIZER_H_
diff --git a/native/utils/zlib/buffer.fbs b/native/utils/zlib/buffer.fbs
old mode 100755
new mode 100644
diff --git a/native/utils/zlib/zlib.cc b/native/utils/zlib/zlib.cc
index 4cb7760..6c8c5fd 100644
--- a/native/utils/zlib/zlib.cc
+++ b/native/utils/zlib/zlib.cc
@@ -16,22 +16,20 @@
#include "utils/zlib/zlib.h"
-#include "utils/flatbuffers.h"
+#include "utils/base/logging.h"
+#include "utils/flatbuffers/flatbuffers.h"
namespace libtextclassifier3 {
-std::unique_ptr<ZlibDecompressor> ZlibDecompressor::Instance(
- const unsigned char* dictionary, const unsigned int dictionary_size) {
- std::unique_ptr<ZlibDecompressor> result(
- new ZlibDecompressor(dictionary, dictionary_size));
+std::unique_ptr<ZlibDecompressor> ZlibDecompressor::Instance() {
+ std::unique_ptr<ZlibDecompressor> result(new ZlibDecompressor());
if (!result->initialized_) {
result.reset();
}
return result;
}
-ZlibDecompressor::ZlibDecompressor(const unsigned char* dictionary,
- const unsigned int dictionary_size) {
+ZlibDecompressor::ZlibDecompressor() {
memset(&stream_, 0, sizeof(stream_));
stream_.zalloc = Z_NULL;
stream_.zfree = Z_NULL;
@@ -40,11 +38,6 @@
TC3_LOG(ERROR) << "Could not initialize decompressor.";
return;
}
- if (dictionary != nullptr &&
- inflateSetDictionary(&stream_, dictionary, dictionary_size) != Z_OK) {
- TC3_LOG(ERROR) << "Could not set dictionary.";
- return;
- }
initialized_ = true;
}
@@ -61,7 +54,8 @@
return false;
}
out->resize(uncompressed_size);
- stream_.next_in = reinterpret_cast<const Bytef*>(buffer);
+ stream_.next_in =
+ const_cast<z_const Bytef*>(reinterpret_cast<const Bytef*>(buffer));
stream_.avail_in = buffer_size;
stream_.next_out = reinterpret_cast<Bytef*>(const_cast<char*>(out->c_str()));
stream_.avail_out = uncompressed_size;
@@ -110,19 +104,15 @@
return MaybeDecompress(compressed_buffer, out);
}
-std::unique_ptr<ZlibCompressor> ZlibCompressor::Instance(
- const unsigned char* dictionary, const unsigned int dictionary_size) {
- std::unique_ptr<ZlibCompressor> result(
- new ZlibCompressor(dictionary, dictionary_size));
+std::unique_ptr<ZlibCompressor> ZlibCompressor::Instance() {
+ std::unique_ptr<ZlibCompressor> result(new ZlibCompressor());
if (!result->initialized_) {
result.reset();
}
return result;
}
-ZlibCompressor::ZlibCompressor(const unsigned char* dictionary,
- const unsigned int dictionary_size,
- const int level, const int tmp_buffer_size) {
+ZlibCompressor::ZlibCompressor(const int level, const int tmp_buffer_size) {
memset(&stream_, 0, sizeof(stream_));
stream_.zalloc = Z_NULL;
stream_.zfree = Z_NULL;
@@ -133,11 +123,6 @@
TC3_LOG(ERROR) << "Could not initialize compressor.";
return;
}
- if (dictionary != nullptr &&
- deflateSetDictionary(&stream_, dictionary, dictionary_size) != Z_OK) {
- TC3_LOG(ERROR) << "Could not set dictionary.";
- return;
- }
initialized_ = true;
}
@@ -147,8 +132,8 @@
CompressedBufferT* out) {
out->uncompressed_size = uncompressed_content.size();
out->buffer.clear();
- stream_.next_in =
- reinterpret_cast<const Bytef*>(uncompressed_content.c_str());
+ stream_.next_in = const_cast<z_const Bytef*>(
+ reinterpret_cast<const Bytef*>(uncompressed_content.c_str()));
stream_.avail_in = uncompressed_content.size();
stream_.next_out = buffer_.get();
stream_.avail_out = buffer_size_;
@@ -177,14 +162,4 @@
} while (status == Z_OK);
}
-bool ZlibCompressor::GetDictionary(std::vector<unsigned char>* dictionary) {
- // Retrieve first the size of the dictionary.
- unsigned int size;
- if (deflateGetDictionary(&stream_, /*dictionary=*/Z_NULL, &size) != Z_OK) {
- return false;
- }
- dictionary->resize(size);
- return deflateGetDictionary(&stream_, dictionary->data(), &size) == Z_OK;
-}
-
} // namespace libtextclassifier3
diff --git a/native/utils/zlib/zlib.h b/native/utils/zlib/zlib.h
index f773c27..1f4d18a 100644
--- a/native/utils/zlib/zlib.h
+++ b/native/utils/zlib/zlib.h
@@ -29,9 +29,7 @@
class ZlibDecompressor {
public:
- static std::unique_ptr<ZlibDecompressor> Instance(
- const unsigned char* dictionary = nullptr,
- unsigned int dictionary_size = 0);
+ static std::unique_ptr<ZlibDecompressor> Instance();
~ZlibDecompressor();
bool Decompress(const uint8* buffer, const int buffer_size,
@@ -48,28 +46,21 @@
const CompressedBuffer* compressed_buffer, std::string* out);
private:
- ZlibDecompressor(const unsigned char* dictionary,
- const unsigned int dictionary_size);
+ explicit ZlibDecompressor();
z_stream stream_;
bool initialized_;
};
class ZlibCompressor {
public:
- static std::unique_ptr<ZlibCompressor> Instance(
- const unsigned char* dictionary = nullptr,
- unsigned int dictionary_size = 0);
+ static std::unique_ptr<ZlibCompressor> Instance();
~ZlibCompressor();
void Compress(const std::string& uncompressed_content,
CompressedBufferT* out);
- bool GetDictionary(std::vector<unsigned char>* dictionary);
-
private:
- explicit ZlibCompressor(const unsigned char* dictionary = nullptr,
- const unsigned int dictionary_size = 0,
- const int level = Z_BEST_COMPRESSION,
+ explicit ZlibCompressor(const int level = Z_BEST_COMPRESSION,
// Tmp. buffer size was set based on the current set
// of patterns to be compressed.
const int tmp_buffer_size = 64 * 1024);
diff --git a/native/utils/zlib/zlib_regex.cc b/native/utils/zlib/zlib_regex.cc
index 73b6d30..901bb91 100644
--- a/native/utils/zlib/zlib_regex.cc
+++ b/native/utils/zlib/zlib_regex.cc
@@ -19,7 +19,7 @@
#include <memory>
#include "utils/base/logging.h"
-#include "utils/flatbuffers.h"
+#include "utils/flatbuffers/flatbuffers.h"
#include "utils/utf8/unicodetext.h"
namespace libtextclassifier3 {
@@ -48,7 +48,7 @@
}
unicode_regex_pattern =
UTF8ToUnicodeText(uncompressed_pattern->c_str(),
- uncompressed_pattern->Length(), /*do_copy=*/false);
+ uncompressed_pattern->size(), /*do_copy=*/false);
}
if (result_pattern_text != nullptr) {
diff --git a/notification/Android.bp b/notification/Android.bp
index 277985b..782d5cb 100644
--- a/notification/Android.bp
+++ b/notification/Android.bp
@@ -28,7 +28,7 @@
name: "TextClassifierNotificationLib",
static_libs: ["TextClassifierNotificationLibNoManifest"],
sdk_version: "system_current",
- min_sdk_version: "29",
+ min_sdk_version: "30",
manifest: "AndroidManifest.xml",
}
@@ -41,6 +41,6 @@
"guava",
],
sdk_version: "system_current",
- min_sdk_version: "29",
+ min_sdk_version: "30",
manifest: "LibNoManifest_AndroidManifest.xml",
}
diff --git a/notification/AndroidManifest.xml b/notification/AndroidManifest.xml
index 3153d1d..5a98ea3 100644
--- a/notification/AndroidManifest.xml
+++ b/notification/AndroidManifest.xml
@@ -1,7 +1,7 @@
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="com.android.textclassifier.notification">
- <uses-sdk android:minSdkVersion="29" />
+ <uses-sdk android:minSdkVersion="30" />
<application>
<activity
@@ -10,4 +10,4 @@
android:theme="@android:style/Theme.NoDisplay" />
</application>
-</manifest>
\ No newline at end of file
+</manifest>
diff --git a/notification/LibNoManifest_AndroidManifest.xml b/notification/LibNoManifest_AndroidManifest.xml
index b9ebf7d..06e8da4 100644
--- a/notification/LibNoManifest_AndroidManifest.xml
+++ b/notification/LibNoManifest_AndroidManifest.xml
@@ -25,6 +25,6 @@
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="com.android.textclassifier.notification">
- <uses-sdk android:minSdkVersion="29" android:targetSdkVersion="29"/>
+ <uses-sdk android:minSdkVersion="30" android:targetSdkVersion="30"/>
</manifest>
diff --git a/notification/lint-baseline.xml b/notification/lint-baseline.xml
deleted file mode 100644
index 1f2ee2a..0000000
--- a/notification/lint-baseline.xml
+++ /dev/null
@@ -1,37 +0,0 @@
-<?xml version="1.0" encoding="UTF-8"?>
-<issues format="5" by="lint 4.1.0" client="cli" variant="all" version="4.1.0">
-
- <issue
- id="NewApi"
- message="Call requires API level R (current min is 29): `android.app.Notification#getContextualActions`"
- errorLine1=" boolean hasAppGeneratedContextualActions = !notification.getContextualActions().isEmpty();"
- errorLine2=" ~~~~~~~~~~~~~~~~~~~~">
- <location
- file="external/libtextclassifier/notification/src/com/android/textclassifier/notification/SmartSuggestionsHelper.java"
- line="248"
- column="62"/>
- </issue>
-
- <issue
- id="NewApi"
- message="Call requires API level R (current min is 29): `android.app.Notification#findRemoteInputActionPair`"
- errorLine1=" notification.findRemoteInputActionPair(/* requiresFreeform */ true);"
- errorLine2=" ~~~~~~~~~~~~~~~~~~~~~~~~~">
- <location
- file="external/libtextclassifier/notification/src/com/android/textclassifier/notification/SmartSuggestionsHelper.java"
- line="251"
- column="22"/>
- </issue>
-
- <issue
- id="NewApi"
- message="Call requires API level R (current min is 29): `android.app.Notification.MessagingStyle.Message#getMessagesFromBundleArray`"
- errorLine1=" Message.getMessagesFromBundleArray("
- errorLine2=" ~~~~~~~~~~~~~~~~~~~~~~~~~~">
- <location
- file="external/libtextclassifier/notification/src/com/android/textclassifier/notification/SmartSuggestionsHelper.java"
- line="434"
- column="17"/>
- </issue>
-
-</issues>
diff --git a/notification/src/com/android/textclassifier/notification/SmartSuggestionsHelper.java b/notification/src/com/android/textclassifier/notification/SmartSuggestionsHelper.java
index 0a2cce7..9429b29 100644
--- a/notification/src/com/android/textclassifier/notification/SmartSuggestionsHelper.java
+++ b/notification/src/com/android/textclassifier/notification/SmartSuggestionsHelper.java
@@ -35,11 +35,9 @@
import android.util.Pair;
import android.view.textclassifier.ConversationAction;
import android.view.textclassifier.ConversationActions;
-import android.view.textclassifier.TextClassification;
import android.view.textclassifier.TextClassificationContext;
import android.view.textclassifier.TextClassificationManager;
import android.view.textclassifier.TextClassifier;
-
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
@@ -172,10 +170,7 @@
} else {
SmartSuggestionsLogSession session =
new SmartSuggestionsLogSession(
- resultId,
- repliesScore,
- textClassifier,
- textClassificationContext);
+ resultId, repliesScore, textClassifier, textClassificationContext);
session.onSuggestionsGenerated(conversationActions);
// Store the session if we expect more logging from it, destroy it otherwise.
@@ -226,7 +221,10 @@
code,
contentDescription,
PendingIntent.getActivity(
- context, code.hashCode(), intent, PendingIntent.FLAG_UPDATE_CURRENT));
+ context,
+ code.hashCode(),
+ intent,
+ PendingIntent.FLAG_UPDATE_CURRENT | PendingIntent.FLAG_IMMUTABLE));
return createNotificationActionFromRemoteAction(
remoteAction, TYPE_COPY, conversationAction.getConfidenceScore());
@@ -317,7 +315,15 @@
*/
public void onNotificationExpansionChanged(
StatusBarNotification statusBarNotification, boolean isExpanded) {
- SmartSuggestionsLogSession session = sessionCache.get(statusBarNotification.getKey());
+ onNotificationExpansionChanged(statusBarNotification.getKey(), isExpanded);
+ }
+
+ /**
+ * Similar to {@link onNotificationExpansionChanged}, except that this takes the notificataion key
+ * as input.
+ */
+ public void onNotificationExpansionChanged(String key, boolean isExpanded) {
+ SmartSuggestionsLogSession session = sessionCache.get(key);
if (session == null) {
return;
}
diff --git a/notification/tests/Android.bp b/notification/tests/Android.bp
index 7613496..48c6324 100644
--- a/notification/tests/Android.bp
+++ b/notification/tests/Android.bp
@@ -42,9 +42,10 @@
],
test_suites: [
- "device-tests", "mts-extservices"
+ "general-tests", "mts-extservices"
],
- instrumentation_for: "TextClassifierNotificationLib",
min_sdk_version: "30",
+
+ instrumentation_for: "TextClassifierNotificationLib",
}
diff --git a/notification/tests/AndroidManifest.xml b/notification/tests/AndroidManifest.xml
index 81308e3..d3da067 100644
--- a/notification/tests/AndroidManifest.xml
+++ b/notification/tests/AndroidManifest.xml
@@ -2,8 +2,8 @@
package="com.android.textclassifier.notification">
<uses-sdk
- android:minSdkVersion="29"
- android:targetSdkVersion="29" />
+ android:minSdkVersion="30"
+ android:targetSdkVersion="30" />
<application>
<uses-library android:name="android.test.runner"/>
diff --git a/notification/tests/AndroidTest.xml b/notification/tests/AndroidTest.xml
index 1890e75..0f60d10 100644
--- a/notification/tests/AndroidTest.xml
+++ b/notification/tests/AndroidTest.xml
@@ -13,8 +13,8 @@
See the License for the specific language governing permissions and
limitations under the License.
-->
-<!-- This test config file is auto-generated. -->
<configuration description="Runs TextClassifierNotificationTests.">
+ <option name="config-descriptor:metadata" key="mainline-param" value="com.google.android.extservices.apex" />
<option name="test-suite-tag" value="apct" />
<option name="test-suite-tag" value="apct-instrumentation" />
<target_preparer class="com.android.tradefed.targetprep.suite.SuiteApkInstaller">
diff --git a/notification/tests/src/com/android/textclassifier/notification/NotificationTest.java b/notification/tests/src/com/android/textclassifier/notification/NotificationTest.java
new file mode 100644
index 0000000..215c3bc
--- /dev/null
+++ b/notification/tests/src/com/android/textclassifier/notification/NotificationTest.java
@@ -0,0 +1,81 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.notification;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import android.app.Notification;
+import android.app.Person;
+import android.content.Context;
+import androidx.test.core.app.ApplicationProvider;
+import java.util.List;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/**
+ * Tests the {@link Notification} API. This is to fulfill the API coverage requirement for mainline.
+ *
+ * <p>Ideally, it should be a CTS. However, it is too late to add a CTS for R, and thus we test it
+ * on our side.
+ */
+@RunWith(JUnit4.class)
+public final class NotificationTest {
+
+ private Context context;
+
+ @Before
+ public void setup() {
+ context = ApplicationProvider.getApplicationContext();
+ }
+
+ @Test
+ public void getMessagesFromBundleArray() {
+ Person sender = new Person.Builder().setName("Sender").build();
+ Notification.MessagingStyle.Message firstExpectedMessage =
+ new Notification.MessagingStyle.Message("hello", /* timestamp= */ 123, sender);
+ Notification.MessagingStyle.Message secondExpectedMessage =
+ new Notification.MessagingStyle.Message("hello2", /* timestamp= */ 456, sender);
+
+ Notification.MessagingStyle messagingStyle =
+ new Notification.MessagingStyle("self name")
+ .addMessage(firstExpectedMessage)
+ .addMessage(secondExpectedMessage);
+ Notification notification =
+ new Notification.Builder(context, "test id")
+ .setSmallIcon(1)
+ .setContentTitle("test title")
+ .setStyle(messagingStyle)
+ .build();
+
+ List<Notification.MessagingStyle.Message> actualMessages =
+ Notification.MessagingStyle.Message.getMessagesFromBundleArray(
+ notification.extras.getParcelableArray(Notification.EXTRA_MESSAGES));
+
+ assertThat(actualMessages).hasSize(2);
+ assertMessageEquals(firstExpectedMessage, actualMessages.get(0));
+ assertMessageEquals(secondExpectedMessage, actualMessages.get(1));
+ }
+
+ private static void assertMessageEquals(
+ Notification.MessagingStyle.Message expected, Notification.MessagingStyle.Message actual) {
+ assertThat(actual.getText().toString()).isEqualTo(expected.getText().toString());
+ assertThat(actual.getTimestamp()).isEqualTo(expected.getTimestamp());
+ assertThat(actual.getSenderPerson()).isEqualTo(expected.getSenderPerson());
+ }
+}
diff --git a/notification/tests/src/com/android/textclassifier/notification/SmartSuggestionsHelperTest.java b/notification/tests/src/com/android/textclassifier/notification/SmartSuggestionsHelperTest.java
index 9d0a720..84cf4fb 100644
--- a/notification/tests/src/com/android/textclassifier/notification/SmartSuggestionsHelperTest.java
+++ b/notification/tests/src/com/android/textclassifier/notification/SmartSuggestionsHelperTest.java
@@ -42,6 +42,7 @@
import androidx.test.ext.junit.runners.AndroidJUnit4;
import androidx.test.filters.LargeTest;
import com.google.common.collect.ImmutableList;
+import java.time.Instant;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
@@ -426,8 +427,8 @@
long expectedReferenceTime) {
assertThat(subject.getText().toString()).isEqualTo(expectedMessage);
assertThat(subject.getAuthor()).isEqualTo(expectedAuthor);
- assertThat(subject.getReferenceTime().toInstant().toEpochMilli())
- .isEqualTo(expectedReferenceTime);
+ assertThat(subject.getReferenceTime().toInstant())
+ .isEqualTo(Instant.ofEpochMilli(expectedReferenceTime));
}
private static void assertAdjustmentWithSmartReply(SmartSuggestions smartSuggestions) {