Export libtextclassifier
Test: atest TextClassifierService
Change-Id: I99b1b4e492f004f542914a3d022d048765260960
diff --git a/java/src/com/android/textclassifier/ActionsSuggestionsHelper.java b/java/src/com/android/textclassifier/ActionsSuggestionsHelper.java
index 0acca43..ca348bf 100644
--- a/java/src/com/android/textclassifier/ActionsSuggestionsHelper.java
+++ b/java/src/com/android/textclassifier/ActionsSuggestionsHelper.java
@@ -29,9 +29,9 @@
import android.view.textclassifier.ConversationActions.Message;
import com.android.textclassifier.ModelFileManager.ModelFile;
import com.android.textclassifier.common.base.TcLog;
+import com.android.textclassifier.common.intent.LabeledIntent;
+import com.android.textclassifier.common.intent.TemplateIntentFactory;
import com.android.textclassifier.common.statsd.ResultIdUtils;
-import com.android.textclassifier.intent.LabeledIntent;
-import com.android.textclassifier.intent.TemplateIntentFactory;
import com.google.android.textclassifier.ActionsSuggestionsModel;
import com.google.android.textclassifier.RemoteActionTemplate;
import com.google.common.base.Equivalence;
diff --git a/java/src/com/android/textclassifier/TextClassifierImpl.java b/java/src/com/android/textclassifier/TextClassifierImpl.java
index 5e7f044..4d0b7cc 100644
--- a/java/src/com/android/textclassifier/TextClassifierImpl.java
+++ b/java/src/com/android/textclassifier/TextClassifierImpl.java
@@ -42,6 +42,8 @@
import androidx.collection.ArraySet;
import androidx.core.util.Pair;
import com.android.textclassifier.common.base.TcLog;
+import com.android.textclassifier.common.intent.LabeledIntent;
+import com.android.textclassifier.common.intent.TemplateIntentFactory;
import com.android.textclassifier.common.statsd.GenerateLinksLogger;
import com.android.textclassifier.common.statsd.ResultIdUtils;
import com.android.textclassifier.common.statsd.SelectionEventConverter;
@@ -49,10 +51,8 @@
import com.android.textclassifier.common.statsd.TextClassifierEventConverter;
import com.android.textclassifier.common.statsd.TextClassifierEventLogger;
import com.android.textclassifier.intent.ClassificationIntentFactory;
-import com.android.textclassifier.intent.LabeledIntent;
import com.android.textclassifier.intent.LegacyClassificationIntentFactory;
import com.android.textclassifier.intent.TemplateClassificationIntentFactory;
-import com.android.textclassifier.intent.TemplateIntentFactory;
import com.android.textclassifier.utils.IndentingPrintWriter;
import com.google.android.textclassifier.ActionsSuggestionsModel;
import com.google.android.textclassifier.AnnotatorModel;
@@ -471,7 +471,7 @@
RemoteAction remoteAction = null;
Bundle extras = new Bundle();
if (labeledIntentResult != null) {
- remoteAction = labeledIntentResult.remoteAction;
+ remoteAction = labeledIntentResult.remoteAction.toRemoteAction();
ExtrasUtils.putActionIntent(extras, labeledIntentResult.resolvedIntent);
}
ExtrasUtils.putSerializedEntityData(extras, nativeSuggestion.getSerializedEntityData());
@@ -641,7 +641,7 @@
final Bundle foreignLanguageBundle = languagesBundles.second;
boolean isPrimaryAction = true;
- final List<LabeledIntent> labeledIntents =
+ final ImmutableList<LabeledIntent> labeledIntents =
classificationIntentFactory.create(
context,
classifiedText,
@@ -660,7 +660,7 @@
}
final Intent intent = result.resolvedIntent;
- final RemoteAction action = result.remoteAction;
+ final RemoteAction action = result.remoteAction.toRemoteAction();
if (isPrimaryAction) {
// For O backwards compatibility, the first RemoteAction is also written to the
// legacy API fields.
diff --git a/java/src/com/android/textclassifier/intent/LabeledIntent.java b/java/src/com/android/textclassifier/common/intent/LabeledIntent.java
similarity index 74%
rename from java/src/com/android/textclassifier/intent/LabeledIntent.java
rename to java/src/com/android/textclassifier/common/intent/LabeledIntent.java
index e3b001b..d6041b9 100644
--- a/java/src/com/android/textclassifier/intent/LabeledIntent.java
+++ b/java/src/com/android/textclassifier/common/intent/LabeledIntent.java
@@ -14,30 +14,27 @@
* limitations under the License.
*/
-package com.android.textclassifier.intent;
+package com.android.textclassifier.common.intent;
import android.app.PendingIntent;
-import android.app.RemoteAction;
import android.content.ComponentName;
import android.content.Context;
import android.content.Intent;
+import android.content.pm.ActivityInfo;
import android.content.pm.PackageManager;
import android.content.pm.ResolveInfo;
-import android.graphics.drawable.Icon;
import android.os.Bundle;
import android.text.TextUtils;
import android.view.textclassifier.TextClassifier;
-import com.android.textclassifier.ExtrasUtils;
-import com.android.textclassifier.R;
+import androidx.annotation.DrawableRes;
+import androidx.core.app.RemoteActionCompat;
+import androidx.core.content.ContextCompat;
+import androidx.core.graphics.drawable.IconCompat;
import com.android.textclassifier.common.base.TcLog;
import com.google.common.base.Preconditions;
import javax.annotation.Nullable;
-/**
- * Helper class to store the information from which RemoteActions are built.
- *
- * @hide
- */
+/** Helper class to store the information from which RemoteActions are built. */
public final class LabeledIntent {
private static final String TAG = "LabeledIntent";
public static final int DEFAULT_REQUEST_CODE = 0;
@@ -104,6 +101,11 @@
TcLog.w(TAG, "resolveInfo or activityInfo is null");
return null;
}
+ if (!hasPermission(context, resolveInfo.activityInfo)) {
+ TcLog.d(TAG, "No permission to access: " + resolveInfo.activityInfo);
+ return null;
+ }
+
final String packageName = resolveInfo.activityInfo.packageName;
final String className = resolveInfo.activityInfo.name;
if (packageName == null || className == null) {
@@ -112,22 +114,24 @@
}
Intent resolvedIntent = new Intent(intent);
resolvedIntent.putExtra(
- TextClassifier.EXTRA_FROM_TEXT_CLASSIFIER, getFromTextClassifierExtra(textLanguagesBundle));
+ TextClassifier.EXTRA_FROM_TEXT_CLASSIFIER,
+ createFromTextClassifierExtra(textLanguagesBundle));
boolean shouldShowIcon = false;
- Icon icon = null;
+ IconCompat icon = null;
if (!"android".equals(packageName)) {
// We only set the component name when the package name is not resolved to "android"
// to workaround a bug that explicit intent with component name == ResolverActivity
// can't be launched on keyguard.
resolvedIntent.setComponent(new ComponentName(packageName, className));
if (resolveInfo.activityInfo.getIconResource() != 0) {
- icon = Icon.createWithResource(packageName, resolveInfo.activityInfo.getIconResource());
+ icon =
+ createIconFromPackage(context, packageName, resolveInfo.activityInfo.getIconResource());
shouldShowIcon = true;
}
}
if (icon == null) {
// RemoteAction requires that there be an icon.
- icon = Icon.createWithResource(context, R.drawable.tcs_app_icon);
+ icon = IconCompat.createWithResource(context, android.R.drawable.ic_menu_more);
}
final PendingIntent pendingIntent = createPendingIntent(context, resolvedIntent, requestCode);
titleChooser = titleChooser == null ? DEFAULT_TITLE_CHOOSER : titleChooser;
@@ -136,8 +140,8 @@
TcLog.w(TAG, "Custom titleChooser return null, fallback to the default titleChooser");
title = DEFAULT_TITLE_CHOOSER.chooseTitle(this, resolveInfo);
}
- final RemoteAction action =
- new RemoteAction(icon, title, resolveDescription(resolveInfo, pm), pendingIntent);
+ final RemoteActionCompat action =
+ new RemoteActionCompat(icon, title, resolveDescription(resolveInfo, pm), pendingIntent);
action.setShouldShowIcon(shouldShowIcon);
return new Result(resolvedIntent, action);
}
@@ -153,6 +157,29 @@
return description;
}
+ // TODO(b/149018167) Remove this once we have moved this to C++.
+ private static void putTextLanguagesExtra(Bundle container, Bundle extra) {
+ container.putBundle("text-languages", extra);
+ }
+
+ @Nullable
+ private static IconCompat createIconFromPackage(
+ Context context, String packageName, @DrawableRes int iconRes) {
+ try {
+ Context packageContext = context.createPackageContext(packageName, 0);
+ return IconCompat.createWithResource(packageContext, iconRes);
+ } catch (PackageManager.NameNotFoundException e) {
+ TcLog.e(TAG, "createIconFromPackage: failed to create package context", e);
+ }
+ return null;
+ }
+
+ private static PendingIntent createPendingIntent(
+ final Context context, final Intent intent, int requestCode) {
+ return PendingIntent.getActivity(
+ context, requestCode, intent, PendingIntent.FLAG_UPDATE_CURRENT);
+ }
+
@Nullable
private static String getApplicationName(ResolveInfo resolveInfo, PackageManager packageManager) {
if (resolveInfo.activityInfo == null) {
@@ -164,31 +191,36 @@
if (resolveInfo.activityInfo.applicationInfo == null) {
return null;
}
- return (String) packageManager.getApplicationLabel(resolveInfo.activityInfo.applicationInfo);
+ return packageManager.getApplicationLabel(resolveInfo.activityInfo.applicationInfo).toString();
}
- private static Bundle getFromTextClassifierExtra(@Nullable Bundle textLanguagesBundle) {
- if (textLanguagesBundle != null) {
- final Bundle bundle = new Bundle();
- ExtrasUtils.putTextLanguagesExtra(bundle, textLanguagesBundle);
- return bundle;
- } else {
+ private static Bundle createFromTextClassifierExtra(@Nullable Bundle textLanguagesBundle) {
+ if (textLanguagesBundle == null) {
return Bundle.EMPTY;
+ } else {
+ Bundle bundle = new Bundle();
+ putTextLanguagesExtra(bundle, textLanguagesBundle);
+ return bundle;
}
}
- private static PendingIntent createPendingIntent(
- final Context context, final Intent intent, int requestCode) {
- return PendingIntent.getActivity(
- context, requestCode, intent, PendingIntent.FLAG_UPDATE_CURRENT);
+ private static boolean hasPermission(Context context, ActivityInfo info) {
+ if (!info.exported) {
+ return false;
+ }
+ if (info.permission == null) {
+ return true;
+ }
+ return ContextCompat.checkSelfPermission(context, info.permission)
+ == PackageManager.PERMISSION_GRANTED;
}
/** Data class that holds the result. */
public static final class Result {
public final Intent resolvedIntent;
- public final RemoteAction remoteAction;
+ public final RemoteActionCompat remoteAction;
- public Result(Intent resolvedIntent, RemoteAction remoteAction) {
+ public Result(Intent resolvedIntent, RemoteActionCompat remoteAction) {
this.resolvedIntent = Preconditions.checkNotNull(resolvedIntent);
this.remoteAction = Preconditions.checkNotNull(remoteAction);
}
diff --git a/java/src/com/android/textclassifier/intent/TemplateIntentFactory.java b/java/src/com/android/textclassifier/common/intent/TemplateIntentFactory.java
similarity index 83%
rename from java/src/com/android/textclassifier/intent/TemplateIntentFactory.java
rename to java/src/com/android/textclassifier/common/intent/TemplateIntentFactory.java
index d74c276..ec95e4e 100644
--- a/java/src/com/android/textclassifier/intent/TemplateIntentFactory.java
+++ b/java/src/com/android/textclassifier/common/intent/TemplateIntentFactory.java
@@ -14,7 +14,7 @@
* limitations under the License.
*/
-package com.android.textclassifier.intent;
+package com.android.textclassifier.common.intent;
import android.content.Intent;
import android.net.Uri;
@@ -23,25 +23,19 @@
import com.android.textclassifier.common.base.TcLog;
import com.google.android.textclassifier.NamedVariant;
import com.google.android.textclassifier.RemoteActionTemplate;
-import java.util.ArrayList;
-import java.util.List;
+import com.google.common.collect.ImmutableList;
import javax.annotation.Nullable;
-/**
- * Creates intents based on {@link RemoteActionTemplate} objects.
- *
- * @hide
- */
+/** Creates intents based on {@link RemoteActionTemplate} objects. */
public final class TemplateIntentFactory {
private static final String TAG = "TemplateIntentFactory";
/** Constructs and returns a list of {@link LabeledIntent} based on the given templates. */
- @Nullable
- public List<LabeledIntent> create(RemoteActionTemplate[] remoteActionTemplates) {
+ public ImmutableList<LabeledIntent> create(RemoteActionTemplate[] remoteActionTemplates) {
if (remoteActionTemplates.length == 0) {
- return new ArrayList<>();
+ return ImmutableList.of();
}
- final List<LabeledIntent> labeledIntents = new ArrayList<>();
+ final ImmutableList.Builder<LabeledIntent> labeledIntents = ImmutableList.builder();
for (RemoteActionTemplate remoteActionTemplate : remoteActionTemplates) {
if (!isValidTemplate(remoteActionTemplate)) {
TcLog.w(TAG, "Invalid RemoteActionTemplate skipped.");
@@ -58,7 +52,7 @@
? LabeledIntent.DEFAULT_REQUEST_CODE
: remoteActionTemplate.requestCode));
}
- return labeledIntents;
+ return labeledIntents.build();
}
private static boolean isValidTemplate(@Nullable RemoteActionTemplate remoteActionTemplate) {
@@ -98,6 +92,9 @@
: Intent.normalizeMimeType(remoteActionTemplate.type);
intent.setDataAndType(uri, type);
intent.setFlags(remoteActionTemplate.flags == null ? 0 : remoteActionTemplate.flags);
+ if (!TextUtils.isEmpty(remoteActionTemplate.packageName)) {
+ intent.setPackage(remoteActionTemplate.packageName);
+ }
if (remoteActionTemplate.category != null) {
for (String category : remoteActionTemplate.category) {
if (category != null) {
@@ -138,6 +135,16 @@
case NamedVariant.TYPE_STRING:
bundle.putString(namedVariant.getName(), namedVariant.getString());
break;
+ case NamedVariant.TYPE_STRING_ARRAY:
+ bundle.putStringArray(namedVariant.getName(), namedVariant.getStringArray());
+ break;
+ case NamedVariant.TYPE_FLOAT_ARRAY:
+ bundle.putFloatArray(namedVariant.getName(), namedVariant.getFloatArray());
+ break;
+ case NamedVariant.TYPE_INT_ARRAY:
+ bundle.putIntArray(namedVariant.getName(), namedVariant.getIntArray());
+ break;
+
default:
TcLog.w(
TAG, "Unsupported type found in nameVariantsToBundle : " + namedVariant.getType());
diff --git a/java/src/com/android/textclassifier/common/statsd/GenerateLinksLogger.java b/java/src/com/android/textclassifier/common/statsd/GenerateLinksLogger.java
index 7db5003..80321f7 100644
--- a/java/src/com/android/textclassifier/common/statsd/GenerateLinksLogger.java
+++ b/java/src/com/android/textclassifier/common/statsd/GenerateLinksLogger.java
@@ -16,10 +16,11 @@
package com.android.textclassifier.common.statsd;
+import android.util.StatsEvent;
+import android.util.StatsLog;
import android.view.textclassifier.TextClassifier;
import android.view.textclassifier.TextLinks;
import androidx.collection.ArrayMap;
-import com.android.textclassifier.TextClassifierStatsLog;
import com.android.textclassifier.common.base.TcLog;
import com.android.textclassifier.common.logging.TextClassifierEvent;
import com.google.common.annotations.VisibleForTesting;
@@ -127,19 +128,24 @@
LinkifyStats stats,
CharSequence text,
long latencyMs) {
- TextClassifierStatsLog.write(
- TextClassifierStatsLog.TEXT_LINKIFY_EVENT,
- callId,
- TextClassifierEvent.TYPE_LINKS_GENERATED,
- /*modelName=*/ null,
- TextClassifierEventLogger.WidgetType.WIDGET_TYPE_UNKNOWN,
- /*eventIndex=*/ 0,
- entityType,
- stats.numLinks,
- stats.numLinksTextLength,
- text.length(),
- latencyMs,
- callingPackageName);
+ StatsEvent statsEvent =
+ StatsEvent.newBuilder()
+ .setAtomId(TextClassifierEventLogger.TEXT_LINKIFY_EVENT_ATOM_ID)
+ .writeString(callId)
+ .writeInt(TextClassifierEvent.TYPE_LINKS_GENERATED)
+ .writeString(/* modelName */ null)
+ .writeInt(TextClassifierEventLogger.WidgetType.WIDGET_TYPE_UNKNOWN)
+ .writeInt(/* eventIndex */ 0)
+ .writeString(entityType)
+ .writeInt(stats.numLinks)
+ .writeInt(stats.numLinksTextLength)
+ .writeInt(text.length())
+ .writeLong(latencyMs)
+ .writeString(callingPackageName)
+ .usePooledBuffer()
+ .build();
+ StatsLog.write(statsEvent);
+
if (TcLog.ENABLE_FULL_LOGGING) {
TcLog.v(
LOG_TAG,
diff --git a/java/src/com/android/textclassifier/common/statsd/TextClassifierEventLogger.java b/java/src/com/android/textclassifier/common/statsd/TextClassifierEventLogger.java
index f42ea82..64426ef 100644
--- a/java/src/com/android/textclassifier/common/statsd/TextClassifierEventLogger.java
+++ b/java/src/com/android/textclassifier/common/statsd/TextClassifierEventLogger.java
@@ -16,8 +16,9 @@
package com.android.textclassifier.common.statsd;
+import android.util.StatsEvent;
+import android.util.StatsLog;
import android.view.textclassifier.TextClassifier;
-import com.android.textclassifier.TextClassifierStatsLog;
import com.android.textclassifier.common.base.TcLog;
import com.android.textclassifier.common.logging.TextClassificationContext;
import com.android.textclassifier.common.logging.TextClassificationSessionId;
@@ -29,8 +30,12 @@
/** Logs {@link android.view.textclassifier.TextClassifierEvent}. */
public final class TextClassifierEventLogger {
-
private static final String TAG = "TCEventLogger";
+ // These constants are defined in atoms.proto.
+ private static final int TEXT_SELECTION_EVENT_ATOM_ID = 219;
+ static final int TEXT_LINKIFY_EVENT_ATOM_ID = 220;
+ private static final int CONVERSATION_ACTIONS_EVENT_ATOM_ID = 221;
+ private static final int LANGUAGE_DETECTION_EVENT_ATOM_ID = 222;
/** Emits a text classifier event to the logs. */
public void writeEvent(
@@ -58,71 +63,89 @@
private static void logTextSelectionEvent(
@Nullable TextClassificationSessionId sessionId,
TextClassifierEvent.TextSelectionEvent event) {
- TextClassifierStatsLog.write(
- TextClassifierStatsLog.TEXT_SELECTION_EVENT,
- sessionId == null ? null : sessionId.flattenToString(),
- event.getEventType(),
- getModelName(event),
- getWidgetType(event),
- event.getEventIndex(),
- getItemAt(event.getEntityTypes(), /* index= */ 0),
- event.getRelativeWordStartIndex(),
- event.getRelativeWordEndIndex(),
- event.getRelativeSuggestedWordStartIndex(),
- event.getRelativeSuggestedWordEndIndex(),
- getPackageName(event));
+ StatsEvent statsEvent =
+ StatsEvent.newBuilder()
+ .setAtomId(TEXT_SELECTION_EVENT_ATOM_ID)
+ .writeString(sessionId == null ? null : sessionId.flattenToString())
+ .writeInt(event.getEventType())
+ .writeString(getModelName(event))
+ .writeInt(getWidgetType(event))
+ .writeInt(event.getEventIndex())
+ .writeString(getItemAt(event.getEntityTypes(), /* index= */ 0))
+ .writeInt(event.getRelativeWordStartIndex())
+ .writeInt(event.getRelativeWordEndIndex())
+ .writeInt(event.getRelativeSuggestedWordStartIndex())
+ .writeInt(event.getRelativeSuggestedWordEndIndex())
+ .writeString(getPackageName(event))
+ .usePooledBuffer()
+ .build();
+ StatsLog.write(statsEvent);
}
private static void logTextLinkifyEvent(
TextClassificationSessionId sessionId, TextClassifierEvent.TextLinkifyEvent event) {
- TextClassifierStatsLog.write(
- TextClassifierStatsLog.TEXT_LINKIFY_EVENT,
- sessionId == null ? null : sessionId.flattenToString(),
- event.getEventType(),
- getModelName(event),
- getWidgetType(event),
- event.getEventIndex(),
- getItemAt(event.getEntityTypes(), /* index= */ 0),
- /*numOfLinks=*/ 0,
- /*linkedTextLength=*/ 0,
- /*textLength=*/ 0,
- /*latencyInMillis=*/ 0L,
- getPackageName(event));
+ StatsEvent statsEvent =
+ StatsEvent.newBuilder()
+ .setAtomId(TEXT_LINKIFY_EVENT_ATOM_ID)
+ .writeString(sessionId == null ? null : sessionId.flattenToString())
+ .writeInt(event.getEventType())
+ .writeString(getModelName(event))
+ .writeInt(getWidgetType(event))
+ .writeInt(event.getEventIndex())
+ .writeString(getItemAt(event.getEntityTypes(), /* index= */ 0))
+ .writeInt(/* numOfLinks */ 0)
+ .writeInt(/* linkedTextLength */ 0)
+ .writeInt(/* textLength */ 0)
+ .writeLong(/* latencyInMillis */ 0L)
+ .writeString(getPackageName(event))
+ .usePooledBuffer()
+ .build();
+ StatsLog.write(statsEvent);
}
private static void logConversationActionsEvent(
@Nullable TextClassificationSessionId sessionId,
TextClassifierEvent.ConversationActionsEvent event) {
ImmutableList<String> modelNames = ResultIdUtils.getModelNames(event.getResultId());
- TextClassifierStatsLog.write(
- TextClassifierStatsLog.CONVERSATION_ACTIONS_EVENT,
- sessionId == null
- ? event.getResultId() // TODO: Update ExtServices to set the session id.
- : sessionId.flattenToString(),
- event.getEventType(),
- getItemAt(modelNames, 0, null),
- getWidgetType(event),
- getItemAt(event.getEntityTypes(), /* index= */ 0),
- getItemAt(event.getEntityTypes(), /* index= */ 1),
- getItemAt(event.getEntityTypes(), /* index= */ 2),
- getFloatAt(event.getScores(), /* index= */ 0),
- getPackageName(event),
- getItemAt(modelNames, 1, null));
+
+ StatsEvent statsEvent =
+ StatsEvent.newBuilder()
+ .setAtomId(CONVERSATION_ACTIONS_EVENT_ATOM_ID)
+ .writeString(
+ sessionId == null
+ ? event.getResultId() // TODO: Update ExtServices to set the session id.
+ : sessionId.flattenToString())
+ .writeInt(event.getEventType())
+ .writeString(getItemAt(modelNames, 0, null))
+ .writeInt(getWidgetType(event))
+ .writeString(getItemAt(event.getEntityTypes(), /* index= */ 0))
+ .writeString(getItemAt(event.getEntityTypes(), /* index= */ 1))
+ .writeString(getItemAt(event.getEntityTypes(), /* index= */ 2))
+ .writeFloat(getFloatAt(event.getScores(), /* index= */ 0))
+ .writeString(getPackageName(event))
+ .writeString(getItemAt(modelNames, 1, null))
+ .usePooledBuffer()
+ .build();
+ StatsLog.write(statsEvent);
}
private static void logLanguageDetectionEvent(
@Nullable TextClassificationSessionId sessionId,
TextClassifierEvent.LanguageDetectionEvent event) {
- TextClassifierStatsLog.write(
- TextClassifierStatsLog.LANGUAGE_DETECTION_EVENT,
- sessionId == null ? null : sessionId.flattenToString(),
- event.getEventType(),
- getModelName(event),
- getWidgetType(event),
- getItemAt(event.getEntityTypes(), /* index= */ 0),
- getFloatAt(event.getScores(), /* index= */ 0),
- getIntAt(event.getActionIndices(), /* index= */ 0),
- getPackageName(event));
+ StatsEvent statsEvent =
+ StatsEvent.newBuilder()
+ .setAtomId(LANGUAGE_DETECTION_EVENT_ATOM_ID)
+ .writeString(sessionId == null ? null : sessionId.flattenToString())
+ .writeInt(event.getEventType())
+ .writeString(getModelName(event))
+ .writeInt(getWidgetType(event))
+ .writeString(getItemAt(event.getEntityTypes(), /* index= */ 0))
+ .writeFloat(getFloatAt(event.getScores(), /* index= */ 0))
+ .writeInt(getIntAt(event.getActionIndices(), /* index= */ 0))
+ .writeString(getPackageName(event))
+ .usePooledBuffer()
+ .build();
+ StatsLog.write(statsEvent);
}
@Nullable
diff --git a/java/src/com/android/textclassifier/intent/ClassificationIntentFactory.java b/java/src/com/android/textclassifier/intent/ClassificationIntentFactory.java
index 450ffc4..cb38c95 100644
--- a/java/src/com/android/textclassifier/intent/ClassificationIntentFactory.java
+++ b/java/src/com/android/textclassifier/intent/ClassificationIntentFactory.java
@@ -19,7 +19,9 @@
import android.content.Context;
import android.content.Intent;
import com.android.textclassifier.R;
+import com.android.textclassifier.common.intent.LabeledIntent;
import com.google.android.textclassifier.AnnotatorModel;
+import com.google.common.collect.ImmutableList;
import java.time.Instant;
import java.util.List;
import javax.annotation.Nullable;
@@ -28,7 +30,7 @@
public interface ClassificationIntentFactory {
/** Return a list of LabeledIntent from the classification result. */
- List<LabeledIntent> create(
+ ImmutableList<LabeledIntent> create(
Context context,
String text,
boolean foreignText,
diff --git a/java/src/com/android/textclassifier/intent/LegacyClassificationIntentFactory.java b/java/src/com/android/textclassifier/intent/LegacyClassificationIntentFactory.java
index 74aa099..f1659dc 100644
--- a/java/src/com/android/textclassifier/intent/LegacyClassificationIntentFactory.java
+++ b/java/src/com/android/textclassifier/intent/LegacyClassificationIntentFactory.java
@@ -31,7 +31,9 @@
import android.view.textclassifier.TextClassifier;
import com.android.textclassifier.R;
import com.android.textclassifier.common.base.TcLog;
+import com.android.textclassifier.common.intent.LabeledIntent;
import com.google.android.textclassifier.AnnotatorModel;
+import com.google.common.collect.ImmutableList;
import java.io.UnsupportedEncodingException;
import java.net.URLEncoder;
import java.time.Instant;
@@ -53,7 +55,7 @@
private static final String TYPE_DICTIONARY = "dictionary";
@Override
- public List<LabeledIntent> create(
+ public ImmutableList<LabeledIntent> create(
Context context,
String text,
boolean foreignText,
@@ -102,7 +104,7 @@
if (foreignText) {
ClassificationIntentFactory.insertTranslateAction(actions, context, text);
}
- return actions;
+ return ImmutableList.copyOf(actions);
}
private static List<LabeledIntent> createForEmail(Context context, String text) {
diff --git a/java/src/com/android/textclassifier/intent/TemplateClassificationIntentFactory.java b/java/src/com/android/textclassifier/intent/TemplateClassificationIntentFactory.java
index fbd742e..f3f6fd3 100644
--- a/java/src/com/android/textclassifier/intent/TemplateClassificationIntentFactory.java
+++ b/java/src/com/android/textclassifier/intent/TemplateClassificationIntentFactory.java
@@ -18,11 +18,14 @@
import android.content.Context;
import com.android.textclassifier.common.base.TcLog;
+import com.android.textclassifier.common.intent.LabeledIntent;
+import com.android.textclassifier.common.intent.TemplateIntentFactory;
import com.google.android.textclassifier.AnnotatorModel;
import com.google.android.textclassifier.RemoteActionTemplate;
import com.google.common.base.Preconditions;
+import com.google.common.collect.ImmutableList;
import java.time.Instant;
-import java.util.Collections;
+import java.util.ArrayList;
import java.util.List;
import javax.annotation.Nullable;
@@ -46,14 +49,14 @@
* Returns a list of {@link LabeledIntent} that are constructed from the classification result.
*/
@Override
- public List<LabeledIntent> create(
+ public ImmutableList<LabeledIntent> create(
Context context,
String text,
boolean foreignText,
@Nullable Instant referenceTime,
@Nullable AnnotatorModel.ClassificationResult classification) {
if (classification == null) {
- return Collections.emptyList();
+ return ImmutableList.of();
}
RemoteActionTemplate[] remoteActionTemplates = classification.getRemoteActionTemplates();
if (remoteActionTemplates == null) {
@@ -63,10 +66,11 @@
"RemoteActionTemplate is missing, fallback to" + " LegacyClassificationIntentFactory.");
return fallback.create(context, text, foreignText, referenceTime, classification);
}
- final List<LabeledIntent> labeledIntents = templateIntentFactory.create(remoteActionTemplates);
+ final List<LabeledIntent> labeledIntents =
+ new ArrayList<>(templateIntentFactory.create(remoteActionTemplates));
if (foreignText) {
ClassificationIntentFactory.insertTranslateAction(labeledIntents, context, text.trim());
}
- return labeledIntents;
+ return ImmutableList.copyOf(labeledIntents);
}
}
diff --git a/java/tests/instrumentation/AndroidManifest.xml b/java/tests/instrumentation/AndroidManifest.xml
index 5de247c..129b909 100644
--- a/java/tests/instrumentation/AndroidManifest.xml
+++ b/java/tests/instrumentation/AndroidManifest.xml
@@ -1,8 +1,10 @@
<?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
- package="com.android.textclassifier.tests">
+ package="com.android.textclassifier.common.tests"
+ android:versionCode="1"
+ android:versionName="1.0">
- <uses-sdk android:minSdkVersion="28"/>
+ <uses-sdk android:minSdkVersion="16"/>
<application>
<uses-library android:name="android.test.runner"/>
@@ -10,5 +12,5 @@
<instrumentation
android:name="androidx.test.runner.AndroidJUnitRunner"
- android:targetPackage="com.android.textclassifier.tests"/>
+ android:targetPackage="com.android.textclassifier.common.tests"/>
</manifest>
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/ActionsSuggestionsHelperTest.java b/java/tests/instrumentation/src/com/android/textclassifier/ActionsSuggestionsHelperTest.java
index d25d97e..59dc41a 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/ActionsSuggestionsHelperTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/ActionsSuggestionsHelperTest.java
@@ -33,8 +33,8 @@
import androidx.test.core.app.ApplicationProvider;
import androidx.test.ext.junit.runners.AndroidJUnit4;
import androidx.test.filters.SmallTest;
-import com.android.textclassifier.intent.LabeledIntent;
-import com.android.textclassifier.intent.TemplateIntentFactory;
+import com.android.textclassifier.common.intent.LabeledIntent;
+import com.android.textclassifier.common.intent.TemplateIntentFactory;
import com.google.android.textclassifier.ActionsSuggestionsModel;
import com.google.android.textclassifier.RemoteActionTemplate;
import com.google.common.collect.ImmutableList;
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java
index a61ea36..9e31c7e 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java
@@ -41,6 +41,7 @@
import androidx.test.core.app.ApplicationProvider;
import androidx.test.ext.junit.runners.AndroidJUnit4;
import androidx.test.filters.SmallTest;
+import com.android.textclassifier.testing.FakeContextBuilder;
import com.google.common.collect.ImmutableList;
import java.util.Arrays;
import java.util.Collections;
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/intent/LabeledIntentTest.java b/java/tests/instrumentation/src/com/android/textclassifier/common/intent/LabeledIntentTest.java
similarity index 97%
rename from java/tests/instrumentation/src/com/android/textclassifier/intent/LabeledIntentTest.java
rename to java/tests/instrumentation/src/com/android/textclassifier/common/intent/LabeledIntentTest.java
index 3840823..fc0393a 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/intent/LabeledIntentTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/common/intent/LabeledIntentTest.java
@@ -14,7 +14,7 @@
* limitations under the License.
*/
-package com.android.textclassifier.intent;
+package com.android.textclassifier.common.intent;
import static com.google.common.truth.Truth.assertThat;
import static org.testng.Assert.assertThrows;
@@ -27,7 +27,7 @@
import android.view.textclassifier.TextClassifier;
import androidx.test.ext.junit.runners.AndroidJUnit4;
import androidx.test.filters.SmallTest;
-import com.android.textclassifier.FakeContextBuilder;
+import com.android.textclassifier.testing.FakeContextBuilder;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/intent/TemplateIntentFactoryTest.java b/java/tests/instrumentation/src/com/android/textclassifier/common/intent/TemplateIntentFactoryTest.java
similarity index 89%
rename from java/tests/instrumentation/src/com/android/textclassifier/intent/TemplateIntentFactoryTest.java
rename to java/tests/instrumentation/src/com/android/textclassifier/common/intent/TemplateIntentFactoryTest.java
index ee45f18..fc36f1a 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/intent/TemplateIntentFactoryTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/common/intent/TemplateIntentFactoryTest.java
@@ -14,7 +14,7 @@
* limitations under the License.
*/
-package com.android.textclassifier.intent;
+package com.android.textclassifier.common.intent;
import static com.google.common.truth.Truth.assertThat;
@@ -49,10 +49,20 @@
private static final String VALUE_ONE = "value1";
private static final String KEY_TWO = "key2";
private static final int VALUE_TWO = 42;
+ private static final String KEY_STRING_ARRAY = "string_array_key";
+ private static final String[] VALUE_STRING_ARRAY = new String[] {"a", "b"};
+ private static final String KEY_FLOAT_ARRAY = "float_array_key";
+ private static final float[] VALUE_FLOAT_ARRAY = new float[] {3.14f, 2.718f};
+ private static final String KEY_INT_ARRAY = "int_array_key";
+ private static final int[] VALUE_INT_ARRAY = new int[] {7, 2, 1};
private static final NamedVariant[] NAMED_VARIANTS =
new NamedVariant[] {
- new NamedVariant(KEY_ONE, VALUE_ONE), new NamedVariant(KEY_TWO, VALUE_TWO)
+ new NamedVariant(KEY_ONE, VALUE_ONE),
+ new NamedVariant(KEY_TWO, VALUE_TWO),
+ new NamedVariant(KEY_STRING_ARRAY, VALUE_STRING_ARRAY),
+ new NamedVariant(KEY_FLOAT_ARRAY, VALUE_FLOAT_ARRAY),
+ new NamedVariant(KEY_INT_ARRAY, VALUE_INT_ARRAY)
};
private static final Integer REQUEST_CODE = 10;
@@ -100,6 +110,9 @@
assertThat(intent.getPackage()).isNull();
assertThat(intent.getStringExtra(KEY_ONE)).isEqualTo(VALUE_ONE);
assertThat(intent.getIntExtra(KEY_TWO, 0)).isEqualTo(VALUE_TWO);
+ assertThat(intent.getStringArrayExtra(KEY_STRING_ARRAY)).isEqualTo(VALUE_STRING_ARRAY);
+ assertThat(intent.getFloatArrayExtra(KEY_FLOAT_ARRAY)).isEqualTo(VALUE_FLOAT_ARRAY);
+ assertThat(intent.getIntArrayExtra(KEY_INT_ARRAY)).isEqualTo(VALUE_INT_ARRAY);
}
@Test
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/intent/LegacyIntentClassificationFactoryTest.java b/java/tests/instrumentation/src/com/android/textclassifier/intent/LegacyIntentClassificationFactoryTest.java
index b976849..33c5085 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/intent/LegacyIntentClassificationFactoryTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/intent/LegacyIntentClassificationFactoryTest.java
@@ -23,6 +23,7 @@
import androidx.test.core.app.ApplicationProvider;
import androidx.test.ext.junit.runners.AndroidJUnit4;
import androidx.test.filters.SmallTest;
+import com.android.textclassifier.common.intent.LabeledIntent;
import com.google.android.textclassifier.AnnotatorModel;
import com.google.android.textclassifier.AnnotatorModel.ClassificationResult;
import com.google.android.textclassifier.AnnotatorModel.DatetimeResult;
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/intent/TemplateClassificationIntentFactoryTest.java b/java/tests/instrumentation/src/com/android/textclassifier/intent/TemplateClassificationIntentFactoryTest.java
index 42176bd..098d930 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/intent/TemplateClassificationIntentFactoryTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/intent/TemplateClassificationIntentFactoryTest.java
@@ -29,6 +29,8 @@
import androidx.test.core.app.ApplicationProvider;
import androidx.test.ext.junit.runners.AndroidJUnit4;
import androidx.test.filters.SmallTest;
+import com.android.textclassifier.common.intent.LabeledIntent;
+import com.android.textclassifier.common.intent.TemplateIntentFactory;
import com.google.android.textclassifier.AnnotatorModel;
import com.google.android.textclassifier.RemoteActionTemplate;
import java.util.List;
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/FakeContextBuilder.java b/java/tests/instrumentation/src/com/android/textclassifier/testing/FakeContextBuilder.java
similarity index 98%
rename from java/tests/instrumentation/src/com/android/textclassifier/FakeContextBuilder.java
rename to java/tests/instrumentation/src/com/android/textclassifier/testing/FakeContextBuilder.java
index 17b6e0a..f3ad833 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/FakeContextBuilder.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/testing/FakeContextBuilder.java
@@ -14,7 +14,7 @@
* limitations under the License.
*/
-package com.android.textclassifier;
+package com.android.textclassifier.testing;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
diff --git a/jni/com/google/android/textclassifier/ActionsSuggestionsModel.java b/jni/com/google/android/textclassifier/ActionsSuggestionsModel.java
index 84b5c3d..3af04e8 100644
--- a/jni/com/google/android/textclassifier/ActionsSuggestionsModel.java
+++ b/jni/com/google/android/textclassifier/ActionsSuggestionsModel.java
@@ -70,7 +70,7 @@
actionsModelPtr,
conversation,
options,
- (annotator != null ? annotator.getNativeAnnotator() : 0),
+ (annotator != null ? annotator.getNativeAnnotatorPointer() : 0),
/* appContext= */ null,
/* deviceLocales= */ null,
/* generateAndroidIntents= */ false);
@@ -86,7 +86,7 @@
actionsModelPtr,
conversation,
options,
- (annotator != null ? annotator.getNativeAnnotator() : 0),
+ (annotator != null ? annotator.getNativeAnnotatorPointer() : 0),
appContext,
deviceLocales,
/* generateAndroidIntents= */ true);
diff --git a/jni/com/google/android/textclassifier/AnnotatorModel.java b/jni/com/google/android/textclassifier/AnnotatorModel.java
index a94e97e..39ad1ae 100644
--- a/jni/com/google/android/textclassifier/AnnotatorModel.java
+++ b/jni/com/google/android/textclassifier/AnnotatorModel.java
@@ -650,7 +650,7 @@
* Retrieves the pointer to the native object. Note: Need to keep the AnnotatorModel alive as long
* as the pointer is used.
*/
- long getNativeAnnotator() {
+ long getNativeAnnotatorPointer() {
return nativeGetNativeModelPtr(annotatorPtr);
}
diff --git a/jni/com/google/android/textclassifier/LangIdModel.java b/jni/com/google/android/textclassifier/LangIdModel.java
index 9701492..c9c47f1 100644
--- a/jni/com/google/android/textclassifier/LangIdModel.java
+++ b/jni/com/google/android/textclassifier/LangIdModel.java
@@ -113,6 +113,14 @@
return nativeGetMinTextSizeInBytes(modelPtr);
}
+ /**
+ * Returns the pointer to the native object. Note: Need to keep the LangIdModel alive as long as
+ * the pointer is used.
+ */
+ long getNativeLangIdPointer() {
+ return modelPtr;
+ }
+
private static native long nativeNew(int fd);
private static native long nativeNewFromPath(String path);
diff --git a/jni/com/google/android/textclassifier/NamedVariant.java b/jni/com/google/android/textclassifier/NamedVariant.java
index d04bb11..23f2616 100644
--- a/jni/com/google/android/textclassifier/NamedVariant.java
+++ b/jni/com/google/android/textclassifier/NamedVariant.java
@@ -29,6 +29,9 @@
public static final int TYPE_DOUBLE = 4;
public static final int TYPE_BOOL = 5;
public static final int TYPE_STRING = 6;
+ public static final int TYPE_STRING_ARRAY = 7;
+ public static final int TYPE_FLOAT_ARRAY = 8;
+ public static final int TYPE_INT_ARRAY = 9;
public NamedVariant(String name, int value) {
this.name = name;
@@ -66,6 +69,24 @@
this.type = TYPE_STRING;
}
+ public NamedVariant(String name, String[] value) {
+ this.name = name;
+ this.stringArrValue = value;
+ this.type = TYPE_STRING_ARRAY;
+ }
+
+ public NamedVariant(String name, float[] value) {
+ this.name = name;
+ this.floatArrValue = value;
+ this.type = TYPE_FLOAT_ARRAY;
+ }
+
+ public NamedVariant(String name, int[] value) {
+ this.name = name;
+ this.intArrValue = value;
+ this.type = TYPE_INT_ARRAY;
+ }
+
public String getName() {
return name;
}
@@ -104,6 +125,21 @@
return stringValue;
}
+ public String[] getStringArray() {
+ assert (type == TYPE_STRING_ARRAY);
+ return stringArrValue;
+ }
+
+ public float[] getFloatArray() {
+ assert (type == TYPE_FLOAT_ARRAY);
+ return floatArrValue;
+ }
+
+ public int[] getIntArray() {
+ assert (type == TYPE_INT_ARRAY);
+ return intArrValue;
+ }
+
private final String name;
private final int type;
private int intValue;
@@ -112,4 +148,7 @@
private double doubleValue;
private boolean boolValue;
private String stringValue;
+ private String[] stringArrValue;
+ private float[] floatArrValue;
+ private int[] intArrValue;
}
diff --git a/native/actions/actions-suggestions.cc b/native/actions/actions-suggestions.cc
index 711d10c..17ba96a 100644
--- a/native/actions/actions-suggestions.cc
+++ b/native/actions/actions-suggestions.cc
@@ -866,9 +866,45 @@
return options;
}
+// Run annotator on the messages of a conversation.
+Conversation ActionsSuggestions::AnnotateConversation(
+ const Conversation& conversation, const Annotator* annotator) const {
+ if (annotator == nullptr) {
+ return conversation;
+ }
+ const int num_messages_grammar =
+ ((model_->rules() && model_->rules()->grammar_rules() &&
+ model_->rules()->grammar_rules()->annotation_nonterminal())
+ ? 1
+ : 0);
+ const int num_messages_mapping =
+ (model_->annotation_actions_spec()
+ ? std::max(model_->annotation_actions_spec()
+ ->max_history_from_any_person(),
+ model_->annotation_actions_spec()
+ ->max_history_from_last_person())
+ : 0);
+ const int num_messages = std::max(num_messages_grammar, num_messages_mapping);
+ if (num_messages == 0) {
+ // No annotations are used.
+ return conversation;
+ }
+ Conversation annotated_conversation = conversation;
+ for (int i = 0, message_index = annotated_conversation.messages.size() - 1;
+ i < num_messages && message_index >= 0; i++, message_index--) {
+ ConversationMessage* message =
+ &annotated_conversation.messages[message_index];
+ if (message->annotations.empty()) {
+ message->annotations = annotator->Annotate(
+ message->text, AnnotationOptionsForMessage(*message));
+ }
+ }
+ return annotated_conversation;
+}
+
void ActionsSuggestions::SuggestActionsFromAnnotations(
- const Conversation& conversation, const ActionSuggestionOptions& options,
- const Annotator* annotator, std::vector<ActionSuggestion>* actions) const {
+ const Conversation& conversation,
+ std::vector<ActionSuggestion>* actions) const {
if (model_->annotation_actions_spec() == nullptr ||
model_->annotation_actions_spec()->annotation_mapping() == nullptr ||
model_->annotation_actions_spec()->annotation_mapping()->size() == 0) {
@@ -914,10 +950,6 @@
}
}
- if (annotations.empty() && annotator != nullptr) {
- annotations = annotator->Annotate(message.text,
- AnnotationOptionsForMessage(message));
- }
std::vector<ActionSuggestionAnnotation> action_annotations;
action_annotations.reserve(annotations.size());
for (const AnnotatedSpan& annotation : annotations) {
@@ -1057,25 +1089,29 @@
return true;
}
+ // Run annotator against messages.
+ const Conversation annotated_conversation =
+ AnnotateConversation(conversation, annotator);
+
const int num_messages = NumMessagesToConsider(
- conversation, model_->max_conversation_history_length());
+ annotated_conversation, model_->max_conversation_history_length());
if (num_messages <= 0) {
TC3_LOG(INFO) << "No messages provided for actions suggestions.";
return false;
}
- SuggestActionsFromAnnotations(conversation, options, annotator,
- &response->actions);
+ SuggestActionsFromAnnotations(annotated_conversation, &response->actions);
int input_text_length = 0;
int num_matching_locales = 0;
- for (int i = conversation.messages.size() - num_messages;
- i < conversation.messages.size(); i++) {
- input_text_length += conversation.messages[i].text.length();
+ for (int i = annotated_conversation.messages.size() - num_messages;
+ i < annotated_conversation.messages.size(); i++) {
+ input_text_length += annotated_conversation.messages[i].text.length();
std::vector<Locale> message_languages;
- if (!ParseLocales(conversation.messages[i].detected_text_language_tags,
- &message_languages)) {
+ if (!ParseLocales(
+ annotated_conversation.messages[i].detected_text_language_tags,
+ &message_languages)) {
continue;
}
if (Locale::IsAnyLocaleSupported(
@@ -1105,17 +1141,18 @@
std::vector<const UniLib::RegexPattern*> post_check_rules;
if (preconditions_.suppress_on_low_confidence_input) {
if ((ngram_model_ != nullptr &&
- ngram_model_->EvalConversation(conversation, num_messages)) ||
- regex_actions_->IsLowConfidenceInput(conversation, num_messages,
- &post_check_rules)) {
+ ngram_model_->EvalConversation(annotated_conversation,
+ num_messages)) ||
+ regex_actions_->IsLowConfidenceInput(annotated_conversation,
+ num_messages, &post_check_rules)) {
response->output_filtered_low_confidence = true;
return true;
}
}
std::unique_ptr<tflite::Interpreter> interpreter;
- if (!SuggestActionsFromModel(conversation, num_messages, options, response,
- &interpreter)) {
+ if (!SuggestActionsFromModel(annotated_conversation, num_messages, options,
+ response, &interpreter)) {
TC3_LOG(ERROR) << "Could not run model.";
return false;
}
@@ -1127,21 +1164,23 @@
}
if (!SuggestActionsFromLua(
- conversation, model_executor_.get(), interpreter.get(),
+ annotated_conversation, model_executor_.get(), interpreter.get(),
annotator != nullptr ? annotator->entity_data_schema() : nullptr,
&response->actions)) {
TC3_LOG(ERROR) << "Could not suggest actions from script.";
return false;
}
- if (!regex_actions_->SuggestActions(conversation, entity_data_builder_.get(),
+ if (!regex_actions_->SuggestActions(annotated_conversation,
+ entity_data_builder_.get(),
&response->actions)) {
TC3_LOG(ERROR) << "Could not suggest actions from regex rules.";
return false;
}
if (grammar_actions_ != nullptr &&
- !grammar_actions_->SuggestActions(conversation, &response->actions)) {
+ !grammar_actions_->SuggestActions(annotated_conversation,
+ &response->actions)) {
TC3_LOG(ERROR) << "Could not suggest actions from grammar rules.";
return false;
}
diff --git a/native/actions/actions-suggestions.h b/native/actions/actions-suggestions.h
index 86d91f7..cd0714a 100644
--- a/native/actions/actions-suggestions.h
+++ b/native/actions/actions-suggestions.h
@@ -197,13 +197,17 @@
const ConversationMessage& message) const;
void SuggestActionsFromAnnotations(
- const Conversation& conversation, const ActionSuggestionOptions& options,
- const Annotator* annotator, std::vector<ActionSuggestion>* actions) const;
+ const Conversation& conversation,
+ std::vector<ActionSuggestion>* actions) const;
void SuggestActionsFromAnnotation(
const int message_index, const ActionSuggestionAnnotation& annotation,
std::vector<ActionSuggestion>* actions) const;
+ // Run annotator on the messages of a conversation.
+ Conversation AnnotateConversation(const Conversation& conversation,
+ const Annotator* annotator) const;
+
// Deduplicates equivalent annotations - annotations that have the same type
// and same span text.
// Returns the indices of the deduplicated annotations.
diff --git a/native/actions/actions-suggestions_test.cc b/native/actions/actions-suggestions_test.cc
deleted file mode 100644
index 3dfefa3..0000000
--- a/native/actions/actions-suggestions_test.cc
+++ /dev/null
@@ -1,1512 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "actions/actions-suggestions.h"
-
-#include <fstream>
-#include <iterator>
-#include <memory>
-
-#include "actions/actions_model_generated.h"
-#include "actions/test-utils.h"
-#include "actions/zlib-utils.h"
-#include "annotator/collections.h"
-#include "annotator/types.h"
-#include "utils/flatbuffers.h"
-#include "utils/flatbuffers_generated.h"
-#include "utils/grammar/utils/rules.h"
-#include "utils/hash/farmhash.h"
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-#include "flatbuffers/flatbuffers.h"
-#include "flatbuffers/reflection.h"
-
-namespace libtextclassifier3 {
-namespace {
-
-using testing::ElementsAre;
-
-constexpr char kModelFileName[] = "actions_suggestions_test.model";
-constexpr char kHashGramModelFileName[] =
- "actions_suggestions_test.hashgram.model";
-
-std::string ReadFile(const std::string& file_name) {
- std::ifstream file_stream(file_name);
- return std::string(std::istreambuf_iterator<char>(file_stream), {});
-}
-
-std::string GetModelPath() {
- return "";
-}
-
-class ActionsSuggestionsTest : public testing::Test {
- protected:
- ActionsSuggestionsTest() : INIT_UNILIB_FOR_TESTING(unilib_) {}
- std::unique_ptr<ActionsSuggestions> LoadTestModel() {
- return ActionsSuggestions::FromPath(GetModelPath() + kModelFileName,
- &unilib_);
- }
- std::unique_ptr<ActionsSuggestions> LoadHashGramTestModel() {
- return ActionsSuggestions::FromPath(GetModelPath() + kHashGramModelFileName,
- &unilib_);
- }
- UniLib unilib_;
-};
-
-TEST_F(ActionsSuggestionsTest, InstantiateActionSuggestions) {
- EXPECT_THAT(LoadTestModel(), testing::NotNull());
-}
-
-TEST_F(ActionsSuggestionsTest, SuggestActions) {
- std::unique_ptr<ActionsSuggestions> actions_suggestions = LoadTestModel();
- const ActionsSuggestionsResponse& response =
- actions_suggestions->SuggestActions(
- {{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
- /*reference_timezone=*/"Europe/Zurich",
- /*annotations=*/{}, /*locales=*/"en"}}});
- EXPECT_EQ(response.actions.size(), 3 /* share_location + 2 smart replies*/);
-}
-
-TEST_F(ActionsSuggestionsTest, SuggestNoActionsForUnknownLocale) {
- std::unique_ptr<ActionsSuggestions> actions_suggestions = LoadTestModel();
- const ActionsSuggestionsResponse& response =
- actions_suggestions->SuggestActions(
- {{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
- /*reference_timezone=*/"Europe/Zurich",
- /*annotations=*/{}, /*locales=*/"zz"}}});
- EXPECT_THAT(response.actions, testing::IsEmpty());
-}
-
-TEST_F(ActionsSuggestionsTest, SuggestActionsFromAnnotations) {
- std::unique_ptr<ActionsSuggestions> actions_suggestions = LoadTestModel();
- AnnotatedSpan annotation;
- annotation.span = {11, 15};
- annotation.classification = {ClassificationResult("address", 1.0)};
- const ActionsSuggestionsResponse& response =
- actions_suggestions->SuggestActions(
- {{{/*user_id=*/1, "are you at home?",
- /*reference_time_ms_utc=*/0,
- /*reference_timezone=*/"Europe/Zurich",
- /*annotations=*/{annotation},
- /*locales=*/"en"}}});
- ASSERT_GE(response.actions.size(), 1);
- EXPECT_EQ(response.actions.front().type, "view_map");
- EXPECT_EQ(response.actions.front().score, 1.0);
-}
-
-TEST_F(ActionsSuggestionsTest, SuggestActionsFromAnnotationsWithEntityData) {
- const std::string actions_model_string =
- ReadFile(GetModelPath() + kModelFileName);
- std::unique_ptr<ActionsModelT> actions_model =
- UnPackActionsModel(actions_model_string.c_str());
- SetTestEntityDataSchema(actions_model.get());
-
- // Set custom actions from annotations config.
- actions_model->annotation_actions_spec->annotation_mapping.clear();
- actions_model->annotation_actions_spec->annotation_mapping.emplace_back(
- new AnnotationActionsSpec_::AnnotationMappingT);
- AnnotationActionsSpec_::AnnotationMappingT* mapping =
- actions_model->annotation_actions_spec->annotation_mapping.back().get();
- mapping->annotation_collection = "address";
- mapping->action.reset(new ActionSuggestionSpecT);
- mapping->action->type = "save_location";
- mapping->action->score = 1.0;
- mapping->action->priority_score = 2.0;
- mapping->entity_field.reset(new FlatbufferFieldPathT);
- mapping->entity_field->field.emplace_back(new FlatbufferFieldT);
- mapping->entity_field->field.back()->field_name = "location";
-
- flatbuffers::FlatBufferBuilder builder;
- FinishActionsModelBuffer(builder,
- ActionsModel::Pack(builder, actions_model.get()));
- std::unique_ptr<ActionsSuggestions> actions_suggestions =
- ActionsSuggestions::FromUnownedBuffer(
- reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
- builder.GetSize(), &unilib_);
-
- AnnotatedSpan annotation;
- annotation.span = {11, 15};
- annotation.classification = {ClassificationResult("address", 1.0)};
- const ActionsSuggestionsResponse& response =
- actions_suggestions->SuggestActions(
- {{{/*user_id=*/1, "are you at home?",
- /*reference_time_ms_utc=*/0,
- /*reference_timezone=*/"Europe/Zurich",
- /*annotations=*/{annotation},
- /*locales=*/"en"}}});
- ASSERT_GE(response.actions.size(), 1);
- EXPECT_EQ(response.actions.front().type, "save_location");
- EXPECT_EQ(response.actions.front().score, 1.0);
-
- // Check that the `location` entity field holds the text from the address
- // annotation.
- const flatbuffers::Table* entity =
- flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
- response.actions.front().serialized_entity_data.data()));
- EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/6)->str(),
- "home");
-}
-
-TEST_F(ActionsSuggestionsTest, SuggestActionsFromAnnotationsNormalization) {
- const std::string actions_model_string =
- ReadFile(GetModelPath() + kModelFileName);
- std::unique_ptr<ActionsModelT> actions_model =
- UnPackActionsModel(actions_model_string.c_str());
- SetTestEntityDataSchema(actions_model.get());
-
- // Set custom actions from annotations config.
- actions_model->annotation_actions_spec->annotation_mapping.clear();
- actions_model->annotation_actions_spec->annotation_mapping.emplace_back(
- new AnnotationActionsSpec_::AnnotationMappingT);
- AnnotationActionsSpec_::AnnotationMappingT* mapping =
- actions_model->annotation_actions_spec->annotation_mapping.back().get();
- mapping->annotation_collection = "address";
- mapping->action.reset(new ActionSuggestionSpecT);
- mapping->action->type = "save_location";
- mapping->action->score = 1.0;
- mapping->action->priority_score = 2.0;
- mapping->entity_field.reset(new FlatbufferFieldPathT);
- mapping->entity_field->field.emplace_back(new FlatbufferFieldT);
- mapping->entity_field->field.back()->field_name = "location";
- mapping->normalization_options.reset(new NormalizationOptionsT);
- mapping->normalization_options->codepointwise_normalization =
- NormalizationOptions_::CodepointwiseNormalizationOp_UPPERCASE;
-
- flatbuffers::FlatBufferBuilder builder;
- FinishActionsModelBuffer(builder,
- ActionsModel::Pack(builder, actions_model.get()));
- std::unique_ptr<ActionsSuggestions> actions_suggestions =
- ActionsSuggestions::FromUnownedBuffer(
- reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
- builder.GetSize(), &unilib_);
-
- AnnotatedSpan annotation;
- annotation.span = {11, 15};
- annotation.classification = {ClassificationResult("address", 1.0)};
- const ActionsSuggestionsResponse& response =
- actions_suggestions->SuggestActions(
- {{{/*user_id=*/1, "are you at home?",
- /*reference_time_ms_utc=*/0,
- /*reference_timezone=*/"Europe/Zurich",
- /*annotations=*/{annotation},
- /*locales=*/"en"}}});
- ASSERT_GE(response.actions.size(), 1);
- EXPECT_EQ(response.actions.front().type, "save_location");
- EXPECT_EQ(response.actions.front().score, 1.0);
-
- // Check that the `location` entity field holds the normalized text of the
- // annotation.
- const flatbuffers::Table* entity =
- flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
- response.actions.front().serialized_entity_data.data()));
- EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/6)->str(),
- "HOME");
-}
-
-TEST_F(ActionsSuggestionsTest, SuggestActionsFromDuplicatedAnnotations) {
- std::unique_ptr<ActionsSuggestions> actions_suggestions = LoadTestModel();
- AnnotatedSpan flight_annotation;
- flight_annotation.span = {11, 15};
- flight_annotation.classification = {ClassificationResult("flight", 2.5)};
- AnnotatedSpan flight_annotation2;
- flight_annotation2.span = {35, 39};
- flight_annotation2.classification = {ClassificationResult("flight", 3.0)};
- AnnotatedSpan email_annotation;
- email_annotation.span = {43, 56};
- email_annotation.classification = {ClassificationResult("email", 2.0)};
-
- const ActionsSuggestionsResponse& response =
- actions_suggestions->SuggestActions(
- {{{/*user_id=*/1,
- "call me at LX38 or send message to LX38 or test@test.com.",
- /*reference_time_ms_utc=*/0,
- /*reference_timezone=*/"Europe/Zurich",
- /*annotations=*/
- {flight_annotation, flight_annotation2, email_annotation},
- /*locales=*/"en"}}});
-
- ASSERT_GE(response.actions.size(), 2);
- EXPECT_EQ(response.actions[0].type, "track_flight");
- EXPECT_EQ(response.actions[0].score, 3.0);
- EXPECT_EQ(response.actions[1].type, "send_email");
- EXPECT_EQ(response.actions[1].score, 2.0);
-}
-
-TEST_F(ActionsSuggestionsTest, SuggestActionsAnnotationsNoDeduplication) {
- const std::string actions_model_string =
- ReadFile(GetModelPath() + kModelFileName);
- std::unique_ptr<ActionsModelT> actions_model =
- UnPackActionsModel(actions_model_string.c_str());
- // Disable deduplication.
- actions_model->annotation_actions_spec->deduplicate_annotations = false;
- flatbuffers::FlatBufferBuilder builder;
- FinishActionsModelBuffer(builder,
- ActionsModel::Pack(builder, actions_model.get()));
- std::unique_ptr<ActionsSuggestions> actions_suggestions =
- ActionsSuggestions::FromUnownedBuffer(
- reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
- builder.GetSize(), &unilib_);
- AnnotatedSpan flight_annotation;
- flight_annotation.span = {11, 15};
- flight_annotation.classification = {ClassificationResult("flight", 2.5)};
- AnnotatedSpan flight_annotation2;
- flight_annotation2.span = {35, 39};
- flight_annotation2.classification = {ClassificationResult("flight", 3.0)};
- AnnotatedSpan email_annotation;
- email_annotation.span = {43, 56};
- email_annotation.classification = {ClassificationResult("email", 2.0)};
-
- const ActionsSuggestionsResponse& response =
- actions_suggestions->SuggestActions(
- {{{/*user_id=*/1,
- "call me at LX38 or send message to LX38 or test@test.com.",
- /*reference_time_ms_utc=*/0,
- /*reference_timezone=*/"Europe/Zurich",
- /*annotations=*/
- {flight_annotation, flight_annotation2, email_annotation},
- /*locales=*/"en"}}});
-
- ASSERT_GE(response.actions.size(), 3);
- EXPECT_EQ(response.actions[0].type, "track_flight");
- EXPECT_EQ(response.actions[0].score, 3.0);
- EXPECT_EQ(response.actions[1].type, "track_flight");
- EXPECT_EQ(response.actions[1].score, 2.5);
- EXPECT_EQ(response.actions[2].type, "send_email");
- EXPECT_EQ(response.actions[2].score, 2.0);
-}
-
-ActionsSuggestionsResponse TestSuggestActionsFromAnnotations(
- const std::function<void(ActionsModelT*)>& set_config_fn,
- const UniLib* unilib = nullptr) {
- const std::string actions_model_string =
- ReadFile(GetModelPath() + kModelFileName);
- std::unique_ptr<ActionsModelT> actions_model =
- UnPackActionsModel(actions_model_string.c_str());
-
- // Set custom config.
- set_config_fn(actions_model.get());
-
- // Disable smart reply for easier testing.
- actions_model->preconditions->min_smart_reply_triggering_score = 1.0;
-
- flatbuffers::FlatBufferBuilder builder;
- FinishActionsModelBuffer(builder,
- ActionsModel::Pack(builder, actions_model.get()));
- std::unique_ptr<ActionsSuggestions> actions_suggestions =
- ActionsSuggestions::FromUnownedBuffer(
- reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
- builder.GetSize(), unilib);
-
- AnnotatedSpan flight_annotation;
- flight_annotation.span = {15, 19};
- flight_annotation.classification = {ClassificationResult("flight", 2.0)};
- AnnotatedSpan email_annotation;
- email_annotation.span = {0, 16};
- email_annotation.classification = {ClassificationResult("email", 1.0)};
-
- return actions_suggestions->SuggestActions(
- {{{/*user_id=*/ActionsSuggestions::kLocalUserId,
- "hehe@android.com",
- /*reference_time_ms_utc=*/0,
- /*reference_timezone=*/"Europe/Zurich",
- /*annotations=*/
- {email_annotation},
- /*locales=*/"en"},
- {/*user_id=*/2,
- "yoyo@android.com",
- /*reference_time_ms_utc=*/0,
- /*reference_timezone=*/"Europe/Zurich",
- /*annotations=*/
- {email_annotation},
- /*locales=*/"en"},
- {/*user_id=*/1,
- "test@android.com",
- /*reference_time_ms_utc=*/0,
- /*reference_timezone=*/"Europe/Zurich",
- /*annotations=*/
- {email_annotation},
- /*locales=*/"en"},
- {/*user_id=*/1,
- "I am on flight LX38.",
- /*reference_time_ms_utc=*/0,
- /*reference_timezone=*/"Europe/Zurich",
- /*annotations=*/
- {flight_annotation},
- /*locales=*/"en"}}});
-}
-
-TEST_F(ActionsSuggestionsTest, SuggestActionsWithAnnotationsOnlyLastMessage) {
- const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations(
- [](ActionsModelT* actions_model) {
- actions_model->annotation_actions_spec->include_local_user_messages =
- false;
- actions_model->annotation_actions_spec->only_until_last_sent = true;
- actions_model->annotation_actions_spec->max_history_from_any_person = 1;
- actions_model->annotation_actions_spec->max_history_from_last_person =
- 1;
- },
- &unilib_);
- EXPECT_EQ(response.actions.size(), 1);
- EXPECT_EQ(response.actions[0].type, "track_flight");
-}
-
-TEST_F(ActionsSuggestionsTest, SuggestActionsWithAnnotationsOnlyLastPerson) {
- const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations(
- [](ActionsModelT* actions_model) {
- actions_model->annotation_actions_spec->include_local_user_messages =
- false;
- actions_model->annotation_actions_spec->only_until_last_sent = true;
- actions_model->annotation_actions_spec->max_history_from_any_person = 1;
- actions_model->annotation_actions_spec->max_history_from_last_person =
- 3;
- },
- &unilib_);
- EXPECT_EQ(response.actions.size(), 2);
- EXPECT_EQ(response.actions[0].type, "track_flight");
- EXPECT_EQ(response.actions[1].type, "send_email");
-}
-
-TEST_F(ActionsSuggestionsTest, SuggestActionsWithAnnotationsFromAny) {
- const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations(
- [](ActionsModelT* actions_model) {
- actions_model->annotation_actions_spec->include_local_user_messages =
- false;
- actions_model->annotation_actions_spec->only_until_last_sent = true;
- actions_model->annotation_actions_spec->max_history_from_any_person = 2;
- actions_model->annotation_actions_spec->max_history_from_last_person =
- 1;
- },
- &unilib_);
- EXPECT_EQ(response.actions.size(), 2);
- EXPECT_EQ(response.actions[0].type, "track_flight");
- EXPECT_EQ(response.actions[1].type, "send_email");
-}
-
-TEST_F(ActionsSuggestionsTest,
- SuggestActionsWithAnnotationsFromAnyManyMessages) {
- const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations(
- [](ActionsModelT* actions_model) {
- actions_model->annotation_actions_spec->include_local_user_messages =
- false;
- actions_model->annotation_actions_spec->only_until_last_sent = true;
- actions_model->annotation_actions_spec->max_history_from_any_person = 3;
- actions_model->annotation_actions_spec->max_history_from_last_person =
- 1;
- },
- &unilib_);
- EXPECT_EQ(response.actions.size(), 3);
- EXPECT_EQ(response.actions[0].type, "track_flight");
- EXPECT_EQ(response.actions[1].type, "send_email");
- EXPECT_EQ(response.actions[2].type, "send_email");
-}
-
-TEST_F(ActionsSuggestionsTest,
- SuggestActionsWithAnnotationsFromAnyManyMessagesButNotLocalUser) {
- const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations(
- [](ActionsModelT* actions_model) {
- actions_model->annotation_actions_spec->include_local_user_messages =
- false;
- actions_model->annotation_actions_spec->only_until_last_sent = true;
- actions_model->annotation_actions_spec->max_history_from_any_person = 5;
- actions_model->annotation_actions_spec->max_history_from_last_person =
- 1;
- },
- &unilib_);
- EXPECT_EQ(response.actions.size(), 3);
- EXPECT_EQ(response.actions[0].type, "track_flight");
- EXPECT_EQ(response.actions[1].type, "send_email");
- EXPECT_EQ(response.actions[2].type, "send_email");
-}
-
-TEST_F(ActionsSuggestionsTest,
- SuggestActionsWithAnnotationsFromAnyManyMessagesAlsoFromLocalUser) {
- const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations(
- [](ActionsModelT* actions_model) {
- actions_model->annotation_actions_spec->include_local_user_messages =
- true;
- actions_model->annotation_actions_spec->only_until_last_sent = false;
- actions_model->annotation_actions_spec->max_history_from_any_person = 5;
- actions_model->annotation_actions_spec->max_history_from_last_person =
- 1;
- },
- &unilib_);
- EXPECT_EQ(response.actions.size(), 4);
- EXPECT_EQ(response.actions[0].type, "track_flight");
- EXPECT_EQ(response.actions[1].type, "send_email");
- EXPECT_EQ(response.actions[2].type, "send_email");
- EXPECT_EQ(response.actions[3].type, "send_email");
-}
-
-void TestSuggestActionsWithThreshold(
- const std::function<void(ActionsModelT*)>& set_value_fn,
- const UniLib* unilib = nullptr, const int expected_size = 0,
- const std::string& preconditions_overwrite = "") {
- const std::string actions_model_string =
- ReadFile(GetModelPath() + kModelFileName);
- std::unique_ptr<ActionsModelT> actions_model =
- UnPackActionsModel(actions_model_string.c_str());
- set_value_fn(actions_model.get());
- flatbuffers::FlatBufferBuilder builder;
- FinishActionsModelBuffer(builder,
- ActionsModel::Pack(builder, actions_model.get()));
- std::unique_ptr<ActionsSuggestions> actions_suggestions =
- ActionsSuggestions::FromUnownedBuffer(
- reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
- builder.GetSize(), unilib, preconditions_overwrite);
- ASSERT_TRUE(actions_suggestions);
- const ActionsSuggestionsResponse& response =
- actions_suggestions->SuggestActions(
- {{{/*user_id=*/1, "I have the low-ground. Where are you?",
- /*reference_time_ms_utc=*/0,
- /*reference_timezone=*/"Europe/Zurich",
- /*annotations=*/{}, /*locales=*/"en"}}});
- EXPECT_LE(response.actions.size(), expected_size);
-}
-
-TEST_F(ActionsSuggestionsTest, SuggestActionsWithTriggeringScore) {
- TestSuggestActionsWithThreshold(
- [](ActionsModelT* actions_model) {
- actions_model->preconditions->min_smart_reply_triggering_score = 1.0;
- },
- &unilib_,
- /*expected_size=*/1 /*no smart reply, only actions*/
- );
-}
-
-TEST_F(ActionsSuggestionsTest, SuggestActionsWithMinReplyScore) {
- TestSuggestActionsWithThreshold(
- [](ActionsModelT* actions_model) {
- actions_model->preconditions->min_reply_score_threshold = 1.0;
- },
- &unilib_,
- /*expected_size=*/1 /*no smart reply, only actions*/
- );
-}
-
-TEST_F(ActionsSuggestionsTest, SuggestActionsWithSensitiveTopicScore) {
- TestSuggestActionsWithThreshold(
- [](ActionsModelT* actions_model) {
- actions_model->preconditions->max_sensitive_topic_score = 0.0;
- },
- &unilib_,
- /*expected_size=*/4 /* no sensitive prediction in test model*/);
-}
-
-TEST_F(ActionsSuggestionsTest, SuggestActionsWithMaxInputLength) {
- TestSuggestActionsWithThreshold(
- [](ActionsModelT* actions_model) {
- actions_model->preconditions->max_input_length = 0;
- },
- &unilib_);
-}
-
-TEST_F(ActionsSuggestionsTest, SuggestActionsWithMinInputLength) {
- TestSuggestActionsWithThreshold(
- [](ActionsModelT* actions_model) {
- actions_model->preconditions->min_input_length = 100;
- },
- &unilib_);
-}
-
-TEST_F(ActionsSuggestionsTest, SuggestActionsWithPreconditionsOverwrite) {
- TriggeringPreconditionsT preconditions_overwrite;
- preconditions_overwrite.max_input_length = 0;
- flatbuffers::FlatBufferBuilder builder;
- builder.Finish(
- TriggeringPreconditions::Pack(builder, &preconditions_overwrite));
- TestSuggestActionsWithThreshold(
- // Keep model untouched.
- [](ActionsModelT* actions_model) {}, &unilib_,
- /*expected_size=*/0,
- std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
- builder.GetSize()));
-}
-
-#ifdef TC3_UNILIB_ICU
-TEST_F(ActionsSuggestionsTest, SuggestActionsLowConfidence) {
- TestSuggestActionsWithThreshold(
- [](ActionsModelT* actions_model) {
- actions_model->preconditions->suppress_on_low_confidence_input = true;
- actions_model->low_confidence_rules.reset(new RulesModelT);
- actions_model->low_confidence_rules->regex_rule.emplace_back(
- new RulesModel_::RegexRuleT);
- actions_model->low_confidence_rules->regex_rule.back()->pattern =
- "low-ground";
- },
- &unilib_);
-}
-
-TEST_F(ActionsSuggestionsTest, SuggestActionsLowConfidenceInputOutput) {
- const std::string actions_model_string =
- ReadFile(GetModelPath() + kModelFileName);
- std::unique_ptr<ActionsModelT> actions_model =
- UnPackActionsModel(actions_model_string.c_str());
- // Add custom triggering rule.
- actions_model->rules.reset(new RulesModelT());
- actions_model->rules->regex_rule.emplace_back(new RulesModel_::RegexRuleT);
- RulesModel_::RegexRuleT* rule = actions_model->rules->regex_rule.back().get();
- rule->pattern = "^(?i:hello\\s(there))$";
- {
- std::unique_ptr<RulesModel_::RuleActionSpecT> rule_action(
- new RulesModel_::RuleActionSpecT);
- rule_action->action.reset(new ActionSuggestionSpecT);
- rule_action->action->type = "text_reply";
- rule_action->action->response_text = "General Desaster!";
- rule_action->action->score = 1.0f;
- rule_action->action->priority_score = 1.0f;
- rule->actions.push_back(std::move(rule_action));
- }
- {
- std::unique_ptr<RulesModel_::RuleActionSpecT> rule_action(
- new RulesModel_::RuleActionSpecT);
- rule_action->action.reset(new ActionSuggestionSpecT);
- rule_action->action->type = "text_reply";
- rule_action->action->response_text = "General Kenobi!";
- rule_action->action->score = 1.0f;
- rule_action->action->priority_score = 1.0f;
- rule->actions.push_back(std::move(rule_action));
- }
-
- // Add input-output low confidence rule.
- actions_model->preconditions->suppress_on_low_confidence_input = true;
- actions_model->low_confidence_rules.reset(new RulesModelT);
- actions_model->low_confidence_rules->regex_rule.emplace_back(
- new RulesModel_::RegexRuleT);
- actions_model->low_confidence_rules->regex_rule.back()->pattern = "hello";
- actions_model->low_confidence_rules->regex_rule.back()->output_pattern =
- "(?i:desaster)";
-
- flatbuffers::FlatBufferBuilder builder;
- FinishActionsModelBuffer(builder,
- ActionsModel::Pack(builder, actions_model.get()));
- std::unique_ptr<ActionsSuggestions> actions_suggestions =
- ActionsSuggestions::FromUnownedBuffer(
- reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
- builder.GetSize(), &unilib_);
- ASSERT_TRUE(actions_suggestions);
- const ActionsSuggestionsResponse& response =
- actions_suggestions->SuggestActions(
- {{{/*user_id=*/1, "hello there",
- /*reference_time_ms_utc=*/0,
- /*reference_timezone=*/"Europe/Zurich",
- /*annotations=*/{}, /*locales=*/"en"}}});
- ASSERT_GE(response.actions.size(), 1);
- EXPECT_EQ(response.actions[0].response_text, "General Kenobi!");
-}
-
-TEST_F(ActionsSuggestionsTest,
- SuggestActionsLowConfidenceInputOutputOverwrite) {
- const std::string actions_model_string =
- ReadFile(GetModelPath() + kModelFileName);
- std::unique_ptr<ActionsModelT> actions_model =
- UnPackActionsModel(actions_model_string.c_str());
- actions_model->low_confidence_rules.reset();
-
- // Add custom triggering rule.
- actions_model->rules.reset(new RulesModelT());
- actions_model->rules->regex_rule.emplace_back(new RulesModel_::RegexRuleT);
- RulesModel_::RegexRuleT* rule = actions_model->rules->regex_rule.back().get();
- rule->pattern = "^(?i:hello\\s(there))$";
- {
- std::unique_ptr<RulesModel_::RuleActionSpecT> rule_action(
- new RulesModel_::RuleActionSpecT);
- rule_action->action.reset(new ActionSuggestionSpecT);
- rule_action->action->type = "text_reply";
- rule_action->action->response_text = "General Desaster!";
- rule_action->action->score = 1.0f;
- rule_action->action->priority_score = 1.0f;
- rule->actions.push_back(std::move(rule_action));
- }
- {
- std::unique_ptr<RulesModel_::RuleActionSpecT> rule_action(
- new RulesModel_::RuleActionSpecT);
- rule_action->action.reset(new ActionSuggestionSpecT);
- rule_action->action->type = "text_reply";
- rule_action->action->response_text = "General Kenobi!";
- rule_action->action->score = 1.0f;
- rule_action->action->priority_score = 1.0f;
- rule->actions.push_back(std::move(rule_action));
- }
-
- // Add custom triggering rule via overwrite.
- actions_model->preconditions->low_confidence_rules.reset();
- TriggeringPreconditionsT preconditions;
- preconditions.suppress_on_low_confidence_input = true;
- preconditions.low_confidence_rules.reset(new RulesModelT);
- preconditions.low_confidence_rules->regex_rule.emplace_back(
- new RulesModel_::RegexRuleT);
- preconditions.low_confidence_rules->regex_rule.back()->pattern = "hello";
- preconditions.low_confidence_rules->regex_rule.back()->output_pattern =
- "(?i:desaster)";
- flatbuffers::FlatBufferBuilder preconditions_builder;
- preconditions_builder.Finish(
- TriggeringPreconditions::Pack(preconditions_builder, &preconditions));
- std::string serialize_preconditions = std::string(
- reinterpret_cast<const char*>(preconditions_builder.GetBufferPointer()),
- preconditions_builder.GetSize());
-
- flatbuffers::FlatBufferBuilder builder;
- FinishActionsModelBuffer(builder,
- ActionsModel::Pack(builder, actions_model.get()));
- std::unique_ptr<ActionsSuggestions> actions_suggestions =
- ActionsSuggestions::FromUnownedBuffer(
- reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
- builder.GetSize(), &unilib_, serialize_preconditions);
-
- ASSERT_TRUE(actions_suggestions);
- const ActionsSuggestionsResponse& response =
- actions_suggestions->SuggestActions(
- {{{/*user_id=*/1, "hello there",
- /*reference_time_ms_utc=*/0,
- /*reference_timezone=*/"Europe/Zurich",
- /*annotations=*/{}, /*locales=*/"en"}}});
- ASSERT_GE(response.actions.size(), 1);
- EXPECT_EQ(response.actions[0].response_text, "General Kenobi!");
-}
-#endif
-
-TEST_F(ActionsSuggestionsTest, SuppressActionsFromAnnotationsOnSensitiveTopic) {
- const std::string actions_model_string =
- ReadFile(GetModelPath() + kModelFileName);
- std::unique_ptr<ActionsModelT> actions_model =
- UnPackActionsModel(actions_model_string.c_str());
-
- // Don't test if no sensitivity score is produced
- if (actions_model->tflite_model_spec->output_sensitive_topic_score < 0) {
- return;
- }
-
- actions_model->preconditions->max_sensitive_topic_score = 0.0;
- actions_model->preconditions->suppress_on_sensitive_topic = true;
- flatbuffers::FlatBufferBuilder builder;
- FinishActionsModelBuffer(builder,
- ActionsModel::Pack(builder, actions_model.get()));
- std::unique_ptr<ActionsSuggestions> actions_suggestions =
- ActionsSuggestions::FromUnownedBuffer(
- reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
- builder.GetSize(), &unilib_);
- AnnotatedSpan annotation;
- annotation.span = {11, 15};
- annotation.classification = {
- ClassificationResult(Collections::Address(), 1.0)};
- const ActionsSuggestionsResponse& response =
- actions_suggestions->SuggestActions(
- {{{/*user_id=*/1, "are you at home?",
- /*reference_time_ms_utc=*/0,
- /*reference_timezone=*/"Europe/Zurich",
- /*annotations=*/{annotation},
- /*locales=*/"en"}}});
- EXPECT_THAT(response.actions, testing::IsEmpty());
-}
-
-TEST_F(ActionsSuggestionsTest, SuggestActionsWithLongerConversation) {
- const std::string actions_model_string =
- ReadFile(GetModelPath() + kModelFileName);
- std::unique_ptr<ActionsModelT> actions_model =
- UnPackActionsModel(actions_model_string.c_str());
-
- // Allow a larger conversation context.
- actions_model->max_conversation_history_length = 10;
-
- flatbuffers::FlatBufferBuilder builder;
- FinishActionsModelBuffer(builder,
- ActionsModel::Pack(builder, actions_model.get()));
- std::unique_ptr<ActionsSuggestions> actions_suggestions =
- ActionsSuggestions::FromUnownedBuffer(
- reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
- builder.GetSize(), &unilib_);
- AnnotatedSpan annotation;
- annotation.span = {11, 15};
- annotation.classification = {
- ClassificationResult(Collections::Address(), 1.0)};
- const ActionsSuggestionsResponse& response =
- actions_suggestions->SuggestActions(
- {{{/*user_id=*/ActionsSuggestions::kLocalUserId, "hi, how are you?",
- /*reference_time_ms_utc=*/10000,
- /*reference_timezone=*/"Europe/Zurich",
- /*annotations=*/{}, /*locales=*/"en"},
- {/*user_id=*/1, "good! are you at home?",
- /*reference_time_ms_utc=*/15000,
- /*reference_timezone=*/"Europe/Zurich",
- /*annotations=*/{annotation},
- /*locales=*/"en"}}});
- ASSERT_GE(response.actions.size(), 1);
- EXPECT_EQ(response.actions[0].type, "view_map");
- EXPECT_EQ(response.actions[0].score, 1.0);
-}
-
-TEST_F(ActionsSuggestionsTest, CreateActionsFromClassificationResult) {
- std::unique_ptr<ActionsSuggestions> actions_suggestions = LoadTestModel();
- AnnotatedSpan annotation;
- annotation.span = {8, 12};
- annotation.classification = {
- ClassificationResult(Collections::Flight(), 1.0)};
-
- const ActionsSuggestionsResponse& response =
- actions_suggestions->SuggestActions(
- {{{/*user_id=*/1, "I'm on LX38?",
- /*reference_time_ms_utc=*/0,
- /*reference_timezone=*/"Europe/Zurich",
- /*annotations=*/{annotation},
- /*locales=*/"en"}}});
-
- ASSERT_GE(response.actions.size(), 2);
- EXPECT_EQ(response.actions[0].type, "track_flight");
- EXPECT_EQ(response.actions[0].score, 1.0);
- EXPECT_EQ(response.actions[0].annotations.size(), 1);
- EXPECT_EQ(response.actions[0].annotations[0].span.message_index, 0);
- EXPECT_EQ(response.actions[0].annotations[0].span.span, annotation.span);
-}
-
-#ifdef TC3_UNILIB_ICU
-TEST_F(ActionsSuggestionsTest, CreateActionsFromRules) {
- const std::string actions_model_string =
- ReadFile(GetModelPath() + kModelFileName);
- std::unique_ptr<ActionsModelT> actions_model =
- UnPackActionsModel(actions_model_string.c_str());
- ASSERT_TRUE(DecompressActionsModel(actions_model.get()));
-
- actions_model->rules.reset(new RulesModelT());
- actions_model->rules->regex_rule.emplace_back(new RulesModel_::RegexRuleT);
- RulesModel_::RegexRuleT* rule = actions_model->rules->regex_rule.back().get();
- rule->pattern = "^(?i:hello\\s(there))$";
- rule->actions.emplace_back(new RulesModel_::RuleActionSpecT);
- rule->actions.back()->action.reset(new ActionSuggestionSpecT);
- ActionSuggestionSpecT* action = rule->actions.back()->action.get();
- action->type = "text_reply";
- action->response_text = "General Kenobi!";
- action->score = 1.0f;
- action->priority_score = 1.0f;
-
- // Set capturing groups for entity data.
- rule->actions.back()->capturing_group.emplace_back(
- new RulesModel_::RuleActionSpec_::RuleCapturingGroupT);
- RulesModel_::RuleActionSpec_::RuleCapturingGroupT* greeting_group =
- rule->actions.back()->capturing_group.back().get();
- greeting_group->group_id = 0;
- greeting_group->entity_field.reset(new FlatbufferFieldPathT);
- greeting_group->entity_field->field.emplace_back(new FlatbufferFieldT);
- greeting_group->entity_field->field.back()->field_name = "greeting";
- rule->actions.back()->capturing_group.emplace_back(
- new RulesModel_::RuleActionSpec_::RuleCapturingGroupT);
- RulesModel_::RuleActionSpec_::RuleCapturingGroupT* location_group =
- rule->actions.back()->capturing_group.back().get();
- location_group->group_id = 1;
- location_group->entity_field.reset(new FlatbufferFieldPathT);
- location_group->entity_field->field.emplace_back(new FlatbufferFieldT);
- location_group->entity_field->field.back()->field_name = "location";
-
- // Set test entity data schema.
- SetTestEntityDataSchema(actions_model.get());
-
- // Use meta data to generate custom serialized entity data.
- ReflectiveFlatbufferBuilder entity_data_builder(
- flatbuffers::GetRoot<reflection::Schema>(
- actions_model->actions_entity_data_schema.data()));
- std::unique_ptr<ReflectiveFlatbuffer> entity_data =
- entity_data_builder.NewRoot();
- entity_data->Set("person", "Kenobi");
- action->serialized_entity_data = entity_data->Serialize();
-
- flatbuffers::FlatBufferBuilder builder;
- FinishActionsModelBuffer(builder,
- ActionsModel::Pack(builder, actions_model.get()));
- std::unique_ptr<ActionsSuggestions> actions_suggestions =
- ActionsSuggestions::FromUnownedBuffer(
- reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
- builder.GetSize(), &unilib_);
-
- const ActionsSuggestionsResponse& response =
- actions_suggestions->SuggestActions(
- {{{/*user_id=*/1, "hello there", /*reference_time_ms_utc=*/0,
- /*reference_timezone=*/"Europe/Zurich",
- /*annotations=*/{}, /*locales=*/"en"}}});
- EXPECT_GE(response.actions.size(), 1);
- EXPECT_EQ(response.actions[0].response_text, "General Kenobi!");
-
- // Check entity data.
- const flatbuffers::Table* entity =
- flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
- response.actions[0].serialized_entity_data.data()));
- EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(),
- "hello there");
- EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/6)->str(),
- "there");
- EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(),
- "Kenobi");
-}
-
-TEST_F(ActionsSuggestionsTest, CreateActionsFromRulesWithNormalization) {
- const std::string actions_model_string =
- ReadFile(GetModelPath() + kModelFileName);
- std::unique_ptr<ActionsModelT> actions_model =
- UnPackActionsModel(actions_model_string.c_str());
- ASSERT_TRUE(DecompressActionsModel(actions_model.get()));
-
- actions_model->rules.reset(new RulesModelT());
- actions_model->rules->regex_rule.emplace_back(new RulesModel_::RegexRuleT);
- RulesModel_::RegexRuleT* rule = actions_model->rules->regex_rule.back().get();
- rule->pattern = "^(?i:hello\\sthere)$";
- rule->actions.emplace_back(new RulesModel_::RuleActionSpecT);
- rule->actions.back()->action.reset(new ActionSuggestionSpecT);
- ActionSuggestionSpecT* action = rule->actions.back()->action.get();
- action->type = "text_reply";
- action->response_text = "General Kenobi!";
- action->score = 1.0f;
- action->priority_score = 1.0f;
-
- // Set capturing groups for entity data.
- rule->actions.back()->capturing_group.emplace_back(
- new RulesModel_::RuleActionSpec_::RuleCapturingGroupT);
- RulesModel_::RuleActionSpec_::RuleCapturingGroupT* greeting_group =
- rule->actions.back()->capturing_group.back().get();
- greeting_group->group_id = 0;
- greeting_group->entity_field.reset(new FlatbufferFieldPathT);
- greeting_group->entity_field->field.emplace_back(new FlatbufferFieldT);
- greeting_group->entity_field->field.back()->field_name = "greeting";
- greeting_group->normalization_options.reset(new NormalizationOptionsT);
- greeting_group->normalization_options->codepointwise_normalization =
- NormalizationOptions_::CodepointwiseNormalizationOp_DROP_WHITESPACE |
- NormalizationOptions_::CodepointwiseNormalizationOp_UPPERCASE;
-
- // Set test entity data schema.
- SetTestEntityDataSchema(actions_model.get());
-
- flatbuffers::FlatBufferBuilder builder;
- FinishActionsModelBuffer(builder,
- ActionsModel::Pack(builder, actions_model.get()));
- std::unique_ptr<ActionsSuggestions> actions_suggestions =
- ActionsSuggestions::FromUnownedBuffer(
- reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
- builder.GetSize(), &unilib_);
-
- const ActionsSuggestionsResponse& response =
- actions_suggestions->SuggestActions(
- {{{/*user_id=*/1, "hello there", /*reference_time_ms_utc=*/0,
- /*reference_timezone=*/"Europe/Zurich",
- /*annotations=*/{}, /*locales=*/"en"}}});
- EXPECT_GE(response.actions.size(), 1);
- EXPECT_EQ(response.actions[0].response_text, "General Kenobi!");
-
- // Check entity data.
- const flatbuffers::Table* entity =
- flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
- response.actions[0].serialized_entity_data.data()));
- EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(),
- "HELLOTHERE");
-}
-
-TEST_F(ActionsSuggestionsTest, CreatesTextRepliesFromRules) {
- const std::string actions_model_string =
- ReadFile(GetModelPath() + kModelFileName);
- std::unique_ptr<ActionsModelT> actions_model =
- UnPackActionsModel(actions_model_string.c_str());
- ASSERT_TRUE(DecompressActionsModel(actions_model.get()));
-
- actions_model->rules.reset(new RulesModelT());
- actions_model->rules->regex_rule.emplace_back(new RulesModel_::RegexRuleT);
- RulesModel_::RegexRuleT* rule = actions_model->rules->regex_rule.back().get();
- rule->pattern = "(?i:reply (stop|quit|end) (?:to|for) )";
- rule->actions.emplace_back(new RulesModel_::RuleActionSpecT);
-
- // Set capturing groups for entity data.
- rule->actions.back()->capturing_group.emplace_back(
- new RulesModel_::RuleActionSpec_::RuleCapturingGroupT);
- RulesModel_::RuleActionSpec_::RuleCapturingGroupT* code_group =
- rule->actions.back()->capturing_group.back().get();
- code_group->group_id = 1;
- code_group->text_reply.reset(new ActionSuggestionSpecT);
- code_group->text_reply->score = 1.0f;
- code_group->text_reply->priority_score = 1.0f;
- code_group->normalization_options.reset(new NormalizationOptionsT);
- code_group->normalization_options->codepointwise_normalization =
- NormalizationOptions_::CodepointwiseNormalizationOp_LOWERCASE;
-
- flatbuffers::FlatBufferBuilder builder;
- FinishActionsModelBuffer(builder,
- ActionsModel::Pack(builder, actions_model.get()));
- std::unique_ptr<ActionsSuggestions> actions_suggestions =
- ActionsSuggestions::FromUnownedBuffer(
- reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
- builder.GetSize(), &unilib_);
-
- const ActionsSuggestionsResponse& response =
- actions_suggestions->SuggestActions(
- {{{/*user_id=*/1,
- "visit test.com or reply STOP to cancel your subscription",
- /*reference_time_ms_utc=*/0,
- /*reference_timezone=*/"Europe/Zurich",
- /*annotations=*/{}, /*locales=*/"en"}}});
- EXPECT_GE(response.actions.size(), 1);
- EXPECT_EQ(response.actions[0].response_text, "stop");
-}
-
-TEST_F(ActionsSuggestionsTest, CreatesActionsFromGrammarRules) {
- const std::string actions_model_string =
- ReadFile(GetModelPath() + kModelFileName);
- std::unique_ptr<ActionsModelT> actions_model =
- UnPackActionsModel(actions_model_string.c_str());
- ASSERT_TRUE(DecompressActionsModel(actions_model.get()));
-
- actions_model->rules->grammar_rules.reset(new RulesModel_::GrammarRulesT);
-
- // Set tokenizer options.
- RulesModel_::GrammarRulesT* action_grammar_rules =
- actions_model->rules->grammar_rules.get();
- action_grammar_rules->tokenizer_options.reset(new ActionsTokenizerOptionsT);
- action_grammar_rules->tokenizer_options->type = TokenizationType_ICU;
- action_grammar_rules->tokenizer_options->icu_preserve_whitespace_tokens =
- true;
-
- // Setup test rules.
- action_grammar_rules->rules.reset(new grammar::RulesSetT);
- grammar::Rules rules;
- rules.Add("<knock>", {"<^>", "ventura", "!?", "<$>"},
- /*callback=*/
- static_cast<grammar::CallbackId>(
- GrammarActions::Callback::kActionRuleMatch),
- /*callback_param=*/0);
- rules.Finalize().Serialize(/*include_debug_information=*/false,
- action_grammar_rules->rules.get());
- action_grammar_rules->actions.emplace_back(new RulesModel_::RuleActionSpecT);
- RulesModel_::RuleActionSpecT* actions_spec =
- action_grammar_rules->actions.back().get();
- actions_spec->action.reset(new ActionSuggestionSpecT);
- actions_spec->action->response_text = "Yes, Satan?";
- actions_spec->action->priority_score = 1.0;
- actions_spec->action->score = 1.0;
- actions_spec->action->type = "text_reply";
- action_grammar_rules->rule_match.emplace_back(
- new RulesModel_::GrammarRules_::RuleMatchT);
- action_grammar_rules->rule_match.back()->action_id.push_back(0);
-
- flatbuffers::FlatBufferBuilder builder;
- FinishActionsModelBuffer(builder,
- ActionsModel::Pack(builder, actions_model.get()));
- std::unique_ptr<ActionsSuggestions> actions_suggestions =
- ActionsSuggestions::FromUnownedBuffer(
- reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
- builder.GetSize(), &unilib_);
-
- const ActionsSuggestionsResponse& response =
- actions_suggestions->SuggestActions(
- {{{/*user_id=*/1, "Ventura!",
- /*reference_time_ms_utc=*/0,
- /*reference_timezone=*/"Europe/Zurich",
- /*annotations=*/{}, /*locales=*/"en"}}});
-
- EXPECT_THAT(response.actions, ElementsAre(IsSmartReply("Yes, Satan?")));
-}
-
-TEST_F(ActionsSuggestionsTest, DeduplicateActions) {
- std::unique_ptr<ActionsSuggestions> actions_suggestions = LoadTestModel();
- ActionsSuggestionsResponse response = actions_suggestions->SuggestActions(
- {{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
- /*reference_timezone=*/"Europe/Zurich",
- /*annotations=*/{}, /*locales=*/"en"}}});
-
- // Check that the location sharing model triggered.
- bool has_location_sharing_action = false;
- for (const ActionSuggestion action : response.actions) {
- if (action.type == ActionsSuggestions::kShareLocation) {
- has_location_sharing_action = true;
- break;
- }
- }
- EXPECT_TRUE(has_location_sharing_action);
- const int num_actions = response.actions.size();
-
- // Add custom rule for location sharing.
- const std::string actions_model_string =
- ReadFile(GetModelPath() + kModelFileName);
- std::unique_ptr<ActionsModelT> actions_model =
- UnPackActionsModel(actions_model_string.c_str());
- ASSERT_TRUE(DecompressActionsModel(actions_model.get()));
-
- actions_model->rules.reset(new RulesModelT());
- actions_model->rules->regex_rule.emplace_back(new RulesModel_::RegexRuleT);
- actions_model->rules->regex_rule.back()->pattern =
- "^(?i:where are you[.?]?)$";
- actions_model->rules->regex_rule.back()->actions.emplace_back(
- new RulesModel_::RuleActionSpecT);
- actions_model->rules->regex_rule.back()->actions.back()->action.reset(
- new ActionSuggestionSpecT);
- ActionSuggestionSpecT* action =
- actions_model->rules->regex_rule.back()->actions.back()->action.get();
- action->score = 1.0f;
- action->type = ActionsSuggestions::kShareLocation;
-
- flatbuffers::FlatBufferBuilder builder;
- FinishActionsModelBuffer(builder,
- ActionsModel::Pack(builder, actions_model.get()));
- actions_suggestions = ActionsSuggestions::FromUnownedBuffer(
- reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
- builder.GetSize(), &unilib_);
-
- response = actions_suggestions->SuggestActions(
- {{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
- /*reference_timezone=*/"Europe/Zurich",
- /*annotations=*/{}, /*locales=*/"en"}}});
- EXPECT_EQ(response.actions.size(), num_actions);
-}
-
-TEST_F(ActionsSuggestionsTest, DeduplicateConflictingActions) {
- std::unique_ptr<ActionsSuggestions> actions_suggestions = LoadTestModel();
- AnnotatedSpan annotation;
- annotation.span = {7, 11};
- annotation.classification = {
- ClassificationResult(Collections::Flight(), 1.0)};
- ActionsSuggestionsResponse response = actions_suggestions->SuggestActions(
- {{{/*user_id=*/1, "I'm on LX38",
- /*reference_time_ms_utc=*/0,
- /*reference_timezone=*/"Europe/Zurich",
- /*annotations=*/{annotation},
- /*locales=*/"en"}}});
-
- // Check that the phone actions are present.
- EXPECT_GE(response.actions.size(), 1);
- EXPECT_EQ(response.actions[0].type, "track_flight");
-
- // Add custom rule.
- const std::string actions_model_string =
- ReadFile(GetModelPath() + kModelFileName);
- std::unique_ptr<ActionsModelT> actions_model =
- UnPackActionsModel(actions_model_string.c_str());
- ASSERT_TRUE(DecompressActionsModel(actions_model.get()));
-
- actions_model->rules.reset(new RulesModelT());
- actions_model->rules->regex_rule.emplace_back(new RulesModel_::RegexRuleT);
- RulesModel_::RegexRuleT* rule = actions_model->rules->regex_rule.back().get();
- rule->pattern = "^(?i:I'm on ([a-z0-9]+))$";
- rule->actions.emplace_back(new RulesModel_::RuleActionSpecT);
- rule->actions.back()->action.reset(new ActionSuggestionSpecT);
- ActionSuggestionSpecT* action = rule->actions.back()->action.get();
- action->score = 1.0f;
- action->priority_score = 2.0f;
- action->type = "test_code";
- rule->actions.back()->capturing_group.emplace_back(
- new RulesModel_::RuleActionSpec_::RuleCapturingGroupT);
- RulesModel_::RuleActionSpec_::RuleCapturingGroupT* code_group =
- rule->actions.back()->capturing_group.back().get();
- code_group->group_id = 1;
- code_group->annotation_name = "code";
- code_group->annotation_type = "code";
-
- flatbuffers::FlatBufferBuilder builder;
- FinishActionsModelBuffer(builder,
- ActionsModel::Pack(builder, actions_model.get()));
- actions_suggestions = ActionsSuggestions::FromUnownedBuffer(
- reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
- builder.GetSize(), &unilib_);
-
- response = actions_suggestions->SuggestActions(
- {{{/*user_id=*/1, "I'm on LX38",
- /*reference_time_ms_utc=*/0,
- /*reference_timezone=*/"Europe/Zurich",
- /*annotations=*/{annotation},
- /*locales=*/"en"}}});
- EXPECT_GE(response.actions.size(), 1);
- EXPECT_EQ(response.actions[0].type, "test_code");
-}
-#endif
-
-TEST_F(ActionsSuggestionsTest, SuggestActionsRanking) {
- std::unique_ptr<ActionsSuggestions> actions_suggestions = LoadTestModel();
- std::vector<AnnotatedSpan> annotations(2);
- annotations[0].span = {11, 15};
- annotations[0].classification = {ClassificationResult("address", 1.0)};
- annotations[1].span = {19, 23};
- annotations[1].classification = {ClassificationResult("address", 2.0)};
- const ActionsSuggestionsResponse& response =
- actions_suggestions->SuggestActions(
- {{{/*user_id=*/1, "are you at home or work?",
- /*reference_time_ms_utc=*/0,
- /*reference_timezone=*/"Europe/Zurich",
- /*annotations=*/annotations,
- /*locales=*/"en"}}});
- EXPECT_GE(response.actions.size(), 2);
- EXPECT_EQ(response.actions[0].type, "view_map");
- EXPECT_EQ(response.actions[0].score, 2.0);
- EXPECT_EQ(response.actions[1].type, "view_map");
- EXPECT_EQ(response.actions[1].score, 1.0);
-}
-
-TEST_F(ActionsSuggestionsTest, VisitActionsModel) {
- EXPECT_TRUE(VisitActionsModel<bool>(GetModelPath() + kModelFileName,
- [](const ActionsModel* model) {
- if (model == nullptr) {
- return false;
- }
- return true;
- }));
- EXPECT_FALSE(VisitActionsModel<bool>(GetModelPath() + "non_existing_model.fb",
- [](const ActionsModel* model) {
- if (model == nullptr) {
- return false;
- }
- return true;
- }));
-}
-
-TEST_F(ActionsSuggestionsTest, SuggestActionsWithHashGramModel) {
- std::unique_ptr<ActionsSuggestions> actions_suggestions =
- LoadHashGramTestModel();
- ASSERT_TRUE(actions_suggestions != nullptr);
- {
- const ActionsSuggestionsResponse& response =
- actions_suggestions->SuggestActions(
- {{{/*user_id=*/1, "hello",
- /*reference_time_ms_utc=*/0,
- /*reference_timezone=*/"Europe/Zurich",
- /*annotations=*/{},
- /*locales=*/"en"}}});
- EXPECT_THAT(response.actions, testing::IsEmpty());
- }
- {
- const ActionsSuggestionsResponse& response =
- actions_suggestions->SuggestActions(
- {{{/*user_id=*/1, "where are you",
- /*reference_time_ms_utc=*/0,
- /*reference_timezone=*/"Europe/Zurich",
- /*annotations=*/{},
- /*locales=*/"en"}}});
- EXPECT_THAT(
- response.actions,
- ElementsAre(testing::Field(&ActionSuggestion::type, "share_location")));
- }
- {
- const ActionsSuggestionsResponse& response =
- actions_suggestions->SuggestActions(
- {{{/*user_id=*/1, "do you know johns number",
- /*reference_time_ms_utc=*/0,
- /*reference_timezone=*/"Europe/Zurich",
- /*annotations=*/{},
- /*locales=*/"en"}}});
- EXPECT_THAT(
- response.actions,
- ElementsAre(testing::Field(&ActionSuggestion::type, "share_contact")));
- }
-}
-
-// Test class to expose token embedding methods for testing.
-class TestingMessageEmbedder : private ActionsSuggestions {
- public:
- explicit TestingMessageEmbedder(const ActionsModel* model);
-
- using ActionsSuggestions::EmbedAndFlattenTokens;
- using ActionsSuggestions::EmbedTokensPerMessage;
-
- protected:
- // EmbeddingExecutor that always returns features based on
- // the id of the sparse features.
- class FakeEmbeddingExecutor : public EmbeddingExecutor {
- public:
- bool AddEmbedding(const TensorView<int>& sparse_features, float* dest,
- const int dest_size) const override {
- TC3_CHECK_GE(dest_size, 1);
- EXPECT_EQ(sparse_features.size(), 1);
- dest[0] = sparse_features.data()[0];
- return true;
- }
- };
-};
-
-TestingMessageEmbedder::TestingMessageEmbedder(const ActionsModel* model) {
- model_ = model;
- const ActionsTokenFeatureProcessorOptions* options =
- model->feature_processor_options();
- feature_processor_.reset(
- new ActionsFeatureProcessor(options, /*unilib=*/nullptr));
- embedding_executor_.reset(new FakeEmbeddingExecutor());
- EXPECT_TRUE(
- EmbedTokenId(options->padding_token_id(), &embedded_padding_token_));
- EXPECT_TRUE(EmbedTokenId(options->start_token_id(), &embedded_start_token_));
- EXPECT_TRUE(EmbedTokenId(options->end_token_id(), &embedded_end_token_));
- token_embedding_size_ = feature_processor_->GetTokenEmbeddingSize();
- EXPECT_EQ(token_embedding_size_, 1);
-}
-
-class EmbeddingTest : public testing::Test {
- protected:
- EmbeddingTest() {
- model_.feature_processor_options.reset(
- new ActionsTokenFeatureProcessorOptionsT);
- options_ = model_.feature_processor_options.get();
- options_->chargram_orders = {1};
- options_->num_buckets = 1000;
- options_->embedding_size = 1;
- options_->start_token_id = 0;
- options_->end_token_id = 1;
- options_->padding_token_id = 2;
- options_->tokenizer_options.reset(new ActionsTokenizerOptionsT);
- }
-
- TestingMessageEmbedder CreateTestingMessageEmbedder() {
- flatbuffers::FlatBufferBuilder builder;
- FinishActionsModelBuffer(builder, ActionsModel::Pack(builder, &model_));
- buffer_ = builder.Release();
- return TestingMessageEmbedder(
- flatbuffers::GetRoot<ActionsModel>(buffer_.data()));
- }
-
- flatbuffers::DetachedBuffer buffer_;
- ActionsModelT model_;
- ActionsTokenFeatureProcessorOptionsT* options_;
-};
-
-TEST_F(EmbeddingTest, EmbedsTokensPerMessageWithNoBounds) {
- const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
- std::vector<std::vector<Token>> tokens = {
- {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}};
- std::vector<float> embeddings;
- int max_num_tokens_per_message = 0;
-
- EXPECT_TRUE(embedder.EmbedTokensPerMessage(tokens, &embeddings,
- &max_num_tokens_per_message));
-
- EXPECT_EQ(max_num_tokens_per_message, 3);
- EXPECT_EQ(embeddings.size(), 3);
- EXPECT_THAT(embeddings[0],
- testing::FloatEq(tc3farmhash::Fingerprint64("a", 1) %
- options_->num_buckets));
- EXPECT_THAT(embeddings[1],
- testing::FloatEq(tc3farmhash::Fingerprint64("b", 1) %
- options_->num_buckets));
- EXPECT_THAT(embeddings[2],
- testing::FloatEq(tc3farmhash::Fingerprint64("c", 1) %
- options_->num_buckets));
-}
-
-TEST_F(EmbeddingTest, EmbedsTokensPerMessageWithPadding) {
- options_->min_num_tokens_per_message = 5;
- const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
- std::vector<std::vector<Token>> tokens = {
- {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}};
- std::vector<float> embeddings;
- int max_num_tokens_per_message = 0;
-
- EXPECT_TRUE(embedder.EmbedTokensPerMessage(tokens, &embeddings,
- &max_num_tokens_per_message));
-
- EXPECT_EQ(max_num_tokens_per_message, 5);
- EXPECT_EQ(embeddings.size(), 5);
- EXPECT_THAT(embeddings[0],
- testing::FloatEq(tc3farmhash::Fingerprint64("a", 1) %
- options_->num_buckets));
- EXPECT_THAT(embeddings[1],
- testing::FloatEq(tc3farmhash::Fingerprint64("b", 1) %
- options_->num_buckets));
- EXPECT_THAT(embeddings[2],
- testing::FloatEq(tc3farmhash::Fingerprint64("c", 1) %
- options_->num_buckets));
- EXPECT_THAT(embeddings[3], testing::FloatEq(options_->padding_token_id));
- EXPECT_THAT(embeddings[4], testing::FloatEq(options_->padding_token_id));
-}
-
-TEST_F(EmbeddingTest, EmbedsTokensPerMessageDropsAtBeginning) {
- options_->max_num_tokens_per_message = 2;
- const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
- std::vector<std::vector<Token>> tokens = {
- {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}};
- std::vector<float> embeddings;
- int max_num_tokens_per_message = 0;
-
- EXPECT_TRUE(embedder.EmbedTokensPerMessage(tokens, &embeddings,
- &max_num_tokens_per_message));
-
- EXPECT_EQ(max_num_tokens_per_message, 2);
- EXPECT_EQ(embeddings.size(), 2);
- EXPECT_THAT(embeddings[0],
- testing::FloatEq(tc3farmhash::Fingerprint64("b", 1) %
- options_->num_buckets));
- EXPECT_THAT(embeddings[1],
- testing::FloatEq(tc3farmhash::Fingerprint64("c", 1) %
- options_->num_buckets));
-}
-
-TEST_F(EmbeddingTest, EmbedsTokensPerMessageWithMultipleMessagesNoBounds) {
- const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
- std::vector<std::vector<Token>> tokens = {
- {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)},
- {Token("d", 0, 1), Token("e", 2, 3)}};
- std::vector<float> embeddings;
- int max_num_tokens_per_message = 0;
-
- EXPECT_TRUE(embedder.EmbedTokensPerMessage(tokens, &embeddings,
- &max_num_tokens_per_message));
-
- EXPECT_EQ(max_num_tokens_per_message, 3);
- EXPECT_THAT(embeddings[0],
- testing::FloatEq(tc3farmhash::Fingerprint64("a", 1) %
- options_->num_buckets));
- EXPECT_THAT(embeddings[1],
- testing::FloatEq(tc3farmhash::Fingerprint64("b", 1) %
- options_->num_buckets));
- EXPECT_THAT(embeddings[2],
- testing::FloatEq(tc3farmhash::Fingerprint64("c", 1) %
- options_->num_buckets));
- EXPECT_THAT(embeddings[3],
- testing::FloatEq(tc3farmhash::Fingerprint64("d", 1) %
- options_->num_buckets));
- EXPECT_THAT(embeddings[4],
- testing::FloatEq(tc3farmhash::Fingerprint64("e", 1) %
- options_->num_buckets));
- EXPECT_THAT(embeddings[5], testing::FloatEq(options_->padding_token_id));
-}
-
-TEST_F(EmbeddingTest, EmbedsFlattenedTokensWithNoBounds) {
- const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
- std::vector<std::vector<Token>> tokens = {
- {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}};
- std::vector<float> embeddings;
- int total_token_count = 0;
-
- EXPECT_TRUE(
- embedder.EmbedAndFlattenTokens(tokens, &embeddings, &total_token_count));
-
- EXPECT_EQ(total_token_count, 5);
- EXPECT_EQ(embeddings.size(), 5);
- EXPECT_THAT(embeddings[0], testing::FloatEq(options_->start_token_id));
- EXPECT_THAT(embeddings[1],
- testing::FloatEq(tc3farmhash::Fingerprint64("a", 1) %
- options_->num_buckets));
- EXPECT_THAT(embeddings[2],
- testing::FloatEq(tc3farmhash::Fingerprint64("b", 1) %
- options_->num_buckets));
- EXPECT_THAT(embeddings[3],
- testing::FloatEq(tc3farmhash::Fingerprint64("c", 1) %
- options_->num_buckets));
- EXPECT_THAT(embeddings[4], testing::FloatEq(options_->end_token_id));
-}
-
-TEST_F(EmbeddingTest, EmbedsFlattenedTokensWithPadding) {
- options_->min_num_total_tokens = 7;
- const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
- std::vector<std::vector<Token>> tokens = {
- {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}};
- std::vector<float> embeddings;
- int total_token_count = 0;
-
- EXPECT_TRUE(
- embedder.EmbedAndFlattenTokens(tokens, &embeddings, &total_token_count));
-
- EXPECT_EQ(total_token_count, 7);
- EXPECT_EQ(embeddings.size(), 7);
- EXPECT_THAT(embeddings[0], testing::FloatEq(options_->start_token_id));
- EXPECT_THAT(embeddings[1],
- testing::FloatEq(tc3farmhash::Fingerprint64("a", 1) %
- options_->num_buckets));
- EXPECT_THAT(embeddings[2],
- testing::FloatEq(tc3farmhash::Fingerprint64("b", 1) %
- options_->num_buckets));
- EXPECT_THAT(embeddings[3],
- testing::FloatEq(tc3farmhash::Fingerprint64("c", 1) %
- options_->num_buckets));
- EXPECT_THAT(embeddings[4], testing::FloatEq(options_->end_token_id));
- EXPECT_THAT(embeddings[5], testing::FloatEq(options_->padding_token_id));
- EXPECT_THAT(embeddings[6], testing::FloatEq(options_->padding_token_id));
-}
-
-TEST_F(EmbeddingTest, EmbedsFlattenedTokensDropsAtBeginning) {
- options_->max_num_total_tokens = 3;
- const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
- std::vector<std::vector<Token>> tokens = {
- {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}};
- std::vector<float> embeddings;
- int total_token_count = 0;
-
- EXPECT_TRUE(
- embedder.EmbedAndFlattenTokens(tokens, &embeddings, &total_token_count));
-
- EXPECT_EQ(total_token_count, 3);
- EXPECT_EQ(embeddings.size(), 3);
- EXPECT_THAT(embeddings[0],
- testing::FloatEq(tc3farmhash::Fingerprint64("b", 1) %
- options_->num_buckets));
- EXPECT_THAT(embeddings[1],
- testing::FloatEq(tc3farmhash::Fingerprint64("c", 1) %
- options_->num_buckets));
- EXPECT_THAT(embeddings[2], testing::FloatEq(options_->end_token_id));
-}
-
-TEST_F(EmbeddingTest, EmbedsFlattenedTokensWithMultipleMessagesNoBounds) {
- const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
- std::vector<std::vector<Token>> tokens = {
- {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)},
- {Token("d", 0, 1), Token("e", 2, 3)}};
- std::vector<float> embeddings;
- int total_token_count = 0;
-
- EXPECT_TRUE(
- embedder.EmbedAndFlattenTokens(tokens, &embeddings, &total_token_count));
-
- EXPECT_EQ(total_token_count, 9);
- EXPECT_EQ(embeddings.size(), 9);
- EXPECT_THAT(embeddings[0], testing::FloatEq(options_->start_token_id));
- EXPECT_THAT(embeddings[1],
- testing::FloatEq(tc3farmhash::Fingerprint64("a", 1) %
- options_->num_buckets));
- EXPECT_THAT(embeddings[2],
- testing::FloatEq(tc3farmhash::Fingerprint64("b", 1) %
- options_->num_buckets));
- EXPECT_THAT(embeddings[3],
- testing::FloatEq(tc3farmhash::Fingerprint64("c", 1) %
- options_->num_buckets));
- EXPECT_THAT(embeddings[4], testing::FloatEq(options_->end_token_id));
- EXPECT_THAT(embeddings[5], testing::FloatEq(options_->start_token_id));
- EXPECT_THAT(embeddings[6],
- testing::FloatEq(tc3farmhash::Fingerprint64("d", 1) %
- options_->num_buckets));
- EXPECT_THAT(embeddings[7],
- testing::FloatEq(tc3farmhash::Fingerprint64("e", 1) %
- options_->num_buckets));
- EXPECT_THAT(embeddings[8], testing::FloatEq(options_->end_token_id));
-}
-
-TEST_F(EmbeddingTest,
- EmbedsFlattenedTokensWithMultipleMessagesDropsAtBeginning) {
- options_->max_num_total_tokens = 7;
- const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
- std::vector<std::vector<Token>> tokens = {
- {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)},
- {Token("d", 0, 1), Token("e", 2, 3), Token("f", 4, 5)}};
- std::vector<float> embeddings;
- int total_token_count = 0;
-
- EXPECT_TRUE(
- embedder.EmbedAndFlattenTokens(tokens, &embeddings, &total_token_count));
-
- EXPECT_EQ(total_token_count, 7);
- EXPECT_EQ(embeddings.size(), 7);
- EXPECT_THAT(embeddings[0],
- testing::FloatEq(tc3farmhash::Fingerprint64("c", 1) %
- options_->num_buckets));
- EXPECT_THAT(embeddings[1], testing::FloatEq(options_->end_token_id));
- EXPECT_THAT(embeddings[2], testing::FloatEq(options_->start_token_id));
- EXPECT_THAT(embeddings[3],
- testing::FloatEq(tc3farmhash::Fingerprint64("d", 1) %
- options_->num_buckets));
- EXPECT_THAT(embeddings[4],
- testing::FloatEq(tc3farmhash::Fingerprint64("e", 1) %
- options_->num_buckets));
- EXPECT_THAT(embeddings[5],
- testing::FloatEq(tc3farmhash::Fingerprint64("f", 1) %
- options_->num_buckets));
- EXPECT_THAT(embeddings[6], testing::FloatEq(options_->end_token_id));
-}
-
-} // namespace
-} // namespace libtextclassifier3
diff --git a/native/actions/actions_model.fbs b/native/actions/actions_model.fbs
index 51fe2b4..6f666eb 100755
--- a/native/actions/actions_model.fbs
+++ b/native/actions/actions_model.fbs
@@ -391,6 +391,10 @@
// If set, normalization to apply to the capturing group text.
normalization_options:NormalizationOptions;
+
+ // If set to true, an existing annotator annotation will be used to
+ // create the actions suggestions text annotation.
+ use_annotation_match:bool;
}
// The actions to produce upon triggering.
@@ -426,6 +430,12 @@
action_id:[uint];
}
+namespace libtextclassifier3.RulesModel_.GrammarRules_;
+table AnnotationNonterminalEntry {
+ key:string (key, shared);
+ value:int;
+}
+
// Configuration for actions based on context-free grammars.
namespace libtextclassifier3.RulesModel_;
table GrammarRules {
@@ -439,6 +449,10 @@
// The action specifications used by the rule matches.
actions:[RuleActionSpec];
+
+ // Predefined nonterminals for annotations.
+ // Maps annotation/collection names to non-terminal ids.
+ annotation_nonterminal:[GrammarRules_.AnnotationNonterminalEntry];
}
// Rule based actions.
diff --git a/native/actions/feature-processor_test.cc b/native/actions/feature-processor_test.cc
deleted file mode 100644
index 0a1e3ac..0000000
--- a/native/actions/feature-processor_test.cc
+++ /dev/null
@@ -1,130 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "actions/feature-processor.h"
-
-#include "actions/actions_model_generated.h"
-#include "annotator/model-executor.h"
-#include "utils/tensor-view.h"
-
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-
-namespace libtextclassifier3 {
-namespace {
-
-using testing::FloatEq;
-
-// EmbeddingExecutor that always returns features based on
-// the id of the sparse features.
-class FakeEmbeddingExecutor : public EmbeddingExecutor {
- public:
- bool AddEmbedding(const TensorView<int>& sparse_features, float* dest,
- const int dest_size) const override {
- TC3_CHECK_GE(dest_size, 4);
- EXPECT_EQ(sparse_features.size(), 1);
- dest[0] = sparse_features.data()[0];
- dest[1] = sparse_features.data()[0];
- dest[2] = -sparse_features.data()[0];
- dest[3] = -sparse_features.data()[0];
- return true;
- }
-
- private:
- std::vector<float> storage_;
-};
-
-class FeatureProcessorTest : public ::testing::Test {
- protected:
- FeatureProcessorTest() : INIT_UNILIB_FOR_TESTING(unilib_) {}
-
- flatbuffers::DetachedBuffer PackFeatureProcessorOptions(
- ActionsTokenFeatureProcessorOptionsT* options) const {
- flatbuffers::FlatBufferBuilder builder;
- builder.Finish(CreateActionsTokenFeatureProcessorOptions(builder, options));
- return builder.Release();
- }
-
- FakeEmbeddingExecutor embedding_executor_;
- UniLib unilib_;
-};
-
-TEST_F(FeatureProcessorTest, TokenEmbeddings) {
- ActionsTokenFeatureProcessorOptionsT options;
- options.embedding_size = 4;
- options.tokenizer_options.reset(new ActionsTokenizerOptionsT);
-
- flatbuffers::DetachedBuffer options_fb =
- PackFeatureProcessorOptions(&options);
- ActionsFeatureProcessor feature_processor(
- flatbuffers::GetRoot<ActionsTokenFeatureProcessorOptions>(
- options_fb.data()),
- &unilib_);
-
- Token token("aaa", 0, 3);
- std::vector<float> token_features;
- EXPECT_TRUE(feature_processor.AppendTokenFeatures(token, &embedding_executor_,
- &token_features));
- EXPECT_EQ(token_features.size(), 4);
-}
-
-TEST_F(FeatureProcessorTest, TokenEmbeddingsCaseFeature) {
- ActionsTokenFeatureProcessorOptionsT options;
- options.embedding_size = 4;
- options.extract_case_feature = true;
- options.tokenizer_options.reset(new ActionsTokenizerOptionsT);
-
- flatbuffers::DetachedBuffer options_fb =
- PackFeatureProcessorOptions(&options);
- ActionsFeatureProcessor feature_processor(
- flatbuffers::GetRoot<ActionsTokenFeatureProcessorOptions>(
- options_fb.data()),
- &unilib_);
-
- Token token("Aaa", 0, 3);
- std::vector<float> token_features;
- EXPECT_TRUE(feature_processor.AppendTokenFeatures(token, &embedding_executor_,
- &token_features));
- EXPECT_EQ(token_features.size(), 5);
- EXPECT_THAT(token_features[4], FloatEq(1.0));
-}
-
-TEST_F(FeatureProcessorTest, MultipleTokenEmbeddingsCaseFeature) {
- ActionsTokenFeatureProcessorOptionsT options;
- options.embedding_size = 4;
- options.extract_case_feature = true;
- options.tokenizer_options.reset(new ActionsTokenizerOptionsT);
-
- flatbuffers::DetachedBuffer options_fb =
- PackFeatureProcessorOptions(&options);
- ActionsFeatureProcessor feature_processor(
- flatbuffers::GetRoot<ActionsTokenFeatureProcessorOptions>(
- options_fb.data()),
- &unilib_);
-
- const std::vector<Token> tokens = {Token("Aaa", 0, 3), Token("bbb", 4, 7),
- Token("Cccc", 8, 12)};
- std::vector<float> token_features;
- EXPECT_TRUE(feature_processor.AppendTokenFeatures(
- tokens, &embedding_executor_, &token_features));
- EXPECT_EQ(token_features.size(), 15);
- EXPECT_THAT(token_features[4], FloatEq(1.0));
- EXPECT_THAT(token_features[9], FloatEq(-1.0));
- EXPECT_THAT(token_features[14], FloatEq(1.0));
-}
-
-} // namespace
-} // namespace libtextclassifier3
diff --git a/native/actions/grammar-actions.cc b/native/actions/grammar-actions.cc
index 84208aa..05cae46 100644
--- a/native/actions/grammar-actions.cc
+++ b/native/actions/grammar-actions.cc
@@ -21,6 +21,7 @@
#include "actions/feature-processor.h"
#include "actions/utils.h"
+#include "annotator/types.h"
#include "utils/grammar/callback-delegate.h"
#include "utils/grammar/match.h"
#include "utils/grammar/matcher.h"
@@ -31,6 +32,12 @@
namespace libtextclassifier3 {
namespace {
+// Represents an annotator annotated span in the grammar.
+struct AnnotationMatch : public grammar::Match {
+ static const int16 kType = 1;
+ ClassificationResult annotation;
+};
+
class GrammarActionsCallbackDelegate : public grammar::CallbackDelegate {
public:
GrammarActionsCallbackDelegate(const UniLib* unilib,
@@ -180,6 +187,16 @@
if (FillAnnotationFromCapturingMatch(
/*span=*/capturing_match->codepoint_span, group,
/*message_index=*/message_index, match_text, &annotation)) {
+ if (group->use_annotation_match()) {
+ const AnnotationMatch* annotation_match =
+ grammar::SelectFirstOfType<AnnotationMatch>(
+ capturing_match, AnnotationMatch::kType);
+ if (!annotation_match) {
+ TC3_LOG(ERROR) << "Could not get annotation for match.";
+ return false;
+ }
+ annotation.entity = annotation_match->annotation;
+ }
annotations.push_back(annotation);
}
}
@@ -247,11 +264,42 @@
{grammar_rules_->rules()->rules()->Get(i), &callback_handler});
}
}
+
+ if (locale_rules.empty()) {
+ // Nothing to do.
+ return true;
+ }
+
+ std::vector<AnnotationMatch> matches;
+ if (auto annotation_nonterminals = grammar_rules_->annotation_nonterminal()) {
+ for (const AnnotatedSpan& annotation :
+ conversation.messages.back().annotations) {
+ if (annotation.classification.empty()) {
+ continue;
+ }
+ const ClassificationResult& classification =
+ annotation.classification.front();
+ if (auto entry = annotation_nonterminals->LookupByKey(
+ classification.collection.c_str())) {
+ AnnotationMatch match;
+ match.Init(entry->value(), annotation.span, annotation.span.first,
+ AnnotationMatch::kType);
+ match.annotation = classification;
+ matches.push_back(match);
+ }
+ }
+ }
+
+ std::vector<grammar::Match*> annotation_matches(matches.size());
+ for (int i = 0; i < matches.size(); i++) {
+ annotation_matches[i] = &matches[i];
+ }
+
grammar::Matcher matcher(*unilib_, grammar_rules_->rules(), locale_rules);
// Run grammar on last message.
lexer_.Process(tokenizer_->Tokenize(conversation.messages.back().text),
- &matcher);
+ /*matches=*/annotation_matches, &matcher);
// Populate results.
return callback_handler.GetActions(conversation, smart_reply_action_type_,
diff --git a/native/actions/grammar-actions_test.cc b/native/actions/grammar-actions_test.cc
deleted file mode 100644
index 3df1d63..0000000
--- a/native/actions/grammar-actions_test.cc
+++ /dev/null
@@ -1,575 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "actions/grammar-actions.h"
-
-#include <iostream>
-#include <memory>
-
-#include "actions/actions_model_generated.h"
-#include "actions/test-utils.h"
-#include "actions/types.h"
-#include "utils/flatbuffers.h"
-#include "utils/grammar/rules_generated.h"
-#include "utils/grammar/utils/rules.h"
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-
-namespace libtextclassifier3 {
-namespace {
-
-using testing::ElementsAre;
-
-class TestGrammarActions : public GrammarActions {
- public:
- explicit TestGrammarActions(
- const UniLib* unilib, const RulesModel_::GrammarRules* grammar_rules,
- const ReflectiveFlatbufferBuilder* entity_data_builder = nullptr)
- : GrammarActions(unilib, grammar_rules, entity_data_builder,
-
- /*smart_reply_action_type=*/"text_reply") {}
-};
-
-class GrammarActionsTest : public testing::Test {
- protected:
- GrammarActionsTest()
- : INIT_UNILIB_FOR_TESTING(unilib_),
- serialized_entity_data_schema_(TestEntityDataSchema()),
- entity_data_builder_(new ReflectiveFlatbufferBuilder(
- flatbuffers::GetRoot<reflection::Schema>(
- serialized_entity_data_schema_.data()))) {}
-
- void SetTokenizerOptions(
- RulesModel_::GrammarRulesT* action_grammar_rules) const {
- action_grammar_rules->tokenizer_options.reset(new ActionsTokenizerOptionsT);
- action_grammar_rules->tokenizer_options->type = TokenizationType_ICU;
- action_grammar_rules->tokenizer_options->icu_preserve_whitespace_tokens =
- true;
- }
-
- flatbuffers::DetachedBuffer PackRules(
- const RulesModel_::GrammarRulesT& action_grammar_rules) const {
- flatbuffers::FlatBufferBuilder builder;
- builder.Finish(
- RulesModel_::GrammarRules::Pack(builder, &action_grammar_rules));
- return builder.Release();
- }
-
- int AddActionSpec(const std::string& type, const std::string& response_text,
- const std::vector<std::pair<int, std::string>>& annotations,
- RulesModel_::GrammarRulesT* action_grammar_rules) const {
- const int action_id = action_grammar_rules->actions.size();
- action_grammar_rules->actions.emplace_back(
- new RulesModel_::RuleActionSpecT);
- RulesModel_::RuleActionSpecT* actions_spec =
- action_grammar_rules->actions.back().get();
- actions_spec->action.reset(new ActionSuggestionSpecT);
- actions_spec->action->response_text = response_text;
- actions_spec->action->priority_score = 1.0;
- actions_spec->action->score = 1.0;
- actions_spec->action->type = type;
- // Create annotations for specified capturing groups.
- for (const auto& it : annotations) {
- actions_spec->capturing_group.emplace_back(
- new RulesModel_::RuleActionSpec_::RuleCapturingGroupT);
- actions_spec->capturing_group.back()->group_id = it.first;
- actions_spec->capturing_group.back()->annotation_name = it.second;
- actions_spec->capturing_group.back()->annotation_type = it.second;
- }
-
- return action_id;
- }
-
- int AddSmartReplySpec(
- const std::string& response_text,
- RulesModel_::GrammarRulesT* action_grammar_rules) const {
- return AddActionSpec("text_reply", response_text, {}, action_grammar_rules);
- }
-
- int AddCapturingMatchSmartReplySpec(
- const int match_id,
- RulesModel_::GrammarRulesT* action_grammar_rules) const {
- const int action_id = action_grammar_rules->actions.size();
- action_grammar_rules->actions.emplace_back(
- new RulesModel_::RuleActionSpecT);
- RulesModel_::RuleActionSpecT* actions_spec =
- action_grammar_rules->actions.back().get();
- actions_spec->capturing_group.emplace_back(
- new RulesModel_::RuleActionSpec_::RuleCapturingGroupT);
- actions_spec->capturing_group.back()->group_id = 0;
- actions_spec->capturing_group.back()->text_reply.reset(
- new ActionSuggestionSpecT);
- actions_spec->capturing_group.back()->text_reply->priority_score = 1.0;
- actions_spec->capturing_group.back()->text_reply->score = 1.0;
- return action_id;
- }
-
- int AddRuleMatch(const std::vector<int>& action_ids,
- RulesModel_::GrammarRulesT* action_grammar_rules) const {
- const int rule_match_id = action_grammar_rules->rule_match.size();
- action_grammar_rules->rule_match.emplace_back(
- new RulesModel_::GrammarRules_::RuleMatchT);
- action_grammar_rules->rule_match.back()->action_id.insert(
- action_grammar_rules->rule_match.back()->action_id.end(),
- action_ids.begin(), action_ids.end());
- return rule_match_id;
- }
-
- UniLib unilib_;
- const std::string serialized_entity_data_schema_;
- std::unique_ptr<ReflectiveFlatbufferBuilder> entity_data_builder_;
-};
-
-TEST_F(GrammarActionsTest, ProducesSmartReplies) {
- // Create test rules.
- // Rule: ^knock knock.?$ -> "Who's there?", "Yes?"
- RulesModel_::GrammarRulesT action_grammar_rules;
- SetTokenizerOptions(&action_grammar_rules);
- action_grammar_rules.rules.reset(new grammar::RulesSetT);
- grammar::Rules rules;
- rules.Add(
- "<knock>", {"<^>", "knock", "knock", ".?", "<$>"},
- /*callback=*/
- static_cast<grammar::CallbackId>(
- GrammarActions::Callback::kActionRuleMatch),
- /*callback_param=*/
- AddRuleMatch({AddSmartReplySpec("Who's there?", &action_grammar_rules),
- AddSmartReplySpec("Yes?", &action_grammar_rules)},
- &action_grammar_rules));
- rules.Finalize().Serialize(/*include_debug_information=*/false,
- action_grammar_rules.rules.get());
- flatbuffers::DetachedBuffer serialized_rules =
- PackRules(action_grammar_rules);
- TestGrammarActions grammar_actions(
- &unilib_,
- flatbuffers::GetRoot<RulesModel_::GrammarRules>(serialized_rules.data()));
-
- std::vector<ActionSuggestion> result;
- EXPECT_TRUE(grammar_actions.SuggestActions(
- {/*messages=*/{{/*user_id=*/0, /*text=*/"Knock knock"}}}, &result));
-
- EXPECT_THAT(result,
- ElementsAre(IsSmartReply("Who's there?"), IsSmartReply("Yes?")));
-}
-
-TEST_F(GrammarActionsTest, ProducesSmartRepliesFromCapturingMatches) {
- // Create test rules.
- // Rule: ^Text <reply> to <command>
- RulesModel_::GrammarRulesT action_grammar_rules;
- SetTokenizerOptions(&action_grammar_rules);
- action_grammar_rules.rules.reset(new grammar::RulesSetT);
- grammar::Rules rules;
-
- // Capturing matches will create their own match objects to keep track of
- // match ids, so we declare the handler as a filter so that the grammar system
- // knows that we handle this ourselves.
- rules.DefineFilter(static_cast<grammar::CallbackId>(
- GrammarActions::Callback::kCapturingMatch));
-
- rules.Add("<scripted_reply>",
- {"<^>", "text", "<captured_reply>", "to", "<command>"},
- /*callback=*/
- static_cast<grammar::CallbackId>(
- GrammarActions::Callback::kActionRuleMatch),
- /*callback_param=*/
- AddRuleMatch({AddCapturingMatchSmartReplySpec(
- /*match_id=*/0, &action_grammar_rules)},
- &action_grammar_rules));
-
- // <command> ::= unsubscribe | cancel | confirm | receive
- rules.Add("<command>", {"unsubscribe"});
- rules.Add("<command>", {"cancel"});
- rules.Add("<command>", {"confirm"});
- rules.Add("<command>", {"receive"});
-
- // <reply> ::= help | stop | cancel | yes
- rules.Add("<reply>", {"help"});
- rules.Add("<reply>", {"stop"});
- rules.Add("<reply>", {"cancel"});
- rules.Add("<reply>", {"yes"});
- rules.Add("<captured_reply>", {"<reply>"},
- /*callback=*/
- static_cast<grammar::CallbackId>(
- GrammarActions::Callback::kCapturingMatch),
- /*callback_param=*/0 /*match_id*/);
-
- rules.Finalize().Serialize(/*include_debug_information=*/false,
- action_grammar_rules.rules.get());
- flatbuffers::DetachedBuffer serialized_rules =
- PackRules(action_grammar_rules);
- TestGrammarActions grammar_actions(
- &unilib_,
- flatbuffers::GetRoot<RulesModel_::GrammarRules>(serialized_rules.data()));
-
- {
- std::vector<ActionSuggestion> result;
- EXPECT_TRUE(grammar_actions.SuggestActions(
- {/*messages=*/{{/*user_id=*/0,
- /*text=*/"Text YES to confirm your subscription"}}},
- &result));
- EXPECT_THAT(result, ElementsAre(IsSmartReply("YES")));
- }
-
- {
- std::vector<ActionSuggestion> result;
- EXPECT_TRUE(grammar_actions.SuggestActions(
- {/*messages=*/{{/*user_id=*/0,
- /*text=*/"text Stop to cancel your order"}}},
- &result));
- EXPECT_THAT(result, ElementsAre(IsSmartReply("Stop")));
- }
-}
-
-TEST_F(GrammarActionsTest, ProducesAnnotationsForActions) {
- // Create test rules.
- // Rule: please dial <phone>
- RulesModel_::GrammarRulesT action_grammar_rules;
- SetTokenizerOptions(&action_grammar_rules);
- action_grammar_rules.rules.reset(new grammar::RulesSetT);
- grammar::Rules rules;
-
- // Capturing matches will create their own match objects to keep track of
- // match ids, so we declare the handler as a filter so that the grammar system
- // knows that we handle this ourselves.
- rules.DefineFilter(static_cast<grammar::CallbackId>(
- GrammarActions::Callback::kCapturingMatch));
-
- rules.Add(
- "<call_phone>", {"please", "dial", "<phone>"},
- /*callback=*/
- static_cast<grammar::CallbackId>(
- GrammarActions::Callback::kActionRuleMatch),
- /*callback_param=*/
- AddRuleMatch({AddActionSpec("call_phone", /*response_text=*/"",
- /*annotations=*/{{0 /*match_id*/, "phone"}},
- &action_grammar_rules)},
- &action_grammar_rules));
- // phone ::= +00 00 000 00 00
- rules.Add("<phone>",
- {"+", "<2_digits>", "<2_digits>", "<3_digits>", "<2_digits>",
- "<2_digits>"},
- /*callback=*/
- static_cast<grammar::CallbackId>(
- GrammarActions::Callback::kCapturingMatch),
- /*callback_param=*/0 /*match_id*/);
- rules.Finalize().Serialize(/*include_debug_information=*/false,
- action_grammar_rules.rules.get());
- flatbuffers::DetachedBuffer serialized_rules =
- PackRules(action_grammar_rules);
- TestGrammarActions grammar_actions(
- &unilib_,
- flatbuffers::GetRoot<RulesModel_::GrammarRules>(serialized_rules.data()));
-
- std::vector<ActionSuggestion> result;
- EXPECT_TRUE(grammar_actions.SuggestActions(
- {/*messages=*/{{/*user_id=*/0, /*text=*/"Please dial +41 79 123 45 67"}}},
- &result));
-
- EXPECT_THAT(result, ElementsAre(IsActionOfType("call_phone")));
- EXPECT_THAT(result.front().annotations,
- ElementsAre(IsActionSuggestionAnnotation(
- "phone", "+41 79 123 45 67", CodepointSpan{12, 28})));
-}
-
-TEST_F(GrammarActionsTest, HandlesLocales) {
- // Create test rules.
- // Rule: ^knock knock.?$ -> "Who's there?"
- RulesModel_::GrammarRulesT action_grammar_rules;
- SetTokenizerOptions(&action_grammar_rules);
- action_grammar_rules.rules.reset(new grammar::RulesSetT);
- grammar::Rules rules(/*num_shards=*/2);
- rules.Add(
- "<knock>", {"<^>", "knock", "knock", ".?", "<$>"},
- /*callback=*/
- static_cast<grammar::CallbackId>(
- GrammarActions::Callback::kActionRuleMatch),
- /*callback_param=*/
- AddRuleMatch({AddSmartReplySpec("Who's there?", &action_grammar_rules)},
- &action_grammar_rules));
- rules.Add(
- "<toc>", {"<knock>"},
- /*callback=*/
- static_cast<grammar::CallbackId>(
- GrammarActions::Callback::kActionRuleMatch),
- /*callback_param=*/
- AddRuleMatch({AddSmartReplySpec("Qui est là?", &action_grammar_rules)},
- &action_grammar_rules),
- /*max_whitespace_gap=*/-1,
- /*case_sensitive=*/false,
- /*shard=*/1);
- rules.Finalize().Serialize(/*include_debug_information=*/false,
- action_grammar_rules.rules.get());
- // Set locales for rules.
- action_grammar_rules.rules->rules.back()->locale.emplace_back(
- new LanguageTagT);
- action_grammar_rules.rules->rules.back()->locale.back()->language = "fr";
-
- flatbuffers::DetachedBuffer serialized_rules =
- PackRules(action_grammar_rules);
- TestGrammarActions grammar_actions(
- &unilib_,
- flatbuffers::GetRoot<RulesModel_::GrammarRules>(serialized_rules.data()));
-
- // Check default.
- {
- std::vector<ActionSuggestion> result;
- EXPECT_TRUE(grammar_actions.SuggestActions(
- {/*messages=*/{{/*user_id=*/0, /*text=*/"knock knock",
- /*reference_time_ms_utc=*/0,
- /*reference_timezone=*/"UTC", /*annotations=*/{},
- /*detected_text_language_tags=*/"en"}}},
- &result));
-
- EXPECT_THAT(result, ElementsAre(IsSmartReply("Who's there?")));
- }
-
- // Check fr.
- {
- std::vector<ActionSuggestion> result;
- EXPECT_TRUE(grammar_actions.SuggestActions(
- {/*messages=*/{{/*user_id=*/0, /*text=*/"knock knock",
- /*reference_time_ms_utc=*/0,
- /*reference_timezone=*/"UTC", /*annotations=*/{},
- /*detected_text_language_tags=*/"fr-CH"}}},
- &result));
-
- EXPECT_THAT(result, ElementsAre(IsSmartReply("Who's there?"),
- IsSmartReply("Qui est là?")));
- }
-}
-
-TEST_F(GrammarActionsTest, HandlesAssertions) {
- // Create test rules.
- // Rule: <flight> -> Track flight.
- RulesModel_::GrammarRulesT action_grammar_rules;
- SetTokenizerOptions(&action_grammar_rules);
- action_grammar_rules.rules.reset(new grammar::RulesSetT);
- grammar::Rules rules;
-
- // Capturing matches will create their own match objects to keep track of
- // match ids, so we declare the handler as a filter so that the grammar system
- // knows that we handle this ourselves.
- rules.DefineFilter(static_cast<grammar::CallbackId>(
- GrammarActions::Callback::kCapturingMatch));
-
- // Assertion matches will create their own match objects.
- // We declare the handler as a filter so that the grammar system knows that we
- // handle this ourselves.
- rules.DefineFilter(static_cast<grammar::CallbackId>(
- GrammarActions::Callback::kAssertionMatch));
-
- rules.Add("<carrier>", {"lx"});
- rules.Add("<carrier>", {"aa"});
- rules.Add("<flight_code>", {"<2_digits>"});
- rules.Add("<flight_code>", {"<3_digits>"});
- rules.Add("<flight_code>", {"<4_digits>"});
-
- // Capture flight code.
- rules.Add("<flight>", {"<carrier>", "<flight_code>"},
- /*callback=*/
- static_cast<grammar::CallbackId>(
- GrammarActions::Callback::kCapturingMatch),
- /*callback_param=*/0 /*match_id*/);
-
- // Flight: carrier + flight code and check right context.
- rules.Add(
- "<track_flight>", {"<flight>", "<context_assertion>?"},
- /*callback=*/
- static_cast<grammar::CallbackId>(
- GrammarActions::Callback::kActionRuleMatch),
- /*callback_param=*/
- AddRuleMatch({AddActionSpec("track_flight", /*response_text=*/"",
- /*annotations=*/{{0 /*match_id*/, "flight"}},
- &action_grammar_rules)},
- &action_grammar_rules));
-
- // Exclude matches like: LX 38.00 etc.
- rules.Add("<context_assertion>", {".?", "<digits>"},
- /*callback=*/
- static_cast<grammar::CallbackId>(
- GrammarActions::Callback::kAssertionMatch),
- /*callback_param=*/true /*negative*/);
-
- rules.Finalize().Serialize(/*include_debug_information=*/false,
- action_grammar_rules.rules.get());
-
- flatbuffers::DetachedBuffer serialized_rules =
- PackRules(action_grammar_rules);
- TestGrammarActions grammar_actions(
- &unilib_,
- flatbuffers::GetRoot<RulesModel_::GrammarRules>(serialized_rules.data()));
-
- std::vector<ActionSuggestion> result;
- EXPECT_TRUE(grammar_actions.SuggestActions(
- {/*messages=*/{{/*user_id=*/0, /*text=*/"LX38 aa 44 LX 38.38"}}},
- &result));
-
- EXPECT_THAT(result, ElementsAre(IsActionOfType("track_flight"),
- IsActionOfType("track_flight")));
- EXPECT_THAT(result[0].annotations,
- ElementsAre(IsActionSuggestionAnnotation("flight", "LX38",
- CodepointSpan{0, 4})));
- EXPECT_THAT(result[1].annotations,
- ElementsAre(IsActionSuggestionAnnotation("flight", "aa 44",
- CodepointSpan{5, 10})));
-}
-
-TEST_F(GrammarActionsTest, SetsStaticEntityData) {
- // Create test rules.
- // Rule: ^hello there$
- RulesModel_::GrammarRulesT action_grammar_rules;
- SetTokenizerOptions(&action_grammar_rules);
- action_grammar_rules.rules.reset(new grammar::RulesSetT);
- grammar::Rules rules;
-
- // Create smart reply and static entity data.
- const int spec_id =
- AddSmartReplySpec("General Kenobi!", &action_grammar_rules);
- std::unique_ptr<ReflectiveFlatbuffer> entity_data =
- entity_data_builder_->NewRoot();
- entity_data->Set("person", "Kenobi");
- action_grammar_rules.actions[spec_id]->action->serialized_entity_data =
- entity_data->Serialize();
-
- rules.Add("<greeting>", {"<^>", "hello", "there", "<$>"},
- /*callback=*/
- static_cast<grammar::CallbackId>(
- GrammarActions::Callback::kActionRuleMatch),
- /*callback_param=*/
- AddRuleMatch({spec_id}, &action_grammar_rules));
- rules.Finalize().Serialize(/*include_debug_information=*/false,
- action_grammar_rules.rules.get());
- flatbuffers::DetachedBuffer serialized_rules =
- PackRules(action_grammar_rules);
- TestGrammarActions grammar_actions(
- &unilib_,
- flatbuffers::GetRoot<RulesModel_::GrammarRules>(serialized_rules.data()),
- entity_data_builder_.get());
-
- std::vector<ActionSuggestion> result;
- EXPECT_TRUE(grammar_actions.SuggestActions(
- {/*messages=*/{{/*user_id=*/0, /*text=*/"Hello there"}}}, &result));
-
- // Check the produces smart replies.
- EXPECT_THAT(result, ElementsAre(IsSmartReply("General Kenobi!")));
-
- // Check entity data.
- const flatbuffers::Table* entity =
- flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
- result[0].serialized_entity_data.data()));
- EXPECT_THAT(
- entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(),
- "Kenobi");
-}
-
-TEST_F(GrammarActionsTest, SetsEntityDataFromCapturingMatches) {
- // Create test rules.
- // Rule: ^hello there$
- RulesModel_::GrammarRulesT action_grammar_rules;
- SetTokenizerOptions(&action_grammar_rules);
- action_grammar_rules.rules.reset(new grammar::RulesSetT);
- grammar::Rules rules;
-
- // Capturing matches will create their own match objects to keep track of
- // match ids, so we declare the handler as a filter so that the grammar system
- // knows that we handle this ourselves.
- rules.DefineFilter(static_cast<grammar::CallbackId>(
- GrammarActions::Callback::kCapturingMatch));
-
- // Create smart reply and static entity data.
- const int spec_id =
- AddSmartReplySpec("General Kenobi!", &action_grammar_rules);
- std::unique_ptr<ReflectiveFlatbuffer> entity_data =
- entity_data_builder_->NewRoot();
- entity_data->Set("person", "Kenobi");
- action_grammar_rules.actions[spec_id]->action->serialized_entity_data =
- entity_data->Serialize();
-
- // Specify results for capturing matches.
- const int greeting_match_id = 0;
- const int location_match_id = 1;
- {
- action_grammar_rules.actions[spec_id]->capturing_group.emplace_back(
- new RulesModel_::RuleActionSpec_::RuleCapturingGroupT);
- RulesModel_::RuleActionSpec_::RuleCapturingGroupT* group =
- action_grammar_rules.actions[spec_id]->capturing_group.back().get();
- group->group_id = greeting_match_id;
- group->entity_field.reset(new FlatbufferFieldPathT);
- group->entity_field->field.emplace_back(new FlatbufferFieldT);
- group->entity_field->field.back()->field_name = "greeting";
- }
- {
- action_grammar_rules.actions[spec_id]->capturing_group.emplace_back(
- new RulesModel_::RuleActionSpec_::RuleCapturingGroupT);
- RulesModel_::RuleActionSpec_::RuleCapturingGroupT* group =
- action_grammar_rules.actions[spec_id]->capturing_group.back().get();
- group->group_id = 1;
- group->entity_field.reset(new FlatbufferFieldPathT);
- group->entity_field->field.emplace_back(new FlatbufferFieldT);
- group->entity_field->field.back()->field_name = "location";
- }
-
- rules.Add("<location>", {"there"});
- rules.Add("<location>", {"here"});
- rules.Add("<captured_location>", {"<location>"},
- /*callback=*/
- static_cast<grammar::CallbackId>(
- GrammarActions::Callback::kCapturingMatch),
- /*callback_param=*/location_match_id);
- rules.Add("<greeting>", {"hello", "<captured_location>"},
- /*callback=*/
- static_cast<grammar::CallbackId>(
- GrammarActions::Callback::kCapturingMatch),
- /*callback_param=*/greeting_match_id);
- rules.Add("<test>", {"<^>", "<greeting>", "<$>"},
- /*callback=*/
- static_cast<grammar::CallbackId>(
- GrammarActions::Callback::kActionRuleMatch),
- /*callback_param=*/
- AddRuleMatch({spec_id}, &action_grammar_rules));
- rules.Finalize().Serialize(/*include_debug_information=*/false,
- action_grammar_rules.rules.get());
- flatbuffers::DetachedBuffer serialized_rules =
- PackRules(action_grammar_rules);
- TestGrammarActions grammar_actions(
- &unilib_,
- flatbuffers::GetRoot<RulesModel_::GrammarRules>(serialized_rules.data()),
- entity_data_builder_.get());
-
- std::vector<ActionSuggestion> result;
- EXPECT_TRUE(grammar_actions.SuggestActions(
- {/*messages=*/{{/*user_id=*/0, /*text=*/"Hello there"}}}, &result));
-
- // Check the produces smart replies.
- EXPECT_THAT(result, ElementsAre(IsSmartReply("General Kenobi!")));
-
- // Check entity data.
- const flatbuffers::Table* entity =
- flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
- result[0].serialized_entity_data.data()));
- EXPECT_THAT(
- entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(),
- "Hello there");
- EXPECT_THAT(
- entity->GetPointer<const flatbuffers::String*>(/*field=*/6)->str(),
- "there");
- EXPECT_THAT(
- entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(),
- "Kenobi");
-}
-
-} // namespace
-} // namespace libtextclassifier3
diff --git a/native/actions/lua-actions_test.cc b/native/actions/lua-actions_test.cc
deleted file mode 100644
index b371387..0000000
--- a/native/actions/lua-actions_test.cc
+++ /dev/null
@@ -1,192 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "actions/lua-actions.h"
-
-#include <map>
-#include <string>
-
-#include "actions/test-utils.h"
-#include "actions/types.h"
-#include "utils/tflite-model-executor.h"
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-
-namespace libtextclassifier3 {
-namespace {
-
-using testing::ElementsAre;
-
-TEST(LuaActions, SimpleAction) {
- Conversation conversation;
- const std::string test_snippet = R"(
- return {{ type = "test_action" }}
- )";
- std::vector<ActionSuggestion> actions;
- EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
- test_snippet, conversation,
- /*model_executor=*/nullptr,
- /*model_spec=*/nullptr,
- /*interpreter=*/nullptr,
- /*actions_entity_data_schema=*/nullptr,
- /*annotations_entity_data_schema=*/nullptr)
- ->SuggestActions(&actions));
- EXPECT_THAT(actions, ElementsAre(IsActionOfType("test_action")));
-}
-
-TEST(LuaActions, ConversationActions) {
- Conversation conversation;
- conversation.messages.push_back({/*user_id=*/0, "hello there!"});
- conversation.messages.push_back({/*user_id=*/1, "general kenobi!"});
- const std::string test_snippet = R"(
- local actions = {}
- for i, message in pairs(messages) do
- if i < #messages then
- if message.text == "hello there!" and
- messages[i+1].text == "general kenobi!" then
- table.insert(actions, {
- type = "text_reply",
- response_text = "you are a bold one!"
- })
- end
- if message.text == "i am the senate!" and
- messages[i+1].text == "not yet!" then
- table.insert(actions, {
- type = "text_reply",
- response_text = "it's treason then"
- })
- end
- end
- end
- return actions;
- )";
- std::vector<ActionSuggestion> actions;
- EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
- test_snippet, conversation,
- /*model_executor=*/nullptr,
- /*model_spec=*/nullptr,
- /*interpreter=*/nullptr,
- /*actions_entity_data_schema=*/nullptr,
- /*annotations_entity_data_schema=*/nullptr)
- ->SuggestActions(&actions));
- EXPECT_THAT(actions, ElementsAre(IsSmartReply("you are a bold one!")));
-}
-
-TEST(LuaActions, SimpleModelAction) {
- Conversation conversation;
- const std::string test_snippet = R"(
- if #model.actions_scores == 0 then
- return {{ type = "test_action" }}
- end
- return {}
- )";
- std::vector<ActionSuggestion> actions;
- EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
- test_snippet, conversation,
- /*model_executor=*/nullptr,
- /*model_spec=*/nullptr,
- /*interpreter=*/nullptr,
- /*actions_entity_data_schema=*/nullptr,
- /*annotations_entity_data_schema=*/nullptr)
- ->SuggestActions(&actions));
- EXPECT_THAT(actions, ElementsAre(IsActionOfType("test_action")));
-}
-
-TEST(LuaActions, AnnotationActions) {
- AnnotatedSpan annotation;
- annotation.span = {11, 15};
- annotation.classification = {ClassificationResult("address", 1.0)};
- Conversation conversation = {{{/*user_id=*/1, "are you at home?",
- /*reference_time_ms_utc=*/0,
- /*reference_timezone=*/"Europe/Zurich",
- /*annotations=*/{annotation},
- /*locales=*/"en"}}};
- const std::string test_snippet = R"(
- local actions = {}
- local last_message = messages[#messages]
- for i, annotation in pairs(last_message.annotation) do
- if #annotation.classification > 0 then
- if annotation.classification[1].collection == "address" then
- local text = string.sub(last_message.text,
- annotation.span["begin"] + 1,
- annotation.span["end"])
- table.insert(actions, {
- type = "text_reply",
- response_text = "i am at " .. text,
- annotation = {{
- name = "location",
- span = {
- text = text
- },
- entity = annotation.classification[1]
- }},
- })
- end
- end
- end
- return actions;
- )";
- std::vector<ActionSuggestion> actions;
- EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
- test_snippet, conversation,
- /*model_executor=*/nullptr,
- /*model_spec=*/nullptr,
- /*interpreter=*/nullptr,
- /*actions_entity_data_schema=*/nullptr,
- /*annotations_entity_data_schema=*/nullptr)
- ->SuggestActions(&actions));
- EXPECT_THAT(actions, ElementsAre(IsSmartReply("i am at home")));
- EXPECT_EQ("address", actions[0].annotations[0].entity.collection);
-}
-
-TEST(LuaActions, EntityData) {
- std::string test_schema = TestEntityDataSchema();
- Conversation conversation = {{{/*user_id=*/1, "hello there"}}};
- const std::string test_snippet = R"(
- return {{
- type = "test",
- entity = {
- greeting = "hello",
- location = "there",
- person = "Kenobi",
- },
- }};
- )";
- std::vector<ActionSuggestion> actions;
- EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
- test_snippet, conversation,
- /*model_executor=*/nullptr,
- /*model_spec=*/nullptr,
- /*interpreter=*/nullptr,
- /*actions_entity_data_schema=*/
- flatbuffers::GetRoot<reflection::Schema>(test_schema.data()),
- /*annotations_entity_data_schema=*/nullptr)
- ->SuggestActions(&actions));
- EXPECT_THAT(actions, testing::SizeIs(1));
- EXPECT_EQ("test", actions.front().type);
- const flatbuffers::Table* entity =
- flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
- actions.front().serialized_entity_data.data()));
- EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(),
- "hello");
- EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/6)->str(),
- "there");
- EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(),
- "Kenobi");
-}
-
-} // namespace
-} // namespace libtextclassifier3
diff --git a/native/actions/lua-ranker_test.cc b/native/actions/lua-ranker_test.cc
deleted file mode 100644
index a790042..0000000
--- a/native/actions/lua-ranker_test.cc
+++ /dev/null
@@ -1,269 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "actions/lua-ranker.h"
-
-#include <string>
-
-#include "actions/types.h"
-#include "utils/flatbuffers.h"
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-
-namespace libtextclassifier3 {
-namespace {
-
-MATCHER_P2(IsAction, type, response_text, "") {
- return testing::Value(arg.type, type) &&
- testing::Value(arg.response_text, response_text);
-}
-
-MATCHER_P(IsActionType, type, "") { return testing::Value(arg.type, type); }
-
-std::string TestEntitySchema() {
- // Create fake entity data schema meta data.
- // Cannot use object oriented API here as that is not available for the
- // reflection schema.
- flatbuffers::FlatBufferBuilder schema_builder;
- std::vector<flatbuffers::Offset<reflection::Field>> fields = {
- reflection::CreateField(
- schema_builder,
- /*name=*/schema_builder.CreateString("test"),
- /*type=*/
- reflection::CreateType(schema_builder,
- /*base_type=*/reflection::String),
- /*id=*/0,
- /*offset=*/4)};
- std::vector<flatbuffers::Offset<reflection::Enum>> enums;
- std::vector<flatbuffers::Offset<reflection::Object>> objects = {
- reflection::CreateObject(
- schema_builder,
- /*name=*/schema_builder.CreateString("EntityData"),
- /*fields=*/
- schema_builder.CreateVectorOfSortedTables(&fields))};
- schema_builder.Finish(reflection::CreateSchema(
- schema_builder, schema_builder.CreateVectorOfSortedTables(&objects),
- schema_builder.CreateVectorOfSortedTables(&enums),
- /*(unused) file_ident=*/0,
- /*(unused) file_ext=*/0,
- /*root_table*/ objects[0]));
- return std::string(
- reinterpret_cast<const char*>(schema_builder.GetBufferPointer()),
- schema_builder.GetSize());
-}
-
-TEST(LuaRankingTest, PassThrough) {
- const Conversation conversation = {{{/*user_id=*/1, "hello hello"}}};
- ActionsSuggestionsResponse response;
- response.actions = {
- {/*response_text=*/"hello there", /*type=*/"text_reply",
- /*score=*/1.0},
- {/*response_text=*/"", /*type=*/"share_location", /*score=*/0.5},
- {/*response_text=*/"", /*type=*/"add_to_collection", /*score=*/0.1}};
- const std::string test_snippet = R"(
- local result = {}
- for i=1,#actions do
- table.insert(result, i)
- end
- return result
- )";
-
- EXPECT_TRUE(ActionsSuggestionsLuaRanker::Create(
- conversation, test_snippet, /*entity_data_schema=*/nullptr,
- /*annotations_entity_data_schema=*/nullptr, &response)
- ->RankActions());
- EXPECT_THAT(response.actions,
- testing::ElementsAreArray({IsActionType("text_reply"),
- IsActionType("share_location"),
- IsActionType("add_to_collection")}));
-}
-
-TEST(LuaRankingTest, Filtering) {
- const Conversation conversation = {{{/*user_id=*/1, "hello hello"}}};
- ActionsSuggestionsResponse response;
- response.actions = {
- {/*response_text=*/"hello there", /*type=*/"text_reply",
- /*score=*/1.0},
- {/*response_text=*/"", /*type=*/"share_location", /*score=*/0.5},
- {/*response_text=*/"", /*type=*/"add_to_collection", /*score=*/0.1}};
- const std::string test_snippet = R"(
- return {}
- )";
-
- EXPECT_TRUE(ActionsSuggestionsLuaRanker::Create(
- conversation, test_snippet, /*entity_data_schema=*/nullptr,
- /*annotations_entity_data_schema=*/nullptr, &response)
- ->RankActions());
- EXPECT_THAT(response.actions, testing::IsEmpty());
-}
-
-TEST(LuaRankingTest, Duplication) {
- const Conversation conversation = {{{/*user_id=*/1, "hello hello"}}};
- ActionsSuggestionsResponse response;
- response.actions = {
- {/*response_text=*/"hello there", /*type=*/"text_reply",
- /*score=*/1.0},
- {/*response_text=*/"", /*type=*/"share_location", /*score=*/0.5},
- {/*response_text=*/"", /*type=*/"add_to_collection", /*score=*/0.1}};
- const std::string test_snippet = R"(
- local result = {}
- for i=1,#actions do
- table.insert(result, 1)
- end
- return result
- )";
-
- EXPECT_TRUE(ActionsSuggestionsLuaRanker::Create(
- conversation, test_snippet, /*entity_data_schema=*/nullptr,
- /*annotations_entity_data_schema=*/nullptr, &response)
- ->RankActions());
- EXPECT_THAT(response.actions,
- testing::ElementsAreArray({IsActionType("text_reply"),
- IsActionType("text_reply"),
- IsActionType("text_reply")}));
-}
-
-TEST(LuaRankingTest, SortByScore) {
- const Conversation conversation = {{{/*user_id=*/1, "hello hello"}}};
- ActionsSuggestionsResponse response;
- response.actions = {
- {/*response_text=*/"hello there", /*type=*/"text_reply",
- /*score=*/1.0},
- {/*response_text=*/"", /*type=*/"share_location", /*score=*/0.5},
- {/*response_text=*/"", /*type=*/"add_to_collection", /*score=*/0.1}};
- const std::string test_snippet = R"(
- function testScoreSorter(a, b)
- return actions[a].score < actions[b].score
- end
- local result = {}
- for i=1,#actions do
- result[i] = i
- end
- table.sort(result, testScoreSorter)
- return result
- )";
-
- EXPECT_TRUE(ActionsSuggestionsLuaRanker::Create(
- conversation, test_snippet, /*entity_data_schema=*/nullptr,
- /*annotations_entity_data_schema=*/nullptr, &response)
- ->RankActions());
- EXPECT_THAT(response.actions,
- testing::ElementsAreArray({IsActionType("add_to_collection"),
- IsActionType("share_location"),
- IsActionType("text_reply")}));
-}
-
-TEST(LuaRankingTest, SuppressType) {
- const Conversation conversation = {{{/*user_id=*/1, "hello hello"}}};
- ActionsSuggestionsResponse response;
- response.actions = {
- {/*response_text=*/"hello there", /*type=*/"text_reply",
- /*score=*/1.0},
- {/*response_text=*/"", /*type=*/"share_location", /*score=*/0.5},
- {/*response_text=*/"", /*type=*/"add_to_collection", /*score=*/0.1}};
- const std::string test_snippet = R"(
- local result = {}
- for id, action in pairs(actions) do
- if action.type ~= "text_reply" then
- table.insert(result, id)
- end
- end
- return result
- )";
-
- EXPECT_TRUE(ActionsSuggestionsLuaRanker::Create(
- conversation, test_snippet, /*entity_data_schema=*/nullptr,
- /*annotations_entity_data_schema=*/nullptr, &response)
- ->RankActions());
- EXPECT_THAT(response.actions,
- testing::ElementsAreArray({IsActionType("share_location"),
- IsActionType("add_to_collection")}));
-}
-
-TEST(LuaRankingTest, HandlesConversation) {
- const Conversation conversation = {{{/*user_id=*/1, "hello hello"}}};
- ActionsSuggestionsResponse response;
- response.actions = {
- {/*response_text=*/"hello there", /*type=*/"text_reply",
- /*score=*/1.0},
- {/*response_text=*/"", /*type=*/"share_location", /*score=*/0.5},
- {/*response_text=*/"", /*type=*/"add_to_collection", /*score=*/0.1}};
- const std::string test_snippet = R"(
- local result = {}
- if messages[1].text ~= "hello hello" then
- return result
- end
- for id, action in pairs(actions) do
- if action.type ~= "text_reply" then
- table.insert(result, id)
- end
- end
- return result
- )";
-
- EXPECT_TRUE(ActionsSuggestionsLuaRanker::Create(
- conversation, test_snippet, /*entity_data_schema=*/nullptr,
- /*annotations_entity_data_schema=*/nullptr, &response)
- ->RankActions());
- EXPECT_THAT(response.actions,
- testing::ElementsAreArray({IsActionType("share_location"),
- IsActionType("add_to_collection")}));
-}
-
-TEST(LuaRankingTest, HandlesEntityData) {
- std::string serialized_schema = TestEntitySchema();
- const reflection::Schema* entity_data_schema =
- flatbuffers::GetRoot<reflection::Schema>(serialized_schema.data());
-
- // Create test entity data.
- ReflectiveFlatbufferBuilder builder(entity_data_schema);
- std::unique_ptr<ReflectiveFlatbuffer> buffer = builder.NewRoot();
- buffer->Set("test", "value_a");
- const std::string serialized_entity_data_a = buffer->Serialize();
- buffer->Set("test", "value_b");
- const std::string serialized_entity_data_b = buffer->Serialize();
-
- const Conversation conversation = {{{/*user_id=*/1, "hello hello"}}};
- ActionsSuggestionsResponse response;
- response.actions = {
- {/*response_text=*/"", /*type=*/"test",
- /*score=*/1.0, /*priority_score=*/1.0, /*annotations=*/{},
- /*serialized_entity_data=*/serialized_entity_data_a},
- {/*response_text=*/"", /*type=*/"test",
- /*score=*/1.0, /*priority_score=*/1.0, /*annotations=*/{},
- /*serialized_entity_data=*/serialized_entity_data_b},
- {/*response_text=*/"", /*type=*/"share_location", /*score=*/0.5},
- {/*response_text=*/"", /*type=*/"add_to_collection", /*score=*/0.1}};
- const std::string test_snippet = R"(
- local result = {}
- for id, action in pairs(actions) do
- if action.type == "test" and action.test == "value_a" then
- table.insert(result, id)
- end
- end
- return result
- )";
-
- EXPECT_TRUE(ActionsSuggestionsLuaRanker::Create(
- conversation, test_snippet, entity_data_schema,
- /*annotations_entity_data_schema=*/nullptr, &response)
- ->RankActions());
- EXPECT_THAT(response.actions,
- testing::ElementsAreArray({IsActionType("test")}));
-}
-
-} // namespace
-} // namespace libtextclassifier3
diff --git a/native/actions/ranker_test.cc b/native/actions/ranker_test.cc
deleted file mode 100644
index b52cf45..0000000
--- a/native/actions/ranker_test.cc
+++ /dev/null
@@ -1,382 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "actions/ranker.h"
-
-#include <string>
-
-#include "actions/types.h"
-#include "utils/zlib/zlib.h"
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-
-namespace libtextclassifier3 {
-namespace {
-
-MATCHER_P3(IsAction, type, response_text, score, "") {
- return testing::Value(arg.type, type) &&
- testing::Value(arg.response_text, response_text) &&
- testing::Value(arg.score, score);
-}
-
-MATCHER_P(IsActionType, type, "") { return testing::Value(arg.type, type); }
-
-TEST(RankingTest, DeduplicationSmartReply) {
- const Conversation conversation = {{{/*user_id=*/1, "hello hello"}}};
- ActionsSuggestionsResponse response;
- response.actions = {
- {/*response_text=*/"hello there", /*type=*/"text_reply",
- /*score=*/1.0},
- {/*response_text=*/"hello there", /*type=*/"text_reply", /*score=*/0.5}};
-
- RankingOptionsT options;
- options.deduplicate_suggestions = true;
- flatbuffers::FlatBufferBuilder builder;
- builder.Finish(RankingOptions::Pack(builder, &options));
- auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
- flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
- /*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
-
- ranker->RankActions(conversation, &response);
- EXPECT_THAT(
- response.actions,
- testing::ElementsAreArray({IsAction("text_reply", "hello there", 1.0)}));
-}
-
-TEST(RankingTest, DeduplicationExtraData) {
- const Conversation conversation = {{{/*user_id=*/1, "hello hello"}}};
- ActionsSuggestionsResponse response;
- response.actions = {
- {/*response_text=*/"hello there", /*type=*/"text_reply",
- /*score=*/1.0, /*priority_score=*/0.0},
- {/*response_text=*/"hello there", /*type=*/"text_reply", /*score=*/0.5,
- /*priority_score=*/0.0},
- {/*response_text=*/"hello there", /*type=*/"text_reply", /*score=*/0.6,
- /*priority_score=*/0.0,
- /*annotations=*/{}, /*serialized_entity_data=*/"test"},
- };
-
- RankingOptionsT options;
- options.deduplicate_suggestions = true;
- flatbuffers::FlatBufferBuilder builder;
- builder.Finish(RankingOptions::Pack(builder, &options));
- auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
- flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
- /*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
-
- ranker->RankActions(conversation, &response);
- EXPECT_THAT(
- response.actions,
- testing::ElementsAreArray({IsAction("text_reply", "hello there", 1.0),
- // Is kept as it has different entity data.
- IsAction("text_reply", "hello there", 0.6)}));
-}
-
-TEST(RankingTest, DeduplicationAnnotations) {
- const Conversation conversation = {
- {{/*user_id=*/1, "742 Evergreen Terrace, the number is 1-800-TESTING"}}};
- ActionsSuggestionsResponse response;
- {
- ActionSuggestionAnnotation annotation;
- annotation.span = {/*message_index=*/0, /*span=*/{0, 21},
- /*text=*/"742 Evergreen Terrace"};
- annotation.entity = ClassificationResult("address", 0.5);
- response.actions.push_back({/*response_text=*/"",
- /*type=*/"view_map",
- /*score=*/0.5,
- /*priority_score=*/1.0,
- /*annotations=*/{annotation}});
- }
- {
- ActionSuggestionAnnotation annotation;
- annotation.span = {/*message_index=*/0, /*span=*/{0, 21},
- /*text=*/"742 Evergreen Terrace"};
- annotation.entity = ClassificationResult("address", 1.0);
- response.actions.push_back({/*response_text=*/"",
- /*type=*/"view_map",
- /*score=*/1.0,
- /*priority_score=*/2.0,
- /*annotations=*/{annotation}});
- }
- {
- ActionSuggestionAnnotation annotation;
- annotation.span = {/*message_index=*/0, /*span=*/{37, 50},
- /*text=*/"1-800-TESTING"};
- annotation.entity = ClassificationResult("phone", 0.5);
- response.actions.push_back({/*response_text=*/"",
- /*type=*/"call_phone",
- /*score=*/0.5,
- /*priority_score=*/1.0,
- /*annotations=*/{annotation}});
- }
-
- RankingOptionsT options;
- options.deduplicate_suggestions = true;
- flatbuffers::FlatBufferBuilder builder;
- builder.Finish(RankingOptions::Pack(builder, &options));
- auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
- flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
- /*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
-
- ranker->RankActions(conversation, &response);
- EXPECT_THAT(response.actions,
- testing::ElementsAreArray({IsAction("view_map", "", 1.0),
- IsAction("call_phone", "", 0.5)}));
-}
-
-TEST(RankingTest, DeduplicationAnnotationsByPriorityScore) {
- const Conversation conversation = {
- {{/*user_id=*/1, "742 Evergreen Terrace, the number is 1-800-TESTING"}}};
- ActionsSuggestionsResponse response;
- {
- ActionSuggestionAnnotation annotation;
- annotation.span = {/*message_index=*/0, /*span=*/{0, 21},
- /*text=*/"742 Evergreen Terrace"};
- annotation.entity = ClassificationResult("address", 0.5);
- response.actions.push_back({/*response_text=*/"",
- /*type=*/"view_map",
- /*score=*/0.6,
- /*priority_score=*/2.0,
- /*annotations=*/{annotation}});
- }
- {
- ActionSuggestionAnnotation annotation;
- annotation.span = {/*message_index=*/0, /*span=*/{0, 21},
- /*text=*/"742 Evergreen Terrace"};
- annotation.entity = ClassificationResult("address", 1.0);
- response.actions.push_back({/*response_text=*/"",
- /*type=*/"view_map",
- /*score=*/1.0,
- /*priority_score=*/1.0,
- /*annotations=*/{annotation}});
- }
- {
- ActionSuggestionAnnotation annotation;
- annotation.span = {/*message_index=*/0, /*span=*/{37, 50},
- /*text=*/"1-800-TESTING"};
- annotation.entity = ClassificationResult("phone", 0.5);
- response.actions.push_back({/*response_text=*/"",
- /*type=*/"call_phone",
- /*score=*/0.5,
- /*priority_score=*/1.0,
- /*annotations=*/{annotation}});
- }
-
- RankingOptionsT options;
- options.deduplicate_suggestions = true;
- flatbuffers::FlatBufferBuilder builder;
- builder.Finish(RankingOptions::Pack(builder, &options));
- auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
- flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
- /*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
-
- ranker->RankActions(conversation, &response);
- EXPECT_THAT(
- response.actions,
- testing::ElementsAreArray(
- {IsAction("view_map", "",
- 0.6), // lower score wins, as priority score is higher
- IsAction("call_phone", "", 0.5)}));
-}
-
-TEST(RankingTest, DeduplicatesConflictingActions) {
- const Conversation conversation = {{{/*user_id=*/1, "code A-911"}}};
- ActionsSuggestionsResponse response;
- {
- ActionSuggestionAnnotation annotation;
- annotation.span = {/*message_index=*/0, /*span=*/{7, 10},
- /*text=*/"911"};
- annotation.entity = ClassificationResult("phone", 1.0);
- response.actions.push_back({/*response_text=*/"",
- /*type=*/"call_phone",
- /*score=*/1.0,
- /*priority_score=*/1.0,
- /*annotations=*/{annotation}});
- }
- {
- ActionSuggestionAnnotation annotation;
- annotation.span = {/*message_index=*/0, /*span=*/{5, 10},
- /*text=*/"A-911"};
- annotation.entity = ClassificationResult("code", 1.0);
- response.actions.push_back({/*response_text=*/"",
- /*type=*/"copy_code",
- /*score=*/1.0,
- /*priority_score=*/2.0,
- /*annotations=*/{annotation}});
- }
- RankingOptionsT options;
- options.deduplicate_suggestions = true;
- flatbuffers::FlatBufferBuilder builder;
- builder.Finish(RankingOptions::Pack(builder, &options));
- auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
- flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
- /*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
-
- ranker->RankActions(conversation, &response);
- EXPECT_THAT(response.actions,
- testing::ElementsAreArray({IsAction("copy_code", "", 1.0)}));
-}
-
-TEST(RankingTest, HandlesCompressedLuaScript) {
- const Conversation conversation = {{{/*user_id=*/1, "hello hello"}}};
- ActionsSuggestionsResponse response;
- response.actions = {
- {/*response_text=*/"hello there", /*type=*/"text_reply",
- /*score=*/1.0},
- {/*response_text=*/"", /*type=*/"share_location", /*score=*/0.5},
- {/*response_text=*/"", /*type=*/"add_to_collection", /*score=*/0.1}};
- const std::string test_snippet = R"(
- local result = {}
- for id, action in pairs(actions) do
- if action.type ~= "text_reply" then
- table.insert(result, id)
- end
- end
- return result
- )";
- RankingOptionsT options;
- options.compressed_lua_ranking_script.reset(new CompressedBufferT);
- std::unique_ptr<ZlibCompressor> compressor = ZlibCompressor::Instance();
- compressor->Compress(test_snippet,
- options.compressed_lua_ranking_script.get());
- options.deduplicate_suggestions = true;
- flatbuffers::FlatBufferBuilder builder;
- builder.Finish(RankingOptions::Pack(builder, &options));
-
- std::unique_ptr<ZlibDecompressor> decompressor = ZlibDecompressor::Instance();
- auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
- flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
- decompressor.get(), /*smart_reply_action_type=*/"text_reply");
-
- ranker->RankActions(conversation, &response);
- EXPECT_THAT(response.actions,
- testing::ElementsAreArray({IsActionType("share_location"),
- IsActionType("add_to_collection")}));
-}
-
-TEST(RankingTest, SuppressSmartRepliesWithAction) {
- const Conversation conversation = {{{/*user_id=*/1, "should i call 911"}}};
- ActionsSuggestionsResponse response;
- {
- ActionSuggestionAnnotation annotation;
- annotation.span = {/*message_index=*/0, /*span=*/{5, 8},
- /*text=*/"911"};
- annotation.entity = ClassificationResult("phone", 1.0);
- response.actions.push_back({/*response_text=*/"",
- /*type=*/"call_phone",
- /*score=*/1.0,
- /*priority_score=*/1.0,
- /*annotations=*/{annotation}});
- }
- response.actions.push_back({/*response_text=*/"How are you?",
- /*type=*/"text_reply"});
- RankingOptionsT options;
- options.suppress_smart_replies_with_actions = true;
- flatbuffers::FlatBufferBuilder builder;
- builder.Finish(RankingOptions::Pack(builder, &options));
- auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
- flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
- /*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
-
- ranker->RankActions(conversation, &response);
-
- EXPECT_THAT(response.actions,
- testing::ElementsAreArray({IsAction("call_phone", "", 1.0)}));
-}
-
-TEST(RankingTest, GroupsActionsByAnnotations) {
- const Conversation conversation = {{{/*user_id=*/1, "should i call 911"}}};
- ActionsSuggestionsResponse response;
- {
- ActionSuggestionAnnotation annotation;
- annotation.span = {/*message_index=*/0, /*span=*/{5, 8},
- /*text=*/"911"};
- annotation.entity = ClassificationResult("phone", 1.0);
- response.actions.push_back({/*response_text=*/"",
- /*type=*/"call_phone",
- /*score=*/1.0,
- /*priority_score=*/1.0,
- /*annotations=*/{annotation}});
- response.actions.push_back({/*response_text=*/"",
- /*type=*/"add_contact",
- /*score=*/0.0,
- /*priority_score=*/0.0,
- /*annotations=*/{annotation}});
- }
- response.actions.push_back({/*response_text=*/"How are you?",
- /*type=*/"text_reply",
- /*score=*/0.5});
- RankingOptionsT options;
- options.group_by_annotations = true;
- flatbuffers::FlatBufferBuilder builder;
- builder.Finish(RankingOptions::Pack(builder, &options));
- auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
- flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
- /*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
-
- ranker->RankActions(conversation, &response);
-
- // The text reply should be last, even though it has a higher score than the
- // `add_contact` action.
- EXPECT_THAT(
- response.actions,
- testing::ElementsAreArray({IsAction("call_phone", "", 1.0),
- IsAction("add_contact", "", 0.0),
- IsAction("text_reply", "How are you?", 0.5)}));
-}
-
-TEST(RankingTest, SortsActionsByScore) {
- const Conversation conversation = {{{/*user_id=*/1, "should i call 911"}}};
- ActionsSuggestionsResponse response;
- {
- ActionSuggestionAnnotation annotation;
- annotation.span = {/*message_index=*/0, /*span=*/{5, 8},
- /*text=*/"911"};
- annotation.entity = ClassificationResult("phone", 1.0);
- response.actions.push_back({/*response_text=*/"",
- /*type=*/"call_phone",
- /*score=*/1.0,
- /*priority_score=*/1.0,
- /*annotations=*/{annotation}});
- response.actions.push_back({/*response_text=*/"",
- /*type=*/"add_contact",
- /*score=*/0.0,
- /*priority_score=*/0.0,
- /*annotations=*/{annotation}});
- }
- response.actions.push_back({/*response_text=*/"How are you?",
- /*type=*/"text_reply",
- /*score=*/0.5});
- RankingOptionsT options;
- // Don't group by annotation.
- options.group_by_annotations = false;
- flatbuffers::FlatBufferBuilder builder;
- builder.Finish(RankingOptions::Pack(builder, &options));
- auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
- flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
- /*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
-
- ranker->RankActions(conversation, &response);
-
- EXPECT_THAT(
- response.actions,
- testing::ElementsAreArray({IsAction("call_phone", "", 1.0),
- IsAction("text_reply", "How are you?", 0.5),
- IsAction("add_contact", "", 0.0)}));
-}
-
-} // namespace
-} // namespace libtextclassifier3
diff --git a/native/actions/test-utils.cc b/native/actions/test-utils.cc
deleted file mode 100644
index 9b003dd..0000000
--- a/native/actions/test-utils.cc
+++ /dev/null
@@ -1,78 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "actions/test-utils.h"
-
-namespace libtextclassifier3 {
-
-std::string TestEntityDataSchema() {
- // Create fake entity data schema meta data.
- // Cannot use object oriented API here as that is not available for the
- // reflection schema.
- flatbuffers::FlatBufferBuilder schema_builder;
- std::vector<flatbuffers::Offset<reflection::Field>> fields = {
- reflection::CreateField(
- schema_builder,
- /*name=*/schema_builder.CreateString("greeting"),
- /*type=*/
- reflection::CreateType(schema_builder,
- /*base_type=*/reflection::String),
- /*id=*/0,
- /*offset=*/4),
- reflection::CreateField(
- schema_builder,
- /*name=*/schema_builder.CreateString("location"),
- /*type=*/
- reflection::CreateType(schema_builder,
- /*base_type=*/reflection::String),
- /*id=*/1,
- /*offset=*/6),
- reflection::CreateField(
- schema_builder,
- /*name=*/schema_builder.CreateString("person"),
- /*type=*/
- reflection::CreateType(schema_builder,
- /*base_type=*/reflection::String),
- /*id=*/2,
- /*offset=*/8)};
- std::vector<flatbuffers::Offset<reflection::Enum>> enums;
- std::vector<flatbuffers::Offset<reflection::Object>> objects = {
- reflection::CreateObject(
- schema_builder,
- /*name=*/schema_builder.CreateString("EntityData"),
- /*fields=*/
- schema_builder.CreateVectorOfSortedTables(&fields))};
- schema_builder.Finish(reflection::CreateSchema(
- schema_builder, schema_builder.CreateVectorOfSortedTables(&objects),
- schema_builder.CreateVectorOfSortedTables(&enums),
- /*(unused) file_ident=*/0,
- /*(unused) file_ext=*/0,
- /*root_table*/ objects[0]));
-
- return std::string(
- reinterpret_cast<const char*>(schema_builder.GetBufferPointer()),
- schema_builder.GetSize());
-}
-
-void SetTestEntityDataSchema(ActionsModelT* test_model) {
- const std::string serialized_schema = TestEntityDataSchema();
-
- test_model->actions_entity_data_schema.assign(
- serialized_schema.data(),
- serialized_schema.data() + serialized_schema.size());
-}
-
-} // namespace libtextclassifier3
diff --git a/native/actions/test-utils.h b/native/actions/test-utils.h
deleted file mode 100644
index c05d6a9..0000000
--- a/native/actions/test-utils.h
+++ /dev/null
@@ -1,51 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_ACTIONS_TEST_UTILS_H_
-#define LIBTEXTCLASSIFIER_ACTIONS_TEST_UTILS_H_
-
-#include <string>
-
-#include "actions/actions_model_generated.h"
-#include "utils/flatbuffers.h"
-#include "gmock/gmock.h"
-
-namespace libtextclassifier3 {
-
-using testing::ExplainMatchResult;
-using testing::Value;
-
-// Create test entity data schema.
-std::string TestEntityDataSchema();
-void SetTestEntityDataSchema(ActionsModelT* test_model);
-
-MATCHER_P(IsActionOfType, type, "") { return Value(arg.type, type); }
-MATCHER_P(IsSmartReply, response_text, "") {
- return ExplainMatchResult(IsActionOfType("text_reply"), arg,
- result_listener) &&
- Value(arg.response_text, response_text);
-}
-MATCHER_P(IsSpan, span, "") {
- return Value(arg.first, span.first) && Value(arg.second, span.second);
-}
-MATCHER_P3(IsActionSuggestionAnnotation, name, text, span, "") {
- return Value(arg.name, name) && Value(arg.span.text, text) &&
- ExplainMatchResult(IsSpan(span), arg.span.span, result_listener);
-}
-
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_ACTIONS_TEST_UTILS_H_
diff --git a/native/actions/zlib-utils_test.cc b/native/actions/zlib-utils_test.cc
deleted file mode 100644
index befee31..0000000
--- a/native/actions/zlib-utils_test.cc
+++ /dev/null
@@ -1,77 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "actions/zlib-utils.h"
-
-#include <memory>
-
-#include "actions/actions_model_generated.h"
-#include "utils/zlib/zlib.h"
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-
-namespace libtextclassifier3 {
-namespace {
-
-using testing::ElementsAre;
-using testing::Field;
-using testing::Pointee;
-
-TEST(ZlibUtilsTest, CompressModel) {
- ActionsModelT model;
- constexpr char kTestPattern1[] = "this is a test pattern";
- constexpr char kTestPattern2[] = "this is a second test pattern";
- model.rules.reset(new RulesModelT);
- model.rules->regex_rule.emplace_back(new RulesModel_::RegexRuleT);
- model.rules->regex_rule.back()->pattern = kTestPattern1;
- model.rules->regex_rule.emplace_back(new RulesModel_::RegexRuleT);
- model.rules->regex_rule.back()->pattern = kTestPattern2;
-
- // Compress the model.
- EXPECT_TRUE(CompressActionsModel(&model));
-
- // Sanity check that uncompressed field is removed.
- const auto is_empty_pattern =
- Pointee(Field(&libtextclassifier3::RulesModel_::RegexRuleT::pattern,
- testing::IsEmpty()));
- EXPECT_THAT(model.rules->regex_rule,
- ElementsAre(is_empty_pattern, is_empty_pattern));
- // Pack and load the model.
- flatbuffers::FlatBufferBuilder builder;
- FinishActionsModelBuffer(builder, ActionsModel::Pack(builder, &model));
- const ActionsModel* compressed_model = GetActionsModel(
- reinterpret_cast<const char*>(builder.GetBufferPointer()));
- ASSERT_TRUE(compressed_model != nullptr);
-
- // Decompress the fields again and check that they match the original.
- std::unique_ptr<ZlibDecompressor> decompressor = ZlibDecompressor::Instance();
- ASSERT_TRUE(decompressor != nullptr);
- std::string uncompressed_pattern;
- EXPECT_TRUE(decompressor->MaybeDecompress(
- compressed_model->rules()->regex_rule()->Get(0)->compressed_pattern(),
- &uncompressed_pattern));
- EXPECT_EQ(uncompressed_pattern, kTestPattern1);
- EXPECT_TRUE(decompressor->MaybeDecompress(
- compressed_model->rules()->regex_rule()->Get(1)->compressed_pattern(),
- &uncompressed_pattern));
- EXPECT_EQ(uncompressed_pattern, kTestPattern2);
- EXPECT_TRUE(DecompressActionsModel(&model));
- EXPECT_EQ(model.rules->regex_rule[0]->pattern, kTestPattern1);
- EXPECT_EQ(model.rules->regex_rule[1]->pattern, kTestPattern2);
-}
-
-} // namespace
-} // namespace libtextclassifier3
diff --git a/native/annotator/annotator.cc b/native/annotator/annotator.cc
index d992e7a..d9c69a2 100644
--- a/native/annotator/annotator.cc
+++ b/native/annotator/annotator.cc
@@ -17,7 +17,6 @@
#include "annotator/annotator.h"
#include <algorithm>
-#include <cctype>
#include <cmath>
#include <iterator>
#include <numeric>
@@ -38,6 +37,7 @@
#include "utils/strings/numbers.h"
#include "utils/strings/split.h"
#include "utils/utf8/unicodetext.h"
+#include "utils/utf8/unilib-common.h"
#include "utils/zlib/zlib_regex.h"
namespace libtextclassifier3 {
@@ -412,7 +412,9 @@
/*tokenizer_options=*/
model_->grammar_datetime_model()->grammar_tokenizer_options(),
*calendarlib_,
- /*datetime_rules=*/model_->grammar_datetime_model()->datetime_rules()));
+ /*datetime_rules=*/model_->grammar_datetime_model()->datetime_rules(),
+ model_->grammar_datetime_model()->target_classification_score(),
+ model_->grammar_datetime_model()->priority_score()));
if (!cfg_datetime_parser_) {
TC3_LOG(ERROR) << "Could not initialize context free grammar based "
"datetime parser.";
@@ -594,17 +596,15 @@
}
}
-bool Annotator::InitializePersonNameEngineFromFileDescriptor(int fd, int offset,
- int size) {
- std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd, offset, size));
-
- if (!mmap->handle().ok()) {
+bool Annotator::InitializePersonNameEngineFromScopedMmap(
+ const ScopedMmap& mmap) {
+ if (!mmap.handle().ok()) {
TC3_LOG(ERROR) << "Mmap for person name model failed.";
return false;
}
const PersonNameModel* person_name_model = LoadAndVerifyPersonNameModel(
- mmap->handle().start(), mmap->handle().num_bytes());
+ mmap.handle().start(), mmap.handle().num_bytes());
if (person_name_model == nullptr) {
TC3_LOG(ERROR) << "Person name model verification failed.";
@@ -616,7 +616,7 @@
}
std::unique_ptr<PersonNameEngine> person_name_engine(
- new PersonNameEngine(unilib_));
+ new PersonNameEngine(selection_feature_processor_.get(), unilib_));
if (!person_name_engine->Initialize(person_name_model)) {
TC3_LOG(ERROR) << "Failed to initialize the person name engine.";
return false;
@@ -625,6 +625,17 @@
return true;
}
+bool Annotator::InitializePersonNameEngineFromPath(const std::string& path) {
+ std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(path));
+ return InitializePersonNameEngineFromScopedMmap(*mmap);
+}
+
+bool Annotator::InitializePersonNameEngineFromFileDescriptor(int fd, int offset,
+ int size) {
+ std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd, offset, size));
+ return InitializePersonNameEngineFromScopedMmap(*mmap);
+}
+
namespace {
int CountDigits(const std::string& str, CodepointSpan selection_indices) {
@@ -633,7 +644,7 @@
const UnicodeText unicode_str = UTF8ToUnicodeText(str, /*do_copy=*/false);
for (auto it = unicode_str.begin(); it != unicode_str.end(); ++it, ++i) {
if (i >= selection_indices.first && i < selection_indices.second &&
- isdigit(*it)) {
+ IsDigit(*it)) {
++count;
}
}
@@ -2259,14 +2270,14 @@
std::unique_ptr<EntityDataT> data =
LoadAndVerifyMutableFlatbuffer<libtextclassifier3::EntityData>(
*serialized_entity_data);
- if (data == nullptr) {
+ if (data == nullptr || data->money->unnormalized_amount.empty()) {
return false;
}
UnicodeText amount =
UTF8ToUnicodeText(data->money->unnormalized_amount, /*do_copy=*/false);
int separator_back_index = 0;
- auto it_decimal_separator = amount.end();
+ auto it_decimal_separator = --amount.end();
for (; it_decimal_separator != amount.begin();
--it_decimal_separator, ++separator_back_index) {
if (std::find(money_separators_.begin(), money_separators_.end(),
@@ -2279,7 +2290,7 @@
// If there are 3 digits after the last separator, we consider that a
// thousands separator => the number is an int (e.g. 1.234 is considered int).
// If there is no separator in number, also that number is an int.
- if (separator_back_index == 4 || it_decimal_separator == amount.begin()) {
+ if (separator_back_index == 3 || it_decimal_separator == amount.begin()) {
it_decimal_separator = amount.end();
}
@@ -2295,7 +2306,7 @@
const int amount_codepoints_size = amount.size_codepoints();
if (!unilib_->ParseInt32(
UnicodeText::Substring(
- amount, amount_codepoints_size - separator_back_index + 1,
+ amount, amount_codepoints_size - separator_back_index,
amount_codepoints_size, /*do_copy=*/false),
&data->money->amount_decimal_part)) {
TC3_LOG(ERROR) << "Could not parse the money decimal part as int32.";
diff --git a/native/annotator/annotator.h b/native/annotator/annotator.h
index 67ae92c..c20eb9e 100644
--- a/native/annotator/annotator.h
+++ b/native/annotator/annotator.h
@@ -139,6 +139,14 @@
// Initializes the installed app engine with the given config.
bool InitializeInstalledAppEngine(const std::string& serialized_config);
+ // Initializes the person name engine with the given person name model from
+ // the provided mmap.
+ bool InitializePersonNameEngineFromScopedMmap(const ScopedMmap& mmap);
+
+ // Initializes the person name engine with the given person name model in the
+ // provided file path.
+ bool InitializePersonNameEngineFromPath(const std::string& path);
+
// Initializes the person name engine with the given person name model in the
// provided file descriptor.
bool InitializePersonNameEngineFromFileDescriptor(int fd, int offset,
diff --git a/native/annotator/annotator_jni_test.cc b/native/annotator/annotator_jni_test.cc
deleted file mode 100644
index 929fb59..0000000
--- a/native/annotator/annotator_jni_test.cc
+++ /dev/null
@@ -1,79 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "annotator/annotator_jni.h"
-
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-
-namespace libtextclassifier3 {
-namespace {
-
-TEST(Annotator, ConvertIndicesBMPUTF8) {
- // Test boundary cases.
- EXPECT_EQ(ConvertIndicesBMPToUTF8("hello", {0, 5}), std::make_pair(0, 5));
- EXPECT_EQ(ConvertIndicesUTF8ToBMP("hello", {0, 5}), std::make_pair(0, 5));
-
- EXPECT_EQ(ConvertIndicesBMPToUTF8("hello world", {0, 5}),
- std::make_pair(0, 5));
- EXPECT_EQ(ConvertIndicesUTF8ToBMP("hello world", {0, 5}),
- std::make_pair(0, 5));
- EXPECT_EQ(ConvertIndicesBMPToUTF8("😁ello world", {0, 6}),
- std::make_pair(0, 5));
- EXPECT_EQ(ConvertIndicesUTF8ToBMP("😁ello world", {0, 5}),
- std::make_pair(0, 6));
-
- EXPECT_EQ(ConvertIndicesBMPToUTF8("hello world", {6, 11}),
- std::make_pair(6, 11));
- EXPECT_EQ(ConvertIndicesUTF8ToBMP("hello world", {6, 11}),
- std::make_pair(6, 11));
- EXPECT_EQ(ConvertIndicesBMPToUTF8("hello worl😁", {6, 12}),
- std::make_pair(6, 11));
- EXPECT_EQ(ConvertIndicesUTF8ToBMP("hello worl😁", {6, 11}),
- std::make_pair(6, 12));
-
- // Simple example where the longer character is before the selection.
- // character 😁 is 0x1f601
- EXPECT_EQ(ConvertIndicesBMPToUTF8("😁 Hello World.", {3, 8}),
- std::make_pair(2, 7));
-
- EXPECT_EQ(ConvertIndicesUTF8ToBMP("😁 Hello World.", {2, 7}),
- std::make_pair(3, 8));
-
- // Longer character is before and in selection.
- EXPECT_EQ(ConvertIndicesBMPToUTF8("😁 Hell😁 World.", {3, 9}),
- std::make_pair(2, 7));
-
- EXPECT_EQ(ConvertIndicesUTF8ToBMP("😁 Hell😁 World.", {2, 7}),
- std::make_pair(3, 9));
-
- // Longer character is before and after selection.
- EXPECT_EQ(ConvertIndicesBMPToUTF8("😁 Hello😁World.", {3, 8}),
- std::make_pair(2, 7));
-
- EXPECT_EQ(ConvertIndicesUTF8ToBMP("😁 Hello😁World.", {2, 7}),
- std::make_pair(3, 8));
-
- // Longer character is before in after selection.
- EXPECT_EQ(ConvertIndicesBMPToUTF8("😁 Hell😁😁World.", {3, 9}),
- std::make_pair(2, 7));
-
- EXPECT_EQ(ConvertIndicesUTF8ToBMP("😁 Hell😁😁World.", {2, 7}),
- std::make_pair(3, 9));
-}
-
-} // namespace
-} // namespace libtextclassifier3
diff --git a/native/annotator/cached-features_test.cc b/native/annotator/cached-features_test.cc
deleted file mode 100644
index 702f3ca..0000000
--- a/native/annotator/cached-features_test.cc
+++ /dev/null
@@ -1,157 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "annotator/cached-features.h"
-
-#include "annotator/model-executor.h"
-#include "utils/tensor-view.h"
-
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-
-using testing::ElementsAreArray;
-using testing::FloatEq;
-using testing::Matcher;
-
-namespace libtextclassifier3 {
-namespace {
-
-Matcher<std::vector<float>> ElementsAreFloat(const std::vector<float>& values) {
- std::vector<Matcher<float>> matchers;
- for (const float value : values) {
- matchers.push_back(FloatEq(value));
- }
- return ElementsAreArray(matchers);
-}
-
-std::unique_ptr<std::vector<float>> MakeFeatures(int num_tokens) {
- std::unique_ptr<std::vector<float>> features(new std::vector<float>());
- for (int i = 1; i <= num_tokens; ++i) {
- features->push_back(i * 11.0f);
- features->push_back(-i * 11.0f);
- features->push_back(i * 0.1f);
- }
- return features;
-}
-
-std::vector<float> GetCachedClickContextFeatures(
- const CachedFeatures& cached_features, int click_pos) {
- std::vector<float> output_features;
- cached_features.AppendClickContextFeaturesForClick(click_pos,
- &output_features);
- return output_features;
-}
-
-std::vector<float> GetCachedBoundsSensitiveFeatures(
- const CachedFeatures& cached_features, TokenSpan selected_span) {
- std::vector<float> output_features;
- cached_features.AppendBoundsSensitiveFeaturesForSpan(selected_span,
- &output_features);
- return output_features;
-}
-
-TEST(CachedFeaturesTest, ClickContext) {
- FeatureProcessorOptionsT options;
- options.context_size = 2;
- options.feature_version = 1;
- flatbuffers::FlatBufferBuilder builder;
- builder.Finish(CreateFeatureProcessorOptions(builder, &options));
- flatbuffers::DetachedBuffer options_fb = builder.Release();
-
- std::unique_ptr<std::vector<float>> features = MakeFeatures(9);
- std::unique_ptr<std::vector<float>> padding_features(
- new std::vector<float>{112233.0, -112233.0, 321.0});
-
- const std::unique_ptr<CachedFeatures> cached_features =
- CachedFeatures::Create(
- {3, 10}, std::move(features), std::move(padding_features),
- flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
- /*feature_vector_size=*/3);
- ASSERT_TRUE(cached_features);
-
- EXPECT_THAT(GetCachedClickContextFeatures(*cached_features, 5),
- ElementsAreFloat({11.0, -11.0, 0.1, 22.0, -22.0, 0.2, 33.0, -33.0,
- 0.3, 44.0, -44.0, 0.4, 55.0, -55.0, 0.5}));
-
- EXPECT_THAT(GetCachedClickContextFeatures(*cached_features, 6),
- ElementsAreFloat({22.0, -22.0, 0.2, 33.0, -33.0, 0.3, 44.0, -44.0,
- 0.4, 55.0, -55.0, 0.5, 66.0, -66.0, 0.6}));
-
- EXPECT_THAT(GetCachedClickContextFeatures(*cached_features, 7),
- ElementsAreFloat({33.0, -33.0, 0.3, 44.0, -44.0, 0.4, 55.0, -55.0,
- 0.5, 66.0, -66.0, 0.6, 77.0, -77.0, 0.7}));
-}
-
-TEST(CachedFeaturesTest, BoundsSensitive) {
- std::unique_ptr<FeatureProcessorOptions_::BoundsSensitiveFeaturesT> config(
- new FeatureProcessorOptions_::BoundsSensitiveFeaturesT());
- config->enabled = true;
- config->num_tokens_before = 2;
- config->num_tokens_inside_left = 2;
- config->num_tokens_inside_right = 2;
- config->num_tokens_after = 2;
- config->include_inside_bag = true;
- config->include_inside_length = true;
- FeatureProcessorOptionsT options;
- options.bounds_sensitive_features = std::move(config);
- options.feature_version = 2;
- flatbuffers::FlatBufferBuilder builder;
- builder.Finish(CreateFeatureProcessorOptions(builder, &options));
- flatbuffers::DetachedBuffer options_fb = builder.Release();
-
- std::unique_ptr<std::vector<float>> features = MakeFeatures(9);
- std::unique_ptr<std::vector<float>> padding_features(
- new std::vector<float>{112233.0, -112233.0, 321.0});
-
- const std::unique_ptr<CachedFeatures> cached_features =
- CachedFeatures::Create(
- {3, 9}, std::move(features), std::move(padding_features),
- flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
- /*feature_vector_size=*/3);
- ASSERT_TRUE(cached_features);
-
- EXPECT_THAT(
- GetCachedBoundsSensitiveFeatures(*cached_features, {5, 8}),
- ElementsAreFloat({11.0, -11.0, 0.1, 22.0, -22.0, 0.2, 33.0,
- -33.0, 0.3, 44.0, -44.0, 0.4, 44.0, -44.0,
- 0.4, 55.0, -55.0, 0.5, 66.0, -66.0, 0.6,
- 112233.0, -112233.0, 321.0, 44.0, -44.0, 0.4, 3.0}));
-
- EXPECT_THAT(
- GetCachedBoundsSensitiveFeatures(*cached_features, {5, 7}),
- ElementsAreFloat({11.0, -11.0, 0.1, 22.0, -22.0, 0.2, 33.0,
- -33.0, 0.3, 44.0, -44.0, 0.4, 33.0, -33.0,
- 0.3, 44.0, -44.0, 0.4, 55.0, -55.0, 0.5,
- 66.0, -66.0, 0.6, 38.5, -38.5, 0.35, 2.0}));
-
- EXPECT_THAT(
- GetCachedBoundsSensitiveFeatures(*cached_features, {6, 8}),
- ElementsAreFloat({22.0, -22.0, 0.2, 33.0, -33.0, 0.3, 44.0,
- -44.0, 0.4, 55.0, -55.0, 0.5, 44.0, -44.0,
- 0.4, 55.0, -55.0, 0.5, 66.0, -66.0, 0.6,
- 112233.0, -112233.0, 321.0, 49.5, -49.5, 0.45, 2.0}));
-
- EXPECT_THAT(
- GetCachedBoundsSensitiveFeatures(*cached_features, {6, 7}),
- ElementsAreFloat({22.0, -22.0, 0.2, 33.0, -33.0, 0.3,
- 44.0, -44.0, 0.4, 112233.0, -112233.0, 321.0,
- 112233.0, -112233.0, 321.0, 44.0, -44.0, 0.4,
- 55.0, -55.0, 0.5, 66.0, -66.0, 0.6,
- 44.0, -44.0, 0.4, 1.0}));
-}
-
-} // namespace
-} // namespace libtextclassifier3
diff --git a/native/annotator/datetime/parser_test.cc b/native/annotator/datetime/parser_test.cc
deleted file mode 100644
index 1ddcf50..0000000
--- a/native/annotator/datetime/parser_test.cc
+++ /dev/null
@@ -1,1439 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "annotator/datetime/parser.h"
-
-#include <time.h>
-
-#include <fstream>
-#include <iostream>
-#include <memory>
-#include <string>
-
-#include "annotator/annotator.h"
-#include "annotator/model_generated.h"
-#include "annotator/types-test-util.h"
-#include "utils/testing/annotator.h"
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-
-using std::vector;
-using testing::ElementsAreArray;
-
-namespace libtextclassifier3 {
-namespace {
-// Builder class to construct the DatetimeComponents and make the test readable.
-class DatetimeComponentsBuilder {
- public:
- DatetimeComponentsBuilder Add(DatetimeComponent::ComponentType type,
- int value) {
- DatetimeComponent component;
- component.component_type = type;
- component.value = value;
- return AddComponent(component);
- }
-
- DatetimeComponentsBuilder Add(
- DatetimeComponent::ComponentType type, int value,
- DatetimeComponent::RelativeQualifier relative_qualifier,
- int relative_count) {
- DatetimeComponent component;
- component.component_type = type;
- component.value = value;
- component.relative_qualifier = relative_qualifier;
- component.relative_count = relative_count;
- return AddComponent(component);
- }
-
- std::vector<DatetimeComponent> Build() {
- std::vector<DatetimeComponent> result(datetime_components_);
- datetime_components_.clear();
- return result;
- }
-
- private:
- DatetimeComponentsBuilder AddComponent(
- const DatetimeComponent& datetime_component) {
- datetime_components_.push_back(datetime_component);
- return *this;
- }
- std::vector<DatetimeComponent> datetime_components_;
-};
-
-std::string GetModelPath() {
- return TC3_TEST_DATA_DIR;
-}
-
-std::string ReadFile(const std::string& file_name) {
- std::ifstream file_stream(file_name);
- return std::string(std::istreambuf_iterator<char>(file_stream), {});
-}
-
-class ParserTest : public testing::Test {
- public:
- void SetUp() override {
- // Loads default unmodified model. Individual tests can call LoadModel to
- // make changes.
- LoadModel([](ModelT* model) {});
- }
-
- template <typename Fn>
- void LoadModel(Fn model_visitor_fn) {
- std::string model_buffer = ReadFile(GetModelPath() + "test_model.fb");
- model_buffer_ = ModifyAnnotatorModel(model_buffer, model_visitor_fn);
- classifier_ = Annotator::FromUnownedBuffer(model_buffer_.data(),
- model_buffer_.size(), &unilib_);
- TC3_CHECK(classifier_);
- parser_ = classifier_->DatetimeParserForTests();
- TC3_CHECK(parser_);
- }
-
- bool HasNoResult(const std::string& text, bool anchor_start_end = false,
- const std::string& timezone = "Europe/Zurich",
- AnnotationUsecase annotation_usecase =
- AnnotationUsecase_ANNOTATION_USECASE_SMART) {
- std::vector<DatetimeParseResultSpan> results;
- if (!parser_->Parse(text, 0, timezone, /*locales=*/"", ModeFlag_ANNOTATION,
- annotation_usecase, anchor_start_end, &results)) {
- TC3_LOG(ERROR) << text;
- TC3_CHECK(false);
- }
- return results.empty();
- }
-
- bool ParsesCorrectly(const std::string& marked_text,
- const vector<int64>& expected_ms_utcs,
- DatetimeGranularity expected_granularity,
- vector<vector<DatetimeComponent>> datetime_components,
- bool anchor_start_end = false,
- const std::string& timezone = "Europe/Zurich",
- const std::string& locales = "en-US",
- AnnotationUsecase annotation_usecase =
- AnnotationUsecase_ANNOTATION_USECASE_SMART) {
- const UnicodeText marked_text_unicode =
- UTF8ToUnicodeText(marked_text, /*do_copy=*/false);
- auto brace_open_it =
- std::find(marked_text_unicode.begin(), marked_text_unicode.end(), '{');
- auto brace_end_it =
- std::find(marked_text_unicode.begin(), marked_text_unicode.end(), '}');
- TC3_CHECK(brace_open_it != marked_text_unicode.end());
- TC3_CHECK(brace_end_it != marked_text_unicode.end());
-
- std::string text;
- text +=
- UnicodeText::UTF8Substring(marked_text_unicode.begin(), brace_open_it);
- text += UnicodeText::UTF8Substring(std::next(brace_open_it), brace_end_it);
- text += UnicodeText::UTF8Substring(std::next(brace_end_it),
- marked_text_unicode.end());
-
- std::vector<DatetimeParseResultSpan> results;
-
- if (!parser_->Parse(text, 0, timezone, locales, ModeFlag_ANNOTATION,
- annotation_usecase, anchor_start_end, &results)) {
- TC3_LOG(ERROR) << text;
- TC3_CHECK(false);
- }
- if (results.empty()) {
- TC3_LOG(ERROR) << "No results.";
- return false;
- }
-
- const int expected_start_index =
- std::distance(marked_text_unicode.begin(), brace_open_it);
- // The -1 below is to account for the opening bracket character.
- const int expected_end_index =
- std::distance(marked_text_unicode.begin(), brace_end_it) - 1;
-
- std::vector<DatetimeParseResultSpan> filtered_results;
- for (const DatetimeParseResultSpan& result : results) {
- if (SpansOverlap(result.span,
- {expected_start_index, expected_end_index})) {
- filtered_results.push_back(result);
- }
- }
- std::vector<DatetimeParseResultSpan> expected{
- {{expected_start_index, expected_end_index},
- {},
- /*target_classification_score=*/1.0,
- /*priority_score=*/0.1}};
- expected[0].data.resize(expected_ms_utcs.size());
- for (int i = 0; i < expected_ms_utcs.size(); i++) {
- expected[0].data[i] = {expected_ms_utcs[i], expected_granularity,
- datetime_components[i]};
- }
-
- const bool matches =
- testing::Matches(ElementsAreArray(expected))(filtered_results);
- if (!matches) {
- TC3_LOG(ERROR) << "Expected: " << expected[0];
- if (filtered_results.empty()) {
- TC3_LOG(ERROR) << "But got no results.";
- }
- TC3_LOG(ERROR) << "Actual: " << filtered_results[0];
- }
-
- return matches;
- }
-
- bool ParsesCorrectly(const std::string& marked_text,
- const int64 expected_ms_utc,
- DatetimeGranularity expected_granularity,
- vector<vector<DatetimeComponent>> datetime_components,
- bool anchor_start_end = false,
- const std::string& timezone = "Europe/Zurich",
- const std::string& locales = "en-US",
- AnnotationUsecase annotation_usecase =
- AnnotationUsecase_ANNOTATION_USECASE_SMART) {
- return ParsesCorrectly(marked_text, vector<int64>{expected_ms_utc},
- expected_granularity, datetime_components,
- anchor_start_end, timezone, locales,
- annotation_usecase);
- }
-
- bool ParsesCorrectlyGerman(
- const std::string& marked_text, const vector<int64>& expected_ms_utcs,
- DatetimeGranularity expected_granularity,
- vector<vector<DatetimeComponent>> datetime_components) {
- return ParsesCorrectly(marked_text, expected_ms_utcs, expected_granularity,
- datetime_components,
- /*anchor_start_end=*/false,
- /*timezone=*/"Europe/Zurich", /*locales=*/"de");
- }
-
- bool ParsesCorrectlyGerman(
- const std::string& marked_text, const int64 expected_ms_utc,
- DatetimeGranularity expected_granularity,
- vector<vector<DatetimeComponent>> datetime_components) {
- return ParsesCorrectly(marked_text, expected_ms_utc, expected_granularity,
- datetime_components,
- /*anchor_start_end=*/false,
- /*timezone=*/"Europe/Zurich", /*locales=*/"de");
- }
-
- protected:
- std::string model_buffer_;
- std::unique_ptr<Annotator> classifier_;
- const DatetimeParser* parser_;
- UniLib unilib_;
-};
-
-// Test with just a few cases to make debugging of general failures easier.
-TEST_F(ParserTest, ParseShort) {
- EXPECT_TRUE(ParsesCorrectly(
- "{January 1, 1988}", 567990000000, GRANULARITY_DAY,
- {DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
- .Add(DatetimeComponent::ComponentType::MONTH, 1)
- .Add(DatetimeComponent::ComponentType::YEAR, 1988)
- .Build()}));
-}
-
-TEST_F(ParserTest, Parse) {
- EXPECT_TRUE(ParsesCorrectly(
- "{January 1, 1988}", 567990000000, GRANULARITY_DAY,
- {DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
- .Add(DatetimeComponent::ComponentType::MONTH, 1)
- .Add(DatetimeComponent::ComponentType::YEAR, 1988)
- .Build()}));
- EXPECT_TRUE(ParsesCorrectly(
- "{january 31 2018}", 1517353200000, GRANULARITY_DAY,
- {DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 31)
- .Add(DatetimeComponent::ComponentType::MONTH, 1)
- .Add(DatetimeComponent::ComponentType::YEAR, 2018)
- .Build()}));
- EXPECT_TRUE(ParsesCorrectly(
- "lorem {1 january 2018} ipsum", 1514761200000, GRANULARITY_DAY,
- {DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
- .Add(DatetimeComponent::ComponentType::MONTH, 1)
- .Add(DatetimeComponent::ComponentType::YEAR, 2018)
- .Build()}));
- EXPECT_TRUE(ParsesCorrectly(
- "{09/Mar/2004 22:02:40}", 1078866160000, GRANULARITY_SECOND,
- {DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::SECOND, 40)
- .Add(DatetimeComponent::ComponentType::MINUTE, 02)
- .Add(DatetimeComponent::ComponentType::HOUR, 22)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 9)
- .Add(DatetimeComponent::ComponentType::MONTH, 3)
- .Add(DatetimeComponent::ComponentType::YEAR, 2004)
- .Build()}));
- EXPECT_TRUE(ParsesCorrectly(
- "{Dec 2, 2010 2:39:58 AM}", 1291253998000, GRANULARITY_SECOND,
- {DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
- .Add(DatetimeComponent::ComponentType::SECOND, 58)
- .Add(DatetimeComponent::ComponentType::MINUTE, 39)
- .Add(DatetimeComponent::ComponentType::HOUR, 2)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 2)
- .Add(DatetimeComponent::ComponentType::MONTH, 12)
- .Add(DatetimeComponent::ComponentType::YEAR, 2010)
- .Build()}));
- EXPECT_TRUE(ParsesCorrectly(
- "{Jun 09 2011 15:28:14}", 1307626094000, GRANULARITY_SECOND,
- {DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::SECOND, 14)
- .Add(DatetimeComponent::ComponentType::MINUTE, 28)
- .Add(DatetimeComponent::ComponentType::HOUR, 15)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 9)
- .Add(DatetimeComponent::ComponentType::MONTH, 6)
- .Add(DatetimeComponent::ComponentType::YEAR, 2011)
- .Build()}));
- EXPECT_TRUE(ParsesCorrectly(
- "{Mar 16 08:12:04}", {6419524000, 6462724000}, GRANULARITY_SECOND,
- {DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
- .Add(DatetimeComponent::ComponentType::SECOND, 4)
- .Add(DatetimeComponent::ComponentType::MINUTE, 12)
- .Add(DatetimeComponent::ComponentType::HOUR, 8)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 16)
- .Add(DatetimeComponent::ComponentType::MONTH, 3)
- .Build(),
- DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
- .Add(DatetimeComponent::ComponentType::SECOND, 4)
- .Add(DatetimeComponent::ComponentType::MINUTE, 12)
- .Add(DatetimeComponent::ComponentType::HOUR, 8)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 16)
- .Add(DatetimeComponent::ComponentType::MONTH, 3)
- .Build()}));
- EXPECT_TRUE(ParsesCorrectly(
- "{2010-06-26 02:31:29}", {1277512289000, 1277555489000},
- GRANULARITY_SECOND,
- {DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
- .Add(DatetimeComponent::ComponentType::SECOND, 29)
- .Add(DatetimeComponent::ComponentType::MINUTE, 31)
- .Add(DatetimeComponent::ComponentType::HOUR, 2)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 26)
- .Add(DatetimeComponent::ComponentType::MONTH, 6)
- .Add(DatetimeComponent::ComponentType::YEAR, 2010)
- .Build(),
- DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
- .Add(DatetimeComponent::ComponentType::SECOND, 29)
- .Add(DatetimeComponent::ComponentType::MINUTE, 31)
- .Add(DatetimeComponent::ComponentType::HOUR, 2)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 26)
- .Add(DatetimeComponent::ComponentType::MONTH, 6)
- .Add(DatetimeComponent::ComponentType::YEAR, 2010)
- .Build()}));
- EXPECT_TRUE(ParsesCorrectly(
- "{2006/01/22 04:11:05}", {1137899465000, 1137942665000},
- GRANULARITY_SECOND,
- {DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
- .Add(DatetimeComponent::ComponentType::SECOND, 5)
- .Add(DatetimeComponent::ComponentType::MINUTE, 11)
- .Add(DatetimeComponent::ComponentType::HOUR, 4)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 22)
- .Add(DatetimeComponent::ComponentType::MONTH, 1)
- .Add(DatetimeComponent::ComponentType::YEAR, 2006)
- .Build(),
- DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
- .Add(DatetimeComponent::ComponentType::SECOND, 5)
- .Add(DatetimeComponent::ComponentType::MINUTE, 11)
- .Add(DatetimeComponent::ComponentType::HOUR, 4)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 22)
- .Add(DatetimeComponent::ComponentType::MONTH, 1)
- .Add(DatetimeComponent::ComponentType::YEAR, 2006)
- .Build()}));
- EXPECT_TRUE(
- ParsesCorrectly("{11:42:35}", {38555000, 81755000}, GRANULARITY_SECOND,
- {DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
- .Add(DatetimeComponent::ComponentType::SECOND, 35)
- .Add(DatetimeComponent::ComponentType::MINUTE, 42)
- .Add(DatetimeComponent::ComponentType::HOUR, 11)
- .Build(),
- DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
- .Add(DatetimeComponent::ComponentType::SECOND, 35)
- .Add(DatetimeComponent::ComponentType::MINUTE, 42)
- .Add(DatetimeComponent::ComponentType::HOUR, 11)
- .Build()}));
- EXPECT_TRUE(ParsesCorrectly(
- "{23/Apr 11:42:35}", {9715355000, 9758555000}, GRANULARITY_SECOND,
- {DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
- .Add(DatetimeComponent::ComponentType::SECOND, 35)
- .Add(DatetimeComponent::ComponentType::MINUTE, 42)
- .Add(DatetimeComponent::ComponentType::HOUR, 11)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
- .Add(DatetimeComponent::ComponentType::MONTH, 4)
- .Build(),
- DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
- .Add(DatetimeComponent::ComponentType::SECOND, 35)
- .Add(DatetimeComponent::ComponentType::MINUTE, 42)
- .Add(DatetimeComponent::ComponentType::HOUR, 11)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
- .Add(DatetimeComponent::ComponentType::MONTH, 4)
- .Build()}));
- EXPECT_TRUE(ParsesCorrectly(
- "{23/Apr/2015 11:42:35}", {1429782155000, 1429825355000},
- GRANULARITY_SECOND,
- {DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
- .Add(DatetimeComponent::ComponentType::SECOND, 35)
- .Add(DatetimeComponent::ComponentType::MINUTE, 42)
- .Add(DatetimeComponent::ComponentType::HOUR, 11)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
- .Add(DatetimeComponent::ComponentType::MONTH, 4)
- .Add(DatetimeComponent::ComponentType::YEAR, 2015)
- .Build(),
- DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
- .Add(DatetimeComponent::ComponentType::SECOND, 35)
- .Add(DatetimeComponent::ComponentType::MINUTE, 42)
- .Add(DatetimeComponent::ComponentType::HOUR, 11)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
- .Add(DatetimeComponent::ComponentType::MONTH, 4)
- .Add(DatetimeComponent::ComponentType::YEAR, 2015)
- .Build()}));
- EXPECT_TRUE(ParsesCorrectly(
- "{23-Apr-2015 11:42:35}", {1429782155000, 1429825355000},
- GRANULARITY_SECOND,
- {DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
- .Add(DatetimeComponent::ComponentType::SECOND, 35)
- .Add(DatetimeComponent::ComponentType::MINUTE, 42)
- .Add(DatetimeComponent::ComponentType::HOUR, 11)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
- .Add(DatetimeComponent::ComponentType::MONTH, 4)
- .Add(DatetimeComponent::ComponentType::YEAR, 2015)
- .Build(),
- DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
- .Add(DatetimeComponent::ComponentType::SECOND, 35)
- .Add(DatetimeComponent::ComponentType::MINUTE, 42)
- .Add(DatetimeComponent::ComponentType::HOUR, 11)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
- .Add(DatetimeComponent::ComponentType::MONTH, 4)
- .Add(DatetimeComponent::ComponentType::YEAR, 2015)
- .Build()}));
- EXPECT_TRUE(ParsesCorrectly(
- "{23 Apr 2015 11:42:35}", {1429782155000, 1429825355000},
- GRANULARITY_SECOND,
- {DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
- .Add(DatetimeComponent::ComponentType::SECOND, 35)
- .Add(DatetimeComponent::ComponentType::MINUTE, 42)
- .Add(DatetimeComponent::ComponentType::HOUR, 11)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
- .Add(DatetimeComponent::ComponentType::MONTH, 4)
- .Add(DatetimeComponent::ComponentType::YEAR, 2015)
- .Build(),
- DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
- .Add(DatetimeComponent::ComponentType::SECOND, 35)
- .Add(DatetimeComponent::ComponentType::MINUTE, 42)
- .Add(DatetimeComponent::ComponentType::HOUR, 11)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
- .Add(DatetimeComponent::ComponentType::MONTH, 4)
- .Add(DatetimeComponent::ComponentType::YEAR, 2015)
- .Build()}));
- EXPECT_TRUE(ParsesCorrectly(
- "{04/23/15 11:42:35}", {1429782155000, 1429825355000}, GRANULARITY_SECOND,
- {DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
- .Add(DatetimeComponent::ComponentType::SECOND, 35)
- .Add(DatetimeComponent::ComponentType::MINUTE, 42)
- .Add(DatetimeComponent::ComponentType::HOUR, 11)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
- .Add(DatetimeComponent::ComponentType::MONTH, 4)
- .Add(DatetimeComponent::ComponentType::YEAR, 2015)
- .Build(),
- DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
- .Add(DatetimeComponent::ComponentType::SECOND, 35)
- .Add(DatetimeComponent::ComponentType::MINUTE, 42)
- .Add(DatetimeComponent::ComponentType::HOUR, 11)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
- .Add(DatetimeComponent::ComponentType::MONTH, 4)
- .Add(DatetimeComponent::ComponentType::YEAR, 2015)
- .Build()}));
- EXPECT_TRUE(ParsesCorrectly(
- "{04/23/2015 11:42:35}", {1429782155000, 1429825355000},
- GRANULARITY_SECOND,
- {DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
- .Add(DatetimeComponent::ComponentType::SECOND, 35)
- .Add(DatetimeComponent::ComponentType::MINUTE, 42)
- .Add(DatetimeComponent::ComponentType::HOUR, 11)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
- .Add(DatetimeComponent::ComponentType::MONTH, 4)
- .Add(DatetimeComponent::ComponentType::YEAR, 2015)
- .Build(),
- DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
- .Add(DatetimeComponent::ComponentType::SECOND, 35)
- .Add(DatetimeComponent::ComponentType::MINUTE, 42)
- .Add(DatetimeComponent::ComponentType::HOUR, 11)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
- .Add(DatetimeComponent::ComponentType::MONTH, 4)
- .Add(DatetimeComponent::ComponentType::YEAR, 2015)
- .Build()}));
- EXPECT_TRUE(ParsesCorrectly(
- "{9/28/2011 2:23:15 PM}", 1317212595000, GRANULARITY_SECOND,
- {DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
- .Add(DatetimeComponent::ComponentType::SECOND, 15)
- .Add(DatetimeComponent::ComponentType::MINUTE, 23)
- .Add(DatetimeComponent::ComponentType::HOUR, 2)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 28)
- .Add(DatetimeComponent::ComponentType::MONTH, 9)
- .Add(DatetimeComponent::ComponentType::YEAR, 2011)
- .Build()}));
- EXPECT_TRUE(ParsesCorrectly(
- "Are sentiments apartments decisively the especially alteration. "
- "Thrown shy denote ten ladies though ask saw. Or by to he going "
- "think order event music. Incommode so intention defective at "
- "convinced. Led income months itself and houses you. After nor "
- "you leave might share court balls. {19/apr/2010 06:36:15} Are "
- "sentiments apartments decisively the especially alteration. "
- "Thrown shy denote ten ladies though ask saw. Or by to he going "
- "think order event music. Incommode so intention defective at "
- "convinced. Led income months itself and houses you. After nor "
- "you leave might share court balls. ",
- {1271651775000, 1271694975000}, GRANULARITY_SECOND,
- {DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
- .Add(DatetimeComponent::ComponentType::SECOND, 15)
- .Add(DatetimeComponent::ComponentType::MINUTE, 36)
- .Add(DatetimeComponent::ComponentType::HOUR, 6)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 19)
- .Add(DatetimeComponent::ComponentType::MONTH, 4)
- .Add(DatetimeComponent::ComponentType::YEAR, 2010)
- .Build(),
- DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
- .Add(DatetimeComponent::ComponentType::SECOND, 15)
- .Add(DatetimeComponent::ComponentType::MINUTE, 36)
- .Add(DatetimeComponent::ComponentType::HOUR, 6)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 19)
- .Add(DatetimeComponent::ComponentType::MONTH, 4)
- .Add(DatetimeComponent::ComponentType::YEAR, 2010)
- .Build()}));
- EXPECT_TRUE(ParsesCorrectly(
- "{january 1 2018 at 4:30}", {1514777400000, 1514820600000},
- GRANULARITY_MINUTE,
- {DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
- .Add(DatetimeComponent::ComponentType::MINUTE, 30)
- .Add(DatetimeComponent::ComponentType::HOUR, 4)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
- .Add(DatetimeComponent::ComponentType::MONTH, 1)
- .Add(DatetimeComponent::ComponentType::YEAR, 2018)
- .Build(),
- DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
- .Add(DatetimeComponent::ComponentType::MINUTE, 30)
- .Add(DatetimeComponent::ComponentType::HOUR, 4)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
- .Add(DatetimeComponent::ComponentType::MONTH, 1)
- .Add(DatetimeComponent::ComponentType::YEAR, 2018)
- .Build()}));
- EXPECT_TRUE(ParsesCorrectly(
- "{january 1 2018 at 4:30 am}", 1514777400000, GRANULARITY_MINUTE,
- {DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
- .Add(DatetimeComponent::ComponentType::MINUTE, 30)
- .Add(DatetimeComponent::ComponentType::HOUR, 4)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
- .Add(DatetimeComponent::ComponentType::MONTH, 1)
- .Add(DatetimeComponent::ComponentType::YEAR, 2018)
- .Build()}));
- EXPECT_TRUE(ParsesCorrectly(
- "{january 1 2018 at 4pm}", 1514818800000, GRANULARITY_HOUR,
- {DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
- .Add(DatetimeComponent::ComponentType::HOUR, 4)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
- .Add(DatetimeComponent::ComponentType::MONTH, 1)
- .Add(DatetimeComponent::ComponentType::YEAR, 2018)
- .Build()}));
-
- EXPECT_TRUE(ParsesCorrectly(
- "{today at 0:00}", {-3600000, 39600000}, GRANULARITY_MINUTE,
- {DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
- .Add(DatetimeComponent::ComponentType::MINUTE, 0)
- .Add(DatetimeComponent::ComponentType::HOUR, 0)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
- DatetimeComponent::RelativeQualifier::NOW, 0)
- .Build(),
- DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
- .Add(DatetimeComponent::ComponentType::MINUTE, 0)
- .Add(DatetimeComponent::ComponentType::HOUR, 0)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
- DatetimeComponent::RelativeQualifier::NOW, 0)
- .Build()}));
- EXPECT_TRUE(ParsesCorrectly(
- "{today at 0:00}", {-57600000, -14400000}, GRANULARITY_MINUTE,
- {DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
- .Add(DatetimeComponent::ComponentType::MINUTE, 0)
- .Add(DatetimeComponent::ComponentType::HOUR, 0)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
- DatetimeComponent::RelativeQualifier::NOW, 0)
- .Build(),
- DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
- .Add(DatetimeComponent::ComponentType::MINUTE, 0)
- .Add(DatetimeComponent::ComponentType::HOUR, 0)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
- DatetimeComponent::RelativeQualifier::NOW, 0)
- .Build()},
- /*anchor_start_end=*/false, "America/Los_Angeles"));
- EXPECT_TRUE(ParsesCorrectly(
- "{tomorrow at 4:00}", {97200000, 140400000}, GRANULARITY_MINUTE,
- {DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
- .Add(DatetimeComponent::ComponentType::MINUTE, 0)
- .Add(DatetimeComponent::ComponentType::HOUR, 4)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
- DatetimeComponent::RelativeQualifier::TOMORROW, 1)
- .Build(),
- DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
- .Add(DatetimeComponent::ComponentType::MINUTE, 0)
- .Add(DatetimeComponent::ComponentType::HOUR, 4)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
- DatetimeComponent::RelativeQualifier::TOMORROW, 1)
- .Build()}));
- EXPECT_TRUE(ParsesCorrectly(
- "{tomorrow at 4am}", 97200000, GRANULARITY_HOUR,
- {DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
- .Add(DatetimeComponent::ComponentType::HOUR, 4)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
- DatetimeComponent::RelativeQualifier::TOMORROW, 1)
- .Build()}));
- EXPECT_TRUE(ParsesCorrectly(
- "{wednesday at 4am}", 529200000, GRANULARITY_HOUR,
- {DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
- .Add(DatetimeComponent::ComponentType::HOUR, 4)
- .Add(DatetimeComponent::ComponentType::DAY_OF_WEEK, 4,
- DatetimeComponent::RelativeQualifier::THIS, 0)
- .Build()}));
- EXPECT_TRUE(ParsesCorrectly(
- "last seen {today at 9:01 PM}", 72060000, GRANULARITY_MINUTE,
- {DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
- .Add(DatetimeComponent::ComponentType::MINUTE, 1)
- .Add(DatetimeComponent::ComponentType::HOUR, 9)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
- DatetimeComponent::RelativeQualifier::NOW, 0)
- .Build()}));
- EXPECT_TRUE(ParsesCorrectly(
- "set an alarm for {7am tomorrow}", 108000000, GRANULARITY_HOUR,
- {DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
- .Add(DatetimeComponent::ComponentType::HOUR, 7)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
- DatetimeComponent::RelativeQualifier::TOMORROW, 1)
- .Build()}));
- EXPECT_TRUE(
- ParsesCorrectly("set an alarm for {7 a.m}", 21600000, GRANULARITY_HOUR,
- {DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
- .Add(DatetimeComponent::ComponentType::HOUR, 7)
- .Build()}));
-}
-
-TEST_F(ParserTest, ParseWithAnchor) {
- EXPECT_TRUE(ParsesCorrectly(
- "{January 1, 1988}", 567990000000, GRANULARITY_DAY,
- {DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
- .Add(DatetimeComponent::ComponentType::MONTH, 1)
- .Add(DatetimeComponent::ComponentType::YEAR, 1988)
- .Build()},
- /*anchor_start_end=*/false));
- EXPECT_TRUE(ParsesCorrectly(
- "{January 1, 1988}", 567990000000, GRANULARITY_DAY,
- {DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
- .Add(DatetimeComponent::ComponentType::MONTH, 1)
- .Add(DatetimeComponent::ComponentType::YEAR, 1988)
- .Build()},
- /*anchor_start_end=*/true));
- EXPECT_TRUE(ParsesCorrectly(
- "lorem {1 january 2018} ipsum", 1514761200000, GRANULARITY_DAY,
- {DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
- .Add(DatetimeComponent::ComponentType::MONTH, 1)
- .Add(DatetimeComponent::ComponentType::YEAR, 2018)
- .Build()},
- /*anchor_start_end=*/false));
- EXPECT_TRUE(HasNoResult("lorem 1 january 2018 ipsum",
- /*anchor_start_end=*/true));
-}
-
-TEST_F(ParserTest, ParseWithRawUsecase) {
- // Annotated for RAW usecase.
- EXPECT_TRUE(ParsesCorrectly(
- "{tomorrow}", 82800000, GRANULARITY_DAY,
- {DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
- DatetimeComponent::RelativeQualifier::TOMORROW, 1)
- .Build()},
- /*anchor_start_end=*/false,
- /*timezone=*/"Europe/Zurich", /*locales=*/"en-US",
- /*annotation_usecase=*/AnnotationUsecase_ANNOTATION_USECASE_RAW));
-
- EXPECT_TRUE(ParsesCorrectly(
- "call me {in two hours}", 7200000, GRANULARITY_HOUR,
- {DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::HOUR, 0,
- DatetimeComponent::RelativeQualifier::FUTURE, 2)
- .Build()},
- /*anchor_start_end=*/false,
- /*timezone=*/"Europe/Zurich", /*locales=*/"en-US",
- /*annotation_usecase=*/AnnotationUsecase_ANNOTATION_USECASE_RAW));
-
- EXPECT_TRUE(ParsesCorrectly(
- "call me {next month}", 2674800000, GRANULARITY_MONTH,
- {DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MONTH, 0,
- DatetimeComponent::RelativeQualifier::NEXT, 1)
- .Build()},
- /*anchor_start_end=*/false,
- /*timezone=*/"Europe/Zurich", /*locales=*/"en-US",
- /*annotation_usecase=*/AnnotationUsecase_ANNOTATION_USECASE_RAW));
- EXPECT_TRUE(ParsesCorrectly(
- "what's the time {now}", -3600000, GRANULARITY_DAY,
- {DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
- DatetimeComponent::RelativeQualifier::NOW, 0)
- .Build()},
- /*anchor_start_end=*/false,
- /*timezone=*/"Europe/Zurich", /*locales=*/"en-US",
- /*annotation_usecase=*/AnnotationUsecase_ANNOTATION_USECASE_RAW));
-
- EXPECT_TRUE(ParsesCorrectly(
- "call me on {Saturday}", 169200000, GRANULARITY_DAY,
- {DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::DAY_OF_WEEK, 7,
- DatetimeComponent::RelativeQualifier::THIS, 0)
- .Build()},
- /*anchor_start_end=*/false,
- /*timezone=*/"Europe/Zurich", /*locales=*/"en-US",
- /*annotation_usecase=*/AnnotationUsecase_ANNOTATION_USECASE_RAW));
-
- // Not annotated for Smart usecase.
- EXPECT_TRUE(HasNoResult(
- "{tomorrow}", /*anchor_start_end=*/false,
- /*timezone=*/"Europe/Zurich",
- /*annotation_usecase=*/AnnotationUsecase_ANNOTATION_USECASE_SMART));
-}
-
-TEST_F(ParserTest, AddsADayWhenTimeInThePastAndDayNotSpecified) {
- // ParsesCorrectly uses 0 as the reference time, which corresponds to:
- // "Thu Jan 01 1970 01:00:00" Zurich time. So if we pass "0:30" here, it means
- // it is in the past, and so the parser should move this to the next day ->
- // "Fri Jan 02 1970 00:30:00" Zurich time (b/139112907).
- EXPECT_TRUE(ParsesCorrectly(
- "{0:30am}", 84600000L /* 23.5 hours from reference time */,
- GRANULARITY_MINUTE,
- {DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
- .Add(DatetimeComponent::ComponentType::MINUTE, 30)
- .Add(DatetimeComponent::ComponentType::HOUR, 0)
- .Build()}));
-}
-
-TEST_F(ParserTest, DoesNotAddADayWhenTimeInThePastAndDayNotSpecifiedDisabled) {
- // ParsesCorrectly uses 0 as the reference time, which corresponds to:
- // "Thu Jan 01 1970 01:00:00" Zurich time. So if we pass "0:30" here, it means
- // it is in the past. The parameter prefer_future_when_unspecified_day is
- // disabled, so the parser should annotate this to the same day: "Thu Jan 01
- // 1970 00:30:00" Zurich time.
- LoadModel([](ModelT* model) {
- // In the test model, the prefer_future_for_unspecified_date is true; make
- // it false only for this test.
- model->datetime_model->prefer_future_for_unspecified_date = false;
- });
-
- EXPECT_TRUE(ParsesCorrectly(
- "{0:30am}", -1800000L /* -30 minutes from reference time */,
- GRANULARITY_MINUTE,
- {DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
- .Add(DatetimeComponent::ComponentType::MINUTE, 30)
- .Add(DatetimeComponent::ComponentType::HOUR, 0)
- .Build()}));
-}
-
-TEST_F(ParserTest, ParsesNoonAndMidnightCorrectly) {
- EXPECT_TRUE(ParsesCorrectly(
- "{January 1, 1988 12:30am}", 567991800000, GRANULARITY_MINUTE,
- {DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
- .Add(DatetimeComponent::ComponentType::MINUTE, 30)
- .Add(DatetimeComponent::ComponentType::HOUR, 12)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
- .Add(DatetimeComponent::ComponentType::MONTH, 1)
- .Add(DatetimeComponent::ComponentType::YEAR, 1988)
- .Build()}));
- EXPECT_TRUE(ParsesCorrectly(
- "{January 1, 1988 12:30pm}", 568035000000, GRANULARITY_MINUTE,
- {DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
- .Add(DatetimeComponent::ComponentType::MINUTE, 30)
- .Add(DatetimeComponent::ComponentType::HOUR, 12)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
- .Add(DatetimeComponent::ComponentType::MONTH, 1)
- .Add(DatetimeComponent::ComponentType::YEAR, 1988)
- .Build()}));
- EXPECT_TRUE(ParsesCorrectly(
- "{tomorrow at 12:00 am}", 82800000, GRANULARITY_MINUTE,
- {DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
- .Add(DatetimeComponent::ComponentType::MINUTE, 0)
- .Add(DatetimeComponent::ComponentType::HOUR, 12)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
- DatetimeComponent::RelativeQualifier::TOMORROW, 1)
- .Build()}));
-}
-
-TEST_F(ParserTest, ParseGerman) {
- EXPECT_TRUE(ParsesCorrectlyGerman(
- "{Januar 1 2018}", 1514761200000, GRANULARITY_DAY,
- {DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
- .Add(DatetimeComponent::ComponentType::MONTH, 1)
- .Add(DatetimeComponent::ComponentType::YEAR, 2018)
- .Build()}));
- EXPECT_TRUE(ParsesCorrectlyGerman(
- "{1 2 2018}", 1517439600000, GRANULARITY_DAY,
- {DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
- .Add(DatetimeComponent::ComponentType::MONTH, 2)
- .Add(DatetimeComponent::ComponentType::YEAR, 2018)
- .Build()}));
- EXPECT_TRUE(ParsesCorrectlyGerman(
- "lorem {1 Januar 2018} ipsum", 1514761200000, GRANULARITY_DAY,
- {DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
- .Add(DatetimeComponent::ComponentType::MONTH, 1)
- .Add(DatetimeComponent::ComponentType::YEAR, 2018)
- .Build()}));
- EXPECT_TRUE(ParsesCorrectlyGerman(
- "{19/Apr/2010:06:36:15}", {1271651775000, 1271694975000},
- GRANULARITY_SECOND,
- {DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
- .Add(DatetimeComponent::ComponentType::SECOND, 15)
- .Add(DatetimeComponent::ComponentType::MINUTE, 36)
- .Add(DatetimeComponent::ComponentType::HOUR, 6)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 19)
- .Add(DatetimeComponent::ComponentType::MONTH, 4)
- .Add(DatetimeComponent::ComponentType::YEAR, 2010)
- .Build(),
- DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
- .Add(DatetimeComponent::ComponentType::SECOND, 15)
- .Add(DatetimeComponent::ComponentType::MINUTE, 36)
- .Add(DatetimeComponent::ComponentType::HOUR, 6)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 19)
- .Add(DatetimeComponent::ComponentType::MONTH, 4)
- .Add(DatetimeComponent::ComponentType::YEAR, 2010)
- .Build()}));
- EXPECT_TRUE(ParsesCorrectlyGerman(
- "{09/März/2004 22:02:40}", 1078866160000, GRANULARITY_SECOND,
- {DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::SECOND, 40)
- .Add(DatetimeComponent::ComponentType::MINUTE, 02)
- .Add(DatetimeComponent::ComponentType::HOUR, 22)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 9)
- .Add(DatetimeComponent::ComponentType::MONTH, 3)
- .Add(DatetimeComponent::ComponentType::YEAR, 2004)
- .Build()}));
- EXPECT_TRUE(ParsesCorrectlyGerman(
- "{Dez 2, 2010 2:39:58}", {1291253998000, 1291297198000},
- GRANULARITY_SECOND,
- {DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
- .Add(DatetimeComponent::ComponentType::SECOND, 58)
- .Add(DatetimeComponent::ComponentType::MINUTE, 39)
- .Add(DatetimeComponent::ComponentType::HOUR, 2)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 2)
- .Add(DatetimeComponent::ComponentType::MONTH, 12)
- .Add(DatetimeComponent::ComponentType::YEAR, 2010)
- .Build(),
- DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
- .Add(DatetimeComponent::ComponentType::SECOND, 58)
- .Add(DatetimeComponent::ComponentType::MINUTE, 39)
- .Add(DatetimeComponent::ComponentType::HOUR, 2)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 2)
- .Add(DatetimeComponent::ComponentType::MONTH, 12)
- .Add(DatetimeComponent::ComponentType::YEAR, 2010)
- .Build()}));
- EXPECT_TRUE(ParsesCorrectlyGerman(
- "{Juni 09 2011 15:28:14}", 1307626094000, GRANULARITY_SECOND,
- {DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::SECOND, 14)
- .Add(DatetimeComponent::ComponentType::MINUTE, 28)
- .Add(DatetimeComponent::ComponentType::HOUR, 15)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 9)
- .Add(DatetimeComponent::ComponentType::MONTH, 6)
- .Add(DatetimeComponent::ComponentType::YEAR, 2011)
- .Build()}));
- EXPECT_TRUE(ParsesCorrectlyGerman(
- "{März 16 08:12:04}", {6419524000, 6462724000}, GRANULARITY_SECOND,
- {DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
- .Add(DatetimeComponent::ComponentType::SECOND, 4)
- .Add(DatetimeComponent::ComponentType::MINUTE, 12)
- .Add(DatetimeComponent::ComponentType::HOUR, 8)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 16)
- .Add(DatetimeComponent::ComponentType::MONTH, 3)
- .Build(),
- DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
- .Add(DatetimeComponent::ComponentType::SECOND, 4)
- .Add(DatetimeComponent::ComponentType::MINUTE, 12)
- .Add(DatetimeComponent::ComponentType::HOUR, 8)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 16)
- .Add(DatetimeComponent::ComponentType::MONTH, 3)
- .Build()}));
- EXPECT_TRUE(ParsesCorrectlyGerman(
- "{2010-06-26 02:31:29}", {1277512289000, 1277555489000},
- GRANULARITY_SECOND,
- {DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
- .Add(DatetimeComponent::ComponentType::SECOND, 29)
- .Add(DatetimeComponent::ComponentType::MINUTE, 31)
- .Add(DatetimeComponent::ComponentType::HOUR, 2)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 26)
- .Add(DatetimeComponent::ComponentType::MONTH, 6)
- .Add(DatetimeComponent::ComponentType::YEAR, 2010)
- .Build(),
- DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
- .Add(DatetimeComponent::ComponentType::SECOND, 29)
- .Add(DatetimeComponent::ComponentType::MINUTE, 31)
- .Add(DatetimeComponent::ComponentType::HOUR, 2)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 26)
- .Add(DatetimeComponent::ComponentType::MONTH, 6)
- .Add(DatetimeComponent::ComponentType::YEAR, 2010)
- .Build()}));
- EXPECT_TRUE(ParsesCorrectlyGerman(
- "{2006/01/22 04:11:05}", {1137899465000, 1137942665000},
- GRANULARITY_SECOND,
- {DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
- .Add(DatetimeComponent::ComponentType::SECOND, 5)
- .Add(DatetimeComponent::ComponentType::MINUTE, 11)
- .Add(DatetimeComponent::ComponentType::HOUR, 4)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 22)
- .Add(DatetimeComponent::ComponentType::MONTH, 1)
- .Add(DatetimeComponent::ComponentType::YEAR, 2006)
- .Build(),
- DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
- .Add(DatetimeComponent::ComponentType::SECOND, 5)
- .Add(DatetimeComponent::ComponentType::MINUTE, 11)
- .Add(DatetimeComponent::ComponentType::HOUR, 4)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 22)
- .Add(DatetimeComponent::ComponentType::MONTH, 1)
- .Add(DatetimeComponent::ComponentType::YEAR, 2006)
- .Build()}));
- EXPECT_TRUE(ParsesCorrectlyGerman(
- "{11:42:35}", {38555000, 81755000}, GRANULARITY_SECOND,
- {DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
- .Add(DatetimeComponent::ComponentType::SECOND, 35)
- .Add(DatetimeComponent::ComponentType::MINUTE, 42)
- .Add(DatetimeComponent::ComponentType::HOUR, 11)
- .Build(),
- DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
- .Add(DatetimeComponent::ComponentType::SECOND, 35)
- .Add(DatetimeComponent::ComponentType::MINUTE, 42)
- .Add(DatetimeComponent::ComponentType::HOUR, 11)
- .Build()}));
- EXPECT_TRUE(ParsesCorrectlyGerman(
- "{23/Apr 11:42:35}", {9715355000, 9758555000}, GRANULARITY_SECOND,
- {DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
- .Add(DatetimeComponent::ComponentType::SECOND, 35)
- .Add(DatetimeComponent::ComponentType::MINUTE, 42)
- .Add(DatetimeComponent::ComponentType::HOUR, 11)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
- .Add(DatetimeComponent::ComponentType::MONTH, 4)
- .Build(),
- DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
- .Add(DatetimeComponent::ComponentType::SECOND, 35)
- .Add(DatetimeComponent::ComponentType::MINUTE, 42)
- .Add(DatetimeComponent::ComponentType::HOUR, 11)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
- .Add(DatetimeComponent::ComponentType::MONTH, 4)
- .Build()}));
- EXPECT_TRUE(ParsesCorrectlyGerman(
- "{23/Apr/2015:11:42:35}", {1429782155000, 1429825355000},
- GRANULARITY_SECOND,
- {DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
- .Add(DatetimeComponent::ComponentType::SECOND, 35)
- .Add(DatetimeComponent::ComponentType::MINUTE, 42)
- .Add(DatetimeComponent::ComponentType::HOUR, 11)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
- .Add(DatetimeComponent::ComponentType::MONTH, 4)
- .Add(DatetimeComponent::ComponentType::YEAR, 2015)
- .Build(),
- DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
- .Add(DatetimeComponent::ComponentType::SECOND, 35)
- .Add(DatetimeComponent::ComponentType::MINUTE, 42)
- .Add(DatetimeComponent::ComponentType::HOUR, 11)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
- .Add(DatetimeComponent::ComponentType::MONTH, 4)
- .Add(DatetimeComponent::ComponentType::YEAR, 2015)
- .Build()}));
- EXPECT_TRUE(ParsesCorrectlyGerman(
- "{23/Apr/2015 11:42:35}", {1429782155000, 1429825355000},
- GRANULARITY_SECOND,
- {DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
- .Add(DatetimeComponent::ComponentType::SECOND, 35)
- .Add(DatetimeComponent::ComponentType::MINUTE, 42)
- .Add(DatetimeComponent::ComponentType::HOUR, 11)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
- .Add(DatetimeComponent::ComponentType::MONTH, 4)
- .Add(DatetimeComponent::ComponentType::YEAR, 2015)
- .Build(),
- DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
- .Add(DatetimeComponent::ComponentType::SECOND, 35)
- .Add(DatetimeComponent::ComponentType::MINUTE, 42)
- .Add(DatetimeComponent::ComponentType::HOUR, 11)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
- .Add(DatetimeComponent::ComponentType::MONTH, 4)
- .Add(DatetimeComponent::ComponentType::YEAR, 2015)
- .Build()}));
- EXPECT_TRUE(ParsesCorrectlyGerman(
- "{23-Apr-2015 11:42:35}", {1429782155000, 1429825355000},
- GRANULARITY_SECOND,
- {DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
- .Add(DatetimeComponent::ComponentType::SECOND, 35)
- .Add(DatetimeComponent::ComponentType::MINUTE, 42)
- .Add(DatetimeComponent::ComponentType::HOUR, 11)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
- .Add(DatetimeComponent::ComponentType::MONTH, 4)
- .Add(DatetimeComponent::ComponentType::YEAR, 2015)
- .Build(),
- DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
- .Add(DatetimeComponent::ComponentType::SECOND, 35)
- .Add(DatetimeComponent::ComponentType::MINUTE, 42)
- .Add(DatetimeComponent::ComponentType::HOUR, 11)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
- .Add(DatetimeComponent::ComponentType::MONTH, 4)
- .Add(DatetimeComponent::ComponentType::YEAR, 2015)
- .Build()}));
- EXPECT_TRUE(ParsesCorrectlyGerman(
- "{23 Apr 2015 11:42:35}", {1429782155000, 1429825355000},
- GRANULARITY_SECOND,
- {DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
- .Add(DatetimeComponent::ComponentType::SECOND, 35)
- .Add(DatetimeComponent::ComponentType::MINUTE, 42)
- .Add(DatetimeComponent::ComponentType::HOUR, 11)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
- .Add(DatetimeComponent::ComponentType::MONTH, 4)
- .Add(DatetimeComponent::ComponentType::YEAR, 2015)
- .Build(),
- DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
- .Add(DatetimeComponent::ComponentType::SECOND, 35)
- .Add(DatetimeComponent::ComponentType::MINUTE, 42)
- .Add(DatetimeComponent::ComponentType::HOUR, 11)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
- .Add(DatetimeComponent::ComponentType::MONTH, 4)
- .Add(DatetimeComponent::ComponentType::YEAR, 2015)
- .Build()}));
- EXPECT_TRUE(ParsesCorrectlyGerman(
- "{04/23/15 11:42:35}", {1429782155000, 1429825355000}, GRANULARITY_SECOND,
- {DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
- .Add(DatetimeComponent::ComponentType::SECOND, 35)
- .Add(DatetimeComponent::ComponentType::MINUTE, 42)
- .Add(DatetimeComponent::ComponentType::HOUR, 11)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
- .Add(DatetimeComponent::ComponentType::MONTH, 4)
- .Add(DatetimeComponent::ComponentType::YEAR, 2015)
- .Build(),
- DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
- .Add(DatetimeComponent::ComponentType::SECOND, 35)
- .Add(DatetimeComponent::ComponentType::MINUTE, 42)
- .Add(DatetimeComponent::ComponentType::HOUR, 11)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
- .Add(DatetimeComponent::ComponentType::MONTH, 4)
- .Add(DatetimeComponent::ComponentType::YEAR, 2015)
- .Build()}));
- EXPECT_TRUE(ParsesCorrectlyGerman(
- "{04/23/2015 11:42:35}", {1429782155000, 1429825355000},
- GRANULARITY_SECOND,
- {DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
- .Add(DatetimeComponent::ComponentType::SECOND, 35)
- .Add(DatetimeComponent::ComponentType::MINUTE, 42)
- .Add(DatetimeComponent::ComponentType::HOUR, 11)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
- .Add(DatetimeComponent::ComponentType::MONTH, 4)
- .Add(DatetimeComponent::ComponentType::YEAR, 2015)
- .Build(),
- DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
- .Add(DatetimeComponent::ComponentType::SECOND, 35)
- .Add(DatetimeComponent::ComponentType::MINUTE, 42)
- .Add(DatetimeComponent::ComponentType::HOUR, 11)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
- .Add(DatetimeComponent::ComponentType::MONTH, 4)
- .Add(DatetimeComponent::ComponentType::YEAR, 2015)
- .Build()}));
- EXPECT_TRUE(ParsesCorrectlyGerman(
- "{19/apr/2010:06:36:15}", {1271651775000, 1271694975000},
- GRANULARITY_SECOND,
- {DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
- .Add(DatetimeComponent::ComponentType::SECOND, 15)
- .Add(DatetimeComponent::ComponentType::MINUTE, 36)
- .Add(DatetimeComponent::ComponentType::HOUR, 6)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 19)
- .Add(DatetimeComponent::ComponentType::MONTH, 4)
- .Add(DatetimeComponent::ComponentType::YEAR, 2010)
- .Build(),
- DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
- .Add(DatetimeComponent::ComponentType::SECOND, 15)
- .Add(DatetimeComponent::ComponentType::MINUTE, 36)
- .Add(DatetimeComponent::ComponentType::HOUR, 6)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 19)
- .Add(DatetimeComponent::ComponentType::MONTH, 4)
- .Add(DatetimeComponent::ComponentType::YEAR, 2010)
- .Build()}));
- EXPECT_TRUE(ParsesCorrectlyGerman(
- "{januar 1 2018 um 4:30}", {1514777400000, 1514820600000},
- GRANULARITY_MINUTE,
- {DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
- .Add(DatetimeComponent::ComponentType::MINUTE, 30)
- .Add(DatetimeComponent::ComponentType::HOUR, 4)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
- .Add(DatetimeComponent::ComponentType::MONTH, 1)
- .Add(DatetimeComponent::ComponentType::YEAR, 2018)
- .Build(),
- DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
- .Add(DatetimeComponent::ComponentType::MINUTE, 30)
- .Add(DatetimeComponent::ComponentType::HOUR, 4)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
- .Add(DatetimeComponent::ComponentType::MONTH, 1)
- .Add(DatetimeComponent::ComponentType::YEAR, 2018)
- .Build()}));
- EXPECT_TRUE(ParsesCorrectlyGerman(
- "{januar 1 2018 um 4:30 nachm}", 1514820600000, GRANULARITY_MINUTE,
- {DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
- .Add(DatetimeComponent::ComponentType::MINUTE, 30)
- .Add(DatetimeComponent::ComponentType::HOUR, 4)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
- .Add(DatetimeComponent::ComponentType::MONTH, 1)
- .Add(DatetimeComponent::ComponentType::YEAR, 2018)
- .Build()}));
- EXPECT_TRUE(ParsesCorrectlyGerman(
- "{januar 1 2018 um 4 nachm}", 1514818800000, GRANULARITY_HOUR,
- {DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
- .Add(DatetimeComponent::ComponentType::HOUR, 4)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
- .Add(DatetimeComponent::ComponentType::MONTH, 1)
- .Add(DatetimeComponent::ComponentType::YEAR, 2018)
- .Build()}));
- EXPECT_TRUE(ParsesCorrectlyGerman(
- "{14.03.2017}", 1489446000000, GRANULARITY_DAY,
- {DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 14)
- .Add(DatetimeComponent::ComponentType::MONTH, 3)
- .Add(DatetimeComponent::ComponentType::YEAR, 2017)
- .Build()}));
-
- EXPECT_TRUE(ParsesCorrectlyGerman(
- "{morgen 0:00}", {82800000, 126000000}, GRANULARITY_MINUTE,
- {DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
- .Add(DatetimeComponent::ComponentType::MINUTE, 0)
- .Add(DatetimeComponent::ComponentType::HOUR, 0)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
- DatetimeComponent::RelativeQualifier::TOMORROW, 1)
- .Build(),
- DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
- .Add(DatetimeComponent::ComponentType::MINUTE, 0)
- .Add(DatetimeComponent::ComponentType::HOUR, 0)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
- DatetimeComponent::RelativeQualifier::TOMORROW, 1)
- .Build()}));
- EXPECT_TRUE(ParsesCorrectlyGerman(
- "{morgen um 4:00}", {97200000, 140400000}, GRANULARITY_MINUTE,
- {DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
- .Add(DatetimeComponent::ComponentType::MINUTE, 0)
- .Add(DatetimeComponent::ComponentType::HOUR, 4)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
- DatetimeComponent::RelativeQualifier::TOMORROW, 1)
- .Build(),
- DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
- .Add(DatetimeComponent::ComponentType::MINUTE, 0)
- .Add(DatetimeComponent::ComponentType::HOUR, 4)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
- DatetimeComponent::RelativeQualifier::TOMORROW, 1)
- .Build()}));
- EXPECT_TRUE(ParsesCorrectlyGerman(
- "{morgen um 4 vorm}", 97200000, GRANULARITY_HOUR,
- {DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
- .Add(DatetimeComponent::ComponentType::HOUR, 4)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
- DatetimeComponent::RelativeQualifier::TOMORROW, 1)
- .Build()}));
-}
-
-TEST_F(ParserTest, ParseNonUs) {
- auto first_may_2015 =
- DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
- .Add(DatetimeComponent::ComponentType::MONTH, 5)
- .Add(DatetimeComponent::ComponentType::YEAR, 2015)
- .Build();
-
- EXPECT_TRUE(ParsesCorrectly("{1/5/15}", 1430431200000, GRANULARITY_DAY,
- {first_may_2015},
- /*anchor_start_end=*/false,
- /*timezone=*/"Europe/Zurich",
- /*locales=*/"en-GB"));
- EXPECT_TRUE(ParsesCorrectly("{1/5/15}", 1430431200000, GRANULARITY_DAY,
- {first_may_2015},
- /*anchor_start_end=*/false,
- /*timezone=*/"Europe/Zurich", /*locales=*/"en"));
-}
-
-TEST_F(ParserTest, ParseUs) {
- auto five_january_2015 =
- DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 5)
- .Add(DatetimeComponent::ComponentType::MONTH, 1)
- .Add(DatetimeComponent::ComponentType::YEAR, 2015)
- .Build();
-
- EXPECT_TRUE(ParsesCorrectly("{1/5/15}", 1420412400000, GRANULARITY_DAY,
- {five_january_2015},
- /*anchor_start_end=*/false,
- /*timezone=*/"Europe/Zurich",
- /*locales=*/"en-US"));
- EXPECT_TRUE(ParsesCorrectly("{1/5/15}", 1420412400000, GRANULARITY_DAY,
- {five_january_2015},
- /*anchor_start_end=*/false,
- /*timezone=*/"Europe/Zurich",
- /*locales=*/"es-US"));
-}
-
-TEST_F(ParserTest, ParseUnknownLanguage) {
- EXPECT_TRUE(ParsesCorrectly(
- "bylo to {31. 12. 2015} v 6 hodin", 1451516400000, GRANULARITY_DAY,
- {DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 31)
- .Add(DatetimeComponent::ComponentType::MONTH, 12)
- .Add(DatetimeComponent::ComponentType::YEAR, 2015)
- .Build()},
- /*anchor_start_end=*/false,
- /*timezone=*/"Europe/Zurich", /*locales=*/"xx"));
-}
-
-TEST_F(ParserTest, WhenAlternativesEnabledGeneratesAlternatives) {
- LoadModel([](ModelT* model) {
- model->datetime_model->generate_alternative_interpretations_when_ambiguous =
- true;
- });
-
- EXPECT_TRUE(ParsesCorrectly(
- "{january 1 2018 at 4:30}", {1514777400000, 1514820600000},
- GRANULARITY_MINUTE,
- {DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
- .Add(DatetimeComponent::ComponentType::MINUTE, 30)
- .Add(DatetimeComponent::ComponentType::HOUR, 4)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
- .Add(DatetimeComponent::ComponentType::MONTH, 1)
- .Add(DatetimeComponent::ComponentType::YEAR, 2018)
- .Build(),
- DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
- .Add(DatetimeComponent::ComponentType::MINUTE, 30)
- .Add(DatetimeComponent::ComponentType::HOUR, 4)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
- .Add(DatetimeComponent::ComponentType::MONTH, 1)
- .Add(DatetimeComponent::ComponentType::YEAR, 2018)
- .Build()}));
- EXPECT_TRUE(ParsesCorrectly(
- "{monday 3pm}", 396000000, GRANULARITY_HOUR,
- {DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
- .Add(DatetimeComponent::ComponentType::HOUR, 3)
- .Add(DatetimeComponent::ComponentType::DAY_OF_WEEK, 2,
- DatetimeComponent::RelativeQualifier::THIS, 0)
- .Build()}));
- EXPECT_TRUE(ParsesCorrectly(
- "{monday 3:00}", {352800000, 396000000}, GRANULARITY_MINUTE,
- {DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
- .Add(DatetimeComponent::ComponentType::MINUTE, 0)
- .Add(DatetimeComponent::ComponentType::HOUR, 3)
- .Add(DatetimeComponent::ComponentType::DAY_OF_WEEK, 2,
- DatetimeComponent::RelativeQualifier::THIS, 0)
- .Build(),
- DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
- .Add(DatetimeComponent::ComponentType::MINUTE, 0)
- .Add(DatetimeComponent::ComponentType::HOUR, 3)
- .Add(DatetimeComponent::ComponentType::DAY_OF_WEEK, 2,
- DatetimeComponent::RelativeQualifier::THIS, 0)
- .Build()}));
-}
-
-TEST_F(ParserTest, WhenAlternativesDisabledDoesNotGenerateAlternatives) {
- LoadModel([](ModelT* model) {
- model->datetime_model->generate_alternative_interpretations_when_ambiguous =
- false;
- });
-
- EXPECT_TRUE(ParsesCorrectly(
- "{january 1 2018 at 4:30}", 1514777400000, GRANULARITY_MINUTE,
- {DatetimeComponentsBuilder()
- .Add(DatetimeComponent::ComponentType::MINUTE, 30)
- .Add(DatetimeComponent::ComponentType::HOUR, 4)
- .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
- .Add(DatetimeComponent::ComponentType::MONTH, 1)
- .Add(DatetimeComponent::ComponentType::YEAR, 2018)
- .Build()}));
-}
-
-class ParserLocaleTest : public testing::Test {
- public:
- void SetUp() override;
- bool HasResult(const std::string& input, const std::string& locales);
-
- protected:
- UniLib unilib_;
- CalendarLib calendarlib_;
- flatbuffers::FlatBufferBuilder builder_;
- std::unique_ptr<DatetimeParser> parser_;
-};
-
-void AddPattern(const std::string& regex, int locale,
- std::vector<std::unique_ptr<DatetimeModelPatternT>>* patterns) {
- patterns->emplace_back(new DatetimeModelPatternT);
- patterns->back()->regexes.emplace_back(new DatetimeModelPattern_::RegexT);
- patterns->back()->regexes.back()->pattern = regex;
- patterns->back()->regexes.back()->groups.push_back(
- DatetimeGroupType_GROUP_UNUSED);
- patterns->back()->locales.push_back(locale);
-}
-
-void ParserLocaleTest::SetUp() {
- DatetimeModelT model;
- model.use_extractors_for_locating = false;
- model.locales.clear();
- model.locales.push_back("en-US");
- model.locales.push_back("en-CH");
- model.locales.push_back("zh-Hant");
- model.locales.push_back("en-*");
- model.locales.push_back("zh-Hant-*");
- model.locales.push_back("*-CH");
- model.locales.push_back("default");
- model.default_locales.push_back(6);
-
- AddPattern(/*regex=*/"en-US", /*locale=*/0, &model.patterns);
- AddPattern(/*regex=*/"en-CH", /*locale=*/1, &model.patterns);
- AddPattern(/*regex=*/"zh-Hant", /*locale=*/2, &model.patterns);
- AddPattern(/*regex=*/"en-all", /*locale=*/3, &model.patterns);
- AddPattern(/*regex=*/"zh-Hant-all", /*locale=*/4, &model.patterns);
- AddPattern(/*regex=*/"all-CH", /*locale=*/5, &model.patterns);
- AddPattern(/*regex=*/"default", /*locale=*/6, &model.patterns);
-
- builder_.Finish(DatetimeModel::Pack(builder_, &model));
- const DatetimeModel* model_fb =
- flatbuffers::GetRoot<DatetimeModel>(builder_.GetBufferPointer());
- ASSERT_TRUE(model_fb);
-
- parser_ = DatetimeParser::Instance(model_fb, unilib_, calendarlib_,
- /*decompressor=*/nullptr);
- ASSERT_TRUE(parser_);
-}
-
-bool ParserLocaleTest::HasResult(const std::string& input,
- const std::string& locales) {
- std::vector<DatetimeParseResultSpan> results;
- EXPECT_TRUE(parser_->Parse(
- input, /*reference_time_ms_utc=*/0,
- /*reference_timezone=*/"", locales, ModeFlag_ANNOTATION,
- AnnotationUsecase_ANNOTATION_USECASE_SMART, false, &results));
- return results.size() == 1;
-}
-
-TEST_F(ParserLocaleTest, English) {
- EXPECT_TRUE(HasResult("en-US", /*locales=*/"en-US"));
- EXPECT_FALSE(HasResult("en-CH", /*locales=*/"en-US"));
- EXPECT_FALSE(HasResult("en-US", /*locales=*/"en-CH"));
- EXPECT_TRUE(HasResult("en-CH", /*locales=*/"en-CH"));
- EXPECT_TRUE(HasResult("default", /*locales=*/"en-CH"));
-}
-
-TEST_F(ParserLocaleTest, TraditionalChinese) {
- EXPECT_TRUE(HasResult("zh-Hant-all", /*locales=*/"zh-Hant"));
- EXPECT_TRUE(HasResult("zh-Hant-all", /*locales=*/"zh-Hant-TW"));
- EXPECT_TRUE(HasResult("zh-Hant-all", /*locales=*/"zh-Hant-SG"));
- EXPECT_FALSE(HasResult("zh-Hant-all", /*locales=*/"zh-SG"));
- EXPECT_FALSE(HasResult("zh-Hant-all", /*locales=*/"zh"));
- EXPECT_TRUE(HasResult("default", /*locales=*/"zh"));
- EXPECT_TRUE(HasResult("default", /*locales=*/"zh-Hant-SG"));
-}
-
-TEST_F(ParserLocaleTest, SwissEnglish) {
- EXPECT_TRUE(HasResult("all-CH", /*locales=*/"de-CH"));
- EXPECT_TRUE(HasResult("all-CH", /*locales=*/"en-CH"));
- EXPECT_TRUE(HasResult("en-all", /*locales=*/"en-CH"));
- EXPECT_FALSE(HasResult("all-CH", /*locales=*/"de-DE"));
- EXPECT_TRUE(HasResult("default", /*locales=*/"de-CH"));
- EXPECT_TRUE(HasResult("default", /*locales=*/"en-CH"));
-}
-
-} // namespace
-} // namespace libtextclassifier3
diff --git a/native/annotator/duration/duration_test.cc b/native/annotator/duration/duration_test.cc
deleted file mode 100644
index 1afd701..0000000
--- a/native/annotator/duration/duration_test.cc
+++ /dev/null
@@ -1,567 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "annotator/duration/duration.h"
-
-#include <string>
-#include <vector>
-
-#include "annotator/collections.h"
-#include "annotator/model_generated.h"
-#include "annotator/types-test-util.h"
-#include "annotator/types.h"
-#include "utils/test-utils.h"
-#include "utils/utf8/unicodetext.h"
-#include "utils/utf8/unilib.h"
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-
-namespace libtextclassifier3 {
-namespace {
-
-using testing::AllOf;
-using testing::ElementsAre;
-using testing::Field;
-using testing::IsEmpty;
-
-const DurationAnnotatorOptions* TestingDurationAnnotatorOptions() {
- static const flatbuffers::DetachedBuffer* options_data = []() {
- DurationAnnotatorOptionsT options;
- options.enabled = true;
-
- options.week_expressions.push_back("week");
- options.week_expressions.push_back("weeks");
-
- options.day_expressions.push_back("day");
- options.day_expressions.push_back("days");
-
- options.hour_expressions.push_back("hour");
- options.hour_expressions.push_back("hours");
-
- options.minute_expressions.push_back("minute");
- options.minute_expressions.push_back("minutes");
-
- options.second_expressions.push_back("second");
- options.second_expressions.push_back("seconds");
-
- options.filler_expressions.push_back("and");
- options.filler_expressions.push_back("a");
- options.filler_expressions.push_back("an");
- options.filler_expressions.push_back("one");
-
- options.half_expressions.push_back("half");
-
- options.sub_token_separator_codepoints.push_back('-');
-
- flatbuffers::FlatBufferBuilder builder;
- builder.Finish(DurationAnnotatorOptions::Pack(builder, &options));
- return new flatbuffers::DetachedBuffer(builder.Release());
- }();
-
- return flatbuffers::GetRoot<DurationAnnotatorOptions>(options_data->data());
-}
-
-std::unique_ptr<FeatureProcessor> BuildFeatureProcessor(const UniLib* unilib) {
- static const flatbuffers::DetachedBuffer* options_data = []() {
- FeatureProcessorOptionsT options;
- options.context_size = 1;
- options.max_selection_span = 1;
- options.snap_label_span_boundaries_to_containing_tokens = false;
- options.ignored_span_boundary_codepoints.push_back(',');
-
- options.tokenization_codepoint_config.emplace_back(
- new TokenizationCodepointRangeT());
- auto& config = options.tokenization_codepoint_config.back();
- config->start = 32;
- config->end = 33;
- config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR;
-
- flatbuffers::FlatBufferBuilder builder;
- builder.Finish(FeatureProcessorOptions::Pack(builder, &options));
- return new flatbuffers::DetachedBuffer(builder.Release());
- }();
-
- const FeatureProcessorOptions* feature_processor_options =
- flatbuffers::GetRoot<FeatureProcessorOptions>(options_data->data());
-
- return std::unique_ptr<FeatureProcessor>(
- new FeatureProcessor(feature_processor_options, unilib));
-}
-
-class DurationAnnotatorTest : public ::testing::Test {
- protected:
- DurationAnnotatorTest()
- : INIT_UNILIB_FOR_TESTING(unilib_),
- feature_processor_(BuildFeatureProcessor(&unilib_)),
- duration_annotator_(TestingDurationAnnotatorOptions(),
- feature_processor_.get(), &unilib_) {}
-
- std::vector<Token> Tokenize(const UnicodeText& text) {
- return feature_processor_->Tokenize(text);
- }
-
- UniLib unilib_;
- std::unique_ptr<FeatureProcessor> feature_processor_;
- DurationAnnotator duration_annotator_;
-};
-
-TEST_F(DurationAnnotatorTest, ClassifiesSimpleDuration) {
- ClassificationResult classification;
- EXPECT_TRUE(duration_annotator_.ClassifyText(
- UTF8ToUnicodeText("Wake me up in 15 minutes ok?"), {14, 24},
- AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification));
-
- EXPECT_THAT(classification,
- AllOf(Field(&ClassificationResult::collection, "duration"),
- Field(&ClassificationResult::duration_ms, 15 * 60 * 1000)));
-}
-
-TEST_F(DurationAnnotatorTest, ClassifiesWhenTokensDontAlignWithSelection) {
- ClassificationResult classification;
- EXPECT_TRUE(duration_annotator_.ClassifyText(
- UTF8ToUnicodeText("Wake me up in15 minutesok?"), {13, 23},
- AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification));
-
- EXPECT_THAT(classification,
- AllOf(Field(&ClassificationResult::collection, "duration"),
- Field(&ClassificationResult::duration_ms, 15 * 60 * 1000)));
-}
-
-TEST_F(DurationAnnotatorTest, DoNotClassifyWhenInputIsInvalid) {
- ClassificationResult classification;
- EXPECT_FALSE(duration_annotator_.ClassifyText(
- UTF8ToUnicodeText("Weird space"), {5, 6},
- AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification));
-}
-
-TEST_F(DurationAnnotatorTest, FindsSimpleDuration) {
- const UnicodeText text = UTF8ToUnicodeText("Wake me up in 15 minutes ok?");
- std::vector<Token> tokens = Tokenize(text);
- std::vector<AnnotatedSpan> result;
- EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
-
- EXPECT_THAT(
- result,
- ElementsAre(
- AllOf(Field(&AnnotatedSpan::span, CodepointSpan(14, 24)),
- Field(&AnnotatedSpan::classification,
- ElementsAre(AllOf(
- Field(&ClassificationResult::collection, "duration"),
- Field(&ClassificationResult::duration_ms,
- 15 * 60 * 1000)))))));
-}
-
-TEST_F(DurationAnnotatorTest, FindsDurationWithHalfExpression) {
- const UnicodeText text =
- UTF8ToUnicodeText("Set a timer for 3 and half minutes ok?");
- std::vector<Token> tokens = Tokenize(text);
- std::vector<AnnotatedSpan> result;
- EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
-
- EXPECT_THAT(
- result,
- ElementsAre(
- AllOf(Field(&AnnotatedSpan::span, CodepointSpan(16, 34)),
- Field(&AnnotatedSpan::classification,
- ElementsAre(AllOf(
- Field(&ClassificationResult::collection, "duration"),
- Field(&ClassificationResult::duration_ms,
- 3.5 * 60 * 1000)))))));
-}
-
-TEST_F(DurationAnnotatorTest, FindsComposedDuration) {
- const UnicodeText text =
- UTF8ToUnicodeText("Wake me up in 3 hours and 5 seconds ok?");
- std::vector<Token> tokens = Tokenize(text);
- std::vector<AnnotatedSpan> result;
- EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
-
- EXPECT_THAT(
- result,
- ElementsAre(
- AllOf(Field(&AnnotatedSpan::span, CodepointSpan(14, 35)),
- Field(&AnnotatedSpan::classification,
- ElementsAre(AllOf(
- Field(&ClassificationResult::collection, "duration"),
- Field(&ClassificationResult::duration_ms,
- 3 * 60 * 60 * 1000 + 5 * 1000)))))));
-}
-
-TEST_F(DurationAnnotatorTest, AllUnitsAreCovered) {
- const UnicodeText text = UTF8ToUnicodeText(
- "See you in a week and a day and an hour and a minute and a second");
- std::vector<Token> tokens = Tokenize(text);
- std::vector<AnnotatedSpan> result;
- EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
-
- EXPECT_THAT(
- result,
- ElementsAre(
- AllOf(Field(&AnnotatedSpan::span, CodepointSpan(13, 65)),
- Field(&AnnotatedSpan::classification,
- ElementsAre(AllOf(
- Field(&ClassificationResult::collection, "duration"),
- Field(&ClassificationResult::duration_ms,
- 7 * 24 * 60 * 60 * 1000 + 24 * 60 * 60 * 1000 +
- 60 * 60 * 1000 + 60 * 1000 + 1000)))))));
-}
-
-TEST_F(DurationAnnotatorTest, FindsHalfAnHour) {
- const UnicodeText text = UTF8ToUnicodeText("Set a timer for half an hour");
- std::vector<Token> tokens = Tokenize(text);
- std::vector<AnnotatedSpan> result;
- EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
-
- EXPECT_THAT(
- result,
- ElementsAre(
- AllOf(Field(&AnnotatedSpan::span, CodepointSpan(16, 28)),
- Field(&AnnotatedSpan::classification,
- ElementsAre(AllOf(
- Field(&ClassificationResult::collection, "duration"),
- Field(&ClassificationResult::duration_ms,
- 0.5 * 60 * 60 * 1000)))))));
-}
-
-TEST_F(DurationAnnotatorTest, FindsWhenHalfIsAfterGranularitySpecification) {
- const UnicodeText text =
- UTF8ToUnicodeText("Set a timer for 1 hour and a half");
- std::vector<Token> tokens = Tokenize(text);
- std::vector<AnnotatedSpan> result;
- EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
-
- EXPECT_THAT(
- result,
- ElementsAre(
- AllOf(Field(&AnnotatedSpan::span, CodepointSpan(16, 33)),
- Field(&AnnotatedSpan::classification,
- ElementsAre(AllOf(
- Field(&ClassificationResult::collection, "duration"),
- Field(&ClassificationResult::duration_ms,
- 1.5 * 60 * 60 * 1000)))))));
-}
-
-TEST_F(DurationAnnotatorTest, FindsAnHourAndAHalf) {
- const UnicodeText text =
- UTF8ToUnicodeText("Set a timer for an hour and a half");
- std::vector<Token> tokens = Tokenize(text);
- std::vector<AnnotatedSpan> result;
- EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
-
- EXPECT_THAT(
- result,
- ElementsAre(
- AllOf(Field(&AnnotatedSpan::span, CodepointSpan(19, 34)),
- Field(&AnnotatedSpan::classification,
- ElementsAre(AllOf(
- Field(&ClassificationResult::collection, "duration"),
- Field(&ClassificationResult::duration_ms,
- 1.5 * 60 * 60 * 1000)))))));
-}
-
-TEST_F(DurationAnnotatorTest,
- FindsCorrectlyWhenSecondsComeSecondAndDontHaveNumber) {
- const UnicodeText text =
- UTF8ToUnicodeText("Set a timer for 10 minutes and a second ok?");
- std::vector<Token> tokens = Tokenize(text);
- std::vector<AnnotatedSpan> result;
- EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
-
- EXPECT_THAT(
- result,
- ElementsAre(
- AllOf(Field(&AnnotatedSpan::span, CodepointSpan(16, 39)),
- Field(&AnnotatedSpan::classification,
- ElementsAre(AllOf(
- Field(&ClassificationResult::collection, "duration"),
- Field(&ClassificationResult::duration_ms,
- 10 * 60 * 1000 + 1 * 1000)))))));
-}
-
-TEST_F(DurationAnnotatorTest, DoesNotGreedilyTakeFillerWords) {
- const UnicodeText text = UTF8ToUnicodeText(
- "Set a timer for a a a 10 minutes and 2 seconds an and an ok?");
- std::vector<Token> tokens = Tokenize(text);
- std::vector<AnnotatedSpan> result;
- EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
-
- EXPECT_THAT(
- result,
- ElementsAre(
- AllOf(Field(&AnnotatedSpan::span, CodepointSpan(22, 46)),
- Field(&AnnotatedSpan::classification,
- ElementsAre(AllOf(
- Field(&ClassificationResult::collection, "duration"),
- Field(&ClassificationResult::duration_ms,
- 10 * 60 * 1000 + 2 * 1000)))))));
-}
-
-TEST_F(DurationAnnotatorTest, DoesNotCrashWhenJustHalfIsSaid) {
- const UnicodeText text = UTF8ToUnicodeText("Set a timer for half ok?");
- std::vector<Token> tokens = Tokenize(text);
- std::vector<AnnotatedSpan> result;
- EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
-
- ASSERT_EQ(result.size(), 0);
-}
-
-TEST_F(DurationAnnotatorTest, StripsPunctuationFromTokens) {
- const UnicodeText text =
- UTF8ToUnicodeText("Set a timer for 10 ,minutes, ,and, ,2, seconds, ok?");
- std::vector<Token> tokens = Tokenize(text);
- std::vector<AnnotatedSpan> result;
- EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
-
- EXPECT_THAT(
- result,
- ElementsAre(
- AllOf(Field(&AnnotatedSpan::span, CodepointSpan(16, 46)),
- Field(&AnnotatedSpan::classification,
- ElementsAre(AllOf(
- Field(&ClassificationResult::collection, "duration"),
- Field(&ClassificationResult::duration_ms,
- 10 * 60 * 1000 + 2 * 1000)))))));
-}
-
-TEST_F(DurationAnnotatorTest, FindsCorrectlyWithCombinedQuantityUnitToken) {
- const UnicodeText text = UTF8ToUnicodeText("Show 5-minute timer.");
- std::vector<Token> tokens = Tokenize(text);
- std::vector<AnnotatedSpan> result;
- EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
-
- EXPECT_THAT(
- result,
- ElementsAre(
- AllOf(Field(&AnnotatedSpan::span, CodepointSpan(5, 13)),
- Field(&AnnotatedSpan::classification,
- ElementsAre(AllOf(
- Field(&ClassificationResult::collection, "duration"),
- Field(&ClassificationResult::duration_ms,
- 5 * 60 * 1000)))))));
-}
-
-TEST_F(DurationAnnotatorTest,
- DoesNotIntOverflowWithDurationThatHasMoreThanInt32Millis) {
- ClassificationResult classification;
- EXPECT_TRUE(duration_annotator_.ClassifyText(
- UTF8ToUnicodeText("1400 hours"), {0, 10},
- AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification));
-
- EXPECT_THAT(classification,
- AllOf(Field(&ClassificationResult::collection, "duration"),
- Field(&ClassificationResult::duration_ms,
- 1400L * 60L * 60L * 1000L)));
-}
-
-TEST_F(DurationAnnotatorTest, FindsSimpleDurationIgnoringCase) {
- const UnicodeText text = UTF8ToUnicodeText("Wake me up in 15 MiNuTeS ok?");
- std::vector<Token> tokens = Tokenize(text);
- std::vector<AnnotatedSpan> result;
- EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
-
- EXPECT_THAT(
- result,
- ElementsAre(
- AllOf(Field(&AnnotatedSpan::span, CodepointSpan(14, 24)),
- Field(&AnnotatedSpan::classification,
- ElementsAre(AllOf(
- Field(&ClassificationResult::collection, "duration"),
- Field(&ClassificationResult::duration_ms,
- 15 * 60 * 1000)))))));
-}
-
-TEST_F(DurationAnnotatorTest, FindsDurationWithHalfExpressionIgnoringCase) {
- const UnicodeText text =
- UTF8ToUnicodeText("Set a timer for 3 and HaLf minutes ok?");
- std::vector<Token> tokens = Tokenize(text);
- std::vector<AnnotatedSpan> result;
- EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
-
- EXPECT_THAT(
- result,
- ElementsAre(
- AllOf(Field(&AnnotatedSpan::span, CodepointSpan(16, 34)),
- Field(&AnnotatedSpan::classification,
- ElementsAre(AllOf(
- Field(&ClassificationResult::collection, "duration"),
- Field(&ClassificationResult::duration_ms,
- 3.5 * 60 * 1000)))))));
-}
-
-TEST_F(DurationAnnotatorTest,
- FindsDurationWithHalfExpressionIgnoringFillerWordCase) {
- const UnicodeText text =
- UTF8ToUnicodeText("Set a timer for 3 AnD half minutes ok?");
- std::vector<Token> tokens = Tokenize(text);
- std::vector<AnnotatedSpan> result;
- EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
-
- EXPECT_THAT(
- result,
- ElementsAre(
- AllOf(Field(&AnnotatedSpan::span, CodepointSpan(16, 34)),
- Field(&AnnotatedSpan::classification,
- ElementsAre(AllOf(
- Field(&ClassificationResult::collection, "duration"),
- Field(&ClassificationResult::duration_ms,
- 3.5 * 60 * 1000)))))));
-}
-
-TEST_F(DurationAnnotatorTest, CorrectlyAnnotatesSpanWithDanglingQuantity) {
- const UnicodeText text = UTF8ToUnicodeText("20 minutes 10");
- std::vector<Token> tokens = Tokenize(text);
- std::vector<AnnotatedSpan> result;
- EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
-
- // TODO(b/144752747) Include test for duration_ms.
- EXPECT_THAT(
- result,
- ElementsAre(
- AllOf(Field(&AnnotatedSpan::span, CodepointSpan(0, 13)),
- Field(&AnnotatedSpan::classification,
- ElementsAre(AllOf(Field(&ClassificationResult::collection,
- "duration")))))));
-}
-
-const DurationAnnotatorOptions* TestingJapaneseDurationAnnotatorOptions() {
- static const flatbuffers::DetachedBuffer* options_data = []() {
- DurationAnnotatorOptionsT options;
- options.enabled = true;
-
- options.week_expressions.push_back("週間");
-
- options.day_expressions.push_back("日間");
-
- options.hour_expressions.push_back("時間");
-
- options.minute_expressions.push_back("分");
- options.minute_expressions.push_back("分間");
-
- options.second_expressions.push_back("秒");
- options.second_expressions.push_back("秒間");
-
- options.half_expressions.push_back("半");
-
- options.require_quantity = true;
- options.enable_dangling_quantity_interpretation = false;
-
- flatbuffers::FlatBufferBuilder builder;
- builder.Finish(DurationAnnotatorOptions::Pack(builder, &options));
- return new flatbuffers::DetachedBuffer(builder.Release());
- }();
-
- return flatbuffers::GetRoot<DurationAnnotatorOptions>(options_data->data());
-}
-
-class JapaneseDurationAnnotatorTest : public ::testing::Test {
- protected:
- JapaneseDurationAnnotatorTest()
- : INIT_UNILIB_FOR_TESTING(unilib_),
- feature_processor_(BuildFeatureProcessor(&unilib_)),
- duration_annotator_(TestingJapaneseDurationAnnotatorOptions(),
- feature_processor_.get(), &unilib_) {}
-
- std::vector<Token> Tokenize(const UnicodeText& text) {
- return feature_processor_->Tokenize(text);
- }
-
- UniLib unilib_;
- std::unique_ptr<FeatureProcessor> feature_processor_;
- DurationAnnotator duration_annotator_;
-};
-
-TEST_F(JapaneseDurationAnnotatorTest, FindsDuration) {
- const UnicodeText text = UTF8ToUnicodeText("10 分 の アラーム");
- std::vector<Token> tokens = Tokenize(text);
- std::vector<AnnotatedSpan> result;
- EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
-
- EXPECT_THAT(
- result,
- ElementsAre(
- AllOf(Field(&AnnotatedSpan::span, CodepointSpan(0, 4)),
- Field(&AnnotatedSpan::classification,
- ElementsAre(AllOf(
- Field(&ClassificationResult::collection, "duration"),
- Field(&ClassificationResult::duration_ms,
- 10 * 60 * 1000)))))));
-}
-
-TEST_F(JapaneseDurationAnnotatorTest, FindsDurationWithHalfExpression) {
- const UnicodeText text = UTF8ToUnicodeText("2 分 半 の アラーム");
- std::vector<Token> tokens = Tokenize(text);
- std::vector<AnnotatedSpan> result;
- EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
-
- EXPECT_THAT(
- result,
- ElementsAre(
- AllOf(Field(&AnnotatedSpan::span, CodepointSpan(0, 5)),
- Field(&AnnotatedSpan::classification,
- ElementsAre(AllOf(
- Field(&ClassificationResult::collection, "duration"),
- Field(&ClassificationResult::duration_ms,
- 2.5 * 60 * 1000)))))));
-}
-
-TEST_F(JapaneseDurationAnnotatorTest, IgnoresDurationWithoutQuantity) {
- const UnicodeText text = UTF8ToUnicodeText("分 の アラーム");
- std::vector<Token> tokens = Tokenize(text);
- std::vector<AnnotatedSpan> result;
- EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
-
- EXPECT_THAT(result, IsEmpty());
-}
-
-TEST_F(JapaneseDurationAnnotatorTest, IgnoresDanglingQuantity) {
- const UnicodeText text = UTF8ToUnicodeText("2 分 10 の アラーム");
- std::vector<Token> tokens = Tokenize(text);
- std::vector<AnnotatedSpan> result;
- EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
-
- EXPECT_THAT(
- result,
- ElementsAre(
- AllOf(Field(&AnnotatedSpan::span, CodepointSpan(0, 3)),
- Field(&AnnotatedSpan::classification,
- ElementsAre(AllOf(
- Field(&ClassificationResult::collection, "duration"),
- Field(&ClassificationResult::duration_ms,
- 2 * 60 * 1000)))))));
-}
-
-} // namespace
-} // namespace libtextclassifier3
diff --git a/native/annotator/grammar/dates/cfg-datetime-annotator.cc b/native/annotator/grammar/dates/cfg-datetime-annotator.cc
index 5b24b31..78b8dad 100644
--- a/native/annotator/grammar/dates/cfg-datetime-annotator.cc
+++ b/native/annotator/grammar/dates/cfg-datetime-annotator.cc
@@ -256,17 +256,23 @@
options->tokenization_codepoint_config() != nullptr &&
options->tokenize_on_script_change();
return Tokenizer(options->tokenization_type(), unilib, codepoint_config,
- internal_codepoint_config, tokenize_on_script_change, true);
+ internal_codepoint_config, tokenize_on_script_change,
+ /*icu_preserve_whitespace_tokens=*/false);
}
} // namespace
CfgDatetimeAnnotator::CfgDatetimeAnnotator(
const UniLib& unilib, const GrammarTokenizerOptions* tokenizer_options,
- const CalendarLib& calendar_lib, const DatetimeRules* datetime_rules)
+ const CalendarLib& calendar_lib, const DatetimeRules* datetime_rules,
+ const float annotator_target_classification_score,
+ const float annotator_priority_score)
: calendar_lib_(calendar_lib),
tokenizer_(BuildTokenizer(&unilib, tokenizer_options)),
- parser_(unilib, datetime_rules) {}
+ parser_(unilib, datetime_rules),
+ annotator_target_classification_score_(
+ annotator_target_classification_score),
+ annotator_priority_score_(annotator_priority_score) {}
// Helper method to convert the Thing into DatetimeParseResult.
// Thing constains the annotation instance i.e. type of the annotation and its
@@ -306,8 +312,10 @@
DatetimeParseResultSpan datetime_parse_result_span;
datetime_parse_result_span.span =
CodepointSpan{annotation.begin, annotation.end};
- datetime_parse_result_span.priority_score = 0.1;
- datetime_parse_result_span.target_classification_score = 1.0;
+ datetime_parse_result_span.target_classification_score =
+ annotator_target_classification_score_;
+ datetime_parse_result_span.priority_score = annotator_priority_score_;
+
// Though the datastructre allow multiple DatetimeParseResult per span
// but for annotator based on grammar there will just one.
DatetimeParseResult datetime_parse_result;
diff --git a/native/annotator/grammar/dates/cfg-datetime-annotator.h b/native/annotator/grammar/dates/cfg-datetime-annotator.h
index 1980aae..00ed447 100644
--- a/native/annotator/grammar/dates/cfg-datetime-annotator.h
+++ b/native/annotator/grammar/dates/cfg-datetime-annotator.h
@@ -36,7 +36,9 @@
CfgDatetimeAnnotator(const UniLib& unilib,
const GrammarTokenizerOptions* tokenizer_options,
const CalendarLib& calendar_lib,
- const DatetimeRules* datetime_rules);
+ const DatetimeRules* datetime_rules,
+ const float annotator_target_classification_score,
+ const float annotator_priority_score);
// CfgDatetimeAnnotator is neither copyable nor movable.
CfgDatetimeAnnotator(const CfgDatetimeAnnotator&) = delete;
@@ -71,6 +73,8 @@
const CalendarLib& calendar_lib_;
const Tokenizer tokenizer_;
DateParser parser_;
+ const float annotator_target_classification_score_;
+ const float annotator_priority_score_;
};
} // namespace libtextclassifier3::dates
diff --git a/native/annotator/grammar/dates/parser.cc b/native/annotator/grammar/dates/parser.cc
index e745e2e..bf6e428 100644
--- a/native/annotator/grammar/dates/parser.cc
+++ b/native/annotator/grammar/dates/parser.cc
@@ -928,7 +928,7 @@
}
}
grammar::Matcher matcher(unilib_, datetime_rules_->rules(), locale_rules);
- lexer_.Process(tokens, &matcher);
+ lexer_.Process(tokens, /*matches=*/{}, &matcher);
return GetOutputAsAnnotationList(unilib_, extractor, codepoint_offsets,
options);
}
diff --git a/native/annotator/grammar/dates/testing/equals-proto.h b/native/annotator/grammar/dates/testing/equals-proto.h
deleted file mode 100644
index fba4de4..0000000
--- a/native/annotator/grammar/dates/testing/equals-proto.h
+++ /dev/null
@@ -1,28 +0,0 @@
-
-#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_TESTING_EQUALS_PROTO_H_
-#define LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_TESTING_EQUALS_PROTO_H_
-
-#include "net/proto2/compat/public/message_lite.h" // IWYU pragma: export
-#include "gmock/gmock.h" // IWYU pragma: export
-
-#if defined(__ANDROID__) || defined(__APPLE__)
-namespace libtextclassifier3 {
-namespace portable_equals_proto {
-MATCHER_P(EqualsProto, other, "Compare MessageLite by serialized string") {
- return ::testing::ExplainMatchResult(::testing::Eq(other.SerializeAsString()),
- arg.SerializeAsString(),
- result_listener);
-} // MATCHER_P
-} // namespace portable_equals_proto
-} // namespace libtextclassifier3
-#else
-namespace libtextclassifier3 {
-namespace portable_equals_proto {
-// Leverage the powerful matcher when available, for human readable
-// differences.
-using ::testing::EqualsProto;
-} // namespace portable_equals_proto
-} // namespace libtextclassifier3
-#endif // defined(__ANDROID__) || defined(__APPLE__)
-
-#endif // LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_TESTING_EQUALS_PROTO_H_
diff --git a/native/annotator/grammar/dates/utils/date-match_test.cc b/native/annotator/grammar/dates/utils/date-match_test.cc
deleted file mode 100644
index d9c0af7..0000000
--- a/native/annotator/grammar/dates/utils/date-match_test.cc
+++ /dev/null
@@ -1,421 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "annotator/grammar/dates/utils/date-match.h"
-
-#include <stdint.h>
-
-#include <string>
-
-#include "annotator/grammar/dates/dates_generated.h"
-#include "annotator/grammar/dates/timezone-code_generated.h"
-#include "annotator/grammar/dates/utils/date-utils.h"
-#include "utils/strings/append.h"
-#include "gtest/gtest.h"
-
-namespace libtextclassifier3 {
-namespace dates {
-namespace {
-
-class DateMatchTest : public ::testing::Test {
- protected:
- enum {
- X = NO_VAL,
- };
-
- static DayOfWeek DOW_X() { return DayOfWeek_DOW_NONE; }
- static DayOfWeek SUN() { return DayOfWeek_SUNDAY; }
-
- static BCAD BCAD_X() { return BCAD_BCAD_NONE; }
- static BCAD BC() { return BCAD_BC; }
-
- DateMatch& SetDate(DateMatch* date, int year, int8 month, int8 day,
- DayOfWeek day_of_week = DOW_X(), BCAD bc_ad = BCAD_X()) {
- date->year = year;
- date->month = month;
- date->day = day;
- date->day_of_week = day_of_week;
- date->bc_ad = bc_ad;
- return *date;
- }
-
- DateMatch& SetTimeValue(DateMatch* date, int8 hour, int8 minute = X,
- int8 second = X, double fraction_second = X) {
- date->hour = hour;
- date->minute = minute;
- date->second = second;
- date->fraction_second = fraction_second;
- return *date;
- }
-
- DateMatch& SetTimeSpan(DateMatch* date, TimespanCode time_span_code) {
- date->time_span_code = time_span_code;
- return *date;
- }
-
- DateMatch& SetTimeZone(DateMatch* date, TimezoneCode time_zone_code,
- int16 time_zone_offset = INT16_MIN) {
- date->time_zone_code = time_zone_code;
- date->time_zone_offset = time_zone_offset;
- return *date;
- }
-
- bool SameDate(const DateMatch& a, const DateMatch& b) {
- return (a.day == b.day && a.month == b.month && a.year == b.year &&
- a.day_of_week == b.day_of_week);
- }
-
- DateMatch& SetDayOfWeek(DateMatch* date, DayOfWeek dow) {
- date->day_of_week = dow;
- return *date;
- }
-};
-
-TEST_F(DateMatchTest, BitFieldWidth) {
- // For DateMatch::day_of_week (:8).
- EXPECT_GE(DayOfWeek_MIN, INT8_MIN);
- EXPECT_LE(DayOfWeek_MAX, INT8_MAX);
-
- // For DateMatch::bc_ad (:8).
- EXPECT_GE(BCAD_MIN, INT8_MIN);
- EXPECT_LE(BCAD_MAX, INT8_MAX);
-
- // For DateMatch::time_span_code (:16).
- EXPECT_GE(TimespanCode_MIN, INT16_MIN);
- EXPECT_LE(TimespanCode_MAX, INT16_MAX);
-}
-
-TEST_F(DateMatchTest, IsValid) {
- // Valid: dates.
- {
- DateMatch d;
- SetDate(&d, 2014, 1, 26);
- EXPECT_TRUE(d.IsValid()) << d.DebugString();
- }
- {
- DateMatch d;
- SetDate(&d, 2014, 1, X);
- EXPECT_TRUE(d.IsValid()) << d.DebugString();
- }
- {
- DateMatch d;
- SetDate(&d, 2014, X, X);
- EXPECT_TRUE(d.IsValid()) << d.DebugString();
- }
- {
- DateMatch d;
- SetDate(&d, X, 1, 26);
- EXPECT_TRUE(d.IsValid()) << d.DebugString();
- }
- {
- DateMatch d;
- SetDate(&d, X, 1, X);
- EXPECT_TRUE(d.IsValid()) << d.DebugString();
- }
- {
- DateMatch d;
- SetDate(&d, X, X, 26);
- EXPECT_TRUE(d.IsValid()) << d.DebugString();
- }
- {
- DateMatch d;
- SetDate(&d, 2014, 1, 26, SUN());
- EXPECT_TRUE(d.IsValid()) << d.DebugString();
- }
- {
- DateMatch d;
- SetDate(&d, X, 1, 26, SUN());
- EXPECT_TRUE(d.IsValid()) << d.DebugString();
- }
- {
- DateMatch d;
- SetDate(&d, X, X, 26, SUN());
- EXPECT_TRUE(d.IsValid()) << d.DebugString();
- }
- {
- DateMatch d;
- SetDate(&d, 2014, 1, 26, DOW_X(), BC());
- EXPECT_TRUE(d.IsValid()) << d.DebugString();
- }
- // Valid: times.
- {
- DateMatch d;
- SetTimeValue(&d, 12, 30, 59, 0.99);
- EXPECT_TRUE(d.IsValid()) << d.DebugString();
- }
- {
- DateMatch d;
- SetTimeValue(&d, 12, 30, 59);
- EXPECT_TRUE(d.IsValid()) << d.DebugString();
- }
- {
- DateMatch d;
- SetTimeValue(&d, 12, 30);
- EXPECT_TRUE(d.IsValid()) << d.DebugString();
- }
- {
- DateMatch d;
- SetTimeValue(&d, 12);
- EXPECT_TRUE(d.IsValid()) << d.DebugString();
- }
- // Valid: mixed.
- {
- DateMatch d;
- SetDate(&d, 2014, 1, 26);
- SetTimeValue(&d, 12, 30, 59, 0.99);
- EXPECT_TRUE(d.IsValid()) << d.DebugString();
- }
- {
- DateMatch d;
- SetDate(&d, X, 1, 26);
- SetTimeValue(&d, 12, 30, 59);
- EXPECT_TRUE(d.IsValid()) << d.DebugString();
- }
- {
- DateMatch d;
- SetDate(&d, X, X, X, SUN());
- SetTimeValue(&d, 12, 30);
- EXPECT_TRUE(d.IsValid()) << d.DebugString();
- }
- // Invalid: dates.
- {
- DateMatch d;
- SetDate(&d, X, 1, 26, DOW_X(), BC());
- EXPECT_FALSE(d.IsValid()) << d.DebugString();
- }
- {
- DateMatch d;
- SetDate(&d, 2014, X, 26);
- EXPECT_FALSE(d.IsValid()) << d.DebugString();
- }
- {
- DateMatch d;
- SetDate(&d, 2014, X, X, SUN());
- EXPECT_FALSE(d.IsValid()) << d.DebugString();
- }
- {
- DateMatch d;
- SetDate(&d, X, 1, X, SUN());
- EXPECT_FALSE(d.IsValid()) << d.DebugString();
- }
- // Invalid: times.
- {
- DateMatch d;
- SetTimeValue(&d, 12, X, 59);
- EXPECT_FALSE(d.IsValid()) << d.DebugString();
- }
- {
- DateMatch d;
- SetTimeValue(&d, 12, X, X, 0.99);
- EXPECT_FALSE(d.IsValid()) << d.DebugString();
- }
- {
- DateMatch d;
- SetTimeValue(&d, 12, 30, X, 0.99);
- EXPECT_FALSE(d.IsValid()) << d.DebugString();
- }
- {
- DateMatch d;
- SetTimeValue(&d, X, 30);
- EXPECT_FALSE(d.IsValid()) << d.DebugString();
- }
- // Invalid: mixed.
- {
- DateMatch d;
- SetDate(&d, 2014, 1, X);
- SetTimeValue(&d, 12);
- EXPECT_FALSE(d.IsValid()) << d.DebugString();
- }
- // Invalid: empty.
- {
- DateMatch d;
- EXPECT_FALSE(d.IsValid()) << d.DebugString();
- }
-}
-
-std::string DebugStrings(const std::vector<DateMatch>& instances) {
- std::string res;
- for (int i = 0; i < instances.size(); ++i) {
- ::libtextclassifier3::strings::SStringAppendF(
- &res, 0, "[%d] == %s\n", i, instances[i].DebugString().c_str());
- }
- return res;
-}
-
-TEST_F(DateMatchTest, IncrementOneDay) {
- {
- DateMatch date;
- SetDate(&date, 2014, 2, 2);
- DateMatch expected;
- SetDate(&expected, 2014, 2, 3);
- IncrementOneDay(&date);
- EXPECT_TRUE(SameDate(date, expected));
- }
- {
- DateMatch date;
- SetDate(&date, 2015, 2, 28);
- DateMatch expected;
- SetDate(&expected, 2015, 3, 1);
- IncrementOneDay(&date);
- EXPECT_TRUE(SameDate(date, expected));
- }
- {
- DateMatch date;
- SetDate(&date, 2016, 2, 28);
- DateMatch expected;
- SetDate(&expected, 2016, 2, 29);
- IncrementOneDay(&date);
- EXPECT_TRUE(SameDate(date, expected));
- }
- {
- DateMatch date;
- SetDate(&date, 2017, 7, 16);
- SetDayOfWeek(&date, DayOfWeek_SUNDAY);
- DateMatch expected;
- SetDate(&expected, 2017, 7, 17);
- SetDayOfWeek(&expected, DayOfWeek_MONDAY);
- IncrementOneDay(&date);
- EXPECT_TRUE(SameDate(date, expected));
- }
- {
- DateMatch date;
- SetDate(&date, X, 7, 16);
- SetDayOfWeek(&date, DayOfWeek_MONDAY);
- DateMatch expected;
- SetDate(&expected, X, 7, 17);
- SetDayOfWeek(&expected, DayOfWeek_TUESDAY);
- IncrementOneDay(&date);
- EXPECT_TRUE(SameDate(date, expected));
- }
-}
-
-TEST_F(DateMatchTest, IsRefinement) {
- {
- DateMatch a;
- SetDate(&a, 2014, 2, X);
- DateMatch b;
- SetDate(&b, 2014, X, X);
- EXPECT_TRUE(IsRefinement(a, b)) << DebugStrings({a, b});
- }
- {
- DateMatch a;
- SetDate(&a, 2014, 2, 24);
- DateMatch b;
- SetDate(&b, 2014, 2, X);
- EXPECT_TRUE(IsRefinement(a, b)) << DebugStrings({a, b});
- }
- {
- DateMatch a;
- SetDate(&a, 2014, 2, 24);
- DateMatch b;
- SetDate(&b, X, 2, 24);
- EXPECT_TRUE(IsRefinement(a, b)) << DebugStrings({a, b});
- }
- {
- DateMatch a;
- SetDate(&a, 2014, 2, 24);
- SetTimeValue(&a, 9, X, X);
- DateMatch b;
- SetDate(&b, 2014, 2, 24);
- EXPECT_TRUE(IsRefinement(a, b)) << DebugStrings({a, b});
- }
- {
- DateMatch a;
- SetDate(&a, 2014, 2, 24);
- SetTimeValue(&a, 9, 0, X);
- DateMatch b;
- SetDate(&b, 2014, 2, 24);
- SetTimeValue(&b, 9, X, X);
- EXPECT_TRUE(IsRefinement(a, b)) << DebugStrings({a, b});
- }
- {
- DateMatch a;
- SetDate(&a, 2014, 2, 24);
- SetTimeValue(&a, 9, 0, 0);
- DateMatch b;
- SetDate(&b, 2014, 2, 24);
- SetTimeValue(&b, 9, 0, X);
- EXPECT_TRUE(IsRefinement(a, b)) << DebugStrings({a, b});
- }
- {
- DateMatch a;
- SetDate(&a, 2014, 2, 24);
- SetTimeValue(&a, 9, X, X);
- SetTimeSpan(&a, TimespanCode_AM);
- DateMatch b;
- SetDate(&b, 2014, 2, 24);
- SetTimeValue(&b, 9, X, X);
- EXPECT_TRUE(IsRefinement(a, b)) << DebugStrings({a, b});
- }
- {
- DateMatch a;
- SetDate(&a, 2014, 2, 24);
- SetTimeValue(&a, 9, X, X);
- SetTimeZone(&a, TimezoneCode_PST8PDT);
- DateMatch b;
- SetDate(&b, 2014, 2, 24);
- SetTimeValue(&b, 9, X, X);
- EXPECT_TRUE(IsRefinement(a, b)) << DebugStrings({a, b});
- }
- {
- DateMatch a;
- SetDate(&a, 2014, 2, 24);
- SetTimeValue(&a, 9, X, X);
- a.priority += 10;
- DateMatch b;
- SetDate(&b, 2014, 2, 24);
- SetTimeValue(&b, 9, X, X);
- EXPECT_TRUE(IsRefinement(a, b)) << DebugStrings({a, b});
- }
- {
- DateMatch a;
- SetDate(&a, 2014, 2, 24);
- SetTimeValue(&a, 9, X, X);
- DateMatch b;
- SetDate(&b, 2014, 2, 24);
- SetTimeValue(&b, 9, X, X);
- EXPECT_TRUE(IsRefinement(a, b)) << DebugStrings({a, b});
- }
- {
- DateMatch a;
- SetDate(&a, 2014, 2, 24);
- SetTimeValue(&a, 9, X, X);
- DateMatch b;
- SetDate(&b, X, 2, 24);
- SetTimeValue(&b, 9, 0, X);
- EXPECT_FALSE(IsRefinement(a, b)) << DebugStrings({a, b});
- }
- {
- DateMatch a;
- SetDate(&a, X, 2, 24);
- SetTimeValue(&a, 9, X, X);
- DateMatch b;
- SetDate(&b, 2014, 2, 24);
- EXPECT_FALSE(IsRefinement(a, b)) << DebugStrings({a, b});
- }
- {
- DateMatch a;
- SetTimeValue(&a, 9, 0, 0);
- DateMatch b;
- SetTimeValue(&b, 9, X, X);
- SetTimeSpan(&b, TimespanCode_AM);
- EXPECT_FALSE(IsRefinement(a, b)) << DebugStrings({a, b});
- }
-}
-
-} // namespace
-} // namespace dates
-} // namespace libtextclassifier3
diff --git a/native/annotator/grammar/grammar-annotator.cc b/native/annotator/grammar/grammar-annotator.cc
index d39e33c..71d0b43 100644
--- a/native/annotator/grammar/grammar-annotator.cc
+++ b/native/annotator/grammar/grammar-annotator.cc
@@ -172,7 +172,7 @@
// Run the grammar.
grammar::Matcher matcher(*unilib_, model_->rules(), locale_rules);
- lexer_.Process(tokens, &matcher);
+ lexer_.Process(tokens, /*matches=*/{}, &matcher);
// Populate results.
return callback_handler.GetAnnotations(result);
diff --git a/native/annotator/grammar/grammar-annotator_test.cc b/native/annotator/grammar/grammar-annotator_test.cc
deleted file mode 100644
index 15dcc70..0000000
--- a/native/annotator/grammar/grammar-annotator_test.cc
+++ /dev/null
@@ -1,195 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "annotator/grammar/grammar-annotator.h"
-
-#include "annotator/model_generated.h"
-#include "utils/grammar/utils/rules.h"
-#include "utils/tokenizer.h"
-#include "utils/utf8/unilib.h"
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-
-namespace libtextclassifier3 {
-namespace {
-
-using testing::ElementsAre;
-using testing::Pair;
-using testing::Value;
-
-MATCHER_P3(IsAnnotatedSpan, start, end, collection, "") {
- return Value(arg.span, Pair(start, end)) &&
- Value(arg.classification.front().collection, collection);
-}
-
-class GrammarAnnotatorTest : public testing::Test {
- protected:
- GrammarAnnotatorTest()
- : INIT_UNILIB_FOR_TESTING(unilib_),
- tokenizer_(TokenizationType_ICU, &unilib_, {}, {},
- /*split_on_script_change=*/true,
- /*icu_preserve_whitespace_tokens=*/true) {}
-
- flatbuffers::DetachedBuffer PackModel(const GrammarModelT& model) const {
- flatbuffers::FlatBufferBuilder builder;
- builder.Finish(GrammarModel::Pack(builder, &model));
- return builder.Release();
- }
-
- int AddRuleClassificationResult(const std::string& collection,
- GrammarModelT* model) const {
- const int result_id = model->rule_classification_result.size();
- model->rule_classification_result.emplace_back(
- new GrammarModel_::RuleClassificationResultT);
- GrammarModel_::RuleClassificationResultT* result =
- model->rule_classification_result.back().get();
- result->collection_name = collection;
- return result_id;
- }
-
- UniLib unilib_;
- Tokenizer tokenizer_;
-};
-
-TEST_F(GrammarAnnotatorTest, AnnotesWithGrammarRules) {
- // Create test rules.
- GrammarModelT grammar_model;
- grammar_model.rules.reset(new grammar::RulesSetT);
- grammar::Rules rules;
- rules.Add("<carrier>", {"lx"});
- rules.Add("<carrier>", {"aa"});
- rules.Add("<flight_code>", {"<2_digits>"});
- rules.Add("<flight_code>", {"<3_digits>"});
- rules.Add("<flight_code>", {"<4_digits>"});
- rules.Add(
- "<flight>", {"<carrier>", "<flight_code>"},
- /*callback=*/
- static_cast<grammar::CallbackId>(GrammarAnnotator::Callback::kRuleMatch),
- /*callback_param=*/
- AddRuleClassificationResult("flight", &grammar_model));
- rules.Finalize().Serialize(/*include_debug_information=*/false,
- grammar_model.rules.get());
- flatbuffers::DetachedBuffer serialized_model = PackModel(grammar_model);
- GrammarAnnotator annotator(
- &unilib_, flatbuffers::GetRoot<GrammarModel>(serialized_model.data()));
-
- std::vector<AnnotatedSpan> result;
- EXPECT_TRUE(annotator.Annotate(
- {Locale::FromBCP47("en")},
- tokenizer_.Tokenize(
- "My flight: LX 38 arriving at 4pm, I'll fly back on AA2014"),
- &result));
-
- EXPECT_THAT(result, ElementsAre(IsAnnotatedSpan(11, 16, "flight"),
- IsAnnotatedSpan(51, 57, "flight")));
-}
-
-TEST_F(GrammarAnnotatorTest, HandlesAssertions) {
- // Create test rules.
- GrammarModelT grammar_model;
- grammar_model.rules.reset(new grammar::RulesSetT);
- grammar::Rules rules;
- rules.Add("<carrier>", {"lx"});
- rules.Add("<carrier>", {"aa"});
- rules.Add("<flight_code>", {"<2_digits>"});
- rules.Add("<flight_code>", {"<3_digits>"});
- rules.Add("<flight_code>", {"<4_digits>"});
-
- // Flight: carrier + flight code and check right context.
- rules.Add(
- "<flight>", {"<carrier>", "<flight_code>", "<context_assertion>?"},
- /*callback=*/
- static_cast<grammar::CallbackId>(GrammarAnnotator::Callback::kRuleMatch),
- /*callback_param=*/
- AddRuleClassificationResult("flight", &grammar_model));
-
- // Exclude matches like: LX 38.00 etc.
- rules.Add("<context_assertion>", {".?", "<digits>"},
- /*callback=*/
- static_cast<grammar::CallbackId>(
- GrammarAnnotator::Callback::kAssertionMatch),
- /*callback_param=*/true /*negative*/);
-
- // Assertion matches will create their own match objects.
- // We declare the handler as a filter so that the grammar system knows that we
- // handle this ourselves.
- rules.DefineFilter(static_cast<grammar::CallbackId>(
- GrammarAnnotator::Callback::kAssertionMatch));
-
- rules.Finalize().Serialize(/*include_debug_information=*/false,
- grammar_model.rules.get());
- flatbuffers::DetachedBuffer serialized_model = PackModel(grammar_model);
- GrammarAnnotator annotator(
- &unilib_, flatbuffers::GetRoot<GrammarModel>(serialized_model.data()));
-
- std::vector<AnnotatedSpan> result;
- EXPECT_TRUE(annotator.Annotate(
- {Locale::FromBCP47("en")},
- tokenizer_.Tokenize("My flight: LX 38 arriving at 4pm, I'll fly back on "
- "AA2014 on LX 38.00"),
- &result));
-
- EXPECT_THAT(result, ElementsAre(IsAnnotatedSpan(11, 16, "flight"),
- IsAnnotatedSpan(51, 57, "flight")));
-}
-
-TEST_F(GrammarAnnotatorTest, HandlesCapturingGroups) {
- // Create test rules.
- GrammarModelT grammar_model;
- grammar_model.rules.reset(new grammar::RulesSetT);
- grammar::Rules rules;
- rules.Add("<low_confidence_phone>", {"<digits>"},
- /*callback=*/
- static_cast<grammar::CallbackId>(
- GrammarAnnotator::Callback::kCapturingMatch),
- /*callback_param=*/0);
-
- // Create rule result.
- const int classification_result_id =
- AddRuleClassificationResult("phone", &grammar_model);
- grammar_model.rule_classification_result[classification_result_id]
- ->capturing_group.emplace_back(new CapturingGroupT);
- grammar_model.rule_classification_result[classification_result_id]
- ->capturing_group.back()
- ->extend_selection = true;
-
- rules.Add(
- "<phone>", {"please", "call", "<low_confidence_phone>"},
- /*callback=*/
- static_cast<grammar::CallbackId>(GrammarAnnotator::Callback::kRuleMatch),
- /*callback_param=*/classification_result_id);
-
- // Capturing matches will create their own match objects to keep track of
- // match ids, so we declare the handler as a filter so that the grammar system
- // knows that we handle this ourselves.
- rules.DefineFilter(static_cast<grammar::CallbackId>(
- GrammarAnnotator::Callback::kCapturingMatch));
-
- rules.Finalize().Serialize(/*include_debug_information=*/false,
- grammar_model.rules.get());
- flatbuffers::DetachedBuffer serialized_model = PackModel(grammar_model);
- GrammarAnnotator annotator(
- &unilib_, flatbuffers::GetRoot<GrammarModel>(serialized_model.data()));
-
- std::vector<AnnotatedSpan> result;
- EXPECT_TRUE(annotator.Annotate(
- {Locale::FromBCP47("en")},
- tokenizer_.Tokenize("Please call 911 before 10 am!"), &result));
- EXPECT_THAT(result, ElementsAre(IsAnnotatedSpan(12, 15, "phone")));
-}
-
-} // namespace
-} // namespace libtextclassifier3
diff --git a/native/annotator/model.fbs b/native/annotator/model.fbs
index aa99b21..3e616b6 100755
--- a/native/annotator/model.fbs
+++ b/native/annotator/model.fbs
@@ -380,6 +380,12 @@
// The datetime grammar rules.
datetime_rules:dates.DatetimeRules;
+
+ // The final score to assign to the results of grammar model
+ target_classification_score:float = 1;
+
+ // The priority score used for conflict resolution with the other models.
+ priority_score:float = 1;
}
namespace libtextclassifier3.DatetimeModelLibrary_;
diff --git a/native/annotator/number/number.cc b/native/annotator/number/number.cc
index f01ba3f..fe986ae 100644
--- a/native/annotator/number/number.cc
+++ b/native/annotator/number/number.cc
@@ -285,7 +285,6 @@
void NumberAnnotator::FindPercentages(
const UnicodeText& context, std::vector<AnnotatedSpan>* result) const {
- std::vector<AnnotatedSpan> percentage_annotations;
const int initial_result_size = result->size();
for (int i = 0; i < initial_result_size; ++i) {
AnnotatedSpan annotated_span = (*result)[i];
diff --git a/native/annotator/person_name/person-name-engine-dummy.h b/native/annotator/person_name/person-name-engine-dummy.h
index 91ae2e5..9c83241 100644
--- a/native/annotator/person_name/person-name-engine-dummy.h
+++ b/native/annotator/person_name/person-name-engine-dummy.h
@@ -32,7 +32,8 @@
// A dummy implementation of the person name engine.
class PersonNameEngine {
public:
- explicit PersonNameEngine(const UniLib* unilib) {}
+ explicit PersonNameEngine(const FeatureProcessor* feature_processor,
+ const UniLib* unilib) {}
bool Initialize(const PersonNameModel* model) {
TC3_LOG(ERROR) << "No person name engine to initialize.";
diff --git a/native/annotator/person_name/person_name_model.fbs b/native/annotator/person_name/person_name_model.fbs
index 6421341..091ad31 100755
--- a/native/annotator/person_name/person_name_model.fbs
+++ b/native/annotator/person_name/person_name_model.fbs
@@ -26,7 +26,7 @@
person_name:string (shared);
}
-// Next ID: 3
+// Next ID: 5
namespace libtextclassifier3;
table PersonNameModel {
// Decides if the person name annotator is enabled.
@@ -35,6 +35,15 @@
// List of all person names which are considered by the person name annotator.
person_names:[PersonNameModel_.PersonName];
+
+ // Decides if the English genitive ending 's is stripped, e.g., if Peter's is
+ // stripped to Peter before looking for the name in the dictionary. required
+ strip_english_genitive_ending:bool;
+
+ // List of codepoints that are considered as 'end of person name' indicator in
+ // the heuristic to find the longest person name match.
+ // required
+ end_of_person_name_indicators:[int];
}
root_type libtextclassifier3.PersonNameModel;
diff --git a/native/annotator/quantization_test.cc b/native/annotator/quantization_test.cc
deleted file mode 100644
index b995096..0000000
--- a/native/annotator/quantization_test.cc
+++ /dev/null
@@ -1,163 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "annotator/quantization.h"
-
-#include <vector>
-
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-
-using testing::ElementsAreArray;
-using testing::FloatEq;
-using testing::Matcher;
-
-namespace libtextclassifier3 {
-namespace {
-
-Matcher<std::vector<float>> ElementsAreFloat(const std::vector<float>& values) {
- std::vector<Matcher<float>> matchers;
- for (const float value : values) {
- matchers.push_back(FloatEq(value));
- }
- return ElementsAreArray(matchers);
-}
-
-TEST(QuantizationTest, DequantizeAdd8bit) {
- std::vector<float> scales{{0.1, 9.0, -7.0}};
- std::vector<uint8> embeddings{{/*0: */ 0x00, 0xFF, 0x09, 0x00,
- /*1: */ 0xFF, 0x09, 0x00, 0xFF,
- /*2: */ 0x09, 0x00, 0xFF, 0x09}};
-
- const int quantization_bits = 8;
- const int bytes_per_embedding = 4;
- const int num_sparse_features = 7;
- {
- const int bucket_id = 0;
- std::vector<float> dest(4, 0.0);
- DequantizeAdd(scales.data(), embeddings.data(), bytes_per_embedding,
- num_sparse_features, quantization_bits, bucket_id,
- dest.data(), dest.size());
-
- EXPECT_THAT(dest,
- ElementsAreFloat(std::vector<float>{
- // clang-format off
- {1.0 / 7 * 0.1 * (0x00 - 128),
- 1.0 / 7 * 0.1 * (0xFF - 128),
- 1.0 / 7 * 0.1 * (0x09 - 128),
- 1.0 / 7 * 0.1 * (0x00 - 128)}
- // clang-format on
- }));
- }
-
- {
- const int bucket_id = 1;
- std::vector<float> dest(4, 0.0);
- DequantizeAdd(scales.data(), embeddings.data(), bytes_per_embedding,
- num_sparse_features, quantization_bits, bucket_id,
- dest.data(), dest.size());
-
- EXPECT_THAT(dest,
- ElementsAreFloat(std::vector<float>{
- // clang-format off
- {1.0 / 7 * 9.0 * (0xFF - 128),
- 1.0 / 7 * 9.0 * (0x09 - 128),
- 1.0 / 7 * 9.0 * (0x00 - 128),
- 1.0 / 7 * 9.0 * (0xFF - 128)}
- // clang-format on
- }));
- }
-}
-
-TEST(QuantizationTest, DequantizeAdd1bitZeros) {
- const int bytes_per_embedding = 4;
- const int num_buckets = 3;
- const int num_sparse_features = 7;
- const int quantization_bits = 1;
- const int bucket_id = 1;
-
- std::vector<float> scales(num_buckets);
- std::vector<uint8> embeddings(bytes_per_embedding * num_buckets);
- std::fill(scales.begin(), scales.end(), 1);
- std::fill(embeddings.begin(), embeddings.end(), 0);
-
- std::vector<float> dest(32);
- DequantizeAdd(scales.data(), embeddings.data(), bytes_per_embedding,
- num_sparse_features, quantization_bits, bucket_id, dest.data(),
- dest.size());
-
- std::vector<float> expected(32);
- std::fill(expected.begin(), expected.end(),
- 1.0 / num_sparse_features * (0 - 1));
- EXPECT_THAT(dest, ElementsAreFloat(expected));
-}
-
-TEST(QuantizationTest, DequantizeAdd1bitOnes) {
- const int bytes_per_embedding = 4;
- const int num_buckets = 3;
- const int num_sparse_features = 7;
- const int quantization_bits = 1;
- const int bucket_id = 1;
-
- std::vector<float> scales(num_buckets, 1.0);
- std::vector<uint8> embeddings(bytes_per_embedding * num_buckets, 0xFF);
-
- std::vector<float> dest(32);
- DequantizeAdd(scales.data(), embeddings.data(), bytes_per_embedding,
- num_sparse_features, quantization_bits, bucket_id, dest.data(),
- dest.size());
- std::vector<float> expected(32);
- std::fill(expected.begin(), expected.end(),
- 1.0 / num_sparse_features * (1 - 1));
- EXPECT_THAT(dest, ElementsAreFloat(expected));
-}
-
-TEST(QuantizationTest, DequantizeAdd3bit) {
- const int bytes_per_embedding = 4;
- const int num_buckets = 3;
- const int num_sparse_features = 7;
- const int quantization_bits = 3;
- const int bucket_id = 1;
-
- std::vector<float> scales(num_buckets, 1.0);
- scales[1] = 9.0;
- std::vector<uint8> embeddings(bytes_per_embedding * num_buckets, 0);
- // For bucket_id=1, the embedding has values 0..9 for indices 0..9:
- embeddings[4] = (1 << 7) | (1 << 6) | (1 << 4) | 1;
- embeddings[5] = (1 << 6) | (1 << 4) | (1 << 3);
- embeddings[6] = (1 << 4) | (1 << 3) | (1 << 2) | (1 << 1) | 1;
-
- std::vector<float> dest(10);
- DequantizeAdd(scales.data(), embeddings.data(), bytes_per_embedding,
- num_sparse_features, quantization_bits, bucket_id, dest.data(),
- dest.size());
-
- std::vector<float> expected;
- expected.push_back(1.0 / num_sparse_features * (1 - 4) * scales[bucket_id]);
- expected.push_back(1.0 / num_sparse_features * (2 - 4) * scales[bucket_id]);
- expected.push_back(1.0 / num_sparse_features * (3 - 4) * scales[bucket_id]);
- expected.push_back(1.0 / num_sparse_features * (4 - 4) * scales[bucket_id]);
- expected.push_back(1.0 / num_sparse_features * (5 - 4) * scales[bucket_id]);
- expected.push_back(1.0 / num_sparse_features * (6 - 4) * scales[bucket_id]);
- expected.push_back(1.0 / num_sparse_features * (7 - 4) * scales[bucket_id]);
- expected.push_back(1.0 / num_sparse_features * (0 - 4) * scales[bucket_id]);
- expected.push_back(1.0 / num_sparse_features * (0 - 4) * scales[bucket_id]);
- expected.push_back(1.0 / num_sparse_features * (0 - 4) * scales[bucket_id]);
- EXPECT_THAT(dest, ElementsAreFloat(expected));
-}
-
-} // namespace
-} // namespace libtextclassifier3
diff --git a/native/annotator/strip-unpaired-brackets_test.cc b/native/annotator/strip-unpaired-brackets_test.cc
deleted file mode 100644
index 32585ce..0000000
--- a/native/annotator/strip-unpaired-brackets_test.cc
+++ /dev/null
@@ -1,66 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "annotator/strip-unpaired-brackets.h"
-
-#include "gtest/gtest.h"
-
-namespace libtextclassifier3 {
-namespace {
-
-class StripUnpairedBracketsTest : public ::testing::Test {
- protected:
- StripUnpairedBracketsTest() : INIT_UNILIB_FOR_TESTING(unilib_) {}
- UniLib unilib_;
-};
-
-TEST_F(StripUnpairedBracketsTest, StripUnpairedBrackets) {
- // If the brackets match, nothing gets stripped.
- EXPECT_EQ(StripUnpairedBrackets("call me (123) 456 today", {8, 17}, unilib_),
- std::make_pair(8, 17));
- EXPECT_EQ(StripUnpairedBrackets("call me (123 456) today", {8, 17}, unilib_),
- std::make_pair(8, 17));
-
- // If the brackets don't match, they get stripped.
- EXPECT_EQ(StripUnpairedBrackets("call me (123 456 today", {8, 16}, unilib_),
- std::make_pair(9, 16));
- EXPECT_EQ(StripUnpairedBrackets("call me )123 456 today", {8, 16}, unilib_),
- std::make_pair(9, 16));
- EXPECT_EQ(StripUnpairedBrackets("call me 123 456) today", {8, 16}, unilib_),
- std::make_pair(8, 15));
- EXPECT_EQ(StripUnpairedBrackets("call me 123 456( today", {8, 16}, unilib_),
- std::make_pair(8, 15));
-
- // Strips brackets correctly from length-1 selections that consist of
- // a bracket only.
- EXPECT_EQ(StripUnpairedBrackets("call me at ) today", {11, 12}, unilib_),
- std::make_pair(12, 12));
- EXPECT_EQ(StripUnpairedBrackets("call me at ( today", {11, 12}, unilib_),
- std::make_pair(12, 12));
-
- // Handles invalid spans gracefully.
- EXPECT_EQ(StripUnpairedBrackets("call me at today", {11, 11}, unilib_),
- std::make_pair(11, 11));
- EXPECT_EQ(StripUnpairedBrackets("hello world", {0, 0}, unilib_),
- std::make_pair(0, 0));
- EXPECT_EQ(StripUnpairedBrackets("hello world", {11, 11}, unilib_),
- std::make_pair(11, 11));
- EXPECT_EQ(StripUnpairedBrackets("hello world", {-1, -1}, unilib_),
- std::make_pair(-1, -1));
-}
-
-} // namespace
-} // namespace libtextclassifier3
diff --git a/native/annotator/translate/translate_test.cc b/native/annotator/translate/translate_test.cc
deleted file mode 100644
index 8b2a8ef..0000000
--- a/native/annotator/translate/translate_test.cc
+++ /dev/null
@@ -1,181 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "annotator/translate/translate.h"
-
-#include <memory>
-
-#include "annotator/model_generated.h"
-#include "lang_id/fb_model/lang-id-from-fb.h"
-#include "lang_id/lang-id.h"
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-
-namespace libtextclassifier3 {
-namespace {
-
-using testing::AllOf;
-using testing::Field;
-
-const TranslateAnnotatorOptions* TestingTranslateAnnotatorOptions() {
- static const flatbuffers::DetachedBuffer* options_data = []() {
- TranslateAnnotatorOptionsT options;
- options.enabled = true;
- options.algorithm = TranslateAnnotatorOptions_::Algorithm_BACKOFF;
- options.backoff_options.reset(
- new TranslateAnnotatorOptions_::BackoffOptionsT());
-
- flatbuffers::FlatBufferBuilder builder;
- builder.Finish(TranslateAnnotatorOptions::Pack(builder, &options));
- return new flatbuffers::DetachedBuffer(builder.Release());
- }();
-
- return flatbuffers::GetRoot<TranslateAnnotatorOptions>(options_data->data());
-}
-
-class TestingTranslateAnnotator : public TranslateAnnotator {
- public:
- // Make these protected members public for tests.
- using TranslateAnnotator::BackoffDetectLanguages;
- using TranslateAnnotator::FindIndexOfNextWhitespaceOrPunctuation;
- using TranslateAnnotator::TokenAlignedSubstringAroundSpan;
- using TranslateAnnotator::TranslateAnnotator;
-};
-
-std::string GetModelPath() {
- return "";
-}
-
-class TranslateAnnotatorTest : public ::testing::Test {
- protected:
- TranslateAnnotatorTest()
- : INIT_UNILIB_FOR_TESTING(unilib_),
- langid_model_(libtextclassifier3::mobile::lang_id::GetLangIdFromFlatbufferFile(
- GetModelPath() + "lang_id.smfb")),
- translate_annotator_(TestingTranslateAnnotatorOptions(),
- langid_model_.get(), &unilib_) {}
-
- UniLib unilib_;
- std::unique_ptr<libtextclassifier3::mobile::lang_id::LangId> langid_model_;
- TestingTranslateAnnotator translate_annotator_;
-};
-
-TEST_F(TranslateAnnotatorTest, WhenSpeaksEnglishGetsTranslateActionForCzech) {
- ClassificationResult classification;
- EXPECT_TRUE(translate_annotator_.ClassifyText(
- UTF8ToUnicodeText("Třista třicet tři stříbrných stříkaček."), {18, 28},
- "en", &classification));
-
- EXPECT_THAT(classification,
- AllOf(Field(&ClassificationResult::collection, "translate")));
-}
-
-TEST_F(TranslateAnnotatorTest,
- WhenSpeaksEnglishDoesntGetTranslateActionForEnglish) {
- ClassificationResult classification;
- EXPECT_FALSE(translate_annotator_.ClassifyText(
- UTF8ToUnicodeText("This is utterly unutterable."), {8, 15}, "en",
- &classification));
-}
-
-TEST_F(TranslateAnnotatorTest,
- WhenSpeaksMultipleAndNotCzechGetsTranslateActionForCzech) {
- ClassificationResult classification;
- EXPECT_TRUE(translate_annotator_.ClassifyText(
- UTF8ToUnicodeText("Třista třicet tři stříbrných stříkaček."), {8, 15},
- "de,en,ja", &classification));
-
- EXPECT_THAT(classification,
- AllOf(Field(&ClassificationResult::collection, "translate")));
-}
-
-TEST_F(TranslateAnnotatorTest,
- WhenSpeaksMultipleAndEnglishDoesntGetTranslateActionForEnglish) {
- ClassificationResult classification;
- EXPECT_FALSE(translate_annotator_.ClassifyText(
- UTF8ToUnicodeText("This is utterly unutterable."), {8, 15}, "cs,en,de,ja",
- &classification));
-}
-
-TEST_F(TranslateAnnotatorTest, FindIndexOfNextWhitespaceOrPunctuation) {
- const UnicodeText text =
- UTF8ToUnicodeText("Třista třicet, tři stříbrných stříkaček");
-
- EXPECT_EQ(
- translate_annotator_.FindIndexOfNextWhitespaceOrPunctuation(text, 0, -1),
- text.begin());
- EXPECT_EQ(
- translate_annotator_.FindIndexOfNextWhitespaceOrPunctuation(text, 35, 1),
- text.end());
- EXPECT_EQ(
- translate_annotator_.FindIndexOfNextWhitespaceOrPunctuation(text, 10, -1),
- std::next(text.begin(), 6));
- EXPECT_EQ(
- translate_annotator_.FindIndexOfNextWhitespaceOrPunctuation(text, 10, 1),
- std::next(text.begin(), 13));
-}
-
-TEST_F(TranslateAnnotatorTest, TokenAlignedSubstringAroundSpan) {
- const UnicodeText text =
- UTF8ToUnicodeText("Třista třicet, tři stříbrných stříkaček");
-
- EXPECT_EQ(translate_annotator_.TokenAlignedSubstringAroundSpan(
- text, {35, 37}, /*minimum_length=*/100),
- text);
- EXPECT_EQ(translate_annotator_.TokenAlignedSubstringAroundSpan(
- text, {35, 37}, /*minimum_length=*/0),
- UTF8ToUnicodeText("ač"));
- EXPECT_EQ(translate_annotator_.TokenAlignedSubstringAroundSpan(
- text, {35, 37}, /*minimum_length=*/3),
- UTF8ToUnicodeText("stříkaček"));
- EXPECT_EQ(translate_annotator_.TokenAlignedSubstringAroundSpan(
- text, {35, 37}, /*minimum_length=*/10),
- UTF8ToUnicodeText("stříkaček"));
- EXPECT_EQ(translate_annotator_.TokenAlignedSubstringAroundSpan(
- text, {35, 37}, /*minimum_length=*/11),
- UTF8ToUnicodeText("stříbrných stříkaček"));
-
- const UnicodeText text_no_whitespace =
- UTF8ToUnicodeText("reallyreallylongstring");
- EXPECT_EQ(translate_annotator_.TokenAlignedSubstringAroundSpan(
- text_no_whitespace, {10, 11}, /*minimum_length=*/2),
- UTF8ToUnicodeText("reallyreallylongstring"));
-}
-
-TEST_F(TranslateAnnotatorTest, TokenAlignedSubstringWhitespaceText) {
- const UnicodeText text = UTF8ToUnicodeText(" ");
-
- // Shouldn't modify the selection in case it's all whitespace.
- EXPECT_EQ(translate_annotator_.TokenAlignedSubstringAroundSpan(
- text, {5, 7}, /*minimum_length=*/3),
- UTF8ToUnicodeText(" "));
- EXPECT_EQ(translate_annotator_.TokenAlignedSubstringAroundSpan(
- text, {5, 5}, /*minimum_length=*/1),
- UTF8ToUnicodeText(""));
-}
-
-TEST_F(TranslateAnnotatorTest, TokenAlignedSubstringMostlyWhitespaceText) {
- const UnicodeText text = UTF8ToUnicodeText("a a");
-
- // Should still select the whole text even if pointing to whitespace
- // initially.
- EXPECT_EQ(translate_annotator_.TokenAlignedSubstringAroundSpan(
- text, {5, 7}, /*minimum_length=*/11),
- UTF8ToUnicodeText("a a"));
-}
-
-} // namespace
-} // namespace libtextclassifier3
diff --git a/native/annotator/types-test-util.h b/native/annotator/types-test-util.h
deleted file mode 100644
index 1d018a1..0000000
--- a/native/annotator/types-test-util.h
+++ /dev/null
@@ -1,45 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_TYPES_TEST_UTIL_H_
-#define LIBTEXTCLASSIFIER_ANNOTATOR_TYPES_TEST_UTIL_H_
-
-#include <ostream>
-
-#include "annotator/types.h"
-#include "utils/base/logging.h"
-
-namespace libtextclassifier3 {
-
-#define TC3_DECLARE_PRINT_OPERATOR(TYPE_NAME) \
- inline std::ostream& operator<<(std::ostream& stream, \
- const TYPE_NAME& value) { \
- logging::LoggingStringStream tmp_stream; \
- tmp_stream << value; \
- return stream << tmp_stream.message; \
- }
-
-TC3_DECLARE_PRINT_OPERATOR(AnnotatedSpan)
-TC3_DECLARE_PRINT_OPERATOR(ClassificationResult)
-TC3_DECLARE_PRINT_OPERATOR(DatetimeParsedData)
-TC3_DECLARE_PRINT_OPERATOR(DatetimeParseResultSpan)
-TC3_DECLARE_PRINT_OPERATOR(Token)
-
-#undef TC3_DECLARE_PRINT_OPERATOR
-
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_ANNOTATOR_TYPES_TEST_UTIL_H_
diff --git a/native/annotator/zlib-utils_test.cc b/native/annotator/zlib-utils_test.cc
deleted file mode 100644
index 363c155..0000000
--- a/native/annotator/zlib-utils_test.cc
+++ /dev/null
@@ -1,152 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "annotator/zlib-utils.h"
-
-#include <memory>
-
-#include "annotator/model_generated.h"
-#include "utils/zlib/zlib.h"
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-
-namespace libtextclassifier3 {
-
-TEST(ZlibUtilsTest, CompressModel) {
- ModelT model;
- model.regex_model.reset(new RegexModelT);
- model.regex_model->patterns.emplace_back(new RegexModel_::PatternT);
- model.regex_model->patterns.back()->pattern = "this is a test pattern";
- model.regex_model->patterns.emplace_back(new RegexModel_::PatternT);
- model.regex_model->patterns.back()->pattern = "this is a second test pattern";
-
- model.datetime_model.reset(new DatetimeModelT);
- model.datetime_model->patterns.emplace_back(new DatetimeModelPatternT);
- model.datetime_model->patterns.back()->regexes.emplace_back(
- new DatetimeModelPattern_::RegexT);
- model.datetime_model->patterns.back()->regexes.back()->pattern =
- "an example datetime pattern";
- model.datetime_model->extractors.emplace_back(new DatetimeModelExtractorT);
- model.datetime_model->extractors.back()->pattern =
- "an example datetime extractor";
-
- model.intent_options.reset(new IntentFactoryModelT);
- model.intent_options->generator.emplace_back(
- new IntentFactoryModel_::IntentGeneratorT);
- const std::string intent_generator1 = "lua generator 1";
- model.intent_options->generator.back()->lua_template_generator =
- std::vector<uint8_t>(intent_generator1.begin(), intent_generator1.end());
- model.intent_options->generator.emplace_back(
- new IntentFactoryModel_::IntentGeneratorT);
- const std::string intent_generator2 = "lua generator 2";
- model.intent_options->generator.back()->lua_template_generator =
- std::vector<uint8_t>(intent_generator2.begin(), intent_generator2.end());
-
- // NOTE: The resource strings contain some repetition, so that the compressed
- // version is smaller than the uncompressed one. Because the compression code
- // looks at that as well.
- model.resources.reset(new ResourcePoolT);
- model.resources->resource_entry.emplace_back(new ResourceEntryT);
- model.resources->resource_entry.back()->resource.emplace_back(new ResourceT);
- model.resources->resource_entry.back()->resource.back()->content =
- "rrrrrrrrrrrrr1.1";
- model.resources->resource_entry.back()->resource.emplace_back(new ResourceT);
- model.resources->resource_entry.back()->resource.back()->content =
- "rrrrrrrrrrrrr1.2";
- model.resources->resource_entry.emplace_back(new ResourceEntryT);
- model.resources->resource_entry.back()->resource.emplace_back(new ResourceT);
- model.resources->resource_entry.back()->resource.back()->content =
- "rrrrrrrrrrrrr2.1";
- model.resources->resource_entry.back()->resource.emplace_back(new ResourceT);
- model.resources->resource_entry.back()->resource.back()->content =
- "rrrrrrrrrrrrr2.2";
-
- // Compress the model.
- EXPECT_TRUE(CompressModel(&model));
-
- // Sanity check that uncompressed field is removed.
- EXPECT_TRUE(model.regex_model->patterns[0]->pattern.empty());
- EXPECT_TRUE(model.regex_model->patterns[1]->pattern.empty());
- EXPECT_TRUE(model.datetime_model->patterns[0]->regexes[0]->pattern.empty());
- EXPECT_TRUE(model.datetime_model->extractors[0]->pattern.empty());
- EXPECT_TRUE(
- model.intent_options->generator[0]->lua_template_generator.empty());
- EXPECT_TRUE(
- model.intent_options->generator[1]->lua_template_generator.empty());
- EXPECT_TRUE(model.resources->resource_entry[0]->resource[0]->content.empty());
- EXPECT_TRUE(model.resources->resource_entry[0]->resource[1]->content.empty());
- EXPECT_TRUE(model.resources->resource_entry[1]->resource[0]->content.empty());
- EXPECT_TRUE(model.resources->resource_entry[1]->resource[1]->content.empty());
-
- // Pack and load the model.
- flatbuffers::FlatBufferBuilder builder;
- builder.Finish(Model::Pack(builder, &model));
- const Model* compressed_model =
- GetModel(reinterpret_cast<const char*>(builder.GetBufferPointer()));
- ASSERT_TRUE(compressed_model != nullptr);
-
- // Decompress the fields again and check that they match the original.
- std::unique_ptr<ZlibDecompressor> decompressor = ZlibDecompressor::Instance();
- ASSERT_TRUE(decompressor != nullptr);
- std::string uncompressed_pattern;
- EXPECT_TRUE(decompressor->MaybeDecompress(
- compressed_model->regex_model()->patterns()->Get(0)->compressed_pattern(),
- &uncompressed_pattern));
- EXPECT_EQ(uncompressed_pattern, "this is a test pattern");
- EXPECT_TRUE(decompressor->MaybeDecompress(
- compressed_model->regex_model()->patterns()->Get(1)->compressed_pattern(),
- &uncompressed_pattern));
- EXPECT_EQ(uncompressed_pattern, "this is a second test pattern");
- EXPECT_TRUE(decompressor->MaybeDecompress(compressed_model->datetime_model()
- ->patterns()
- ->Get(0)
- ->regexes()
- ->Get(0)
- ->compressed_pattern(),
- &uncompressed_pattern));
- EXPECT_EQ(uncompressed_pattern, "an example datetime pattern");
- EXPECT_TRUE(decompressor->MaybeDecompress(compressed_model->datetime_model()
- ->extractors()
- ->Get(0)
- ->compressed_pattern(),
- &uncompressed_pattern));
- EXPECT_EQ(uncompressed_pattern, "an example datetime extractor");
-
- EXPECT_TRUE(DecompressModel(&model));
- EXPECT_EQ(model.regex_model->patterns[0]->pattern, "this is a test pattern");
- EXPECT_EQ(model.regex_model->patterns[1]->pattern,
- "this is a second test pattern");
- EXPECT_EQ(model.datetime_model->patterns[0]->regexes[0]->pattern,
- "an example datetime pattern");
- EXPECT_EQ(model.datetime_model->extractors[0]->pattern,
- "an example datetime extractor");
- EXPECT_EQ(
- model.intent_options->generator[0]->lua_template_generator,
- std::vector<uint8_t>(intent_generator1.begin(), intent_generator1.end()));
- EXPECT_EQ(
- model.intent_options->generator[1]->lua_template_generator,
- std::vector<uint8_t>(intent_generator2.begin(), intent_generator2.end()));
- EXPECT_EQ(model.resources->resource_entry[0]->resource[0]->content,
- "rrrrrrrrrrrrr1.1");
- EXPECT_EQ(model.resources->resource_entry[0]->resource[1]->content,
- "rrrrrrrrrrrrr1.2");
- EXPECT_EQ(model.resources->resource_entry[1]->resource[0]->content,
- "rrrrrrrrrrrrr2.1");
- EXPECT_EQ(model.resources->resource_entry[1]->resource[1]->content,
- "rrrrrrrrrrrrr2.2");
-}
-
-} // namespace libtextclassifier3
diff --git a/native/utils/base/logging.h b/native/utils/base/logging.h
index b465230..4939b7e 100644
--- a/native/utils/base/logging.h
+++ b/native/utils/base/logging.h
@@ -144,7 +144,7 @@
// Debug checks: a TC3_DCHECK<suffix> macro should behave like TC3_CHECK<suffix>
// in debug mode an don't check / don't print anything in non-debug mode.
-#ifdef NDEBUG
+#if defined(NDEBUG) && !defined(TC3_DEBUG_LOGGING)
#define TC3_DCHECK(x) TC3_NULLSTREAM
#define TC3_DCHECK_EQ(x, y) TC3_NULLSTREAM
diff --git a/native/utils/base/logging_raw.cc b/native/utils/base/logging_raw.cc
index ccaef22..e3a73e2 100644
--- a/native/utils/base/logging_raw.cc
+++ b/native/utils/base/logging_raw.cc
@@ -17,8 +17,14 @@
#include "utils/base/logging_raw.h"
#include <stdio.h>
+
#include <string>
+#define TC3_RETURN_IF_NOT_ERROR_OR_FATAL \
+ if (severity != ERROR && severity != FATAL) { \
+ return; \
+ }
+
// NOTE: this file contains two implementations: one for Android, one for all
// other cases. We always build exactly one implementation.
#if defined(__ANDROID__)
@@ -49,13 +55,10 @@
void LowLevelLogging(LogSeverity severity, const std::string& tag,
const std::string& message) {
- const int android_log_level = GetAndroidLogLevel(severity);
#if !defined(TC3_DEBUG_LOGGING)
- if (android_log_level != ANDROID_LOG_ERROR &&
- android_log_level != ANDROID_LOG_FATAL) {
- return;
- }
+ TC3_RETURN_IF_NOT_ERROR_OR_FATAL
#endif
+ const int android_log_level = GetAndroidLogLevel(severity);
__android_log_write(android_log_level, tag.c_str(), message.c_str());
}
@@ -88,6 +91,9 @@
void LowLevelLogging(LogSeverity severity, const std::string &tag,
const std::string &message) {
+#if !defined(TC3_DEBUG_LOGGING)
+ TC3_RETURN_IF_NOT_ERROR_OR_FATAL
+#endif
fprintf(stderr, "[%s] %s : %s\n", LogSeverityToString(severity), tag.c_str(),
message.c_str());
fflush(stderr);
diff --git a/native/utils/base/status_test.cc b/native/utils/base/status_test.cc
deleted file mode 100644
index 82d5aad..0000000
--- a/native/utils/base/status_test.cc
+++ /dev/null
@@ -1,156 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/base/status.h"
-
-#include "utils/base/logging.h"
-#include "utils/base/status_macros.h"
-#include "gtest/gtest.h"
-
-namespace libtextclassifier3 {
-namespace {
-
-TEST(StatusTest, PrintsAbortedStatus) {
- logging::LoggingStringStream stream;
- stream << Status::UNKNOWN;
- EXPECT_EQ(Status::UNKNOWN.error_code(), 2);
- EXPECT_EQ(Status::UNKNOWN.CanonicalCode(), StatusCode::UNKNOWN);
- EXPECT_EQ(Status::UNKNOWN.error_message(), "");
- EXPECT_EQ(stream.message, "2");
-}
-
-TEST(StatusTest, PrintsOKStatus) {
- logging::LoggingStringStream stream;
- stream << Status::OK;
- EXPECT_EQ(Status::OK.error_code(), 0);
- EXPECT_EQ(Status::OK.CanonicalCode(), StatusCode::OK);
- EXPECT_EQ(Status::OK.error_message(), "");
- EXPECT_EQ(stream.message, "0");
-}
-
-TEST(StatusTest, UnknownStatusHasRightAttributes) {
- EXPECT_EQ(Status::UNKNOWN.error_code(), 2);
- EXPECT_EQ(Status::UNKNOWN.CanonicalCode(), StatusCode::UNKNOWN);
- EXPECT_EQ(Status::UNKNOWN.error_message(), "");
-}
-
-TEST(StatusTest, OkStatusHasRightAttributes) {
- EXPECT_EQ(Status::OK.error_code(), 0);
- EXPECT_EQ(Status::OK.CanonicalCode(), StatusCode::OK);
- EXPECT_EQ(Status::OK.error_message(), "");
-}
-
-TEST(StatusTest, CustomStatusHasRightAttributes) {
- Status status(StatusCode::INVALID_ARGUMENT, "You can't put this here!");
- EXPECT_EQ(status.error_code(), 3);
- EXPECT_EQ(status.CanonicalCode(), StatusCode::INVALID_ARGUMENT);
- EXPECT_EQ(status.error_message(), "You can't put this here!");
-}
-
-TEST(StatusTest, AssignmentPreservesMembers) {
- Status status(StatusCode::INVALID_ARGUMENT, "You can't put this here!");
-
- Status status2 = status;
-
- EXPECT_EQ(status2.error_code(), 3);
- EXPECT_EQ(status2.CanonicalCode(), StatusCode::INVALID_ARGUMENT);
- EXPECT_EQ(status2.error_message(), "You can't put this here!");
-}
-
-TEST(StatusTest, ReturnIfErrorOkStatus) {
- bool returned_due_to_error = true;
- auto lambda = [&returned_due_to_error](const Status& s) {
- TC3_RETURN_IF_ERROR(s);
- returned_due_to_error = false;
- return Status::OK;
- };
-
- // OK should allow execution to continue and the returned status should also
- // be OK.
- Status status = lambda(Status());
- EXPECT_EQ(status.error_code(), 0);
- EXPECT_EQ(status.CanonicalCode(), StatusCode::OK);
- EXPECT_EQ(status.error_message(), "");
- EXPECT_FALSE(returned_due_to_error);
-}
-
-TEST(StatusTest, ReturnIfErrorInvalidArgumentStatus) {
- bool returned_due_to_error = true;
- auto lambda = [&returned_due_to_error](const Status& s) {
- TC3_RETURN_IF_ERROR(s);
- returned_due_to_error = false;
- return Status::OK;
- };
-
- // INVALID_ARGUMENT should cause an early return.
- Status invalid_arg_status(StatusCode::INVALID_ARGUMENT, "You can't do that!");
- Status status = lambda(invalid_arg_status);
- EXPECT_EQ(status.error_code(), 3);
- EXPECT_EQ(status.CanonicalCode(), StatusCode::INVALID_ARGUMENT);
- EXPECT_EQ(status.error_message(), "You can't do that!");
- EXPECT_TRUE(returned_due_to_error);
-}
-
-TEST(StatusTest, ReturnIfErrorUnknownStatus) {
- bool returned_due_to_error = true;
- auto lambda = [&returned_due_to_error](const Status& s) {
- TC3_RETURN_IF_ERROR(s);
- returned_due_to_error = false;
- return Status::OK;
- };
-
- // UNKNOWN should cause an early return.
- Status unknown_status(StatusCode::UNKNOWN,
- "We also know there are known unknowns.");
- libtextclassifier3::Status status = lambda(unknown_status);
- EXPECT_EQ(status.error_code(), 2);
- EXPECT_EQ(status.CanonicalCode(), StatusCode::UNKNOWN);
- EXPECT_EQ(status.error_message(), "We also know there are known unknowns.");
- EXPECT_TRUE(returned_due_to_error);
-}
-
-TEST(StatusTest, ReturnIfErrorOnlyInvokesExpressionOnce) {
- int num_invocations = 0;
- auto ok_internal_expr = [&num_invocations]() {
- ++num_invocations;
- return Status::OK;
- };
- auto ok_lambda = [&ok_internal_expr]() {
- TC3_RETURN_IF_ERROR(ok_internal_expr());
- return Status::OK;
- };
-
- libtextclassifier3::Status status = ok_lambda();
- EXPECT_EQ(status.CanonicalCode(), StatusCode::OK);
- EXPECT_EQ(num_invocations, 1);
-
- num_invocations = 0;
- auto error_internal_expr = [&num_invocations]() {
- ++num_invocations;
- return Status::UNKNOWN;
- };
- auto error_lambda = [&error_internal_expr]() {
- TC3_RETURN_IF_ERROR(error_internal_expr());
- return Status::OK;
- };
-
- status = error_lambda();
- EXPECT_EQ(status.CanonicalCode(), StatusCode::UNKNOWN);
- EXPECT_EQ(num_invocations, 1);
-}
-
-} // namespace
-} // namespace libtextclassifier3
diff --git a/native/utils/base/statusor_test.cc b/native/utils/base/statusor_test.cc
deleted file mode 100644
index 6e8afb1..0000000
--- a/native/utils/base/statusor_test.cc
+++ /dev/null
@@ -1,132 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/base/statusor.h"
-
-#include "utils/base/logging.h"
-#include "utils/base/status.h"
-#include "gtest/gtest.h"
-
-namespace libtextclassifier3 {
-namespace {
-
-TEST(StatusOrTest, DoesntDieWhenOK) {
- StatusOr<std::string> status_or_string = std::string("Hello World");
- EXPECT_TRUE(status_or_string.ok());
- EXPECT_EQ(status_or_string.ValueOrDie(), "Hello World");
-}
-
-TEST(StatusOrTest, DiesWhenNotOK) {
- StatusOr<std::string> status_or_string = {Status::UNKNOWN};
- EXPECT_FALSE(status_or_string.ok());
- EXPECT_DEATH(status_or_string.ValueOrDie(),
- "Attempting to fetch value of non-OK StatusOr: 2");
-}
-
-// Foo is NOT default constructible and can be implicitly converted to from int.
-class Foo {
- public:
- // Copy value conversion
- Foo(int i) : i_(i) {} // NOLINT
- int i() const { return i_; }
-
- private:
- int i_;
-};
-
-TEST(StatusOrTest, HandlesNonDefaultConstructibleValues) {
- StatusOr<Foo> foo_or(Foo(7));
- EXPECT_TRUE(foo_or.ok());
- EXPECT_EQ(foo_or.ValueOrDie().i(), 7);
-
- StatusOr<Foo> error_or(Status::UNKNOWN);
- EXPECT_FALSE(error_or.ok());
- EXPECT_EQ(error_or.status().CanonicalCode(), StatusCode::UNKNOWN);
-}
-
-class Bar {
- public:
- // Move value conversion
- Bar(Foo&& f) : i_(2 * f.i()) {} // NOLINT
-
- // Movable, but not copyable.
- Bar(const Bar& other) = delete;
- Bar& operator=(const Bar& rhs) = delete;
- Bar(Bar&& other) = default;
- Bar& operator=(Bar&& rhs) = default;
-
- int i() const { return i_; }
-
- private:
- int i_;
-};
-
-TEST(StatusOrTest, HandlesValueConversion) {
- // Copy value conversion constructor : StatusOr<Foo>(const int&)
- StatusOr<Foo> foo_status(19);
- EXPECT_TRUE(foo_status.ok());
- EXPECT_EQ(foo_status.ValueOrDie().i(), 19);
-
- // Move value conversion constructor : StatusOr<Bar>(Foo&&)
- StatusOr<Bar> bar_status(std::move(foo_status));
- EXPECT_TRUE(bar_status.ok());
- EXPECT_EQ(bar_status.ValueOrDie().i(), 38);
-
- StatusOr<int> int_status(19);
- // Copy conversion constructor : StatusOr<Foo>(const StatusOr<int>&)
- StatusOr<Foo> copied_status(int_status);
- EXPECT_TRUE(copied_status.ok());
- EXPECT_EQ(copied_status.ValueOrDie().i(), 19);
-
- // Move conversion constructor : StatusOr<Bar>(StatusOr<Foo>&&)
- StatusOr<Bar> moved_status(std::move(copied_status));
- EXPECT_TRUE(moved_status.ok());
- EXPECT_EQ(moved_status.ValueOrDie().i(), 38);
-
- // Move conversion constructor with error : StatusOr<Bar>(StatusOr<Foo>&&)
- StatusOr<Foo> error_status(Status::UNKNOWN);
- StatusOr<Bar> moved_error_status(std::move(error_status));
- EXPECT_FALSE(moved_error_status.ok());
-}
-
-struct OkFn {
- StatusOr<int> operator()() { return 42; }
-};
-TEST(StatusOrTest, AssignOrReturnValOk) {
- auto lambda = []() {
- TC3_ASSIGN_OR_RETURN(int i, OkFn()(), -1);
- return i;
- };
-
- // OkFn() should return a valid integer, so lambda should return that integer.
- EXPECT_EQ(lambda(), 42);
-}
-
-struct FailFn {
- StatusOr<int> operator()() { return Status::UNKNOWN; }
-};
-TEST(StatusOrTest, AssignOrReturnValError) {
- auto lambda = []() {
- TC3_ASSIGN_OR_RETURN(int i, FailFn()(), -1);
- return i;
- };
-
- // FailFn() should return an error, so lambda should return -1.
- EXPECT_EQ(lambda(), -1);
-}
-
-} // namespace
-} // namespace libtextclassifier3
diff --git a/native/utils/calendar/calendar_test-include.cc b/native/utils/calendar/calendar_test-include.cc
deleted file mode 100644
index 7fe6f53..0000000
--- a/native/utils/calendar/calendar_test-include.cc
+++ /dev/null
@@ -1,324 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/calendar/calendar_test-include.h"
-
-namespace libtextclassifier3 {
-namespace test_internal {
-
-static constexpr int kWednesday = 4;
-
-TEST_F(CalendarTest, Interface) {
- int64 time;
- DatetimeGranularity granularity;
- std::string timezone;
- DatetimeParsedData data;
- bool result = calendarlib_.InterpretParseData(
- data, /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Zurich",
- /*reference_locale=*/"en-CH",
- /*prefer_future_for_unspecified_date=*/false, &time, &granularity);
- TC3_LOG(INFO) << result;
-}
-
-TEST_F(CalendarTest, SetsZeroTimeWhenNotRelative) {
- int64 time;
- DatetimeGranularity granularity;
- DatetimeParsedData data;
- data.SetAbsoluteValue(DatetimeComponent::ComponentType::YEAR, 2018);
-
- ASSERT_TRUE(calendarlib_.InterpretParseData(
- data,
- /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
- /*reference_locale=*/"en-CH",
- /*prefer_future_for_unspecified_date=*/false, &time, &granularity));
- EXPECT_EQ(time, 1514761200000L /* Jan 01 2018 00:00:00 */);
-
- ASSERT_TRUE(calendarlib_.InterpretParseData(
- data,
- /*reference_time_ms_utc=*/1L, /*reference_timezone=*/"Europe/Zurich",
- /*reference_locale=*/"en-CH",
- /*prefer_future_for_unspecified_date=*/false, &time, &granularity));
- EXPECT_EQ(time, 1514761200000L /* Jan 01 2018 00:00:00 */);
-}
-
-TEST_F(CalendarTest, RoundingToGranularityBasic) {
- int64 time;
- DatetimeGranularity granularity;
- DatetimeParsedData data;
-
- data.SetAbsoluteValue(DatetimeComponent::ComponentType::YEAR, 2018);
- ASSERT_TRUE(calendarlib_.InterpretParseData(
- data,
- /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
- /*reference_locale=*/"en-CH",
- /*prefer_future_for_unspecified_date=*/false, &time, &granularity));
- EXPECT_EQ(time, 1514761200000L /* Jan 01 2018 00:00:00 */);
-
- data.SetAbsoluteValue(DatetimeComponent::ComponentType::MONTH, 4);
- ASSERT_TRUE(calendarlib_.InterpretParseData(
- data,
- /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
- /*reference_locale=*/"en-CH",
- /*prefer_future_for_unspecified_date=*/false, &time, &granularity));
- EXPECT_EQ(time, 1522533600000L /* Apr 01 2018 00:00:00 */);
-
- data.SetAbsoluteValue(DatetimeComponent::ComponentType::DAY_OF_MONTH, 25);
- ASSERT_TRUE(calendarlib_.InterpretParseData(
- data,
- /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
- /*reference_locale=*/"en-CH",
- /*prefer_future_for_unspecified_date=*/false, &time, &granularity));
- EXPECT_EQ(time, 1524607200000L /* Apr 25 2018 00:00:00 */);
-
- data.SetAbsoluteValue(DatetimeComponent::ComponentType::HOUR, 9);
- ASSERT_TRUE(calendarlib_.InterpretParseData(
- data,
- /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
- /*reference_locale=*/"en-CH",
- /*prefer_future_for_unspecified_date=*/false, &time, &granularity));
- EXPECT_EQ(time, 1524639600000L /* Apr 25 2018 09:00:00 */);
-
- data.SetAbsoluteValue(DatetimeComponent::ComponentType::MINUTE, 33);
- ASSERT_TRUE(calendarlib_.InterpretParseData(
- data,
- /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
- /*reference_locale=*/"en-CH",
- /*prefer_future_for_unspecified_date=*/false, &time, &granularity));
- EXPECT_EQ(time, 1524641580000 /* Apr 25 2018 09:33:00 */);
-
- data.SetAbsoluteValue(DatetimeComponent::ComponentType::SECOND, 59);
- ASSERT_TRUE(calendarlib_.InterpretParseData(
- data,
- /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
- /*reference_locale=*/"en-CH",
- /*prefer_future_for_unspecified_date=*/false, &time, &granularity));
- EXPECT_EQ(time, 1524641639000 /* Apr 25 2018 09:33:59 */);
-}
-
-TEST_F(CalendarTest, RoundingToGranularityWeek) {
- int64 time;
- DatetimeGranularity granularity;
- // Prepare data structure that means: "next week"
- DatetimeParsedData data;
- data.SetRelativeValue(DatetimeComponent::ComponentType::WEEK,
- DatetimeComponent::RelativeQualifier::NEXT);
- data.SetRelativeCount(DatetimeComponent::ComponentType::WEEK, 1);
-
- ASSERT_TRUE(calendarlib_.InterpretParseData(
- data,
- /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
- /*reference_locale=*/"de-CH",
- /*prefer_future_for_unspecified_date=*/false, &time, &granularity));
- EXPECT_EQ(time, 342000000L /* Mon Jan 05 1970 00:00:00 */);
-
- ASSERT_TRUE(calendarlib_.InterpretParseData(
- data,
- /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
- /*reference_locale=*/"en-US",
- /*prefer_future_for_unspecified_date=*/false, &time, &granularity));
- EXPECT_EQ(time, 255600000L /* Sun Jan 04 1970 00:00:00 */);
-}
-
-TEST_F(CalendarTest, RelativeTime) {
- const int64 ref_time = 1524648839000L; /* 25 April 2018 09:33:59 */
- int64 time;
- DatetimeGranularity granularity;
-
- // Two Weds from now.
- DatetimeParsedData future_wed_parse;
- future_wed_parse.SetRelativeValue(
- DatetimeComponent::ComponentType::DAY_OF_WEEK,
- DatetimeComponent::RelativeQualifier::FUTURE);
- future_wed_parse.SetRelativeCount(
- DatetimeComponent::ComponentType::DAY_OF_WEEK, 2);
- future_wed_parse.SetAbsoluteValue(
- DatetimeComponent::ComponentType::DAY_OF_WEEK, kWednesday);
-
- ASSERT_TRUE(calendarlib_.InterpretParseData(
- future_wed_parse, ref_time, /*reference_timezone=*/"Europe/Zurich",
- /*reference_locale=*/"en-US",
- /*prefer_future_for_unspecified_date=*/false, &time, &granularity));
- EXPECT_EQ(time, 1525858439000L /* Wed May 09 2018 11:33:59 */);
- EXPECT_EQ(granularity, GRANULARITY_DAY);
-
- // Next Wed.
- DatetimeParsedData next_wed_parse;
- next_wed_parse.SetAbsoluteValue(DatetimeComponent::ComponentType::DAY_OF_WEEK,
- kWednesday);
- next_wed_parse.SetRelativeValue(DatetimeComponent::ComponentType::DAY_OF_WEEK,
- DatetimeComponent::RelativeQualifier::NEXT);
- next_wed_parse.SetRelativeCount(DatetimeComponent::ComponentType::DAY_OF_WEEK,
- 1);
-
- ASSERT_TRUE(calendarlib_.InterpretParseData(
- next_wed_parse, ref_time, /*reference_timezone=*/"Europe/Zurich",
- /*reference_locale=*/"en-US",
- /*prefer_future_for_unspecified_date=*/false, &time, &granularity));
- EXPECT_EQ(time, 1525212000000L /* Wed May 02 2018 00:00:00 */);
- EXPECT_EQ(granularity, GRANULARITY_DAY);
-
- // Same Wed.
- DatetimeParsedData same_wed_parse;
- same_wed_parse.SetRelativeValue(DatetimeComponent::ComponentType::DAY_OF_WEEK,
- DatetimeComponent::RelativeQualifier::THIS);
- same_wed_parse.SetAbsoluteValue(DatetimeComponent::ComponentType::DAY_OF_WEEK,
- kWednesday);
- same_wed_parse.SetRelativeCount(DatetimeComponent::ComponentType::DAY_OF_WEEK,
- 1);
-
- ASSERT_TRUE(calendarlib_.InterpretParseData(
- same_wed_parse, ref_time, /*reference_timezone=*/"Europe/Zurich",
- /*reference_locale=*/"en-US",
- /*prefer_future_for_unspecified_date=*/false, &time, &granularity));
- EXPECT_EQ(time, 1524607200000L /* Wed Apr 25 2018 00:00:00 */);
- EXPECT_EQ(granularity, GRANULARITY_DAY);
-
- // Previous Wed.
- DatetimeParsedData last_wed_parse;
- last_wed_parse.SetRelativeValue(DatetimeComponent::ComponentType::DAY_OF_WEEK,
- DatetimeComponent::RelativeQualifier::LAST);
- last_wed_parse.SetAbsoluteValue(DatetimeComponent::ComponentType::DAY_OF_WEEK,
- kWednesday);
- last_wed_parse.SetRelativeCount(DatetimeComponent::ComponentType::DAY_OF_WEEK,
- 1);
-
- ASSERT_TRUE(calendarlib_.InterpretParseData(
- last_wed_parse, ref_time, /*reference_timezone=*/"Europe/Zurich",
- /*reference_locale=*/"en-US",
- /*prefer_future_for_unspecified_date=*/false, &time, &granularity));
- EXPECT_EQ(time, 1524002400000L /* Wed Apr 18 2018 00:00:00 */);
- EXPECT_EQ(granularity, GRANULARITY_DAY);
-
- // Two Weds ago.
- DatetimeParsedData past_wed_parse;
- past_wed_parse.SetRelativeValue(DatetimeComponent::ComponentType::DAY_OF_WEEK,
- DatetimeComponent::RelativeQualifier::PAST);
- past_wed_parse.SetAbsoluteValue(DatetimeComponent::ComponentType::DAY_OF_WEEK,
- kWednesday);
- past_wed_parse.SetRelativeCount(DatetimeComponent::ComponentType::DAY_OF_WEEK,
- 2);
-
- ASSERT_TRUE(calendarlib_.InterpretParseData(
- past_wed_parse, ref_time, /*reference_timezone=*/"Europe/Zurich",
- /*reference_locale=*/"en-US",
- /*prefer_future_for_unspecified_date=*/false, &time, &granularity));
- EXPECT_EQ(time, 1523439239000L /* Wed Apr 11 2018 11:33:59 */);
- EXPECT_EQ(granularity, GRANULARITY_DAY);
-
- // In 3 hours.
- DatetimeParsedData in_3_hours_parse;
- in_3_hours_parse.SetRelativeValue(
- DatetimeComponent::ComponentType::HOUR,
- DatetimeComponent::RelativeQualifier::FUTURE);
- in_3_hours_parse.SetRelativeCount(DatetimeComponent::ComponentType::HOUR, 3);
-
- ASSERT_TRUE(calendarlib_.InterpretParseData(
- in_3_hours_parse, ref_time, /*reference_timezone=*/"Europe/Zurich",
- /*reference_locale=*/"en-US",
- /*prefer_future_for_unspecified_date=*/false, &time, &granularity));
- EXPECT_EQ(time, 1524659639000L /* Wed Apr 25 2018 14:33:59 */);
- EXPECT_EQ(granularity, GRANULARITY_HOUR);
-
- // In 5 minutes.
- DatetimeParsedData in_5_minutes_parse;
- in_5_minutes_parse.SetRelativeValue(
- DatetimeComponent::ComponentType::MINUTE,
- DatetimeComponent::RelativeQualifier::FUTURE);
- in_5_minutes_parse.SetRelativeCount(DatetimeComponent::ComponentType::MINUTE,
- 5);
-
- ASSERT_TRUE(calendarlib_.InterpretParseData(
- in_5_minutes_parse, ref_time, /*reference_timezone=*/"Europe/Zurich",
- /*reference_locale=*/"en-US",
- /*prefer_future_for_unspecified_date=*/false, &time, &granularity));
- EXPECT_EQ(time, 1524649139000L /* Wed Apr 25 2018 14:33:59 */);
- EXPECT_EQ(granularity, GRANULARITY_MINUTE);
-
- // In 10 seconds.
- DatetimeParsedData in_10_seconds_parse;
- in_10_seconds_parse.SetRelativeValue(
- DatetimeComponent::ComponentType::SECOND,
- DatetimeComponent::RelativeQualifier::FUTURE);
- in_10_seconds_parse.SetRelativeCount(DatetimeComponent::ComponentType::SECOND,
- 10);
-
- ASSERT_TRUE(calendarlib_.InterpretParseData(
- in_10_seconds_parse, ref_time, /*reference_timezone=*/"Europe/Zurich",
- /*reference_locale=*/"en-US",
- /*prefer_future_for_unspecified_date=*/false, &time, &granularity));
- EXPECT_EQ(time, 1524648849000L /* Wed Apr 25 2018 14:33:59 */);
- EXPECT_EQ(granularity, GRANULARITY_SECOND);
-}
-
-TEST_F(CalendarTest, AddsADayWhenTimeInThePastAndDayNotSpecified) {
- int64 time;
- DatetimeGranularity granularity;
- DatetimeParsedData data;
- data.SetAbsoluteValue(DatetimeComponent::ComponentType::HOUR, 7);
- data.SetAbsoluteValue(DatetimeComponent::ComponentType::MINUTE, 10);
-
- ASSERT_TRUE(calendarlib_.InterpretParseData(
- data,
- /*reference_time_ms_utc=*/1567317600000L /* Sept 01 2019 00:00:00 */,
- /*reference_timezone=*/"Europe/Zurich",
- /*reference_locale=*/"en-CH", /*prefer_future_for_unspecified_date=*/true,
- &time, &granularity));
- EXPECT_EQ(time, 1567401000000L /* Sept 02 2019 07:10:00 */);
-}
-
-TEST_F(CalendarTest,
- DoesntAddADayWhenTimeInThePastAndDayNotSpecifiedAndDisabled) {
- int64 time;
- DatetimeGranularity granularity;
- DatetimeParsedData data;
- data.SetAbsoluteValue(DatetimeComponent::ComponentType::HOUR, 7);
- data.SetAbsoluteValue(DatetimeComponent::ComponentType::MINUTE, 10);
-
- ASSERT_TRUE(calendarlib_.InterpretParseData(
- data,
- /*reference_time_ms_utc=*/1567317600000L /* Sept 01 2019 00:00:00 */,
- /*reference_timezone=*/"Europe/Zurich",
- /*reference_locale=*/"en-CH",
- /*prefer_future_for_unspecified_date=*/false, &time, &granularity));
- EXPECT_EQ(time, 1567314600000L /* Sept 01 2019 07:10:00 */);
-}
-
-TEST_F(CalendarTest, DoesntAddADayWhenTimeInTheFutureAndDayNotSpecified) {
- int64 time;
- DatetimeGranularity granularity;
- DatetimeParsedData data;
- data.SetAbsoluteValue(DatetimeComponent::ComponentType::HOUR, 9);
- data.SetAbsoluteValue(DatetimeComponent::ComponentType::MINUTE, 10);
-
- ASSERT_TRUE(calendarlib_.InterpretParseData(
- data,
- /*reference_time_ms_utc=*/1567317600000L /* Sept 01 2019 00:00:00 */,
- /*reference_timezone=*/"Europe/Zurich",
- /*reference_locale=*/"en-CH", /*prefer_future_for_unspecified_date=*/true,
- &time, &granularity));
- EXPECT_EQ(time, 1567321800000L /* Sept 01 2019 09:10:00 */);
-
- ASSERT_TRUE(calendarlib_.InterpretParseData(
- data,
- /*reference_time_ms_utc=*/1567317600000L /* Sept 01 2019 00:00:00 */,
- /*reference_timezone=*/"Europe/Zurich",
- /*reference_locale=*/"en-CH",
- /*prefer_future_for_unspecified_date=*/false, &time, &granularity));
- EXPECT_EQ(time, 1567321800000L /* Sept 01 2019 09:10:00 */);
-}
-
-} // namespace test_internal
-} // namespace libtextclassifier3
diff --git a/native/utils/calendar/calendar_test-include.h b/native/utils/calendar/calendar_test-include.h
deleted file mode 100644
index 58ad6e0..0000000
--- a/native/utils/calendar/calendar_test-include.h
+++ /dev/null
@@ -1,59 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-// This is a shared test between icu and javaicu calendar implementations.
-// It is meant to be #include'd.
-
-#ifndef LIBTEXTCLASSIFIER_UTILS_CALENDAR_CALENDAR_TEST_INCLUDE_H_
-#define LIBTEXTCLASSIFIER_UTILS_CALENDAR_CALENDAR_TEST_INCLUDE_H_
-
-#include "gtest/gtest.h"
-
-#if defined TC3_CALENDAR_ICU
-#include "utils/calendar/calendar-icu.h"
-#define TC3_TESTING_CREATE_CALENDARLIB_INSTANCE(VAR) VAR()
-#elif defined TC3_CALENDAR_APPLE
-#include "utils/calendar/calendar-apple.h"
-#define TC3_TESTING_CREATE_CALENDARLIB_INSTANCE(VAR) VAR()
-#elif defined TC3_CALENDAR_JAVAICU
-#include <jni.h>
-extern JNIEnv* g_jenv;
-#define TC3_TESTING_CREATE_CALENDARLIB_INSTANCE(VAR) \
- VAR(JniCache::Create(g_jenv))
-#include "utils/calendar/calendar-javaicu.h"
-#else
-#error Unsupported calendar implementation.
-#endif
-
-// This can get overridden in the javaicu version which needs to pass an JNIEnv*
-// argument to the constructor.
-#ifndef TC3_TESTING_CREATE_CALENDARLIB_INSTANCE
-
-#endif
-
-namespace libtextclassifier3 {
-namespace test_internal {
-
-class CalendarTest : public ::testing::Test {
- protected:
- CalendarTest() : TC3_TESTING_CREATE_CALENDARLIB_INSTANCE(calendarlib_) {}
- CalendarLib calendarlib_;
-};
-
-} // namespace test_internal
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_UTILS_CALENDAR_CALENDAR_TEST_INCLUDE_H_
diff --git a/native/utils/calendar/calendar_test.cc b/native/utils/calendar/calendar_test.cc
deleted file mode 100644
index 54ed2a0..0000000
--- a/native/utils/calendar/calendar_test.cc
+++ /dev/null
@@ -1,20 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "gtest/gtest.h"
-
-// The actual code of the test is in the following include:
-#include "utils/calendar/calendar_test-include.h"
diff --git a/native/utils/checksum_test.cc b/native/utils/checksum_test.cc
deleted file mode 100644
index dd04956..0000000
--- a/native/utils/checksum_test.cc
+++ /dev/null
@@ -1,57 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/checksum.h"
-#include "gtest/gtest.h"
-
-namespace libtextclassifier3 {
-namespace {
-
-TEST(LuhnTest, CorrectlyHandlesSimpleCases) {
- EXPECT_TRUE(VerifyLuhnChecksum("3782 8224 6310 005"));
- EXPECT_FALSE(VerifyLuhnChecksum("0"));
- EXPECT_FALSE(VerifyLuhnChecksum("1"));
- EXPECT_FALSE(VerifyLuhnChecksum("0A"));
-}
-
-TEST(LuhnTest, CorrectlyVerifiesPaymentCardNumbers) {
- // Fake test numbers.
- EXPECT_TRUE(VerifyLuhnChecksum("3782 8224 6310 005"));
- EXPECT_TRUE(VerifyLuhnChecksum("371449635398431"));
- EXPECT_TRUE(VerifyLuhnChecksum("5610591081018250"));
- EXPECT_TRUE(VerifyLuhnChecksum("38520000023237"));
- EXPECT_TRUE(VerifyLuhnChecksum("6011000990139424"));
- EXPECT_TRUE(VerifyLuhnChecksum("3566002020360505"));
- EXPECT_TRUE(VerifyLuhnChecksum("5105105105105100"));
- EXPECT_TRUE(VerifyLuhnChecksum("4012 8888 8888 1881"));
-}
-
-TEST(LuhnTest, HandlesWhitespace) {
- EXPECT_TRUE(
- VerifyLuhnChecksum("3782 8224 6310 005 ", /*ignore_whitespace=*/true));
- EXPECT_FALSE(
- VerifyLuhnChecksum("3782 8224 6310 005 ", /*ignore_whitespace=*/false));
-}
-
-TEST(LuhnTest, HandlesEdgeCases) {
- EXPECT_FALSE(VerifyLuhnChecksum(" ", /*ignore_whitespace=*/true));
- EXPECT_FALSE(VerifyLuhnChecksum(" ", /*ignore_whitespace=*/false));
- EXPECT_FALSE(VerifyLuhnChecksum("", /*ignore_whitespace=*/true));
- EXPECT_FALSE(VerifyLuhnChecksum("", /*ignore_whitespace=*/false));
-}
-
-} // namespace
-} // namespace libtextclassifier3
diff --git a/native/utils/container/double-array-trie.h b/native/utils/container/double-array-trie.h
index e1b43da..39c8822 100644
--- a/native/utils/container/double-array-trie.h
+++ b/native/utils/container/double-array-trie.h
@@ -30,7 +30,7 @@
// A trie node specifies a node in the tree, either an intermediate node or
// a leaf node.
// A leaf node contains the id as an int of the string match. This id is encoded
-// in the lower 30 bits, thus the number of distinct ids is 2^30.
+// in the lower 31 bits, thus the number of distinct ids is 2^31.
// An intermediate node has an associated label and an offset to it's children.
// The label is encoded in the least significant byte and must match the input
// character during matching.
diff --git a/native/utils/container/double-array-trie_test.cc b/native/utils/container/double-array-trie_test.cc
deleted file mode 100644
index 8ceec00..0000000
--- a/native/utils/container/double-array-trie_test.cc
+++ /dev/null
@@ -1,131 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/container/double-array-trie.h"
-
-#include <fstream>
-#include <string>
-#include <vector>
-
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-
-namespace libtextclassifier3 {
-namespace {
-
-std::string GetTestConfigPath() {
- return "";
-}
-
-TEST(DoubleArrayTest, Lookup) {
- // Test trie that contains pieces "hell", "hello", "o", "there".
- std::ifstream test_config_stream(GetTestConfigPath());
- std::string config((std::istreambuf_iterator<char>(test_config_stream)),
- (std::istreambuf_iterator<char>()));
- DoubleArrayTrie trie(reinterpret_cast<const TrieNode*>(config.data()),
- config.size() / sizeof(TrieNode));
-
- {
- std::vector<StringSet::Match> matches;
- EXPECT_TRUE(trie.FindAllPrefixMatches("hello there", &matches));
- EXPECT_EQ(matches.size(), 2);
- EXPECT_EQ(matches[0].id, 0 /*hell*/);
- EXPECT_EQ(matches[0].match_length, 4 /*hell*/);
- EXPECT_EQ(matches[1].id, 1 /*hello*/);
- EXPECT_EQ(matches[1].match_length, 5 /*hello*/);
- }
-
- {
- std::vector<StringSet::Match> matches;
- EXPECT_TRUE(trie.FindAllPrefixMatches("he", &matches));
- EXPECT_THAT(matches, testing::IsEmpty());
- }
-
- {
- std::vector<StringSet::Match> matches;
- EXPECT_TRUE(trie.FindAllPrefixMatches("abcd", &matches));
- EXPECT_THAT(matches, testing::IsEmpty());
- }
-
- {
- std::vector<StringSet::Match> matches;
- EXPECT_TRUE(trie.FindAllPrefixMatches("", &matches));
- EXPECT_THAT(matches, testing::IsEmpty());
- }
-
- {
- std::vector<StringSet::Match> matches;
- EXPECT_TRUE(trie.FindAllPrefixMatches("hi there", &matches));
- EXPECT_THAT(matches, testing::IsEmpty());
- }
-
- {
- std::vector<StringSet::Match> matches;
- EXPECT_TRUE(trie.FindAllPrefixMatches(StringPiece("\0", 1), &matches));
- EXPECT_THAT(matches, testing::IsEmpty());
- }
-
- {
- std::vector<StringSet::Match> matches;
- EXPECT_TRUE(
- trie.FindAllPrefixMatches(StringPiece("\xff, \xfe", 2), &matches));
- EXPECT_THAT(matches, testing::IsEmpty());
- }
-
- {
- StringSet::Match match;
- EXPECT_TRUE(trie.LongestPrefixMatch("hella there", &match));
- EXPECT_EQ(match.id, 0 /*hell*/);
- }
-
- {
- StringSet::Match match;
- EXPECT_TRUE(trie.LongestPrefixMatch("hello there", &match));
- EXPECT_EQ(match.id, 1 /*hello*/);
- }
-
- {
- StringSet::Match match;
- EXPECT_TRUE(trie.LongestPrefixMatch("abcd", &match));
- EXPECT_EQ(match.id, -1);
- }
-
- {
- StringSet::Match match;
- EXPECT_TRUE(trie.LongestPrefixMatch("", &match));
- EXPECT_EQ(match.id, -1);
- }
-
- {
- int value;
- EXPECT_TRUE(trie.Find("hell", &value));
- EXPECT_EQ(value, 0);
- }
-
- {
- int value;
- EXPECT_FALSE(trie.Find("hella", &value));
- }
-
- {
- int value;
- EXPECT_TRUE(trie.Find("hello", &value));
- EXPECT_EQ(value, 1 /*hello*/);
- }
-}
-
-} // namespace
-} // namespace libtextclassifier3
diff --git a/native/utils/container/sorted-strings-table_test.cc b/native/utils/container/sorted-strings-table_test.cc
deleted file mode 100644
index a93b197..0000000
--- a/native/utils/container/sorted-strings-table_test.cc
+++ /dev/null
@@ -1,131 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/container/sorted-strings-table.h"
-
-#include <vector>
-
-#include "utils/base/integral_types.h"
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-
-namespace libtextclassifier3 {
-namespace {
-
-TEST(SortedStringsTest, Lookup) {
- const char pieces[] = "hell\0hello\0o\0there\0";
- const uint32 offsets[] = {0, 5, 11, 13};
-
- SortedStringsTable table(/*num_pieces=*/4, offsets, StringPiece(pieces, 18),
- /*use_linear_scan_threshold=*/1);
-
- {
- std::vector<StringSet::Match> matches;
- EXPECT_TRUE(table.FindAllPrefixMatches("hello there", &matches));
- EXPECT_EQ(matches.size(), 2);
- EXPECT_EQ(matches[0].id, 0 /*hell*/);
- EXPECT_EQ(matches[0].match_length, 4 /*hell*/);
- EXPECT_EQ(matches[1].id, 1 /*hello*/);
- EXPECT_EQ(matches[1].match_length, 5 /*hello*/);
- }
-
- {
- std::vector<StringSet::Match> matches;
- EXPECT_TRUE(table.FindAllPrefixMatches("he", &matches));
- EXPECT_THAT(matches, testing::IsEmpty());
- }
-
- {
- std::vector<StringSet::Match> matches;
- EXPECT_TRUE(table.FindAllPrefixMatches("he", &matches));
- EXPECT_THAT(matches, testing::IsEmpty());
- }
-
- {
- std::vector<StringSet::Match> matches;
- EXPECT_TRUE(table.FindAllPrefixMatches("abcd", &matches));
- EXPECT_THAT(matches, testing::IsEmpty());
- }
-
- {
- std::vector<StringSet::Match> matches;
- EXPECT_TRUE(table.FindAllPrefixMatches("", &matches));
- EXPECT_THAT(matches, testing::IsEmpty());
- }
-
- {
- std::vector<StringSet::Match> matches;
- EXPECT_TRUE(table.FindAllPrefixMatches("hi there", &matches));
- EXPECT_THAT(matches, testing::IsEmpty());
- }
-
- {
- std::vector<StringSet::Match> matches;
- EXPECT_TRUE(table.FindAllPrefixMatches(StringPiece("\0", 1), &matches));
- EXPECT_THAT(matches, testing::IsEmpty());
- }
-
- {
- std::vector<StringSet::Match> matches;
- EXPECT_TRUE(
- table.FindAllPrefixMatches(StringPiece("\xff, \xfe", 2), &matches));
- EXPECT_THAT(matches, testing::IsEmpty());
- }
-
- {
- StringSet::Match match;
- EXPECT_TRUE(table.LongestPrefixMatch("hella there", &match));
- EXPECT_EQ(match.id, 0 /*hell*/);
- }
-
- {
- StringSet::Match match;
- EXPECT_TRUE(table.LongestPrefixMatch("hello there", &match));
- EXPECT_EQ(match.id, 1 /*hello*/);
- }
-
- {
- StringSet::Match match;
- EXPECT_TRUE(table.LongestPrefixMatch("abcd", &match));
- EXPECT_EQ(match.id, -1);
- }
-
- {
- StringSet::Match match;
- EXPECT_TRUE(table.LongestPrefixMatch("", &match));
- EXPECT_EQ(match.id, -1);
- }
-
- {
- int value;
- EXPECT_TRUE(table.Find("hell", &value));
- EXPECT_EQ(value, 0);
- }
-
- {
- int value;
- EXPECT_FALSE(table.Find("hella", &value));
- }
-
- {
- int value;
- EXPECT_TRUE(table.Find("hello", &value));
- EXPECT_EQ(value, 1 /*hello*/);
- }
-}
-
-} // namespace
-} // namespace libtextclassifier3
diff --git a/native/utils/flatbuffers.cc b/native/utils/flatbuffers.cc
index 53805e0..bb3bbb7 100644
--- a/native/utils/flatbuffers.cc
+++ b/native/utils/flatbuffers.cc
@@ -183,34 +183,6 @@
return libtextclassifier3::GetFieldByOffsetOrNull(type_, field_offset);
}
-bool ReflectiveFlatbuffer::IsMatchingType(const reflection::Field* field,
- const Variant& value) const {
- switch (field->type()->base_type()) {
- case reflection::Bool:
- return value.HasBool();
- case reflection::Byte:
- return value.HasInt8();
- case reflection::UByte:
- return value.HasUInt8();
- case reflection::Int:
- return value.HasInt();
- case reflection::UInt:
- return value.HasUInt();
- case reflection::Long:
- return value.HasInt64();
- case reflection::ULong:
- return value.HasUInt64();
- case reflection::Float:
- return value.HasFloat();
- case reflection::Double:
- return value.HasDouble();
- case reflection::String:
- return value.HasString();
- default:
- return false;
- }
-}
-
bool ReflectiveFlatbuffer::ParseAndSet(const reflection::Field* field,
const std::string& value) {
switch (field->type()->base_type()) {
@@ -518,7 +490,7 @@
case reflection::Vector:
switch (field->type()->element()) {
case reflection::Int:
- AppendFromVector<int>(from, field);
+ AppendFromVector<int32>(from, field);
break;
case reflection::UInt:
AppendFromVector<uint>(from, field);
diff --git a/native/utils/flatbuffers.h b/native/utils/flatbuffers.h
index e40c1bc..d987609 100644
--- a/native/utils/flatbuffers.h
+++ b/native/utils/flatbuffers.h
@@ -130,8 +130,37 @@
reflection::Field const** field);
// Checks whether a variant value type agrees with a field type.
- bool IsMatchingType(const reflection::Field* field,
- const Variant& value) const;
+ template <typename T>
+ bool IsMatchingType(const reflection::BaseType type) const {
+ switch (type) {
+ case reflection::Bool:
+ return std::is_same<T, bool>::value;
+ case reflection::Byte:
+ return std::is_same<T, int8>::value;
+ case reflection::UByte:
+ return std::is_same<T, uint8>::value;
+ case reflection::Int:
+ return std::is_same<T, int32>::value;
+ case reflection::UInt:
+ return std::is_same<T, uint32>::value;
+ case reflection::Long:
+ return std::is_same<T, int64>::value;
+ case reflection::ULong:
+ return std::is_same<T, uint64>::value;
+ case reflection::Float:
+ return std::is_same<T, float>::value;
+ case reflection::Double:
+ return std::is_same<T, double>::value;
+ case reflection::String:
+ return std::is_same<T, std::string>::value ||
+ std::is_same<T, StringPiece>::value ||
+ std::is_same<T, const char*>::value;
+ case reflection::Obj:
+ return std::is_same<T, ReflectiveFlatbuffer>::value;
+ default:
+ return false;
+ }
+ }
// Sets a (primitive) field to a specific value.
// Returns true if successful, and false if the field was not found or the
@@ -154,7 +183,7 @@
return false;
}
Variant variant_value(value);
- if (!IsMatchingType(field, variant_value)) {
+ if (!IsMatchingType<T>(field->type()->base_type())) {
TC3_LOG(ERROR) << "Type mismatch for field `" << field->name()->str()
<< "`, expected: " << field->type()->base_type()
<< ", got: " << variant_value.GetType();
@@ -193,6 +222,11 @@
template <typename T>
TypedRepeatedField<T>* Repeated(const reflection::Field* field) {
+ if (!IsMatchingType<T>(field->type()->element())) {
+ TC3_LOG(ERROR) << "Type mismatch for field `" << field->name()->str()
+ << "`";
+ return nullptr;
+ }
return static_cast<TypedRepeatedField<T>*>(Repeated(field));
}
diff --git a/native/utils/flatbuffers_test.cc b/native/utils/flatbuffers_test.cc
deleted file mode 100644
index e3aaa6c..0000000
--- a/native/utils/flatbuffers_test.cc
+++ /dev/null
@@ -1,370 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/flatbuffers.h"
-
-#include <fstream>
-#include <map>
-#include <memory>
-#include <string>
-
-#include "utils/flatbuffers_generated.h"
-#include "utils/flatbuffers_test_generated.h"
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-#include "flatbuffers/flatbuffers.h"
-#include "flatbuffers/reflection.h"
-#include "flatbuffers/reflection_generated.h"
-
-namespace libtextclassifier3 {
-namespace {
-
-std::string GetTestMetadataPath() {
- return "flatbuffers_test.bfbs";
-}
-
-std::string LoadTestMetadata() {
- std::ifstream test_config_stream(GetTestMetadataPath());
- return std::string((std::istreambuf_iterator<char>(test_config_stream)),
- (std::istreambuf_iterator<char>()));
-}
-
-TEST(FlatbuffersTest, PrimitiveFieldsAreCorrectlySet) {
- std::string metadata_buffer = LoadTestMetadata();
- ReflectiveFlatbufferBuilder reflective_builder(
- flatbuffers::GetRoot<reflection::Schema>(metadata_buffer.data()));
-
- std::unique_ptr<ReflectiveFlatbuffer> buffer = reflective_builder.NewRoot();
- EXPECT_TRUE(buffer != nullptr);
- EXPECT_TRUE(buffer->Set("an_int_field", 42));
- EXPECT_TRUE(buffer->Set("a_long_field", 84ll));
- EXPECT_TRUE(buffer->Set("a_bool_field", true));
- EXPECT_TRUE(buffer->Set("a_float_field", 1.f));
- EXPECT_TRUE(buffer->Set("a_double_field", 1.0));
-
- // Try to parse with the generated code.
- std::string serialized_entity_data = buffer->Serialize();
- std::unique_ptr<test::EntityDataT> entity_data =
- LoadAndVerifyMutableFlatbuffer<test::EntityData>(
- serialized_entity_data.data(), serialized_entity_data.size());
- EXPECT_TRUE(entity_data != nullptr);
- EXPECT_EQ(entity_data->an_int_field, 42);
- EXPECT_EQ(entity_data->a_long_field, 84);
- EXPECT_EQ(entity_data->a_bool_field, true);
- EXPECT_NEAR(entity_data->a_float_field, 1.f, 1e-4);
- EXPECT_NEAR(entity_data->a_double_field, 1.f, 1e-4);
-}
-
-TEST(FlatbuffersTest, HandlesUnknownFields) {
- std::string metadata_buffer = LoadTestMetadata();
- const reflection::Schema* schema =
- flatbuffers::GetRoot<reflection::Schema>(metadata_buffer.data());
- ReflectiveFlatbufferBuilder reflective_builder(schema);
-
- std::unique_ptr<ReflectiveFlatbuffer> buffer = reflective_builder.NewRoot();
- EXPECT_TRUE(buffer != nullptr);
-
- // Add a field that is not known to the (statically generated) code.
- EXPECT_TRUE(buffer->Set("mystic", "this is an unknown field."));
-
- flatbuffers::FlatBufferBuilder builder;
- builder.Finish(flatbuffers::Offset<void>(buffer->Serialize(&builder)));
-
- // Try to read the field again.
- const flatbuffers::Table* extra =
- flatbuffers::GetAnyRoot(builder.GetBufferPointer());
- EXPECT_EQ(extra
- ->GetPointer<const flatbuffers::String*>(
- buffer->GetFieldOrNull("mystic")->offset())
- ->str(),
- "this is an unknown field.");
-}
-
-TEST(FlatbuffersTest, HandlesNestedFields) {
- std::string metadata_buffer = LoadTestMetadata();
- const reflection::Schema* schema =
- flatbuffers::GetRoot<reflection::Schema>(metadata_buffer.data());
- ReflectiveFlatbufferBuilder reflective_builder(schema);
-
- FlatbufferFieldPathT path;
- path.field.emplace_back(new FlatbufferFieldT);
- path.field.back()->field_name = "flight_number";
- path.field.emplace_back(new FlatbufferFieldT);
- path.field.back()->field_name = "carrier_code";
- flatbuffers::FlatBufferBuilder path_builder;
- path_builder.Finish(FlatbufferFieldPath::Pack(path_builder, &path));
-
- std::unique_ptr<ReflectiveFlatbuffer> buffer = reflective_builder.NewRoot();
-
- ReflectiveFlatbuffer* parent = nullptr;
- reflection::Field const* field = nullptr;
- EXPECT_TRUE(
- buffer->GetFieldWithParent(flatbuffers::GetRoot<FlatbufferFieldPath>(
- path_builder.GetBufferPointer()),
- &parent, &field));
- EXPECT_EQ(parent, buffer->Mutable("flight_number"));
- EXPECT_EQ(field,
- buffer->Mutable("flight_number")->GetFieldOrNull("carrier_code"));
-}
-
-TEST(FlatbuffersTest, HandlesMultipleNestedFields) {
- std::string metadata_buffer = LoadTestMetadata();
- ReflectiveFlatbufferBuilder reflective_builder(
- flatbuffers::GetRoot<reflection::Schema>(metadata_buffer.data()));
-
- std::unique_ptr<ReflectiveFlatbuffer> buffer = reflective_builder.NewRoot();
- ReflectiveFlatbuffer* flight_info = buffer->Mutable("flight_number");
- flight_info->Set("carrier_code", "LX");
- flight_info->Set("flight_code", 38);
-
- ReflectiveFlatbuffer* contact_info = buffer->Mutable("contact_info");
- EXPECT_TRUE(contact_info->Set("first_name", "Barack"));
- EXPECT_TRUE(contact_info->Set("last_name", "Obama"));
- EXPECT_TRUE(contact_info->Set("phone_number", "1-800-TEST"));
- EXPECT_TRUE(contact_info->Set("score", 1.f));
-
- // Try to parse with the generated code.
- std::string serialized_entity_data = buffer->Serialize();
- std::unique_ptr<test::EntityDataT> entity_data =
- LoadAndVerifyMutableFlatbuffer<test::EntityData>(
- serialized_entity_data.data(), serialized_entity_data.size());
- EXPECT_TRUE(entity_data != nullptr);
- EXPECT_EQ(entity_data->flight_number->carrier_code, "LX");
- EXPECT_EQ(entity_data->flight_number->flight_code, 38);
- EXPECT_EQ(entity_data->contact_info->first_name, "Barack");
- EXPECT_EQ(entity_data->contact_info->last_name, "Obama");
- EXPECT_EQ(entity_data->contact_info->phone_number, "1-800-TEST");
- EXPECT_NEAR(entity_data->contact_info->score, 1.f, 1e-4);
-}
-
-TEST(FlatbuffersTest, HandlesFieldsSetWithNamePath) {
- std::string metadata_buffer = LoadTestMetadata();
- ReflectiveFlatbufferBuilder reflective_builder(
- flatbuffers::GetRoot<reflection::Schema>(metadata_buffer.data()));
-
- FlatbufferFieldPathT path;
- path.field.emplace_back(new FlatbufferFieldT);
- path.field.back()->field_name = "flight_number";
- path.field.emplace_back(new FlatbufferFieldT);
- path.field.back()->field_name = "carrier_code";
- flatbuffers::FlatBufferBuilder path_builder;
- path_builder.Finish(FlatbufferFieldPath::Pack(path_builder, &path));
-
- std::unique_ptr<ReflectiveFlatbuffer> buffer = reflective_builder.NewRoot();
- // Test setting value using Set function.
- buffer->Mutable("flight_number")->Set("flight_code", 38);
- // Test setting value using FlatbufferFieldPath.
- buffer->Set(flatbuffers::GetRoot<FlatbufferFieldPath>(
- path_builder.GetBufferPointer()),
- "LX");
-
- // Try to parse with the generated code.
- std::string serialized_entity_data = buffer->Serialize();
- std::unique_ptr<test::EntityDataT> entity_data =
- LoadAndVerifyMutableFlatbuffer<test::EntityData>(
- serialized_entity_data.data(), serialized_entity_data.size());
- EXPECT_TRUE(entity_data != nullptr);
- EXPECT_EQ(entity_data->flight_number->carrier_code, "LX");
- EXPECT_EQ(entity_data->flight_number->flight_code, 38);
-}
-
-TEST(FlatbuffersTest, HandlesFieldsSetWithOffsetPath) {
- std::string metadata_buffer = LoadTestMetadata();
- ReflectiveFlatbufferBuilder reflective_builder(
- flatbuffers::GetRoot<reflection::Schema>(metadata_buffer.data()));
-
- FlatbufferFieldPathT path;
- path.field.emplace_back(new FlatbufferFieldT);
- path.field.back()->field_offset = 14;
- path.field.emplace_back(new FlatbufferFieldT);
- path.field.back()->field_offset = 4;
- flatbuffers::FlatBufferBuilder path_builder;
- path_builder.Finish(FlatbufferFieldPath::Pack(path_builder, &path));
-
- std::unique_ptr<ReflectiveFlatbuffer> buffer = reflective_builder.NewRoot();
- // Test setting value using Set function.
- buffer->Mutable("flight_number")->Set("flight_code", 38);
- // Test setting value using FlatbufferFieldPath.
- buffer->Set(flatbuffers::GetRoot<FlatbufferFieldPath>(
- path_builder.GetBufferPointer()),
- "LX");
-
- // Try to parse with the generated code.
- std::string serialized_entity_data = buffer->Serialize();
- std::unique_ptr<test::EntityDataT> entity_data =
- LoadAndVerifyMutableFlatbuffer<test::EntityData>(
- serialized_entity_data.data(), serialized_entity_data.size());
- EXPECT_TRUE(entity_data != nullptr);
- EXPECT_EQ(entity_data->flight_number->carrier_code, "LX");
- EXPECT_EQ(entity_data->flight_number->flight_code, 38);
-}
-
-TEST(FlatbuffersTest, PartialBuffersAreCorrectlyMerged) {
- std::string metadata_buffer = LoadTestMetadata();
- ReflectiveFlatbufferBuilder reflective_builder(
- flatbuffers::GetRoot<reflection::Schema>(metadata_buffer.data()));
- std::unique_ptr<ReflectiveFlatbuffer> buffer = reflective_builder.NewRoot();
- buffer->Set("an_int_field", 42);
- buffer->Set("a_long_field", 84ll);
- ReflectiveFlatbuffer* flight_info = buffer->Mutable("flight_number");
- flight_info->Set("carrier_code", "LX");
- flight_info->Set("flight_code", 38);
- auto* reminders = buffer->Repeated<ReflectiveFlatbuffer>("reminders");
- ReflectiveFlatbuffer* reminder1 = reminders->Add();
- reminder1->Set("title", "reminder1");
- auto* reminder1_notes = reminder1->Repeated<std::string>("notes");
- reminder1_notes->Add("note1");
- reminder1_notes->Add("note2");
-
- // Create message to merge.
- test::EntityDataT additional_entity_data;
- additional_entity_data.an_int_field = 43;
- additional_entity_data.flight_number.reset(new test::FlightNumberInfoT);
- additional_entity_data.flight_number->flight_code = 39;
- additional_entity_data.contact_info.reset(new test::ContactInfoT);
- additional_entity_data.contact_info->first_name = "Barack";
- additional_entity_data.reminders.push_back(
- std::unique_ptr<test::ReminderT>(new test::ReminderT));
- additional_entity_data.reminders[0]->notes.push_back("additional note1");
- additional_entity_data.reminders[0]->notes.push_back("additional note2");
- additional_entity_data.numbers.push_back(9);
- additional_entity_data.numbers.push_back(10);
- additional_entity_data.strings.push_back("str1");
- additional_entity_data.strings.push_back("str2");
-
- flatbuffers::FlatBufferBuilder to_merge_builder;
- to_merge_builder.Finish(
- test::EntityData::Pack(to_merge_builder, &additional_entity_data));
-
- // Merge it.
- EXPECT_TRUE(buffer->MergeFrom(
- flatbuffers::GetAnyRoot(to_merge_builder.GetBufferPointer())));
-
- // Try to parse it with the generated code.
- std::string serialized_entity_data = buffer->Serialize();
- std::unique_ptr<test::EntityDataT> entity_data =
- LoadAndVerifyMutableFlatbuffer<test::EntityData>(
- serialized_entity_data.data(), serialized_entity_data.size());
- EXPECT_TRUE(entity_data != nullptr);
- EXPECT_EQ(entity_data->an_int_field, 43);
- EXPECT_EQ(entity_data->a_long_field, 84);
- EXPECT_EQ(entity_data->flight_number->carrier_code, "LX");
- EXPECT_EQ(entity_data->flight_number->flight_code, 39);
- EXPECT_EQ(entity_data->contact_info->first_name, "Barack");
- ASSERT_EQ(entity_data->reminders.size(), 2);
- EXPECT_EQ(entity_data->reminders[1]->notes[0], "additional note1");
- EXPECT_EQ(entity_data->reminders[1]->notes[1], "additional note2");
- ASSERT_EQ(entity_data->numbers.size(), 2);
- EXPECT_EQ(entity_data->numbers[0], 9);
- EXPECT_EQ(entity_data->numbers[1], 10);
- ASSERT_EQ(entity_data->strings.size(), 2);
- EXPECT_EQ(entity_data->strings[0], "str1");
- EXPECT_EQ(entity_data->strings[1], "str2");
-}
-
-TEST(FlatbuffersTest, PrimitiveAndNestedFieldsAreCorrectlyFlattened) {
- std::string metadata_buffer = LoadTestMetadata();
- ReflectiveFlatbufferBuilder reflective_builder(
- flatbuffers::GetRoot<reflection::Schema>(metadata_buffer.data()));
- std::unique_ptr<ReflectiveFlatbuffer> buffer = reflective_builder.NewRoot();
- buffer->Set("an_int_field", 42);
- buffer->Set("a_long_field", 84ll);
- ReflectiveFlatbuffer* flight_info = buffer->Mutable("flight_number");
- flight_info->Set("carrier_code", "LX");
- flight_info->Set("flight_code", 38);
-
- std::map<std::string, Variant> entity_data_map = buffer->AsFlatMap();
- EXPECT_EQ(4, entity_data_map.size());
- EXPECT_EQ(42, entity_data_map["an_int_field"].IntValue());
- EXPECT_EQ(84, entity_data_map["a_long_field"].Int64Value());
- EXPECT_EQ("LX", entity_data_map["flight_number.carrier_code"].StringValue());
- EXPECT_EQ(38, entity_data_map["flight_number.flight_code"].IntValue());
-}
-
-TEST(FlatbuffersTest, ToTextProtoWorks) {
- std::string metadata_buffer = LoadTestMetadata();
- ReflectiveFlatbufferBuilder reflective_builder(
- flatbuffers::GetRoot<reflection::Schema>(metadata_buffer.data()));
- std::unique_ptr<ReflectiveFlatbuffer> buffer = reflective_builder.NewRoot();
- buffer->Set("an_int_field", 42);
- buffer->Set("a_long_field", 84ll);
- ReflectiveFlatbuffer* flight_info = buffer->Mutable("flight_number");
- flight_info->Set("carrier_code", "LX");
- flight_info->Set("flight_code", 38);
-
- EXPECT_EQ(buffer->ToTextProto(),
- "a_long_field: 84, an_int_field: 42, flight_number "
- "{flight_code: 38, carrier_code: 'LX'}");
-}
-
-TEST(FlatbuffersTest, RepeatedFieldSetThroughReflectionCanBeRead) {
- std::string metadata_buffer = LoadTestMetadata();
- const reflection::Schema* schema =
- flatbuffers::GetRoot<reflection::Schema>(metadata_buffer.data());
- ReflectiveFlatbufferBuilder reflective_builder(schema);
- std::unique_ptr<ReflectiveFlatbuffer> buffer = reflective_builder.NewRoot();
-
- auto reminders = buffer->Repeated<ReflectiveFlatbuffer>("reminders");
- {
- auto reminder = reminders->Add();
- reminder->Set("title", "test reminder");
- auto notes = reminder->Repeated<std::string>("notes");
- notes->Add("note A");
- notes->Add("note B");
- }
- {
- auto reminder = reminders->Add();
- reminder->Set("title", "test reminder 2");
- auto notes = reminder->Repeated<std::string>("notes");
- notes->Add("note i");
- notes->Add("note ii");
- notes->Add("note iii");
- }
- const std::string serialized_entity_data = buffer->Serialize();
-
- std::unique_ptr<test::EntityDataT> entity_data =
- LoadAndVerifyMutableFlatbuffer<test::EntityData>(
- serialized_entity_data.data(), serialized_entity_data.size());
- EXPECT_TRUE(entity_data != nullptr);
- EXPECT_EQ(2, entity_data->reminders.size());
- EXPECT_EQ("test reminder", entity_data->reminders[0]->title);
- EXPECT_THAT(entity_data->reminders[0]->notes,
- testing::ElementsAreArray({"note A", "note B"}));
- EXPECT_EQ("test reminder 2", entity_data->reminders[1]->title);
- EXPECT_THAT(entity_data->reminders[1]->notes,
- testing::ElementsAreArray({"note i", "note ii", "note iii"}));
-}
-
-TEST(FlatbuffersTest, ResolvesFieldOffsets) {
- std::string metadata_buffer = LoadTestMetadata();
- const reflection::Schema* schema =
- flatbuffers::GetRoot<reflection::Schema>(metadata_buffer.data());
- FlatbufferFieldPathT path;
- path.field.emplace_back(new FlatbufferFieldT);
- path.field.back()->field_name = "flight_number";
- path.field.emplace_back(new FlatbufferFieldT);
- path.field.back()->field_name = "carrier_code";
-
- EXPECT_TRUE(SwapFieldNamesForOffsetsInPath(schema, &path));
-
- EXPECT_THAT(path.field[0]->field_name, testing::IsEmpty());
- EXPECT_EQ(14, path.field[0]->field_offset);
- EXPECT_THAT(path.field[1]->field_name, testing::IsEmpty());
- EXPECT_EQ(4, path.field[1]->field_offset);
-}
-
-} // namespace
-} // namespace libtextclassifier3
diff --git a/native/utils/flatbuffers_test.fbs b/native/utils/flatbuffers_test.fbs
deleted file mode 100644
index f208ff4..0000000
--- a/native/utils/flatbuffers_test.fbs
+++ /dev/null
@@ -1,49 +0,0 @@
-//
-// Copyright (C) 2018 The Android Open Source Project
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-//
-
-namespace libtextclassifier3.test;
-
-table FlightNumberInfo {
- carrier_code: string;
- flight_code: int;
-}
-
-table ContactInfo {
- first_name: string;
- last_name: string;
- phone_number: string;
- score: float;
-}
-
-table Reminder {
- title: string;
- notes: [string];
-}
-
-table EntityData {
- an_int_field: int;
- a_long_field: int64;
- a_bool_field: bool;
- a_float_field: float;
- a_double_field: double;
- flight_number: FlightNumberInfo;
- contact_info: ContactInfo;
- reminders: [Reminder];
- numbers: [int];
- strings: [string];
-}
-
-root_type libtextclassifier3.test.EntityData;
diff --git a/native/utils/grammar/lexer.cc b/native/utils/grammar/lexer.cc
index 960f6c4..190545a 100644
--- a/native/utils/grammar/lexer.cc
+++ b/native/utils/grammar/lexer.cc
@@ -16,6 +16,8 @@
#include "utils/grammar/lexer.h"
+#include <unordered_map>
+
namespace libtextclassifier3::grammar {
namespace {
@@ -25,9 +27,20 @@
return matcher->ArenaSize() <= kMaxMemoryUsage;
}
+Match* CheckedAddMatch(const Nonterm nonterm,
+ const CodepointSpan codepoint_span,
+ const int match_offset, const int16 type,
+ Matcher* matcher) {
+ if (nonterm == kUnassignedNonterm || !CheckMemoryUsage(matcher)) {
+ return nullptr;
+ }
+ return matcher->AllocateAndInitMatch<Match>(nonterm, codepoint_span,
+ match_offset, type);
+}
+
void CheckedEmit(const Nonterm nonterm, const CodepointSpan codepoint_span,
const int match_offset, int16 type, Matcher* matcher) {
- if (CheckMemoryUsage(matcher)) {
+ if (nonterm != kUnassignedNonterm && CheckMemoryUsage(matcher)) {
matcher->AddMatch(matcher->AllocateAndInitMatch<Match>(
nonterm, codepoint_span, match_offset, type));
}
@@ -35,104 +48,95 @@
} // namespace
-void Lexer::Emit(const StringPiece value, const CodepointSpan codepoint_span,
- const int match_offset, const TokenType type,
- const RulesSet_::Nonterminals* nonterms,
+void Lexer::Emit(const Symbol& symbol, const RulesSet_::Nonterminals* nonterms,
Matcher* matcher) const {
- // Emit the token as terminal.
- if (CheckMemoryUsage(matcher)) {
- matcher->AddTerminal(codepoint_span, match_offset, value);
- }
-
- // Emit <token> if used by rules.
- if (nonterms->token_nt() != kUnassignedNonterm) {
- CheckedEmit(nonterms->token_nt(), codepoint_span, match_offset,
- Match::kTokenType, matcher);
- }
-
- // Emit token type specific non-terminals.
- if (type == TOKEN_TYPE_DIGITS) {
- // Emit <digits> if used by the rules.
- if (nonterms->digits_nt() != kUnassignedNonterm) {
- CheckedEmit(nonterms->digits_nt(), codepoint_span, match_offset,
- Match::kDigitsType, matcher);
+ switch (symbol.type) {
+ case Symbol::Type::TYPE_MATCH: {
+ // Just emit the match.
+ matcher->AddMatch(symbol.match);
+ return;
}
+ case Symbol::Type::TYPE_DIGITS: {
+ // Emit <digits> if used by the rules.
+ CheckedEmit(nonterms->digits_nt(), symbol.codepoint_span,
+ symbol.match_offset, Match::kDigitsType, matcher);
- // Emit <n_digits> if used by the rules.
- if (nonterms->n_digits_nt() != nullptr) {
- const int num_digits = codepoint_span.second - codepoint_span.first;
- if (num_digits <= nonterms->n_digits_nt()->size()) {
- if (const Nonterm n_digits =
- nonterms->n_digits_nt()->Get(num_digits - 1)) {
- CheckedEmit(n_digits, codepoint_span, match_offset,
+ // Emit <n_digits> if used by the rules.
+ if (nonterms->n_digits_nt() != nullptr) {
+ const int num_digits =
+ symbol.codepoint_span.second - symbol.codepoint_span.first;
+ if (num_digits <= nonterms->n_digits_nt()->size()) {
+ CheckedEmit(nonterms->n_digits_nt()->Get(num_digits - 1),
+ symbol.codepoint_span, symbol.match_offset,
Match::kDigitsType, matcher);
}
}
+ break;
}
+ default:
+ break;
}
+
+ // Emit the token as terminal.
+ if (CheckMemoryUsage(matcher)) {
+ matcher->AddTerminal(symbol.codepoint_span, symbol.match_offset,
+ symbol.lexeme);
+ }
+
+ // Emit <token> if used by rules.
+ CheckedEmit(nonterms->token_nt(), symbol.codepoint_span, symbol.match_offset,
+ Match::kTokenType, matcher);
}
-Lexer::TokenType Lexer::GetTokenType(
+Lexer::Symbol::Type Lexer::GetSymbolType(
const UnicodeText::const_iterator& it) const {
if (unilib_.IsPunctuation(*it)) {
- return TOKEN_TYPE_PUNCTUATION;
+ return Symbol::Type::TYPE_PUNCTUATION;
} else if (unilib_.IsDigit(*it)) {
- return TOKEN_TYPE_DIGITS;
- } else if (unilib_.IsWhitespace(*it)) {
- return TOKEN_TYPE_WHITESPACE;
+ return Symbol::Type::TYPE_DIGITS;
} else {
- return TOKEN_TYPE_TERM;
+ return Symbol::Type::TYPE_TERM;
}
}
-int Lexer::ProcessToken(const StringPiece value, const int prev_token_end,
- const CodepointSpan codepoint_span,
- const RulesSet_::Nonterminals* nonterms,
- Matcher* matcher) const {
+void Lexer::ProcessToken(const StringPiece value, const int prev_token_end,
+ const CodepointSpan codepoint_span,
+ std::vector<Lexer::Symbol>* symbols) const {
+ // Possibly split token.
UnicodeText token_unicode = UTF8ToUnicodeText(value.data(), value.size(),
/*do_copy=*/false);
-
- if (unilib_.IsWhitespace(*token_unicode.begin())) {
- // Ignore whitespace tokens.
- return prev_token_end;
- }
-
- // Possibly split token.
int last_end = prev_token_end;
auto token_end = token_unicode.end();
auto it = token_unicode.begin();
- TokenType type = GetTokenType(it);
+ Symbol::Type type = GetSymbolType(it);
CodepointIndex sub_token_start = codepoint_span.first;
while (it != token_end) {
auto next = std::next(it);
int num_codepoints = 1;
- TokenType next_type;
+ Symbol::Type next_type;
while (next != token_end) {
- next_type = GetTokenType(next);
- if (type == TOKEN_TYPE_PUNCTUATION || next_type != type) {
+ next_type = GetSymbolType(next);
+ if (type == Symbol::Type::TYPE_PUNCTUATION || next_type != type) {
break;
}
++next;
++num_codepoints;
}
-
- // Emit token.
- StringPiece sub_token =
- StringPiece(it.utf8_data(), next.utf8_data() - it.utf8_data());
- if (type != TOKEN_TYPE_WHITESPACE) {
- Emit(sub_token,
- CodepointSpan{sub_token_start, sub_token_start + num_codepoints},
- /*match_offset=*/last_end, type, nonterms, matcher);
- last_end = sub_token_start + num_codepoints;
- }
+ symbols->push_back(Symbol{
+ type, CodepointSpan{sub_token_start, sub_token_start + num_codepoints},
+ /*match_offset=*/last_end,
+ /*lexeme=*/
+ StringPiece(it.utf8_data(), next.utf8_data() - it.utf8_data())});
+ last_end = sub_token_start + num_codepoints;
it = next;
type = next_type;
sub_token_start = last_end;
}
- return last_end;
}
-void Lexer::Process(const std::vector<Token>& tokens, Matcher* matcher) const {
+void Lexer::Process(const std::vector<Token>& tokens,
+ const std::vector<Match*>& matches,
+ Matcher* matcher) const {
if (tokens.empty()) {
return;
}
@@ -140,28 +144,84 @@
const RulesSet_::Nonterminals* nonterminals = matcher->nonterminals();
// Initialize processing of new text.
- int prev_token_end = 0;
+ CodepointIndex prev_token_end = 0;
+ std::vector<Symbol> symbols;
matcher->Reset();
- // Emit start symbol if used by the grammar.
- if (nonterminals->start_nt() != kUnassignedNonterm) {
- matcher->AddMatch(matcher->AllocateAndInitMatch<Match>(
- nonterminals->start_nt(), CodepointSpan{0, 0},
- /*match_offset=*/0));
+ // The matcher expects the terminals and non-terminals it received to be in
+ // non-decreasing end-position order. The sorting above makes sure the
+ // pre-defined matches adhere to that order.
+ // Ideally, we would just have to emit a predefined match whenever we see that
+ // the next token we feed would be ending later.
+ // But as we implicitly ignore whitespace, we have to merge preceding
+ // whitespace to the match start so that tokens and non-terminals fed appear
+ // as next to each other without whitespace.
+ // We keep track of real token starts and precending whitespace in
+ // `token_match_start`, so that we can extend a predefined match's start to
+ // include the preceding whitespace.
+ std::unordered_map<CodepointIndex, CodepointIndex> token_match_start;
+
+ // Add start symbols.
+ if (Match* match =
+ CheckedAddMatch(nonterminals->start_nt(), CodepointSpan{0, 0},
+ /*match_offset=*/0, Match::kBreakType, matcher)) {
+ symbols.push_back(Symbol(match));
+ }
+ if (Match* match =
+ CheckedAddMatch(nonterminals->wordbreak_nt(), CodepointSpan{0, 0},
+ /*match_offset=*/0, Match::kBreakType, matcher)) {
+ symbols.push_back(Symbol(match));
}
for (const Token& token : tokens) {
- prev_token_end = ProcessToken(token.value,
- /*prev_token_end=*/prev_token_end,
- CodepointSpan{token.start, token.end},
- nonterminals, matcher);
+ // Record match starts for token boundaries, so that we can snap pre-defined
+ // matches to it.
+ if (prev_token_end != token.start) {
+ token_match_start[token.start] = prev_token_end;
+ }
+
+ ProcessToken(token.value,
+ /*prev_token_end=*/prev_token_end,
+ CodepointSpan{token.start, token.end}, &symbols);
+ prev_token_end = token.end;
+
+ // Add word break symbol if used by the grammar.
+ if (Match* match = CheckedAddMatch(
+ nonterminals->wordbreak_nt(), CodepointSpan{token.end, token.end},
+ /*match_offset=*/token.end, Match::kBreakType, matcher)) {
+ symbols.push_back(Symbol(match));
+ }
}
- // Emit end symbol if used by the grammar.
- if (nonterminals->end_nt() != kUnassignedNonterm) {
- matcher->AddMatch(matcher->AllocateAndInitMatch<Match>(
- nonterminals->end_nt(), CodepointSpan{prev_token_end, prev_token_end},
- /*match_offset=*/prev_token_end));
+ // Add end symbol if used by the grammar.
+ if (Match* match = CheckedAddMatch(
+ nonterminals->end_nt(), CodepointSpan{prev_token_end, prev_token_end},
+ /*match_offset=*/prev_token_end, Match::kBreakType, matcher)) {
+ symbols.push_back(Symbol(match));
+ }
+
+ // Add predefined matches.
+ for (Match* match : matches) {
+ // Decrease match offset to include preceding whitespace.
+ auto token_match_start_it = token_match_start.find(match->match_offset);
+ if (token_match_start_it != token_match_start.end()) {
+ match->match_offset = token_match_start_it->second;
+ }
+ symbols.push_back(Symbol(match));
+ }
+
+ std::sort(symbols.begin(), symbols.end(),
+ [](const Symbol& a, const Symbol& b) {
+ // Sort by increasing (end, start) position to guarantee the
+ // matcher requirement that the tokens are fed in non-decreasing
+ // end position order.
+ return std::tie(a.codepoint_span.second, a.codepoint_span.first) <
+ std::tie(b.codepoint_span.second, b.codepoint_span.first);
+ });
+
+ // Emit symbols to matcher.
+ for (const Symbol& symbol : symbols) {
+ Emit(symbol, nonterminals, matcher);
}
}
diff --git a/native/utils/grammar/lexer.h b/native/utils/grammar/lexer.h
index fd1a99d..87fd504 100644
--- a/native/utils/grammar/lexer.h
+++ b/native/utils/grammar/lexer.h
@@ -57,9 +57,8 @@
// This makes it appear to the Matcher as if the tokens are adjacent -- so
// whitespace is simply ignored.
//
-// A minor optimization: We don't bother to output <token> unless
-// the grammar includes a nonterminal with that name. Similarly, we don't
-// output <digits> unless the grammar uses them.
+// A minor optimization: We don't bother to output nonterminals if the grammar
+// rules don't reference them.
#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_LEXER_H_
#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_LEXER_H_
@@ -78,30 +77,77 @@
public:
explicit Lexer(const UniLib& unilib) : unilib_(unilib) {}
- void Process(const std::vector<Token>& tokens, Matcher* matcher) const;
+ // Processes a tokenized text. Classifies the tokens and feeds them to the
+ // matcher. Predefined existing matches `matches` will be fed to the matcher
+ // alongside the tokens.
+ void Process(const std::vector<Token>& tokens,
+ const std::vector<Match*>& matches, Matcher* matcher) const;
private:
- enum TokenType {
- TOKEN_TYPE_TERM,
- TOKEN_TYPE_WHITESPACE,
- TOKEN_TYPE_DIGITS,
- TOKEN_TYPE_PUNCTUATION
+ // A lexical symbol with an identified meaning that represents raw tokens,
+ // token categories or predefined text matches.
+ // It is the unit fed to the grammar matcher.
+ struct Symbol {
+ // The type of the lexical symbol.
+ enum class Type {
+ // A raw token.
+ TYPE_TERM,
+
+ // A symbol representing a string of digits.
+ TYPE_DIGITS,
+
+ // Punctuation characters.
+ TYPE_PUNCTUATION,
+
+ // A predefined match.
+ TYPE_MATCH
+ };
+
+ explicit Symbol() = default;
+
+ // Constructs a symbol of a given type with an anchor in the text.
+ Symbol(const Type type, const CodepointSpan codepoint_span,
+ const int match_offset, StringPiece lexeme)
+ : type(type),
+ codepoint_span(codepoint_span),
+ match_offset(match_offset),
+ lexeme(lexeme) {}
+
+ // Constructs a symbol from a pre-defined match.
+ explicit Symbol(Match* match)
+ : type(Type::TYPE_MATCH),
+ codepoint_span(match->codepoint_span),
+ match_offset(match->match_offset),
+ match(match) {}
+
+ // The type of the symbole.
+ Type type;
+
+ // The span in the text as codepoint offsets.
+ CodepointSpan codepoint_span;
+
+ // The match start offset (including preceding whitespace) as codepoint
+ // offset.
+ int match_offset;
+
+ // The symbol text value.
+ StringPiece lexeme;
+
+ // The predefined match.
+ Match* match;
};
- // Processes a single token: the token is split, classified and passed to the
- // matcher. Returns the end of the last part emitted.
- int ProcessToken(const StringPiece value, const int prev_token_end,
- const CodepointSpan codepoint_span,
- const RulesSet_::Nonterminals* nonterms,
- Matcher* matcher) const;
+ // Processes a single token: the token is split and classified into symbols.
+ void ProcessToken(const StringPiece value, const int prev_token_end,
+ const CodepointSpan codepoint_span,
+ std::vector<Symbol>* symbols) const;
// Emits a token to the matcher.
- void Emit(const StringPiece value, const CodepointSpan codepoint_span,
- int match_offset, const TokenType type,
- const RulesSet_::Nonterminals* nonterms, Matcher* matcher) const;
+ void Emit(const Symbol& symbol, const RulesSet_::Nonterminals* nonterms,
+ Matcher* matcher) const;
// Gets the type of a character.
- TokenType GetTokenType(const UnicodeText::const_iterator& it) const;
+ Symbol::Type GetSymbolType(const UnicodeText::const_iterator& it) const;
private:
const UniLib& unilib_;
diff --git a/native/utils/grammar/lexer_test.cc b/native/utils/grammar/lexer_test.cc
deleted file mode 100644
index bcf2309..0000000
--- a/native/utils/grammar/lexer_test.cc
+++ /dev/null
@@ -1,437 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-// Unit tests for the lexer.
-
-#include "utils/grammar/lexer.h"
-
-#include <string>
-#include <vector>
-
-#include "utils/grammar/callback-delegate.h"
-#include "utils/grammar/matcher.h"
-#include "utils/grammar/rules_generated.h"
-#include "utils/grammar/utils/ir.h"
-#include "utils/tokenizer.h"
-#include "utils/utf8/unilib.h"
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-
-namespace libtextclassifier3::grammar {
-namespace {
-
-using testing::ElementsAre;
-using testing::Eq;
-using testing::Value;
-
-// Superclass of all tests here.
-class LexerTest : public testing::Test {
- protected:
- LexerTest()
- : INIT_UNILIB_FOR_TESTING(unilib_),
- tokenizer_(TokenizationType_ICU, &unilib_,
- /*codepoint_ranges=*/{},
- /*internal_tokenizer_codepoint_ranges=*/{},
- /*split_on_script_change=*/false,
- /*icu_preserve_whitespace_tokens=*/true),
- lexer_(unilib_) {}
-
- // Creates a grammar just checking the specified terminals.
- std::string GrammarForTerminals(const std::vector<std::string>& terminals) {
- const CallbackId callback = 1;
- Ir ir;
- for (const std::string& terminal : terminals) {
- ir.Add(Ir::Lhs{kUnassignedNonterm, {callback}}, terminal);
- }
- return ir.SerializeAsFlatbuffer();
- }
-
- UniLib unilib_;
- Tokenizer tokenizer_;
- Lexer lexer_;
-};
-
-struct TestMatchResult {
- int begin;
- int end;
- std::string terminal;
- std::string nonterminal;
-};
-
-MATCHER_P3(IsTerminal, begin, end, terminal, "") {
- return Value(arg.begin, begin) && Value(arg.end, end) &&
- Value(arg.terminal, terminal);
-}
-
-MATCHER_P3(IsNonterminal, begin, end, name, "") {
- return Value(arg.begin, begin) && Value(arg.end, end) &&
- Value(arg.nonterminal, name);
-}
-
-// This is a simple callback for testing purposes.
-class TestCallbackDelegate : public CallbackDelegate {
- public:
- explicit TestCallbackDelegate(
- const RulesSet_::DebugInformation* debug_information)
- : debug_information_(debug_information) {}
-
- void MatchFound(const Match* match, const CallbackId, const int64,
- Matcher*) override {
- TestMatchResult result;
- result.begin = match->match_offset;
- result.end = match->codepoint_span.second;
- if (match->IsTerminalRule()) {
- result.terminal = match->terminal;
- } else if (match->IsUnaryRule()) {
- // We use callbacks on unary rules to attach a callback to a
- // predefined lhs.
- result.nonterminal = GetNonterminalName(match->unary_rule_rhs()->lhs);
- } else {
- result.nonterminal = GetNonterminalName(match->lhs);
- }
- log_.push_back(result);
- }
-
- void Clear() { log_.clear(); }
-
- const std::vector<TestMatchResult>& log() const { return log_; }
-
- private:
- std::string GetNonterminalName(const Nonterm nonterminal) const {
- if (const RulesSet_::DebugInformation_::NonterminalNamesEntry* entry =
- debug_information_->nonterminal_names()->LookupByKey(nonterminal)) {
- return entry->value()->str();
- }
- // Unnamed Nonterm.
- return "()";
- }
-
- const RulesSet_::DebugInformation* debug_information_;
- std::vector<TestMatchResult> log_;
-};
-
-TEST_F(LexerTest, HandlesSimpleLetters) {
- std::string rules_buffer = GrammarForTerminals({"a", "is", "this", "word"});
- const RulesSet* rules = flatbuffers::GetRoot<RulesSet>(rules_buffer.data());
- TestCallbackDelegate test_logger(rules->debug_information());
- Matcher matcher(unilib_, rules, &test_logger);
-
- lexer_.Process(tokenizer_.Tokenize("This is a word"), &matcher);
-
- EXPECT_THAT(test_logger.log().size(), Eq(4));
-}
-
-TEST_F(LexerTest, HandlesConcatedLettersAndDigit) {
- std::string rules_buffer =
- GrammarForTerminals({"1234", "4321", "a", "cde", "this"});
- const RulesSet* rules = flatbuffers::GetRoot<RulesSet>(rules_buffer.data());
- TestCallbackDelegate test_logger(rules->debug_information());
- Matcher matcher(unilib_, rules, &test_logger);
-
- lexer_.Process(tokenizer_.Tokenize("1234This a4321cde"), &matcher);
-
- EXPECT_THAT(test_logger.log().size(), Eq(5));
-}
-
-TEST_F(LexerTest, HandlesPunctuation) {
- std::string rules_buffer = GrammarForTerminals({"10", "18", "2014", "/"});
- const RulesSet* rules = flatbuffers::GetRoot<RulesSet>(rules_buffer.data());
- TestCallbackDelegate test_logger(rules->debug_information());
- Matcher matcher(unilib_, rules, &test_logger);
-
- lexer_.Process(tokenizer_.Tokenize("10/18/2014"), &matcher);
-
- EXPECT_THAT(test_logger.log().size(), Eq(5));
-}
-
-TEST_F(LexerTest, HandlesUTF8Punctuation) {
- std::string rules_buffer =
- GrammarForTerminals({"电话", ":", "0871", "—", "6857", "(", "曹"});
- const RulesSet* rules = flatbuffers::GetRoot<RulesSet>(rules_buffer.data());
- TestCallbackDelegate test_logger(rules->debug_information());
- Matcher matcher(unilib_, rules, &test_logger);
-
- lexer_.Process(tokenizer_.Tokenize("电话:0871—6857(曹"), &matcher);
-
- EXPECT_THAT(test_logger.log().size(), Eq(7));
-}
-
-TEST_F(LexerTest, HandlesMixedPunctuation) {
- std::string rules_buffer =
- GrammarForTerminals({"电话", ":", "0871", "—", "6857", "(", "曹"});
- const RulesSet* rules = flatbuffers::GetRoot<RulesSet>(rules_buffer.data());
- TestCallbackDelegate test_logger(rules->debug_information());
- Matcher matcher(unilib_, rules, &test_logger);
-
- lexer_.Process(tokenizer_.Tokenize("电话 :0871—6857(曹"), &matcher);
-
- EXPECT_THAT(test_logger.log().size(), Eq(7));
-}
-
-// Tests that the tokenizer adds the correct tokens, including <digits>, to
-// the Matcher.
-TEST_F(LexerTest, CorrectTokenOutputWithDigits) {
- const CallbackId callback = 1;
- Ir ir;
- const Nonterm digits = ir.AddUnshareableNonterminal("<digits>");
- ir.Add(Ir::Lhs{kUnassignedNonterm, {callback}}, digits);
- ir.Add(Ir::Lhs{kUnassignedNonterm, {callback}}, "the");
- ir.Add(Ir::Lhs{kUnassignedNonterm, {callback}}, "quick");
- ir.Add(Ir::Lhs{kUnassignedNonterm, {callback}}, "brown");
- ir.Add(Ir::Lhs{kUnassignedNonterm, {callback}}, "fox");
- ir.Add(Ir::Lhs{kUnassignedNonterm, {callback}}, "2345");
- ir.Add(Ir::Lhs{kUnassignedNonterm, {callback}}, "88");
- ir.Add(Ir::Lhs{kUnassignedNonterm, {callback}}, ".");
- ir.Add(Ir::Lhs{kUnassignedNonterm, {callback}}, "<");
- ir.Add(Ir::Lhs{kUnassignedNonterm, {callback}}, "&");
- ir.Add(Ir::Lhs{kUnassignedNonterm, {callback}}, "\xE2\x80\x94");
- std::string rules_buffer =
- ir.SerializeAsFlatbuffer(/*include_debug_information=*/true);
- const RulesSet* rules = flatbuffers::GetRoot<RulesSet>(rules_buffer.data());
- TestCallbackDelegate test_logger(rules->debug_information());
- Matcher matcher(unilib_, rules, &test_logger);
-
- lexer_.Process(
- tokenizer_.Tokenize("The.qUIck\n brown2345fox88 \xE2\x80\x94 the"),
- &matcher);
-
- EXPECT_THAT(
- test_logger.log(),
- ElementsAre(IsTerminal(0, 3, "the"), IsTerminal(3, 4, "."),
- IsTerminal(4, 9, "quick"), IsTerminal(9, 16, "brown"),
-
- // Lexer automatically creates a digits nonterminal.
- IsTerminal(16, 20, "2345"), IsNonterminal(16, 20, "<digits>"),
-
- IsTerminal(20, 23, "fox"),
-
- // Lexer automatically creates a digits nonterminal.
- IsTerminal(23, 25, "88"), IsNonterminal(23, 25, "<digits>"),
-
- IsTerminal(25, 27, "\xE2\x80\x94"),
- IsTerminal(27, 31, "the")));
-}
-
-// Tests that the tokenizer adds the correct tokens, including the
-// special <token> tokens, to the Matcher.
-TEST_F(LexerTest, CorrectTokenOutputWithGenericTokens) {
- const CallbackId callback = 1;
- Ir ir;
- const Nonterm token = ir.AddUnshareableNonterminal("<token>");
- ir.Add(Ir::Lhs{kUnassignedNonterm, {callback}}, token);
- ir.Add(Ir::Lhs{kUnassignedNonterm, {callback}}, "the");
- ir.Add(Ir::Lhs{kUnassignedNonterm, {callback}}, "quick");
- ir.Add(Ir::Lhs{kUnassignedNonterm, {callback}}, "brown");
- ir.Add(Ir::Lhs{kUnassignedNonterm, {callback}}, "fox");
- ir.Add(Ir::Lhs{kUnassignedNonterm, {callback}}, "2345");
- ir.Add(Ir::Lhs{kUnassignedNonterm, {callback}}, "88");
- ir.Add(Ir::Lhs{kUnassignedNonterm, {callback}}, ".");
- ir.Add(Ir::Lhs{kUnassignedNonterm, {callback}}, "<");
- ir.Add(Ir::Lhs{kUnassignedNonterm, {callback}}, "&");
- ir.Add(Ir::Lhs{kUnassignedNonterm, {callback}}, "\xE2\x80\x94");
- std::string rules_buffer =
- ir.SerializeAsFlatbuffer(/*include_debug_information=*/true);
- const RulesSet* rules = flatbuffers::GetRoot<RulesSet>(rules_buffer.data());
- TestCallbackDelegate test_logger(rules->debug_information());
- Matcher matcher(unilib_, rules, &test_logger);
-
- lexer_.Process(
- tokenizer_.Tokenize("The.qUIck\n brown2345fox88 \xE2\x80\x94 the"),
- &matcher);
-
- EXPECT_THAT(
- test_logger.log(),
- // Lexer will create a <token> nonterminal for each token.
- ElementsAre(IsTerminal(0, 3, "the"), IsNonterminal(0, 3, "<token>"),
- IsTerminal(3, 4, "."), IsNonterminal(3, 4, "<token>"),
- IsTerminal(4, 9, "quick"), IsNonterminal(4, 9, "<token>"),
- IsTerminal(9, 16, "brown"), IsNonterminal(9, 16, "<token>"),
- IsTerminal(16, 20, "2345"), IsNonterminal(16, 20, "<token>"),
- IsTerminal(20, 23, "fox"), IsNonterminal(20, 23, "<token>"),
- IsTerminal(23, 25, "88"), IsNonterminal(23, 25, "<token>"),
- IsTerminal(25, 27, "\xE2\x80\x94"),
- IsNonterminal(25, 27, "<token>"), IsTerminal(27, 31, "the"),
- IsNonterminal(27, 31, "<token>")));
-}
-
-// Tests that the tokenizer adds the correct tokens, including <digits>, to
-// the Matcher. This test includes UTF8 letters.
-TEST_F(LexerTest, CorrectTokenOutputWithDigitsAndUTF8) {
- const CallbackId callback = 1;
- Ir ir;
- ir.Add(Ir::Lhs{kUnassignedNonterm, {callback}},
- ir.AddUnshareableNonterminal("<digits>"));
-
- const std::string em_dash = "\xE2\x80\x94"; // "Em dash" in UTF-8
- const std::string capital_i_acute = "\xC3\x8D";
- const std::string lower_i_acute = "\xC3\xAD";
-
- ir.Add(Ir::Lhs{kUnassignedNonterm, {callback}}, "the");
- ir.Add(Ir::Lhs{kUnassignedNonterm, {callback}}, ".");
- ir.Add(Ir::Lhs{kUnassignedNonterm, {callback}},
- "quick" + lower_i_acute + lower_i_acute + lower_i_acute + "h");
- ir.Add(Ir::Lhs{kUnassignedNonterm, {callback}}, em_dash);
- ir.Add(Ir::Lhs{kUnassignedNonterm, {callback}}, lower_i_acute + "h");
- ir.Add(Ir::Lhs{kUnassignedNonterm, {callback}}, "22");
- ir.Add(Ir::Lhs{kUnassignedNonterm, {callback}}, "brown");
- std::string rules_buffer =
- ir.SerializeAsFlatbuffer(/*include_debug_information=*/true);
- const RulesSet* rules = flatbuffers::GetRoot<RulesSet>(rules_buffer.data());
- TestCallbackDelegate test_logger(rules->debug_information());
- Matcher matcher(unilib_, rules, &test_logger);
-
- lexer_.Process(tokenizer_.Tokenize("The.qUIck" + lower_i_acute +
- capital_i_acute + lower_i_acute + "h" +
- em_dash + capital_i_acute + "H22brown"),
- &matcher);
-
- EXPECT_THAT(
- test_logger.log(),
- ElementsAre(IsTerminal(0, 3, "the"), IsTerminal(3, 4, "."),
- IsTerminal(4, 13, "quickíííh"), IsTerminal(13, 14, "—"),
- IsTerminal(14, 16, "íh"), IsTerminal(16, 18, "22"),
- IsNonterminal(16, 18, "<digits>"),
- IsTerminal(18, 23, "brown")));
-}
-
-// Tests that the tokenizer adds the correct tokens to the Matcher.
-// For this test, there's no <digits> nonterminal in the Rules, and some
-// tokens aren't in any grammar rules.
-TEST_F(LexerTest, CorrectTokenOutputWithoutDigits) {
- const CallbackId callback = 1;
- Ir ir;
- ir.Add(Ir::Lhs{kUnassignedNonterm, {callback}}, "the");
- ir.Add(Ir::Lhs{kUnassignedNonterm, {callback}}, "2345");
- ir.Add(Ir::Lhs{kUnassignedNonterm, {callback}}, "88");
- std::string rules_buffer =
- ir.SerializeAsFlatbuffer(/*include_debug_information=*/true);
- const RulesSet* rules = flatbuffers::GetRoot<RulesSet>(rules_buffer.data());
- TestCallbackDelegate test_logger(rules->debug_information());
- Matcher matcher(unilib_, rules, &test_logger);
-
- lexer_.Process(tokenizer_.Tokenize("The.qUIck\n brown2345fox88"), &matcher);
-
- EXPECT_THAT(test_logger.log(),
- ElementsAre(IsTerminal(0, 3, "the"), IsTerminal(16, 20, "2345"),
- IsTerminal(23, 25, "88")));
-}
-
-// Tests that the tokenizer adds the correct <n_digits> tokens to the Matcher.
-TEST_F(LexerTest, CorrectTokenOutputWithNDigits) {
- const CallbackId callback = 1;
- Ir ir;
- ir.Add(Ir::Lhs{kUnassignedNonterm, {callback}},
- ir.AddUnshareableNonterminal("<digits>"));
- ir.Add(Ir::Lhs{kUnassignedNonterm, {callback}},
- ir.AddUnshareableNonterminal("<2_digits>"));
- ir.Add(Ir::Lhs{kUnassignedNonterm, {callback}},
- ir.AddUnshareableNonterminal("<4_digits>"));
- ir.Add(Ir::Lhs{kUnassignedNonterm, {callback}}, "the");
- ir.Add(Ir::Lhs{kUnassignedNonterm, {callback}}, "2345");
- ir.Add(Ir::Lhs{kUnassignedNonterm, {callback}}, "88");
- std::string rules_buffer =
- ir.SerializeAsFlatbuffer(/*include_debug_information=*/true);
- const RulesSet* rules = flatbuffers::GetRoot<RulesSet>(rules_buffer.data());
- TestCallbackDelegate test_logger(rules->debug_information());
- Matcher matcher(unilib_, rules, &test_logger);
-
- lexer_.Process(tokenizer_.Tokenize("The.qUIck\n brown2345fox88"), &matcher);
-
- EXPECT_THAT(
- test_logger.log(),
- ElementsAre(IsTerminal(0, 3, "the"),
- // Lexer should generate <digits> and <4_digits> for 2345.
- IsTerminal(16, 20, "2345"), IsNonterminal(16, 20, "<digits>"),
- IsNonterminal(16, 20, "<4_digits>"),
-
- // Lexer should generate <digits> and <2_digits> for 88.
- IsTerminal(23, 25, "88"), IsNonterminal(23, 25, "<digits>"),
- IsNonterminal(23, 25, "<2_digits>")));
-}
-
-// Tests that the tokenizer splits "million+" into separate tokens.
-TEST_F(LexerTest, CorrectTokenOutputWithPlusSigns) {
- const CallbackId callback = 1;
- Ir ir;
- ir.Add(Ir::Lhs{kUnassignedNonterm, {callback}}, "the");
- ir.Add(Ir::Lhs{kUnassignedNonterm, {callback}}, "2345");
- ir.Add(Ir::Lhs{kUnassignedNonterm, {callback}}, "+");
- const std::string lower_i_acute = "\xC3\xAD";
- ir.Add(Ir::Lhs{kUnassignedNonterm, {callback}},
- lower_i_acute + lower_i_acute);
- std::string rules_buffer =
- ir.SerializeAsFlatbuffer(/*include_debug_information=*/true);
- const RulesSet* rules = flatbuffers::GetRoot<RulesSet>(rules_buffer.data());
- TestCallbackDelegate test_logger(rules->debug_information());
- Matcher matcher(unilib_, rules, &test_logger);
-
- lexer_.Process(tokenizer_.Tokenize("The+2345++the +" + lower_i_acute +
- lower_i_acute + "+"),
- &matcher);
-
- EXPECT_THAT(test_logger.log(),
- ElementsAre(IsTerminal(0, 3, "the"), IsTerminal(3, 4, "+"),
- IsTerminal(4, 8, "2345"), IsTerminal(8, 9, "+"),
- IsTerminal(9, 10, "+"), IsTerminal(10, 13, "the"),
- IsTerminal(13, 15, "+"), IsTerminal(15, 17, "íí"),
- IsTerminal(17, 18, "+")));
-}
-
-// Tests that the tokenizer correctly uses the anchor tokens.
-TEST_F(LexerTest, HandlesStartAnchor) {
- const CallbackId log_callback = 1;
- Ir ir;
- // <test> ::= <^> the test <$>
- ir.Add(Ir::Lhs{ir.AddNonterminal("<test>"), {log_callback}},
- std::vector<Nonterm>{
- ir.AddUnshareableNonterminal(kStartNonterm),
- ir.Add(Ir::Lhs{kUnassignedNonterm, {log_callback}}, "the"),
- ir.Add(Ir::Lhs{kUnassignedNonterm, {log_callback}}, "test"),
- ir.AddUnshareableNonterminal(kEndNonterm),
- });
- std::string rules_buffer =
- ir.SerializeAsFlatbuffer(/*include_debug_information=*/true);
- const RulesSet* rules = flatbuffers::GetRoot<RulesSet>(rules_buffer.data());
- TestCallbackDelegate test_logger(rules->debug_information());
- Matcher matcher(unilib_, rules, &test_logger);
-
- // Make sure the grammar recognizes "the test".
- lexer_.Process(tokenizer_.Tokenize("the test"), &matcher);
- EXPECT_THAT(test_logger.log(),
- // Expect logging of the two terminals and then matching of the
- // nonterminal.
- ElementsAre(IsTerminal(0, 3, "the"), IsTerminal(3, 8, "test"),
- IsNonterminal(0, 8, "<test>")));
-
- // Make sure that only left anchored matches are propagated.
- test_logger.Clear();
- lexer_.Process(tokenizer_.Tokenize("the the test"), &matcher); // NOTYPO
- EXPECT_THAT(test_logger.log(),
- // Expect that "<test>" nonterminal is not matched.
- ElementsAre(IsTerminal(0, 3, "the"), IsTerminal(3, 7, "the"),
- IsTerminal(7, 12, "test")));
-
- // Make sure that only right anchored matches are propagated.
- test_logger.Clear();
- lexer_.Process(tokenizer_.Tokenize("the test test"), &matcher);
- EXPECT_THAT(test_logger.log(),
- // Expect that "<test>" nonterminal is not matched.
- ElementsAre(IsTerminal(0, 3, "the"), IsTerminal(3, 8, "test"),
- IsTerminal(8, 13, "test")));
-}
-
-} // namespace
-} // namespace libtextclassifier3::grammar
diff --git a/native/utils/grammar/match.h b/native/utils/grammar/match.h
index 2e578d8..c1c7022 100644
--- a/native/utils/grammar/match.h
+++ b/native/utils/grammar/match.h
@@ -33,8 +33,9 @@
static const int16 kUnknownType = 0;
static const int16 kTokenType = -1;
static const int16 kDigitsType = -2;
- static const int16 kCapturingMatch = -3;
- static const int16 kAssertionMatch = -4;
+ static const int16 kBreakType = -3;
+ static const int16 kCapturingMatch = -4;
+ static const int16 kAssertionMatch = -5;
void Init(const Nonterm arg_lhs, const CodepointSpan arg_codepoint_span,
const int arg_match_offset, const int arg_type = kUnknownType) {
diff --git a/native/utils/grammar/matcher.h b/native/utils/grammar/matcher.h
index 64db80f..662ff9c 100644
--- a/native/utils/grammar/matcher.h
+++ b/native/utils/grammar/matcher.h
@@ -19,9 +19,8 @@
// A lexer passes token to the matcher: literal terminal strings and token
// types. It passes tokens to the matcher by calling AddTerminal() and
// AddMatch() for literal terminals and token types, respectively.
-// The lexer passes
-// each token along with the [begin, end) position range in which it
-// occurs. So for an input string "Groundhog February 2, 2007", the
+// The lexer passes each token along with the [begin, end) position range
+// in which it occurs. So for an input string "Groundhog February 2, 2007", the
// lexer would tell the matcher that:
//
// "Groundhog" occurs at [0, 9)
@@ -47,7 +46,7 @@
// to the matcher in left-to-right order, strictly speaking, their "end"
// positions must be nondecreasing. (This constraint allows a more
// efficient matching algorithm.) The "begin" positions can be in any
-// order. Zero-width tokens (begin==end) are not allowed.
+// order.
//
// There are two kinds of supported callbacks:
// (1) OUTPUT: Callbacks are the only output mechanism a matcher has. For each
diff --git a/native/utils/grammar/matcher_test.cc b/native/utils/grammar/matcher_test.cc
deleted file mode 100644
index e5c0ab9..0000000
--- a/native/utils/grammar/matcher_test.cc
+++ /dev/null
@@ -1,520 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/grammar/matcher.h"
-
-#include <iostream>
-#include <string>
-#include <vector>
-
-#include "utils/grammar/callback-delegate.h"
-#include "utils/grammar/rules_generated.h"
-#include "utils/grammar/types.h"
-#include "utils/grammar/utils/rules.h"
-#include "utils/strings/append.h"
-#include "utils/utf8/unilib.h"
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-
-namespace libtextclassifier3::grammar {
-namespace {
-
-using testing::ElementsAre;
-using testing::IsEmpty;
-using testing::Value;
-
-struct TestMatchResult {
- int begin;
- int end;
- std::string terminal;
- std::string nonterminal;
- int callback_id;
- int callback_param;
-};
-
-MATCHER_P3(IsTerminal, begin, end, terminal, "") {
- return Value(arg.begin, begin) && Value(arg.end, end) &&
- Value(arg.terminal, terminal);
-}
-
-MATCHER_P4(IsTerminalWithCallback, begin, end, terminal, callback, "") {
- return (ExplainMatchResult(IsTerminal(begin, end, terminal), arg,
- result_listener) &&
- Value(arg.callback_id, callback));
-}
-
-MATCHER_P3(IsNonterminal, begin, end, name, "") {
- return Value(arg.begin, begin) && Value(arg.end, end) &&
- Value(arg.nonterminal, name);
-}
-
-MATCHER_P4(IsNonterminalWithCallback, begin, end, name, callback, "") {
- return (ExplainMatchResult(IsNonterminal(begin, end, name), arg,
- result_listener) &&
- Value(arg.callback_id, callback));
-}
-
-MATCHER_P5(IsNonterminalWithCallbackAndParam, begin, end, name, callback, param,
- "") {
- return (
- ExplainMatchResult(IsNonterminalWithCallback(begin, end, name, callback),
- arg, result_listener) &&
- Value(arg.callback_param, param));
-}
-
-// Superclass of all tests.
-class MatcherTest : public testing::Test {
- protected:
- MatcherTest() : INIT_UNILIB_FOR_TESTING(unilib_) {}
-
- UniLib unilib_;
-};
-
-// This is a simple delegate implementation for testing purposes.
-// All it does is produce a record of all matches that were added.
-class TestCallbackDelegate : public CallbackDelegate {
- public:
- explicit TestCallbackDelegate(
- const RulesSet_::DebugInformation* debug_information)
- : debug_information_(debug_information) {}
-
- void MatchFound(const Match* match, const CallbackId callback_id,
- const int64 callback_param, Matcher*) override {
- TestMatchResult result;
- result.begin = match->codepoint_span.first;
- result.end = match->codepoint_span.second;
- result.callback_id = callback_id;
- result.callback_param = static_cast<int>(callback_param);
- result.nonterminal = GetNonterminalName(match->lhs);
- if (match->IsTerminalRule()) {
- result.terminal = match->terminal;
- }
- log_.push_back(result);
- }
-
- void ClearLog() { log_.clear(); }
-
- const std::vector<TestMatchResult> GetLogAndReset() {
- const auto result = log_;
- ClearLog();
- return result;
- }
-
- protected:
- std::string GetNonterminalName(const Nonterm nonterminal) const {
- if (const RulesSet_::DebugInformation_::NonterminalNamesEntry* entry =
- debug_information_->nonterminal_names()->LookupByKey(nonterminal)) {
- return entry->value()->str();
- }
- // Unnamed Nonterm.
- return "()";
- }
-
- std::vector<TestMatchResult> log_;
- const RulesSet_::DebugInformation* debug_information_;
-};
-
-TEST_F(MatcherTest, HandlesBasicOperations) {
- // Create an example grammar.
- Rules rules;
- const CallbackId callback = 1;
- rules.Add("<test>", {"the", "quick", "brown", "fox"}, callback);
- rules.Add("<action>", {"<test>"}, callback);
- const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer(
- /*include_debug_information=*/true);
- const RulesSet* rules_set =
- flatbuffers::GetRoot<RulesSet>(rules_buffer.data());
- TestCallbackDelegate test_logger(rules_set->debug_information());
- Matcher matcher(unilib_, rules_set, &test_logger);
-
- matcher.AddTerminal(0, 1, "the");
- matcher.AddTerminal(1, 2, "quick");
- matcher.AddTerminal(2, 3, "brown");
- matcher.AddTerminal(3, 4, "fox");
-
- EXPECT_THAT(test_logger.GetLogAndReset(),
- ElementsAre(IsNonterminal(0, 4, "<test>"),
- IsNonterminal(0, 4, "<action>")));
-}
-
-std::string CreateTestGrammar() {
- // Create an example grammar.
- Rules rules;
-
- // Callbacks on terminal rules.
- rules.Add("<output_5>", {"quick"}, 6);
- rules.Add("<output_0>", {"the"}, 1);
-
- // Callbacks on non-terminal rules.
- rules.Add("<output_1>", {"the", "quick", "brown", "fox"}, 2);
- rules.Add("<output_2>", {"the", "quick"}, 3, static_cast<int64>(-1));
- rules.Add("<output_3>", {"brown", "fox"}, 4);
- // Now a complex thing: "the* brown fox".
- rules.Add("<thestarbrownfox>", {"brown", "fox"}, 5);
- rules.Add("<thestarbrownfox>", {"the", "<thestarbrownfox>"}, 5);
-
- return rules.Finalize().SerializeAsFlatbuffer(
- /*include_debug_information=*/true);
-}
-
-std::string CreateTestGrammarWithOptionalElements() {
- // Create an example grammar.
- Rules rules;
- rules.Add("<output_0>", {"a?", "b?", "c?", "d?", "e"}, 1);
- rules.Add("<output_1>", {"a", "b?", "c", "d?", "e"}, 2);
- rules.Add("<output_2>", {"a", "b?", "c", "d", "e?"}, 3);
-
- return rules.Finalize().SerializeAsFlatbuffer(
- /*include_debug_information=*/true);
-}
-
-Nonterm FindNontermForName(const RulesSet* rules,
- const std::string& nonterminal_name) {
- for (const RulesSet_::DebugInformation_::NonterminalNamesEntry* entry :
- *rules->debug_information()->nonterminal_names()) {
- if (entry->value()->str() == nonterminal_name) {
- return entry->key();
- }
- }
- return kUnassignedNonterm;
-}
-
-TEST_F(MatcherTest, HandlesBasicOperationsWithCallbacks) {
- const std::string rules_buffer = CreateTestGrammar();
- const RulesSet* rules_set =
- flatbuffers::GetRoot<RulesSet>(rules_buffer.data());
- TestCallbackDelegate test_logger(rules_set->debug_information());
- Matcher matcher(unilib_, rules_set, &test_logger);
-
- matcher.AddTerminal(0, 1, "the");
- EXPECT_THAT(test_logger.GetLogAndReset(),
- ElementsAre(IsTerminalWithCallback(/*begin=*/0, /*end=*/1, "the",
- /*callback=*/1)));
- matcher.AddTerminal(1, 2, "quick");
- EXPECT_THAT(
- test_logger.GetLogAndReset(),
- ElementsAre(IsTerminalWithCallback(/*begin=*/1, /*end=*/2, "quick",
- /*callback=*/6),
- IsNonterminalWithCallbackAndParam(
- /*begin=*/0, /*end=*/2, "<output_2>",
- /*callback=*/3, -1)));
-
- matcher.AddTerminal(2, 3, "brown");
- EXPECT_THAT(test_logger.GetLogAndReset(), IsEmpty());
-
- matcher.AddTerminal(3, 4, "fox");
- EXPECT_THAT(
- test_logger.GetLogAndReset(),
- ElementsAre(
- IsNonterminalWithCallback(/*begin=*/0, /*end=*/4, "<output_1>",
- /*callback=*/2),
- IsNonterminalWithCallback(/*begin=*/2, /*end=*/4, "<output_3>",
- /*callback=*/4),
- IsNonterminalWithCallback(/*begin=*/2, /*end=*/4, "<thestarbrownfox>",
- /*callback=*/5)));
-
- matcher.AddTerminal(3, 5, "fox");
- EXPECT_THAT(
- test_logger.GetLogAndReset(),
- ElementsAre(
- IsNonterminalWithCallback(/*begin=*/0, /*end=*/5, "<output_1>",
- /*callback=*/2),
- IsNonterminalWithCallback(/*begin=*/2, /*end=*/5, "<output_3>",
- /*callback=*/4),
- IsNonterminalWithCallback(/*begin=*/2, /*end=*/5, "<thestarbrownfox>",
- /*callback=*/5)));
-
- matcher.AddTerminal(4, 6, "fox"); // Not adjacent to "brown".
- EXPECT_THAT(test_logger.GetLogAndReset(), IsEmpty());
-}
-
-TEST_F(MatcherTest, HandlesRecursiveRules) {
- const std::string rules_buffer = CreateTestGrammar();
- const RulesSet* rules_set =
- flatbuffers::GetRoot<RulesSet>(rules_buffer.data());
- TestCallbackDelegate test_logger(rules_set->debug_information());
- Matcher matcher(unilib_, rules_set, &test_logger);
-
- matcher.AddTerminal(0, 1, "the");
- matcher.AddTerminal(1, 2, "the");
- matcher.AddTerminal(2, 4, "the");
- matcher.AddTerminal(3, 4, "the");
- matcher.AddTerminal(4, 5, "brown");
- matcher.AddTerminal(5, 6, "fox"); // Generates 5 of <thestarbrownfox>
-
- EXPECT_THAT(
- test_logger.GetLogAndReset(),
- ElementsAre(IsTerminal(/*begin=*/0, /*end=*/1, "the"),
- IsTerminal(/*begin=*/1, /*end=*/2, "the"),
- IsTerminal(/*begin=*/2, /*end=*/4, "the"),
- IsTerminal(/*begin=*/3, /*end=*/4, "the"),
- IsNonterminal(/*begin=*/4, /*end=*/6, "<output_3>"),
- IsNonterminal(/*begin=*/4, /*end=*/6, "<thestarbrownfox>"),
- IsNonterminal(/*begin=*/3, /*end=*/6, "<thestarbrownfox>"),
- IsNonterminal(/*begin=*/2, /*end=*/6, "<thestarbrownfox>"),
- IsNonterminal(/*begin=*/1, /*end=*/6, "<thestarbrownfox>"),
- IsNonterminal(/*begin=*/0, /*end=*/6, "<thestarbrownfox>")));
-}
-
-TEST_F(MatcherTest, HandlesManualAddMatchCalls) {
- const std::string rules_buffer = CreateTestGrammar();
- const RulesSet* rules_set =
- flatbuffers::GetRoot<RulesSet>(rules_buffer.data());
- TestCallbackDelegate test_logger(rules_set->debug_information());
- Matcher matcher(unilib_, rules_set, &test_logger);
-
- // Test having the lexer call AddMatch() instead of AddTerminal()
- matcher.AddTerminal(-4, 37, "the");
- Match* match = matcher.AllocateMatch(sizeof(Match));
- match->codepoint_span = {37, 42};
- match->match_offset = 37;
- match->lhs = FindNontermForName(rules_set, "<thestarbrownfox>");
- matcher.AddMatch(match);
-
- EXPECT_THAT(test_logger.GetLogAndReset(),
- ElementsAre(IsTerminal(/*begin=*/-4, /*end=*/37, "the"),
- IsNonterminal(/*begin=*/-4, /*end=*/42,
- "<thestarbrownfox>")));
-}
-
-TEST_F(MatcherTest, HandlesOptionalRuleElements) {
- const std::string rules_buffer = CreateTestGrammarWithOptionalElements();
- const RulesSet* rules_set =
- flatbuffers::GetRoot<RulesSet>(rules_buffer.data());
- TestCallbackDelegate test_logger(rules_set->debug_information());
- Matcher matcher(unilib_, rules_set, &test_logger);
-
- // Run the matcher on "a b c d e".
- matcher.AddTerminal(0, 1, "a");
- matcher.AddTerminal(1, 2, "b");
- matcher.AddTerminal(2, 3, "c");
- matcher.AddTerminal(3, 4, "d");
- matcher.AddTerminal(4, 5, "e");
-
- EXPECT_THAT(test_logger.GetLogAndReset(),
- ElementsAre(IsNonterminal(/*begin=*/0, /*end=*/4, "<output_2>"),
- IsTerminal(/*begin=*/4, /*end=*/5, "e"),
- IsNonterminal(/*begin=*/0, /*end=*/5, "<output_0>"),
- IsNonterminal(/*begin=*/0,
- /*end=*/5, "<output_1>"),
- IsNonterminal(/*begin=*/0, /*end=*/5, "<output_2>"),
- IsNonterminal(/*begin=*/1,
- /*end=*/5, "<output_0>"),
- IsNonterminal(/*begin=*/2, /*end=*/5, "<output_0>"),
- IsNonterminal(/*begin=*/3,
- /*end=*/5, "<output_0>")));
-}
-
-class FilterTestCallbackDelegate : public TestCallbackDelegate {
- public:
- FilterTestCallbackDelegate(
- const RulesSet_::DebugInformation* debug_information,
- const std::string& filter)
- : TestCallbackDelegate(debug_information), filter_(filter) {}
-
- void MatchFound(const Match* candidate, const CallbackId callback_id,
- const int64 callback_param, Matcher* matcher) override {
- TestCallbackDelegate::MatchFound(candidate, callback_id, callback_param,
- matcher);
- // Filter callback.
- if (callback_id == 1) {
- if (candidate->IsTerminalRule() && filter_ != candidate->terminal) {
- return;
- } else {
- std::vector<const Match*> terminals = SelectTerminals(candidate);
- if (terminals.empty() || terminals.front()->terminal != filter_) {
- return;
- }
- }
- Match* match = matcher->AllocateAndInitMatch<Match>(*candidate);
- matcher->AddMatch(match);
- matcher->AddTerminal(match->codepoint_span, match->match_offset,
- "faketerminal");
- }
- }
-
- protected:
- const std::string filter_;
-};
-
-std::string CreateExampleGrammarWithFilters() {
- // Create an example grammar.
- Rules rules;
- const CallbackId filter_callback = 1;
- rules.DefineFilter(filter_callback);
-
- rules.Add("<term_pass>", {"hello"}, filter_callback);
- rules.Add("<nonterm_pass>", {"hello", "there"}, filter_callback);
- rules.Add("<term_fail>", {"there"}, filter_callback);
- rules.Add("<nonterm_fail>", {"there", "world"}, filter_callback);
-
- // We use this to test whether AddTerminal() worked from inside a filter
- // callback.
- const CallbackId output_callback = 2;
- rules.Add("<output_faketerminal>", {"faketerminal"}, output_callback);
-
- // We use this to test whether AddMatch() worked from inside a filter
- // callback.
- rules.Add("<output_term_pass>", {"<term_pass>"}, output_callback);
- rules.Add("<output_nonterm_pass>", {"<nonterm_pass>"}, output_callback);
- rules.Add("<output_term_fail>", {"<term_fail>"}, output_callback);
- rules.Add("<output_nonterm_fail>", {"<term_nonterm_fail>"}, output_callback);
-
- // We use this to make sure rules with output callbacks are always adding
- // their lhs to the chart. This is to check for off-by-one errors in the
- // callback numbering, make sure we don't mistakenly treat the output
- // callback as a filter callback.
- rules.Add("<output>", {"<output_faketerminal>"}, output_callback);
- rules.Add("<output>", {"<output_term_pass>"}, output_callback);
- rules.Add("<output>", {"<output_nonterm_pass>"}, output_callback);
-
- return rules.Finalize().SerializeAsFlatbuffer(
- /*include_debug_information=*/true);
-}
-
-TEST_F(MatcherTest, HandlesTerminalFilters) {
- const std::string rules_buffer = CreateExampleGrammarWithFilters();
- const RulesSet* rules_set =
- flatbuffers::GetRoot<RulesSet>(rules_buffer.data());
- FilterTestCallbackDelegate test_logger(rules_set->debug_information(),
- "hello");
- Matcher matcher(unilib_, rules_set, &test_logger);
- matcher.AddTerminal(0, 1, "hello");
-
- EXPECT_THAT(
- test_logger.GetLogAndReset(),
- ElementsAre(
- // Bubbling up of:
- // "hello" -> "<term_pass>" -> "<output_term_pass>" -> "<output>"
- IsNonterminal(0, 1, "<term_pass>"),
- IsNonterminal(0, 1, "<output_term_pass>"),
- IsNonterminal(0, 1, "<output>"),
-
- // Bubbling up of:
- // "faketerminal" -> "<output_faketerminal>" -> "<output>"
- IsNonterminal(0, 1, "<output_faketerminal>"),
- IsNonterminal(0, 1, "<output>")));
-}
-
-TEST_F(MatcherTest, HandlesNonterminalFilters) {
- const std::string rules_buffer = CreateExampleGrammarWithFilters();
- const RulesSet* rules_set =
- flatbuffers::GetRoot<RulesSet>(rules_buffer.data());
- FilterTestCallbackDelegate test_logger(rules_set->debug_information(),
- "hello");
- Matcher matcher(unilib_, rules_set, &test_logger);
-
- matcher.AddTerminal(0, 1, "hello");
- test_logger.ClearLog();
- matcher.AddTerminal(1, 2, "there");
-
- EXPECT_THAT(test_logger.GetLogAndReset(),
- ElementsAre(
- // "<term_fail>" is filtered, and doesn't bubble up.
- IsNonterminal(1, 2, "<term_fail>"),
-
- // <nonterm_pass> ::= hello there
- IsNonterminal(0, 2, "<nonterm_pass>"),
- IsNonterminal(0, 2, "<output_faketerminal>"),
- IsNonterminal(0, 2, "<output>"),
- IsNonterminal(0, 2, "<output_nonterm_pass>"),
- IsNonterminal(0, 2, "<output>")));
-}
-
-TEST_F(MatcherTest, HandlesWhitespaceGapLimits) {
- Rules rules;
- rules.Add("<iata>", {"lx"});
- rules.Add("<iata>", {"aa"});
- // Require no whitespace between code and flight number.
- rules.Add("<flight_number>", {"<iata>", "<4_digits>"}, /*callback=*/1, 0,
- /*max_whitespace_gap=*/0);
- const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer(
- /*include_debug_information=*/true);
- const RulesSet* rules_set =
- flatbuffers::GetRoot<RulesSet>(rules_buffer.data());
- TestCallbackDelegate test_logger(rules_set->debug_information());
- Matcher matcher(unilib_, rules_set, &test_logger);
-
- // Check that the grammar triggers on LX1138.
- {
- matcher.AddTerminal(0, 2, "LX");
- matcher.AddMatch(matcher.AllocateAndInitMatch<Match>(
- rules_set->nonterminals()->n_digits_nt()->Get(4 - 1),
- CodepointSpan{2, 6}, /*match_offset=*/2));
- EXPECT_THAT(test_logger.GetLogAndReset(),
- ElementsAre(IsNonterminal(0, 6, "<flight_number>")));
- }
-
- // Check that the grammar doesn't trigger on LX 1138.
- {
- matcher.AddTerminal(6, 8, "LX");
- matcher.AddMatch(matcher.AllocateAndInitMatch<Match>(
- rules_set->nonterminals()->n_digits_nt()->Get(4 - 1),
- CodepointSpan{9, 13}, /*match_offset=*/8));
- EXPECT_THAT(test_logger.GetLogAndReset(), IsEmpty());
- }
-}
-
-TEST_F(MatcherTest, HandlesCaseSensitiveTerminals) {
- Rules rules;
- rules.Add("<iata>", {"LX"}, /*callback=*/kNoCallback, 0,
- /*max_whitespace_gap*/ -1, /*case_sensitive=*/true);
- rules.Add("<iata>", {"AA"}, /*callback=*/kNoCallback, 0,
- /*max_whitespace_gap*/ -1, /*case_sensitive=*/true);
- rules.Add("<iata>", {"dl"}, /*callback=*/kNoCallback, 0,
- /*max_whitespace_gap*/ -1, /*case_sensitive=*/false);
- // Require no whitespace between code and flight number.
- rules.Add("<flight_number>", {"<iata>", "<4_digits>"}, /*callback=*/1, 0,
- /*max_whitespace_gap=*/0);
- const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer(
- /*include_debug_information=*/true);
- const RulesSet* rules_set =
- flatbuffers::GetRoot<RulesSet>(rules_buffer.data());
- TestCallbackDelegate test_logger(rules_set->debug_information());
- Matcher matcher(unilib_, rules_set, &test_logger);
-
- // Check that the grammar triggers on LX1138.
- {
- matcher.AddTerminal(0, 2, "LX");
- matcher.AddMatch(matcher.AllocateAndInitMatch<Match>(
- rules_set->nonterminals()->n_digits_nt()->Get(4 - 1),
- CodepointSpan{2, 6}, /*match_offset=*/2));
- EXPECT_THAT(test_logger.GetLogAndReset(),
- ElementsAre(IsNonterminal(0, 6, "<flight_number>")));
- }
-
- // Check that the grammar doesn't trigger on lx1138.
- {
- matcher.AddTerminal(6, 8, "lx");
- matcher.AddMatch(matcher.AllocateAndInitMatch<Match>(
- rules_set->nonterminals()->n_digits_nt()->Get(4 - 1),
- CodepointSpan{8, 12}, /*match_offset=*/8));
- EXPECT_THAT(test_logger.GetLogAndReset(), IsEmpty());
- }
-
- // Check that the grammar does trigger on dl1138.
- {
- matcher.AddTerminal(12, 14, "dl");
- matcher.AddMatch(matcher.AllocateAndInitMatch<Match>(
- rules_set->nonterminals()->n_digits_nt()->Get(4 - 1),
- CodepointSpan{14, 18}, /*match_offset=*/14));
- EXPECT_THAT(test_logger.GetLogAndReset(),
- ElementsAre(IsNonterminal(12, 18, "<flight_number>")));
- }
-}
-
-} // namespace
-} // namespace libtextclassifier3::grammar
diff --git a/native/utils/grammar/rules-utils_test.cc b/native/utils/grammar/rules-utils_test.cc
deleted file mode 100644
index e562508..0000000
--- a/native/utils/grammar/rules-utils_test.cc
+++ /dev/null
@@ -1,85 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/grammar/rules-utils.h"
-
-#include <vector>
-
-#include "utils/grammar/match.h"
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-
-namespace libtextclassifier3::grammar {
-namespace {
-
-using testing::ElementsAre;
-using testing::Value;
-
-// Create test match object.
-Match CreateMatch(const CodepointIndex begin, const CodepointIndex end) {
- Match match;
- match.Init(0, CodepointSpan{begin, end},
- /*arg_match_offset=*/begin);
- return match;
-}
-
-MATCHER_P(IsRuleMatch, candidate, "") {
- return Value(arg.rule_id, candidate.rule_id) &&
- Value(arg.match, candidate.match);
-}
-
-TEST(UtilsTest, DeduplicatesMatches) {
- // Overlapping matches from the same rule.
- Match matches[] = {CreateMatch(0, 1), CreateMatch(1, 2), CreateMatch(0, 2)};
- const std::vector<RuleMatch> candidates = {{&matches[0], /*rule_id=*/0},
- {&matches[1], /*rule_id=*/0},
- {&matches[2], /*rule_id=*/0}};
-
- // Keep longest.
- EXPECT_THAT(DeduplicateMatches(candidates),
- ElementsAre(IsRuleMatch(candidates[2])));
-}
-
-TEST(UtilsTest, DeduplicatesMatchesPerRule) {
- // Overlapping matches from different rules.
- Match matches[] = {CreateMatch(0, 1), CreateMatch(1, 2), CreateMatch(0, 2)};
- const std::vector<RuleMatch> candidates = {{&matches[0], /*rule_id=*/0},
- {&matches[1], /*rule_id=*/0},
- {&matches[2], /*rule_id=*/0},
- {&matches[0], /*rule_id=*/1}};
-
- // Keep longest for rule 0, but also keep match from rule 1.
- EXPECT_THAT(
- DeduplicateMatches(candidates),
- ElementsAre(IsRuleMatch(candidates[2]), IsRuleMatch(candidates[3])));
-}
-
-TEST(UtilsTest, KeepNonoverlapping) {
- // Non-overlapping matches.
- Match matches[] = {CreateMatch(0, 1), CreateMatch(1, 2), CreateMatch(2, 3)};
- const std::vector<RuleMatch> candidates = {{&matches[0], /*rule_id=*/0},
- {&matches[1], /*rule_id=*/0},
- {&matches[2], /*rule_id=*/0}};
-
- // Keep all matches.
- EXPECT_THAT(
- DeduplicateMatches(candidates),
- ElementsAre(IsRuleMatch(candidates[0]), IsRuleMatch(candidates[1]),
- IsRuleMatch(candidates[2])));
-}
-
-} // namespace
-} // namespace libtextclassifier3::grammar
diff --git a/native/utils/grammar/rules.fbs b/native/utils/grammar/rules.fbs
index 2de7a95..41c2c86 100755
--- a/native/utils/grammar/rules.fbs
+++ b/native/utils/grammar/rules.fbs
@@ -128,6 +128,9 @@
// `n_digits_nt[k]` is the id of the nonterminal indicating a string of
// `k` digits.
n_digits_nt:[int];
+
+ // Id of the nonterminal indicating a word or token boundary.
+ wordbreak_nt:int;
}
// Callback information.
diff --git a/native/utils/grammar/utils/ir.cc b/native/utils/grammar/utils/ir.cc
index 5184c6b..6be1128 100644
--- a/native/utils/grammar/utils/ir.cc
+++ b/native/utils/grammar/utils/ir.cc
@@ -406,6 +406,7 @@
output->nonterminals.reset(new RulesSet_::NonterminalsT);
output->nonterminals->start_nt = GetNonterminalForName(kStartNonterm);
output->nonterminals->end_nt = GetNonterminalForName(kEndNonterm);
+ output->nonterminals->wordbreak_nt = GetNonterminalForName(kWordBreakNonterm);
output->nonterminals->token_nt = GetNonterminalForName(kTokenNonterm);
output->nonterminals->digits_nt = GetNonterminalForName(kDigitsNonterm);
for (int i = 1; i <= kMaxNDigitsNontermLength; i++) {
diff --git a/native/utils/grammar/utils/ir.h b/native/utils/grammar/utils/ir.h
index 7e5c984..89087c6 100644
--- a/native/utils/grammar/utils/ir.h
+++ b/native/utils/grammar/utils/ir.h
@@ -31,6 +31,7 @@
// Pre-defined nonterminal classes that the lexer can handle.
constexpr const char* kStartNonterm = "<^>";
constexpr const char* kEndNonterm = "<$>";
+constexpr const char* kWordBreakNonterm = "<\b>";
constexpr const char* kTokenNonterm = "<token>";
constexpr const char* kDigitsNonterm = "<digits>";
constexpr const char* kNDigitsNonterm = "<%d_digits>";
diff --git a/native/utils/grammar/utils/ir_test.cc b/native/utils/grammar/utils/ir_test.cc
deleted file mode 100644
index b3f0417..0000000
--- a/native/utils/grammar/utils/ir_test.cc
+++ /dev/null
@@ -1,238 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/grammar/utils/ir.h"
-
-#include "utils/grammar/rules_generated.h"
-#include "utils/grammar/types.h"
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-
-namespace libtextclassifier3::grammar {
-namespace {
-
-using testing::Eq;
-using testing::IsEmpty;
-using testing::Ne;
-
-TEST(IrTest, HandlesSharingWithTerminalRules) {
- Ir ir;
-
- // <t1> ::= the
- const Nonterm t1 = ir.Add(kUnassignedNonterm, "the");
-
- // <t2> ::= quick
- const Nonterm t2 = ir.Add(kUnassignedNonterm, "quick");
-
- // <t3> ::= quick -- should share with <t2>
- const Nonterm t3 = ir.Add(kUnassignedNonterm, "quick");
-
- // <t4> ::= quick -- specify unshareable <t4>
- // <t4> ::= brown
- const Nonterm t4_unshareable = ir.AddUnshareableNonterminal();
- ir.Add(t4_unshareable, "quick");
- ir.Add(t4_unshareable, "brown");
-
- // <t5> ::= brown -- should not be shared with <t4>
- const Nonterm t5 = ir.Add(kUnassignedNonterm, "brown");
-
- // <t6> ::= brown -- specify unshareable <t6>
- const Nonterm t6_unshareable = ir.AddUnshareableNonterminal();
- ir.Add(t6_unshareable, "brown");
-
- // <t7> ::= brown -- should share with <t5>
- const Nonterm t7 = ir.Add(kUnassignedNonterm, "brown");
-
- EXPECT_THAT(t1, Ne(kUnassignedNonterm));
- EXPECT_THAT(t2, Ne(kUnassignedNonterm));
- EXPECT_THAT(t1, Ne(t2));
- EXPECT_THAT(t2, Eq(t3));
- EXPECT_THAT(t4_unshareable, Ne(kUnassignedNonterm));
- EXPECT_THAT(t4_unshareable, Ne(t3));
- EXPECT_THAT(t4_unshareable, Ne(t5));
- EXPECT_THAT(t6_unshareable, Ne(kUnassignedNonterm));
- EXPECT_THAT(t6_unshareable, Ne(t4_unshareable));
- EXPECT_THAT(t6_unshareable, Ne(t5));
- EXPECT_THAT(t7, Eq(t5));
-}
-
-TEST(IrTest, HandlesSharingWithNonterminalRules) {
- Ir ir;
-
- // Setup a few terminal rules.
- const std::vector<Nonterm> rhs = {
- ir.Add(kUnassignedNonterm, "the"), ir.Add(kUnassignedNonterm, "quick"),
- ir.Add(kUnassignedNonterm, "brown"), ir.Add(kUnassignedNonterm, "fox")};
-
- // Check for proper sharing using nonterminal rules.
- for (int rhs_length = 1; rhs_length <= rhs.size(); rhs_length++) {
- std::vector<Nonterm> rhs_truncated = rhs;
- rhs_truncated.resize(rhs_length);
- const Nonterm nt_u = ir.AddUnshareableNonterminal();
- ir.Add(nt_u, rhs_truncated);
- const Nonterm nt_1 = ir.Add(kUnassignedNonterm, rhs_truncated);
- const Nonterm nt_2 = ir.Add(kUnassignedNonterm, rhs_truncated);
-
- EXPECT_THAT(nt_1, Eq(nt_2));
- EXPECT_THAT(nt_1, Ne(nt_u));
- }
-}
-
-TEST(IrTest, HandlesSharingWithCallbacksWithSameParameters) {
- // Test sharing in the presence of callbacks.
- constexpr CallbackId kOutput1 = 1;
- constexpr CallbackId kOutput2 = 2;
- constexpr CallbackId kFilter1 = 3;
- constexpr CallbackId kFilter2 = 4;
- Ir ir(/*filters=*/{kFilter1, kFilter2});
-
- const Nonterm x1 = ir.Add(kUnassignedNonterm, "hello");
- const Nonterm x2 =
- ir.Add(Ir::Lhs{kUnassignedNonterm, {kOutput1, 0}}, "hello");
- const Nonterm x3 =
- ir.Add(Ir::Lhs{kUnassignedNonterm, {kFilter1, 0}}, "hello");
- const Nonterm x4 =
- ir.Add(Ir::Lhs{kUnassignedNonterm, {kOutput2, 0}}, "hello");
- const Nonterm x5 =
- ir.Add(Ir::Lhs{kUnassignedNonterm, {kFilter2, 0}}, "hello");
-
- // Duplicate entry.
- const Nonterm x6 =
- ir.Add(Ir::Lhs{kUnassignedNonterm, {kOutput2, 0}}, "hello");
-
- EXPECT_THAT(x2, Eq(x1));
- EXPECT_THAT(x3, Ne(x1));
- EXPECT_THAT(x4, Eq(x1));
- EXPECT_THAT(x5, Ne(x1));
- EXPECT_THAT(x5, Ne(x3));
- EXPECT_THAT(x6, Ne(x3));
-}
-
-TEST(IrTest, HandlesSharingWithCallbacksWithDifferentParameters) {
- // Test sharing in the presence of callbacks.
- constexpr CallbackId kOutput = 1;
- constexpr CallbackId kFilter = 2;
- Ir ir(/*filters=*/{kFilter});
-
- const Nonterm x1 = ir.Add(Ir::Lhs{kUnassignedNonterm, {kOutput, 0}}, "world");
- const Nonterm x2 = ir.Add(Ir::Lhs{kUnassignedNonterm, {kOutput, 1}}, "world");
- const Nonterm x3 = ir.Add(Ir::Lhs{kUnassignedNonterm, {kFilter, 0}}, "world");
- const Nonterm x4 = ir.Add(Ir::Lhs{kUnassignedNonterm, {kFilter, 1}}, "world");
-
- EXPECT_THAT(x2, Eq(x1));
- EXPECT_THAT(x3, Ne(x1));
- EXPECT_THAT(x4, Ne(x1));
- EXPECT_THAT(x4, Ne(x3));
-}
-
-TEST(IrTest, SerializesRulesToFlatbufferFormat) {
- constexpr CallbackId kOutput = 1;
- Ir ir;
- const Nonterm verb = ir.AddUnshareableNonterminal();
- ir.Add(verb, "buy");
- ir.Add(Ir::Lhs{verb, {kOutput}}, "bring");
- ir.Add(verb, "upbring");
- ir.Add(verb, "remind");
- const Nonterm set_reminder = ir.AddUnshareableNonterminal();
- ir.Add(set_reminder,
- std::vector<Nonterm>{ir.Add(kUnassignedNonterm, "remind"),
- ir.Add(kUnassignedNonterm, "me"),
- ir.Add(kUnassignedNonterm, "to"), verb});
- const Nonterm action = ir.AddUnshareableNonterminal();
- ir.Add(action, set_reminder);
- RulesSetT rules;
- ir.Serialize(/*include_debug_information=*/false, &rules);
-
- EXPECT_THAT(rules.rules.size(), Eq(1));
-
- // Only one rule uses a callback, the rest will be encoded directly.
- EXPECT_THAT(rules.lhs.size(), Eq(1));
- EXPECT_THAT(rules.lhs.front().callback_id(), kOutput);
-
- // 6 distinct terminals: "buy", "upbring", "bring", "remind", "me" and "to".
- EXPECT_THAT(
- rules.rules.front()->lowercase_terminal_rules->terminal_offsets.size(),
- Eq(6));
- EXPECT_THAT(rules.rules.front()->terminal_rules->terminal_offsets, IsEmpty());
-
- // As "bring" is a suffix of "upbring" it is expected to be suffix merged in
- // the string pool
- EXPECT_THAT(rules.terminals,
- Eq(std::string("buy\0me\0remind\0to\0upbring\0", 25)));
-
- EXPECT_THAT(rules.rules.front()->binary_rules.size(), Eq(3));
-
- // One unary rule: <action> ::= <set_reminder>
- EXPECT_THAT(rules.rules.front()->unary_rules.size(), Eq(1));
-}
-
-TEST(IrTest, HandlesRulesSharding) {
- Ir ir(/*filters=*/{}, /*num_shards=*/2);
- const Nonterm verb = ir.AddUnshareableNonterminal();
- const Nonterm set_reminder = ir.AddUnshareableNonterminal();
-
- // Shard 0: en
- ir.Add(verb, "buy");
- ir.Add(verb, "bring");
- ir.Add(verb, "remind");
- ir.Add(set_reminder,
- std::vector<Nonterm>{ir.Add(kUnassignedNonterm, "remind"),
- ir.Add(kUnassignedNonterm, "me"),
- ir.Add(kUnassignedNonterm, "to"), verb});
-
- // Shard 1: de
- ir.Add(verb, "kaufen", /*case_sensitive=*/false, /*shard=*/1);
- ir.Add(verb, "bringen", /*case_sensitive=*/false, /*shard=*/1);
- ir.Add(verb, "erinnern", /*case_sensitive=*/false, /*shard=*/1);
- ir.Add(set_reminder,
- std::vector<Nonterm>{ir.Add(kUnassignedNonterm, "erinnere",
- /*case_sensitive=*/false, /*shard=*/1),
- ir.Add(kUnassignedNonterm, "mich",
- /*case_sensitive=*/false, /*shard=*/1),
- ir.Add(kUnassignedNonterm, "zu",
- /*case_sensitive=*/false, /*shard=*/1),
- verb},
- /*shard=*/1);
-
- // Test that terminal strings are correctly merged into the shared
- // string pool.
- RulesSetT rules;
- ir.Serialize(/*include_debug_information=*/false, &rules);
-
- EXPECT_THAT(rules.rules.size(), Eq(2));
-
- // 5 distinct terminals: "buy", "bring", "remind", "me" and "to".
- EXPECT_THAT(rules.rules[0]->lowercase_terminal_rules->terminal_offsets.size(),
- Eq(5));
- EXPECT_THAT(rules.rules[0]->terminal_rules->terminal_offsets, IsEmpty());
-
- // 6 distinct terminals: "kaufen", "bringen", "erinnern", "erinnere", "mich"
- // and "zu".
- EXPECT_THAT(rules.rules[1]->lowercase_terminal_rules->terminal_offsets.size(),
- Eq(6));
- EXPECT_THAT(rules.rules[1]->terminal_rules->terminal_offsets, IsEmpty());
-
- EXPECT_THAT(rules.terminals,
- Eq(std::string("bring\0bringen\0buy\0erinnere\0erinnern\0kaufen\0"
- "me\0mich\0remind\0to\0zu\0",
- 64)));
-
- EXPECT_THAT(rules.rules[0]->binary_rules.size(), Eq(3));
- EXPECT_THAT(rules.rules[1]->binary_rules.size(), Eq(3));
-}
-
-} // namespace
-} // namespace libtextclassifier3::grammar
diff --git a/native/utils/grammar/utils/rules.cc b/native/utils/grammar/utils/rules.cc
index af7a56f..4dbab61 100644
--- a/native/utils/grammar/utils/rules.cc
+++ b/native/utils/grammar/utils/rules.cc
@@ -18,6 +18,7 @@
#include <set>
+#include "utils/grammar/utils/ir.h"
#include "utils/strings/append.h"
namespace libtextclassifier3::grammar {
@@ -26,7 +27,8 @@
// Returns whether a nonterminal is a pre-defined one.
bool IsPredefinedNonterminal(const std::string& nonterminal_name) {
if (nonterminal_name == kStartNonterm || nonterminal_name == kEndNonterm ||
- nonterminal_name == kTokenNonterm || nonterminal_name == kDigitsNonterm) {
+ nonterminal_name == kTokenNonterm || nonterminal_name == kDigitsNonterm ||
+ nonterminal_name == kWordBreakNonterm) {
return true;
}
for (int digits = 1; digits <= kMaxNDigitsNontermLength; digits++) {
@@ -99,6 +101,22 @@
rhs_nonterms, rule.shard);
}
+// Check whether this component is a non-terminal.
+bool IsNonterminal(StringPiece rhs_component) {
+ return rhs_component[0] == '<' &&
+ rhs_component[rhs_component.size() - 1] == '>';
+}
+
+// Sanity check for common typos -- '<' or '>' in a terminal.
+void ValidateTerminal(StringPiece rhs_component) {
+ TC3_CHECK_EQ(rhs_component.find('<'), std::string::npos)
+ << "Rhs terminal `" << rhs_component << "` contains an angle bracket.";
+ TC3_CHECK_EQ(rhs_component.find('>'), std::string::npos)
+ << "Rhs terminal `" << rhs_component << "` contains an angle bracket.";
+ TC3_CHECK_EQ(rhs_component.find('?'), std::string::npos)
+ << "Rhs terminal `" << rhs_component << "` contains a question mark.";
+}
+
} // namespace
int Rules::AddNonterminal(StringPiece nonterminal_name) {
@@ -175,20 +193,12 @@
}
// Check whether this component is a non-terminal.
- if (rhs_component[0] == '<' &&
- rhs_component[rhs_component.size() - 1] == '>') {
+ if (IsNonterminal(rhs_component)) {
rhs_elements.push_back(RhsElement(AddNonterminal(rhs_component)));
} else {
// A terminal.
// Sanity check for common typos -- '<' or '>' in a terminal.
- TC3_CHECK_EQ(rhs_component.find('<'), std::string::npos)
- << "Rhs terminal `" << rhs_component
- << "` contains an angle bracket.";
- TC3_CHECK_EQ(rhs_component.find('>'), std::string::npos)
- << "Rhs terminal `" << rhs_component
- << "` contains an angle bracket.";
- TC3_CHECK_EQ(rhs_component.find('?'), std::string::npos)
- << "Rhs terminal `" << rhs_component << "` contains a question mark.";
+ ValidateTerminal(rhs_component);
rhs_elements.push_back(RhsElement(rhs_component.ToString()));
}
}
@@ -203,7 +213,7 @@
optional_element_indices.end(), &omit_these);
}
-Ir Rules::Finalize() const {
+Ir Rules::Finalize(const std::set<std::string>& predefined_nonterminals) const {
Ir rules(filters_, num_shards_);
std::unordered_map<int, Nonterm> nonterminal_ids;
@@ -212,7 +222,9 @@
// Define all used predefined nonterminals.
for (const auto it : nonterminal_names_) {
- if (IsPredefinedNonterminal(it.first)) {
+ if (IsPredefinedNonterminal(it.first) ||
+ predefined_nonterminals.find(it.first) !=
+ predefined_nonterminals.end()) {
nonterminal_ids[it.second] = rules.AddUnshareableNonterminal(it.first);
}
}
diff --git a/native/utils/grammar/utils/rules.h b/native/utils/grammar/utils/rules.h
index 8a5257c..42fa7cd 100644
--- a/native/utils/grammar/utils/rules.h
+++ b/native/utils/grammar/utils/rules.h
@@ -110,7 +110,11 @@
void DefineFilter(const CallbackId filter_id) { filters_.insert(filter_id); }
// Lowers the rule set into the intermediate representation.
- Ir Finalize() const;
+ // Treats nonterminals given by the argument `predefined_nonterminals` as
+ // defined externally. This allows to define rules that are dependent on
+ // non-terminals produced by e.g. existing text annotations and that will be
+ // fed to the matcher by the lexer.
+ Ir Finalize(const std::set<std::string>& predefined_nonterminals = {}) const;
private:
// Expands optional components in rules.
diff --git a/native/utils/grammar/utils/rules_test.cc b/native/utils/grammar/utils/rules_test.cc
deleted file mode 100644
index 84183f4..0000000
--- a/native/utils/grammar/utils/rules_test.cc
+++ /dev/null
@@ -1,131 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/grammar/utils/rules.h"
-
-#include "utils/grammar/rules_generated.h"
-#include "utils/grammar/utils/ir.h"
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-
-namespace libtextclassifier3::grammar {
-namespace {
-
-using testing::Eq;
-using testing::IsEmpty;
-
-TEST(SerializeRulesTest, HandlesSimpleRuleSet) {
- Rules rules;
-
- rules.Add("<verb>", {"buy"});
- rules.Add("<verb>", {"bring"});
- rules.Add("<verb>", {"remind"});
- rules.Add("<reminder>", {"remind", "me", "to", "<verb>"});
- rules.Add("<action>", {"<reminder>"});
-
- const Ir ir = rules.Finalize();
- RulesSetT frozen_rules;
- ir.Serialize(/*include_debug_information=*/false, &frozen_rules);
-
- EXPECT_THAT(frozen_rules.rules.size(), Eq(1));
- EXPECT_THAT(frozen_rules.lhs, IsEmpty());
- EXPECT_THAT(frozen_rules.terminals,
- Eq(std::string("bring\0buy\0me\0remind\0to\0", 23)));
- EXPECT_THAT(frozen_rules.rules.front()->binary_rules.size(), Eq(3));
- EXPECT_THAT(frozen_rules.rules.front()->unary_rules.size(), Eq(1));
-}
-
-TEST(SerializeRulesTest, HandlesRulesSetWithCallbacks) {
- Rules rules;
- const CallbackId output = 1;
- const CallbackId filter = 2;
- rules.DefineFilter(filter);
-
- rules.Add("<verb>", {"buy"});
- rules.Add("<verb>", {"bring"}, output, 0);
- rules.Add("<verb>", {"remind"}, output, 0);
- rules.Add("<reminder>", {"remind", "me", "to", "<verb>"});
- rules.Add("<action>", {"<reminder>"}, filter, 0);
-
- const Ir ir = rules.Finalize();
- RulesSetT frozen_rules;
- ir.Serialize(/*include_debug_information=*/false, &frozen_rules);
-
- EXPECT_THAT(frozen_rules.rules.size(), Eq(1));
- EXPECT_THAT(frozen_rules.terminals,
- Eq(std::string("bring\0buy\0me\0remind\0to\0", 23)));
-
- // We have two identical output calls and one filter call in the rule set
- // definition above.
- EXPECT_THAT(frozen_rules.lhs.size(), Eq(2));
-
- EXPECT_THAT(frozen_rules.rules.front()->binary_rules.size(), Eq(3));
- EXPECT_THAT(frozen_rules.rules.front()->unary_rules.size(), Eq(1));
-}
-
-TEST(SerializeRulesTest, HandlesRulesWithWhitespaceGapLimits) {
- Rules rules;
- rules.Add("<iata>", {"lx"});
- rules.Add("<iata>", {"aa"});
- rules.Add("<flight>", {"<iata>", "<4_digits>"}, kNoCallback, 0,
- /*max_whitespace_gap=*/0);
-
- const Ir ir = rules.Finalize();
- RulesSetT frozen_rules;
- ir.Serialize(/*include_debug_information=*/false, &frozen_rules);
-
- EXPECT_THAT(frozen_rules.rules.size(), Eq(1));
- EXPECT_THAT(frozen_rules.terminals, Eq(std::string("aa\0lx\0", 6)));
- EXPECT_THAT(frozen_rules.lhs.size(), Eq(1));
-}
-
-TEST(SerializeRulesTest, HandlesCaseSensitiveTerminals) {
- Rules rules;
- rules.Add("<iata>", {"LX"}, kNoCallback, 0, /*max_whitespace_gap=*/-1,
- /*case_sensitive=*/true);
- rules.Add("<iata>", {"AA"}, kNoCallback, 0, /*max_whitespace_gap=*/-1,
- /*case_sensitive=*/true);
- rules.Add("<iata>", {"dl"}, kNoCallback, 0, /*max_whitespace_gap=*/-1,
- /*case_sensitive=*/false);
- rules.Add("<flight>", {"<iata>", "<4_digits>"}, kNoCallback, 0,
- /*max_whitespace_gap=*/0);
-
- const Ir ir = rules.Finalize();
- RulesSetT frozen_rules;
- ir.Serialize(/*include_debug_information=*/false, &frozen_rules);
-
- EXPECT_THAT(frozen_rules.rules.size(), Eq(1));
- EXPECT_THAT(frozen_rules.terminals, Eq(std::string("AA\0LX\0dl\0", 9)));
- EXPECT_THAT(frozen_rules.lhs.size(), Eq(1));
-}
-
-TEST(SerializeRulesTest, HandlesMultipleShards) {
- Rules rules(/*num_shards=*/2);
- rules.Add("<iata>", {"LX"}, kNoCallback, 0, /*max_whitespace_gap=*/-1,
- /*case_sensitive=*/true, /*shard=*/0);
- rules.Add("<iata>", {"aa"}, kNoCallback, 0, /*max_whitespace_gap=*/-1,
- /*case_sensitive=*/false, /*shard=*/1);
-
- const Ir ir = rules.Finalize();
- RulesSetT frozen_rules;
- ir.Serialize(/*include_debug_information=*/false, &frozen_rules);
-
- EXPECT_THAT(frozen_rules.rules.size(), Eq(2));
- EXPECT_THAT(frozen_rules.terminals, Eq(std::string("LX\0aa\0", 6)));
-}
-
-} // namespace
-} // namespace libtextclassifier3::grammar
diff --git a/native/utils/i18n/locale_test.cc b/native/utils/i18n/locale_test.cc
deleted file mode 100644
index faea4f6..0000000
--- a/native/utils/i18n/locale_test.cc
+++ /dev/null
@@ -1,103 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/i18n/locale.h"
-
-#include "gtest/gtest.h"
-
-namespace libtextclassifier3 {
-namespace {
-
-TEST(LocaleTest, ParseUnknown) {
- Locale locale = Locale::Invalid();
- EXPECT_FALSE(locale.IsValid());
-}
-
-TEST(LocaleTest, ParseSwissEnglish) {
- Locale locale = Locale::FromBCP47("en-CH");
- EXPECT_TRUE(locale.IsValid());
- EXPECT_EQ(locale.Language(), "en");
- EXPECT_EQ(locale.Script(), "");
- EXPECT_EQ(locale.Region(), "CH");
-}
-
-TEST(LocaleTest, ParseChineseChina) {
- Locale locale = Locale::FromBCP47("zh-CN");
- EXPECT_TRUE(locale.IsValid());
- EXPECT_EQ(locale.Language(), "zh");
- EXPECT_EQ(locale.Script(), "");
- EXPECT_EQ(locale.Region(), "CN");
-}
-
-TEST(LocaleTest, ParseChineseTaiwan) {
- Locale locale = Locale::FromBCP47("zh-Hant-TW");
- EXPECT_TRUE(locale.IsValid());
- EXPECT_EQ(locale.Language(), "zh");
- EXPECT_EQ(locale.Script(), "Hant");
- EXPECT_EQ(locale.Region(), "TW");
-}
-
-TEST(LocaleTest, ParseEnglish) {
- Locale locale = Locale::FromBCP47("en");
- EXPECT_TRUE(locale.IsValid());
- EXPECT_EQ(locale.Language(), "en");
- EXPECT_EQ(locale.Script(), "");
- EXPECT_EQ(locale.Region(), "");
-}
-
-TEST(LocaleTest, ParseCineseTraditional) {
- Locale locale = Locale::FromBCP47("zh-Hant");
- EXPECT_TRUE(locale.IsValid());
- EXPECT_EQ(locale.Language(), "zh");
- EXPECT_EQ(locale.Script(), "Hant");
- EXPECT_EQ(locale.Region(), "");
-}
-
-TEST(LocaleTest, IsAnyLocaleSupportedMatch) {
- std::vector<Locale> locales = {Locale::FromBCP47("zh-HK"),
- Locale::FromBCP47("en-UK")};
- std::vector<Locale> supported_locales = {Locale::FromBCP47("en")};
-
- EXPECT_TRUE(Locale::IsAnyLocaleSupported(locales, supported_locales,
- /*default_value=*/false));
-}
-
-TEST(LocaleTest, IsAnyLocaleSupportedNotMatch) {
- std::vector<Locale> locales = {Locale::FromBCP47("zh-tw")};
- std::vector<Locale> supported_locales = {Locale::FromBCP47("en"),
- Locale::FromBCP47("fr")};
-
- EXPECT_FALSE(Locale::IsAnyLocaleSupported(locales, supported_locales,
- /*default_value=*/false));
-}
-
-TEST(LocaleTest, IsAnyLocaleSupportedAnyLocale) {
- std::vector<Locale> locales = {Locale::FromBCP47("zh-tw")};
- std::vector<Locale> supported_locales = {Locale::FromBCP47("*")};
-
- EXPECT_TRUE(Locale::IsAnyLocaleSupported(locales, supported_locales,
- /*default_value=*/false));
-}
-
-TEST(LocaleTest, IsAnyLocaleSupportedEmptyLocales) {
- std::vector<Locale> supported_locales = {Locale::FromBCP47("en")};
-
- EXPECT_TRUE(Locale::IsAnyLocaleSupported({}, supported_locales,
- /*default_value=*/true));
-}
-
-} // namespace
-} // namespace libtextclassifier3
diff --git a/native/utils/intents/intent-generator.cc b/native/utils/intents/intent-generator.cc
index 7dc24e4..36453bb 100644
--- a/native/utils/intents/intent-generator.cc
+++ b/native/utils/intents/intent-generator.cc
@@ -56,6 +56,8 @@
static constexpr const char* kDeviceLocaleKey = "device_locales";
static constexpr const char* kFormatKey = "format";
+static constexpr const int kIndexStackTop = -1;
+
// An Android specific Lua environment with JNI backed callbacks.
class JniLuaEnvironment : public LuaEnvironment {
public:
@@ -91,8 +93,19 @@
// Reads the extras from the Lua result.
void ReadExtras(std::map<std::string, Variant>* extra);
- // Reads the intent categories array from a Lua result.
- void ReadCategories(std::vector<std::string>* category);
+ // Reads a vector from a Lua result. read_element_func is the function to read
+ // a single element from the lua side.
+ template <class T>
+ std::vector<T> ReadVector(const std::function<T()> read_element_func) const;
+
+ // Reads a string vector from a Lua result.
+ std::vector<std::string> ReadStringVector() const;
+
+ // Reads a float vector from a Lua result.
+ std::vector<float> ReadFloatVector() const;
+
+ // Reads a int vector from a Lua result.
+ std::vector<int> ReadIntVector() const;
// Retrieves user manager if not previously done.
bool RetrieveUserManager();
@@ -193,7 +206,7 @@
}
int JniLuaEnvironment::HandleExternalCallback() {
- const StringPiece key = ReadString(/*index=*/-1);
+ const StringPiece key = ReadString(kIndexStackTop);
if (key.Equals(kHashKey)) {
Bind<JniLuaEnvironment, &JniLuaEnvironment::HandleHash>();
return 1;
@@ -208,7 +221,7 @@
}
int JniLuaEnvironment::HandleAndroidCallback() {
- const StringPiece key = ReadString(/*index=*/-1);
+ const StringPiece key = ReadString(kIndexStackTop);
if (key.Equals(kDeviceLocaleKey)) {
// Provide the locale as table with the individual fields set.
lua_newtable(state_);
@@ -292,7 +305,7 @@
return 0;
}
- const StringPiece key_str = ReadString(/*index=*/-1);
+ const StringPiece key_str = ReadString(kIndexStackTop);
if (key_str.empty()) {
TC3_LOG(ERROR) << "Expected string, got null.";
lua_error(state_);
@@ -410,7 +423,7 @@
}
int JniLuaEnvironment::HandleUrlHost() {
- const StringPiece url = ReadString(/*index=*/-1);
+ const StringPiece url = ReadString(kIndexStackTop);
const StatusOr<ScopedLocalRef<jobject>> status_or_parsed_uri = ParseUri(url);
if (!status_or_parsed_uri.ok()) {
@@ -443,7 +456,7 @@
}
int JniLuaEnvironment::HandleHash() {
- const StringPiece input = ReadString(/*index=*/-1);
+ const StringPiece input = ReadString(kIndexStackTop);
lua_pushinteger(state_, tc3farmhash::Hash32(input.data(), input.length()));
return 1;
}
@@ -464,7 +477,7 @@
return false;
}
- const StringPiece resource_name = ReadString(/*index=*/-1);
+ const StringPiece resource_name = ReadString(kIndexStackTop);
std::string resource_content;
if (!resources_.GetResourceContent(device_locales_, resource_name,
&resource_content)) {
@@ -492,10 +505,10 @@
int resource_id;
switch (lua_type(state_, -1)) {
case LUA_TNUMBER:
- resource_id = static_cast<int>(lua_tonumber(state_, /*idx=*/-1));
+ resource_id = static_cast<int>(lua_tonumber(state_, kIndexStackTop));
break;
case LUA_TSTRING: {
- const StringPiece resource_name_str = ReadString(/*index=*/-1);
+ const StringPiece resource_name_str = ReadString(kIndexStackTop);
if (resource_name_str.empty()) {
TC3_LOG(ERROR) << "No resource name provided.";
lua_error(state_);
@@ -595,27 +608,28 @@
while (lua_next(state_, /*idx=*/-2)) {
const StringPiece key = ReadString(/*index=*/-2);
if (key.Equals("title_without_entity")) {
- result.title_without_entity = ReadString(/*index=*/-1).ToString();
+ result.title_without_entity = ReadString(kIndexStackTop).ToString();
} else if (key.Equals("title_with_entity")) {
- result.title_with_entity = ReadString(/*index=*/-1).ToString();
+ result.title_with_entity = ReadString(kIndexStackTop).ToString();
} else if (key.Equals("description")) {
- result.description = ReadString(/*index=*/-1).ToString();
+ result.description = ReadString(kIndexStackTop).ToString();
} else if (key.Equals("description_with_app_name")) {
- result.description_with_app_name = ReadString(/*index=*/-1).ToString();
+ result.description_with_app_name = ReadString(kIndexStackTop).ToString();
} else if (key.Equals("action")) {
- result.action = ReadString(/*index=*/-1).ToString();
+ result.action = ReadString(kIndexStackTop).ToString();
} else if (key.Equals("data")) {
- result.data = ReadString(/*index=*/-1).ToString();
+ result.data = ReadString(kIndexStackTop).ToString();
} else if (key.Equals("type")) {
- result.type = ReadString(/*index=*/-1).ToString();
+ result.type = ReadString(kIndexStackTop).ToString();
} else if (key.Equals("flags")) {
- result.flags = static_cast<int>(lua_tointeger(state_, /*idx=*/-1));
+ result.flags = static_cast<int>(lua_tointeger(state_, kIndexStackTop));
} else if (key.Equals("package_name")) {
- result.package_name = ReadString(/*index=*/-1).ToString();
+ result.package_name = ReadString(kIndexStackTop).ToString();
} else if (key.Equals("request_code")) {
- result.request_code = static_cast<int>(lua_tointeger(state_, /*idx=*/-1));
+ result.request_code =
+ static_cast<int>(lua_tointeger(state_, kIndexStackTop));
} else if (key.Equals("category")) {
- ReadCategories(&result.category);
+ result.category = ReadStringVector();
} else if (key.Equals("extra")) {
ReadExtras(&result.extra);
} else {
@@ -627,25 +641,43 @@
return result;
}
-void JniLuaEnvironment::ReadCategories(std::vector<std::string>* category) {
- // Read category array.
- if (lua_type(state_, /*idx=*/-1) != LUA_TTABLE) {
- TC3_LOG(ERROR) << "Expected categories table, got: "
- << lua_type(state_, /*idx=*/-1);
+template <class T>
+std::vector<T> JniLuaEnvironment::ReadVector(
+ const std::function<T()> read_element_func) const {
+ std::vector<T> vector;
+ if (lua_type(state_, kIndexStackTop) != LUA_TTABLE) {
+ TC3_LOG(ERROR) << "Expected a table, got: "
+ << lua_type(state_, kIndexStackTop);
lua_pop(state_, 1);
- return;
+ return {};
}
lua_pushnil(state_);
while (lua_next(state_, /*idx=*/-2)) {
- category->push_back(ReadString(/*index=*/-1).ToString());
+ vector.push_back(read_element_func());
lua_pop(state_, 1);
}
+ return vector;
+}
+
+std::vector<std::string> JniLuaEnvironment::ReadStringVector() const {
+ return ReadVector<std::string>(
+ [this]() { return this->ReadString(kIndexStackTop).ToString(); });
+}
+
+std::vector<float> JniLuaEnvironment::ReadFloatVector() const {
+ return ReadVector<float>(
+ [this]() { return lua_tonumber(state_, kIndexStackTop); });
+}
+
+std::vector<int> JniLuaEnvironment::ReadIntVector() const {
+ return ReadVector<int>(
+ [this]() { return lua_tonumber(state_, kIndexStackTop); });
}
void JniLuaEnvironment::ReadExtras(std::map<std::string, Variant>* extra) {
- if (lua_type(state_, /*idx=*/-1) != LUA_TTABLE) {
+ if (lua_type(state_, kIndexStackTop) != LUA_TTABLE) {
TC3_LOG(ERROR) << "Expected extras table, got: "
- << lua_type(state_, /*idx=*/-1);
+ << lua_type(state_, kIndexStackTop);
lua_pop(state_, 1);
return;
}
@@ -654,9 +686,9 @@
// Each entry is a table specifying name and value.
// The value is specified via a type specific field as Lua doesn't allow
// to easily distinguish between different number types.
- if (lua_type(state_, /*idx=*/-1) != LUA_TTABLE) {
+ if (lua_type(state_, kIndexStackTop) != LUA_TTABLE) {
TC3_LOG(ERROR) << "Expected a table for an extra, got: "
- << lua_type(state_, /*idx=*/-1);
+ << lua_type(state_, kIndexStackTop);
lua_pop(state_, 1);
return;
}
@@ -667,17 +699,26 @@
while (lua_next(state_, /*idx=*/-2)) {
const StringPiece key = ReadString(/*index=*/-2);
if (key.Equals("name")) {
- name = ReadString(/*index=*/-1).ToString();
+ name = ReadString(kIndexStackTop).ToString();
} else if (key.Equals("int_value")) {
- value = Variant(static_cast<int>(lua_tonumber(state_, /*idx=*/-1)));
+ value = Variant(static_cast<int>(lua_tonumber(state_, kIndexStackTop)));
} else if (key.Equals("long_value")) {
- value = Variant(static_cast<int64>(lua_tonumber(state_, /*idx=*/-1)));
+ value =
+ Variant(static_cast<int64>(lua_tonumber(state_, kIndexStackTop)));
} else if (key.Equals("float_value")) {
- value = Variant(static_cast<float>(lua_tonumber(state_, /*idx=*/-1)));
+ value =
+ Variant(static_cast<float>(lua_tonumber(state_, kIndexStackTop)));
} else if (key.Equals("bool_value")) {
- value = Variant(static_cast<bool>(lua_toboolean(state_, /*idx=*/-1)));
+ value =
+ Variant(static_cast<bool>(lua_toboolean(state_, kIndexStackTop)));
} else if (key.Equals("string_value")) {
- value = Variant(ReadString(/*index=*/-1).ToString());
+ value = Variant(ReadString(kIndexStackTop).ToString());
+ } else if (key.Equals("string_array_value")) {
+ value = Variant(ReadStringVector());
+ } else if (key.Equals("float_array_value")) {
+ value = Variant(ReadFloatVector());
+ } else if (key.Equals("int_array_value")) {
+ value = Variant(ReadIntVector());
} else {
TC3_LOG(INFO) << "Unknown extra field: " << key.ToString();
}
@@ -695,7 +736,7 @@
int JniLuaEnvironment::ReadRemoteActionTemplates(
std::vector<RemoteActionTemplate>* result) {
// Read result.
- if (lua_type(state_, /*idx=*/-1) != LUA_TTABLE) {
+ if (lua_type(state_, kIndexStackTop) != LUA_TTABLE) {
TC3_LOG(ERROR) << "Unexpected result for snippet: " << lua_type(state_, -1);
lua_error(state_);
return LUA_ERRRUN;
@@ -704,9 +745,9 @@
// Read remote action templates array.
lua_pushnil(state_);
while (lua_next(state_, /*idx=*/-2)) {
- if (lua_type(state_, /*idx=*/-1) != LUA_TTABLE) {
+ if (lua_type(state_, kIndexStackTop) != LUA_TTABLE) {
TC3_LOG(ERROR) << "Expected intent table, got: "
- << lua_type(state_, /*idx=*/-1);
+ << lua_type(state_, kIndexStackTop);
lua_pop(state_, 1);
continue;
}
@@ -897,6 +938,7 @@
// Retrieve generator for specified entity.
auto it = generators_.find(classification.collection);
if (it == generators_.end()) {
+ TC3_VLOG(INFO) << "Cannot find a generator for the specified collection.";
return true;
}
diff --git a/native/utils/intents/jni.cc b/native/utils/intents/jni.cc
index ed7737e..bd0fc7d 100644
--- a/native/utils/intents/jni.cc
+++ b/native/utils/intents/jni.cc
@@ -77,7 +77,12 @@
"(Ljava/lang/String;Z)V");
TC3_GET_METHOD(named_variant_class_, named_variant_from_string_, "<init>",
"(Ljava/lang/String;Ljava/lang/String;)V");
-
+ TC3_GET_METHOD(named_variant_class_, named_variant_from_string_array_,
+ "<init>", "(Ljava/lang/String;[Ljava/lang/String;)V");
+ TC3_GET_METHOD(named_variant_class_, named_variant_from_float_array_,
+ "<init>", "(Ljava/lang/String;[F)V");
+ TC3_GET_METHOD(named_variant_class_, named_variant_from_int_array_, "<init>",
+ "(Ljava/lang/String;[I)V");
return handler;
}
@@ -124,6 +129,38 @@
return result;
}
+StatusOr<ScopedLocalRef<jfloatArray>>
+RemoteActionTemplatesHandler::AsFloatArray(
+ const std::vector<float>& values) const {
+ if (values.empty()) {
+ return {{nullptr, jni_cache_->GetEnv()}};
+ }
+
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jfloatArray> result,
+ JniHelper::NewFloatArray(jni_cache_->GetEnv(), values.size()));
+
+ jni_cache_->GetEnv()->SetFloatArrayRegion(result.get(), /*start=*/0,
+ /*len=*/values.size(),
+ &(values[0]));
+ return result;
+}
+
+StatusOr<ScopedLocalRef<jintArray>> RemoteActionTemplatesHandler::AsIntArray(
+ const std::vector<int>& values) const {
+ if (values.empty()) {
+ return {{nullptr, jni_cache_->GetEnv()}};
+ }
+
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jintArray> result,
+ JniHelper::NewIntArray(jni_cache_->GetEnv(), values.size()));
+
+ jni_cache_->GetEnv()->SetIntArrayRegion(result.get(), /*start=*/0,
+ /*len=*/values.size(), &(values[0]));
+ return result;
+}
+
StatusOr<ScopedLocalRef<jobject>> RemoteActionTemplatesHandler::AsNamedVariant(
const std::string& name_str, const Variant& value) const {
TC3_ASSIGN_OR_RETURN(ScopedLocalRef<jstring> name,
@@ -165,6 +202,33 @@
value_jstring.get());
}
+ case Variant::TYPE_STRING_VECTOR_VALUE: {
+ TC3_ASSIGN_OR_RETURN(ScopedLocalRef<jobjectArray> value_jstring_array,
+ AsStringArray(value.StringVectorValue()));
+
+ return JniHelper::NewObject(env, named_variant_class_.get(),
+ named_variant_from_string_array_, name.get(),
+ value_jstring_array.get());
+ }
+
+ case Variant::TYPE_FLOAT_VECTOR_VALUE: {
+ TC3_ASSIGN_OR_RETURN(ScopedLocalRef<jfloatArray> value_jfloat_array,
+ AsFloatArray(value.FloatVectorValue()));
+
+ return JniHelper::NewObject(env, named_variant_class_.get(),
+ named_variant_from_float_array_, name.get(),
+ value_jfloat_array.get());
+ }
+
+ case Variant::TYPE_INT_VECTOR_VALUE: {
+ TC3_ASSIGN_OR_RETURN(ScopedLocalRef<jintArray> value_jint_array,
+ AsIntArray(value.IntVectorValue()));
+
+ return JniHelper::NewObject(env, named_variant_class_.get(),
+ named_variant_from_float_array_, name.get(),
+ value_jint_array.get());
+ }
+
case Variant::TYPE_EMPTY:
return {Status::UNKNOWN};
diff --git a/native/utils/intents/jni.h b/native/utils/intents/jni.h
index a185472..a2d57a0 100644
--- a/native/utils/intents/jni.h
+++ b/native/utils/intents/jni.h
@@ -60,6 +60,10 @@
const Optional<int>& optional) const;
StatusOr<ScopedLocalRef<jobjectArray>> AsStringArray(
const std::vector<std::string>& values) const;
+ StatusOr<ScopedLocalRef<jfloatArray>> AsFloatArray(
+ const std::vector<float>& values) const;
+ StatusOr<ScopedLocalRef<jintArray>> AsIntArray(
+ const std::vector<int>& values) const;
StatusOr<ScopedLocalRef<jobject>> AsNamedVariant(const std::string& name,
const Variant& value) const;
StatusOr<ScopedLocalRef<jobjectArray>> AsNamedVariantArray(
@@ -98,6 +102,9 @@
jmethodID named_variant_from_double_ = nullptr;
jmethodID named_variant_from_bool_ = nullptr;
jmethodID named_variant_from_string_ = nullptr;
+ jmethodID named_variant_from_string_array_ = nullptr;
+ jmethodID named_variant_from_float_array_ = nullptr;
+ jmethodID named_variant_from_int_array_ = nullptr;
};
} // namespace libtextclassifier3
diff --git a/native/utils/java/jni-helper.cc b/native/utils/java/jni-helper.cc
index 34e0ab8..6618cfa 100644
--- a/native/utils/java/jni-helper.cc
+++ b/native/utils/java/jni-helper.cc
@@ -120,6 +120,15 @@
return result;
}
+StatusOr<ScopedLocalRef<jfloatArray>> JniHelper::NewFloatArray(JNIEnv* env,
+ jsize length) {
+ TC3_ENSURE_LOCAL_CAPACITY_OR_RETURN;
+ ScopedLocalRef<jfloatArray> result(env->NewFloatArray(length), env);
+ TC3_NOT_NULL_OR_RETURN;
+ TC3_NO_EXCEPTION_OR_RETURN;
+ return result;
+}
+
StatusOr<ScopedLocalRef<jobjectArray>> JniHelper::NewObjectArray(
JNIEnv* env, jsize length, jclass element_class, jobject initial_element) {
TC3_ENSURE_LOCAL_CAPACITY_OR_RETURN;
diff --git a/native/utils/java/jni-helper.h b/native/utils/java/jni-helper.h
index 56a07b7..71d31cb 100644
--- a/native/utils/java/jni-helper.h
+++ b/native/utils/java/jni-helper.h
@@ -93,6 +93,8 @@
jsize length);
static StatusOr<ScopedLocalRef<jstring>> NewStringUTF(JNIEnv* env,
const char* bytes);
+ static StatusOr<ScopedLocalRef<jfloatArray>> NewFloatArray(JNIEnv* env,
+ jsize length);
// Call* methods.
TC3_DEFINE_VARIADIC_SCOPED_LOCAL_REF_ENV_METHOD(CallObjectMethod, jobject,
diff --git a/native/utils/memory/mmap.h b/native/utils/memory/mmap.h
index acce7db..974cc02 100644
--- a/native/utils/memory/mmap.h
+++ b/native/utils/memory/mmap.h
@@ -130,7 +130,7 @@
}
}
- const MmapHandle &handle() { return handle_; }
+ const MmapHandle &handle() const { return handle_; }
private:
MmapHandle handle_;
diff --git a/native/utils/normalization_test.cc b/native/utils/normalization_test.cc
deleted file mode 100644
index 1bf9fae..0000000
--- a/native/utils/normalization_test.cc
+++ /dev/null
@@ -1,121 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/normalization.h"
-
-#include <string>
-
-#include "utils/base/integral_types.h"
-#include "utils/utf8/unicodetext.h"
-#include "utils/utf8/unilib.h"
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-
-namespace libtextclassifier3 {
-namespace {
-
-using testing::Eq;
-
-class NormalizationTest : public testing::Test {
- protected:
- NormalizationTest() : INIT_UNILIB_FOR_TESTING(unilib_) {}
-
- std::string NormalizeTextCodepointWise(const std::string& text,
- const int32 codepointwise_ops) {
- return libtextclassifier3::NormalizeTextCodepointWise(
- &unilib_, codepointwise_ops,
- UTF8ToUnicodeText(text, /*do_copy=*/false))
- .ToUTF8String();
- }
-
- UniLib unilib_;
-};
-
-TEST_F(NormalizationTest, ReturnsIdenticalStringWhenNoNormalization) {
- EXPECT_THAT(NormalizeTextCodepointWise(
- "Never gonna let you down.",
- NormalizationOptions_::CodepointwiseNormalizationOp_NONE),
- Eq("Never gonna let you down."));
-}
-
-#if !defined(TC3_UNILIB_DUMMY)
-TEST_F(NormalizationTest, DropsWhitespace) {
- EXPECT_THAT(
- NormalizeTextCodepointWise(
- "Never gonna let you down.",
- NormalizationOptions_::CodepointwiseNormalizationOp_DROP_WHITESPACE),
- Eq("Nevergonnaletyoudown."));
- EXPECT_THAT(
- NormalizeTextCodepointWise(
- "Never\tgonna\t\tlet\tyou\tdown.",
- NormalizationOptions_::CodepointwiseNormalizationOp_DROP_WHITESPACE),
- Eq("Nevergonnaletyoudown."));
- EXPECT_THAT(
- NormalizeTextCodepointWise(
- "Never\u2003gonna\u2003let\u2003you\u2003down.",
- NormalizationOptions_::CodepointwiseNormalizationOp_DROP_WHITESPACE),
- Eq("Nevergonnaletyoudown."));
-}
-
-TEST_F(NormalizationTest, DropsPunctuation) {
- EXPECT_THAT(
- NormalizeTextCodepointWise(
- "Never gonna let you down.",
- NormalizationOptions_::CodepointwiseNormalizationOp_DROP_PUNCTUATION),
- Eq("Never gonna let you down"));
- EXPECT_THAT(
- NormalizeTextCodepointWise(
- "αʹ. Σημεῖόν ἐστιν, οὗ μέρος οὐθέν.",
- NormalizationOptions_::CodepointwiseNormalizationOp_DROP_PUNCTUATION),
- Eq("αʹ Σημεῖόν ἐστιν οὗ μέρος οὐθέν"));
- EXPECT_THAT(
- NormalizeTextCodepointWise(
- "978—3—16—148410—0",
- NormalizationOptions_::CodepointwiseNormalizationOp_DROP_PUNCTUATION),
- Eq("9783161484100"));
-}
-
-TEST_F(NormalizationTest, LowercasesUnicodeText) {
- EXPECT_THAT(
- NormalizeTextCodepointWise(
- "αʹ. Σημεῖόν ἐστιν, οὗ μέρος οὐθέν.",
- NormalizationOptions_::CodepointwiseNormalizationOp_LOWERCASE),
- Eq("αʹ. σημεῖόν ἐστιν, οὗ μέρος οὐθέν."));
- EXPECT_THAT(
- NormalizeTextCodepointWise(
- "αʹ. Σημεῖόν ἐστιν, οὗ μέρος οὐθέν.",
- NormalizationOptions_::CodepointwiseNormalizationOp_DROP_WHITESPACE |
- NormalizationOptions_::CodepointwiseNormalizationOp_LOWERCASE),
- Eq("αʹ.σημεῖόνἐστιν,οὗμέροςοὐθέν."));
-}
-
-TEST_F(NormalizationTest, UppercasesUnicodeText) {
- EXPECT_THAT(
- NormalizeTextCodepointWise(
- "Κανένας άνθρωπος δεν ξέρει",
- NormalizationOptions_::CodepointwiseNormalizationOp_UPPERCASE),
- Eq("ΚΑΝΈΝΑΣ ΆΝΘΡΩΠΟΣ ΔΕΝ ΞΈΡΕΙ"));
- EXPECT_THAT(
- NormalizeTextCodepointWise(
- "Κανένας άνθρωπος δεν ξέρει",
- NormalizationOptions_::CodepointwiseNormalizationOp_DROP_WHITESPACE |
- NormalizationOptions_::CodepointwiseNormalizationOp_UPPERCASE),
- Eq("ΚΑΝΈΝΑΣΆΝΘΡΩΠΟΣΔΕΝΞΈΡΕΙ"));
-}
-#endif
-
-} // namespace
-} // namespace libtextclassifier3
diff --git a/native/utils/regex-match_test.cc b/native/utils/regex-match_test.cc
deleted file mode 100644
index c45fb29..0000000
--- a/native/utils/regex-match_test.cc
+++ /dev/null
@@ -1,114 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/regex-match.h"
-
-#include <memory>
-
-#include "utils/utf8/unicodetext.h"
-#include "utils/utf8/unilib.h"
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-
-namespace libtextclassifier3 {
-namespace {
-
-class RegexMatchTest : public testing::Test {
- protected:
- RegexMatchTest() : INIT_UNILIB_FOR_TESTING(unilib_) {}
- UniLib unilib_;
-};
-
-#ifdef TC3_UNILIB_ICU
-#ifndef TC3_DISABLE_LUA
-TEST_F(RegexMatchTest, HandlesSimpleVerification) {
- EXPECT_TRUE(VerifyMatch(/*context=*/"", /*matcher=*/nullptr, "return true;"));
-}
-#endif // TC3_DISABLE_LUA
-
-#ifndef TC3_DISABLE_LUA
-TEST_F(RegexMatchTest, HandlesCustomVerification) {
- UnicodeText pattern = UTF8ToUnicodeText("(\\d{16})",
- /*do_copy=*/true);
- UnicodeText message = UTF8ToUnicodeText("cc: 4012888888881881",
- /*do_copy=*/true);
- const std::string verifier = R"(
-function luhn(candidate)
- local sum = 0
- local num_digits = string.len(candidate)
- local parity = num_digits % 2
- for pos = 1,num_digits do
- d = tonumber(string.sub(candidate, pos, pos))
- if pos % 2 ~= parity then
- d = d * 2
- end
- if d > 9 then
- d = d - 9
- end
- sum = sum + d
- end
- return (sum % 10) == 0
-end
-return luhn(match[1].text);
- )";
- const std::unique_ptr<UniLib::RegexPattern> regex_pattern =
- unilib_.CreateRegexPattern(pattern);
- ASSERT_TRUE(regex_pattern != nullptr);
- const std::unique_ptr<UniLib::RegexMatcher> matcher =
- regex_pattern->Matcher(message);
- ASSERT_TRUE(matcher != nullptr);
- int status = UniLib::RegexMatcher::kNoError;
- ASSERT_TRUE(matcher->Find(&status) &&
- status == UniLib::RegexMatcher::kNoError);
-
- EXPECT_TRUE(VerifyMatch(message.ToUTF8String(), matcher.get(), verifier));
-}
-#endif // TC3_DISABLE_LUA
-
-TEST_F(RegexMatchTest, RetrievesMatchGroupTest) {
- UnicodeText pattern =
- UTF8ToUnicodeText("never gonna (?:give (you) up|let (you) down)",
- /*do_copy=*/true);
- const std::unique_ptr<UniLib::RegexPattern> regex_pattern =
- unilib_.CreateRegexPattern(pattern);
- ASSERT_TRUE(regex_pattern != nullptr);
- UnicodeText message =
- UTF8ToUnicodeText("never gonna give you up - never gonna let you down");
- const std::unique_ptr<UniLib::RegexMatcher> matcher =
- regex_pattern->Matcher(message);
- ASSERT_TRUE(matcher != nullptr);
- int status = UniLib::RegexMatcher::kNoError;
-
- ASSERT_TRUE(matcher->Find(&status) &&
- status == UniLib::RegexMatcher::kNoError);
- EXPECT_THAT(GetCapturingGroupText(matcher.get(), 0).value(),
- testing::Eq("never gonna give you up"));
- EXPECT_THAT(GetCapturingGroupText(matcher.get(), 1).value(),
- testing::Eq("you"));
- EXPECT_FALSE(GetCapturingGroupText(matcher.get(), 2).has_value());
-
- ASSERT_TRUE(matcher->Find(&status) &&
- status == UniLib::RegexMatcher::kNoError);
- EXPECT_THAT(GetCapturingGroupText(matcher.get(), 0).value(),
- testing::Eq("never gonna let you down"));
- EXPECT_FALSE(GetCapturingGroupText(matcher.get(), 1).has_value());
- EXPECT_THAT(GetCapturingGroupText(matcher.get(), 2).value(),
- testing::Eq("you"));
-}
-#endif
-
-} // namespace
-} // namespace libtextclassifier3
diff --git a/native/utils/resources_test.cc b/native/utils/resources_test.cc
deleted file mode 100644
index c385f39..0000000
--- a/native/utils/resources_test.cc
+++ /dev/null
@@ -1,287 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/resources.h"
-#include "utils/i18n/locale.h"
-#include "utils/resources_generated.h"
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-
-namespace libtextclassifier3 {
-namespace {
-
-class ResourcesTest
- : public testing::TestWithParam<testing::tuple<bool, bool>> {
- protected:
- ResourcesTest() {}
-
- std::string BuildTestResources(bool add_default_language = true) const {
- ResourcePoolT test_resources;
-
- // Test locales.
- test_resources.locale.emplace_back(new LanguageTagT);
- test_resources.locale.back()->language = "en";
- test_resources.locale.back()->region = "US";
- test_resources.locale.emplace_back(new LanguageTagT);
- test_resources.locale.back()->language = "en";
- test_resources.locale.back()->region = "GB";
- test_resources.locale.emplace_back(new LanguageTagT);
- test_resources.locale.back()->language = "de";
- test_resources.locale.back()->region = "DE";
- test_resources.locale.emplace_back(new LanguageTagT);
- test_resources.locale.back()->language = "fr";
- test_resources.locale.back()->region = "FR";
- test_resources.locale.emplace_back(new LanguageTagT);
- test_resources.locale.back()->language = "pt";
- test_resources.locale.back()->region = "PT";
- test_resources.locale.emplace_back(new LanguageTagT);
- test_resources.locale.back()->language = "pt";
- test_resources.locale.emplace_back(new LanguageTagT);
- test_resources.locale.back()->language = "zh";
- test_resources.locale.back()->script = "Hans";
- test_resources.locale.back()->region = "CN";
- test_resources.locale.emplace_back(new LanguageTagT);
- test_resources.locale.back()->language = "zh";
- test_resources.locale.emplace_back(new LanguageTagT);
- test_resources.locale.back()->language = "fr";
- test_resources.locale.back()->language = "fr-CA";
- if (add_default_language) {
- test_resources.locale.emplace_back(new LanguageTagT); // default
- }
-
- // Test entries.
- test_resources.resource_entry.emplace_back(new ResourceEntryT);
- test_resources.resource_entry.back()->name = /*resource_name=*/"A";
-
- // en-US, default
- test_resources.resource_entry.back()->resource.emplace_back(new ResourceT);
- test_resources.resource_entry.back()->resource.back()->content = "localize";
- test_resources.resource_entry.back()->resource.back()->locale.push_back(0);
- if (add_default_language) {
- test_resources.resource_entry.back()->resource.back()->locale.push_back(
- 9);
- }
-
- // en-GB
- test_resources.resource_entry.back()->resource.emplace_back(new ResourceT);
- test_resources.resource_entry.back()->resource.back()->content = "localise";
- test_resources.resource_entry.back()->resource.back()->locale.push_back(1);
-
- // de-DE
- test_resources.resource_entry.back()->resource.emplace_back(new ResourceT);
- test_resources.resource_entry.back()->resource.back()->content =
- "lokalisieren";
- test_resources.resource_entry.back()->resource.back()->locale.push_back(2);
-
- // fr-FR, fr-CA
- test_resources.resource_entry.back()->resource.emplace_back(new ResourceT);
- test_resources.resource_entry.back()->resource.back()->content =
- "localiser";
- test_resources.resource_entry.back()->resource.back()->locale.push_back(3);
- test_resources.resource_entry.back()->resource.back()->locale.push_back(8);
-
- // pt-PT
- test_resources.resource_entry.back()->resource.emplace_back(new ResourceT);
- test_resources.resource_entry.back()->resource.back()->content =
- "localizar";
- test_resources.resource_entry.back()->resource.back()->locale.push_back(4);
-
- // pt
- test_resources.resource_entry.back()->resource.emplace_back(new ResourceT);
- test_resources.resource_entry.back()->resource.back()->content =
- "concentrar";
- test_resources.resource_entry.back()->resource.back()->locale.push_back(5);
-
- // zh-Hans-CN
- test_resources.resource_entry.back()->resource.emplace_back(new ResourceT);
- test_resources.resource_entry.back()->resource.back()->content = "龙";
- test_resources.resource_entry.back()->resource.back()->locale.push_back(6);
-
- // zh
- test_resources.resource_entry.back()->resource.emplace_back(new ResourceT);
- test_resources.resource_entry.back()->resource.back()->content = "龍";
- test_resources.resource_entry.back()->resource.back()->locale.push_back(7);
-
- if (compress()) {
- EXPECT_TRUE(CompressResources(
- &test_resources,
- /*build_compression_dictionary=*/build_dictionary()));
- }
-
- flatbuffers::FlatBufferBuilder builder;
- builder.Finish(ResourcePool::Pack(builder, &test_resources));
-
- return std::string(
- reinterpret_cast<const char*>(builder.GetBufferPointer()),
- builder.GetSize());
- }
-
- bool compress() const { return testing::get<0>(GetParam()); }
-
- bool build_dictionary() const { return testing::get<1>(GetParam()); }
-};
-
-INSTANTIATE_TEST_SUITE_P(Compression, ResourcesTest,
- testing::Combine(testing::Bool(), testing::Bool()));
-
-TEST_P(ResourcesTest, CorrectlyHandlesExactMatch) {
- std::string test_resources = BuildTestResources();
- Resources resources(
- flatbuffers::GetRoot<ResourcePool>(test_resources.data()));
- std::string content;
- EXPECT_TRUE(resources.GetResourceContent({Locale::FromBCP47("en-US")},
- /*resource_name=*/"A", &content));
- EXPECT_EQ("localize", content);
- EXPECT_TRUE(resources.GetResourceContent({Locale::FromBCP47("en-GB")},
- /*resource_name=*/"A", &content));
- EXPECT_EQ("localise", content);
- EXPECT_TRUE(resources.GetResourceContent({Locale::FromBCP47("pt-PT")},
- /*resource_name=*/"A", &content));
- EXPECT_EQ("localizar", content);
- EXPECT_TRUE(resources.GetResourceContent({Locale::FromBCP47("zh-Hans-CN")},
- /*resource_name=*/"A", &content));
- EXPECT_EQ("龙", content);
- EXPECT_TRUE(resources.GetResourceContent({Locale::FromBCP47("zh")},
- /*resource_name=*/"A", &content));
- EXPECT_EQ("龍", content);
- EXPECT_TRUE(resources.GetResourceContent({Locale::FromBCP47("fr-CA")},
- /*resource_name=*/"A", &content));
- EXPECT_EQ("localiser", content);
-}
-
-TEST_P(ResourcesTest, CorrectlyHandlesTie) {
- std::string test_resources = BuildTestResources();
- Resources resources(
- flatbuffers::GetRoot<ResourcePool>(test_resources.data()));
- // Uses first best match in case of a tie.
- std::string content;
- EXPECT_TRUE(resources.GetResourceContent({Locale::FromBCP47("en-CA")},
- /*resource_name=*/"A", &content));
- EXPECT_EQ("localize", content);
-}
-
-TEST_P(ResourcesTest, RequiresLanguageMatch) {
- {
- std::string test_resources =
- BuildTestResources(/*add_default_language=*/false);
- Resources resources(
- flatbuffers::GetRoot<ResourcePool>(test_resources.data()));
- EXPECT_FALSE(resources.GetResourceContent({Locale::FromBCP47("es-US")},
- /*resource_name=*/"A",
- /*result=*/nullptr));
- }
- {
- std::string test_resources =
- BuildTestResources(/*add_default_language=*/true);
- Resources resources(
- flatbuffers::GetRoot<ResourcePool>(test_resources.data()));
- std::string content;
- EXPECT_TRUE(resources.GetResourceContent({Locale::FromBCP47("es-US")},
- /*resource_name=*/"A",
- /*result=*/&content));
- EXPECT_EQ("localize", content);
- }
-}
-
-TEST_P(ResourcesTest, HandlesFallback) {
- std::string test_resources = BuildTestResources();
- Resources resources(
- flatbuffers::GetRoot<ResourcePool>(test_resources.data()));
- std::string content;
- EXPECT_TRUE(resources.GetResourceContent({Locale::FromBCP47("fr-CH")},
- /*resource_name=*/"A", &content));
- EXPECT_EQ("localiser", content);
- EXPECT_TRUE(resources.GetResourceContent({Locale::FromBCP47("zh-Hans")},
- /*resource_name=*/"A", &content));
- EXPECT_EQ("龙", content);
- EXPECT_TRUE(resources.GetResourceContent({Locale::FromBCP47("zh-Hans-ZZ")},
- /*resource_name=*/"A", &content));
- EXPECT_EQ("龙", content);
-
- // Fallback to default, en-US.
- EXPECT_TRUE(resources.GetResourceContent({Locale::FromBCP47("ru")},
- /*resource_name=*/"A", &content));
- EXPECT_EQ("localize", content);
-}
-
-TEST_P(ResourcesTest, HandlesFallbackMultipleLocales) {
- std::string test_resources = BuildTestResources();
- Resources resources(
- flatbuffers::GetRoot<ResourcePool>(test_resources.data()));
- std::string content;
-
- // Still use inexact match with primary locale if language matches,
- // even though secondary locale would match exactly.
- EXPECT_TRUE(resources.GetResourceContent(
- {Locale::FromBCP47("fr-CH"), Locale::FromBCP47("en-US")},
- /*resource_name=*/"A", &content));
- EXPECT_EQ("localiser", content);
-
- // Use secondary language instead of default fallback if that is an exact
- // language match.
- EXPECT_TRUE(resources.GetResourceContent(
- {Locale::FromBCP47("ru"), Locale::FromBCP47("de")},
- /*resource_name=*/"A", &content));
- EXPECT_EQ("lokalisieren", content);
-
- // Use tertiary language.
- EXPECT_TRUE(resources.GetResourceContent(
- {Locale::FromBCP47("ru"), Locale::FromBCP47("it-IT"),
- Locale::FromBCP47("de")},
- /*resource_name=*/"A", &content));
- EXPECT_EQ("lokalisieren", content);
-
- // Default fallback if no locale matches.
- EXPECT_TRUE(resources.GetResourceContent(
- {Locale::FromBCP47("ru"), Locale::FromBCP47("it-IT"),
- Locale::FromBCP47("es")},
- /*resource_name=*/"A", &content));
- EXPECT_EQ("localize", content);
-}
-
-TEST_P(ResourcesTest, PreferGenericCallback) {
- std::string test_resources = BuildTestResources();
- Resources resources(
- flatbuffers::GetRoot<ResourcePool>(test_resources.data()));
- std::string content;
- EXPECT_TRUE(resources.GetResourceContent({Locale::FromBCP47("pt-BR")},
- /*resource_name=*/"A", &content));
- EXPECT_EQ("concentrar", content); // Falls back to pt, not pt-PT.
- EXPECT_TRUE(resources.GetResourceContent({Locale::FromBCP47("zh-Hant")},
- /*resource_name=*/"A", &content));
- EXPECT_EQ("龍", content); // Falls back to zh, not zh-Hans-CN.
- EXPECT_TRUE(resources.GetResourceContent({Locale::FromBCP47("zh-Hant-CN")},
- /*resource_name=*/"A", &content));
- EXPECT_EQ("龍", content); // Falls back to zh, not zh-Hans-CN.
- EXPECT_TRUE(resources.GetResourceContent({Locale::FromBCP47("zh-CN")},
- /*resource_name=*/"A", &content));
- EXPECT_EQ("龍", content); // Falls back to zh, not zh-Hans-CN.
-}
-
-TEST_P(ResourcesTest, PreferGenericWhenGeneric) {
- std::string test_resources = BuildTestResources();
- Resources resources(
- flatbuffers::GetRoot<ResourcePool>(test_resources.data()));
- std::string content;
- EXPECT_TRUE(resources.GetResourceContent({Locale::FromBCP47("pt")},
- /*resource_name=*/"A", &content));
-
- // Uses pt, not pt-PT.
- EXPECT_EQ("concentrar", content);
-}
-
-} // namespace
-} // namespace libtextclassifier3
diff --git a/native/utils/sentencepiece/encoder_test.cc b/native/utils/sentencepiece/encoder_test.cc
deleted file mode 100644
index 740db35..0000000
--- a/native/utils/sentencepiece/encoder_test.cc
+++ /dev/null
@@ -1,122 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/sentencepiece/encoder.h"
-
-#include <memory>
-#include <vector>
-
-#include "utils/base/integral_types.h"
-#include "utils/container/sorted-strings-table.h"
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-
-namespace libtextclassifier3 {
-namespace {
-
-using testing::ElementsAre;
-
-TEST(EncoderTest, SimpleTokenization) {
- const char pieces_table[] = "hell\0hello\0o\0there\0";
- const uint32 offsets[] = {0, 5, 11, 13};
- float scores[] = {-0.5, -1.0, -10.0, -1.0};
- std::unique_ptr<StringSet> pieces(new SortedStringsTable(
- /*num_pieces=*/4, offsets, StringPiece(pieces_table, 18)));
- const Encoder encoder(pieces.get(),
- /*num_pieces=*/4, scores);
-
- {
- std::vector<int> encoded_text;
- EXPECT_TRUE(encoder.Encode("hellothere", &encoded_text));
- EXPECT_THAT(encoded_text, ElementsAre(0, 3, 5, 1));
- }
-
- // Make probability of hello very low:
- // hello gets now tokenized as hell + o.
- scores[1] = -100.0;
- {
- std::vector<int> encoded_text;
- EXPECT_TRUE(encoder.Encode("hellothere", &encoded_text));
- EXPECT_THAT(encoded_text, ElementsAre(0, 2, 4, 5, 1));
- }
-}
-
-TEST(EncoderTest, HandlesEdgeCases) {
- const char pieces_table[] = "hell\0hello\0o\0there\0";
- const uint32 offsets[] = {0, 5, 11, 13};
- float scores[] = {-0.5, -1.0, -10.0, -1.0};
- std::unique_ptr<StringSet> pieces(new SortedStringsTable(
- /*num_pieces=*/4, offsets, StringPiece(pieces_table, 18)));
- const Encoder encoder(pieces.get(),
- /*num_pieces=*/4, scores);
- {
- std::vector<int> encoded_text;
- EXPECT_TRUE(encoder.Encode("hellhello", &encoded_text));
- EXPECT_THAT(encoded_text, ElementsAre(0, 2, 3, 1));
- }
- {
- std::vector<int> encoded_text;
- EXPECT_TRUE(encoder.Encode("hellohell", &encoded_text));
- EXPECT_THAT(encoded_text, ElementsAre(0, 3, 2, 1));
- }
- {
- std::vector<int> encoded_text;
- EXPECT_TRUE(encoder.Encode("", &encoded_text));
- EXPECT_THAT(encoded_text, ElementsAre(0, 1));
- }
- {
- std::vector<int> encoded_text;
- EXPECT_TRUE(encoder.Encode("hellathere", &encoded_text));
- EXPECT_THAT(encoded_text, ElementsAre(0, 1));
- }
-}
-
-TEST(EncoderTest, HandlesOutOfDictionary) {
- const char pieces_table[] = "hell\0hello\0o\0there\0";
- const uint32 offsets[] = {0, 5, 11, 13};
- float scores[] = {-0.5, -1.0, -10.0, -1.0};
- std::unique_ptr<StringSet> pieces(new SortedStringsTable(
- /*num_pieces=*/4, offsets, StringPiece(pieces_table, 18)));
- const Encoder encoder(pieces.get(),
- /*num_pieces=*/4, scores,
- /*start_code=*/0, /*end_code=*/1,
- /*encoding_offset=*/3, /*unknown_code=*/2,
- /*unknown_score=*/-100.0);
- {
- std::vector<int> encoded_text;
- EXPECT_TRUE(encoder.Encode("hellhello", &encoded_text));
- EXPECT_THAT(encoded_text, ElementsAre(0, 3, 4, 1));
- }
- {
- std::vector<int> encoded_text;
- EXPECT_TRUE(encoder.Encode("hellohell", &encoded_text));
- EXPECT_THAT(encoded_text, ElementsAre(0, 4, 3, 1));
- }
- {
- std::vector<int> encoded_text;
- EXPECT_TRUE(encoder.Encode("", &encoded_text));
- EXPECT_THAT(encoded_text, ElementsAre(0, 1));
- }
- {
- std::vector<int> encoded_text;
- EXPECT_TRUE(encoder.Encode("hellathere", &encoded_text));
- EXPECT_THAT(encoded_text,
- ElementsAre(0, /*hell*/ 3, /*unknown*/ 2, /*there*/ 6, 1));
- }
-}
-
-} // namespace
-} // namespace libtextclassifier3
diff --git a/native/utils/sentencepiece/normalizer_test.cc b/native/utils/sentencepiece/normalizer_test.cc
deleted file mode 100644
index 5010035..0000000
--- a/native/utils/sentencepiece/normalizer_test.cc
+++ /dev/null
@@ -1,198 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/sentencepiece/normalizer.h"
-
-#include <fstream>
-#include <string>
-
-#include "utils/container/double-array-trie.h"
-#include "utils/sentencepiece/test_utils.h"
-#include "utils/strings/stringpiece.h"
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-
-namespace libtextclassifier3 {
-namespace {
-
-std::string GetTestConfigPath() {
- return "";
-}
-
-TEST(NormalizerTest, NormalizesAsReferenceNormalizer) {
- std::ifstream test_config_stream(GetTestConfigPath());
- std::string config((std::istreambuf_iterator<char>(test_config_stream)),
- (std::istreambuf_iterator<char>()));
- SentencePieceNormalizer normalizer =
- NormalizerFromSpec(config, /*add_dummy_prefix=*/true,
- /*remove_extra_whitespaces=*/true,
- /*escape_whitespaces=*/true);
- {
- std::string normalized;
- EXPECT_TRUE(normalizer.Normalize("hello there", &normalized));
- EXPECT_EQ(normalized, "▁hello▁there");
- }
-
- // Redundant whitespace.
- {
- std::string normalized;
- EXPECT_TRUE(normalizer.Normalize("when is the world cup?", &normalized));
- EXPECT_EQ(normalized, "▁when▁is▁the▁world▁cup?");
- }
-
- // Different whitespace.
- {
- std::string normalized;
- EXPECT_TRUE(normalizer.Normalize("general\tkenobi", &normalized));
- EXPECT_EQ(normalized, "▁general▁kenobi");
- }
-
- // NFKC char to multi-char normalization.
- {
- std::string normalized;
- EXPECT_TRUE(normalizer.Normalize("㍿", &normalized));
- EXPECT_EQ(normalized, "▁株式会社");
- }
-
- // Half width katakana, character composition happens.
- {
- std::string normalized;
- EXPECT_TRUE(normalizer.Normalize(" グーグル ", &normalized));
- EXPECT_EQ(normalized, "▁グーグル");
- }
-
- // NFKC char to char normalization.
- {
- std::string normalized;
- EXPECT_TRUE(normalizer.Normalize("①②③", &normalized));
- EXPECT_EQ(normalized, "▁123");
- }
-}
-
-TEST(NormalizerTest, NoDummyPrefix) {
- std::ifstream test_config_stream(GetTestConfigPath());
- std::string config((std::istreambuf_iterator<char>(test_config_stream)),
- (std::istreambuf_iterator<char>()));
- SentencePieceNormalizer normalizer =
- NormalizerFromSpec(config, /*add_dummy_prefix=*/false,
- /*remove_extra_whitespaces=*/true,
- /*escape_whitespaces=*/true);
-
- // NFKC char to char normalization.
- {
- std::string normalized;
- EXPECT_TRUE(normalizer.Normalize("hello there", &normalized));
- EXPECT_EQ(normalized, "hello▁there");
- }
-
- // Redundant whitespace.
- {
- std::string normalized;
- EXPECT_TRUE(normalizer.Normalize("when is the world cup?", &normalized));
- EXPECT_EQ(normalized, "when▁is▁the▁world▁cup?");
- }
-
- // Different whitespace.
- {
- std::string normalized;
- EXPECT_TRUE(normalizer.Normalize("general\tkenobi", &normalized));
- EXPECT_EQ(normalized, "general▁kenobi");
- }
-
- // NFKC char to multi-char normalization.
- {
- std::string normalized;
- EXPECT_TRUE(normalizer.Normalize("㍿", &normalized));
- EXPECT_EQ(normalized, "株式会社");
- }
-
- // Half width katakana, character composition happens.
- {
- std::string normalized;
- EXPECT_TRUE(normalizer.Normalize(" グーグル ", &normalized));
- EXPECT_EQ(normalized, "グーグル");
- }
-
- // NFKC char to char normalization.
- {
- std::string normalized;
- EXPECT_TRUE(normalizer.Normalize("①②③", &normalized));
- EXPECT_EQ(normalized, "123");
- }
-}
-
-TEST(NormalizerTest, NoRemoveExtraWhitespace) {
- std::ifstream test_config_stream(GetTestConfigPath());
- std::string config((std::istreambuf_iterator<char>(test_config_stream)),
- (std::istreambuf_iterator<char>()));
- SentencePieceNormalizer normalizer =
- NormalizerFromSpec(config, /*add_dummy_prefix=*/false,
- /*remove_extra_whitespaces=*/false,
- /*escape_whitespaces=*/true);
-
- {
- std::string normalized;
- EXPECT_TRUE(normalizer.Normalize("hello there", &normalized));
- EXPECT_EQ(normalized, "hello▁there");
- }
-
- // Redundant whitespace.
- {
- std::string normalized;
- EXPECT_TRUE(normalizer.Normalize("when is the world cup?", &normalized));
- EXPECT_EQ(normalized, "when▁is▁▁the▁▁world▁cup?");
- }
-
- // Different whitespace.
- {
- std::string normalized;
- EXPECT_TRUE(normalizer.Normalize("general\tkenobi", &normalized));
- EXPECT_EQ(normalized, "general▁kenobi");
- }
-}
-
-TEST(NormalizerTest, NoEscapeWhitespaces) {
- std::ifstream test_config_stream(GetTestConfigPath());
- std::string config((std::istreambuf_iterator<char>(test_config_stream)),
- (std::istreambuf_iterator<char>()));
- SentencePieceNormalizer normalizer =
- NormalizerFromSpec(config, /*add_dummy_prefix=*/false,
- /*remove_extra_whitespaces=*/false,
- /*escape_whitespaces=*/false);
-
- {
- std::string normalized;
- EXPECT_TRUE(normalizer.Normalize("hello there", &normalized));
- EXPECT_EQ(normalized, "hello there");
- }
-
- // Redundant whitespace.
- {
- std::string normalized;
- EXPECT_TRUE(normalizer.Normalize("when is the world cup?", &normalized));
- EXPECT_EQ(normalized, "when is the world cup?");
- }
-
- // Different whitespace.
- {
- std::string normalized;
- EXPECT_TRUE(normalizer.Normalize("general\tkenobi", &normalized));
- EXPECT_EQ(normalized, "general kenobi");
- }
-}
-
-} // namespace
-} // namespace libtextclassifier3
diff --git a/native/utils/strings/append_test.cc b/native/utils/strings/append_test.cc
deleted file mode 100644
index 8950761..0000000
--- a/native/utils/strings/append_test.cc
+++ /dev/null
@@ -1,60 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/strings/append.h"
-
-#include "gtest/gtest.h"
-
-namespace libtextclassifier3 {
-namespace strings {
-
-TEST(StringUtilTest, SStringAppendF) {
- std::string str;
- SStringAppendF(&str, 5, "%d %d", 0, 1);
- EXPECT_EQ(str, "0 1");
-
- SStringAppendF(&str, 1, "%d", 9);
- EXPECT_EQ(str, "0 19");
-
- SStringAppendF(&str, 1, "%d", 10);
- EXPECT_EQ(str, "0 191");
-
- str.clear();
-
- SStringAppendF(&str, 5, "%d", 100);
- EXPECT_EQ(str, "100");
-}
-
-TEST(StringUtilTest, SStringAppendFBufCalc) {
- std::string str;
- SStringAppendF(&str, 0, "%d %s %d", 1, "hello", 2);
- EXPECT_EQ(str, "1 hello 2");
-}
-
-TEST(StringUtilTest, JoinStrings) {
- std::vector<std::string> vec;
- vec.push_back("1");
- vec.push_back("2");
- vec.push_back("3");
-
- EXPECT_EQ("1,2,3", JoinStrings(",", vec));
- EXPECT_EQ("123", JoinStrings("", vec));
- EXPECT_EQ("1, 2, 3", JoinStrings(", ", vec));
- EXPECT_EQ("", JoinStrings(",", std::vector<std::string>()));
-}
-
-} // namespace strings
-} // namespace libtextclassifier3
diff --git a/native/utils/strings/numbers_test.cc b/native/utils/strings/numbers_test.cc
deleted file mode 100644
index bf2f84a..0000000
--- a/native/utils/strings/numbers_test.cc
+++ /dev/null
@@ -1,109 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/strings/numbers.h"
-
-#include "utils/base/integral_types.h"
-#include "gtest/gtest.h"
-
-namespace libtextclassifier3 {
-namespace {
-
-void TestParseInt32(const char *c_str, bool expected_parsing_success,
- int32 expected_parsed_value = 0) {
- int32 parsed_value = 0;
- EXPECT_EQ(expected_parsing_success, ParseInt32(c_str, &parsed_value));
- if (expected_parsing_success) {
- EXPECT_EQ(expected_parsed_value, parsed_value);
- }
-}
-
-TEST(ParseInt32Test, Normal) {
- TestParseInt32("2", true, 2);
- TestParseInt32("-357", true, -357);
- TestParseInt32("7", true, 7);
- TestParseInt32("+7", true, 7);
- TestParseInt32(" +7", true, 7);
- TestParseInt32("-23", true, -23);
- TestParseInt32(" -23", true, -23);
- TestParseInt32("04", true, 4);
- TestParseInt32("07", true, 7);
- TestParseInt32("08", true, 8);
- TestParseInt32("09", true, 9);
-}
-
-TEST(ParseInt32Test, ErrorCases) {
- TestParseInt32("", false);
- TestParseInt32(" ", false);
- TestParseInt32("not-a-number", false);
- TestParseInt32("123a", false);
-}
-
-void TestParseInt64(const char *c_str, bool expected_parsing_success,
- int64 expected_parsed_value = 0) {
- int64 parsed_value = 0;
- EXPECT_EQ(expected_parsing_success, ParseInt64(c_str, &parsed_value));
- if (expected_parsing_success) {
- EXPECT_EQ(expected_parsed_value, parsed_value);
- }
-}
-
-TEST(ParseInt64Test, Normal) {
- TestParseInt64("2", true, 2);
- TestParseInt64("-357", true, -357);
- TestParseInt64("7", true, 7);
- TestParseInt64("+7", true, 7);
- TestParseInt64(" +7", true, 7);
- TestParseInt64("-23", true, -23);
- TestParseInt64(" -23", true, -23);
- TestParseInt64("07", true, 7);
- TestParseInt64("08", true, 8);
-}
-
-TEST(ParseInt64Test, ErrorCases) {
- TestParseInt64("", false);
- TestParseInt64(" ", false);
- TestParseInt64("not-a-number", false);
- TestParseInt64("23z", false);
-}
-
-void TestParseDouble(const char *c_str, bool expected_parsing_success,
- double expected_parsed_value = 0.0) {
- double parsed_value = 0.0;
- EXPECT_EQ(expected_parsing_success, ParseDouble(c_str, &parsed_value));
- if (expected_parsing_success) {
- EXPECT_NEAR(expected_parsed_value, parsed_value, 0.00001);
- }
-}
-
-TEST(ParseDoubleTest, Normal) {
- TestParseDouble("2", true, 2.0);
- TestParseDouble("-357.023", true, -357.023);
- TestParseDouble("7.04", true, 7.04);
- TestParseDouble("+7.2", true, 7.2);
- TestParseDouble(" +7.236", true, 7.236);
- TestParseDouble("-23.4", true, -23.4);
- TestParseDouble(" -23.4", true, -23.4);
-}
-
-TEST(ParseDoubleTest, ErrorCases) {
- TestParseDouble("", false);
- TestParseDouble(" ", false);
- TestParseDouble("not-a-number", false);
- TestParseDouble("23.5a", false);
-}
-} // namespace
-} // namespace libtextclassifier3
diff --git a/native/utils/strings/stringpiece_test.cc b/native/utils/strings/stringpiece_test.cc
deleted file mode 100644
index 64808d3..0000000
--- a/native/utils/strings/stringpiece_test.cc
+++ /dev/null
@@ -1,95 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-
-#include "utils/strings/stringpiece.h"
-
-namespace libtextclassifier3 {
-namespace {
-
-TEST(StringPieceTest, EndsWith) {
- EXPECT_TRUE(EndsWith("hello there!", "there!"));
- EXPECT_TRUE(EndsWith("hello there!", "!"));
- EXPECT_FALSE(EndsWith("hello there!", "there"));
- EXPECT_FALSE(EndsWith("hello there!", " hello there!"));
- EXPECT_TRUE(EndsWith("hello there!", ""));
- EXPECT_FALSE(EndsWith("", "hello there!"));
-}
-
-TEST(StringPieceTest, StartsWith) {
- EXPECT_TRUE(StartsWith("hello there!", "hello"));
- EXPECT_TRUE(StartsWith("hello there!", "hello "));
- EXPECT_FALSE(StartsWith("hello there!", "there!"));
- EXPECT_FALSE(StartsWith("hello there!", " hello there! "));
- EXPECT_TRUE(StartsWith("hello there!", ""));
- EXPECT_FALSE(StartsWith("", "hello there!"));
-}
-
-TEST(StringPieceTest, ConsumePrefix) {
- StringPiece str("hello there!");
- EXPECT_TRUE(ConsumePrefix(&str, "hello "));
- EXPECT_EQ(str.ToString(), "there!");
- EXPECT_TRUE(ConsumePrefix(&str, "there"));
- EXPECT_EQ(str.ToString(), "!");
- EXPECT_FALSE(ConsumePrefix(&str, "!!"));
- EXPECT_TRUE(ConsumePrefix(&str, ""));
- EXPECT_TRUE(ConsumePrefix(&str, "!"));
- EXPECT_EQ(str.ToString(), "");
- EXPECT_TRUE(ConsumePrefix(&str, ""));
- EXPECT_FALSE(ConsumePrefix(&str, "!"));
-}
-
-TEST(StringPieceTest, ConsumeSuffix) {
- StringPiece str("hello there!");
- EXPECT_TRUE(ConsumeSuffix(&str, "!"));
- EXPECT_EQ(str.ToString(), "hello there");
- EXPECT_TRUE(ConsumeSuffix(&str, " there"));
- EXPECT_EQ(str.ToString(), "hello");
- EXPECT_FALSE(ConsumeSuffix(&str, "!!"));
- EXPECT_TRUE(ConsumeSuffix(&str, ""));
- EXPECT_TRUE(ConsumeSuffix(&str, "hello"));
- EXPECT_EQ(str.ToString(), "");
- EXPECT_TRUE(ConsumeSuffix(&str, ""));
- EXPECT_FALSE(ConsumeSuffix(&str, "!"));
-}
-
-TEST(StringPieceTest, Find) {
- StringPiece str("<hello there!>");
- EXPECT_EQ(str.find('<'), 0);
- EXPECT_EQ(str.find('>'), str.length() - 1);
- EXPECT_EQ(str.find('?'), StringPiece::npos);
- EXPECT_EQ(str.find('<', str.length() - 1), StringPiece::npos);
- EXPECT_EQ(str.find('<', 0), 0);
- EXPECT_EQ(str.find('>', str.length() - 1), str.length() - 1);
-}
-
-TEST(StringPieceTest, FindStringPiece) {
- StringPiece str("<foo bar baz!>");
- EXPECT_EQ(str.find("foo"), 1);
- EXPECT_EQ(str.find("bar"), 5);
- EXPECT_EQ(str.find("baz"), 9);
- EXPECT_EQ(str.find("qux"), StringPiece::npos);
- EXPECT_EQ(str.find("?"), StringPiece::npos);
- EXPECT_EQ(str.find(">"), str.length() - 1);
- EXPECT_EQ(str.find("<", str.length() - 1), StringPiece::npos);
- EXPECT_EQ(str.find("<", 0), 0);
- EXPECT_EQ(str.find(">", str.length() - 1), str.length() - 1);
-}
-
-} // namespace
-} // namespace libtextclassifier3
diff --git a/native/utils/strings/substitute_test.cc b/native/utils/strings/substitute_test.cc
deleted file mode 100644
index 94b37ab..0000000
--- a/native/utils/strings/substitute_test.cc
+++ /dev/null
@@ -1,48 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/strings/substitute.h"
-
-#include <vector>
-
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-
-#include "utils/strings/stringpiece.h"
-
-namespace libtextclassifier3 {
-namespace {
-
-TEST(SubstituteTest, Substitute) {
- EXPECT_EQ("Hello, world!",
- strings::Substitute("$0, $1!", {"Hello", "world"}));
-
- // Out of order.
- EXPECT_EQ("world, Hello!",
- strings::Substitute("$1, $0!", {"Hello", "world"}));
- EXPECT_EQ("b, a, c, b",
- strings::Substitute("$1, $0, $2, $1", {"a", "b", "c"}));
-
- // Literal $
- EXPECT_EQ("$", strings::Substitute("$$", {}));
- EXPECT_EQ("$1", strings::Substitute("$$1", {}));
-
- const char* null_cstring = nullptr;
- EXPECT_EQ("Text: ''", strings::Substitute("Text: '$0'", {null_cstring}));
-}
-
-} // namespace
-} // namespace libtextclassifier3
diff --git a/native/utils/strings/utf8_test.cc b/native/utils/strings/utf8_test.cc
deleted file mode 100644
index 59f3864..0000000
--- a/native/utils/strings/utf8_test.cc
+++ /dev/null
@@ -1,81 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/strings/utf8.h"
-
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-
-namespace libtextclassifier3 {
-namespace {
-
-using testing::Eq;
-
-TEST(Utf8Test, GetNumBytesForUTF8Char) {
- EXPECT_THAT(GetNumBytesForUTF8Char("\x00"), Eq(0));
- EXPECT_THAT(GetNumBytesForUTF8Char("h"), Eq(1));
- EXPECT_THAT(GetNumBytesForUTF8Char("😋"), Eq(4));
- EXPECT_THAT(GetNumBytesForUTF8Char("㍿"), Eq(3));
-}
-
-TEST(Utf8Test, IsValidUTF8) {
- EXPECT_TRUE(IsValidUTF8("1234😋hello", 13));
- EXPECT_TRUE(IsValidUTF8("\u304A\u00B0\u106B", 8));
- EXPECT_TRUE(IsValidUTF8("this is a test😋😋😋", 26));
- EXPECT_TRUE(IsValidUTF8("\xf0\x9f\x98\x8b", 4));
- // Too short (string is too short).
- EXPECT_FALSE(IsValidUTF8("\xf0\x9f", 2));
- // Too long (too many trailing bytes).
- EXPECT_FALSE(IsValidUTF8("\xf0\x9f\x98\x8b\x8b", 5));
- // Too short (too few trailing bytes).
- EXPECT_FALSE(IsValidUTF8("\xf0\x9f\x98\x61\x61", 5));
-}
-
-TEST(Utf8Test, ValidUTF8CharLength) {
- EXPECT_THAT(ValidUTF8CharLength("1234😋hello", 13), Eq(1));
- EXPECT_THAT(ValidUTF8CharLength("\u304A\u00B0\u106B", 8), Eq(3));
- EXPECT_THAT(ValidUTF8CharLength("this is a test😋😋😋", 26), Eq(1));
- EXPECT_THAT(ValidUTF8CharLength("\xf0\x9f\x98\x8b", 4), Eq(4));
- // Too short (string is too short).
- EXPECT_THAT(ValidUTF8CharLength("\xf0\x9f", 2), Eq(-1));
- // Too long (too many trailing bytes). First character is valid.
- EXPECT_THAT(ValidUTF8CharLength("\xf0\x9f\x98\x8b\x8b", 5), Eq(4));
- // Too short (too few trailing bytes).
- EXPECT_THAT(ValidUTF8CharLength("\xf0\x9f\x98\x61\x61", 5), Eq(-1));
-}
-
-TEST(Utf8Test, CorrectlyTruncatesStrings) {
- EXPECT_THAT(SafeTruncateLength("FooBar", 3), Eq(3));
- EXPECT_THAT(SafeTruncateLength("früh", 3), Eq(2));
- EXPECT_THAT(SafeTruncateLength("مَمِمّمَّمِّ", 5), Eq(4));
-}
-
-TEST(Utf8Test, CorrectlyConvertsFromUtf8) {
- EXPECT_THAT(ValidCharToRune("a"), Eq(97));
- EXPECT_THAT(ValidCharToRune("\0"), Eq(0));
- EXPECT_THAT(ValidCharToRune("\u304A"), Eq(0x304a));
- EXPECT_THAT(ValidCharToRune("\xe3\x81\x8a"), Eq(0x304a));
-}
-
-TEST(Utf8Test, CorrectlyConvertsToUtf8) {
- char utf8_encoding[4];
- EXPECT_THAT(ValidRuneToChar(97, utf8_encoding), Eq(1));
- EXPECT_THAT(ValidRuneToChar(0, utf8_encoding), Eq(1));
- EXPECT_THAT(ValidRuneToChar(0x304a, utf8_encoding), Eq(3));
-}
-
-} // namespace
-} // namespace libtextclassifier3
diff --git a/native/utils/tensor-view_test.cc b/native/utils/tensor-view_test.cc
deleted file mode 100644
index 9467264..0000000
--- a/native/utils/tensor-view_test.cc
+++ /dev/null
@@ -1,52 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/tensor-view.h"
-
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-
-namespace libtextclassifier3 {
-namespace {
-
-TEST(TensorViewTest, TestSize) {
- std::vector<float> data{0.1, 0.2, 0.3, 0.4, 0.5, 0.6};
- const TensorView<float> tensor(data.data(), {3, 1, 2});
- EXPECT_TRUE(tensor.is_valid());
- EXPECT_EQ(tensor.shape(), (std::vector<int>{3, 1, 2}));
- EXPECT_EQ(tensor.data(), data.data());
- EXPECT_EQ(tensor.size(), 6);
- EXPECT_EQ(tensor.dims(), 3);
- EXPECT_EQ(tensor.dim(0), 3);
- EXPECT_EQ(tensor.dim(1), 1);
- EXPECT_EQ(tensor.dim(2), 2);
- std::vector<float> output_data(6);
- EXPECT_TRUE(tensor.copy_to(output_data.data(), output_data.size()));
- EXPECT_EQ(data, output_data);
-
- // Should not copy when the output is small.
- std::vector<float> small_output_data{-1, -1, -1};
- EXPECT_FALSE(
- tensor.copy_to(small_output_data.data(), small_output_data.size()));
- // The output buffer should not be changed.
- EXPECT_EQ(small_output_data, (std::vector<float>{-1, -1, -1}));
-
- const TensorView<float> invalid_tensor = TensorView<float>::Invalid();
- EXPECT_FALSE(invalid_tensor.is_valid());
-}
-
-} // namespace
-} // namespace libtextclassifier3
diff --git a/native/utils/test-utils.cc b/native/utils/test-utils.cc
deleted file mode 100644
index cad7be5..0000000
--- a/native/utils/test-utils.cc
+++ /dev/null
@@ -1,58 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/test-utils.h"
-
-#include <iterator>
-
-#include "utils/codepoint-range.h"
-#include "utils/strings/utf8.h"
-#include "utils/utf8/unicodetext.h"
-
-namespace libtextclassifier3 {
-
-using libtextclassifier3::Token;
-
-std::vector<Token> TokenizeOnSpace(const std::string& text) {
- const char32 space_codepoint = ValidCharToRune(" ");
- const UnicodeText unicode_text = UTF8ToUnicodeText(text, /*do_copy=*/false);
-
- std::vector<Token> result;
-
- int token_start_codepoint = 0;
- auto token_start_it = unicode_text.begin();
- int codepoint_idx = 0;
-
- UnicodeText::const_iterator it;
- for (it = unicode_text.begin(); it < unicode_text.end(); it++) {
- if (*it == space_codepoint) {
- result.push_back(Token{UnicodeText::UTF8Substring(token_start_it, it),
- token_start_codepoint, codepoint_idx});
-
- token_start_codepoint = codepoint_idx + 1;
- token_start_it = it;
- token_start_it++;
- }
-
- codepoint_idx++;
- }
- result.push_back(Token{UnicodeText::UTF8Substring(token_start_it, it),
- token_start_codepoint, codepoint_idx});
-
- return result;
-}
-
-} // namespace libtextclassifier3
diff --git a/native/utils/test-utils.h b/native/utils/test-utils.h
deleted file mode 100644
index 582736f..0000000
--- a/native/utils/test-utils.h
+++ /dev/null
@@ -1,33 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-// Utilities for tests.
-
-#ifndef LIBTEXTCLASSIFIER_UTILS_TEST_UTILS_H_
-#define LIBTEXTCLASSIFIER_UTILS_TEST_UTILS_H_
-
-#include <string>
-
-#include "annotator/types.h"
-
-namespace libtextclassifier3 {
-
-// Returns a list of Tokens for a given input string.
-std::vector<Token> TokenizeOnSpace(const std::string& text);
-
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_UTILS_TEST_UTILS_H_
diff --git a/native/utils/test-utils_test.cc b/native/utils/test-utils_test.cc
deleted file mode 100644
index f596de4..0000000
--- a/native/utils/test-utils_test.cc
+++ /dev/null
@@ -1,69 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/test-utils.h"
-
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-
-namespace libtextclassifier3 {
-namespace {
-
-TEST(TestUtilTest, Utf8) {
- std::vector<Token> tokens =
- TokenizeOnSpace("Where is Jörg Borg located? Maybe in Zürich ...");
-
- EXPECT_EQ(tokens.size(), 9);
-
- EXPECT_EQ(tokens[0].value, "Where");
- EXPECT_EQ(tokens[0].start, 0);
- EXPECT_EQ(tokens[0].end, 5);
-
- EXPECT_EQ(tokens[1].value, "is");
- EXPECT_EQ(tokens[1].start, 6);
- EXPECT_EQ(tokens[1].end, 8);
-
- EXPECT_EQ(tokens[2].value, "Jörg");
- EXPECT_EQ(tokens[2].start, 9);
- EXPECT_EQ(tokens[2].end, 13);
-
- EXPECT_EQ(tokens[3].value, "Borg");
- EXPECT_EQ(tokens[3].start, 14);
- EXPECT_EQ(tokens[3].end, 18);
-
- EXPECT_EQ(tokens[4].value, "located?");
- EXPECT_EQ(tokens[4].start, 19);
- EXPECT_EQ(tokens[4].end, 27);
-
- EXPECT_EQ(tokens[5].value, "Maybe");
- EXPECT_EQ(tokens[5].start, 28);
- EXPECT_EQ(tokens[5].end, 33);
-
- EXPECT_EQ(tokens[6].value, "in");
- EXPECT_EQ(tokens[6].start, 34);
- EXPECT_EQ(tokens[6].end, 36);
-
- EXPECT_EQ(tokens[7].value, "Zürich");
- EXPECT_EQ(tokens[7].start, 37);
- EXPECT_EQ(tokens[7].end, 43);
-
- EXPECT_EQ(tokens[8].value, "...");
- EXPECT_EQ(tokens[8].start, 44);
- EXPECT_EQ(tokens[8].end, 47);
-}
-
-} // namespace
-} // namespace libtextclassifier3
diff --git a/native/utils/testing/annotator.cc b/native/utils/testing/annotator.cc
deleted file mode 100644
index 43ed1df..0000000
--- a/native/utils/testing/annotator.cc
+++ /dev/null
@@ -1,200 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/testing/annotator.h"
-
-namespace libtextclassifier3 {
-
-std::string FirstResult(const std::vector<ClassificationResult>& results) {
- if (results.empty()) {
- return "<INVALID RESULTS>";
- }
- return results[0].collection;
-}
-
-std::string ReadFile(const std::string& file_name) {
- std::ifstream file_stream(file_name);
- return std::string(std::istreambuf_iterator<char>(file_stream), {});
-}
-
-std::unique_ptr<RegexModel_::PatternT> MakePattern(
- const std::string& collection_name, const std::string& pattern,
- const bool enabled_for_classification, const bool enabled_for_selection,
- const bool enabled_for_annotation, const float score,
- const float priority_score) {
- std::unique_ptr<RegexModel_::PatternT> result(new RegexModel_::PatternT);
- result->collection_name = collection_name;
- result->pattern = pattern;
- // We cannot directly operate with |= on the flag, so use an int here.
- int enabled_modes = ModeFlag_NONE;
- if (enabled_for_annotation) enabled_modes |= ModeFlag_ANNOTATION;
- if (enabled_for_classification) enabled_modes |= ModeFlag_CLASSIFICATION;
- if (enabled_for_selection) enabled_modes |= ModeFlag_SELECTION;
- result->enabled_modes = static_cast<ModeFlag>(enabled_modes);
- result->target_classification_score = score;
- result->priority_score = priority_score;
- return result;
-}
-
-// Shortcut function that doesn't need to specify the priority score.
-std::unique_ptr<RegexModel_::PatternT> MakePattern(
- const std::string& collection_name, const std::string& pattern,
- const bool enabled_for_classification, const bool enabled_for_selection,
- const bool enabled_for_annotation, const float score) {
- return MakePattern(collection_name, pattern, enabled_for_classification,
- enabled_for_selection, enabled_for_annotation,
- /*score=*/score,
- /*priority_score=*/score);
-}
-
-void AddTestRegexModel(ModelT* unpacked_model) {
- // Add test regex models.
- unpacked_model->regex_model->patterns.push_back(MakePattern(
- "person_with_age", "(Barack) (?:(Obama) )?is (\\d+) years old",
- /*enabled_for_classification=*/true,
- /*enabled_for_selection=*/true, /*enabled_for_annotation=*/true, 1.0));
-
- // Use meta data to generate custom serialized entity data.
- ReflectiveFlatbufferBuilder entity_data_builder(
- flatbuffers::GetRoot<reflection::Schema>(
- unpacked_model->entity_data_schema.data()));
- RegexModel_::PatternT* pattern =
- unpacked_model->regex_model->patterns.back().get();
-
- {
- std::unique_ptr<ReflectiveFlatbuffer> entity_data =
- entity_data_builder.NewRoot();
- entity_data->Set("is_alive", true);
- pattern->serialized_entity_data = entity_data->Serialize();
- }
- pattern->capturing_group.emplace_back(new CapturingGroupT);
- pattern->capturing_group.emplace_back(new CapturingGroupT);
- pattern->capturing_group.emplace_back(new CapturingGroupT);
- pattern->capturing_group.emplace_back(new CapturingGroupT);
- // Group 0 is the full match, capturing groups starting at 1.
- pattern->capturing_group[1]->entity_field_path.reset(
- new FlatbufferFieldPathT);
- pattern->capturing_group[1]->entity_field_path->field.emplace_back(
- new FlatbufferFieldT);
- pattern->capturing_group[1]->entity_field_path->field.back()->field_name =
- "first_name";
- pattern->capturing_group[2]->entity_field_path.reset(
- new FlatbufferFieldPathT);
- pattern->capturing_group[2]->entity_field_path->field.emplace_back(
- new FlatbufferFieldT);
- pattern->capturing_group[2]->entity_field_path->field.back()->field_name =
- "last_name";
- // Set `former_us_president` field if we match Obama.
- {
- std::unique_ptr<ReflectiveFlatbuffer> entity_data =
- entity_data_builder.NewRoot();
- entity_data->Set("former_us_president", true);
- pattern->capturing_group[2]->serialized_entity_data =
- entity_data->Serialize();
- }
- pattern->capturing_group[3]->entity_field_path.reset(
- new FlatbufferFieldPathT);
- pattern->capturing_group[3]->entity_field_path->field.emplace_back(
- new FlatbufferFieldT);
- pattern->capturing_group[3]->entity_field_path->field.back()->field_name =
- "age";
-}
-
-std::string CreateEmptyModel() {
- ModelT model;
- flatbuffers::FlatBufferBuilder builder;
- FinishModelBuffer(builder, Model::Pack(builder, &model));
- return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
- builder.GetSize());
-}
-
-// Create fake entity data schema meta data.
-void AddTestEntitySchemaData(ModelT* unpacked_model) {
- // Cannot use object oriented API here as that is not available for the
- // reflection schema.
- flatbuffers::FlatBufferBuilder schema_builder;
- std::vector<flatbuffers::Offset<reflection::Field>> fields = {
- reflection::CreateField(
- schema_builder,
- /*name=*/schema_builder.CreateString("first_name"),
- /*type=*/
- reflection::CreateType(schema_builder,
- /*base_type=*/reflection::String),
- /*id=*/0,
- /*offset=*/4),
- reflection::CreateField(
- schema_builder,
- /*name=*/schema_builder.CreateString("is_alive"),
- /*type=*/
- reflection::CreateType(schema_builder,
- /*base_type=*/reflection::Bool),
- /*id=*/1,
- /*offset=*/6),
- reflection::CreateField(
- schema_builder,
- /*name=*/schema_builder.CreateString("last_name"),
- /*type=*/
- reflection::CreateType(schema_builder,
- /*base_type=*/reflection::String),
- /*id=*/2,
- /*offset=*/8),
- reflection::CreateField(
- schema_builder,
- /*name=*/schema_builder.CreateString("age"),
- /*type=*/
- reflection::CreateType(schema_builder,
- /*base_type=*/reflection::Int),
- /*id=*/3,
- /*offset=*/10),
- reflection::CreateField(
- schema_builder,
- /*name=*/schema_builder.CreateString("former_us_president"),
- /*type=*/
- reflection::CreateType(schema_builder,
- /*base_type=*/reflection::Bool),
- /*id=*/4,
- /*offset=*/12)};
- std::vector<flatbuffers::Offset<reflection::Enum>> enums;
- std::vector<flatbuffers::Offset<reflection::Object>> objects = {
- reflection::CreateObject(
- schema_builder,
- /*name=*/schema_builder.CreateString("EntityData"),
- /*fields=*/
- schema_builder.CreateVectorOfSortedTables(&fields))};
- schema_builder.Finish(reflection::CreateSchema(
- schema_builder, schema_builder.CreateVectorOfSortedTables(&objects),
- schema_builder.CreateVectorOfSortedTables(&enums),
- /*(unused) file_ident=*/0,
- /*(unused) file_ext=*/0,
- /*root_table*/ objects[0]));
-
- unpacked_model->entity_data_schema.assign(
- schema_builder.GetBufferPointer(),
- schema_builder.GetBufferPointer() + schema_builder.GetSize());
-}
-
-AnnotatedSpan MakeAnnotatedSpan(CodepointSpan span,
- const std::string& collection,
- const float score,
- AnnotatedSpan::Source source) {
- AnnotatedSpan result;
- result.span = span;
- result.classification.push_back({collection, score});
- result.source = source;
- return result;
-}
-
-} // namespace libtextclassifier3
diff --git a/native/utils/testing/annotator.h b/native/utils/testing/annotator.h
deleted file mode 100644
index 1565e0d..0000000
--- a/native/utils/testing/annotator.h
+++ /dev/null
@@ -1,79 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-// Helper utilities for testing Annotator.
-
-#ifndef LIBTEXTCLASSIFIER_UTILS_TESTING_ANNOTATOR_H_
-#define LIBTEXTCLASSIFIER_UTILS_TESTING_ANNOTATOR_H_
-
-#include <fstream>
-#include <memory>
-#include <string>
-
-#include "annotator/model_generated.h"
-#include "annotator/types.h"
-#include "flatbuffers/flatbuffers.h"
-
-namespace libtextclassifier3 {
-
-// Loads FlatBuffer model, unpacks it and passes it to the visitor_fn so that it
-// can modify it. Afterwards the modified unpacked model is serialized back to a
-// flatbuffer.
-template <typename Fn>
-std::string ModifyAnnotatorModel(const std::string& model_flatbuffer,
- Fn visitor_fn) {
- std::unique_ptr<ModelT> unpacked_model =
- UnPackModel(model_flatbuffer.c_str());
-
- visitor_fn(unpacked_model.get());
-
- flatbuffers::FlatBufferBuilder builder;
- FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
-
- return std::string(reinterpret_cast<char*>(builder.GetBufferPointer()),
- builder.GetSize());
-}
-
-std::string FirstResult(const std::vector<ClassificationResult>& results);
-
-std::string ReadFile(const std::string& file_name);
-
-std::unique_ptr<RegexModel_::PatternT> MakePattern(
- const std::string& collection_name, const std::string& pattern,
- const bool enabled_for_classification, const bool enabled_for_selection,
- const bool enabled_for_annotation, const float score,
- const float priority_score);
-
-// Shortcut function that doesn't need to specify the priority score.
-std::unique_ptr<RegexModel_::PatternT> MakePattern(
- const std::string& collection_name, const std::string& pattern,
- const bool enabled_for_classification, const bool enabled_for_selection,
- const bool enabled_for_annotation, const float score);
-
-void AddTestRegexModel(ModelT* unpacked_model);
-
-std::string CreateEmptyModel();
-
-// Create fake entity data schema meta data.
-void AddTestEntitySchemaData(ModelT* unpacked_model);
-
-AnnotatedSpan MakeAnnotatedSpan(
- CodepointSpan span, const std::string& collection, const float score,
- AnnotatedSpan::Source source = AnnotatedSpan::Source::OTHER);
-
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_UTILS_TESTING_ANNOTATOR_H_
diff --git a/native/utils/testing/logging_event_listener.h b/native/utils/testing/logging_event_listener.h
deleted file mode 100644
index 2663a9c..0000000
--- a/native/utils/testing/logging_event_listener.h
+++ /dev/null
@@ -1,62 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_UTILS_TESTING_LOGGING_EVENT_LISTENER_H_
-#define LIBTEXTCLASSIFIER_UTILS_TESTING_LOGGING_EVENT_LISTENER_H_
-
-#include "gtest/gtest.h"
-
-namespace libtextclassifier3 {
-
-// TestEventListener that writes test results to the log so that they will be
-// visible in the logcat output in Sponge.
-// The formatting of the output is patterend after the output produced by the
-// standard PrettyUnitTestResultPrinter.
-class LoggingEventListener : public ::testing::TestEventListener {
- public:
- void OnTestProgramStart(const testing::UnitTest& unit_test) override;
-
- void OnTestIterationStart(const testing::UnitTest& unit_test,
- int iteration) override;
-
- void OnEnvironmentsSetUpStart(const testing::UnitTest& unit_test) override;
-
- void OnEnvironmentsSetUpEnd(const testing::UnitTest& unit_test) override;
-
- void OnTestCaseStart(const testing::TestCase& test_case) override;
-
- void OnTestStart(const testing::TestInfo& test_info) override;
-
- void OnTestPartResult(
- const testing::TestPartResult& test_part_result) override;
-
- void OnTestEnd(const testing::TestInfo& test_info) override;
-
- void OnTestCaseEnd(const testing::TestCase& test_case) override;
-
- void OnEnvironmentsTearDownStart(const testing::UnitTest& unit_test) override;
-
- void OnEnvironmentsTearDownEnd(const testing::UnitTest& unit_test) override;
-
- void OnTestIterationEnd(const testing::UnitTest& unit_test,
- int iteration) override;
-
- void OnTestProgramEnd(const testing::UnitTest& unit_test) override;
-};
-
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_UTILS_TESTING_LOGGING_EVENT_LISTENER_H_
diff --git a/native/utils/tflite/dist_diversification_test.cc b/native/utils/tflite/dist_diversification_test.cc
deleted file mode 100644
index 2380116..0000000
--- a/native/utils/tflite/dist_diversification_test.cc
+++ /dev/null
@@ -1,82 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/tflite/dist_diversification.h"
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-#include "tensorflow/lite/interpreter.h"
-#include "tensorflow/lite/kernels/register.h"
-#include "tensorflow/lite/kernels/test_util.h"
-#include "tensorflow/lite/model.h"
-
-namespace libtextclassifier3 {
-namespace {
-
-class DistanceDiversificationOpModel : public tflite::SingleOpModel {
- public:
- explicit DistanceDiversificationOpModel(int matrix_rows);
- void SetDistanceMatrix(const std::initializer_list<float>& values) {
- PopulateTensor(distance_matrix_, values);
- }
- void SetNumOutput(int length) { PopulateTensor(num_results_, {length}); }
- void SetMinDistance(float min_distance) {
- PopulateTensor(min_distance_, {min_distance});
- }
- int GetOutputLen() { return ExtractVector<int>(output_len_).front(); }
- std::vector<int> GetOutputIndexes(int output_length) {
- auto res = ExtractVector<int>(output_indexes_);
- res.resize(output_length);
- return res;
- }
-
- private:
- int distance_matrix_;
- int num_results_;
- int min_distance_;
-
- int output_len_;
- int output_indexes_;
-};
-
-DistanceDiversificationOpModel::DistanceDiversificationOpModel(
- int matrix_rows) {
- distance_matrix_ = AddInput(tflite::TensorType_FLOAT32);
- min_distance_ = AddInput(tflite::TensorType_FLOAT32);
- num_results_ = AddInput(tflite::TensorType_INT32);
-
- output_indexes_ = AddOutput(tflite::TensorType_INT32);
- output_len_ = AddOutput(tflite::TensorType_INT32);
- SetCustomOp("DistanceDiversification", {},
- tflite::ops::custom::Register_DISTANCE_DIVERSIFICATION);
- BuildInterpreter({{matrix_rows, matrix_rows}, {1}, {1}});
-}
-
-// Tests
-TEST(DistanceDiversificationOp, Simple) {
- DistanceDiversificationOpModel m(5);
- m.SetDistanceMatrix({0.0, 0.1, 0.2, 0.3, 0.4, 0.1, 0.0, 0.1, 0.2,
- 0.3, 0.2, 0.1, 0.0, 0.1, 0.2, 0.3, 0.2, 0.1,
- 0.0, 0.1, 0.4, 0.3, 0.2, 0.1, 0.0});
- m.SetMinDistance(0.21);
- m.SetNumOutput(3);
- m.Invoke();
- const int output_length = m.GetOutputLen();
- EXPECT_EQ(output_length, 2);
- EXPECT_THAT(m.GetOutputIndexes(output_length), testing::ElementsAre(0, 3));
-}
-
-} // namespace
-} // namespace libtextclassifier3
diff --git a/native/utils/tflite/encoder_common_test.cc b/native/utils/tflite/encoder_common_test.cc
deleted file mode 100644
index 247689f..0000000
--- a/native/utils/tflite/encoder_common_test.cc
+++ /dev/null
@@ -1,34 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/tflite/encoder_common.h"
-
-#include "gtest/gtest.h"
-#include "tensorflow/lite/model.h"
-
-namespace libtextclassifier3 {
-namespace {
-
-TEST(EncoderUtilsTest, CreateIntArray) {
- TfLiteIntArray* a = CreateIntArray({1, 2, 3});
- EXPECT_EQ(a->data[0], 1);
- EXPECT_EQ(a->data[1], 2);
- EXPECT_EQ(a->data[2], 3);
- TfLiteIntArrayFree(a);
-}
-
-} // namespace
-} // namespace libtextclassifier3
diff --git a/native/utils/tflite/text_encoder_test.cc b/native/utils/tflite/text_encoder_test.cc
deleted file mode 100644
index 6386432..0000000
--- a/native/utils/tflite/text_encoder_test.cc
+++ /dev/null
@@ -1,176 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include <fstream>
-#include <string>
-#include <vector>
-
-#include "utils/tflite/text_encoder.h"
-#include "gtest/gtest.h"
-#include "third_party/absl/flags/flag.h"
-#include "flatbuffers/flexbuffers.h"
-#include "tensorflow/lite/interpreter.h"
-#include "tensorflow/lite/kernels/register.h"
-#include "tensorflow/lite/kernels/test_util.h"
-#include "tensorflow/lite/model.h"
-#include "tensorflow/lite/string_util.h"
-
-namespace libtextclassifier3 {
-namespace {
-
-std::string GetTestConfigPath() {
- return "";
-}
-
-class TextEncoderOpModel : public tflite::SingleOpModel {
- public:
- TextEncoderOpModel(std::initializer_list<int> input_strings_shape,
- std::initializer_list<int> attribute_shape);
- void SetInputText(const std::initializer_list<std::string>& strings) {
- PopulateStringTensor(input_string_, strings);
- PopulateTensor(input_length_, {static_cast<int32_t>(strings.size())});
- }
- void SetMaxOutputLength(int length) {
- PopulateTensor(input_output_maxlength_, {length});
- }
- void SetInt32Attribute(const std::initializer_list<int>& attribute) {
- PopulateTensor(input_attributes_int32_, attribute);
- }
- void SetFloatAttribute(const std::initializer_list<float>& attribute) {
- PopulateTensor(input_attributes_float_, attribute);
- }
-
- std::vector<int> GetOutputEncoding() {
- return ExtractVector<int>(output_encoding_);
- }
- std::vector<int> GetOutputPositions() {
- return ExtractVector<int>(output_positions_);
- }
- std::vector<int> GetOutputAttributeInt32() {
- return ExtractVector<int>(output_attributes_int32_);
- }
- std::vector<float> GetOutputAttributeFloat() {
- return ExtractVector<float>(output_attributes_float_);
- }
- int GetEncodedLength() { return ExtractVector<int>(output_length_)[0]; }
-
- private:
- int input_string_;
- int input_length_;
- int input_output_maxlength_;
- int input_attributes_int32_;
- int input_attributes_float_;
-
- int output_encoding_;
- int output_positions_;
- int output_length_;
- int output_attributes_int32_;
- int output_attributes_float_;
-};
-
-TextEncoderOpModel::TextEncoderOpModel(
- std::initializer_list<int> input_strings_shape,
- std::initializer_list<int> attribute_shape) {
- input_string_ = AddInput(tflite::TensorType_STRING);
- input_length_ = AddInput(tflite::TensorType_INT32);
- input_output_maxlength_ = AddInput(tflite::TensorType_INT32);
- input_attributes_int32_ = AddInput(tflite::TensorType_INT32);
- input_attributes_float_ = AddInput(tflite::TensorType_FLOAT32);
-
- output_encoding_ = AddOutput(tflite::TensorType_INT32);
- output_positions_ = AddOutput(tflite::TensorType_INT32);
- output_length_ = AddOutput(tflite::TensorType_INT32);
- output_attributes_int32_ = AddOutput(tflite::TensorType_INT32);
- output_attributes_float_ = AddOutput(tflite::TensorType_FLOAT32);
-
- std::ifstream test_config_stream(GetTestConfigPath());
- std::string config((std::istreambuf_iterator<char>(test_config_stream)),
- (std::istreambuf_iterator<char>()));
- flexbuffers::Builder builder;
- builder.Map([&]() { builder.String("text_encoder_config", config); });
- builder.Finish();
- SetCustomOp("TextEncoder", builder.GetBuffer(),
- tflite::ops::custom::Register_TEXT_ENCODER);
- BuildInterpreter(
- {input_strings_shape, {1}, {1}, attribute_shape, attribute_shape});
-}
-
-// Tests
-TEST(TextEncoderTest, SimpleEncoder) {
- TextEncoderOpModel m({1, 1}, {1, 1});
- m.SetInputText({"Hello"});
- m.SetMaxOutputLength(10);
- m.SetInt32Attribute({7});
- m.SetFloatAttribute({3.f});
-
- m.Invoke();
-
- EXPECT_EQ(m.GetEncodedLength(), 5);
- EXPECT_THAT(m.GetOutputEncoding(),
- testing::ElementsAre(1, 90, 547, 58, 2, 2, 2, 2, 2, 2));
- EXPECT_THAT(m.GetOutputPositions(),
- testing::ElementsAre(0, 1, 2, 3, 4, 10, 10, 10, 10, 10));
- EXPECT_THAT(m.GetOutputAttributeInt32(),
- testing::ElementsAre(7, 7, 7, 7, 7, 7, 7, 7, 7, 7));
- EXPECT_THAT(
- m.GetOutputAttributeFloat(),
- testing::ElementsAre(3.f, 3.f, 3.f, 3.f, 3.f, 3.f, 3.f, 3.f, 3.f, 3.f));
-}
-
-TEST(TextEncoderTest, ManyStrings) {
- TextEncoderOpModel m({1, 3}, {1, 3});
- m.SetInt32Attribute({1, 2, 3});
- m.SetFloatAttribute({5.f, 4.f, 3.f});
- m.SetInputText({"Hello", "Hi", "Bye"});
- m.SetMaxOutputLength(10);
-
- m.Invoke();
-
- EXPECT_EQ(m.GetEncodedLength(), 10);
- EXPECT_THAT(m.GetOutputEncoding(),
- testing::ElementsAre(547, 58, 2, 1, 862, 2, 1, 1919, 19, 2));
- EXPECT_THAT(m.GetOutputPositions(),
- testing::ElementsAre(2, 3, 4, 0, 1, 2, 0, 1, 2, 3));
- EXPECT_THAT(m.GetOutputAttributeInt32(),
- testing::ElementsAre(1, 1, 1, 2, 2, 2, 3, 3, 3, 3));
- EXPECT_THAT(
- m.GetOutputAttributeFloat(),
- testing::ElementsAre(5.f, 5.f, 5.f, 4.f, 4.f, 4.f, 3.f, 3.f, 3.f, 3.f));
-}
-
-TEST(TextEncoderTest, LongStrings) {
- TextEncoderOpModel m({1, 4}, {1, 4});
- m.SetInt32Attribute({1, 2, 3, 4});
- m.SetFloatAttribute({5.f, 4.f, 3.f, 2.f});
- m.SetInputText({"Hello", "Hi", "Bye", "Hi"});
- m.SetMaxOutputLength(9);
-
- m.Invoke();
-
- EXPECT_EQ(m.GetEncodedLength(), 9);
- EXPECT_THAT(m.GetOutputEncoding(),
- testing::ElementsAre(862, 2, 1, 1919, 19, 2, 1, 862, 2));
- EXPECT_THAT(m.GetOutputPositions(),
- testing::ElementsAre(1, 2, 0, 1, 2, 3, 0, 1, 2));
- EXPECT_THAT(m.GetOutputAttributeInt32(),
- testing::ElementsAre(2, 2, 3, 3, 3, 3, 4, 4, 4));
- EXPECT_THAT(
- m.GetOutputAttributeFloat(),
- testing::ElementsAre(4.f, 4.f, 3.f, 3.f, 3.f, 3.f, 2.f, 2.f, 2.f));
-}
-
-} // namespace
-} // namespace libtextclassifier3
diff --git a/native/utils/tflite/token_encoder_test.cc b/native/utils/tflite/token_encoder_test.cc
deleted file mode 100644
index c7f51a1..0000000
--- a/native/utils/tflite/token_encoder_test.cc
+++ /dev/null
@@ -1,148 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include <vector>
-
-#include "utils/tflite/token_encoder.h"
-#include "gtest/gtest.h"
-#include "third_party/absl/flags/flag.h"
-#include "tensorflow/lite/interpreter.h"
-#include "tensorflow/lite/kernels/register.h"
-#include "tensorflow/lite/kernels/test_util.h"
-#include "tensorflow/lite/model.h"
-
-namespace libtextclassifier3 {
-namespace {
-
-class TokenEncoderOpModel : public tflite::SingleOpModel {
- public:
- TokenEncoderOpModel(std::initializer_list<int> input_shape,
- std::initializer_list<int> attribute_shape);
- void SetNumTokens(const std::initializer_list<int>& num_tokens) {
- PopulateTensor(input_num_tokens_, num_tokens);
- PopulateTensor(input_length_, {static_cast<int32_t>(num_tokens.size())});
- }
- void SetMaxOutputLength(int length) {
- PopulateTensor(input_output_maxlength_, {length});
- }
- void SetInt32Attribute(const std::initializer_list<int>& attribute) {
- PopulateTensor(input_attributes_int32_, attribute);
- }
- void SetFloatAttribute(const std::initializer_list<float>& attribute) {
- PopulateTensor(input_attributes_float_, attribute);
- }
- std::vector<int> GetOutputPositions() {
- return ExtractVector<int>(output_positions_);
- }
- std::vector<int> GetOutputAttributeInt32() {
- return ExtractVector<int>(output_attributes_int32_);
- }
- std::vector<float> GetOutputAttributeFloat() {
- return ExtractVector<float>(output_attributes_float_);
- }
- int GetOutputLength() { return ExtractVector<int>(output_length_)[0]; }
-
- private:
- int input_num_tokens_;
- int input_length_;
- int input_output_maxlength_;
- int input_attributes_int32_;
- int input_attributes_float_;
-
- int output_positions_;
- int output_length_;
- int output_attributes_int32_;
- int output_attributes_float_;
-};
-
-TokenEncoderOpModel::TokenEncoderOpModel(
- std::initializer_list<int> input_shape,
- std::initializer_list<int> attribute_shape) {
- input_num_tokens_ = AddInput(tflite::TensorType_INT32);
- input_length_ = AddInput(tflite::TensorType_INT32);
- input_output_maxlength_ = AddInput(tflite::TensorType_INT32);
- input_attributes_int32_ = AddInput(tflite::TensorType_INT32);
- input_attributes_float_ = AddInput(tflite::TensorType_FLOAT32);
-
- output_positions_ = AddOutput(tflite::TensorType_INT32);
- output_length_ = AddOutput(tflite::TensorType_INT32);
- output_attributes_int32_ = AddOutput(tflite::TensorType_INT32);
- output_attributes_float_ = AddOutput(tflite::TensorType_FLOAT32);
-
- SetCustomOp("TokenEncoder", {}, tflite::ops::custom::Register_TOKEN_ENCODER);
- BuildInterpreter({input_shape, {1}, {1}, attribute_shape, attribute_shape});
-}
-
-// Tests
-TEST(TokenEncoderTest, SimpleEncoder) {
- TokenEncoderOpModel m({1, 1}, {1, 1});
- m.SetNumTokens({1});
- m.SetMaxOutputLength(10);
- m.SetInt32Attribute({7});
- m.SetFloatAttribute({3.f});
-
- m.Invoke();
-
- EXPECT_EQ(m.GetOutputLength(), 3);
- EXPECT_THAT(m.GetOutputPositions(),
- testing::ElementsAre(0, 1, 2, 10, 10, 10, 10, 10, 10, 10));
- EXPECT_THAT(m.GetOutputAttributeInt32(),
- testing::ElementsAre(7, 7, 7, 7, 7, 7, 7, 7, 7, 7));
- EXPECT_THAT(
- m.GetOutputAttributeFloat(),
- testing::ElementsAre(3.f, 3.f, 3.f, 3.f, 3.f, 3.f, 3.f, 3.f, 3.f, 3.f));
-}
-
-TEST(TokenEncoderTest, ManyMessages) {
- TokenEncoderOpModel m({1, 3}, {1, 3});
- m.SetInt32Attribute({1, 2, 3});
- m.SetFloatAttribute({5.f, 4.f, 3.f});
- m.SetNumTokens({1, 1, 1});
- m.SetMaxOutputLength(10);
-
- m.Invoke();
-
- EXPECT_EQ(m.GetOutputLength(), 9);
- EXPECT_THAT(m.GetOutputPositions(),
- testing::ElementsAre(0, 1, 2, 0, 1, 2, 0, 1, 2, 10));
- EXPECT_THAT(m.GetOutputAttributeInt32(),
- testing::ElementsAre(1, 1, 1, 2, 2, 2, 3, 3, 3, 3));
- EXPECT_THAT(
- m.GetOutputAttributeFloat(),
- testing::ElementsAre(5.f, 5.f, 5.f, 4.f, 4.f, 4.f, 3.f, 3.f, 3.f, 3.f));
-}
-
-TEST(TokenEncoderTest, ManyMessagesMultipleTokens) {
- TokenEncoderOpModel m({1, 4}, {1, 4});
- m.SetInt32Attribute({1, 2, 3, 4});
- m.SetFloatAttribute({5.f, 4.f, 3.f, 2.f});
- m.SetNumTokens({1, 2, 3, 4});
- m.SetMaxOutputLength(9);
-
- m.Invoke();
-
- EXPECT_EQ(m.GetOutputLength(), 9);
- EXPECT_THAT(m.GetOutputPositions(),
- testing::ElementsAre(2, 3, 4, 0, 1, 2, 3, 4, 5));
- EXPECT_THAT(m.GetOutputAttributeInt32(),
- testing::ElementsAre(3, 3, 3, 4, 4, 4, 4, 4, 4));
- EXPECT_THAT(
- m.GetOutputAttributeFloat(),
- testing::ElementsAre(3.f, 3.f, 3.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f));
-}
-
-} // namespace
-} // namespace libtextclassifier3
diff --git a/native/utils/token-feature-extractor.cc b/native/utils/token-feature-extractor.cc
index 9faebca..b14f96e 100644
--- a/native/utils/token-feature-extractor.cc
+++ b/native/utils/token-feature-extractor.cc
@@ -109,8 +109,7 @@
if (options_.unicode_aware_features) {
UnicodeText token_unicode =
UTF8ToUnicodeText(token.value, /*do_copy=*/false);
- const bool is_upper = unilib_.IsUpper(*token_unicode.begin());
- if (!token.value.empty() && is_upper) {
+ if (!token.value.empty() && unilib_.IsUpper(*token_unicode.begin())) {
dense_features.push_back(1.0);
} else {
dense_features.push_back(-1.0);
diff --git a/native/utils/token-feature-extractor_test.cc b/native/utils/token-feature-extractor_test.cc
deleted file mode 100644
index 9a97e42..0000000
--- a/native/utils/token-feature-extractor_test.cc
+++ /dev/null
@@ -1,556 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/token-feature-extractor.h"
-
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-
-namespace libtextclassifier3 {
-namespace {
-
-class TokenFeatureExtractorTest : public ::testing::Test {
- protected:
- TokenFeatureExtractorTest() : INIT_UNILIB_FOR_TESTING(unilib_) {}
- UniLib unilib_;
-};
-
-class TestingTokenFeatureExtractor : public TokenFeatureExtractor {
- public:
- using TokenFeatureExtractor::HashToken;
- using TokenFeatureExtractor::TokenFeatureExtractor;
-};
-
-TEST_F(TokenFeatureExtractorTest, ExtractAscii) {
- TokenFeatureExtractorOptions options;
- options.num_buckets = 1000;
- options.chargram_orders = std::vector<int>{1, 2, 3};
- options.extract_case_feature = true;
- options.unicode_aware_features = false;
- options.extract_selection_mask_feature = true;
- TestingTokenFeatureExtractor extractor(options, unilib_);
-
- std::vector<int> sparse_features;
- std::vector<float> dense_features;
-
- extractor.Extract(Token{"Hello", 0, 5}, true, &sparse_features,
- &dense_features);
-
- EXPECT_THAT(sparse_features,
- testing::ElementsAreArray({
- // clang-format off
- extractor.HashToken("H"),
- extractor.HashToken("e"),
- extractor.HashToken("l"),
- extractor.HashToken("l"),
- extractor.HashToken("o"),
- extractor.HashToken("^H"),
- extractor.HashToken("He"),
- extractor.HashToken("el"),
- extractor.HashToken("ll"),
- extractor.HashToken("lo"),
- extractor.HashToken("o$"),
- extractor.HashToken("^He"),
- extractor.HashToken("Hel"),
- extractor.HashToken("ell"),
- extractor.HashToken("llo"),
- extractor.HashToken("lo$")
- // clang-format on
- }));
- EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, 1.0}));
-
- sparse_features.clear();
- dense_features.clear();
- extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features,
- &dense_features);
-
- EXPECT_THAT(sparse_features,
- testing::ElementsAreArray({
- // clang-format off
- extractor.HashToken("w"),
- extractor.HashToken("o"),
- extractor.HashToken("r"),
- extractor.HashToken("l"),
- extractor.HashToken("d"),
- extractor.HashToken("!"),
- extractor.HashToken("^w"),
- extractor.HashToken("wo"),
- extractor.HashToken("or"),
- extractor.HashToken("rl"),
- extractor.HashToken("ld"),
- extractor.HashToken("d!"),
- extractor.HashToken("!$"),
- extractor.HashToken("^wo"),
- extractor.HashToken("wor"),
- extractor.HashToken("orl"),
- extractor.HashToken("rld"),
- extractor.HashToken("ld!"),
- extractor.HashToken("d!$"),
- // clang-format on
- }));
- EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 0.0}));
-}
-
-TEST_F(TokenFeatureExtractorTest, ExtractAsciiNoChargrams) {
- TokenFeatureExtractorOptions options;
- options.num_buckets = 1000;
- options.chargram_orders = std::vector<int>{};
- options.extract_case_feature = true;
- options.unicode_aware_features = false;
- options.extract_selection_mask_feature = true;
- TestingTokenFeatureExtractor extractor(options, unilib_);
-
- std::vector<int> sparse_features;
- std::vector<float> dense_features;
-
- extractor.Extract(Token{"Hello", 0, 5}, true, &sparse_features,
- &dense_features);
-
- EXPECT_THAT(sparse_features,
- testing::ElementsAreArray({extractor.HashToken("^Hello$")}));
- EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, 1.0}));
-
- sparse_features.clear();
- dense_features.clear();
- extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features,
- &dense_features);
-
- EXPECT_THAT(sparse_features,
- testing::ElementsAreArray({extractor.HashToken("^world!$")}));
- EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 0.0}));
-}
-
-TEST_F(TokenFeatureExtractorTest, ExtractUnicode) {
- TokenFeatureExtractorOptions options;
- options.num_buckets = 1000;
- options.chargram_orders = std::vector<int>{1, 2, 3};
- options.extract_case_feature = true;
- options.unicode_aware_features = true;
- options.extract_selection_mask_feature = true;
- TestingTokenFeatureExtractor extractor(options, unilib_);
-
- std::vector<int> sparse_features;
- std::vector<float> dense_features;
-
- extractor.Extract(Token{"Hělló", 0, 5}, true, &sparse_features,
- &dense_features);
-
- EXPECT_THAT(sparse_features,
- testing::ElementsAreArray({
- // clang-format off
- extractor.HashToken("H"),
- extractor.HashToken("ě"),
- extractor.HashToken("l"),
- extractor.HashToken("l"),
- extractor.HashToken("ó"),
- extractor.HashToken("^H"),
- extractor.HashToken("Hě"),
- extractor.HashToken("ěl"),
- extractor.HashToken("ll"),
- extractor.HashToken("ló"),
- extractor.HashToken("ó$"),
- extractor.HashToken("^Hě"),
- extractor.HashToken("Hěl"),
- extractor.HashToken("ěll"),
- extractor.HashToken("lló"),
- extractor.HashToken("ló$")
- // clang-format on
- }));
- EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, 1.0}));
-
- sparse_features.clear();
- dense_features.clear();
- extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features,
- &dense_features);
-
- EXPECT_THAT(sparse_features,
- testing::ElementsAreArray({
- // clang-format off
- extractor.HashToken("w"),
- extractor.HashToken("o"),
- extractor.HashToken("r"),
- extractor.HashToken("l"),
- extractor.HashToken("d"),
- extractor.HashToken("!"),
- extractor.HashToken("^w"),
- extractor.HashToken("wo"),
- extractor.HashToken("or"),
- extractor.HashToken("rl"),
- extractor.HashToken("ld"),
- extractor.HashToken("d!"),
- extractor.HashToken("!$"),
- extractor.HashToken("^wo"),
- extractor.HashToken("wor"),
- extractor.HashToken("orl"),
- extractor.HashToken("rld"),
- extractor.HashToken("ld!"),
- extractor.HashToken("d!$"),
- // clang-format on
- }));
- EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0}));
-}
-
-TEST_F(TokenFeatureExtractorTest, ExtractUnicodeNoChargrams) {
- TokenFeatureExtractorOptions options;
- options.num_buckets = 1000;
- options.chargram_orders = std::vector<int>{};
- options.extract_case_feature = true;
- options.unicode_aware_features = true;
- options.extract_selection_mask_feature = true;
- TestingTokenFeatureExtractor extractor(options, unilib_);
-
- std::vector<int> sparse_features;
- std::vector<float> dense_features;
-
- extractor.Extract(Token{"Hělló", 0, 5}, true, &sparse_features,
- &dense_features);
-
- EXPECT_THAT(sparse_features,
- testing::ElementsAreArray({extractor.HashToken("^Hělló$")}));
- EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, 1.0}));
-
- sparse_features.clear();
- dense_features.clear();
- extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features,
- &dense_features);
-
- EXPECT_THAT(sparse_features, testing::ElementsAreArray({
- extractor.HashToken("^world!$"),
- }));
- EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0}));
-}
-
-#ifdef TC3_TEST_ICU
-TEST_F(TokenFeatureExtractorTest, ICUCaseFeature) {
- TokenFeatureExtractorOptions options;
- options.num_buckets = 1000;
- options.chargram_orders = std::vector<int>{1, 2};
- options.extract_case_feature = true;
- options.unicode_aware_features = true;
- options.extract_selection_mask_feature = false;
- TestingTokenFeatureExtractor extractor(options, unilib_);
-
- std::vector<int> sparse_features;
- std::vector<float> dense_features;
- extractor.Extract(Token{"Hělló", 0, 5}, true, &sparse_features,
- &dense_features);
- EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0}));
-
- sparse_features.clear();
- dense_features.clear();
- extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features,
- &dense_features);
- EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0}));
-
- sparse_features.clear();
- dense_features.clear();
- extractor.Extract(Token{"Ř", 23, 29}, false, &sparse_features,
- &dense_features);
- EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0}));
-
- sparse_features.clear();
- dense_features.clear();
- extractor.Extract(Token{"ř", 23, 29}, false, &sparse_features,
- &dense_features);
- EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0}));
-}
-#endif
-
-TEST_F(TokenFeatureExtractorTest, DigitRemapping) {
- TokenFeatureExtractorOptions options;
- options.num_buckets = 1000;
- options.chargram_orders = std::vector<int>{1, 2};
- options.remap_digits = true;
- options.unicode_aware_features = false;
- TestingTokenFeatureExtractor extractor(options, unilib_);
-
- std::vector<int> sparse_features;
- std::vector<float> dense_features;
- extractor.Extract(Token{"9:30am", 0, 6}, true, &sparse_features,
- &dense_features);
-
- std::vector<int> sparse_features2;
- extractor.Extract(Token{"5:32am", 0, 6}, true, &sparse_features2,
- &dense_features);
- EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
-
- extractor.Extract(Token{"10:32am", 0, 6}, true, &sparse_features2,
- &dense_features);
- EXPECT_THAT(sparse_features,
- testing::Not(testing::ElementsAreArray(sparse_features2)));
-}
-
-TEST_F(TokenFeatureExtractorTest, DigitRemappingUnicode) {
- TokenFeatureExtractorOptions options;
- options.num_buckets = 1000;
- options.chargram_orders = std::vector<int>{1, 2};
- options.remap_digits = true;
- options.unicode_aware_features = true;
- TestingTokenFeatureExtractor extractor(options, unilib_);
-
- std::vector<int> sparse_features;
- std::vector<float> dense_features;
- extractor.Extract(Token{"9:30am", 0, 6}, true, &sparse_features,
- &dense_features);
-
- std::vector<int> sparse_features2;
- extractor.Extract(Token{"5:32am", 0, 6}, true, &sparse_features2,
- &dense_features);
- EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
-
- extractor.Extract(Token{"10:32am", 0, 6}, true, &sparse_features2,
- &dense_features);
- EXPECT_THAT(sparse_features,
- testing::Not(testing::ElementsAreArray(sparse_features2)));
-}
-
-TEST_F(TokenFeatureExtractorTest, LowercaseAscii) {
- TokenFeatureExtractorOptions options;
- options.num_buckets = 1000;
- options.chargram_orders = std::vector<int>{1, 2};
- options.lowercase_tokens = true;
- options.unicode_aware_features = false;
- TestingTokenFeatureExtractor extractor(options, unilib_);
-
- std::vector<int> sparse_features;
- std::vector<float> dense_features;
- extractor.Extract(Token{"AABB", 0, 6}, true, &sparse_features,
- &dense_features);
-
- std::vector<int> sparse_features2;
- extractor.Extract(Token{"aaBB", 0, 6}, true, &sparse_features2,
- &dense_features);
- EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
-
- extractor.Extract(Token{"aAbB", 0, 6}, true, &sparse_features2,
- &dense_features);
- EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
-}
-
-#ifdef TC3_TEST_ICU
-TEST_F(TokenFeatureExtractorTest, LowercaseUnicode) {
- TokenFeatureExtractorOptions options;
- options.num_buckets = 1000;
- options.chargram_orders = std::vector<int>{1, 2};
- options.lowercase_tokens = true;
- options.unicode_aware_features = true;
- TestingTokenFeatureExtractor extractor(options, unilib_);
-
- std::vector<int> sparse_features;
- std::vector<float> dense_features;
- extractor.Extract(Token{"ŘŘ", 0, 6}, true, &sparse_features, &dense_features);
-
- std::vector<int> sparse_features2;
- extractor.Extract(Token{"řř", 0, 6}, true, &sparse_features2,
- &dense_features);
- EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
-}
-#endif
-
-#ifdef TC3_TEST_ICU
-TEST_F(TokenFeatureExtractorTest, RegexFeatures) {
- TokenFeatureExtractorOptions options;
- options.num_buckets = 1000;
- options.chargram_orders = std::vector<int>{1, 2};
- options.remap_digits = false;
- options.unicode_aware_features = false;
- options.regexp_features.push_back("^[a-z]+$"); // all lower case.
- options.regexp_features.push_back("^[0-9]+$"); // all digits.
- TestingTokenFeatureExtractor extractor(options, unilib_);
-
- std::vector<int> sparse_features;
- std::vector<float> dense_features;
- extractor.Extract(Token{"abCde", 0, 6}, true, &sparse_features,
- &dense_features);
- EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0}));
-
- dense_features.clear();
- extractor.Extract(Token{"abcde", 0, 6}, true, &sparse_features,
- &dense_features);
- EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, -1.0}));
-
- dense_features.clear();
- extractor.Extract(Token{"12c45", 0, 6}, true, &sparse_features,
- &dense_features);
- EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0}));
-
- dense_features.clear();
- extractor.Extract(Token{"12345", 0, 6}, true, &sparse_features,
- &dense_features);
- EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 1.0}));
-}
-#endif
-
-TEST_F(TokenFeatureExtractorTest, ExtractTooLongWord) {
- TokenFeatureExtractorOptions options;
- options.num_buckets = 1000;
- options.chargram_orders = std::vector<int>{22};
- options.extract_case_feature = true;
- options.unicode_aware_features = true;
- options.extract_selection_mask_feature = true;
- TestingTokenFeatureExtractor extractor(options, unilib_);
-
- // Test that this runs. ASAN should catch problems.
- std::vector<int> sparse_features;
- std::vector<float> dense_features;
- extractor.Extract(Token{"abcdefghijklmnopqřstuvwxyz", 0, 0}, true,
- &sparse_features, &dense_features);
-
- EXPECT_THAT(sparse_features,
- testing::ElementsAreArray({
- // clang-format off
- extractor.HashToken("^abcdefghij\1qřstuvwxyz"),
- extractor.HashToken("abcdefghij\1qřstuvwxyz$"),
- // clang-format on
- }));
-}
-
-TEST_F(TokenFeatureExtractorTest, ExtractAsciiUnicodeMatches) {
- TokenFeatureExtractorOptions options;
- options.num_buckets = 1000;
- options.chargram_orders = std::vector<int>{1, 2, 3, 4, 5};
- options.extract_case_feature = true;
- options.unicode_aware_features = true;
- options.extract_selection_mask_feature = true;
-
- TestingTokenFeatureExtractor extractor_unicode(options, unilib_);
-
- options.unicode_aware_features = false;
- TestingTokenFeatureExtractor extractor_ascii(options, unilib_);
-
- for (const std::string& input :
- {"https://www.abcdefgh.com/in/xxxkkkvayio",
- "https://www.fjsidofj.om/xx/abadfy/xxxx/?xfjiis=ffffiijiihil",
- "asdfhasdofjiasdofj#%()*%#*(aisdojfaosdifjiaofjdsiofjdi_fdis3w", "abcd",
- "x", "Hello", "Hey,", "Hi", ""}) {
- std::vector<int> sparse_features_unicode;
- std::vector<float> dense_features_unicode;
- extractor_unicode.Extract(Token{input, 0, 0}, true,
- &sparse_features_unicode,
- &dense_features_unicode);
-
- std::vector<int> sparse_features_ascii;
- std::vector<float> dense_features_ascii;
- extractor_ascii.Extract(Token{input, 0, 0}, true, &sparse_features_ascii,
- &dense_features_ascii);
-
- EXPECT_THAT(sparse_features_unicode, sparse_features_ascii) << input;
- EXPECT_THAT(dense_features_unicode, dense_features_ascii) << input;
- }
-}
-
-TEST_F(TokenFeatureExtractorTest, ExtractForPadToken) {
- TokenFeatureExtractorOptions options;
- options.num_buckets = 1000;
- options.chargram_orders = std::vector<int>{1, 2};
- options.extract_case_feature = true;
- options.unicode_aware_features = false;
- options.extract_selection_mask_feature = true;
-
- TestingTokenFeatureExtractor extractor(options, unilib_);
-
- std::vector<int> sparse_features;
- std::vector<float> dense_features;
-
- extractor.Extract(Token(), false, &sparse_features, &dense_features);
-
- EXPECT_THAT(sparse_features,
- testing::ElementsAreArray({extractor.HashToken("<PAD>")}));
- EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 0.0}));
-}
-
-TEST_F(TokenFeatureExtractorTest, ExtractFiltered) {
- TokenFeatureExtractorOptions options;
- options.num_buckets = 1000;
- options.chargram_orders = std::vector<int>{1, 2, 3};
- options.extract_case_feature = true;
- options.unicode_aware_features = false;
- options.extract_selection_mask_feature = true;
- options.allowed_chargrams.insert("^H");
- options.allowed_chargrams.insert("ll");
- options.allowed_chargrams.insert("llo");
- options.allowed_chargrams.insert("w");
- options.allowed_chargrams.insert("!");
- options.allowed_chargrams.insert("\xc4"); // UTF8 control character.
-
- TestingTokenFeatureExtractor extractor(options, unilib_);
-
- std::vector<int> sparse_features;
- std::vector<float> dense_features;
-
- extractor.Extract(Token{"Hěllo", 0, 5}, true, &sparse_features,
- &dense_features);
-
- EXPECT_THAT(sparse_features,
- testing::ElementsAreArray({
- // clang-format off
- 0,
- extractor.HashToken("\xc4"),
- 0,
- 0,
- 0,
- 0,
- extractor.HashToken("^H"),
- 0,
- 0,
- 0,
- extractor.HashToken("ll"),
- 0,
- 0,
- 0,
- 0,
- 0,
- 0,
- extractor.HashToken("llo"),
- 0
- // clang-format on
- }));
- EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, 1.0}));
-
- sparse_features.clear();
- dense_features.clear();
- extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features,
- &dense_features);
-
- EXPECT_THAT(sparse_features, testing::ElementsAreArray({
- // clang-format off
- extractor.HashToken("w"),
- 0,
- 0,
- 0,
- 0,
- extractor.HashToken("!"),
- 0,
- 0,
- 0,
- 0,
- 0,
- 0,
- 0,
- 0,
- 0,
- 0,
- 0,
- 0,
- 0,
- // clang-format on
- }));
- EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 0.0}));
- EXPECT_EQ(extractor.HashToken("<PAD>"), 1);
-}
-
-} // namespace
-} // namespace libtextclassifier3
diff --git a/native/utils/tokenizer_test.cc b/native/utils/tokenizer_test.cc
deleted file mode 100644
index 0f5501d..0000000
--- a/native/utils/tokenizer_test.cc
+++ /dev/null
@@ -1,580 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/tokenizer.h"
-
-#include <vector>
-
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-
-namespace libtextclassifier3 {
-namespace {
-
-using testing::ElementsAreArray;
-
-class TestingTokenizer : public Tokenizer {
- public:
- TestingTokenizer(
- const TokenizationType type, const UniLib* unilib,
- const std::vector<const TokenizationCodepointRange*>& codepoint_ranges,
- const std::vector<const CodepointRange*>&
- internal_tokenizer_codepoint_ranges,
- const bool split_on_script_change,
- const bool icu_preserve_whitespace_tokens,
- const bool preserve_floating_numbers)
- : Tokenizer(type, unilib, codepoint_ranges,
- internal_tokenizer_codepoint_ranges, split_on_script_change,
- icu_preserve_whitespace_tokens, preserve_floating_numbers) {}
-
- using Tokenizer::FindTokenizationRange;
-};
-
-class TestingTokenizerProxy {
- public:
- TestingTokenizerProxy(
- TokenizationType type,
- const std::vector<TokenizationCodepointRangeT>& codepoint_range_configs,
- const std::vector<CodepointRangeT>& internal_codepoint_range_configs,
- const bool split_on_script_change,
- const bool icu_preserve_whitespace_tokens,
- const bool preserve_floating_numbers)
- : INIT_UNILIB_FOR_TESTING(unilib_) {
- const int num_configs = codepoint_range_configs.size();
- std::vector<const TokenizationCodepointRange*> configs_fb;
- configs_fb.reserve(num_configs);
- const int num_internal_configs = internal_codepoint_range_configs.size();
- std::vector<const CodepointRange*> internal_configs_fb;
- internal_configs_fb.reserve(num_internal_configs);
- buffers_.reserve(num_configs + num_internal_configs);
- for (int i = 0; i < num_configs; i++) {
- flatbuffers::FlatBufferBuilder builder;
- builder.Finish(CreateTokenizationCodepointRange(
- builder, &codepoint_range_configs[i]));
- buffers_.push_back(builder.Release());
- configs_fb.push_back(flatbuffers::GetRoot<TokenizationCodepointRange>(
- buffers_.back().data()));
- }
- for (int i = 0; i < num_internal_configs; i++) {
- flatbuffers::FlatBufferBuilder builder;
- builder.Finish(
- CreateCodepointRange(builder, &internal_codepoint_range_configs[i]));
- buffers_.push_back(builder.Release());
- internal_configs_fb.push_back(
- flatbuffers::GetRoot<CodepointRange>(buffers_.back().data()));
- }
- tokenizer_ = std::unique_ptr<TestingTokenizer>(new TestingTokenizer(
- type, &unilib_, configs_fb, internal_configs_fb, split_on_script_change,
- icu_preserve_whitespace_tokens, preserve_floating_numbers));
- }
-
- TokenizationCodepointRange_::Role TestFindTokenizationRole(int c) const {
- const TokenizationCodepointRangeT* range =
- tokenizer_->FindTokenizationRange(c);
- if (range != nullptr) {
- return range->role;
- } else {
- return TokenizationCodepointRange_::Role_DEFAULT_ROLE;
- }
- }
-
- std::vector<Token> Tokenize(const std::string& utf8_text) const {
- return tokenizer_->Tokenize(utf8_text);
- }
-
- private:
- UniLib unilib_;
- std::vector<flatbuffers::DetachedBuffer> buffers_;
- std::unique_ptr<TestingTokenizer> tokenizer_;
-};
-
-TEST(TokenizerTest, FindTokenizationRange) {
- std::vector<TokenizationCodepointRangeT> configs;
- TokenizationCodepointRangeT* config;
-
- configs.emplace_back();
- config = &configs.back();
- config->start = 0;
- config->end = 10;
- config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
-
- configs.emplace_back();
- config = &configs.back();
- config->start = 32;
- config->end = 33;
- config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR;
-
- configs.emplace_back();
- config = &configs.back();
- config->start = 1234;
- config->end = 12345;
- config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
-
- TestingTokenizerProxy tokenizer(TokenizationType_INTERNAL_TOKENIZER, configs,
- {}, /*split_on_script_change=*/false,
- /*icu_preserve_whitespace_tokens=*/false,
- /*preserve_floating_numbers=*/false);
-
- // Test hits to the first group.
- EXPECT_EQ(tokenizer.TestFindTokenizationRole(0),
- TokenizationCodepointRange_::Role_TOKEN_SEPARATOR);
- EXPECT_EQ(tokenizer.TestFindTokenizationRole(5),
- TokenizationCodepointRange_::Role_TOKEN_SEPARATOR);
- EXPECT_EQ(tokenizer.TestFindTokenizationRole(10),
- TokenizationCodepointRange_::Role_DEFAULT_ROLE);
-
- // Test a hit to the second group.
- EXPECT_EQ(tokenizer.TestFindTokenizationRole(31),
- TokenizationCodepointRange_::Role_DEFAULT_ROLE);
- EXPECT_EQ(tokenizer.TestFindTokenizationRole(32),
- TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR);
- EXPECT_EQ(tokenizer.TestFindTokenizationRole(33),
- TokenizationCodepointRange_::Role_DEFAULT_ROLE);
-
- // Test hits to the third group.
- EXPECT_EQ(tokenizer.TestFindTokenizationRole(1233),
- TokenizationCodepointRange_::Role_DEFAULT_ROLE);
- EXPECT_EQ(tokenizer.TestFindTokenizationRole(1234),
- TokenizationCodepointRange_::Role_TOKEN_SEPARATOR);
- EXPECT_EQ(tokenizer.TestFindTokenizationRole(12344),
- TokenizationCodepointRange_::Role_TOKEN_SEPARATOR);
- EXPECT_EQ(tokenizer.TestFindTokenizationRole(12345),
- TokenizationCodepointRange_::Role_DEFAULT_ROLE);
-
- // Test a hit outside.
- EXPECT_EQ(tokenizer.TestFindTokenizationRole(99),
- TokenizationCodepointRange_::Role_DEFAULT_ROLE);
-}
-
-TEST(TokenizerTest, TokenizeOnSpace) {
- std::vector<TokenizationCodepointRangeT> configs;
- TokenizationCodepointRangeT* config;
-
- configs.emplace_back();
- config = &configs.back();
- // Space character.
- config->start = 32;
- config->end = 33;
- config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR;
-
- TestingTokenizerProxy tokenizer(TokenizationType_INTERNAL_TOKENIZER, configs,
- {},
- /*split_on_script_change=*/false,
- /*icu_preserve_whitespace_tokens=*/false,
- /*preserve_floating_numbers=*/false);
- std::vector<Token> tokens = tokenizer.Tokenize("Hello world!");
-
- EXPECT_THAT(tokens,
- ElementsAreArray({Token("Hello", 0, 5), Token("world!", 6, 12)}));
-}
-
-TEST(TokenizerTest, TokenizeOnSpaceAndScriptChange) {
- std::vector<TokenizationCodepointRangeT> configs;
- TokenizationCodepointRangeT* config;
-
- // Latin.
- configs.emplace_back();
- config = &configs.back();
- config->start = 0;
- config->end = 32;
- config->role = TokenizationCodepointRange_::Role_DEFAULT_ROLE;
- config->script_id = 1;
- configs.emplace_back();
- config = &configs.back();
- config->start = 32;
- config->end = 33;
- config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR;
- config->script_id = 1;
- configs.emplace_back();
- config = &configs.back();
- config->start = 33;
- config->end = 0x77F + 1;
- config->role = TokenizationCodepointRange_::Role_DEFAULT_ROLE;
- config->script_id = 1;
-
- TestingTokenizerProxy tokenizer(TokenizationType_INTERNAL_TOKENIZER, configs,
- {},
- /*split_on_script_change=*/true,
- /*icu_preserve_whitespace_tokens=*/false,
- /*preserve_floating_numbers=*/false);
- EXPECT_THAT(tokenizer.Tokenize("앨라배마 주 전화(123) 456-789웹사이트"),
- std::vector<Token>({Token("앨라배마", 0, 4), Token("주", 5, 6),
- Token("전화", 7, 10), Token("(123)", 10, 15),
- Token("456-789", 16, 23),
- Token("웹사이트", 23, 28)}));
-} // namespace
-
-TEST(TokenizerTest, TokenizeComplex) {
- std::vector<TokenizationCodepointRangeT> configs;
- TokenizationCodepointRangeT* config;
-
- // Source: http://www.unicode.org/Public/10.0.0/ucd/Blocks-10.0.0d1.txt
- // Latin - cyrilic.
- // 0000..007F; Basic Latin
- // 0080..00FF; Latin-1 Supplement
- // 0100..017F; Latin Extended-A
- // 0180..024F; Latin Extended-B
- // 0250..02AF; IPA Extensions
- // 02B0..02FF; Spacing Modifier Letters
- // 0300..036F; Combining Diacritical Marks
- // 0370..03FF; Greek and Coptic
- // 0400..04FF; Cyrillic
- // 0500..052F; Cyrillic Supplement
- // 0530..058F; Armenian
- // 0590..05FF; Hebrew
- // 0600..06FF; Arabic
- // 0700..074F; Syriac
- // 0750..077F; Arabic Supplement
- configs.emplace_back();
- config = &configs.back();
- config->start = 0;
- config->end = 32;
- config->role = TokenizationCodepointRange_::Role_DEFAULT_ROLE;
- configs.emplace_back();
- config = &configs.back();
- config->start = 32;
- config->end = 33;
- config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR;
- configs.emplace_back();
- config = &configs.back();
- config->start = 33;
- config->end = 0x77F + 1;
- config->role = TokenizationCodepointRange_::Role_DEFAULT_ROLE;
-
- // CJK
- // 2E80..2EFF; CJK Radicals Supplement
- // 3000..303F; CJK Symbols and Punctuation
- // 3040..309F; Hiragana
- // 30A0..30FF; Katakana
- // 3100..312F; Bopomofo
- // 3130..318F; Hangul Compatibility Jamo
- // 3190..319F; Kanbun
- // 31A0..31BF; Bopomofo Extended
- // 31C0..31EF; CJK Strokes
- // 31F0..31FF; Katakana Phonetic Extensions
- // 3200..32FF; Enclosed CJK Letters and Months
- // 3300..33FF; CJK Compatibility
- // 3400..4DBF; CJK Unified Ideographs Extension A
- // 4DC0..4DFF; Yijing Hexagram Symbols
- // 4E00..9FFF; CJK Unified Ideographs
- // A000..A48F; Yi Syllables
- // A490..A4CF; Yi Radicals
- // A4D0..A4FF; Lisu
- // A500..A63F; Vai
- // F900..FAFF; CJK Compatibility Ideographs
- // FE30..FE4F; CJK Compatibility Forms
- // 20000..2A6DF; CJK Unified Ideographs Extension B
- // 2A700..2B73F; CJK Unified Ideographs Extension C
- // 2B740..2B81F; CJK Unified Ideographs Extension D
- // 2B820..2CEAF; CJK Unified Ideographs Extension E
- // 2CEB0..2EBEF; CJK Unified Ideographs Extension F
- // 2F800..2FA1F; CJK Compatibility Ideographs Supplement
- configs.emplace_back();
- config = &configs.back();
- config->start = 0x2E80;
- config->end = 0x2EFF + 1;
- config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
- configs.emplace_back();
- config = &configs.back();
- config->start = 0x3000;
- config->end = 0xA63F + 1;
- config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
- configs.emplace_back();
- config = &configs.back();
- config->start = 0xF900;
- config->end = 0xFAFF + 1;
- config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
- configs.emplace_back();
- config = &configs.back();
- config->start = 0xFE30;
- config->end = 0xFE4F + 1;
- config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
- configs.emplace_back();
- config = &configs.back();
- config->start = 0x20000;
- config->end = 0x2A6DF + 1;
- config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
- configs.emplace_back();
- config = &configs.back();
- config->start = 0x2A700;
- config->end = 0x2B73F + 1;
- config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
- configs.emplace_back();
- config = &configs.back();
- config->start = 0x2B740;
- config->end = 0x2B81F + 1;
- config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
- configs.emplace_back();
- config = &configs.back();
- config->start = 0x2B820;
- config->end = 0x2CEAF + 1;
- config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
- configs.emplace_back();
- config = &configs.back();
- config->start = 0x2CEB0;
- config->end = 0x2EBEF + 1;
- config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
- configs.emplace_back();
- config = &configs.back();
- config->start = 0x2F800;
- config->end = 0x2FA1F + 1;
- config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
-
- // Thai.
- // 0E00..0E7F; Thai
- configs.emplace_back();
- config = &configs.back();
- config->start = 0x0E00;
- config->end = 0x0E7F + 1;
- config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
-
- TestingTokenizerProxy tokenizer(TokenizationType_INTERNAL_TOKENIZER, configs,
- {},
- /*split_on_script_change=*/false,
- /*icu_preserve_whitespace_tokens=*/false,
- /*preserve_floating_numbers=*/false);
- std::vector<Token> tokens;
-
- tokens = tokenizer.Tokenize(
- "問少目木輸走猶術権自京門録球変。細開括省用掲情結傍走愛明氷。");
- EXPECT_EQ(tokens.size(), 30);
-
- tokens = tokenizer.Tokenize("問少目 hello 木輸ยามきゃ");
- // clang-format off
- EXPECT_THAT(
- tokens,
- ElementsAreArray({Token("問", 0, 1),
- Token("少", 1, 2),
- Token("目", 2, 3),
- Token("hello", 4, 9),
- Token("木", 10, 11),
- Token("輸", 11, 12),
- Token("ย", 12, 13),
- Token("า", 13, 14),
- Token("ม", 14, 15),
- Token("き", 15, 16),
- Token("ゃ", 16, 17)}));
- // clang-format on
-}
-
-#ifdef TC3_TEST_ICU
-TEST(TokenizerTest, ICUTokenize) {
- TestingTokenizerProxy tokenizer(TokenizationType_ICU, {}, {},
- /*split_on_script_change=*/false,
- /*icu_preserve_whitespace_tokens=*/false,
- /*preserve_floating_numbers=*/false);
- std::vector<Token> tokens = tokenizer.Tokenize("พระบาทสมเด็จพระปรมิ");
- ASSERT_EQ(tokens,
- // clang-format off
- std::vector<Token>({Token("พระบาท", 0, 6),
- Token("สมเด็จ", 6, 12),
- Token("พระ", 12, 15),
- Token("ปร", 15, 17),
- Token("มิ", 17, 19)}));
- // clang-format on
-}
-
-TEST(TokenizerTest, ICUTokenizeWithWhitespaces) {
- TestingTokenizerProxy tokenizer(TokenizationType_ICU, {}, {},
- /*split_on_script_change=*/false,
- /*icu_preserve_whitespace_tokens=*/true,
- /*preserve_floating_numbers=*/false);
- std::vector<Token> tokens = tokenizer.Tokenize("พระบาท สมเด็จ พระ ปร มิ");
- ASSERT_EQ(tokens,
- // clang-format off
- std::vector<Token>({Token("พระบาท", 0, 6),
- Token(" ", 6, 7),
- Token("สมเด็จ", 7, 13),
- Token(" ", 13, 14),
- Token("พระ", 14, 17),
- Token(" ", 17, 18),
- Token("ปร", 18, 20),
- Token(" ", 20, 21),
- Token("มิ", 21, 23)}));
- // clang-format on
-}
-
-TEST(TokenizerTest, MixedTokenize) {
- std::vector<TokenizationCodepointRangeT> configs;
- TokenizationCodepointRangeT* config;
-
- configs.emplace_back();
- config = &configs.back();
- config->start = 32;
- config->end = 33;
- config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR;
-
- std::vector<CodepointRangeT> internal_configs;
- CodepointRangeT* interal_config;
-
- internal_configs.emplace_back();
- interal_config = &internal_configs.back();
- interal_config->start = 0;
- interal_config->end = 128;
-
- internal_configs.emplace_back();
- interal_config = &internal_configs.back();
- interal_config->start = 128;
- interal_config->end = 256;
-
- internal_configs.emplace_back();
- interal_config = &internal_configs.back();
- interal_config->start = 256;
- interal_config->end = 384;
-
- internal_configs.emplace_back();
- interal_config = &internal_configs.back();
- interal_config->start = 384;
- interal_config->end = 592;
-
- TestingTokenizerProxy tokenizer(TokenizationType_MIXED, configs,
- internal_configs,
- /*split_on_script_change=*/false,
- /*icu_preserve_whitespace_tokens=*/false,
- /*preserve_floating_numbers=*/false);
-
- std::vector<Token> tokens = tokenizer.Tokenize(
- "こんにちはJapanese-ląnguagę text 你好世界 http://www.google.com/");
- ASSERT_EQ(
- tokens,
- // clang-format off
- std::vector<Token>({Token("こんにちは", 0, 5),
- Token("Japanese-ląnguagę", 5, 22),
- Token("text", 23, 27),
- Token("你好", 28, 30),
- Token("世界", 30, 32),
- Token("http://www.google.com/", 33, 55)}));
- // clang-format on
-}
-
-TEST(TokenizerTest, InternalTokenizeOnScriptChange) {
- std::vector<TokenizationCodepointRangeT> configs;
- TokenizationCodepointRangeT* config;
-
- configs.emplace_back();
- config = &configs.back();
- config->start = 0;
- config->end = 256;
- config->role = TokenizationCodepointRange_::Role_DEFAULT_ROLE;
-
- {
- TestingTokenizerProxy tokenizer(TokenizationType_INTERNAL_TOKENIZER,
- configs, {},
- /*split_on_script_change=*/false,
- /*icu_preserve_whitespace_tokens=*/false,
- /*preserve_floating_numbers=*/false);
-
- EXPECT_EQ(tokenizer.Tokenize("앨라배마123웹사이트"),
- std::vector<Token>({Token("앨라배마123웹사이트", 0, 11)}));
- }
-
- {
- TestingTokenizerProxy tokenizer(TokenizationType_INTERNAL_TOKENIZER,
- configs, {},
- /*split_on_script_change=*/true,
- /*icu_preserve_whitespace_tokens=*/false,
- /*preserve_floating_numbers=*/false);
- EXPECT_EQ(tokenizer.Tokenize("앨라배마123웹사이트"),
- std::vector<Token>({Token("앨라배마", 0, 4), Token("123", 4, 7),
- Token("웹사이트", 7, 11)}));
- }
-}
-#endif
-
-TEST(TokenizerTest, LetterDigitTokenize) {
- TestingTokenizerProxy tokenizer(TokenizationType_LETTER_DIGIT, {}, {},
- /*split_on_script_change=*/false,
- /*icu_preserve_whitespace_tokens=*/false,
- /*preserve_floating_numbers=*/true);
- std::vector<Token> tokens = tokenizer.Tokenize("7% -3.14 68.9#? 7% $99 .18.");
- ASSERT_EQ(tokens,
- std::vector<Token>(
- {Token("7", 0, 1), Token("%", 1, 2), Token(" ", 2, 3),
- Token("-", 3, 4), Token("3.14", 4, 8), Token(" ", 8, 9),
- Token("68.9", 9, 13), Token("#", 13, 14), Token("?", 14, 15),
- Token(" ", 15, 16), Token("7", 16, 17), Token("%", 17, 18),
- Token(" ", 18, 19), Token("$", 19, 20), Token("99", 20, 22),
- Token(" ", 22, 23), Token(".", 23, 24), Token("18", 24, 26),
- Token(".", 26, 27)}));
-}
-
-TEST(TokenizerTest, LetterDigitTokenizeUnicode) {
- TestingTokenizerProxy tokenizer(TokenizationType_LETTER_DIGIT, {}, {},
- /*split_on_script_change=*/false,
- /*icu_preserve_whitespace_tokens=*/false,
- /*preserve_floating_numbers=*/true);
- std::vector<Token> tokens = tokenizer.Tokenize("2 pércént 3パーセント");
- ASSERT_EQ(tokens, std::vector<Token>({Token("2", 0, 1), Token(" ", 1, 2),
- Token("pércént", 2, 9),
- Token(" ", 9, 10), Token("3", 10, 11),
- Token("パーセント", 11, 16)}));
-}
-
-TEST(TokenizerTest, LetterDigitTokenizeWithDots) {
- TestingTokenizerProxy tokenizer(TokenizationType_LETTER_DIGIT, {}, {},
- /*split_on_script_change=*/false,
- /*icu_preserve_whitespace_tokens=*/false,
- /*preserve_floating_numbers=*/true);
- std::vector<Token> tokens = tokenizer.Tokenize("3 3﹒2 3.3%");
- ASSERT_EQ(tokens,
- std::vector<Token>({Token("3", 0, 1), Token(" ", 1, 2),
- Token("3﹒2", 2, 5), Token(" ", 5, 6),
- Token("3.3", 6, 9), Token("%", 9, 10)}));
-}
-
-TEST(TokenizerTest, LetterDigitTokenizeDoNotPreserveFloatingNumbers) {
- TestingTokenizerProxy tokenizer(TokenizationType_LETTER_DIGIT, {}, {},
- /*split_on_script_change=*/false,
- /*icu_preserve_whitespace_tokens=*/false,
- /*preserve_floating_numbers=*/false);
- std::vector<Token> tokens = tokenizer.Tokenize("15.12.2019 january's 3.2");
- ASSERT_EQ(tokens,
- std::vector<Token>(
- {Token("15", 0, 2), Token(".", 2, 3), Token("12", 3, 5),
- Token(".", 5, 6), Token("2019", 6, 10), Token(" ", 10, 11),
- Token("january", 11, 18), Token("'", 18, 19),
- Token("s", 19, 20), Token(" ", 20, 21), Token("3", 21, 22),
- Token(".", 22, 23), Token("2", 23, 24)}));
-}
-
-TEST(TokenizerTest, LetterDigitTokenizeStrangeStringFloatingNumbers) {
- TestingTokenizerProxy tokenizer(TokenizationType_LETTER_DIGIT, {}, {},
- /*split_on_script_change=*/false,
- /*icu_preserve_whitespace_tokens=*/false,
- /*preserve_floating_numbers=*/false);
- std::vector<Token> tokens = tokenizer.Tokenize("The+2345++the +íí+");
- ASSERT_EQ(tokens,
- std::vector<Token>({Token("The", 0, 3), Token("+", 3, 4),
- Token("2345", 4, 8), Token("+", 8, 9),
- Token("+", 9, 10), Token("the", 10, 13),
- Token(" ", 13, 14), Token("+", 14, 15),
- Token("íí", 15, 17), Token("+", 17, 18)}));
-}
-
-TEST(TokenizerTest, LetterDigitTokenizeWhitespcesInSameToken) {
- TestingTokenizerProxy tokenizer(TokenizationType_LETTER_DIGIT, {}, {},
- /*split_on_script_change=*/false,
- /*icu_preserve_whitespace_tokens=*/false,
- /*preserve_floating_numbers=*/false);
- std::vector<Token> tokens = tokenizer.Tokenize("2 3 4 5");
- ASSERT_EQ(tokens, std::vector<Token>({Token("2", 0, 1), Token(" ", 1, 2),
- Token("3", 2, 3), Token(" ", 3, 5),
- Token("4", 5, 6), Token(" ", 6, 9),
- Token("5", 9, 10)}));
-}
-
-} // namespace
-} // namespace libtextclassifier3
diff --git a/native/utils/utf8/unicodetext.h b/native/utils/utf8/unicodetext.h
index cf78e47..9810480 100644
--- a/native/utils/utf8/unicodetext.h
+++ b/native/utils/utf8/unicodetext.h
@@ -223,6 +223,8 @@
// std::string, or from ::string to std::string, because if this happens it
// often results in invalid memory access to a temporary object created during
// such conversion (if do_copy == false).
+// NOTE: These methods don't check if the input string is UTF8 well formed, for
+// efficiency reasons. Use UnicodeText::is_valid() when explicitly needed.
UnicodeText UTF8ToUnicodeText(const char* utf8_buf, int len,
bool do_copy = true);
UnicodeText UTF8ToUnicodeText(const char* utf8_buf, bool do_copy = true);
diff --git a/native/utils/utf8/unicodetext_test.cc b/native/utils/utf8/unicodetext_test.cc
deleted file mode 100644
index 4e8883b..0000000
--- a/native/utils/utf8/unicodetext_test.cc
+++ /dev/null
@@ -1,228 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/utf8/unicodetext.h"
-
-#include "utils/strings/stringpiece.h"
-#include "gtest/gtest.h"
-
-namespace libtextclassifier3 {
-namespace {
-
-class UnicodeTextTest : public testing::Test {
- protected:
- UnicodeTextTest() : empty_text_() {
- text_.push_back(0x1C0);
- text_.push_back(0x4E8C);
- text_.push_back(0xD7DB);
- text_.push_back(0x34);
- text_.push_back(0x1D11E);
- }
-
- UnicodeText empty_text_;
- UnicodeText text_;
-};
-
-TEST(UnicodeTextTest, ConstructionFromUnicodeText) {
- UnicodeText text = UTF8ToUnicodeText("1234😋hello", /*do_copy=*/false);
- EXPECT_EQ(UnicodeText(text).ToUTF8String(), "1234😋hello");
- EXPECT_EQ(UnicodeText(text, /*do_copy=*/false).ToUTF8String(), "1234😋hello");
-}
-
-// Tests for our modifications of UnicodeText.
-TEST(UnicodeTextTest, Custom) {
- UnicodeText text = UTF8ToUnicodeText("1234😋hello", /*do_copy=*/false);
- EXPECT_EQ(text.ToUTF8String(), "1234😋hello");
- EXPECT_EQ(text.size_codepoints(), 10);
- EXPECT_EQ(text.size_bytes(), 13);
-
- auto it_begin = text.begin();
- std::advance(it_begin, 4);
- auto it_end = text.begin();
- std::advance(it_end, 6);
- EXPECT_EQ(text.UTF8Substring(it_begin, it_end), "😋h");
-}
-
-TEST(UnicodeTextTest, StringPieceView) {
- std::string raw_text = "1234😋hello";
- UnicodeText text =
- UTF8ToUnicodeText(StringPiece(raw_text), /*do_copy=*/false);
- EXPECT_EQ(text.ToUTF8String(), "1234😋hello");
- EXPECT_EQ(text.size_codepoints(), 10);
- EXPECT_EQ(text.size_bytes(), 13);
-
- auto it_begin = text.begin();
- std::advance(it_begin, 4);
- auto it_end = text.begin();
- std::advance(it_end, 6);
- EXPECT_EQ(text.UTF8Substring(it_begin, it_end), "😋h");
-}
-
-TEST(UnicodeTextTest, Substring) {
- UnicodeText text = UTF8ToUnicodeText("1234😋hello", /*do_copy=*/false);
-
- EXPECT_EQ(
- UnicodeText::Substring(std::next(text.begin(), 4),
- std::next(text.begin(), 6), /*do_copy=*/true),
- UTF8ToUnicodeText("😋h"));
- EXPECT_EQ(
- UnicodeText::Substring(std::next(text.begin(), 4),
- std::next(text.begin(), 6), /*do_copy=*/false),
- UTF8ToUnicodeText("😋h"));
- EXPECT_EQ(UnicodeText::Substring(text, 4, 6, /*do_copy=*/true),
- UTF8ToUnicodeText("😋h"));
- EXPECT_EQ(UnicodeText::Substring(text, 4, 6, /*do_copy=*/false),
- UTF8ToUnicodeText("😋h"));
-}
-
-TEST(UnicodeTextTest, Ownership) {
- const std::string src = "\u304A\u00B0\u106B";
-
- UnicodeText alias;
- alias.PointToUTF8(src.data(), src.size());
- EXPECT_EQ(alias.data(), src.data());
- UnicodeText::const_iterator it = alias.begin();
- EXPECT_EQ(*it++, 0x304A);
- EXPECT_EQ(*it++, 0x00B0);
- EXPECT_EQ(*it++, 0x106B);
- EXPECT_EQ(it, alias.end());
-
- UnicodeText t = alias; // Copy initialization copies the data.
- EXPECT_NE(t.data(), alias.data());
-}
-
-TEST(UnicodeTextTest, Validation) {
- EXPECT_TRUE(UTF8ToUnicodeText("1234😋hello", /*do_copy=*/false).is_valid());
- EXPECT_TRUE(
- UTF8ToUnicodeText("\u304A\u00B0\u106B", /*do_copy=*/false).is_valid());
- EXPECT_TRUE(
- UTF8ToUnicodeText("this is a test😋😋😋", /*do_copy=*/false).is_valid());
- EXPECT_TRUE(
- UTF8ToUnicodeText("\xf0\x9f\x98\x8b", /*do_copy=*/false).is_valid());
- // Too short (string is too short).
- EXPECT_FALSE(UTF8ToUnicodeText("\xf0\x9f", /*do_copy=*/false).is_valid());
- // Too long (too many trailing bytes).
- EXPECT_FALSE(
- UTF8ToUnicodeText("\xf0\x9f\x98\x8b\x8b", /*do_copy=*/false).is_valid());
- // Too short (too few trailing bytes).
- EXPECT_FALSE(
- UTF8ToUnicodeText("\xf0\x9f\x98\x61\x61", /*do_copy=*/false).is_valid());
- // Invalid with context.
- EXPECT_FALSE(
- UTF8ToUnicodeText("hello \xf0\x9f\x98\x61\x61 world1", /*do_copy=*/false)
- .is_valid());
-}
-
-class IteratorTest : public UnicodeTextTest {};
-
-TEST_F(IteratorTest, Iterates) {
- UnicodeText::const_iterator iter = text_.begin();
- EXPECT_EQ(0x1C0, *iter);
- EXPECT_EQ(&iter, &++iter); // operator++ returns *this.
- EXPECT_EQ(0x4E8C, *iter++);
- EXPECT_EQ(0xD7DB, *iter);
- // Make sure you can dereference more than once.
- EXPECT_EQ(0xD7DB, *iter);
- EXPECT_EQ(0x34, *++iter);
- EXPECT_EQ(0x1D11E, *++iter);
- ASSERT_TRUE(iter != text_.end());
- iter++;
- EXPECT_TRUE(iter == text_.end());
-}
-
-TEST_F(IteratorTest, MultiPass) {
- // Also tests Default Constructible and Assignable.
- UnicodeText::const_iterator i1, i2;
- i1 = text_.begin();
- i2 = i1;
- EXPECT_EQ(0x4E8C, *++i1);
- EXPECT_TRUE(i1 != i2);
- EXPECT_EQ(0x1C0, *i2);
- ++i2;
- EXPECT_TRUE(i1 == i2);
- EXPECT_EQ(0x4E8C, *i2);
-}
-
-TEST_F(IteratorTest, ReverseIterates) {
- UnicodeText::const_iterator iter = text_.end();
- EXPECT_TRUE(iter == text_.end());
- iter--;
- ASSERT_TRUE(iter != text_.end());
- EXPECT_EQ(0x1D11E, *iter--);
- EXPECT_EQ(0x34, *iter);
- EXPECT_EQ(0xD7DB, *--iter);
- // Make sure you can dereference more than once.
- EXPECT_EQ(0xD7DB, *iter);
- --iter;
- EXPECT_EQ(0x4E8C, *iter--);
- EXPECT_EQ(0x1C0, *iter);
- EXPECT_TRUE(iter == text_.begin());
-}
-
-TEST_F(IteratorTest, Comparable) {
- UnicodeText::const_iterator i1, i2;
- i1 = text_.begin();
- i2 = i1;
- ++i2;
-
- EXPECT_TRUE(i1 < i2);
- EXPECT_TRUE(text_.begin() <= i1);
- EXPECT_FALSE(i1 >= i2);
- EXPECT_FALSE(i1 > text_.end());
-}
-
-TEST_F(IteratorTest, Advance) {
- UnicodeText::const_iterator iter = text_.begin();
- EXPECT_EQ(0x1C0, *iter);
- std::advance(iter, 4);
- EXPECT_EQ(0x1D11E, *iter);
- ++iter;
- EXPECT_TRUE(iter == text_.end());
-}
-
-TEST_F(IteratorTest, Distance) {
- UnicodeText::const_iterator iter = text_.begin();
- EXPECT_EQ(0, std::distance(text_.begin(), iter));
- EXPECT_EQ(5, std::distance(iter, text_.end()));
- ++iter;
- ++iter;
- EXPECT_EQ(2, std::distance(text_.begin(), iter));
- EXPECT_EQ(3, std::distance(iter, text_.end()));
- ++iter;
- ++iter;
- EXPECT_EQ(4, std::distance(text_.begin(), iter));
- ++iter;
- EXPECT_EQ(0, std::distance(iter, text_.end()));
-}
-
-class OperatorTest : public UnicodeTextTest {};
-
-TEST_F(OperatorTest, Clear) {
- UnicodeText empty_text(UTF8ToUnicodeText("", /*do_copy=*/false));
- EXPECT_FALSE(text_ == empty_text);
- text_.clear();
- EXPECT_TRUE(text_ == empty_text);
-}
-
-TEST_F(OperatorTest, Empty) {
- EXPECT_TRUE(empty_text_.empty());
- EXPECT_FALSE(text_.empty());
- text_.clear();
- EXPECT_TRUE(text_.empty());
-}
-
-} // namespace
-} // namespace libtextclassifier3
diff --git a/native/utils/utf8/unilib_test-include.cc b/native/utils/utf8/unilib_test-include.cc
deleted file mode 100644
index 7b2a179..0000000
--- a/native/utils/utf8/unilib_test-include.cc
+++ /dev/null
@@ -1,502 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/utf8/unilib_test-include.h"
-
-#include "utils/base/logging.h"
-#include "gmock/gmock.h"
-
-namespace libtextclassifier3 {
-namespace test_internal {
-
-using ::testing::ElementsAre;
-
-TEST_F(UniLibTest, CharacterClassesAscii) {
- EXPECT_TRUE(unilib_.IsOpeningBracket('('));
- EXPECT_TRUE(unilib_.IsClosingBracket(')'));
- EXPECT_FALSE(unilib_.IsWhitespace(')'));
- EXPECT_TRUE(unilib_.IsWhitespace(' '));
- EXPECT_FALSE(unilib_.IsDigit(')'));
- EXPECT_TRUE(unilib_.IsDigit('0'));
- EXPECT_TRUE(unilib_.IsDigit('9'));
- EXPECT_FALSE(unilib_.IsUpper(')'));
- EXPECT_TRUE(unilib_.IsUpper('A'));
- EXPECT_TRUE(unilib_.IsUpper('Z'));
- EXPECT_FALSE(unilib_.IsLower(')'));
- EXPECT_TRUE(unilib_.IsLower('a'));
- EXPECT_TRUE(unilib_.IsLower('z'));
- EXPECT_TRUE(unilib_.IsPunctuation('!'));
- EXPECT_TRUE(unilib_.IsPunctuation('?'));
- EXPECT_TRUE(unilib_.IsPunctuation('#'));
- EXPECT_TRUE(unilib_.IsPunctuation('('));
- EXPECT_FALSE(unilib_.IsPunctuation('0'));
- EXPECT_FALSE(unilib_.IsPunctuation('$'));
- EXPECT_TRUE(unilib_.IsPercentage('%'));
- EXPECT_TRUE(unilib_.IsPercentage(u'%'));
- EXPECT_TRUE(unilib_.IsSlash('/'));
- EXPECT_TRUE(unilib_.IsSlash(u'/'));
- EXPECT_TRUE(unilib_.IsMinus('-'));
- EXPECT_TRUE(unilib_.IsMinus(u'-'));
- EXPECT_TRUE(unilib_.IsNumberSign('#'));
- EXPECT_TRUE(unilib_.IsNumberSign(u'#'));
- EXPECT_TRUE(unilib_.IsDot('.'));
- EXPECT_TRUE(unilib_.IsDot(u'.'));
-
- EXPECT_TRUE(unilib_.IsLatinLetter('A'));
- EXPECT_TRUE(unilib_.IsArabicLetter(u'ب')); // ARABIC LETTER BEH
- EXPECT_TRUE(
- unilib_.IsCyrillicLetter(u'ᲀ')); // CYRILLIC SMALL LETTER ROUNDED VE
- EXPECT_TRUE(unilib_.IsChineseLetter(u'豈')); // CJK COMPATIBILITY IDEOGRAPH
- EXPECT_TRUE(unilib_.IsJapaneseLetter(u'ぁ')); // HIRAGANA LETTER SMALL A
- EXPECT_TRUE(unilib_.IsKoreanLetter(u'ㄱ')); // HANGUL LETTER KIYEOK
- EXPECT_TRUE(unilib_.IsThaiLetter(u'ก')); // THAI CHARACTER KO KAI
- EXPECT_TRUE(unilib_.IsCJTletter(u'ก')); // THAI CHARACTER KO KAI
- EXPECT_FALSE(unilib_.IsCJTletter('A'));
-
- EXPECT_TRUE(unilib_.IsLetter('A'));
- EXPECT_TRUE(unilib_.IsLetter(u'A'));
- EXPECT_TRUE(unilib_.IsLetter(u'ト')); // KATAKANA LETTER TO
- EXPECT_TRUE(unilib_.IsLetter(u'ト')); // HALFWIDTH KATAKANA LETTER TO
- EXPECT_TRUE(unilib_.IsLetter(u'豈')); // CJK COMPATIBILITY IDEOGRAPH
-
- EXPECT_EQ(unilib_.ToLower('A'), 'a');
- EXPECT_EQ(unilib_.ToLower('Z'), 'z');
- EXPECT_EQ(unilib_.ToLower(')'), ')');
- EXPECT_EQ(unilib_.ToLowerText(UTF8ToUnicodeText("Never gonna give you up."))
- .ToUTF8String(),
- "never gonna give you up.");
- EXPECT_EQ(unilib_.ToUpper('a'), 'A');
- EXPECT_EQ(unilib_.ToUpper('z'), 'Z');
- EXPECT_EQ(unilib_.ToUpper(')'), ')');
- EXPECT_EQ(unilib_.ToUpperText(UTF8ToUnicodeText("Never gonna let you down."))
- .ToUTF8String(),
- "NEVER GONNA LET YOU DOWN.");
- EXPECT_EQ(unilib_.GetPairedBracket(')'), '(');
- EXPECT_EQ(unilib_.GetPairedBracket('}'), '{');
-}
-
-TEST_F(UniLibTest, CharacterClassesUnicode) {
- EXPECT_TRUE(unilib_.IsOpeningBracket(0x0F3C)); // TIBET ANG KHANG GYON
- EXPECT_TRUE(unilib_.IsClosingBracket(0x0F3D)); // TIBET ANG KHANG GYAS
- EXPECT_FALSE(unilib_.IsWhitespace(0x23F0)); // ALARM CLOCK
- EXPECT_TRUE(unilib_.IsWhitespace(0x2003)); // EM SPACE
- EXPECT_FALSE(unilib_.IsDigit(0xA619)); // VAI SYMBOL JONG
- EXPECT_TRUE(unilib_.IsDigit(0xA620)); // VAI DIGIT ZERO
- EXPECT_TRUE(unilib_.IsDigit(0xA629)); // VAI DIGIT NINE
- EXPECT_FALSE(unilib_.IsDigit(0xA62A)); // VAI SYLLABLE NDOLE MA
- EXPECT_FALSE(unilib_.IsUpper(0x0211)); // SMALL R WITH DOUBLE GRAVE
- EXPECT_TRUE(unilib_.IsUpper(0x0212)); // CAPITAL R WITH DOUBLE GRAVE
- EXPECT_TRUE(unilib_.IsUpper(0x0391)); // GREEK CAPITAL ALPHA
- EXPECT_TRUE(unilib_.IsUpper(0x03AB)); // GREEK CAPITAL UPSILON W DIAL
- EXPECT_FALSE(unilib_.IsUpper(0x03AC)); // GREEK SMALL ALPHA WITH TONOS
- EXPECT_TRUE(unilib_.IsLower(0x03AC)); // GREEK SMALL ALPHA WITH TONOS
- EXPECT_TRUE(unilib_.IsLower(0x03B1)); // GREEK SMALL ALPHA
- EXPECT_TRUE(unilib_.IsLower(0x03CB)); // GREEK SMALL UPSILON
- EXPECT_TRUE(unilib_.IsLower(0x0211)); // SMALL R WITH DOUBLE GRAVE
- EXPECT_TRUE(unilib_.IsLower(0x03C0)); // GREEK SMALL PI
- EXPECT_TRUE(unilib_.IsLower(0x007A)); // SMALL Z
- EXPECT_FALSE(unilib_.IsLower(0x005A)); // CAPITAL Z
- EXPECT_FALSE(unilib_.IsLower(0x0212)); // CAPITAL R WITH DOUBLE GRAVE
- EXPECT_FALSE(unilib_.IsLower(0x0391)); // GREEK CAPITAL ALPHA
- EXPECT_TRUE(unilib_.IsPunctuation(0x055E)); // ARMENIAN QUESTION MARK
- EXPECT_TRUE(unilib_.IsPunctuation(0x066C)); // ARABIC THOUSANDS SEPARATOR
- EXPECT_TRUE(unilib_.IsPunctuation(0x07F7)); // NKO SYMBOL GBAKURUNEN
- EXPECT_TRUE(unilib_.IsPunctuation(0x10AF2)); // DOUBLE DOT WITHIN DOT
- EXPECT_FALSE(unilib_.IsPunctuation(0x00A3)); // POUND SIGN
- EXPECT_FALSE(unilib_.IsPunctuation(0xA838)); // NORTH INDIC RUPEE MARK
- EXPECT_TRUE(unilib_.IsPercentage(0x0025)); // PERCENT SIGN
- EXPECT_TRUE(unilib_.IsPercentage(0xFF05)); // FULLWIDTH PERCENT SIGN
- EXPECT_TRUE(unilib_.IsSlash(0x002F)); // SOLIDUS
- EXPECT_TRUE(unilib_.IsSlash(0xFF0F)); // FULLWIDTH SOLIDUS
- EXPECT_TRUE(unilib_.IsMinus(0x002D)); // HYPHEN-MINUS
- EXPECT_TRUE(unilib_.IsMinus(0xFF0D)); // FULLWIDTH HYPHEN-MINUS
- EXPECT_TRUE(unilib_.IsNumberSign(0x0023)); // NUMBER SIGN
- EXPECT_TRUE(unilib_.IsNumberSign(0xFF03)); // FULLWIDTH NUMBER SIGN
- EXPECT_TRUE(unilib_.IsDot(0x002E)); // FULL STOP
- EXPECT_TRUE(unilib_.IsDot(0xFF0E)); // FULLWIDTH FULL STOP
-
- EXPECT_TRUE(unilib_.IsLatinLetter(0x0041)); // LATIN CAPITAL LETTER A
- EXPECT_TRUE(unilib_.IsArabicLetter(0x0628)); // ARABIC LETTER BEH
- EXPECT_TRUE(
- unilib_.IsCyrillicLetter(0x1C80)); // CYRILLIC SMALL LETTER ROUNDED VE
- EXPECT_TRUE(unilib_.IsChineseLetter(0xF900)); // CJK COMPATIBILITY IDEOGRAPH
- EXPECT_TRUE(unilib_.IsJapaneseLetter(0x3041)); // HIRAGANA LETTER SMALL A
- EXPECT_TRUE(unilib_.IsKoreanLetter(0x3131)); // HANGUL LETTER KIYEOK
- EXPECT_TRUE(unilib_.IsThaiLetter(0x0E01)); // THAI CHARACTER KO KAI
- EXPECT_TRUE(unilib_.IsCJTletter(0x0E01)); // THAI CHARACTER KO KAI
- EXPECT_FALSE(unilib_.IsCJTletter(0x0041)); // LATIN CAPITAL LETTER A
-
- EXPECT_TRUE(unilib_.IsLetter(0x0041)); // LATIN CAPITAL LETTER A
- EXPECT_TRUE(unilib_.IsLetter(0xFF21)); // FULLWIDTH LATIN CAPITAL LETTER A
- EXPECT_TRUE(unilib_.IsLetter(0x30C8)); // KATAKANA LETTER TO
- EXPECT_TRUE(unilib_.IsLetter(0xFF84)); // HALFWIDTH KATAKANA LETTER TO
- EXPECT_TRUE(unilib_.IsLetter(0xF900)); // CJK COMPATIBILITY IDEOGRAPH
-
- EXPECT_EQ(unilib_.ToLower(0x0391), 0x03B1); // GREEK ALPHA
- EXPECT_EQ(unilib_.ToLower(0x03AB), 0x03CB); // GREEK UPSILON WITH DIALYTIKA
- EXPECT_EQ(unilib_.ToLower(0x03C0), 0x03C0); // GREEK SMALL PI
- EXPECT_EQ(unilib_.ToLower(0x03A3), 0x03C3); // GREEK CAPITAL LETTER SIGMA
- EXPECT_EQ(unilib_.ToLowerText(UTF8ToUnicodeText("Κανένας άνθρωπος δεν ξέρει"))
- .ToUTF8String(),
- "κανένας άνθρωπος δεν ξέρει");
- EXPECT_EQ(unilib_.ToUpper(0x03B1), 0x0391); // GREEK ALPHA
- EXPECT_EQ(unilib_.ToUpper(0x03CB), 0x03AB); // GREEK UPSILON WITH DIALYTIKA
- EXPECT_EQ(unilib_.ToUpper(0x0391), 0x0391); // GREEK CAPITAL ALPHA
- EXPECT_EQ(unilib_.ToUpper(0x03C3), 0x03A3); // GREEK CAPITAL LETTER SIGMA
- EXPECT_EQ(unilib_.ToUpper(0x03C2), 0x03A3); // GREEK CAPITAL LETTER SIGMA
- EXPECT_EQ(unilib_.ToUpperText(UTF8ToUnicodeText("Κανένας άνθρωπος δεν ξέρει"))
- .ToUTF8String(),
- "ΚΑΝΈΝΑΣ ΆΝΘΡΩΠΟΣ ΔΕΝ ΞΈΡΕΙ");
- EXPECT_EQ(unilib_.GetPairedBracket(0x0F3C), 0x0F3D);
- EXPECT_EQ(unilib_.GetPairedBracket(0x0F3D), 0x0F3C);
-}
-
-TEST_F(UniLibTest, RegexInterface) {
- const UnicodeText regex_pattern =
- UTF8ToUnicodeText("[0-9]+", /*do_copy=*/true);
- std::unique_ptr<UniLib::RegexPattern> pattern =
- unilib_.CreateRegexPattern(regex_pattern);
- const UnicodeText input = UTF8ToUnicodeText("hello 0123", /*do_copy=*/false);
- int status;
- std::unique_ptr<UniLib::RegexMatcher> matcher = pattern->Matcher(input);
- TC3_LOG(INFO) << matcher->Matches(&status);
- TC3_LOG(INFO) << matcher->Find(&status);
- TC3_LOG(INFO) << matcher->Start(0, &status);
- TC3_LOG(INFO) << matcher->End(0, &status);
- TC3_LOG(INFO) << matcher->Group(0, &status).size_codepoints();
-}
-
-TEST_F(UniLibTest, Regex) {
- // The smiley face is a 4-byte UTF8 codepoint 0x1F60B, and it's important to
- // test the regex functionality with it to verify we are handling the indices
- // correctly.
- const UnicodeText regex_pattern =
- UTF8ToUnicodeText("[0-9]+😋", /*do_copy=*/false);
- std::unique_ptr<UniLib::RegexPattern> pattern =
- unilib_.CreateRegexPattern(regex_pattern);
- int status;
- std::unique_ptr<UniLib::RegexMatcher> matcher;
-
- matcher = pattern->Matcher(UTF8ToUnicodeText("0123😋", /*do_copy=*/false));
- EXPECT_TRUE(matcher->Matches(&status));
- EXPECT_TRUE(matcher->ApproximatelyMatches(&status));
- EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
- EXPECT_TRUE(matcher->Matches(&status)); // Check that the state is reset.
- EXPECT_TRUE(matcher->ApproximatelyMatches(&status));
- EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
-
- matcher = pattern->Matcher(
- UTF8ToUnicodeText("hello😋😋 0123😋 world", /*do_copy=*/false));
- EXPECT_FALSE(matcher->Matches(&status));
- EXPECT_FALSE(matcher->ApproximatelyMatches(&status));
- EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
-
- matcher = pattern->Matcher(
- UTF8ToUnicodeText("hello😋😋 0123😋 world", /*do_copy=*/false));
- EXPECT_TRUE(matcher->Find(&status));
- EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
- EXPECT_EQ(matcher->Start(0, &status), 8);
- EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
- EXPECT_EQ(matcher->End(0, &status), 13);
- EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
- EXPECT_EQ(matcher->Group(0, &status).ToUTF8String(), "0123😋");
- EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
-}
-
-TEST_F(UniLibTest, RegexLazy) {
- std::unique_ptr<UniLib::RegexPattern> pattern =
- unilib_.CreateLazyRegexPattern(
- UTF8ToUnicodeText("[a-z][0-9]", /*do_copy=*/false));
- int status;
- std::unique_ptr<UniLib::RegexMatcher> matcher;
-
- matcher = pattern->Matcher(UTF8ToUnicodeText("a3", /*do_copy=*/false));
- EXPECT_TRUE(matcher->Matches(&status));
- EXPECT_TRUE(matcher->ApproximatelyMatches(&status));
- EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
- EXPECT_TRUE(matcher->Matches(&status)); // Check that the state is reset.
- EXPECT_TRUE(matcher->ApproximatelyMatches(&status));
- EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
-
- matcher = pattern->Matcher(UTF8ToUnicodeText("3a", /*do_copy=*/false));
- EXPECT_FALSE(matcher->Matches(&status));
- EXPECT_FALSE(matcher->ApproximatelyMatches(&status));
- EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
-}
-
-TEST_F(UniLibTest, RegexGroups) {
- // The smiley face is a 4-byte UTF8 codepoint 0x1F60B, and it's important to
- // test the regex functionality with it to verify we are handling the indices
- // correctly.
- const UnicodeText regex_pattern =
- UTF8ToUnicodeText("([0-9])([0-9]+)😋", /*do_copy=*/false);
- std::unique_ptr<UniLib::RegexPattern> pattern =
- unilib_.CreateRegexPattern(regex_pattern);
- int status;
- std::unique_ptr<UniLib::RegexMatcher> matcher;
-
- matcher = pattern->Matcher(
- UTF8ToUnicodeText("hello😋😋 0123😋 world", /*do_copy=*/false));
- EXPECT_TRUE(matcher->Find(&status));
- EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
- EXPECT_EQ(matcher->Start(0, &status), 8);
- EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
- EXPECT_EQ(matcher->Start(1, &status), 8);
- EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
- EXPECT_EQ(matcher->Start(2, &status), 9);
- EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
- EXPECT_EQ(matcher->End(0, &status), 13);
- EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
- EXPECT_EQ(matcher->End(1, &status), 9);
- EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
- EXPECT_EQ(matcher->End(2, &status), 12);
- EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
- EXPECT_EQ(matcher->Group(0, &status).ToUTF8String(), "0123😋");
- EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
- EXPECT_EQ(matcher->Group(1, &status).ToUTF8String(), "0");
- EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
- EXPECT_EQ(matcher->Group(2, &status).ToUTF8String(), "123");
- EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
-}
-
-TEST_F(UniLibTest, RegexGroupsNotAllGroupsInvolved) {
- const UnicodeText regex_pattern =
- UTF8ToUnicodeText("([0-9])([a-z])?", /*do_copy=*/false);
- std::unique_ptr<UniLib::RegexPattern> pattern =
- unilib_.CreateRegexPattern(regex_pattern);
- int status;
- std::unique_ptr<UniLib::RegexMatcher> matcher;
-
- matcher = pattern->Matcher(UTF8ToUnicodeText("7", /*do_copy=*/false));
- EXPECT_TRUE(matcher->Find(&status));
- EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
- EXPECT_EQ(matcher->Group(0, &status).ToUTF8String(), "7");
- EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
- EXPECT_EQ(matcher->Group(1, &status).ToUTF8String(), "7");
- EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
- EXPECT_EQ(matcher->Group(2, &status).ToUTF8String(), "");
- EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
-}
-
-TEST_F(UniLibTest, RegexGroupsEmptyResult) {
- const UnicodeText regex_pattern =
- UTF8ToUnicodeText("(.*)", /*do_copy=*/false);
- std::unique_ptr<UniLib::RegexPattern> pattern =
- unilib_.CreateRegexPattern(regex_pattern);
- int status;
- std::unique_ptr<UniLib::RegexMatcher> matcher;
-
- matcher = pattern->Matcher(UTF8ToUnicodeText("", /*do_copy=*/false));
- EXPECT_TRUE(matcher->Find(&status));
- EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
- EXPECT_EQ(matcher->Group(0, &status).ToUTF8String(), "");
- EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
- EXPECT_EQ(matcher->Group(1, &status).ToUTF8String(), "");
- EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
-}
-
-TEST_F(UniLibTest, BreakIterator) {
- const UnicodeText text = UTF8ToUnicodeText("some text", /*do_copy=*/false);
- std::unique_ptr<UniLib::BreakIterator> iterator =
- unilib_.CreateBreakIterator(text);
- std::vector<int> break_indices;
- int break_index = 0;
- while ((break_index = iterator->Next()) != UniLib::BreakIterator::kDone) {
- break_indices.push_back(break_index);
- }
- EXPECT_THAT(break_indices, ElementsAre(4, 5, 9));
-}
-
-TEST_F(UniLibTest, BreakIterator4ByteUTF8) {
- const UnicodeText text = UTF8ToUnicodeText("😀😂😋", /*do_copy=*/false);
- std::unique_ptr<UniLib::BreakIterator> iterator =
- unilib_.CreateBreakIterator(text);
- std::vector<int> break_indices;
- int break_index = 0;
- while ((break_index = iterator->Next()) != UniLib::BreakIterator::kDone) {
- break_indices.push_back(break_index);
- }
- EXPECT_THAT(break_indices, ElementsAre(1, 2, 3));
-}
-
-TEST_F(UniLibTest, Integer32Parse) {
- int result;
- EXPECT_TRUE(
- unilib_.ParseInt32(UTF8ToUnicodeText("123", /*do_copy=*/false), &result));
- EXPECT_EQ(result, 123);
-}
-
-TEST_F(UniLibTest, Integer32ParseFloatNumber) {
- int result;
- EXPECT_FALSE(unilib_.ParseInt32(UTF8ToUnicodeText("12.3", /*do_copy=*/false),
- &result));
-}
-
-TEST_F(UniLibTest, Integer32ParseLongNumber) {
- int32 result;
- EXPECT_TRUE(unilib_.ParseInt32(
- UTF8ToUnicodeText("1000000000", /*do_copy=*/false), &result));
- EXPECT_EQ(result, 1000000000);
-}
-
-TEST_F(UniLibTest, Integer32ParseEmptyString) {
- int result;
- EXPECT_FALSE(
- unilib_.ParseInt32(UTF8ToUnicodeText("", /*do_copy=*/false), &result));
-}
-
-TEST_F(UniLibTest, Integer32ParseFullWidth) {
- int result;
- // The input string here is full width
- EXPECT_TRUE(unilib_.ParseInt32(UTF8ToUnicodeText("123", /*do_copy=*/false),
- &result));
- EXPECT_EQ(result, 123);
-}
-
-TEST_F(UniLibTest, Integer32ParseNotNumber) {
- int result;
- // The input string here is full width
- EXPECT_FALSE(unilib_.ParseInt32(UTF8ToUnicodeText("1a3", /*do_copy=*/false),
- &result));
- // Strings starting with "nan" are not numbers.
- EXPECT_FALSE(unilib_.ParseInt32(UTF8ToUnicodeText("Nancy",
- /*do_copy=*/false),
- &result));
- // Strings starting with "inf" are not numbers
- EXPECT_FALSE(unilib_.ParseInt32(
- UTF8ToUnicodeText("Information", /*do_copy=*/false), &result));
-}
-
-TEST_F(UniLibTest, Integer64Parse) {
- int64 result;
- EXPECT_TRUE(
- unilib_.ParseInt64(UTF8ToUnicodeText("123", /*do_copy=*/false), &result));
- EXPECT_EQ(result, 123);
-}
-
-TEST_F(UniLibTest, Integer64ParseFloatNumber) {
- int64 result;
- EXPECT_FALSE(unilib_.ParseInt64(UTF8ToUnicodeText("12.3", /*do_copy=*/false),
- &result));
-}
-
-TEST_F(UniLibTest, Integer64ParseLongNumber) {
- int64 result;
- // The limitation comes from the javaicu implementation: parseDouble does not
- // have ICU support and parseInt limit the size of the number.
- EXPECT_TRUE(unilib_.ParseInt64(
- UTF8ToUnicodeText("1000000000", /*do_copy=*/false), &result));
- EXPECT_EQ(result, 1000000000);
-}
-
-TEST_F(UniLibTest, Integer64ParseEmptyString) {
- int64 result;
- EXPECT_FALSE(
- unilib_.ParseInt64(UTF8ToUnicodeText("", /*do_copy=*/false), &result));
-}
-
-TEST_F(UniLibTest, Integer64ParseFullWidth) {
- int64 result;
- // The input string here is full width
- EXPECT_TRUE(unilib_.ParseInt64(UTF8ToUnicodeText("123", /*do_copy=*/false),
- &result));
- EXPECT_EQ(result, 123);
-}
-
-TEST_F(UniLibTest, Integer64ParseNotNumber) {
- int64 result;
- // The input string here is full width
- EXPECT_FALSE(unilib_.ParseInt64(UTF8ToUnicodeText("1a4", /*do_copy=*/false),
- &result));
- // Strings starting with "nan" are not numbers.
- EXPECT_FALSE(unilib_.ParseInt64(UTF8ToUnicodeText("Nancy",
- /*do_copy=*/false),
- &result));
- // Strings starting with "inf" are not numbers
- EXPECT_FALSE(unilib_.ParseInt64(
- UTF8ToUnicodeText("Information", /*do_copy=*/false), &result));
-}
-
-TEST_F(UniLibTest, DoubleParse) {
- double result;
- EXPECT_TRUE(unilib_.ParseDouble(UTF8ToUnicodeText("1.23", /*do_copy=*/false),
- &result));
- EXPECT_EQ(result, 1.23);
-}
-
-TEST_F(UniLibTest, DoubleParseLongNumber) {
- double result;
- // The limitation comes from the javaicu implementation: parseDouble does not
- // have ICU support and parseInt limit the size of the number.
- EXPECT_TRUE(unilib_.ParseDouble(
- UTF8ToUnicodeText("999999999.999999999", /*do_copy=*/false), &result));
- EXPECT_EQ(result, 999999999.999999999);
-}
-
-TEST_F(UniLibTest, DoubleParseWithoutFractionalPart) {
- double result;
- EXPECT_TRUE(unilib_.ParseDouble(UTF8ToUnicodeText("123", /*do_copy=*/false),
- &result));
- EXPECT_EQ(result, 123);
-}
-
-TEST_F(UniLibTest, DoubleParseEmptyString) {
- double result;
- EXPECT_FALSE(
- unilib_.ParseDouble(UTF8ToUnicodeText("", /*do_copy=*/false), &result));
-}
-
-TEST_F(UniLibTest, DoubleParsePrecedingDot) {
- double result;
- EXPECT_FALSE(unilib_.ParseDouble(UTF8ToUnicodeText(".123", /*do_copy=*/false),
- &result));
-}
-
-TEST_F(UniLibTest, DoubleParseLeadingDot) {
- double result;
- EXPECT_FALSE(unilib_.ParseDouble(UTF8ToUnicodeText("123.", /*do_copy=*/false),
- &result));
-}
-
-TEST_F(UniLibTest, DoubleParseMultipleDots) {
- double result;
- EXPECT_FALSE(unilib_.ParseDouble(
- UTF8ToUnicodeText("1.2.3", /*do_copy=*/false), &result));
-}
-
-TEST_F(UniLibTest, DoubleParseFullWidth) {
- double result;
- // The input string here is full width
- EXPECT_TRUE(unilib_.ParseDouble(
- UTF8ToUnicodeText("1.23", /*do_copy=*/false), &result));
- EXPECT_EQ(result, 1.23);
-}
-
-TEST_F(UniLibTest, DoubleParseNotNumber) {
- double result;
- // The input string here is full width
- EXPECT_FALSE(unilib_.ParseDouble(
- UTF8ToUnicodeText("1a5", /*do_copy=*/false), &result));
- // Strings starting with "nan" are not numbers.
- EXPECT_FALSE(unilib_.ParseDouble(
- UTF8ToUnicodeText("Nancy", /*do_copy=*/false), &result));
- // Strings starting with "inf" are not numbers
- EXPECT_FALSE(unilib_.ParseDouble(
- UTF8ToUnicodeText("Information", /*do_copy=*/false), &result));
-}
-
-} // namespace test_internal
-} // namespace libtextclassifier3
diff --git a/native/utils/utf8/unilib_test-include.h b/native/utils/utf8/unilib_test-include.h
deleted file mode 100644
index 342a00c..0000000
--- a/native/utils/utf8/unilib_test-include.h
+++ /dev/null
@@ -1,47 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_UTILS_UTF8_UNILIB_TEST_INCLUDE_H_
-#define LIBTEXTCLASSIFIER_UTILS_UTF8_UNILIB_TEST_INCLUDE_H_
-
-#include "utils/utf8/unilib.h"
-#include "gtest/gtest.h"
-
-#if defined TC3_UNILIB_ICU
-#define TC3_TESTING_CREATE_UNILIB_INSTANCE(VAR) VAR()
-#elif defined TC3_UNILIB_JAVAICU
-#include <jni.h>
-extern JNIEnv* g_jenv;
-#define TC3_TESTING_CREATE_UNILIB_INSTANCE(VAR) VAR(JniCache::Create(g_jenv))
-#elif defined TC3_UNILIB_APPLE
-#define TC3_TESTING_CREATE_UNILIB_INSTANCE(VAR) VAR()
-#elif defined TC3_UNILIB_DUMMY
-#define TC3_TESTING_CREATE_UNILIB_INSTANCE(VAR) VAR()
-#endif
-
-namespace libtextclassifier3 {
-namespace test_internal {
-
-class UniLibTest : public ::testing::Test {
- protected:
- UniLibTest() : TC3_TESTING_CREATE_UNILIB_INSTANCE(unilib_) {}
- UniLib unilib_;
-};
-
-} // namespace test_internal
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_UTILS_UTF8_UNILIB_TEST_INCLUDE_H_
diff --git a/native/utils/utf8/unilib_test.cc b/native/utils/utf8/unilib_test.cc
deleted file mode 100644
index 01b5164..0000000
--- a/native/utils/utf8/unilib_test.cc
+++ /dev/null
@@ -1,18 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-// The actual code of the test is in the following include:
-#include "utils/utf8/unilib_test-include.h"
diff --git a/native/utils/variant.h b/native/utils/variant.h
index 153a4e8..cb206ee 100644
--- a/native/utils/variant.h
+++ b/native/utils/variant.h
@@ -19,6 +19,7 @@
#include <map>
#include <string>
+#include <vector>
#include "utils/base/integral_types.h"
#include "utils/base/logging.h"
@@ -41,6 +42,9 @@
TYPE_DOUBLE_VALUE = 8,
TYPE_BOOL_VALUE = 9,
TYPE_STRING_VALUE = 10,
+ TYPE_STRING_VECTOR_VALUE = 11,
+ TYPE_FLOAT_VECTOR_VALUE = 12,
+ TYPE_INT_VECTOR_VALUE = 13,
};
Variant() : type_(TYPE_EMPTY) {}
@@ -68,6 +72,12 @@
: type_(TYPE_STRING_VALUE), string_value_(value) {}
explicit Variant(const bool value)
: type_(TYPE_BOOL_VALUE), bool_value_(value) {}
+ explicit Variant(const std::vector<std::string>& value)
+ : type_(TYPE_STRING_VECTOR_VALUE), string_vector_value_(value) {}
+ explicit Variant(const std::vector<float>& value)
+ : type_(TYPE_FLOAT_VECTOR_VALUE), float_vector_value_(value) {}
+ explicit Variant(const std::vector<int>& value)
+ : type_(TYPE_INT_VECTOR_VALUE), int_vector_value_(value) {}
Variant& operator=(const Variant&) = default;
@@ -121,6 +131,21 @@
return string_value_;
}
+ const std::vector<std::string>& StringVectorValue() const {
+ TC3_CHECK(HasStringVector());
+ return string_vector_value_;
+ }
+
+ const std::vector<float>& FloatVectorValue() const {
+ TC3_CHECK(HasFloatVector());
+ return float_vector_value_;
+ }
+
+ const std::vector<int>& IntVectorValue() const {
+ TC3_CHECK(HasIntVector());
+ return int_vector_value_;
+ }
+
// Converts the value of this variant to its string representation, regardless
// of the type of the actual value.
std::string ToString() const;
@@ -145,6 +170,12 @@
bool HasString() const { return type_ == TYPE_STRING_VALUE; }
+ bool HasStringVector() const { return type_ == TYPE_STRING_VECTOR_VALUE; }
+
+ bool HasFloatVector() const { return type_ == TYPE_FLOAT_VECTOR_VALUE; }
+
+ bool HasIntVector() const { return type_ == TYPE_INT_VECTOR_VALUE; }
+
Type GetType() const { return type_; }
bool HasValue() const { return type_ != TYPE_EMPTY; }
@@ -163,6 +194,9 @@
bool bool_value_;
};
std::string string_value_;
+ std::vector<std::string> string_vector_value_;
+ std::vector<float> float_vector_value_;
+ std::vector<int> int_vector_value_;
};
// Pretty-printing function for Variant.
diff --git a/native/utils/variant_test.cc b/native/utils/variant_test.cc
deleted file mode 100644
index 347d53f..0000000
--- a/native/utils/variant_test.cc
+++ /dev/null
@@ -1,73 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/variant.h"
-
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-
-namespace libtextclassifier3 {
-namespace {
-
-TEST(VariantTest, GetType) {
- EXPECT_EQ(Variant().GetType(), Variant::TYPE_EMPTY);
- EXPECT_EQ(Variant(static_cast<int8_t>(9)).GetType(),
- Variant::TYPE_INT8_VALUE);
- EXPECT_EQ(Variant(static_cast<uint8_t>(9)).GetType(),
- Variant::TYPE_UINT8_VALUE);
- EXPECT_EQ(Variant(static_cast<int>(9)).GetType(), Variant::TYPE_INT_VALUE);
- EXPECT_EQ(Variant(static_cast<uint>(9)).GetType(), Variant::TYPE_UINT_VALUE);
- EXPECT_EQ(Variant(static_cast<int64>(9)).GetType(),
- Variant::TYPE_INT64_VALUE);
- EXPECT_EQ(Variant(static_cast<uint64>(9)).GetType(),
- Variant::TYPE_UINT64_VALUE);
- EXPECT_EQ(Variant(static_cast<float>(9)).GetType(),
- Variant::TYPE_FLOAT_VALUE);
- EXPECT_EQ(Variant(static_cast<double>(9)).GetType(),
- Variant::TYPE_DOUBLE_VALUE);
- EXPECT_EQ(Variant(true).GetType(), Variant::TYPE_BOOL_VALUE);
- EXPECT_EQ(Variant("hello").GetType(), Variant::TYPE_STRING_VALUE);
-}
-
-TEST(VariantTest, HasValue) {
- EXPECT_FALSE(Variant().HasValue());
- EXPECT_TRUE(Variant(static_cast<int8_t>(9)).HasValue());
- EXPECT_TRUE(Variant(static_cast<uint8_t>(9)).HasValue());
- EXPECT_TRUE(Variant(static_cast<int>(9)).HasValue());
- EXPECT_TRUE(Variant(static_cast<uint>(9)).HasValue());
- EXPECT_TRUE(Variant(static_cast<int64>(9)).HasValue());
- EXPECT_TRUE(Variant(static_cast<uint64>(9)).HasValue());
- EXPECT_TRUE(Variant(static_cast<float>(9)).HasValue());
- EXPECT_TRUE(Variant(static_cast<double>(9)).HasValue());
- EXPECT_TRUE(Variant(true).HasValue());
- EXPECT_TRUE(Variant("hello").HasValue());
-}
-
-TEST(VariantTest, Value) {
- EXPECT_EQ(Variant(static_cast<int8_t>(9)).Int8Value(), 9);
- EXPECT_EQ(Variant(static_cast<uint8_t>(9)).UInt8Value(), 9);
- EXPECT_EQ(Variant(static_cast<int>(9)).IntValue(), 9);
- EXPECT_EQ(Variant(static_cast<uint>(9)).UIntValue(), 9);
- EXPECT_EQ(Variant(static_cast<int64>(9)).Int64Value(), 9);
- EXPECT_EQ(Variant(static_cast<uint64>(9)).UInt64Value(), 9);
- EXPECT_EQ(Variant(static_cast<float>(9)).FloatValue(), 9);
- EXPECT_EQ(Variant(static_cast<double>(9)).DoubleValue(), 9);
- EXPECT_EQ(Variant(true).BoolValue(), true);
- EXPECT_EQ(Variant("hello").StringValue(), "hello");
-}
-
-} // namespace
-} // namespace libtextclassifier3