Export libtextclassifier
Test: atest libtextclassifier_tests -- --all-abi
Change-Id: I292f8509ae4bedf7800e36a1b2a615001da537d6
diff --git a/java/src/com/android/textclassifier/ActionsSuggestionsHelper.java b/java/src/com/android/textclassifier/ActionsSuggestionsHelper.java
index a51c95d..4838503 100644
--- a/java/src/com/android/textclassifier/ActionsSuggestionsHelper.java
+++ b/java/src/com/android/textclassifier/ActionsSuggestionsHelper.java
@@ -147,6 +147,9 @@
public static LabeledIntent.TitleChooser createTitleChooser(String actionType) {
if (ConversationAction.TYPE_OPEN_URL.equals(actionType)) {
return (labeledIntent, resolveInfo) -> {
+ if (resolveInfo == null) {
+ return labeledIntent.titleWithEntity;
+ }
if (resolveInfo.handleAllWebDataURI) {
return labeledIntent.titleWithEntity;
}
diff --git a/java/src/com/android/textclassifier/TextClassifierImpl.java b/java/src/com/android/textclassifier/TextClassifierImpl.java
index 5c028ef..429ed6b 100644
--- a/java/src/com/android/textclassifier/TextClassifierImpl.java
+++ b/java/src/com/android/textclassifier/TextClassifierImpl.java
@@ -192,7 +192,10 @@
string,
request.getStartIndex(),
request.getEndIndex(),
- new AnnotatorModel.SelectionOptions(localesString, detectLanguageTags));
+ AnnotatorModel.SelectionOptions.builder()
+ .setLocales(localesString)
+ .setDetectedTextLanguageTags(detectLanguageTags)
+ .build());
final int start = startEnd[0];
final int end = startEnd[1];
if (start < end
@@ -206,11 +209,12 @@
string,
start,
end,
- new AnnotatorModel.ClassificationOptions(
- refTime.toInstant().toEpochMilli(),
- refTime.getZone().getId(),
- localesString,
- detectLanguageTags),
+ AnnotatorModel.ClassificationOptions.builder()
+ .setReferenceTimeMsUtc(refTime.toInstant().toEpochMilli())
+ .setReferenceTimezone(refTime.getZone().getId())
+ .setLocales(localesString)
+ .setDetectedTextLanguageTags(detectLanguageTags)
+ .build(),
// Passing null here to suppress intent generation
// TODO: Use an explicit flag to suppress it.
/* appContext */ null,
@@ -256,13 +260,14 @@
string,
request.getStartIndex(),
request.getEndIndex(),
- new AnnotatorModel.ClassificationOptions(
- refTime.toInstant().toEpochMilli(),
- refTime.getZone().getId(),
- localesString,
- String.join(",", detectLanguageTags),
- AnnotatorModel.AnnotationUsecase.SMART.getValue(),
- LocaleList.getDefault().toLanguageTags()),
+ AnnotatorModel.ClassificationOptions.builder()
+ .setReferenceTimeMsUtc(refTime.toInstant().toEpochMilli())
+ .setReferenceTimezone(refTime.getZone().getId())
+ .setLocales(localesString)
+ .setDetectedTextLanguageTags(String.join(",", detectLanguageTags))
+ .setAnnotationUsecase(AnnotatorModel.AnnotationUsecase.SMART.getValue())
+ .setUserFamiliarLanguageTags(LocaleList.getDefault().toLanguageTags())
+ .build(),
context,
getResourceLocalesString());
if (results.length > 0) {
@@ -309,14 +314,15 @@
final AnnotatorModel.AnnotatedSpan[] annotations =
annotatorImpl.annotate(
textString,
- new AnnotatorModel.AnnotationOptions(
- refTime.toInstant().toEpochMilli(),
- refTime.getZone().getId(),
- localesString,
- String.join(",", detectLanguageTags),
- entitiesToIdentify,
- AnnotatorModel.AnnotationUsecase.SMART.getValue(),
- isSerializedEntityDataEnabled));
+ AnnotatorModel.AnnotationOptions.builder()
+ .setReferenceTimeMsUtc(refTime.toInstant().toEpochMilli())
+ .setReferenceTimezone(refTime.getZone().getId())
+ .setLocales(localesString)
+ .setDetectedTextLanguageTags(String.join(",", detectLanguageTags))
+ .setEntityTypes(entitiesToIdentify)
+ .setAnnotationUsecase(AnnotatorModel.AnnotationUsecase.SMART.getValue())
+ .setIsSerializedEntityDataEnabled(isSerializedEntityDataEnabled)
+ .build());
for (AnnotatorModel.AnnotatedSpan span : annotations) {
final AnnotatorModel.ClassificationResult[] results = span.getClassification();
if (results.length == 0 || !entitiesToIdentify.contains(results[0].getCollection())) {
diff --git a/java/src/com/android/textclassifier/common/intent/LabeledIntent.java b/java/src/com/android/textclassifier/common/intent/LabeledIntent.java
index b56d0bb..b32e1ce 100644
--- a/java/src/com/android/textclassifier/common/intent/LabeledIntent.java
+++ b/java/src/com/android/textclassifier/common/intent/LabeledIntent.java
@@ -29,6 +29,7 @@
import androidx.core.content.ContextCompat;
import androidx.core.graphics.drawable.IconCompat;
import com.android.textclassifier.common.base.TcLog;
+import com.google.common.base.Objects;
import com.google.common.base.Preconditions;
import javax.annotation.Nullable;
@@ -94,8 +95,26 @@
final ResolveInfo resolveInfo = pm.resolveActivity(intent, 0);
if (resolveInfo == null || resolveInfo.activityInfo == null) {
- TcLog.w(TAG, "resolveInfo or activityInfo is null");
- return null;
+ // Failed to resolve the intent. It could be because there are no apps to handle
+ // the intent. It could be also because the calling app has no visibility to the target app
+ // due to the app visibility feature introduced on R. For privacy reason, we don't want to
+ // force users of our library to ask for the visibility to the http/https view intent.
+ // Getting visibility to this intent effectively means getting visibility of ~70% of apps.
+ // This defeats the purpose of the app visibility feature. Practically speaking, all devices
+ // are very likely to have a browser installed. Thus, if it is a web intent, we assume we
+ // failed to resolve the intent just because of the app visibility feature. In which case, we
+ // return an implicit intent without an icon.
+ if (isWebIntent()) {
+ IconCompat icon = IconCompat.createWithResource(context, android.R.drawable.ic_menu_more);
+ RemoteActionCompat action =
+ createRemoteAction(
+ context, intent, icon, /* shouldShowIcon= */ false, resolveInfo, titleChooser);
+ // Create a clone so that the client does not modify the original intent.
+ return new Result(new Intent(intent), action);
+ } else {
+ TcLog.w(TAG, "resolveInfo or activityInfo is null");
+ return null;
+ }
}
if (!hasPermission(context, resolveInfo.activityInfo)) {
TcLog.d(TAG, "No permission to access: " + resolveInfo.activityInfo);
@@ -126,6 +145,19 @@
// RemoteAction requires that there be an icon.
icon = IconCompat.createWithResource(context, android.R.drawable.ic_menu_more);
}
+ RemoteActionCompat action =
+ createRemoteAction(
+ context, resolvedIntent, icon, shouldShowIcon, resolveInfo, titleChooser);
+ return new Result(resolvedIntent, action);
+ }
+
+ private RemoteActionCompat createRemoteAction(
+ Context context,
+ Intent resolvedIntent,
+ IconCompat icon,
+ boolean shouldShowIcon,
+ @Nullable ResolveInfo resolveInfo,
+ @Nullable TitleChooser titleChooser) {
final PendingIntent pendingIntent = createPendingIntent(context, resolvedIntent, requestCode);
titleChooser = titleChooser == null ? DEFAULT_TITLE_CHOOSER : titleChooser;
CharSequence title = titleChooser.chooseTitle(this, resolveInfo);
@@ -134,12 +166,25 @@
title = DEFAULT_TITLE_CHOOSER.chooseTitle(this, resolveInfo);
}
final RemoteActionCompat action =
- new RemoteActionCompat(icon, title, resolveDescription(resolveInfo, pm), pendingIntent);
+ new RemoteActionCompat(
+ icon,
+ title,
+ resolveDescription(resolveInfo, context.getPackageManager()),
+ pendingIntent);
action.setShouldShowIcon(shouldShowIcon);
- return new Result(resolvedIntent, action);
+ return action;
}
- private String resolveDescription(ResolveInfo resolveInfo, PackageManager packageManager) {
+ private boolean isWebIntent() {
+ if (!Intent.ACTION_VIEW.equals(intent.getAction())) {
+ return false;
+ }
+ final String scheme = intent.getScheme();
+ return Objects.equal(scheme, "http") || Objects.equal(scheme, "https");
+ }
+
+ private String resolveDescription(
+ @Nullable ResolveInfo resolveInfo, PackageManager packageManager) {
if (!TextUtils.isEmpty(descriptionWithAppName)) {
// Example string format of descriptionWithAppName: "Use %1$s to open map".
String applicationName = getApplicationName(resolveInfo, packageManager);
@@ -169,8 +214,9 @@
}
@Nullable
- private static String getApplicationName(ResolveInfo resolveInfo, PackageManager packageManager) {
- if (resolveInfo.activityInfo == null) {
+ private static String getApplicationName(
+ @Nullable ResolveInfo resolveInfo, PackageManager packageManager) {
+ if (resolveInfo == null || resolveInfo.activityInfo == null) {
return null;
}
if ("android".equals(resolveInfo.activityInfo.packageName)) {
@@ -214,6 +260,6 @@
* is guaranteed to have a non-null {@code activityInfo}.
*/
@Nullable
- CharSequence chooseTitle(LabeledIntent labeledIntent, ResolveInfo resolveInfo);
+ CharSequence chooseTitle(LabeledIntent labeledIntent, @Nullable ResolveInfo resolveInfo);
}
}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/ActionsSuggestionsHelperTest.java b/java/tests/instrumentation/src/com/android/textclassifier/ActionsSuggestionsHelperTest.java
index 59dc41a..427e89e 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/ActionsSuggestionsHelperTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/ActionsSuggestionsHelperTest.java
@@ -25,6 +25,8 @@
import android.app.RemoteAction;
import android.content.ComponentName;
import android.content.Intent;
+import android.content.pm.ActivityInfo;
+import android.content.pm.ResolveInfo;
import android.graphics.drawable.Icon;
import android.net.Uri;
import android.os.Bundle;
@@ -34,6 +36,7 @@
import androidx.test.ext.junit.runners.AndroidJUnit4;
import androidx.test.filters.SmallTest;
import com.android.textclassifier.common.intent.LabeledIntent;
+import com.android.textclassifier.common.intent.LabeledIntent.TitleChooser;
import com.android.textclassifier.common.intent.TemplateIntentFactory;
import com.google.android.textclassifier.ActionsSuggestionsModel;
import com.google.android.textclassifier.RemoteActionTemplate;
@@ -289,6 +292,92 @@
assertThat(labeledIntentResult.resolvedIntent.getAction()).isEqualTo(Intent.ACTION_VIEW);
}
+ @Test
+ public void createTitleChooser_notOpenUrl() {
+ assertThat(ActionsSuggestionsHelper.createTitleChooser(ConversationAction.TYPE_CALL_PHONE))
+ .isNull();
+ }
+
+ @Test
+ public void createTitleChooser_openUrl_resolveInfoIsNull() {
+ TitleChooser titleChooser =
+ ActionsSuggestionsHelper.createTitleChooser(ConversationAction.TYPE_OPEN_URL);
+ LabeledIntent labeledIntent = createWebLabeledIntent();
+
+ assertThat(titleChooser.chooseTitle(labeledIntent, /* resolveInfo= */ null).toString())
+ .isEqualTo("titleWithEntity");
+ }
+
+ @Test
+ public void createTitleChooser_openUrl_packageIsNotAndroidAndHandleAllWebDataUriTrue() {
+ TitleChooser titleChooser =
+ ActionsSuggestionsHelper.createTitleChooser(ConversationAction.TYPE_OPEN_URL);
+ LabeledIntent labeledIntent = createWebLabeledIntent();
+
+ assertThat(
+ titleChooser
+ .chooseTitle(
+ labeledIntent,
+ createResolveInfo("com.android.chrome", /* handleAllWebDataURI= */ true))
+ .toString())
+ .isEqualTo("titleWithEntity");
+ }
+
+ @Test
+ public void createTitleChooser_openUrl_packageIsNotAndroidAndHandleAllWebDataUriFalse() {
+ TitleChooser titleChooser =
+ ActionsSuggestionsHelper.createTitleChooser(ConversationAction.TYPE_OPEN_URL);
+ LabeledIntent labeledIntent = createWebLabeledIntent();
+
+ assertThat(
+ titleChooser
+ .chooseTitle(
+ labeledIntent,
+ createResolveInfo("com.youtube", /* handleAllWebDataURI= */ false))
+ .toString())
+ .isEqualTo("titleWithoutEntity");
+ }
+
+ @Test
+ public void createTitleChooser_openUrl_packageIsAndroidAndHandleAllWebDataUriFalse() {
+ TitleChooser titleChooser =
+ ActionsSuggestionsHelper.createTitleChooser(ConversationAction.TYPE_OPEN_URL);
+ LabeledIntent labeledIntent = createWebLabeledIntent();
+
+ assertThat(
+ titleChooser
+ .chooseTitle(
+ labeledIntent, createResolveInfo("android", /* handleAllWebDataURI= */ false))
+ .toString())
+ .isEqualTo("titleWithEntity");
+ }
+
+ @Test
+ public void createTitleChooser_openUrl_packageIsAndroidAndHandleAllWebDataUriTrue() {
+ TitleChooser titleChooser =
+ ActionsSuggestionsHelper.createTitleChooser(ConversationAction.TYPE_OPEN_URL);
+ LabeledIntent labeledIntent = createWebLabeledIntent();
+
+ assertThat(
+ titleChooser
+ .chooseTitle(
+ labeledIntent, createResolveInfo("android", /* handleAllWebDataURI= */ true))
+ .toString())
+ .isEqualTo("titleWithEntity");
+ }
+
+ private LabeledIntent createWebLabeledIntent() {
+ Intent webIntent = new Intent(Intent.ACTION_VIEW);
+ webIntent.setData(Uri.parse("http://www.android.com"));
+ return new LabeledIntent(
+ "titleWithoutEntity",
+ "titleWithEntity",
+ "description",
+ "descriptionWithAppName",
+ webIntent,
+ /* requestCode= */ 0);
+ }
+
private static ZonedDateTime createZonedDateTimeFromMsUtc(long msUtc) {
return ZonedDateTime.ofInstant(Instant.ofEpochMilli(msUtc), ZoneId.of("UTC"));
}
@@ -303,4 +392,12 @@
assertThat(nativeMessage.getDetectedTextLanguageTags()).isEqualTo(LOCALE_TAG);
assertThat(nativeMessage.getReferenceTimeMsUtc()).isEqualTo(referenceTimeInMsUtc);
}
+
+ private static ResolveInfo createResolveInfo(String packageName, boolean handleAllWebDataURI) {
+ ResolveInfo resolveInfo = new ResolveInfo();
+ resolveInfo.activityInfo = new ActivityInfo();
+ resolveInfo.activityInfo.packageName = packageName;
+ resolveInfo.handleAllWebDataURI = handleAllWebDataURI;
+ return resolveInfo;
+ }
}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/common/intent/LabeledIntentTest.java b/java/tests/instrumentation/src/com/android/textclassifier/common/intent/LabeledIntentTest.java
index a1d9dcf..dfc09a7 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/common/intent/LabeledIntentTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/common/intent/LabeledIntentTest.java
@@ -154,4 +154,56 @@
assertThat(result.remoteAction.getContentDescription().toString())
.isEqualTo("Use fake to open map");
}
+
+ @Test
+ public void resolve_noVisibilityToWebIntentHandler() {
+ Context context =
+ new FakeContextBuilder()
+ .setIntentComponent(Intent.ACTION_VIEW, /* component= */ null)
+ .build();
+ Intent webIntent = new Intent(Intent.ACTION_VIEW);
+ webIntent.setData(Uri.parse("https://www.android.com"));
+ LabeledIntent labeledIntent =
+ new LabeledIntent(
+ TITLE_WITHOUT_ENTITY,
+ TITLE_WITH_ENTITY,
+ DESCRIPTION,
+ /* descriptionWithAppName= */ null,
+ webIntent,
+ REQUEST_CODE);
+
+ LabeledIntent.Result result = labeledIntent.resolve(context, /*titleChooser*/ null);
+
+ assertThat(result).isNotNull();
+ assertThat(result.remoteAction.getTitle().toString()).isEqualTo(TITLE_WITH_ENTITY);
+ assertThat(result.remoteAction.getContentDescription().toString()).isEqualTo(DESCRIPTION);
+ assertThat(result.resolvedIntent.getAction()).isEqualTo(Intent.ACTION_VIEW);
+ assertThat(result.resolvedIntent.getComponent()).isNull();
+ }
+
+ @Test
+ public void resolve_noVisibilityToWebIntentHandler_withDescriptionWithAppName() {
+ Context context =
+ new FakeContextBuilder()
+ .setIntentComponent(Intent.ACTION_VIEW, /* component= */ null)
+ .build();
+ Intent webIntent = new Intent(Intent.ACTION_VIEW);
+ webIntent.setData(Uri.parse("https://www.android.com"));
+ LabeledIntent labeledIntent =
+ new LabeledIntent(
+ TITLE_WITHOUT_ENTITY,
+ TITLE_WITH_ENTITY,
+ DESCRIPTION,
+ /* descriptionWithAppName= */ "name",
+ webIntent,
+ REQUEST_CODE);
+
+ LabeledIntent.Result result = labeledIntent.resolve(context, /*titleChooser*/ null);
+
+ assertThat(result).isNotNull();
+ assertThat(result.remoteAction.getTitle().toString()).isEqualTo(TITLE_WITH_ENTITY);
+ assertThat(result.remoteAction.getContentDescription().toString()).isEqualTo(DESCRIPTION);
+ assertThat(result.resolvedIntent.getAction()).isEqualTo(Intent.ACTION_VIEW);
+ assertThat(result.resolvedIntent.getComponent()).isNull();
+ }
}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/StatsdTestUtils.java b/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/StatsdTestUtils.java
index f2b8223..b52509c 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/StatsdTestUtils.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/StatsdTestUtils.java
@@ -46,7 +46,7 @@
/** Util functions to make statsd testing easier by using adb shell cmd stats commands. */
public class StatsdTestUtils {
private static final String TAG = "StatsdTestUtils";
- private static final long SHORT_WAIT_MS = 1000;
+ private static final long LONG_WAIT_MS = 5000;
private StatsdTestUtils() {}
@@ -75,8 +75,8 @@
* Extracts logged atoms from the report, sorted by logging time, and deletes the saved report.
*/
public static ImmutableList<Atom> getLoggedAtoms(long configId) throws Exception {
- // There is no callback to notify us the log is collected. So we do a short wait here.
- Thread.sleep(SHORT_WAIT_MS);
+ // There is no callback to notify us the log is collected. So we do a wait here.
+ Thread.sleep(LONG_WAIT_MS);
ConfigMetricsReportList reportList = getAndRemoveReportList(configId);
assertThat(reportList.getReportsCount()).isEqualTo(1);
diff --git a/jni/com/google/android/textclassifier/ActionsSuggestionsModel.java b/jni/com/google/android/textclassifier/ActionsSuggestionsModel.java
index 3af04e8..8b6cf2e 100644
--- a/jni/com/google/android/textclassifier/ActionsSuggestionsModel.java
+++ b/jni/com/google/android/textclassifier/ActionsSuggestionsModel.java
@@ -16,7 +16,9 @@
package com.google.android.textclassifier;
+import android.content.res.AssetFileDescriptor;
import java.util.concurrent.atomic.AtomicBoolean;
+import javax.annotation.Nullable;
/**
* Java wrapper for ActionsSuggestions native library interface. This library is used to suggest
@@ -37,7 +39,7 @@
* Creates a new instance of Actions predictor, using the provided model image, given as a file
* descriptor.
*/
- public ActionsSuggestionsModel(int fileDescriptor, byte[] serializedPreconditions) {
+ public ActionsSuggestionsModel(int fileDescriptor, @Nullable byte[] serializedPreconditions) {
actionsModelPtr = nativeNewActionsModel(fileDescriptor, serializedPreconditions);
if (actionsModelPtr == 0L) {
throw new IllegalArgumentException("Couldn't initialize actions model from file descriptor.");
@@ -52,7 +54,7 @@
* Creates a new instance of Actions predictor, using the provided model image, given as a file
* path.
*/
- public ActionsSuggestionsModel(String path, byte[] serializedPreconditions) {
+ public ActionsSuggestionsModel(String path, @Nullable byte[] serializedPreconditions) {
actionsModelPtr = nativeNewActionsModelFromPath(path, serializedPreconditions);
if (actionsModelPtr == 0L) {
throw new IllegalArgumentException("Couldn't initialize actions model from given file.");
@@ -63,6 +65,27 @@
this(path, /* serializedPreconditions= */ null);
}
+ /**
+ * Creates a new instance of Actions predictor, using the provided model image, given as an {@link
+ * AssetFileDescriptor}).
+ */
+ public ActionsSuggestionsModel(
+ AssetFileDescriptor assetFileDescriptor, @Nullable byte[] serializedPreconditions) {
+ actionsModelPtr =
+ nativeNewActionsModelWithOffset(
+ assetFileDescriptor.getParcelFileDescriptor().getFd(),
+ assetFileDescriptor.getStartOffset(),
+ assetFileDescriptor.getLength(),
+ serializedPreconditions);
+ if (actionsModelPtr == 0L) {
+ throw new IllegalArgumentException("Couldn't initialize actions model from file descriptor.");
+ }
+ }
+
+ public ActionsSuggestionsModel(AssetFileDescriptor assetFileDescriptor) {
+ this(assetFileDescriptor, /* serializedPreconditions= */ null);
+ }
+
/** Suggests actions / replies to the given conversation. */
public ActionSuggestion[] suggestActions(
Conversation conversation, ActionSuggestionOptions options, AnnotatorModel annotator) {
@@ -115,32 +138,56 @@
return nativeGetLocales(fd);
}
+ /** Returns a comma separated list of locales supported by the model as BCP 47 tags. */
+ public static String getLocales(AssetFileDescriptor assetFileDescriptor) {
+ return nativeGetLocalesWithOffset(
+ assetFileDescriptor.getParcelFileDescriptor().getFd(),
+ assetFileDescriptor.getStartOffset(),
+ assetFileDescriptor.getLength());
+ }
+
/** Returns the version of the model. */
public static int getVersion(int fd) {
return nativeGetVersion(fd);
}
+ /** Returns the version of the model. */
+ public static int getVersion(AssetFileDescriptor assetFileDescriptor) {
+ return nativeGetVersionWithOffset(
+ assetFileDescriptor.getParcelFileDescriptor().getFd(),
+ assetFileDescriptor.getStartOffset(),
+ assetFileDescriptor.getLength());
+ }
+
/** Returns the name of the model. */
public static String getName(int fd) {
return nativeGetName(fd);
}
+ /** Returns the name of the model. */
+ public static String getName(AssetFileDescriptor assetFileDescriptor) {
+ return nativeGetNameWithOffset(
+ assetFileDescriptor.getParcelFileDescriptor().getFd(),
+ assetFileDescriptor.getStartOffset(),
+ assetFileDescriptor.getLength());
+ }
+
/** Action suggestion that contains a response text and the type of the response. */
public static final class ActionSuggestion {
- private final String responseText;
+ @Nullable private final String responseText;
private final String actionType;
private final float score;
- private final NamedVariant[] entityData;
- private final byte[] serializedEntityData;
- private final RemoteActionTemplate[] remoteActionTemplates;
+ @Nullable private final NamedVariant[] entityData;
+ @Nullable private final byte[] serializedEntityData;
+ @Nullable private final RemoteActionTemplate[] remoteActionTemplates;
public ActionSuggestion(
- String responseText,
+ @Nullable String responseText,
String actionType,
float score,
- NamedVariant[] entityData,
- byte[] serializedEntityData,
- RemoteActionTemplate[] remoteActionTemplates) {
+ @Nullable NamedVariant[] entityData,
+ @Nullable byte[] serializedEntityData,
+ @Nullable RemoteActionTemplate[] remoteActionTemplates) {
this.responseText = responseText;
this.actionType = actionType;
this.score = score;
@@ -149,6 +196,7 @@
this.remoteActionTemplates = remoteActionTemplates;
}
+ @Nullable
public String getResponseText() {
return responseText;
}
@@ -162,14 +210,17 @@
return score;
}
+ @Nullable
public NamedVariant[] getEntityData() {
return entityData;
}
+ @Nullable
public byte[] getSerializedEntityData() {
return serializedEntityData;
}
+ @Nullable
public RemoteActionTemplate[] getRemoteActionTemplates() {
return remoteActionTemplates;
}
@@ -178,17 +229,17 @@
/** Represents a single message in the conversation. */
public static final class ConversationMessage {
private final int userId;
- private final String text;
+ @Nullable private final String text;
private final long referenceTimeMsUtc;
- private final String referenceTimezone;
- private final String detectedTextLanguageTags;
+ @Nullable private final String referenceTimezone;
+ @Nullable private final String detectedTextLanguageTags;
public ConversationMessage(
int userId,
- String text,
+ @Nullable String text,
long referenceTimeMsUtc,
- String referenceTimezone,
- String detectedTextLanguageTags) {
+ @Nullable String referenceTimezone,
+ @Nullable String detectedTextLanguageTags) {
this.userId = userId;
this.text = text;
this.referenceTimeMsUtc = referenceTimeMsUtc;
@@ -201,6 +252,7 @@
return userId;
}
+ @Nullable
public String getText() {
return text;
}
@@ -213,11 +265,13 @@
return referenceTimeMsUtc;
}
+ @Nullable
public String getReferenceTimezone() {
return referenceTimezone;
}
/** Returns a comma separated list of BCP 47 language tags. */
+ @Nullable
public String getDetectedTextLanguageTags() {
return detectedTextLanguageTags;
}
diff --git a/jni/com/google/android/textclassifier/AnnotatorModel.java b/jni/com/google/android/textclassifier/AnnotatorModel.java
index 7658bf5..a116f0a 100644
--- a/jni/com/google/android/textclassifier/AnnotatorModel.java
+++ b/jni/com/google/android/textclassifier/AnnotatorModel.java
@@ -16,8 +16,10 @@
package com.google.android.textclassifier;
+import android.content.res.AssetFileDescriptor;
import java.util.Collection;
import java.util.concurrent.atomic.AtomicBoolean;
+import javax.annotation.Nullable;
/**
* Java wrapper for Annotator native library interface. This library is used for detecting entities
@@ -49,7 +51,7 @@
private long annotatorPtr;
// To tell GC to keep the LangID model alive at least as long as this object.
- private LangIdModel langIdModel;
+ @Nullable private LangIdModel langIdModel;
/** Enumeration for specifying the usecase of the annotations. */
public static enum AnnotationUsecase {
@@ -73,6 +75,25 @@
}
};
+ /** Enumeration for specifying the annotate mode. */
+ public static enum AnnotateMode {
+ /** Result contains entity annotation for each input fragment. */
+ ENTITY_ANNOTATION(0),
+
+ /** Result will include both entity annotation and topicality annotation. */
+ ENTITY_AND_TOPICALITY_ANNOTATION(1);
+
+ private final int value;
+
+ AnnotateMode(int value) {
+ this.value = value;
+ }
+
+ public int getValue() {
+ return value;
+ }
+ };
+
/**
* Creates a new instance of SmartSelect predictor, using the provided model image, given as a
* file descriptor.
@@ -95,6 +116,21 @@
}
}
+ /**
+ * Creates a new instance of SmartSelect predictor, using the provided model image, given as an
+ * {@link AssetFileDescriptor}.
+ */
+ public AnnotatorModel(AssetFileDescriptor assetFileDescriptor) {
+ annotatorPtr =
+ nativeNewAnnotatorWithOffset(
+ assetFileDescriptor.getParcelFileDescriptor().getFd(),
+ assetFileDescriptor.getStartOffset(),
+ assetFileDescriptor.getLength());
+ if (annotatorPtr == 0L) {
+ throw new IllegalArgumentException("Couldn't initialize TC from asset file descriptor.");
+ }
+ }
+
/** Initializes the knowledge engine, passing the given serialized config to it. */
public void initializeKnowledgeEngine(byte[] serializedConfig) {
if (!nativeInitializeKnowledgeEngine(annotatorPtr, serializedConfig)) {
@@ -121,12 +157,26 @@
* before this object is closed. Also, this object does not take the memory ownership of the given
* LangIdModel object.
*/
- public void setLangIdModel(LangIdModel langIdModel) {
+ public void setLangIdModel(@Nullable LangIdModel langIdModel) {
this.langIdModel = langIdModel;
nativeSetLangId(annotatorPtr, langIdModel == null ? 0 : langIdModel.getNativePointer());
}
/**
+ * Initializes the person name engine, using the provided model image, given as an {@link
+ * AssetFileDescriptor}.
+ */
+ public void initializePersonNameEngine(AssetFileDescriptor assetFileDescriptor) {
+ if (!nativeInitializePersonNameEngine(
+ annotatorPtr,
+ assetFileDescriptor.getParcelFileDescriptor().getFd(),
+ assetFileDescriptor.getStartOffset(),
+ assetFileDescriptor.getLength())) {
+ throw new IllegalArgumentException("Couldn't initialize the person name engine");
+ }
+ }
+
+ /**
* Given a string context and current selection, computes the selection suggestion.
*
* <p>The begin and end are character indices into the context UTF8 string. selectionBegin is the
@@ -183,8 +233,7 @@
* Annotates multiple fragments of text at once. There will be one AnnotatedSpan array for each
* input fragment to annotate.
*/
- public AnnotatedSpan[][] annotateStructuredInput(
- InputFragment[] fragments, AnnotationOptions options) {
+ public Annotations annotateStructuredInput(InputFragment[] fragments, AnnotationOptions options) {
return nativeAnnotateStructuredInput(annotatorPtr, fragments, options);
}
@@ -219,16 +268,40 @@
return nativeGetLocales(fd);
}
+ /** Returns a comma separated list of locales supported by the model as BCP 47 tags. */
+ public static String getLocales(AssetFileDescriptor assetFileDescriptor) {
+ return nativeGetLocalesWithOffset(
+ assetFileDescriptor.getParcelFileDescriptor().getFd(),
+ assetFileDescriptor.getStartOffset(),
+ assetFileDescriptor.getLength());
+ }
+
/** Returns the version of the model. */
public static int getVersion(int fd) {
return nativeGetVersion(fd);
}
+ /** Returns the version of the model. */
+ public static int getVersion(AssetFileDescriptor assetFileDescriptor) {
+ return nativeGetVersionWithOffset(
+ assetFileDescriptor.getParcelFileDescriptor().getFd(),
+ assetFileDescriptor.getStartOffset(),
+ assetFileDescriptor.getLength());
+ }
+
/** Returns the name of the model. */
public static String getName(int fd) {
return nativeGetName(fd);
}
+ /** Returns the name of the model. */
+ public static String getName(AssetFileDescriptor assetFileDescriptor) {
+ return nativeGetNameWithOffset(
+ assetFileDescriptor.getParcelFileDescriptor().getFd(),
+ assetFileDescriptor.getStartOffset(),
+ assetFileDescriptor.getLength());
+ }
+
/** Information about a parsed time/date. */
public static final class DatetimeResult {
@@ -261,20 +334,20 @@
public static final class ClassificationResult {
private final String collection;
private final float score;
- private final DatetimeResult datetimeResult;
- private final byte[] serializedKnowledgeResult;
- private final String contactName;
- private final String contactGivenName;
- private final String contactFamilyName;
- private final String contactNickname;
- private final String contactEmailAddress;
- private final String contactPhoneNumber;
- private final String contactId;
- private final String appName;
- private final String appPackageName;
- private final NamedVariant[] entityData;
- private final byte[] serializedEntityData;
- private final RemoteActionTemplate[] remoteActionTemplates;
+ @Nullable private final DatetimeResult datetimeResult;
+ @Nullable private final byte[] serializedKnowledgeResult;
+ @Nullable private final String contactName;
+ @Nullable private final String contactGivenName;
+ @Nullable private final String contactFamilyName;
+ @Nullable private final String contactNickname;
+ @Nullable private final String contactEmailAddress;
+ @Nullable private final String contactPhoneNumber;
+ @Nullable private final String contactId;
+ @Nullable private final String appName;
+ @Nullable private final String appPackageName;
+ @Nullable private final NamedVariant[] entityData;
+ @Nullable private final byte[] serializedEntityData;
+ @Nullable private final RemoteActionTemplate[] remoteActionTemplates;
private final long durationMs;
private final long numericValue;
private final double numericDoubleValue;
@@ -282,20 +355,20 @@
public ClassificationResult(
String collection,
float score,
- DatetimeResult datetimeResult,
- byte[] serializedKnowledgeResult,
- String contactName,
- String contactGivenName,
- String contactFamilyName,
- String contactNickname,
- String contactEmailAddress,
- String contactPhoneNumber,
- String contactId,
- String appName,
- String appPackageName,
- NamedVariant[] entityData,
- byte[] serializedEntityData,
- RemoteActionTemplate[] remoteActionTemplates,
+ @Nullable DatetimeResult datetimeResult,
+ @Nullable byte[] serializedKnowledgeResult,
+ @Nullable String contactName,
+ @Nullable String contactGivenName,
+ @Nullable String contactFamilyName,
+ @Nullable String contactNickname,
+ @Nullable String contactEmailAddress,
+ @Nullable String contactPhoneNumber,
+ @Nullable String contactId,
+ @Nullable String appName,
+ @Nullable String appPackageName,
+ @Nullable NamedVariant[] entityData,
+ @Nullable byte[] serializedEntityData,
+ @Nullable RemoteActionTemplate[] remoteActionTemplates,
long durationMs,
long numericValue,
double numericDoubleValue) {
@@ -330,58 +403,72 @@
return score;
}
+ @Nullable
public DatetimeResult getDatetimeResult() {
return datetimeResult;
}
+ @Nullable
public byte[] getSerializedKnowledgeResult() {
return serializedKnowledgeResult;
}
+ @Nullable
public String getContactName() {
return contactName;
}
+ @Nullable
public String getContactGivenName() {
return contactGivenName;
}
+ @Nullable
public String getContactFamilyName() {
return contactFamilyName;
}
+ @Nullable
public String getContactNickname() {
return contactNickname;
}
+ @Nullable
public String getContactEmailAddress() {
return contactEmailAddress;
}
+ @Nullable
public String getContactPhoneNumber() {
return contactPhoneNumber;
}
+ @Nullable
public String getContactId() {
return contactId;
}
+ @Nullable
public String getAppName() {
return appName;
}
+ @Nullable
public String getAppPackageName() {
return appPackageName;
}
+ @Nullable
public NamedVariant[] getEntityData() {
return entityData;
}
+ @Nullable
public byte[] getSerializedEntityData() {
return serializedEntityData;
}
+ @Nullable
public RemoteActionTemplate[] getRemoteActionTemplates() {
return remoteActionTemplates;
}
@@ -424,6 +511,28 @@
}
}
+ /**
+ * Represents a result of Annotate call, which will include both entity annotations and topicality
+ * annotations.
+ */
+ public static final class Annotations {
+ private final AnnotatedSpan[][] annotatedSpans;
+ private final ClassificationResult[] topicalityResults;
+
+ Annotations(AnnotatedSpan[][] annotatedSpans, ClassificationResult[] topicalityResults) {
+ this.annotatedSpans = annotatedSpans;
+ this.topicalityResults = topicalityResults;
+ }
+
+ public AnnotatedSpan[][] getAnnotatedSpans() {
+ return annotatedSpans;
+ }
+
+ public ClassificationResult[] getTopicalityResults() {
+ return topicalityResults;
+ }
+ }
+
/** Represents a fragment of text to the AnnotateStructuredInput call. */
public static final class InputFragment {
@@ -470,37 +579,101 @@
}
}
- /**
- * Represents options for the suggestSelection call. TODO(b/63427420): Use location with Selection
- * options.
- */
+ /** Represents options for the suggestSelection call. */
public static final class SelectionOptions {
- private final String locales;
- private final String detectedTextLanguageTags;
+ @Nullable private final String locales;
+ @Nullable private final String detectedTextLanguageTags;
private final int annotationUsecase;
private final double userLocationLat;
private final double userLocationLng;
private final float userLocationAccuracyMeters;
+ private final boolean usePodNer;
- public SelectionOptions(
- String locales, String detectedTextLanguageTags, int annotationUsecase) {
+ private SelectionOptions(
+ @Nullable String locales,
+ @Nullable String detectedTextLanguageTags,
+ int annotationUsecase,
+ double userLocationLat,
+ double userLocationLng,
+ float userLocationAccuracyMeters,
+ boolean usePodNer) {
this.locales = locales;
this.detectedTextLanguageTags = detectedTextLanguageTags;
this.annotationUsecase = annotationUsecase;
- this.userLocationLat = INVALID_LATITUDE;
- this.userLocationLng = INVALID_LONGITUDE;
- this.userLocationAccuracyMeters = INVALID_LOCATION_ACCURACY_METERS;
+ this.userLocationLat = userLocationLat;
+ this.userLocationLng = userLocationLng;
+ this.userLocationAccuracyMeters = userLocationAccuracyMeters;
+ this.usePodNer = usePodNer;
}
- public SelectionOptions(String locales, String detectedTextLanguageTags) {
- this(locales, detectedTextLanguageTags, AnnotationUsecase.SMART.getValue());
+ /** Can be used to build a SelectionsOptions instance. */
+ public static class Builder {
+ @Nullable private String locales;
+ @Nullable private String detectedTextLanguageTags;
+ private int annotationUsecase = AnnotationUsecase.SMART.getValue();
+ private double userLocationLat = INVALID_LATITUDE;
+ private double userLocationLng = INVALID_LONGITUDE;
+ private float userLocationAccuracyMeters = INVALID_LOCATION_ACCURACY_METERS;
+ private boolean usePodNer = true;
+
+ public Builder setLocales(@Nullable String locales) {
+ this.locales = locales;
+ return this;
+ }
+
+ public Builder setDetectedTextLanguageTags(@Nullable String detectedTextLanguageTags) {
+ this.detectedTextLanguageTags = detectedTextLanguageTags;
+ return this;
+ }
+
+ public Builder setAnnotationUsecase(int annotationUsecase) {
+ this.annotationUsecase = annotationUsecase;
+ return this;
+ }
+
+ public Builder setUserLocationLat(double userLocationLat) {
+ this.userLocationLat = userLocationLat;
+ return this;
+ }
+
+ public Builder setUserLocationLng(double userLocationLng) {
+ this.userLocationLng = userLocationLng;
+ return this;
+ }
+
+ public Builder setUserLocationAccuracyMeters(float userLocationAccuracyMeters) {
+ this.userLocationAccuracyMeters = userLocationAccuracyMeters;
+ return this;
+ }
+
+ public Builder setUsePodNer(boolean usePodNer) {
+ this.usePodNer = usePodNer;
+ return this;
+ }
+
+ public SelectionOptions build() {
+ return new SelectionOptions(
+ locales,
+ detectedTextLanguageTags,
+ annotationUsecase,
+ userLocationLat,
+ userLocationLng,
+ userLocationAccuracyMeters,
+ usePodNer);
+ }
}
+ public static Builder builder() {
+ return new Builder();
+ }
+
+ @Nullable
public String getLocales() {
return locales;
}
/** Returns a comma separated list of BCP 47 language tags. */
+ @Nullable
public String getDetectedTextLanguageTags() {
return detectedTextLanguageTags;
}
@@ -520,53 +693,128 @@
public float getUserLocationAccuracyMeters() {
return userLocationAccuracyMeters;
}
+
+ public boolean getUsePodNer() {
+ return usePodNer;
+ }
}
- /**
- * Represents options for the classifyText call. TODO(b/63427420): Use location with
- * Classification options.
- */
+ /** Represents options for the classifyText call. */
public static final class ClassificationOptions {
private final long referenceTimeMsUtc;
private final String referenceTimezone;
- private final String locales;
- private final String detectedTextLanguageTags;
+ @Nullable private final String locales;
+ @Nullable private final String detectedTextLanguageTags;
private final int annotationUsecase;
private final double userLocationLat;
private final double userLocationLng;
private final float userLocationAccuracyMeters;
private final String userFamiliarLanguageTags;
+ private final boolean usePodNer;
- public ClassificationOptions(
+ private ClassificationOptions(
long referenceTimeMsUtc,
String referenceTimezone,
- String locales,
- String detectedTextLanguageTags,
+ @Nullable String locales,
+ @Nullable String detectedTextLanguageTags,
int annotationUsecase,
- String userFamiliarLanguageTags) {
+ double userLocationLat,
+ double userLocationLng,
+ float userLocationAccuracyMeters,
+ String userFamiliarLanguageTags,
+ boolean usePodNer) {
this.referenceTimeMsUtc = referenceTimeMsUtc;
this.referenceTimezone = referenceTimezone;
this.locales = locales;
this.detectedTextLanguageTags = detectedTextLanguageTags;
this.annotationUsecase = annotationUsecase;
- this.userLocationLat = INVALID_LATITUDE;
- this.userLocationLng = INVALID_LONGITUDE;
- this.userLocationAccuracyMeters = INVALID_LOCATION_ACCURACY_METERS;
+ this.userLocationLat = userLocationLat;
+ this.userLocationLng = userLocationLng;
+ this.userLocationAccuracyMeters = userLocationAccuracyMeters;
this.userFamiliarLanguageTags = userFamiliarLanguageTags;
+ this.usePodNer = usePodNer;
}
- public ClassificationOptions(
- long referenceTimeMsUtc,
- String referenceTimezone,
- String locales,
- String detectedTextLanguageTags) {
- this(
- referenceTimeMsUtc,
- referenceTimezone,
- locales,
- detectedTextLanguageTags,
- AnnotationUsecase.SMART.getValue(),
- "");
+ /** Can be used to build a ClassificationOptions instance. */
+ public static class Builder {
+ private long referenceTimeMsUtc;
+ @Nullable private String referenceTimezone;
+ @Nullable private String locales;
+ @Nullable private String detectedTextLanguageTags;
+ private int annotationUsecase = AnnotationUsecase.SMART.getValue();
+ private double userLocationLat = INVALID_LATITUDE;
+ private double userLocationLng = INVALID_LONGITUDE;
+ private float userLocationAccuracyMeters = INVALID_LOCATION_ACCURACY_METERS;
+ private String userFamiliarLanguageTags = "";
+ private boolean usePodNer = true;
+
+ public Builder setReferenceTimeMsUtc(long referenceTimeMsUtc) {
+ this.referenceTimeMsUtc = referenceTimeMsUtc;
+ return this;
+ }
+
+ public Builder setReferenceTimezone(String referenceTimezone) {
+ this.referenceTimezone = referenceTimezone;
+ return this;
+ }
+
+ public Builder setLocales(@Nullable String locales) {
+ this.locales = locales;
+ return this;
+ }
+
+ public Builder setDetectedTextLanguageTags(@Nullable String detectedTextLanguageTags) {
+ this.detectedTextLanguageTags = detectedTextLanguageTags;
+ return this;
+ }
+
+ public Builder setAnnotationUsecase(int annotationUsecase) {
+ this.annotationUsecase = annotationUsecase;
+ return this;
+ }
+
+ public Builder setUserLocationLat(double userLocationLat) {
+ this.userLocationLat = userLocationLat;
+ return this;
+ }
+
+ public Builder setUserLocationLng(double userLocationLng) {
+ this.userLocationLng = userLocationLng;
+ return this;
+ }
+
+ public Builder setUserLocationAccuracyMeters(float userLocationAccuracyMeters) {
+ this.userLocationAccuracyMeters = userLocationAccuracyMeters;
+ return this;
+ }
+
+ public Builder setUserFamiliarLanguageTags(String userFamiliarLanguageTags) {
+ this.userFamiliarLanguageTags = userFamiliarLanguageTags;
+ return this;
+ }
+
+ public Builder setUsePodNer(boolean usePodNer) {
+ this.usePodNer = usePodNer;
+ return this;
+ }
+
+ public ClassificationOptions build() {
+ return new ClassificationOptions(
+ referenceTimeMsUtc,
+ referenceTimezone,
+ locales,
+ detectedTextLanguageTags,
+ annotationUsecase,
+ userLocationLat,
+ userLocationLng,
+ userLocationAccuracyMeters,
+ userFamiliarLanguageTags,
+ usePodNer);
+ }
+ }
+
+ public static Builder builder() {
+ return new Builder();
}
public long getReferenceTimeMsUtc() {
@@ -577,11 +825,13 @@
return referenceTimezone;
}
+ @Nullable
public String getLocale() {
return locales;
}
/** Returns a comma separated list of BCP 47 language tags. */
+ @Nullable
public String getDetectedTextLanguageTags() {
return detectedTextLanguageTags;
}
@@ -605,15 +855,20 @@
public String getUserFamiliarLanguageTags() {
return userFamiliarLanguageTags;
}
+
+ public boolean getUsePodNer() {
+ return usePodNer;
+ }
}
/** Represents options for the annotate call. */
public static final class AnnotationOptions {
private final long referenceTimeMsUtc;
private final String referenceTimezone;
- private final String locales;
- private final String detectedTextLanguageTags;
+ @Nullable private final String locales;
+ @Nullable private final String detectedTextLanguageTags;
private final String[] entityTypes;
+ private final int annotateMode;
private final int annotationUsecase;
private final boolean hasLocationPermission;
private final boolean hasPersonalizationPermission;
@@ -621,25 +876,29 @@
private final double userLocationLat;
private final double userLocationLng;
private final float userLocationAccuracyMeters;
+ private final boolean usePodNer;
- public AnnotationOptions(
+ private AnnotationOptions(
long referenceTimeMsUtc,
String referenceTimezone,
- String locales,
- String detectedTextLanguageTags,
- Collection<String> entityTypes,
+ @Nullable String locales,
+ @Nullable String detectedTextLanguageTags,
+ @Nullable Collection<String> entityTypes,
+ int annotateMode,
int annotationUsecase,
boolean hasLocationPermission,
boolean hasPersonalizationPermission,
boolean isSerializedEntityDataEnabled,
double userLocationLat,
double userLocationLng,
- float userLocationAccuracyMeters) {
+ float userLocationAccuracyMeters,
+ boolean usePodNer) {
this.referenceTimeMsUtc = referenceTimeMsUtc;
this.referenceTimezone = referenceTimezone;
this.locales = locales;
this.detectedTextLanguageTags = detectedTextLanguageTags;
this.entityTypes = entityTypes == null ? new String[0] : entityTypes.toArray(new String[0]);
+ this.annotateMode = annotateMode;
this.annotationUsecase = annotationUsecase;
this.isSerializedEntityDataEnabled = isSerializedEntityDataEnabled;
this.userLocationLat = userLocationLat;
@@ -647,68 +906,117 @@
this.userLocationAccuracyMeters = userLocationAccuracyMeters;
this.hasLocationPermission = hasLocationPermission;
this.hasPersonalizationPermission = hasPersonalizationPermission;
+ this.usePodNer = usePodNer;
}
- public AnnotationOptions(
- long referenceTimeMsUtc,
- String referenceTimezone,
- String locales,
- String detectedTextLanguageTags,
- Collection<String> entityTypes,
- int annotationUsecase,
- boolean isSerializedEntityDataEnabled,
- double userLocationLat,
- double userLocationLng,
- float userLocationAccuracyMeters) {
- this(
- referenceTimeMsUtc,
- referenceTimezone,
- locales,
- detectedTextLanguageTags,
- entityTypes,
- annotationUsecase,
- /* hasLocationPermission */ true,
- /* hasPersonalizationPermission */ true,
- isSerializedEntityDataEnabled,
- userLocationLat,
- userLocationLng,
- userLocationAccuracyMeters);
+ /** Can be used to build an AnnotationOptions instance. */
+ public static class Builder {
+ private long referenceTimeMsUtc;
+ @Nullable private String referenceTimezone;
+ @Nullable private String locales;
+ @Nullable private String detectedTextLanguageTags;
+ @Nullable private Collection<String> entityTypes;
+ private int annotateMode = AnnotateMode.ENTITY_ANNOTATION.getValue();
+ private int annotationUsecase = AnnotationUsecase.SMART.getValue();
+ private boolean hasLocationPermission = true;
+ private boolean hasPersonalizationPermission = true;
+ private boolean isSerializedEntityDataEnabled = false;
+ private double userLocationLat = INVALID_LATITUDE;
+ private double userLocationLng = INVALID_LONGITUDE;
+ private float userLocationAccuracyMeters = INVALID_LOCATION_ACCURACY_METERS;
+ private boolean usePodNer = true;
+
+ public Builder setReferenceTimeMsUtc(long referenceTimeMsUtc) {
+ this.referenceTimeMsUtc = referenceTimeMsUtc;
+ return this;
+ }
+
+ public Builder setReferenceTimezone(String referenceTimezone) {
+ this.referenceTimezone = referenceTimezone;
+ return this;
+ }
+
+ public Builder setLocales(@Nullable String locales) {
+ this.locales = locales;
+ return this;
+ }
+
+ public Builder setDetectedTextLanguageTags(@Nullable String detectedTextLanguageTags) {
+ this.detectedTextLanguageTags = detectedTextLanguageTags;
+ return this;
+ }
+
+ public Builder setEntityTypes(Collection<String> entityTypes) {
+ this.entityTypes = entityTypes;
+ return this;
+ }
+
+ public Builder setAnnotateMode(int annotateMode) {
+ this.annotateMode = annotateMode;
+ return this;
+ }
+
+ public Builder setAnnotationUsecase(int annotationUsecase) {
+ this.annotationUsecase = annotationUsecase;
+ return this;
+ }
+
+ public Builder setHasLocationPermission(boolean hasLocationPermission) {
+ this.hasLocationPermission = hasLocationPermission;
+ return this;
+ }
+
+ public Builder setHasPersonalizationPermission(boolean hasPersonalizationPermission) {
+ this.hasPersonalizationPermission = hasPersonalizationPermission;
+ return this;
+ }
+
+ public Builder setIsSerializedEntityDataEnabled(boolean isSerializedEntityDataEnabled) {
+ this.isSerializedEntityDataEnabled = isSerializedEntityDataEnabled;
+ return this;
+ }
+
+ public Builder setUserLocationLat(double userLocationLat) {
+ this.userLocationLat = userLocationLat;
+ return this;
+ }
+
+ public Builder setUserLocationLng(double userLocationLng) {
+ this.userLocationLng = userLocationLng;
+ return this;
+ }
+
+ public Builder setUserLocationAccuracyMeters(float userLocationAccuracyMeters) {
+ this.userLocationAccuracyMeters = userLocationAccuracyMeters;
+ return this;
+ }
+
+ public Builder setUsePodNer(boolean usePodNer) {
+ this.usePodNer = usePodNer;
+ return this;
+ }
+
+ public AnnotationOptions build() {
+ return new AnnotationOptions(
+ referenceTimeMsUtc,
+ referenceTimezone,
+ locales,
+ detectedTextLanguageTags,
+ entityTypes,
+ annotateMode,
+ annotationUsecase,
+ hasLocationPermission,
+ hasPersonalizationPermission,
+ isSerializedEntityDataEnabled,
+ userLocationLat,
+ userLocationLng,
+ userLocationAccuracyMeters,
+ usePodNer);
+ }
}
- public AnnotationOptions(
- long referenceTimeMsUtc,
- String referenceTimezone,
- String locales,
- String detectedTextLanguageTags,
- Collection<String> entityTypes,
- int annotationUsecase,
- boolean isSerializedEntityDataEnabled) {
- this(
- referenceTimeMsUtc,
- referenceTimezone,
- locales,
- detectedTextLanguageTags,
- entityTypes,
- annotationUsecase,
- isSerializedEntityDataEnabled,
- INVALID_LATITUDE,
- INVALID_LONGITUDE,
- INVALID_LOCATION_ACCURACY_METERS);
- }
-
- public AnnotationOptions(
- long referenceTimeMsUtc,
- String referenceTimezone,
- String locales,
- String detectedTextLanguageTags) {
- this(
- referenceTimeMsUtc,
- referenceTimezone,
- locales,
- detectedTextLanguageTags,
- null,
- AnnotationUsecase.SMART.getValue(),
- /* isSerializedEntityDataEnabled */ false);
+ public static Builder builder() {
+ return new Builder();
}
public long getReferenceTimeMsUtc() {
@@ -719,11 +1027,13 @@
return referenceTimezone;
}
+ @Nullable
public String getLocale() {
return locales;
}
/** Returns a comma separated list of BCP 47 language tags. */
+ @Nullable
public String getDetectedTextLanguageTags() {
return detectedTextLanguageTags;
}
@@ -732,6 +1042,10 @@
return entityTypes;
}
+ public int getAnnotateMode() {
+ return annotateMode;
+ }
+
public int getAnnotationUsecase() {
return annotationUsecase;
}
@@ -759,6 +1073,10 @@
public boolean hasPersonalizationPermission() {
return hasPersonalizationPermission;
}
+
+ public boolean getUsePodNer() {
+ return usePodNer;
+ }
}
/**
@@ -815,7 +1133,7 @@
private native AnnotatedSpan[] nativeAnnotate(
long context, String text, AnnotationOptions options);
- private native AnnotatedSpan[][] nativeAnnotateStructuredInput(
+ private native Annotations nativeAnnotateStructuredInput(
long context, InputFragment[] inputFragments, AnnotationOptions options);
private native byte[] nativeLookUpKnowledgeEntity(long context, String id);
diff --git a/native/Android.bp b/native/Android.bp
index 6b070b1..2cb3a80 100644
--- a/native/Android.bp
+++ b/native/Android.bp
@@ -27,7 +27,6 @@
name: "libtextclassifier_hash_defaults",
srcs: [
"utils/hash/farmhash.cc",
- "util/hash/hash.cc",
],
cflags: [
"-DNAMESPACE_FOR_HASH_FUNCTIONS=farmhash",
@@ -138,13 +137,15 @@
srcs: ["**/*.cc"],
exclude_srcs: [
- "**/*_test.cc",
- "**/*-test-lib.cc",
- "**/testing/*.cc",
+ "**/*_test.*",
+ "**/*-test-lib.*",
+ "**/testing/*.*",
"**/*test-util.*",
"**/*test-utils.*",
+ "**/*test_util.*",
+ "**/*test_utils.*",
"**/*_test-include.*",
- "**/*unittest.cc",
+ "**/*unittest.*",
],
version_script: "jni.lds",
@@ -165,8 +166,7 @@
test_suites: ["device-tests", "mts"],
data: [
- "annotator/test_data/**/*",
- "actions/test_data/**/*",
+ "**/test_data/*",
],
srcs: ["**/*.cc"],
@@ -176,18 +176,10 @@
static_libs: [
"libgmock_ndk",
"libgtest_ndk_c++",
+ "libbase_ndk"
],
- multilib: {
- lib32: {
- suffix: "32",
- cppflags: ["-DTC3_TEST_DATA_DIR=\"/data/nativetest/libtextclassifier_tests/test_data/\""],
- },
- lib64: {
- suffix: "64",
- cppflags: ["-DTC3_TEST_DATA_DIR=\"/data/nativetest64/libtextclassifier_tests/test_data/\""],
- },
- },
+ compile_multilib: "prefer32",
}
// ----------------
diff --git a/native/AndroidTest.xml b/native/AndroidTest.xml
index cee26dd..dc5ac90 100644
--- a/native/AndroidTest.xml
+++ b/native/AndroidTest.xml
@@ -20,7 +20,6 @@
<target_preparer class="com.android.compatibility.common.tradefed.targetprep.FilePusher">
<option name="cleanup" value="true" />
<option name="push" value="libtextclassifier_tests->/data/local/tmp/libtextclassifier_tests" />
- <option name="append-bitness" value="true" />
</target_preparer>
<test class="com.android.tradefed.testtype.GTest" >
diff --git a/native/FlatBufferHeaders.bp b/native/FlatBufferHeaders.bp
index 08284d8..267d188 100644
--- a/native/FlatBufferHeaders.bp
+++ b/native/FlatBufferHeaders.bp
@@ -1,4 +1,5 @@
-// Copyright (C) 2020 The Android Open Source Project
+//
+// Copyright (C) 2018 The Android Open Source Project
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -11,107 +12,47 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
-
-cc_library_headers {
- name: "libtextclassifier_flatbuffer_headers",
- stl: "libc++_static",
- sdk_version: "current",
- apex_available: [
- "//apex_available:platform",
- "com.android.extservices",
- ],
- generated_headers: [
- "libtextclassifier_fbgen_flatbuffers",
- "libtextclassifier_fbgen_tokenizer",
- "libtextclassifier_fbgen_codepoint_range",
- "libtextclassifier_fbgen_entity-data",
- "libtextclassifier_fbgen_zlib_buffer",
- "libtextclassifier_fbgen_resources_extra",
- "libtextclassifier_fbgen_intent_config",
- "libtextclassifier_fbgen_annotator_model",
- "libtextclassifier_fbgen_annotator_experimental_model",
- "libtextclassifier_fbgen_actions_model",
- "libtextclassifier_fbgen_tflite_text_encoder_config",
- "libtextclassifier_fbgen_lang_id_embedded_network",
- "libtextclassifier_fbgen_lang_id_model",
- "libtextclassifier_fbgen_actions-entity-data",
- "libtextclassifier_fbgen_normalization",
- "libtextclassifier_fbgen_language-tag",
- "libtextclassifier_fbgen_person_name_model",
- "libtextclassifier_fbgen_grammar_dates",
- "libtextclassifier_fbgen_timezone_code",
- "libtextclassifier_fbgen_grammar_rules"
- ],
- export_generated_headers: [
- "libtextclassifier_fbgen_flatbuffers",
- "libtextclassifier_fbgen_tokenizer",
- "libtextclassifier_fbgen_codepoint_range",
- "libtextclassifier_fbgen_entity-data",
- "libtextclassifier_fbgen_zlib_buffer",
- "libtextclassifier_fbgen_resources_extra",
- "libtextclassifier_fbgen_intent_config",
- "libtextclassifier_fbgen_annotator_model",
- "libtextclassifier_fbgen_annotator_experimental_model",
- "libtextclassifier_fbgen_actions_model",
- "libtextclassifier_fbgen_tflite_text_encoder_config",
- "libtextclassifier_fbgen_lang_id_embedded_network",
- "libtextclassifier_fbgen_lang_id_model",
- "libtextclassifier_fbgen_actions-entity-data",
- "libtextclassifier_fbgen_normalization",
- "libtextclassifier_fbgen_language-tag",
- "libtextclassifier_fbgen_person_name_model",
- "libtextclassifier_fbgen_grammar_dates",
- "libtextclassifier_fbgen_timezone_code",
- "libtextclassifier_fbgen_grammar_rules"
- ]
-}
+//
genrule {
- name: "libtextclassifier_fbgen_flatbuffers",
- srcs: ["utils/flatbuffers.fbs"],
- out: ["utils/flatbuffers_generated.h"],
+ name: "libtextclassifier_fbgen_lang_id_common_flatbuffers_model",
+ srcs: ["lang_id/common/flatbuffers/model.fbs"],
+ out: ["lang_id/common/flatbuffers/model_generated.h"],
defaults: ["fbgen"],
}
genrule {
- name: "libtextclassifier_fbgen_tokenizer",
- srcs: ["utils/tokenizer.fbs"],
- out: ["utils/tokenizer_generated.h"],
+ name: "libtextclassifier_fbgen_lang_id_common_flatbuffers_embedding-network",
+ srcs: ["lang_id/common/flatbuffers/embedding-network.fbs"],
+ out: ["lang_id/common/flatbuffers/embedding-network_generated.h"],
defaults: ["fbgen"],
}
genrule {
- name: "libtextclassifier_fbgen_codepoint_range",
- srcs: ["utils/codepoint-range.fbs"],
- out: ["utils/codepoint-range_generated.h"],
+ name: "libtextclassifier_fbgen_actions_actions_model",
+ srcs: ["actions/actions_model.fbs"],
+ out: ["actions/actions_model_generated.h"],
defaults: ["fbgen"],
}
genrule {
- name: "libtextclassifier_fbgen_resources_extra",
- srcs: ["utils/resources.fbs"],
- out: ["utils/resources_generated.h"],
+ name: "libtextclassifier_fbgen_actions_actions-entity-data",
+ srcs: ["actions/actions-entity-data.fbs"],
+ out: ["actions/actions-entity-data_generated.h"],
defaults: ["fbgen"],
}
genrule {
- name: "libtextclassifier_fbgen_entity-data",
- srcs: ["annotator/entity-data.fbs"],
- out: ["annotator/entity-data_generated.h"],
+ name: "libtextclassifier_fbgen_annotator_grammar_dates_timezone-code",
+ srcs: ["annotator/grammar/dates/timezone-code.fbs"],
+ out: ["annotator/grammar/dates/timezone-code_generated.h"],
defaults: ["fbgen"],
}
genrule {
- name: "libtextclassifier_fbgen_zlib_buffer",
- srcs: ["utils/zlib/buffer.fbs"],
- out: ["utils/zlib/buffer_generated.h"],
- defaults: ["fbgen"],
-}
-
-genrule {
- name: "libtextclassifier_fbgen_intent_config",
- srcs: ["utils/intents/intent-config.fbs"],
- out: ["utils/intents/intent-config_generated.h"],
+ name: "libtextclassifier_fbgen_annotator_grammar_dates_dates",
+ srcs: ["annotator/grammar/dates/dates.fbs"],
+ out: ["annotator/grammar/dates/dates_generated.h"],
defaults: ["fbgen"],
}
@@ -123,85 +64,164 @@
}
genrule {
- name: "libtextclassifier_fbgen_annotator_experimental_model",
- srcs: ["annotator/experimental/experimental.fbs"],
- out: ["annotator/experimental/experimental_generated.h"],
- defaults: ["fbgen"],
-}
-
-genrule {
- name: "libtextclassifier_fbgen_actions_model",
- srcs: ["actions/actions_model.fbs"],
- out: ["actions/actions_model_generated.h"],
- defaults: ["fbgen"],
-}
-
-genrule {
- name: "libtextclassifier_fbgen_tflite_text_encoder_config",
- srcs: ["utils/tflite/text_encoder_config.fbs"],
- out: ["utils/tflite/text_encoder_config_generated.h"],
- defaults: ["fbgen"],
-}
-
-genrule {
- name: "libtextclassifier_fbgen_lang_id_embedded_network",
- srcs: ["lang_id/common/flatbuffers/embedding-network.fbs"],
- out: ["lang_id/common/flatbuffers/embedding-network_generated.h"],
- defaults: ["fbgen"],
-}
-
-genrule {
- name: "libtextclassifier_fbgen_lang_id_model",
- srcs: ["lang_id/common/flatbuffers/model.fbs"],
- out: ["lang_id/common/flatbuffers/model_generated.h"],
- defaults: ["fbgen"],
-}
-
-genrule {
- name: "libtextclassifier_fbgen_actions-entity-data",
- srcs: ["actions/actions-entity-data.fbs"],
- out: ["actions/actions-entity-data_generated.h"],
- defaults: ["fbgen"],
-}
-
-genrule {
- name: "libtextclassifier_fbgen_normalization",
- srcs: ["utils/normalization.fbs"],
- out: ["utils/normalization_generated.h"],
- defaults: ["fbgen"],
-}
-
-genrule {
- name: "libtextclassifier_fbgen_language-tag",
- srcs: ["utils/i18n/language-tag.fbs"],
- out: ["utils/i18n/language-tag_generated.h"],
- defaults: ["fbgen"],
-}
-
-genrule {
- name: "libtextclassifier_fbgen_person_name_model",
+ name: "libtextclassifier_fbgen_annotator_person_name_person_name_model",
srcs: ["annotator/person_name/person_name_model.fbs"],
out: ["annotator/person_name/person_name_model_generated.h"],
defaults: ["fbgen"],
}
genrule {
- name: "libtextclassifier_fbgen_grammar_dates",
- srcs: ["annotator/grammar/dates/dates.fbs"],
- out: ["annotator/grammar/dates/dates_generated.h"],
+ name: "libtextclassifier_fbgen_annotator_experimental_experimental",
+ srcs: ["annotator/experimental/experimental.fbs"],
+ out: ["annotator/experimental/experimental_generated.h"],
defaults: ["fbgen"],
}
genrule {
- name: "libtextclassifier_fbgen_timezone_code",
- srcs: ["annotator/grammar/dates/timezone-code.fbs"],
- out: ["annotator/grammar/dates/timezone-code_generated.h"],
+ name: "libtextclassifier_fbgen_annotator_entity-data",
+ srcs: ["annotator/entity-data.fbs"],
+ out: ["annotator/entity-data_generated.h"],
defaults: ["fbgen"],
}
genrule {
- name: "libtextclassifier_fbgen_grammar_rules",
+ name: "libtextclassifier_fbgen_utils_grammar_next_semantics_expression",
+ srcs: ["utils/grammar/next/semantics/expression.fbs"],
+ out: ["utils/grammar/next/semantics/expression_generated.h"],
+ defaults: ["fbgen"],
+}
+
+genrule {
+ name: "libtextclassifier_fbgen_utils_grammar_rules",
srcs: ["utils/grammar/rules.fbs"],
out: ["utils/grammar/rules_generated.h"],
defaults: ["fbgen"],
}
+
+genrule {
+ name: "libtextclassifier_fbgen_utils_normalization",
+ srcs: ["utils/normalization.fbs"],
+ out: ["utils/normalization_generated.h"],
+ defaults: ["fbgen"],
+}
+
+genrule {
+ name: "libtextclassifier_fbgen_utils_resources",
+ srcs: ["utils/resources.fbs"],
+ out: ["utils/resources_generated.h"],
+ defaults: ["fbgen"],
+}
+
+genrule {
+ name: "libtextclassifier_fbgen_utils_i18n_language-tag",
+ srcs: ["utils/i18n/language-tag.fbs"],
+ out: ["utils/i18n/language-tag_generated.h"],
+ defaults: ["fbgen"],
+}
+
+genrule {
+ name: "libtextclassifier_fbgen_utils_tflite_text_encoder_config",
+ srcs: ["utils/tflite/text_encoder_config.fbs"],
+ out: ["utils/tflite/text_encoder_config_generated.h"],
+ defaults: ["fbgen"],
+}
+
+genrule {
+ name: "libtextclassifier_fbgen_utils_flatbuffers_flatbuffers",
+ srcs: ["utils/flatbuffers/flatbuffers.fbs"],
+ out: ["utils/flatbuffers/flatbuffers_generated.h"],
+ defaults: ["fbgen"],
+}
+
+genrule {
+ name: "libtextclassifier_fbgen_utils_container_bit-vector",
+ srcs: ["utils/container/bit-vector.fbs"],
+ out: ["utils/container/bit-vector_generated.h"],
+ defaults: ["fbgen"],
+}
+
+genrule {
+ name: "libtextclassifier_fbgen_utils_tokenizer",
+ srcs: ["utils/tokenizer.fbs"],
+ out: ["utils/tokenizer_generated.h"],
+ defaults: ["fbgen"],
+}
+
+genrule {
+ name: "libtextclassifier_fbgen_utils_codepoint-range",
+ srcs: ["utils/codepoint-range.fbs"],
+ out: ["utils/codepoint-range_generated.h"],
+ defaults: ["fbgen"],
+}
+
+genrule {
+ name: "libtextclassifier_fbgen_utils_zlib_buffer",
+ srcs: ["utils/zlib/buffer.fbs"],
+ out: ["utils/zlib/buffer_generated.h"],
+ defaults: ["fbgen"],
+}
+
+genrule {
+ name: "libtextclassifier_fbgen_utils_intents_intent-config",
+ srcs: ["utils/intents/intent-config.fbs"],
+ out: ["utils/intents/intent-config_generated.h"],
+ defaults: ["fbgen"],
+}
+
+cc_library_headers {
+ name: "libtextclassifier_flatbuffer_headers",
+ stl: "libc++_static",
+ sdk_version: "current",
+ apex_available: [
+ "//apex_available:platform",
+ "com.android.extservices",
+ ],
+ generated_headers: [
+ "libtextclassifier_fbgen_lang_id_common_flatbuffers_model",
+ "libtextclassifier_fbgen_lang_id_common_flatbuffers_embedding-network",
+ "libtextclassifier_fbgen_actions_actions_model",
+ "libtextclassifier_fbgen_actions_actions-entity-data",
+ "libtextclassifier_fbgen_annotator_grammar_dates_timezone-code",
+ "libtextclassifier_fbgen_annotator_grammar_dates_dates",
+ "libtextclassifier_fbgen_annotator_model",
+ "libtextclassifier_fbgen_annotator_person_name_person_name_model",
+ "libtextclassifier_fbgen_annotator_experimental_experimental",
+ "libtextclassifier_fbgen_annotator_entity-data",
+ "libtextclassifier_fbgen_utils_grammar_next_semantics_expression",
+ "libtextclassifier_fbgen_utils_grammar_rules",
+ "libtextclassifier_fbgen_utils_normalization",
+ "libtextclassifier_fbgen_utils_resources",
+ "libtextclassifier_fbgen_utils_i18n_language-tag",
+ "libtextclassifier_fbgen_utils_tflite_text_encoder_config",
+ "libtextclassifier_fbgen_utils_flatbuffers_flatbuffers",
+ "libtextclassifier_fbgen_utils_container_bit-vector",
+ "libtextclassifier_fbgen_utils_tokenizer",
+ "libtextclassifier_fbgen_utils_codepoint-range",
+ "libtextclassifier_fbgen_utils_zlib_buffer",
+ "libtextclassifier_fbgen_utils_intents_intent-config",
+ ],
+ export_generated_headers: [
+ "libtextclassifier_fbgen_lang_id_common_flatbuffers_model",
+ "libtextclassifier_fbgen_lang_id_common_flatbuffers_embedding-network",
+ "libtextclassifier_fbgen_actions_actions_model",
+ "libtextclassifier_fbgen_actions_actions-entity-data",
+ "libtextclassifier_fbgen_annotator_grammar_dates_timezone-code",
+ "libtextclassifier_fbgen_annotator_grammar_dates_dates",
+ "libtextclassifier_fbgen_annotator_model",
+ "libtextclassifier_fbgen_annotator_person_name_person_name_model",
+ "libtextclassifier_fbgen_annotator_experimental_experimental",
+ "libtextclassifier_fbgen_annotator_entity-data",
+ "libtextclassifier_fbgen_utils_grammar_next_semantics_expression",
+ "libtextclassifier_fbgen_utils_grammar_rules",
+ "libtextclassifier_fbgen_utils_normalization",
+ "libtextclassifier_fbgen_utils_resources",
+ "libtextclassifier_fbgen_utils_i18n_language-tag",
+ "libtextclassifier_fbgen_utils_tflite_text_encoder_config",
+ "libtextclassifier_fbgen_utils_flatbuffers_flatbuffers",
+ "libtextclassifier_fbgen_utils_container_bit-vector",
+ "libtextclassifier_fbgen_utils_tokenizer",
+ "libtextclassifier_fbgen_utils_codepoint-range",
+ "libtextclassifier_fbgen_utils_zlib_buffer",
+ "libtextclassifier_fbgen_utils_intents_intent-config",
+ ],
+}
diff --git a/native/actions/actions-suggestions.cc b/native/actions/actions-suggestions.cc
index 1fcd35c..93ef544 100644
--- a/native/actions/actions-suggestions.cc
+++ b/native/actions/actions-suggestions.cc
@@ -24,12 +24,12 @@
#include "actions/zlib-utils.h"
#include "annotator/collections.h"
#include "utils/base/logging.h"
-#include "utils/flatbuffers.h"
#include "utils/lua-utils.h"
#include "utils/normalization.h"
#include "utils/optional.h"
#include "utils/strings/split.h"
#include "utils/strings/stringpiece.h"
+#include "utils/strings/utf8.h"
#include "utils/utf8/unicodetext.h"
#include "tensorflow/lite/string_util.h"
@@ -52,10 +52,6 @@
const std::string& ActionsSuggestions::kShareLocation =
*[]() { return new std::string("share_location"); }();
-// Name for a datetime annotation that only includes time but no date.
-const std::string& kTimeAnnotation =
- *[]() { return new std::string("time"); }();
-
constexpr float kDefaultFloat = 0.0;
constexpr bool kDefaultBool = false;
constexpr int kDefaultInt = 1;
@@ -285,7 +281,7 @@
}
entity_data_builder_.reset(
- new ReflectiveFlatbufferBuilder(entity_data_schema_));
+ new MutableFlatbufferBuilder(entity_data_schema_));
} else {
entity_data_schema_ = nullptr;
}
@@ -777,7 +773,7 @@
void ActionsSuggestions::FillSuggestionFromSpecWithEntityData(
const ActionSuggestionSpec* spec, ActionSuggestion* suggestion) const {
- std::unique_ptr<ReflectiveFlatbuffer> entity_data =
+ std::unique_ptr<MutableFlatbuffer> entity_data =
entity_data_builder_ != nullptr ? entity_data_builder_->NewRoot()
: nullptr;
FillSuggestionFromSpec(spec, entity_data.get(), suggestion);
@@ -806,7 +802,7 @@
if (triggering) {
ActionSuggestion suggestion;
- std::unique_ptr<ReflectiveFlatbuffer> entity_data =
+ std::unique_ptr<MutableFlatbuffer> entity_data =
entity_data_builder_ != nullptr ? entity_data_builder_->NewRoot()
: nullptr;
FillSuggestionFromSpecWithEntityData(task_spec, &suggestion);
@@ -881,7 +877,7 @@
// Create action from model output.
ActionSuggestion suggestion;
suggestion.type = action_type->name()->str();
- std::unique_ptr<ReflectiveFlatbuffer> entity_data =
+ std::unique_ptr<MutableFlatbuffer> entity_data =
entity_data_builder_ != nullptr ? entity_data_builder_->NewRoot()
: nullptr;
FillSuggestionFromSpecWithEntityData(action_type->action(), &suggestion);
@@ -1036,30 +1032,7 @@
if (message->annotations.empty()) {
message->annotations = annotator->Annotate(
message->text, AnnotationOptionsForMessage(*message));
- for (int i = 0; i < message->annotations.size(); i++) {
- ClassificationResult* classification =
- &message->annotations[i].classification.front();
-
- // Specialize datetime annotation to time annotation if no date
- // component is present.
- if (classification->collection == Collections::DateTime() &&
- classification->datetime_parse_result.IsSet()) {
- bool has_only_time = true;
- for (const DatetimeComponent& component :
- classification->datetime_parse_result.datetime_components) {
- if (component.component_type !=
- DatetimeComponent::ComponentType::UNSPECIFIED &&
- component.component_type <
- DatetimeComponent::ComponentType::HOUR) {
- has_only_time = false;
- break;
- }
- }
- if (has_only_time) {
- classification->collection = kTimeAnnotation;
- }
- }
- }
+ ConvertDatetimeToTime(&message->annotations);
}
}
return annotated_conversation;
@@ -1160,7 +1133,7 @@
continue;
}
- std::unique_ptr<ReflectiveFlatbuffer> entity_data =
+ std::unique_ptr<MutableFlatbuffer> entity_data =
entity_data_builder_ != nullptr ? entity_data_builder_->NewRoot()
: nullptr;
@@ -1363,6 +1336,15 @@
if (conversation.messages[i].reference_time_ms_utc <
conversation.messages[i - 1].reference_time_ms_utc) {
TC3_LOG(ERROR) << "Messages are not sorted most recent last.";
+ return response;
+ }
+ }
+
+ // Check that messages are valid utf8.
+ for (const ConversationMessage& message : conversation.messages) {
+ if (!IsValidUTF8(message.text.data(), message.text.size())) {
+ TC3_LOG(ERROR) << "Not valid utf8 provided.";
+ return response;
}
}
diff --git a/native/actions/actions-suggestions.h b/native/actions/actions-suggestions.h
index 2a321f0..1fee9a1 100644
--- a/native/actions/actions-suggestions.h
+++ b/native/actions/actions-suggestions.h
@@ -34,7 +34,8 @@
#include "annotator/annotator.h"
#include "annotator/model-executor.h"
#include "annotator/types.h"
-#include "utils/flatbuffers.h"
+#include "utils/flatbuffers/flatbuffers.h"
+#include "utils/flatbuffers/mutable.h"
#include "utils/i18n/locale.h"
#include "utils/memory/mmap.h"
#include "utils/tflite-model-executor.h"
@@ -262,7 +263,7 @@
// Builder for creating extra data.
const reflection::Schema* entity_data_schema_;
- std::unique_ptr<ReflectiveFlatbufferBuilder> entity_data_builder_;
+ std::unique_ptr<MutableFlatbufferBuilder> entity_data_builder_;
std::unique_ptr<ActionsSuggestionsRanker> ranker_;
std::string lua_bytecode_;
diff --git a/native/actions/actions_jni.cc b/native/actions/actions_jni.cc
index 7dd0169..1648fb3 100644
--- a/native/actions/actions_jni.cc
+++ b/native/actions/actions_jni.cc
@@ -28,6 +28,7 @@
#include "annotator/annotator.h"
#include "annotator/annotator_jni_common.h"
#include "utils/base/integral_types.h"
+#include "utils/base/status_macros.h"
#include "utils/base/statusor.h"
#include "utils/intents/intent-generator.h"
#include "utils/intents/jni.h"
@@ -35,7 +36,6 @@
#include "utils/java/jni-base.h"
#include "utils/java/jni-cache.h"
#include "utils/java/jni-helper.h"
-#include "utils/java/string_utils.h"
#include "utils/memory/mmap.h"
using libtextclassifier3::ActionsSuggestions;
@@ -45,9 +45,9 @@
using libtextclassifier3::Annotator;
using libtextclassifier3::Conversation;
using libtextclassifier3::IntentGenerator;
+using libtextclassifier3::JStringToUtf8String;
using libtextclassifier3::ScopedLocalRef;
using libtextclassifier3::StatusOr;
-using libtextclassifier3::ToStlString;
// When using the Java's ICU, UniLib needs to be instantiated with a JavaVM
// pointer from JNI. When using a standard ICU the pointer is not needed and the
@@ -74,13 +74,14 @@
std::unique_ptr<IntentGenerator> intent_generator =
IntentGenerator::Create(model->model()->android_intent_options(),
model->model()->resources(), jni_cache);
- std::unique_ptr<RemoteActionTemplatesHandler> template_handler =
- libtextclassifier3::RemoteActionTemplatesHandler::Create(jni_cache);
-
- if (intent_generator == nullptr || template_handler == nullptr) {
+ if (intent_generator == nullptr) {
return nullptr;
}
+ TC3_ASSIGN_OR_RETURN_NULL(
+ std::unique_ptr<RemoteActionTemplatesHandler> template_handler,
+ libtextclassifier3::RemoteActionTemplatesHandler::Create(jni_cache));
+
return new ActionsSuggestionsJniContext(jni_cache, std::move(model),
std::move(intent_generator),
std::move(template_handler));
@@ -166,11 +167,11 @@
serialized_entity_data,
JniHelper::NewByteArray(
env, action_result[i].serialized_entity_data.size()));
- env->SetByteArrayRegion(
- serialized_entity_data.get(), 0,
+ TC3_RETURN_IF_ERROR(JniHelper::SetByteArrayRegion(
+ env, serialized_entity_data.get(), 0,
action_result[i].serialized_entity_data.size(),
reinterpret_cast<const jbyte*>(
- action_result[i].serialized_entity_data.data()));
+ action_result[i].serialized_entity_data.data())));
}
ScopedLocalRef<jobjectArray> remote_action_templates_result;
@@ -203,7 +204,8 @@
static_cast<jfloat>(action_result[i].score),
extras.get(), serialized_entity_data.get(),
remote_action_templates_result.get()));
- env->SetObjectArrayElement(results.get(), i, result.get());
+ TC3_RETURN_IF_ERROR(
+ JniHelper::SetObjectArrayElement(env, results.get(), i, result.get()));
}
return results;
}
@@ -262,13 +264,14 @@
env, jmessage, get_detected_text_language_tags_method));
ConversationMessage message;
- TC3_ASSIGN_OR_RETURN(message.text, ToStlString(env, text.get()));
+ TC3_ASSIGN_OR_RETURN(message.text, JStringToUtf8String(env, text.get()));
message.user_id = user_id;
message.reference_time_ms_utc = reference_time;
TC3_ASSIGN_OR_RETURN(message.reference_timezone,
- ToStlString(env, reference_timezone.get()));
- TC3_ASSIGN_OR_RETURN(message.detected_text_language_tags,
- ToStlString(env, detected_text_language_tags.get()));
+ JStringToUtf8String(env, reference_timezone.get()));
+ TC3_ASSIGN_OR_RETURN(
+ message.detected_text_language_tags,
+ JStringToUtf8String(env, detected_text_language_tags.get()));
return message;
}
@@ -295,7 +298,8 @@
env, jconversation, get_conversation_messages_method));
std::vector<ConversationMessage> messages;
- const int size = env->GetArrayLength(jmessages.get());
+ TC3_ASSIGN_OR_RETURN(const int size,
+ JniHelper::GetArrayLength(env, jmessages.get()));
for (int i = 0; i < size; i++) {
TC3_ASSIGN_OR_RETURN(
ScopedLocalRef<jobject> jmessage,
@@ -353,81 +357,86 @@
using libtextclassifier3::ActionSuggestionsToJObjectArray;
using libtextclassifier3::FromJavaActionSuggestionOptions;
using libtextclassifier3::FromJavaConversation;
+using libtextclassifier3::JByteArrayToString;
TC3_JNI_METHOD(jlong, TC3_ACTIONS_CLASS_NAME, nativeNewActionsModel)
-(JNIEnv* env, jobject thiz, jint fd, jbyteArray serialized_preconditions) {
+(JNIEnv* env, jobject clazz, jint fd, jbyteArray jserialized_preconditions) {
std::shared_ptr<libtextclassifier3::JniCache> jni_cache =
libtextclassifier3::JniCache::Create(env);
- std::string preconditions;
- if (serialized_preconditions != nullptr &&
- !libtextclassifier3::JByteArrayToString(env, serialized_preconditions,
- &preconditions)) {
- TC3_LOG(ERROR) << "Could not convert serialized preconditions.";
- return 0;
+ std::string serialized_preconditions;
+ if (jserialized_preconditions != nullptr) {
+ TC3_ASSIGN_OR_RETURN_0(
+ serialized_preconditions,
+ JByteArrayToString(env, jserialized_preconditions),
+ TC3_LOG(ERROR) << "Could not convert serialized preconditions.");
}
+
#ifdef TC3_UNILIB_JAVAICU
return reinterpret_cast<jlong>(ActionsSuggestionsJniContext::Create(
- jni_cache,
- ActionsSuggestions::FromFileDescriptor(
- fd, std::unique_ptr<UniLib>(new UniLib(jni_cache)), preconditions)));
+ jni_cache, ActionsSuggestions::FromFileDescriptor(
+ fd, std::unique_ptr<UniLib>(new UniLib(jni_cache)),
+ serialized_preconditions)));
#else
return reinterpret_cast<jlong>(ActionsSuggestionsJniContext::Create(
- jni_cache, ActionsSuggestions::FromFileDescriptor(fd, /*unilib=*/nullptr,
- preconditions)));
+ jni_cache, ActionsSuggestions::FromFileDescriptor(
+ fd, /*unilib=*/nullptr, serialized_preconditions)));
#endif // TC3_UNILIB_JAVAICU
}
TC3_JNI_METHOD(jlong, TC3_ACTIONS_CLASS_NAME, nativeNewActionsModelFromPath)
-(JNIEnv* env, jobject thiz, jstring path, jbyteArray serialized_preconditions) {
+(JNIEnv* env, jobject clazz, jstring path,
+ jbyteArray jserialized_preconditions) {
std::shared_ptr<libtextclassifier3::JniCache> jni_cache =
libtextclassifier3::JniCache::Create(env);
- TC3_ASSIGN_OR_RETURN_0(const std::string path_str, ToStlString(env, path));
- std::string preconditions;
- if (serialized_preconditions != nullptr &&
- !libtextclassifier3::JByteArrayToString(env, serialized_preconditions,
- &preconditions)) {
- TC3_LOG(ERROR) << "Could not convert serialized preconditions.";
- return 0;
+ TC3_ASSIGN_OR_RETURN_0(const std::string path_str,
+ JStringToUtf8String(env, path));
+ std::string serialized_preconditions;
+ if (jserialized_preconditions != nullptr) {
+ TC3_ASSIGN_OR_RETURN_0(
+ serialized_preconditions,
+ JByteArrayToString(env, jserialized_preconditions),
+ TC3_LOG(ERROR) << "Could not convert serialized preconditions.");
}
#ifdef TC3_UNILIB_JAVAICU
return reinterpret_cast<jlong>(ActionsSuggestionsJniContext::Create(
jni_cache, ActionsSuggestions::FromPath(
path_str, std::unique_ptr<UniLib>(new UniLib(jni_cache)),
- preconditions)));
+ serialized_preconditions)));
#else
return reinterpret_cast<jlong>(ActionsSuggestionsJniContext::Create(
jni_cache, ActionsSuggestions::FromPath(path_str, /*unilib=*/nullptr,
- preconditions)));
+ serialized_preconditions)));
#endif // TC3_UNILIB_JAVAICU
}
TC3_JNI_METHOD(jlong, TC3_ACTIONS_CLASS_NAME, nativeNewActionsModelWithOffset)
-(JNIEnv* env, jobject thiz, jint fd, jlong offset, jlong size,
- jbyteArray serialized_preconditions) {
+(JNIEnv* env, jobject clazz, jint fd, jlong offset, jlong size,
+ jbyteArray jserialized_preconditions) {
std::shared_ptr<libtextclassifier3::JniCache> jni_cache =
libtextclassifier3::JniCache::Create(env);
- std::string preconditions;
- if (serialized_preconditions != nullptr &&
- !libtextclassifier3::JByteArrayToString(env, serialized_preconditions,
- &preconditions)) {
- TC3_LOG(ERROR) << "Could not convert serialized preconditions.";
- return 0;
+ std::string serialized_preconditions;
+ if (jserialized_preconditions != nullptr) {
+ TC3_ASSIGN_OR_RETURN_0(
+ serialized_preconditions,
+ JByteArrayToString(env, jserialized_preconditions),
+ TC3_LOG(ERROR) << "Could not convert serialized preconditions.");
}
#ifdef TC3_UNILIB_JAVAICU
return reinterpret_cast<jlong>(ActionsSuggestionsJniContext::Create(
jni_cache,
ActionsSuggestions::FromFileDescriptor(
fd, offset, size, std::unique_ptr<UniLib>(new UniLib(jni_cache)),
- preconditions)));
+ serialized_preconditions)));
#else
return reinterpret_cast<jlong>(ActionsSuggestionsJniContext::Create(
- jni_cache, ActionsSuggestions::FromFileDescriptor(
- fd, offset, size, /*unilib=*/nullptr, preconditions)));
+ jni_cache,
+ ActionsSuggestions::FromFileDescriptor(
+ fd, offset, size, /*unilib=*/nullptr, serialized_preconditions)));
#endif // TC3_UNILIB_JAVAICU
}
TC3_JNI_METHOD(jobjectArray, TC3_ACTIONS_CLASS_NAME, nativeSuggestActions)
-(JNIEnv* env, jobject clazz, jlong ptr, jobject jconversation, jobject joptions,
+(JNIEnv* env, jobject thiz, jlong ptr, jobject jconversation, jobject joptions,
jlong annotatorPtr, jobject app_context, jstring device_locales,
jboolean generate_intents) {
if (!ptr) {
@@ -456,7 +465,7 @@
}
TC3_JNI_METHOD(void, TC3_ACTIONS_CLASS_NAME, nativeCloseActionsModel)
-(JNIEnv* env, jobject clazz, jlong model_ptr) {
+(JNIEnv* env, jobject thiz, jlong model_ptr) {
const ActionsSuggestionsJniContext* context =
reinterpret_cast<ActionsSuggestionsJniContext*>(model_ptr);
delete context;
diff --git a/native/actions/actions_jni.h b/native/actions/actions_jni.h
index 276e361..75e2e67 100644
--- a/native/actions/actions_jni.h
+++ b/native/actions/actions_jni.h
@@ -32,13 +32,13 @@
#endif
TC3_JNI_METHOD(jlong, TC3_ACTIONS_CLASS_NAME, nativeNewActionsModel)
-(JNIEnv* env, jobject thiz, jint fd, jbyteArray serialized_preconditions);
+(JNIEnv* env, jobject clazz, jint fd, jbyteArray serialized_preconditions);
TC3_JNI_METHOD(jlong, TC3_ACTIONS_CLASS_NAME, nativeNewActionsModelFromPath)
-(JNIEnv* env, jobject thiz, jstring path, jbyteArray serialized_preconditions);
+(JNIEnv* env, jobject clazz, jstring path, jbyteArray serialized_preconditions);
TC3_JNI_METHOD(jlong, TC3_ACTIONS_CLASS_NAME, nativeNewActionsModelWithOffset)
-(JNIEnv* env, jobject thiz, jint fd, jlong offset, jlong size,
+(JNIEnv* env, jobject clazz, jint fd, jlong offset, jlong size,
jbyteArray serialized_preconditions);
TC3_JNI_METHOD(jobjectArray, TC3_ACTIONS_CLASS_NAME, nativeSuggestActions)
diff --git a/native/actions/actions_model.fbs b/native/actions/actions_model.fbs
index 251610e..7de786f 100755
--- a/native/actions/actions_model.fbs
+++ b/native/actions/actions_model.fbs
@@ -14,16 +14,16 @@
// limitations under the License.
//
-include "actions/actions-entity-data.fbs";
-include "annotator/model.fbs";
-include "utils/codepoint-range.fbs";
-include "utils/flatbuffers.fbs";
-include "utils/grammar/rules.fbs";
-include "utils/intents/intent-config.fbs";
include "utils/normalization.fbs";
-include "utils/resources.fbs";
+include "utils/grammar/rules.fbs";
+include "actions/actions-entity-data.fbs";
+include "utils/flatbuffers/flatbuffers.fbs";
+include "utils/intents/intent-config.fbs";
+include "utils/codepoint-range.fbs";
include "utils/tokenizer.fbs";
include "utils/zlib/buffer.fbs";
+include "utils/resources.fbs";
+include "annotator/model.fbs";
file_identifier "TC3A";
@@ -75,13 +75,10 @@
// int, the number of smart replies to produce.
input_num_suggestions:int = 4;
- // float, the output diversification distance parameter.
reserved_7:int (deprecated);
- // float, the empirical probability factor parameter.
reserved_8:int (deprecated);
- // float, the confidence threshold.
reserved_9:int (deprecated);
// Input port for hashed and embedded tokens, a (num messages, max tokens,
@@ -280,7 +277,9 @@
low_confidence_rules:RulesModel;
reserved_11:float (deprecated);
+
reserved_12:float (deprecated);
+
reserved_13:float (deprecated);
// Smart reply thresholds.
diff --git a/native/actions/feature-processor_test.cc b/native/actions/feature-processor_test.cc
index 969bbf7..e36af90 100644
--- a/native/actions/feature-processor_test.cc
+++ b/native/actions/feature-processor_test.cc
@@ -47,9 +47,9 @@
std::vector<float> storage_;
};
-class FeatureProcessorTest : public ::testing::Test {
+class ActionsFeatureProcessorTest : public ::testing::Test {
protected:
- FeatureProcessorTest() : INIT_UNILIB_FOR_TESTING(unilib_) {}
+ ActionsFeatureProcessorTest() : INIT_UNILIB_FOR_TESTING(unilib_) {}
flatbuffers::DetachedBuffer PackFeatureProcessorOptions(
ActionsTokenFeatureProcessorOptionsT* options) const {
@@ -62,7 +62,7 @@
UniLib unilib_;
};
-TEST_F(FeatureProcessorTest, TokenEmbeddings) {
+TEST_F(ActionsFeatureProcessorTest, TokenEmbeddings) {
ActionsTokenFeatureProcessorOptionsT options;
options.embedding_size = 4;
options.tokenizer_options.reset(new ActionsTokenizerOptionsT);
@@ -81,7 +81,7 @@
EXPECT_THAT(token_features, SizeIs(4));
}
-TEST_F(FeatureProcessorTest, TokenEmbeddingsCaseFeature) {
+TEST_F(ActionsFeatureProcessorTest, TokenEmbeddingsCaseFeature) {
ActionsTokenFeatureProcessorOptionsT options;
options.embedding_size = 4;
options.extract_case_feature = true;
@@ -102,7 +102,7 @@
EXPECT_THAT(token_features[4], FloatEq(1.0));
}
-TEST_F(FeatureProcessorTest, MultipleTokenEmbeddingsCaseFeature) {
+TEST_F(ActionsFeatureProcessorTest, MultipleTokenEmbeddingsCaseFeature) {
ActionsTokenFeatureProcessorOptionsT options;
options.embedding_size = 4;
options.extract_case_feature = true;
diff --git a/native/actions/grammar-actions.cc b/native/actions/grammar-actions.cc
index 7f3e71f..597ee59 100644
--- a/native/actions/grammar-actions.cc
+++ b/native/actions/grammar-actions.cc
@@ -54,7 +54,7 @@
// Deduplicate, verify and populate actions from grammar matches.
bool GetActions(const Conversation& conversation,
const std::string& smart_reply_action_type,
- const ReflectiveFlatbufferBuilder* entity_data_builder,
+ const MutableFlatbufferBuilder* entity_data_builder,
std::vector<ActionSuggestion>* action_suggestions) const {
std::vector<UnicodeText::const_iterator> codepoint_offsets;
const UnicodeText message_unicode =
@@ -95,7 +95,7 @@
const std::vector<UnicodeText::const_iterator>& message_codepoint_offsets,
int message_index, const std::string& smart_reply_action_type,
const grammar::Derivation& candidate,
- const ReflectiveFlatbufferBuilder* entity_data_builder,
+ const MutableFlatbufferBuilder* entity_data_builder,
std::vector<ActionSuggestion>* result) const {
const RulesModel_::GrammarRules_::RuleMatch* rule_match =
grammar_rules_->rule_match()->Get(candidate.rule_id);
@@ -118,7 +118,7 @@
grammar_rules_->actions()->Get(action_id);
std::vector<ActionSuggestionAnnotation> annotations;
- std::unique_ptr<ReflectiveFlatbuffer> entity_data =
+ std::unique_ptr<MutableFlatbuffer> entity_data =
entity_data_builder != nullptr ? entity_data_builder->NewRoot()
: nullptr;
@@ -200,7 +200,7 @@
GrammarActions::GrammarActions(
const UniLib* unilib, const RulesModel_::GrammarRules* grammar_rules,
- const ReflectiveFlatbufferBuilder* entity_data_builder,
+ const MutableFlatbufferBuilder* entity_data_builder,
const std::string& smart_reply_action_type)
: unilib_(*unilib),
grammar_rules_(grammar_rules),
diff --git a/native/actions/grammar-actions.h b/native/actions/grammar-actions.h
index fc3270d..ea8c2b4 100644
--- a/native/actions/grammar-actions.h
+++ b/native/actions/grammar-actions.h
@@ -22,7 +22,7 @@
#include "actions/actions_model_generated.h"
#include "actions/types.h"
-#include "utils/flatbuffers.h"
+#include "utils/flatbuffers/mutable.h"
#include "utils/grammar/lexer.h"
#include "utils/grammar/types.h"
#include "utils/i18n/locale.h"
@@ -37,10 +37,10 @@
public:
enum class Callback : grammar::CallbackId { kActionRuleMatch = 1 };
- explicit GrammarActions(
- const UniLib* unilib, const RulesModel_::GrammarRules* grammar_rules,
- const ReflectiveFlatbufferBuilder* entity_data_builder,
- const std::string& smart_reply_action_type);
+ explicit GrammarActions(const UniLib* unilib,
+ const RulesModel_::GrammarRules* grammar_rules,
+ const MutableFlatbufferBuilder* entity_data_builder,
+ const std::string& smart_reply_action_type);
// Suggests actions for a conversation from a message stream.
bool SuggestActions(const Conversation& conversation,
@@ -51,7 +51,7 @@
const RulesModel_::GrammarRules* grammar_rules_;
const std::unique_ptr<Tokenizer> tokenizer_;
const grammar::Lexer lexer_;
- const ReflectiveFlatbufferBuilder* entity_data_builder_;
+ const MutableFlatbufferBuilder* entity_data_builder_;
const std::string smart_reply_action_type_;
// Pre-parsed locales of the rules.
diff --git a/native/actions/lua-ranker_test.cc b/native/actions/lua-ranker_test.cc
index a790042..939617b 100644
--- a/native/actions/lua-ranker_test.cc
+++ b/native/actions/lua-ranker_test.cc
@@ -19,7 +19,7 @@
#include <string>
#include "actions/types.h"
-#include "utils/flatbuffers.h"
+#include "utils/flatbuffers/mutable.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
@@ -229,8 +229,8 @@
flatbuffers::GetRoot<reflection::Schema>(serialized_schema.data());
// Create test entity data.
- ReflectiveFlatbufferBuilder builder(entity_data_schema);
- std::unique_ptr<ReflectiveFlatbuffer> buffer = builder.NewRoot();
+ MutableFlatbufferBuilder builder(entity_data_schema);
+ std::unique_ptr<MutableFlatbuffer> buffer = builder.NewRoot();
buffer->Set("test", "value_a");
const std::string serialized_entity_data_a = buffer->Serialize();
buffer->Set("test", "value_b");
diff --git a/native/actions/ranker.h b/native/actions/ranker.h
index 2ab3146..5af7c38 100644
--- a/native/actions/ranker.h
+++ b/native/actions/ranker.h
@@ -22,6 +22,7 @@
#include "actions/actions_model_generated.h"
#include "actions/types.h"
#include "utils/zlib/zlib.h"
+#include "flatbuffers/reflection.h"
namespace libtextclassifier3 {
diff --git a/native/actions/regex-actions.cc b/native/actions/regex-actions.cc
index 7d5a4b2..9a2b5a4 100644
--- a/native/actions/regex-actions.cc
+++ b/native/actions/regex-actions.cc
@@ -189,7 +189,7 @@
bool RegexActions::SuggestActions(
const Conversation& conversation,
- const ReflectiveFlatbufferBuilder* entity_data_builder,
+ const MutableFlatbufferBuilder* entity_data_builder,
std::vector<ActionSuggestion>* actions) const {
// Create actions based on rules checking the last message.
const int message_index = conversation.messages.size() - 1;
@@ -206,7 +206,7 @@
const ActionSuggestionSpec* action = rule_action->action();
std::vector<ActionSuggestionAnnotation> annotations;
- std::unique_ptr<ReflectiveFlatbuffer> entity_data =
+ std::unique_ptr<MutableFlatbuffer> entity_data =
entity_data_builder != nullptr ? entity_data_builder->NewRoot()
: nullptr;
diff --git a/native/actions/regex-actions.h b/native/actions/regex-actions.h
index 871f08b..ee0b186 100644
--- a/native/actions/regex-actions.h
+++ b/native/actions/regex-actions.h
@@ -23,7 +23,7 @@
#include "actions/actions_model_generated.h"
#include "actions/types.h"
-#include "utils/flatbuffers.h"
+#include "utils/flatbuffers/mutable.h"
#include "utils/utf8/unilib.h"
#include "utils/zlib/zlib.h"
@@ -55,7 +55,7 @@
// Suggests actions for a conversation from a message stream using the regex
// rules.
bool SuggestActions(const Conversation& conversation,
- const ReflectiveFlatbufferBuilder* entity_data_builder,
+ const MutableFlatbufferBuilder* entity_data_builder,
std::vector<ActionSuggestion>* actions) const;
private:
diff --git a/native/actions/test-utils.cc b/native/actions/test-utils.cc
index 9b003dd..426989d 100644
--- a/native/actions/test-utils.cc
+++ b/native/actions/test-utils.cc
@@ -16,6 +16,8 @@
#include "actions/test-utils.h"
+#include "flatbuffers/reflection.h"
+
namespace libtextclassifier3 {
std::string TestEntityDataSchema() {
diff --git a/native/actions/test-utils.h b/native/actions/test-utils.h
index c05d6a9..e27f510 100644
--- a/native/actions/test-utils.h
+++ b/native/actions/test-utils.h
@@ -20,7 +20,6 @@
#include <string>
#include "actions/actions_model_generated.h"
-#include "utils/flatbuffers.h"
#include "gmock/gmock.h"
namespace libtextclassifier3 {
diff --git a/native/actions/types.h b/native/actions/types.h
index e7d384f..c971529 100644
--- a/native/actions/types.h
+++ b/native/actions/types.h
@@ -23,7 +23,7 @@
#include "actions/actions-entity-data_generated.h"
#include "annotator/types.h"
-#include "utils/flatbuffers.h"
+#include "utils/flatbuffers/flatbuffers.h"
namespace libtextclassifier3 {
diff --git a/native/actions/utils.cc b/native/actions/utils.cc
index 96f6f1f..53714d6 100644
--- a/native/actions/utils.cc
+++ b/native/actions/utils.cc
@@ -16,14 +16,19 @@
#include "actions/utils.h"
+#include "annotator/collections.h"
#include "utils/base/logging.h"
#include "utils/normalization.h"
#include "utils/strings/stringpiece.h"
namespace libtextclassifier3 {
+// Name for a datetime annotation that only includes time but no date.
+const std::string& kTimeAnnotation =
+ *[]() { return new std::string("time"); }();
+
void FillSuggestionFromSpec(const ActionSuggestionSpec* action,
- ReflectiveFlatbuffer* entity_data,
+ MutableFlatbuffer* entity_data,
ActionSuggestion* suggestion) {
if (action != nullptr) {
suggestion->score = action->score();
@@ -52,7 +57,7 @@
}
void SuggestTextRepliesFromCapturingMatch(
- const ReflectiveFlatbufferBuilder* entity_data_builder,
+ const MutableFlatbufferBuilder* entity_data_builder,
const RulesModel_::RuleActionSpec_::RuleCapturingGroup* group,
const UnicodeText& match_text, const std::string& smart_reply_action_type,
std::vector<ActionSuggestion>* actions) {
@@ -60,7 +65,7 @@
ActionSuggestion suggestion;
suggestion.response_text = match_text.ToUTF8String();
suggestion.type = smart_reply_action_type;
- std::unique_ptr<ReflectiveFlatbuffer> entity_data =
+ std::unique_ptr<MutableFlatbuffer> entity_data =
entity_data_builder != nullptr ? entity_data_builder->NewRoot()
: nullptr;
FillSuggestionFromSpec(group->text_reply(), entity_data.get(), &suggestion);
@@ -104,7 +109,7 @@
bool MergeEntityDataFromCapturingMatch(
const RulesModel_::RuleActionSpec_::RuleCapturingGroup* group,
- StringPiece match_text, ReflectiveFlatbuffer* buffer) {
+ StringPiece match_text, MutableFlatbuffer* buffer) {
if (group->entity_field() != nullptr) {
if (!buffer->ParseAndSet(group->entity_field(), match_text.ToString())) {
TC3_LOG(ERROR) << "Could not set entity data from rule capturing group.";
@@ -121,4 +126,29 @@
return true;
}
+void ConvertDatetimeToTime(std::vector<AnnotatedSpan>* annotations) {
+ for (int i = 0; i < annotations->size(); i++) {
+ ClassificationResult* classification =
+ &(*annotations)[i].classification.front();
+ // Specialize datetime annotation to time annotation if no date
+ // component is present.
+ if (classification->collection == Collections::DateTime() &&
+ classification->datetime_parse_result.IsSet()) {
+ bool has_only_time = true;
+ for (const DatetimeComponent& component :
+ classification->datetime_parse_result.datetime_components) {
+ if (component.component_type !=
+ DatetimeComponent::ComponentType::UNSPECIFIED &&
+ component.component_type < DatetimeComponent::ComponentType::HOUR) {
+ has_only_time = false;
+ break;
+ }
+ }
+ if (has_only_time) {
+ classification->collection = kTimeAnnotation;
+ }
+ }
+ }
+}
+
} // namespace libtextclassifier3
diff --git a/native/actions/utils.h b/native/actions/utils.h
index 820c79d..d8bdec2 100644
--- a/native/actions/utils.h
+++ b/native/actions/utils.h
@@ -25,7 +25,8 @@
#include "actions/actions_model_generated.h"
#include "actions/types.h"
#include "annotator/types.h"
-#include "utils/flatbuffers.h"
+#include "utils/flatbuffers/flatbuffers.h"
+#include "utils/flatbuffers/mutable.h"
#include "utils/utf8/unicodetext.h"
#include "utils/utf8/unilib.h"
@@ -33,12 +34,12 @@
// Fills an action suggestion from a template.
void FillSuggestionFromSpec(const ActionSuggestionSpec* action,
- ReflectiveFlatbuffer* entity_data,
+ MutableFlatbuffer* entity_data,
ActionSuggestion* suggestion);
// Creates text replies from capturing matches.
void SuggestTextRepliesFromCapturingMatch(
- const ReflectiveFlatbufferBuilder* entity_data_builder,
+ const MutableFlatbufferBuilder* entity_data_builder,
const RulesModel_::RuleActionSpec_::RuleCapturingGroup* group,
const UnicodeText& match_text, const std::string& smart_reply_action_type,
std::vector<ActionSuggestion>* actions);
@@ -60,7 +61,11 @@
// Parses and sets values from the text and merges fixed data.
bool MergeEntityDataFromCapturingMatch(
const RulesModel_::RuleActionSpec_::RuleCapturingGroup* group,
- StringPiece match_text, ReflectiveFlatbuffer* buffer);
+ StringPiece match_text, MutableFlatbuffer* buffer);
+
+// Changes datetime classifications to a time type if no date component is
+// present. Modifies classifications in-place.
+void ConvertDatetimeToTime(std::vector<AnnotatedSpan>* annotations);
} // namespace libtextclassifier3
diff --git a/native/actions/zlib-utils.cc b/native/actions/zlib-utils.cc
index c8ad4e7..4525a21 100644
--- a/native/actions/zlib-utils.cc
+++ b/native/actions/zlib-utils.cc
@@ -20,7 +20,6 @@
#include "utils/base/logging.h"
#include "utils/intents/zlib-utils.h"
-#include "utils/resources.h"
namespace libtextclassifier3 {
@@ -76,11 +75,6 @@
model->ranking_options->compressed_lua_ranking_script.get());
}
- // Compress resources.
- if (model->resources != nullptr) {
- CompressResources(model->resources.get());
- }
-
// Compress intent generator.
if (model->android_intent_options != nullptr) {
CompressIntentModel(model->android_intent_options.get());
diff --git a/native/annotator/annotator.cc b/native/annotator/annotator.cc
index 6ee983f..2dc9b5c 100644
--- a/native/annotator/annotator.cc
+++ b/native/annotator/annotator.cc
@@ -26,6 +26,8 @@
#include <vector>
#include "annotator/collections.h"
+#include "annotator/flatbuffer-utils.h"
+#include "annotator/knowledge/knowledge-engine-types.h"
#include "annotator/model_generated.h"
#include "annotator/types.h"
#include "utils/base/logging.h"
@@ -37,6 +39,7 @@
#include "utils/normalization.h"
#include "utils/optional.h"
#include "utils/regex-match.h"
+#include "utils/strings/append.h"
#include "utils/strings/numbers.h"
#include "utils/strings/split.h"
#include "utils/utf8/unicodetext.h"
@@ -104,7 +107,7 @@
// Returns whether the provided input is valid:
// * Valid utf8 text.
// * Sane span indices.
-bool IsValidSpanInput(const UnicodeText& context, const CodepointSpan span) {
+bool IsValidSpanInput(const UnicodeText& context, const CodepointSpan& span) {
if (!context.is_valid()) {
return false;
}
@@ -504,6 +507,16 @@
selection_feature_processor_.get(), unilib_));
}
+ if (model_->grammar_model()) {
+ grammar_annotator_.reset(new GrammarAnnotator(
+ unilib_, model_->grammar_model(), entity_data_builder_.get()));
+ }
+
+ if (model_->pod_ner_model()) {
+ pod_ner_annotator_ =
+ PodNerAnnotator::Create(model_->pod_ner_model(), *unilib_);
+ }
+
if (model_->entity_data_schema()) {
entity_data_schema_ = LoadAndVerifyFlatbuffer<reflection::Schema>(
model_->entity_data_schema()->Data(),
@@ -514,17 +527,12 @@
}
entity_data_builder_.reset(
- new ReflectiveFlatbufferBuilder(entity_data_schema_));
+ new MutableFlatbufferBuilder(entity_data_schema_));
} else {
entity_data_schema_ = nullptr;
entity_data_builder_ = nullptr;
}
- if (model_->grammar_model()) {
- grammar_annotator_.reset(new GrammarAnnotator(
- unilib_, model_->grammar_model(), entity_data_builder_.get()));
- }
-
if (model_->triggering_locales() &&
!ParseLocales(model_->triggering_locales()->c_str(),
&model_triggering_locales_)) {
@@ -701,6 +709,7 @@
if (ExperimentalAnnotator::IsEnabled()) {
experimental_annotator_.reset(new ExperimentalAnnotator(
model_->experimental_model(), *selection_feature_processor_, *unilib_));
+
return true;
}
return false;
@@ -708,7 +717,8 @@
namespace {
-int CountDigits(const std::string& str, CodepointSpan selection_indices) {
+int CountDigits(const std::string& str,
+ const CodepointSpan& selection_indices) {
int count = 0;
int i = 0;
const UnicodeText unicode_str = UTF8ToUnicodeText(str, /*do_copy=*/false);
@@ -727,10 +737,10 @@
// Helper function, which if the initial 'span' contains only white-spaces,
// moves the selection to a single-codepoint selection on a left or right side
// of this space.
-CodepointSpan SnapLeftIfWhitespaceSelection(CodepointSpan span,
+CodepointSpan SnapLeftIfWhitespaceSelection(const CodepointSpan& span,
const UnicodeText& context_unicode,
const UniLib& unilib) {
- TC3_CHECK(ValidNonEmptySpan(span));
+ TC3_CHECK(span.IsValid() && !span.IsEmpty());
UnicodeText::const_iterator it;
@@ -743,10 +753,8 @@
}
}
- CodepointSpan result;
-
// Try moving left.
- result = span;
+ CodepointSpan result = span;
it = context_unicode.begin();
std::advance(it, span.first);
while (it != context_unicode.begin() && unilib.IsWhitespace(*it)) {
@@ -888,60 +896,72 @@
click_indices, context_unicode, *unilib_);
}
- std::vector<AnnotatedSpan> candidates;
+ Annotations candidates;
+ // As we process a single string of context, the candidates will only
+ // contain one vector of AnnotatedSpan.
+ candidates.annotated_spans.resize(1);
InterpreterManager interpreter_manager(selection_executor_.get(),
classification_executor_.get());
std::vector<Token> tokens;
if (!ModelSuggestSelection(context_unicode, click_indices,
detected_text_language_tags, &interpreter_manager,
- &tokens, &candidates)) {
+ &tokens, &candidates.annotated_spans[0])) {
TC3_LOG(ERROR) << "Model suggest selection failed.";
return original_click_indices;
}
- if (!RegexChunk(context_unicode, selection_regex_patterns_, &candidates,
- /*is_serialized_entity_data_enabled=*/false)) {
+ const std::unordered_set<std::string> set;
+ const EnabledEntityTypes is_entity_type_enabled(set);
+ if (!RegexChunk(context_unicode, selection_regex_patterns_,
+ /*is_serialized_entity_data_enabled=*/false,
+ is_entity_type_enabled, options.annotation_usecase,
+ &candidates.annotated_spans[0])) {
TC3_LOG(ERROR) << "Regex suggest selection failed.";
return original_click_indices;
}
- if (!DatetimeChunk(
- UTF8ToUnicodeText(context, /*do_copy=*/false),
- /*reference_time_ms_utc=*/0, /*reference_timezone=*/"",
- options.locales, ModeFlag_SELECTION, options.annotation_usecase,
- /*is_serialized_entity_data_enabled=*/false, &candidates)) {
+ if (!DatetimeChunk(UTF8ToUnicodeText(context, /*do_copy=*/false),
+ /*reference_time_ms_utc=*/0, /*reference_timezone=*/"",
+ options.locales, ModeFlag_SELECTION,
+ options.annotation_usecase,
+ /*is_serialized_entity_data_enabled=*/false,
+ &candidates.annotated_spans[0])) {
TC3_LOG(ERROR) << "Datetime suggest selection failed.";
return original_click_indices;
}
if (knowledge_engine_ != nullptr &&
!knowledge_engine_->Chunk(context, options.annotation_usecase,
options.location_context, Permissions(),
- &candidates)) {
+ AnnotateMode::kEntityAnnotation, &candidates)) {
TC3_LOG(ERROR) << "Knowledge suggest selection failed.";
return original_click_indices;
}
if (contact_engine_ != nullptr &&
- !contact_engine_->Chunk(context_unicode, tokens, &candidates)) {
+ !contact_engine_->Chunk(context_unicode, tokens,
+ &candidates.annotated_spans[0])) {
TC3_LOG(ERROR) << "Contact suggest selection failed.";
return original_click_indices;
}
if (installed_app_engine_ != nullptr &&
- !installed_app_engine_->Chunk(context_unicode, tokens, &candidates)) {
+ !installed_app_engine_->Chunk(context_unicode, tokens,
+ &candidates.annotated_spans[0])) {
TC3_LOG(ERROR) << "Installed app suggest selection failed.";
return original_click_indices;
}
if (number_annotator_ != nullptr &&
!number_annotator_->FindAll(context_unicode, options.annotation_usecase,
- &candidates)) {
+ &candidates.annotated_spans[0])) {
TC3_LOG(ERROR) << "Number annotator failed in suggest selection.";
return original_click_indices;
}
if (duration_annotator_ != nullptr &&
!duration_annotator_->FindAll(context_unicode, tokens,
- options.annotation_usecase, &candidates)) {
+ options.annotation_usecase,
+ &candidates.annotated_spans[0])) {
TC3_LOG(ERROR) << "Duration annotator failed in suggest selection.";
return original_click_indices;
}
if (person_name_engine_ != nullptr &&
- !person_name_engine_->Chunk(context_unicode, tokens, &candidates)) {
+ !person_name_engine_->Chunk(context_unicode, tokens,
+ &candidates.annotated_spans[0])) {
TC3_LOG(ERROR) << "Person name suggest selection failed.";
return original_click_indices;
}
@@ -951,24 +971,31 @@
grammar_annotator_->SuggestSelection(detected_text_language_tags,
context_unicode, click_indices,
&grammar_suggested_span)) {
- candidates.push_back(grammar_suggested_span);
+ candidates.annotated_spans[0].push_back(grammar_suggested_span);
+ }
+
+ if (pod_ner_annotator_ != nullptr && options.use_pod_ner) {
+ candidates.annotated_spans[0].push_back(
+ pod_ner_annotator_->SuggestSelection(context_unicode, click_indices));
}
if (experimental_annotator_ != nullptr) {
- candidates.push_back(experimental_annotator_->SuggestSelection(
- context_unicode, click_indices));
+ candidates.annotated_spans[0].push_back(
+ experimental_annotator_->SuggestSelection(context_unicode,
+ click_indices));
}
// Sort candidates according to their position in the input, so that the next
// code can assume that any connected component of overlapping spans forms a
// contiguous block.
- std::sort(candidates.begin(), candidates.end(),
+ std::sort(candidates.annotated_spans[0].begin(),
+ candidates.annotated_spans[0].end(),
[](const AnnotatedSpan& a, const AnnotatedSpan& b) {
return a.span.first < b.span.first;
});
std::vector<int> candidate_indices;
- if (!ResolveConflicts(candidates, context, tokens,
+ if (!ResolveConflicts(candidates.annotated_spans[0], context, tokens,
detected_text_language_tags, options.annotation_usecase,
&interpreter_manager, &candidate_indices)) {
TC3_LOG(ERROR) << "Couldn't resolve conflicts.";
@@ -977,32 +1004,36 @@
std::sort(candidate_indices.begin(), candidate_indices.end(),
[this, &candidates](int a, int b) {
- return GetPriorityScore(candidates[a].classification) >
- GetPriorityScore(candidates[b].classification);
+ return GetPriorityScore(
+ candidates.annotated_spans[0][a].classification) >
+ GetPriorityScore(
+ candidates.annotated_spans[0][b].classification);
});
for (const int i : candidate_indices) {
- if (SpansOverlap(candidates[i].span, click_indices) &&
- SpansOverlap(candidates[i].span, original_click_indices)) {
+ if (SpansOverlap(candidates.annotated_spans[0][i].span, click_indices) &&
+ SpansOverlap(candidates.annotated_spans[0][i].span,
+ original_click_indices)) {
// Run model classification if not present but requested and there's a
// classification collection filter specified.
- if (candidates[i].classification.empty() &&
+ if (candidates.annotated_spans[0][i].classification.empty() &&
model_->selection_options()->always_classify_suggested_selection() &&
!filtered_collections_selection_.empty()) {
- if (!ModelClassifyText(context, detected_text_language_tags,
- candidates[i].span, &interpreter_manager,
- /*embedding_cache=*/nullptr,
- &candidates[i].classification)) {
+ if (!ModelClassifyText(
+ context, detected_text_language_tags,
+ candidates.annotated_spans[0][i].span, &interpreter_manager,
+ /*embedding_cache=*/nullptr,
+ &candidates.annotated_spans[0][i].classification)) {
return original_click_indices;
}
}
// Ignore if span classification is filtered.
- if (FilteredForSelection(candidates[i])) {
+ if (FilteredForSelection(candidates.annotated_spans[0][i])) {
return original_click_indices;
}
- return candidates[i].span;
+ return candidates.annotated_spans[0][i].span;
}
}
@@ -1220,7 +1251,7 @@
}
bool Annotator::ModelSuggestSelection(
- const UnicodeText& context_unicode, CodepointSpan click_indices,
+ const UnicodeText& context_unicode, const CodepointSpan& click_indices,
const std::vector<Locale>& detected_text_language_tags,
InterpreterManager* interpreter_manager, std::vector<Token>* tokens,
std::vector<AnnotatedSpan>* result) const {
@@ -1254,11 +1285,11 @@
// The symmetry context span is the clicked token with symmetry_context_size
// tokens on either side.
- const TokenSpan symmetry_context_span = IntersectTokenSpans(
- ExpandTokenSpan(SingleTokenSpan(click_pos),
- /*num_tokens_left=*/symmetry_context_size,
- /*num_tokens_right=*/symmetry_context_size),
- {0, tokens->size()});
+ const TokenSpan symmetry_context_span =
+ IntersectTokenSpans(TokenSpan(click_pos).Expand(
+ /*num_tokens_left=*/symmetry_context_size,
+ /*num_tokens_right=*/symmetry_context_size),
+ AllOf(*tokens));
// Compute the extraction span based on the model type.
TokenSpan extraction_span;
@@ -1269,22 +1300,21 @@
// the bounds of the selection.
const int max_selection_span =
selection_feature_processor_->GetOptions()->max_selection_span();
- extraction_span =
- ExpandTokenSpan(symmetry_context_span,
- /*num_tokens_left=*/max_selection_span +
- bounds_sensitive_features->num_tokens_before(),
- /*num_tokens_right=*/max_selection_span +
- bounds_sensitive_features->num_tokens_after());
+ extraction_span = symmetry_context_span.Expand(
+ /*num_tokens_left=*/max_selection_span +
+ bounds_sensitive_features->num_tokens_before(),
+ /*num_tokens_right=*/max_selection_span +
+ bounds_sensitive_features->num_tokens_after());
} else {
// The extraction span is the symmetry context span expanded to include
// context_size tokens on either side.
const int context_size =
selection_feature_processor_->GetOptions()->context_size();
- extraction_span = ExpandTokenSpan(symmetry_context_span,
- /*num_tokens_left=*/context_size,
- /*num_tokens_right=*/context_size);
+ extraction_span = symmetry_context_span.Expand(
+ /*num_tokens_left=*/context_size,
+ /*num_tokens_right=*/context_size);
}
- extraction_span = IntersectTokenSpans(extraction_span, {0, tokens->size()});
+ extraction_span = IntersectTokenSpans(extraction_span, AllOf(*tokens));
if (!selection_feature_processor_->HasEnoughSupportedCodepoints(
*tokens, extraction_span)) {
@@ -1333,7 +1363,8 @@
bool Annotator::ModelClassifyText(
const std::string& context,
const std::vector<Locale>& detected_text_language_tags,
- CodepointSpan selection_indices, InterpreterManager* interpreter_manager,
+ const CodepointSpan& selection_indices,
+ InterpreterManager* interpreter_manager,
FeatureProcessor::EmbeddingCache* embedding_cache,
std::vector<ClassificationResult>* classification_results) const {
return ModelClassifyText(context, {}, detected_text_language_tags,
@@ -1343,7 +1374,7 @@
namespace internal {
std::vector<Token> CopyCachedTokens(const std::vector<Token>& cached_tokens,
- CodepointSpan selection_indices,
+ const CodepointSpan& selection_indices,
TokenSpan tokens_around_selection_to_copy) {
const auto first_selection_token = std::upper_bound(
cached_tokens.begin(), cached_tokens.end(), selection_indices.first,
@@ -1407,7 +1438,8 @@
bool Annotator::ModelClassifyText(
const std::string& context, const std::vector<Token>& cached_tokens,
const std::vector<Locale>& detected_text_language_tags,
- CodepointSpan selection_indices, InterpreterManager* interpreter_manager,
+ const CodepointSpan& selection_indices,
+ InterpreterManager* interpreter_manager,
FeatureProcessor::EmbeddingCache* embedding_cache,
std::vector<ClassificationResult>* classification_results) const {
std::vector<Token> tokens;
@@ -1419,7 +1451,8 @@
bool Annotator::ModelClassifyText(
const std::string& context, const std::vector<Token>& cached_tokens,
const std::vector<Locale>& detected_text_language_tags,
- CodepointSpan selection_indices, InterpreterManager* interpreter_manager,
+ const CodepointSpan& selection_indices,
+ InterpreterManager* interpreter_manager,
FeatureProcessor::EmbeddingCache* embedding_cache,
std::vector<ClassificationResult>* classification_results,
std::vector<Token>* tokens) const {
@@ -1450,7 +1483,7 @@
tokens, &click_pos);
const TokenSpan selection_token_span =
CodepointSpanToTokenSpan(*tokens, selection_indices);
- const int selection_num_tokens = TokenSpanSize(selection_token_span);
+ const int selection_num_tokens = selection_token_span.Size();
if (model_->classification_options()->max_num_tokens() > 0 &&
model_->classification_options()->max_num_tokens() <
selection_num_tokens) {
@@ -1473,8 +1506,7 @@
if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
// The extraction span is the selection span expanded to include a relevant
// number of tokens outside of the bounds of the selection.
- extraction_span = ExpandTokenSpan(
- selection_token_span,
+ extraction_span = selection_token_span.Expand(
/*num_tokens_left=*/bounds_sensitive_features->num_tokens_before(),
/*num_tokens_right=*/bounds_sensitive_features->num_tokens_after());
} else {
@@ -1486,11 +1518,11 @@
// either side.
const int context_size =
classification_feature_processor_->GetOptions()->context_size();
- extraction_span = ExpandTokenSpan(SingleTokenSpan(click_pos),
- /*num_tokens_left=*/context_size,
- /*num_tokens_right=*/context_size);
+ extraction_span = TokenSpan(click_pos).Expand(
+ /*num_tokens_left=*/context_size,
+ /*num_tokens_right=*/context_size);
}
- extraction_span = IntersectTokenSpans(extraction_span, {0, tokens->size()});
+ extraction_span = IntersectTokenSpans(extraction_span, AllOf(*tokens));
if (!classification_feature_processor_->HasEnoughSupportedCodepoints(
*tokens, extraction_span)) {
@@ -1588,7 +1620,7 @@
}
bool Annotator::RegexClassifyText(
- const std::string& context, CodepointSpan selection_indices,
+ const std::string& context, const CodepointSpan& selection_indices,
std::vector<ClassificationResult>* classification_result) const {
const std::string selection_text =
UTF8ToUnicodeText(context, /*do_copy=*/false)
@@ -1643,42 +1675,10 @@
}
}
-std::string CreateDatetimeSerializedEntityData(
- const DatetimeParseResult& parse_result) {
- EntityDataT entity_data;
- entity_data.datetime.reset(new EntityData_::DatetimeT());
- entity_data.datetime->time_ms_utc = parse_result.time_ms_utc;
- entity_data.datetime->granularity =
- static_cast<EntityData_::Datetime_::Granularity>(
- parse_result.granularity);
-
- for (const auto& c : parse_result.datetime_components) {
- EntityData_::Datetime_::DatetimeComponentT datetime_component;
- datetime_component.absolute_value = c.value;
- datetime_component.relative_count = c.relative_count;
- datetime_component.component_type =
- static_cast<EntityData_::Datetime_::DatetimeComponent_::ComponentType>(
- c.component_type);
- datetime_component.relation_type =
- EntityData_::Datetime_::DatetimeComponent_::RelationType_ABSOLUTE;
- if (c.relative_qualifier !=
- DatetimeComponent::RelativeQualifier::UNSPECIFIED) {
- datetime_component.relation_type =
- EntityData_::Datetime_::DatetimeComponent_::RelationType_RELATIVE;
- }
- entity_data.datetime->datetime_component.emplace_back(
- new EntityData_::Datetime_::DatetimeComponentT(datetime_component));
- }
- flatbuffers::FlatBufferBuilder builder;
- FinishEntityDataBuffer(builder, EntityData::Pack(builder, &entity_data));
- return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
- builder.GetSize());
-}
-
} // namespace
bool Annotator::DatetimeClassifyText(
- const std::string& context, CodepointSpan selection_indices,
+ const std::string& context, const CodepointSpan& selection_indices,
const ClassificationOptions& options,
std::vector<ClassificationResult>* classification_results) const {
if (!datetime_parser_ && !cfg_datetime_parser_) {
@@ -1720,8 +1720,8 @@
for (const DatetimeParseResultSpan& datetime_span : datetime_spans) {
// Only consider the result valid if the selection and extracted datetime
// spans exactly match.
- if (std::make_pair(datetime_span.span.first + selection_indices.first,
- datetime_span.span.second + selection_indices.first) ==
+ if (CodepointSpan(datetime_span.span.first + selection_indices.first,
+ datetime_span.span.second + selection_indices.first) ==
selection_indices) {
for (const DatetimeParseResult& parse_result : datetime_span.data) {
classification_results->emplace_back(
@@ -1740,7 +1740,7 @@
}
std::vector<ClassificationResult> Annotator::ClassifyText(
- const std::string& context, CodepointSpan selection_indices,
+ const std::string& context, const CodepointSpan& selection_indices,
const ClassificationOptions& options) const {
if (!initialized_) {
TC3_LOG(ERROR) << "Not initialized";
@@ -1772,8 +1772,7 @@
if (!IsValidSpanInput(UTF8ToUnicodeText(context, /*do_copy=*/false),
selection_indices)) {
TC3_VLOG(1) << "Trying to run ClassifyText with invalid input: "
- << std::get<0>(selection_indices) << " "
- << std::get<1>(selection_indices);
+ << selection_indices.first << " " << selection_indices.second;
return {};
}
@@ -1885,11 +1884,16 @@
candidates.push_back({selection_indices, {grammar_annotator_result}});
}
- ClassificationResult experimental_annotator_result;
- if (experimental_annotator_ &&
- experimental_annotator_->ClassifyText(context_unicode, selection_indices,
- &experimental_annotator_result)) {
- candidates.push_back({selection_indices, {experimental_annotator_result}});
+ ClassificationResult pod_ner_annotator_result;
+ if (pod_ner_annotator_ && options.use_pod_ner &&
+ pod_ner_annotator_->ClassifyText(context_unicode, selection_indices,
+ &pod_ner_annotator_result)) {
+ candidates.push_back({selection_indices, {pod_ner_annotator_result}});
+ }
+
+ if (experimental_annotator_) {
+ experimental_annotator_->ClassifyText(context_unicode, selection_indices,
+ candidates);
}
// Try the ML model.
@@ -1983,7 +1987,8 @@
selection_feature_processor_->GetOptions()->only_use_line_with_click(),
tokens,
/*click_pos=*/nullptr);
- const TokenSpan full_line_span = {0, tokens->size()};
+ const TokenSpan full_line_span = {0,
+ static_cast<TokenIndex>(tokens->size())};
// TODO(zilka): Add support for greater granularity of this check.
if (!selection_feature_processor_->HasEnoughSupportedCodepoints(
@@ -2014,9 +2019,13 @@
const int offset = std::distance(context_unicode.begin(), line.first);
for (const TokenSpan& chunk : local_chunks) {
- const CodepointSpan codepoint_span =
+ CodepointSpan codepoint_span =
selection_feature_processor_->StripBoundaryCodepoints(
line_str, TokenSpanToCodepointSpan(*tokens, chunk));
+ if (model_->selection_options()->strip_unpaired_brackets()) {
+ codepoint_span =
+ StripUnpairedBrackets(context_unicode, codepoint_span, *unilib_);
+ }
// Skip empty spans.
if (codepoint_span.first != codepoint_span.second) {
@@ -2133,15 +2142,16 @@
return Status(StatusCode::INTERNAL, "Couldn't run ModelAnnotate.");
}
+ const EnabledEntityTypes is_entity_type_enabled(options.entity_types);
// Annotate with the regular expression models.
- if (!RegexChunk(UTF8ToUnicodeText(context, /*do_copy=*/false),
- annotation_regex_patterns_, candidates,
- options.is_serialized_entity_data_enabled)) {
+ if (!RegexChunk(
+ UTF8ToUnicodeText(context, /*do_copy=*/false),
+ annotation_regex_patterns_, options.is_serialized_entity_data_enabled,
+ is_entity_type_enabled, options.annotation_usecase, candidates)) {
return Status(StatusCode::INTERNAL, "Couldn't run RegexChunk.");
}
// Annotate with the datetime model.
- const EnabledEntityTypes is_entity_type_enabled(options.entity_types);
if ((is_entity_type_enabled(Collections::Date()) ||
is_entity_type_enabled(Collections::DateTime())) &&
!DatetimeChunk(UTF8ToUnicodeText(context, /*do_copy=*/false),
@@ -2166,7 +2176,15 @@
}
// Annotate with the number annotator.
- if (number_annotator_ != nullptr &&
+ bool number_annotations_enabled = true;
+ // Disable running the annotator in RAW mode if the number/percentage
+ // annotations are not explicitly requested.
+ if (options.annotation_usecase == AnnotationUsecase_ANNOTATION_USECASE_RAW &&
+ !is_entity_type_enabled(Collections::Number()) &&
+ !is_entity_type_enabled(Collections::Percentage())) {
+ number_annotations_enabled = false;
+ }
+ if (number_annotations_enabled && number_annotator_ != nullptr &&
!number_annotator_->FindAll(context_unicode, options.annotation_usecase,
candidates)) {
return Status(StatusCode::INTERNAL,
@@ -2197,6 +2215,13 @@
return Status(StatusCode::INTERNAL, "Couldn't run grammar annotators.");
}
+ // Annotate with the POD NER annotator.
+ if (pod_ner_annotator_ != nullptr && options.use_pod_ner &&
+ !pod_ner_annotator_->Annotate(context_unicode, candidates)) {
+ return Status(StatusCode::INTERNAL, "Couldn't run POD NER annotator.");
+ }
+
+ // Annotate with the experimental annotator.
if (experimental_annotator_ != nullptr &&
!experimental_annotator_->Annotate(context_unicode, candidates)) {
return Status(StatusCode::INTERNAL, "Couldn't run experimental annotator.");
@@ -2267,12 +2292,11 @@
return Status::OK;
}
-StatusOr<std::vector<std::vector<AnnotatedSpan>>>
-Annotator::AnnotateStructuredInput(
+StatusOr<Annotations> Annotator::AnnotateStructuredInput(
const std::vector<InputFragment>& string_fragments,
const AnnotationOptions& options) const {
- std::vector<std::vector<AnnotatedSpan>> annotation_candidates(
- string_fragments.size());
+ Annotations annotation_candidates;
+ annotation_candidates.annotated_spans.resize(string_fragments.size());
std::vector<std::string> text_to_annotate;
text_to_annotate.reserve(string_fragments.size());
@@ -2286,21 +2310,30 @@
!knowledge_engine_
->ChunkMultipleSpans(text_to_annotate, options.annotation_usecase,
options.location_context, options.permissions,
- &annotation_candidates)
+ options.annotate_mode, &annotation_candidates)
.ok()) {
return Status(StatusCode::INTERNAL, "Couldn't run knowledge engine Chunk.");
}
// The annotator engines shouldn't change the number of annotation vectors.
- if (annotation_candidates.size() != text_to_annotate.size()) {
+ if (annotation_candidates.annotated_spans.size() != text_to_annotate.size()) {
TC3_LOG(ERROR) << "Received " << text_to_annotate.size()
<< " texts to annotate but generated a different number of "
"lists of annotations:"
- << annotation_candidates.size();
+ << annotation_candidates.annotated_spans.size();
return Status(StatusCode::INTERNAL,
"Number of annotation candidates differs from "
"number of texts to annotate.");
}
+ // As an optimization, if the only annotated type is Entity, we skip all the
+ // other annotators than the KnowledgeEngine. This only happens in the raw
+ // mode, to make sure it does not affect the result.
+ if (options.annotation_usecase == ANNOTATION_USECASE_RAW &&
+ options.entity_types.size() == 1 &&
+ *options.entity_types.begin() == Collections::Entity()) {
+ return annotation_candidates;
+ }
+
// Other annotators run on each fragment independently.
for (int i = 0; i < text_to_annotate.size(); ++i) {
AnnotationOptions annotation_options = options;
@@ -2314,10 +2347,11 @@
}
AddContactMetadataToKnowledgeClassificationResults(
- &annotation_candidates[i]);
+ &annotation_candidates.annotated_spans[i]);
- Status annotation_status = AnnotateSingleInput(
- text_to_annotate[i], annotation_options, &annotation_candidates[i]);
+ Status annotation_status =
+ AnnotateSingleInput(text_to_annotate[i], annotation_options,
+ &annotation_candidates.annotated_spans[i]);
if (!annotation_status.ok()) {
return annotation_status;
}
@@ -2329,14 +2363,14 @@
const std::string& context, const AnnotationOptions& options) const {
std::vector<InputFragment> string_fragments;
string_fragments.push_back({.text = context});
- StatusOr<std::vector<std::vector<AnnotatedSpan>>> annotations =
+ StatusOr<Annotations> annotations =
AnnotateStructuredInput(string_fragments, options);
if (!annotations.ok()) {
TC3_LOG(ERROR) << "Returned error when calling AnnotateStructuredInput: "
<< annotations.status().error_message();
return {};
}
- return annotations.ValueOrDie()[0];
+ return annotations.ValueOrDie().annotated_spans[0];
}
CodepointSpan Annotator::ComputeSelectionBoundaries(
@@ -2408,7 +2442,7 @@
}
TC3_CHECK(entity_data_builder_ != nullptr);
- std::unique_ptr<ReflectiveFlatbuffer> entity_data =
+ std::unique_ptr<MutableFlatbuffer> entity_data =
entity_data_builder_->NewRoot();
TC3_CHECK(entity_data != nullptr);
@@ -2490,8 +2524,44 @@
return whole_amount;
}
+void Annotator::GetMoneyQuantityFromCapturingGroup(
+ const UniLib::RegexMatcher* match, const RegexModel_::Pattern* config,
+ const UnicodeText& context_unicode, std::string* quantity,
+ int* exponent) const {
+ if (config->capturing_group() == nullptr) {
+ *exponent = 0;
+ return;
+ }
+
+ const int num_groups = config->capturing_group()->size();
+ for (int i = 0; i < num_groups; i++) {
+ int status = UniLib::RegexMatcher::kNoError;
+ const int group_start = match->Start(i, &status);
+ const int group_end = match->End(i, &status);
+ if (group_start == kInvalidIndex || group_end == kInvalidIndex) {
+ continue;
+ }
+
+ *quantity =
+ unilib_
+ ->ToLowerText(UnicodeText::Substring(context_unicode, group_start,
+ group_end, /*do_copy=*/false))
+ .ToUTF8String();
+
+ if (auto entry = model_->money_parsing_options()
+ ->quantities_name_to_exponent()
+ ->LookupByKey((*quantity).c_str())) {
+ *exponent = entry->value();
+ return;
+ }
+ }
+ *exponent = 0;
+}
+
bool Annotator::ParseAndFillInMoneyAmount(
- std::string* serialized_entity_data) const {
+ std::string* serialized_entity_data, const UniLib::RegexMatcher* match,
+ const RegexModel_::Pattern* config,
+ const UnicodeText& context_unicode) const {
std::unique_ptr<EntityDataT> data =
LoadAndVerifyMutableFlatbuffer<libtextclassifier3::EntityData>(
*serialized_entity_data);
@@ -2543,20 +2613,41 @@
<< data->money->unnormalized_amount;
return false;
}
+
if (it_decimal_separator == amount.end()) {
data->money->amount_decimal_part = 0;
+ data->money->nanos = 0;
} else {
const int amount_codepoints_size = amount.size_codepoints();
- if (!unilib_->ParseInt32(
- UnicodeText::Substring(
- amount, amount_codepoints_size - separator_back_index,
- amount_codepoints_size, /*do_copy=*/false),
- &data->money->amount_decimal_part)) {
+ const UnicodeText decimal_part = UnicodeText::Substring(
+ amount, amount_codepoints_size - separator_back_index,
+ amount_codepoints_size, /*do_copy=*/false);
+ if (!unilib_->ParseInt32(decimal_part, &data->money->amount_decimal_part)) {
TC3_LOG(ERROR) << "Could not parse the money decimal part as int32 from "
"the amount: "
<< data->money->unnormalized_amount;
return false;
}
+ data->money->nanos = data->money->amount_decimal_part *
+ pow(10, 9 - decimal_part.size_codepoints());
+ }
+
+ if (model_->money_parsing_options()->quantities_name_to_exponent() !=
+ nullptr) {
+ int quantity_exponent;
+ std::string quantity;
+ GetMoneyQuantityFromCapturingGroup(match, config, context_unicode,
+ &quantity, &quantity_exponent);
+ if (quantity_exponent != 0) {
+ data->money->amount_whole_part =
+ data->money->amount_whole_part * pow(10, quantity_exponent) +
+ data->money->nanos / pow(10, 9 - quantity_exponent);
+ data->money->nanos = data->money->nanos %
+ static_cast<int>(pow(10, 9 - quantity_exponent)) *
+ pow(10, quantity_exponent);
+ data->money->unnormalized_amount = strings::JoinStrings(
+ " ", {data->money->unnormalized_amount, quantity});
+ }
}
*serialized_entity_data =
@@ -2566,10 +2657,17 @@
bool Annotator::RegexChunk(const UnicodeText& context_unicode,
const std::vector<int>& rules,
- std::vector<AnnotatedSpan>* result,
- bool is_serialized_entity_data_enabled) const {
+ bool is_serialized_entity_data_enabled,
+ const EnabledEntityTypes& enabled_entity_types,
+ const AnnotationUsecase& annotation_usecase,
+ std::vector<AnnotatedSpan>* result) const {
for (int pattern_id : rules) {
const CompiledRegexPattern& regex_pattern = regex_patterns_[pattern_id];
+ if (!enabled_entity_types(regex_pattern.config->collection_name()->str()) &&
+ annotation_usecase == AnnotationUsecase_ANNOTATION_USECASE_RAW) {
+ // No regex annotation type has been requested, skip regex annotation.
+ continue;
+ }
const auto matcher = regex_pattern.pattern->Matcher(context_unicode);
if (!matcher) {
TC3_LOG(ERROR) << "Could not get regex matcher for pattern: "
@@ -2596,12 +2694,14 @@
return false;
}
- // Further parsing unnormalized_amount for money into amount_whole_part
- // and amount_decimal_part. Can't do this with regexes because we cannot
- // have empty groups (amount_decimal_part might be an empty group).
+ // Further parsing of money amount. Need this since regexes cannot have
+ // empty groups that fill in entity data (amount_decimal_part and
+ // quantity might be empty groups).
if (regex_pattern.config->collection_name()->str() ==
Collections::Money()) {
- if (!ParseAndFillInMoneyAmount(&serialized_entity_data)) {
+ if (!ParseAndFillInMoneyAmount(&serialized_entity_data, matcher.get(),
+ regex_pattern.config,
+ context_unicode)) {
if (model_->version() >= 706) {
// This way of parsing money entity data is enabled for models
// newer than v706 => logging errors only for them (b/156634162).
@@ -2639,11 +2739,11 @@
// The inference span is the span of interest expanded to include
// max_selection_span tokens on either side, which is how far a selection can
// stretch from the click.
- const TokenSpan inference_span = IntersectTokenSpans(
- ExpandTokenSpan(span_of_interest,
- /*num_tokens_left=*/max_selection_span,
- /*num_tokens_right=*/max_selection_span),
- {0, num_tokens});
+ const TokenSpan inference_span =
+ IntersectTokenSpans(span_of_interest.Expand(
+ /*num_tokens_left=*/max_selection_span,
+ /*num_tokens_right=*/max_selection_span),
+ {0, num_tokens});
std::vector<ScoredChunk> scored_chunks;
if (selection_feature_processor_->GetOptions()->bounds_sensitive_features() &&
@@ -2670,7 +2770,7 @@
// Traverse the candidate chunks from highest-scoring to lowest-scoring. Pick
// them greedily as long as they do not overlap with any previously picked
// chunks.
- std::vector<bool> token_used(TokenSpanSize(inference_span));
+ std::vector<bool> token_used(inference_span.Size());
chunks->clear();
for (const ScoredChunk& scored_chunk : scored_chunks) {
bool feasible = true;
@@ -2766,9 +2866,8 @@
TC3_LOG(ERROR) << "Couldn't map the label to a token span.";
return false;
}
- const TokenSpan candidate_span = ExpandTokenSpan(
- SingleTokenSpan(click_pos), relative_token_span.first,
- relative_token_span.second);
+ const TokenSpan candidate_span = TokenSpan(click_pos).Expand(
+ relative_token_span.first, relative_token_span.second);
if (candidate_span.first >= 0 && candidate_span.second <= num_tokens) {
UpdateMax(&chunk_scores, candidate_span, scores[j]);
}
@@ -2803,7 +2902,7 @@
scored_chunks->clear();
if (score_single_token_spans_as_zero) {
- scored_chunks->reserve(TokenSpanSize(span_of_interest));
+ scored_chunks->reserve(span_of_interest.Size());
}
// Prepare all chunk candidates into one batch:
@@ -2819,8 +2918,7 @@
end <= inference_span.second && end - start <= max_chunk_length;
++end) {
const TokenSpan candidate_span = {start, end};
- if (score_single_token_spans_as_zero &&
- TokenSpanSize(candidate_span) == 1) {
+ if (score_single_token_spans_as_zero && candidate_span.Size() == 1) {
// Do not include the single token span in the batch, add a zero score
// for it directly to the output.
scored_chunks->push_back(ScoredChunk{candidate_span, 0.0f});
diff --git a/native/annotator/annotator.h b/native/annotator/annotator.h
index ebd762c..67f10d3 100644
--- a/native/annotator/annotator.h
+++ b/native/annotator/annotator.h
@@ -38,13 +38,15 @@
#include "annotator/model_generated.h"
#include "annotator/number/number.h"
#include "annotator/person_name/person-name-engine.h"
+#include "annotator/pod_ner/pod-ner.h"
#include "annotator/strip-unpaired-brackets.h"
#include "annotator/translate/translate.h"
#include "annotator/types.h"
#include "annotator/zlib-utils.h"
#include "utils/base/status.h"
#include "utils/base/statusor.h"
-#include "utils/flatbuffers.h"
+#include "utils/flatbuffers/flatbuffers.h"
+#include "utils/flatbuffers/mutable.h"
#include "utils/i18n/locale.h"
#include "utils/memory/mmap.h"
#include "utils/utf8/unilib.h"
@@ -184,7 +186,7 @@
// Classifies the selected text given the context string.
// Returns an empty result if an error occurs.
std::vector<ClassificationResult> ClassifyText(
- const std::string& context, CodepointSpan selection_indices,
+ const std::string& context, const CodepointSpan& selection_indices,
const ClassificationOptions& options = ClassificationOptions()) const;
// Annotates the given structed input request. Models which handle the full
@@ -197,7 +199,7 @@
// of input fragments. The order of annotation span vectors will match the
// order of input fragments. If annotation is not possible for any of the
// annotators, no annotation is returned.
- StatusOr<std::vector<std::vector<AnnotatedSpan>>> AnnotateStructuredInput(
+ StatusOr<Annotations> AnnotateStructuredInput(
const std::vector<InputFragment>& string_fragments,
const AnnotationOptions& options = AnnotationOptions()) const;
@@ -282,7 +284,7 @@
// Provides the tokens produced during tokenization of the context string for
// reuse.
bool ModelSuggestSelection(
- const UnicodeText& context_unicode, CodepointSpan click_indices,
+ const UnicodeText& context_unicode, const CodepointSpan& click_indices,
const std::vector<Locale>& detected_text_language_tags,
InterpreterManager* interpreter_manager, std::vector<Token>* tokens,
std::vector<AnnotatedSpan>* result) const;
@@ -292,7 +294,8 @@
// Returns true if no error occurred.
bool ModelClassifyText(
const std::string& context, const std::vector<Token>& cached_tokens,
- const std::vector<Locale>& locales, CodepointSpan selection_indices,
+ const std::vector<Locale>& detected_text_language_tags,
+ const CodepointSpan& selection_indices,
InterpreterManager* interpreter_manager,
FeatureProcessor::EmbeddingCache* embedding_cache,
std::vector<ClassificationResult>* classification_results,
@@ -302,7 +305,8 @@
bool ModelClassifyText(
const std::string& context, const std::vector<Token>& cached_tokens,
const std::vector<Locale>& detected_text_language_tags,
- CodepointSpan selection_indices, InterpreterManager* interpreter_manager,
+ const CodepointSpan& selection_indices,
+ InterpreterManager* interpreter_manager,
FeatureProcessor::EmbeddingCache* embedding_cache,
std::vector<ClassificationResult>* classification_results) const;
@@ -310,7 +314,8 @@
bool ModelClassifyText(
const std::string& context,
const std::vector<Locale>& detected_text_language_tags,
- CodepointSpan selection_indices, InterpreterManager* interpreter_manager,
+ const CodepointSpan& selection_indices,
+ InterpreterManager* interpreter_manager,
FeatureProcessor::EmbeddingCache* embedding_cache,
std::vector<ClassificationResult>* classification_results) const;
@@ -322,13 +327,13 @@
// Classifies the selected text with the regular expressions models.
// Returns true if no error happened, false otherwise.
bool RegexClassifyText(
- const std::string& context, CodepointSpan selection_indices,
+ const std::string& context, const CodepointSpan& selection_indices,
std::vector<ClassificationResult>* classification_result) const;
// Classifies the selected text with the date time model.
// Returns true if no error happened, false otherwise.
bool DatetimeClassifyText(
- const std::string& context, CodepointSpan selection_indices,
+ const std::string& context, const CodepointSpan& selection_indices,
const ClassificationOptions& options,
std::vector<ClassificationResult>* classification_results) const;
@@ -379,8 +384,11 @@
// Produces chunks isolated by a set of regular expressions.
bool RegexChunk(const UnicodeText& context_unicode,
const std::vector<int>& rules,
- std::vector<AnnotatedSpan>* result,
- bool is_serialized_entity_data_enabled) const;
+ bool is_serialized_entity_data_enabled,
+ const EnabledEntityTypes& enabled_entity_types,
+ const AnnotationUsecase& annotation_usecase,
+
+ std::vector<AnnotatedSpan>* result) const;
// Produces chunks from the datetime parser.
bool DatetimeChunk(const UnicodeText& context_unicode,
@@ -462,7 +470,19 @@
// Parses the money amount into whole and decimal part and fills in the
// entity data information.
- bool ParseAndFillInMoneyAmount(std::string* serialized_entity_data) const;
+ bool ParseAndFillInMoneyAmount(std::string* serialized_entity_data,
+ const UniLib::RegexMatcher* match,
+ const RegexModel_::Pattern* config,
+ const UnicodeText& context_unicode) const;
+
+ // Given the regex capturing groups, extract the one representing the money
+ // quantity and fills in the actual string and the power of 10 the amount
+ // should be multiplied with.
+ void GetMoneyQuantityFromCapturingGroup(const UniLib::RegexMatcher* match,
+ const RegexModel_::Pattern* config,
+ const UnicodeText& context_unicode,
+ std::string* quantity,
+ int* exponent) const;
std::unique_ptr<ScopedMmap> mmap_;
bool initialized_ = false;
@@ -491,11 +511,12 @@
std::unique_ptr<const DurationAnnotator> duration_annotator_;
std::unique_ptr<const PersonNameEngine> person_name_engine_;
std::unique_ptr<const TranslateAnnotator> translate_annotator_;
+ std::unique_ptr<const PodNerAnnotator> pod_ner_annotator_;
std::unique_ptr<const ExperimentalAnnotator> experimental_annotator_;
// Builder for creating extra data.
const reflection::Schema* entity_data_schema_;
- std::unique_ptr<ReflectiveFlatbufferBuilder> entity_data_builder_;
+ std::unique_ptr<MutableFlatbufferBuilder> entity_data_builder_;
// Locales for which the entire model triggers.
std::vector<Locale> model_triggering_locales_;
@@ -526,7 +547,7 @@
// Helper function, which if the initial 'span' contains only white-spaces,
// moves the selection to a single-codepoint selection on the left side
// of this block of white-space.
-CodepointSpan SnapLeftIfWhitespaceSelection(CodepointSpan span,
+CodepointSpan SnapLeftIfWhitespaceSelection(const CodepointSpan& span,
const UnicodeText& context_unicode,
const UniLib& unilib);
@@ -534,7 +555,7 @@
// 'tokens_around_selection_to_copy' (on the left, and right) tokens distant
// from the tokens that correspond to 'selection_indices'.
std::vector<Token> CopyCachedTokens(const std::vector<Token>& cached_tokens,
- CodepointSpan selection_indices,
+ const CodepointSpan& selection_indices,
TokenSpan tokens_around_selection_to_copy);
} // namespace internal
diff --git a/native/annotator/annotator_jni.cc b/native/annotator/annotator_jni.cc
index 3e04f7f..8d5ad33 100644
--- a/native/annotator/annotator_jni.cc
+++ b/native/annotator/annotator_jni.cc
@@ -21,10 +21,12 @@
#include <jni.h>
#include <type_traits>
+#include <utility>
#include <vector>
#include "annotator/annotator.h"
#include "annotator/annotator_jni_common.h"
+#include "annotator/knowledge/knowledge-engine-types.h"
#include "annotator/types.h"
#include "utils/base/integral_types.h"
#include "utils/base/status_macros.h"
@@ -35,7 +37,6 @@
#include "utils/intents/remote-action-template.h"
#include "utils/java/jni-cache.h"
#include "utils/java/jni-helper.h"
-#include "utils/java/string_utils.h"
#include "utils/memory/mmap.h"
#include "utils/strings/stringpiece.h"
#include "utils/utf8/unilib.h"
@@ -49,6 +50,7 @@
#endif
using libtextclassifier3::AnnotatedSpan;
+using libtextclassifier3::Annotations;
using libtextclassifier3::Annotator;
using libtextclassifier3::ClassificationResult;
using libtextclassifier3::CodepointSpan;
@@ -81,11 +83,9 @@
std::unique_ptr<IntentGenerator> intent_generator =
IntentGenerator::Create(model->model()->intent_options(),
model->model()->resources(), jni_cache);
- std::unique_ptr<RemoteActionTemplatesHandler> template_handler =
- libtextclassifier3::RemoteActionTemplatesHandler::Create(jni_cache);
- if (template_handler == nullptr) {
- return nullptr;
- }
+ TC3_ASSIGN_OR_RETURN_NULL(
+ std::unique_ptr<RemoteActionTemplatesHandler> template_handler,
+ libtextclassifier3::RemoteActionTemplatesHandler::Create(jni_cache));
return new AnnotatorJniContext(jni_cache, std::move(model),
std::move(intent_generator),
@@ -151,10 +151,11 @@
TC3_ASSIGN_OR_RETURN(serialized_knowledge_result,
JniHelper::NewByteArray(
env, serialized_knowledge_result_string.size()));
- env->SetByteArrayRegion(serialized_knowledge_result.get(), 0,
- serialized_knowledge_result_string.size(),
- reinterpret_cast<const jbyte*>(
- serialized_knowledge_result_string.data()));
+ TC3_RETURN_IF_ERROR(JniHelper::SetByteArrayRegion(
+ env, serialized_knowledge_result.get(), 0,
+ serialized_knowledge_result_string.size(),
+ reinterpret_cast<const jbyte*>(
+ serialized_knowledge_result_string.data())));
}
ScopedLocalRef<jstring> contact_name;
@@ -242,11 +243,11 @@
serialized_entity_data,
JniHelper::NewByteArray(
env, classification_result.serialized_entity_data.size()));
- env->SetByteArrayRegion(
- serialized_entity_data.get(), 0,
+ TC3_RETURN_IF_ERROR(JniHelper::SetByteArrayRegion(
+ env, serialized_entity_data.get(), 0,
classification_result.serialized_entity_data.size(),
reinterpret_cast<const jbyte*>(
- classification_result.serialized_entity_data.data()));
+ classification_result.serialized_entity_data.data())));
}
ScopedLocalRef<jobjectArray> remote_action_templates_result;
@@ -340,7 +341,7 @@
return ClassificationResultsWithIntentsToJObjectArray(
env, model_context,
/*(unused) app_context=*/nullptr,
- /*(unused) devide_locale=*/nullptr,
+ /*(unused) device_locale=*/nullptr,
/*(unusued) options=*/nullptr,
/*(unused) selection_text=*/"",
/*(unused) selection_indices=*/{kInvalidIndex, kInvalidIndex},
@@ -348,9 +349,9 @@
/*generate_intents=*/false);
}
-CodepointSpan ConvertIndicesBMPUTF8(const std::string& utf8_str,
- CodepointSpan orig_indices,
- bool from_utf8) {
+std::pair<int, int> ConvertIndicesBMPUTF8(
+ const std::string& utf8_str, const std::pair<int, int>& orig_indices,
+ bool from_utf8) {
const libtextclassifier3::UnicodeText unicode_str =
libtextclassifier3::UTF8ToUnicodeText(utf8_str, /*do_copy=*/false);
@@ -367,7 +368,7 @@
target_index = &unicode_index;
}
- CodepointSpan result{-1, -1};
+ std::pair<int, int> result = std::make_pair(-1, -1);
std::function<void()> assign_indices_fn = [&result, &orig_indices,
&source_index, &target_index]() {
if (orig_indices.first == *source_index) {
@@ -396,13 +397,17 @@
} // namespace
CodepointSpan ConvertIndicesBMPToUTF8(const std::string& utf8_str,
- CodepointSpan bmp_indices) {
- return ConvertIndicesBMPUTF8(utf8_str, bmp_indices, /*from_utf8=*/false);
+ const std::pair<int, int>& bmp_indices) {
+ const std::pair<int, int> utf8_indices =
+ ConvertIndicesBMPUTF8(utf8_str, bmp_indices, /*from_utf8=*/false);
+ return CodepointSpan(utf8_indices.first, utf8_indices.second);
}
-CodepointSpan ConvertIndicesUTF8ToBMP(const std::string& utf8_str,
- CodepointSpan utf8_indices) {
- return ConvertIndicesBMPUTF8(utf8_str, utf8_indices, /*from_utf8=*/true);
+std::pair<int, int> ConvertIndicesUTF8ToBMP(const std::string& utf8_str,
+ const CodepointSpan& utf8_indices) {
+ return ConvertIndicesBMPUTF8(
+ utf8_str, std::make_pair(utf8_indices.first, utf8_indices.second),
+ /*from_utf8=*/true);
}
StatusOr<ScopedLocalRef<jstring>> GetLocalesFromMmap(
@@ -456,10 +461,10 @@
using libtextclassifier3::FromJavaInputFragment;
using libtextclassifier3::FromJavaSelectionOptions;
using libtextclassifier3::InputFragment;
-using libtextclassifier3::ToStlString;
+using libtextclassifier3::JStringToUtf8String;
TC3_JNI_METHOD(jlong, TC3_ANNOTATOR_CLASS_NAME, nativeNewAnnotator)
-(JNIEnv* env, jobject thiz, jint fd) {
+(JNIEnv* env, jobject clazz, jint fd) {
std::shared_ptr<libtextclassifier3::JniCache> jni_cache(
libtextclassifier3::JniCache::Create(env));
#ifdef TC3_USE_JAVAICU
@@ -475,8 +480,9 @@
}
TC3_JNI_METHOD(jlong, TC3_ANNOTATOR_CLASS_NAME, nativeNewAnnotatorFromPath)
-(JNIEnv* env, jobject thiz, jstring path) {
- TC3_ASSIGN_OR_RETURN_0(const std::string path_str, ToStlString(env, path));
+(JNIEnv* env, jobject clazz, jstring path) {
+ TC3_ASSIGN_OR_RETURN_0(const std::string path_str,
+ JStringToUtf8String(env, path));
std::shared_ptr<libtextclassifier3::JniCache> jni_cache(
libtextclassifier3::JniCache::Create(env));
#ifdef TC3_USE_JAVAICU
@@ -492,7 +498,7 @@
}
TC3_JNI_METHOD(jlong, TC3_ANNOTATOR_CLASS_NAME, nativeNewAnnotatorWithOffset)
-(JNIEnv* env, jobject thiz, jint fd, jlong offset, jlong size) {
+(JNIEnv* env, jobject clazz, jint fd, jlong offset, jlong size) {
std::shared_ptr<libtextclassifier3::JniCache> jni_cache(
libtextclassifier3::JniCache::Create(env));
#ifdef TC3_USE_JAVAICU
@@ -516,13 +522,9 @@
Annotator* model = reinterpret_cast<AnnotatorJniContext*>(ptr)->model();
- std::string serialized_config_string;
- TC3_ASSIGN_OR_RETURN_FALSE(jsize length,
- JniHelper::GetArrayLength(env, serialized_config));
- serialized_config_string.resize(length);
- env->GetByteArrayRegion(serialized_config, 0, length,
- reinterpret_cast<jbyte*>(const_cast<char*>(
- serialized_config_string.data())));
+ TC3_ASSIGN_OR_RETURN_FALSE(
+ const std::string serialized_config_string,
+ libtextclassifier3::JByteArrayToString(env, serialized_config));
return model->InitializeKnowledgeEngine(serialized_config_string);
}
@@ -536,13 +538,9 @@
Annotator* model = reinterpret_cast<AnnotatorJniContext*>(ptr)->model();
- std::string serialized_config_string;
- TC3_ASSIGN_OR_RETURN_FALSE(jsize length,
- JniHelper::GetArrayLength(env, serialized_config));
- serialized_config_string.resize(length);
- env->GetByteArrayRegion(serialized_config, 0, length,
- reinterpret_cast<jbyte*>(const_cast<char*>(
- serialized_config_string.data())));
+ TC3_ASSIGN_OR_RETURN_FALSE(
+ const std::string serialized_config_string,
+ libtextclassifier3::JByteArrayToString(env, serialized_config));
return model->InitializeContactEngine(serialized_config_string);
}
@@ -556,13 +554,9 @@
Annotator* model = reinterpret_cast<AnnotatorJniContext*>(ptr)->model();
- std::string serialized_config_string;
- TC3_ASSIGN_OR_RETURN_FALSE(jsize length,
- JniHelper::GetArrayLength(env, serialized_config));
- serialized_config_string.resize(length);
- env->GetByteArrayRegion(serialized_config, 0, length,
- reinterpret_cast<jbyte*>(const_cast<char*>(
- serialized_config_string.data())));
+ TC3_ASSIGN_OR_RETURN_FALSE(
+ const std::string serialized_config_string,
+ libtextclassifier3::JByteArrayToString(env, serialized_config));
return model->InitializeInstalledAppEngine(serialized_config_string);
}
@@ -608,20 +602,23 @@
}
const Annotator* model = reinterpret_cast<AnnotatorJniContext*>(ptr)->model();
TC3_ASSIGN_OR_RETURN_NULL(const std::string context_utf8,
- ToStlString(env, context));
- CodepointSpan input_indices =
+ JStringToUtf8String(env, context));
+ const CodepointSpan input_indices =
ConvertIndicesBMPToUTF8(context_utf8, {selection_begin, selection_end});
TC3_ASSIGN_OR_RETURN_NULL(
libtextclassifier3::SelectionOptions selection_options,
FromJavaSelectionOptions(env, options));
CodepointSpan selection =
model->SuggestSelection(context_utf8, input_indices, selection_options);
- selection = ConvertIndicesUTF8ToBMP(context_utf8, selection);
+ const std::pair<int, int> selection_bmp =
+ ConvertIndicesUTF8ToBMP(context_utf8, selection);
TC3_ASSIGN_OR_RETURN_NULL(ScopedLocalRef<jintArray> result,
JniHelper::NewIntArray(env, 2));
- env->SetIntArrayRegion(result.get(), 0, 1, &(std::get<0>(selection)));
- env->SetIntArrayRegion(result.get(), 1, 1, &(std::get<1>(selection)));
+ TC3_RETURN_NULL_IF_ERROR(JniHelper::SetIntArrayRegion(
+ env, result.get(), 0, 1, &(selection_bmp.first)));
+ TC3_RETURN_NULL_IF_ERROR(JniHelper::SetIntArrayRegion(
+ env, result.get(), 1, 1, &(selection_bmp.second)));
return result.release();
}
@@ -636,7 +633,7 @@
reinterpret_cast<AnnotatorJniContext*>(ptr);
TC3_ASSIGN_OR_RETURN_NULL(const std::string context_utf8,
- ToStlString(env, context));
+ JStringToUtf8String(env, context));
const CodepointSpan input_indices =
ConvertIndicesBMPToUTF8(context_utf8, {selection_begin, selection_end});
TC3_ASSIGN_OR_RETURN_NULL(
@@ -672,7 +669,7 @@
const AnnotatorJniContext* model_context =
reinterpret_cast<AnnotatorJniContext*>(ptr);
TC3_ASSIGN_OR_RETURN_NULL(const std::string context_utf8,
- ToStlString(env, context));
+ JStringToUtf8String(env, context));
TC3_ASSIGN_OR_RETURN_NULL(
libtextclassifier3::AnnotationOptions annotation_options,
FromJavaAnnotationOptions(env, options));
@@ -696,7 +693,7 @@
JniHelper::NewObjectArray(env, annotations.size(), result_class.get()));
for (int i = 0; i < annotations.size(); ++i) {
- CodepointSpan span_bmp =
+ const std::pair<int, int> span_bmp =
ConvertIndicesUTF8ToBMP(context_utf8, annotations[i].span);
TC3_ASSIGN_OR_RETURN_NULL(
@@ -718,8 +715,7 @@
return results.release();
}
-TC3_JNI_METHOD(jobjectArray, TC3_ANNOTATOR_CLASS_NAME,
- nativeAnnotateStructuredInput)
+TC3_JNI_METHOD(jobject, TC3_ANNOTATOR_CLASS_NAME, nativeAnnotateStructuredInput)
(JNIEnv* env, jobject thiz, jlong ptr, jobjectArray jinput_fragments,
jobject options) {
if (!ptr) {
@@ -743,7 +739,7 @@
TC3_ASSIGN_OR_RETURN_NULL(
libtextclassifier3::AnnotationOptions annotation_options,
FromJavaAnnotationOptions(env, options));
- const StatusOr<std::vector<std::vector<AnnotatedSpan>>> annotations_or =
+ const StatusOr<Annotations> annotations_or =
model_context->model()->AnnotateStructuredInput(string_fragments,
annotation_options);
if (!annotations_or.ok()) {
@@ -752,8 +748,20 @@
return nullptr;
}
- std::vector<std::vector<AnnotatedSpan>> annotations =
- std::move(annotations_or.ValueOrDie());
+ Annotations annotations = std::move(annotations_or.ValueOrDie());
+ TC3_ASSIGN_OR_RETURN_NULL(
+ ScopedLocalRef<jclass> annotations_class,
+ JniHelper::FindClass(
+ env, TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR "$Annotations"));
+
+ TC3_ASSIGN_OR_RETURN_NULL(
+ jmethodID annotations_class_constructor,
+ JniHelper::GetMethodID(
+ env, annotations_class.get(), "<init>",
+ "([[L" TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR
+ "$AnnotatedSpan;[L" TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR
+ "$ClassificationResult;)V"));
+
TC3_ASSIGN_OR_RETURN_NULL(
ScopedLocalRef<jclass> span_class,
JniHelper::FindClass(
@@ -773,26 +781,28 @@
"$AnnotatedSpan;"));
TC3_ASSIGN_OR_RETURN_NULL(
- ScopedLocalRef<jobjectArray> results,
+ ScopedLocalRef<jobjectArray> annotated_spans,
JniHelper::NewObjectArray(env, input_size, span_class_array.get()));
- for (int fragment_index = 0; fragment_index < annotations.size();
- ++fragment_index) {
+ for (int fragment_index = 0;
+ fragment_index < annotations.annotated_spans.size(); ++fragment_index) {
TC3_ASSIGN_OR_RETURN_NULL(
ScopedLocalRef<jobjectArray> jfragmentAnnotations,
- JniHelper::NewObjectArray(env, annotations[fragment_index].size(),
- span_class.get()));
+ JniHelper::NewObjectArray(
+ env, annotations.annotated_spans[fragment_index].size(),
+ span_class.get()));
for (int annotation_index = 0;
- annotation_index < annotations[fragment_index].size();
+ annotation_index < annotations.annotated_spans[fragment_index].size();
++annotation_index) {
- CodepointSpan span_bmp = ConvertIndicesUTF8ToBMP(
+ const std::pair<int, int> span_bmp = ConvertIndicesUTF8ToBMP(
string_fragments[fragment_index].text,
- annotations[fragment_index][annotation_index].span);
+ annotations.annotated_spans[fragment_index][annotation_index].span);
TC3_ASSIGN_OR_RETURN_NULL(
ScopedLocalRef<jobjectArray> classification_results,
ClassificationResultsToJObjectArray(
env, model_context,
- annotations[fragment_index][annotation_index].classification));
+ annotations.annotated_spans[fragment_index][annotation_index]
+ .classification));
TC3_ASSIGN_OR_RETURN_NULL(
ScopedLocalRef<jobject> single_annotation,
JniHelper::NewObject(env, span_class.get(), span_class_constructor,
@@ -808,14 +818,26 @@
}
}
- if (!JniHelper::SetObjectArrayElement(env, results.get(), fragment_index,
+ if (!JniHelper::SetObjectArrayElement(env, annotated_spans.get(),
+ fragment_index,
jfragmentAnnotations.get())
.ok()) {
return nullptr;
}
}
- return results.release();
+ TC3_ASSIGN_OR_RETURN_NULL(
+ ScopedLocalRef<jobjectArray> topicality_results,
+ ClassificationResultsToJObjectArray(env, model_context,
+ annotations.topicality_results));
+
+ TC3_ASSIGN_OR_RETURN_NULL(
+ ScopedLocalRef<jobject> annotations_result,
+ JniHelper::NewObject(env, annotations_class.get(),
+ annotations_class_constructor, annotated_spans.get(),
+ topicality_results.get()));
+
+ return annotations_result.release();
}
TC3_JNI_METHOD(jbyteArray, TC3_ANNOTATOR_CLASS_NAME,
@@ -825,7 +847,8 @@
return nullptr;
}
const Annotator* model = reinterpret_cast<AnnotatorJniContext*>(ptr)->model();
- TC3_ASSIGN_OR_RETURN_NULL(const std::string id_utf8, ToStlString(env, id));
+ TC3_ASSIGN_OR_RETURN_NULL(const std::string id_utf8,
+ JStringToUtf8String(env, id));
std::string serialized_knowledge_result;
if (!model->LookUpKnowledgeEntity(id_utf8, &serialized_knowledge_result)) {
return nullptr;
@@ -834,9 +857,9 @@
TC3_ASSIGN_OR_RETURN_NULL(
ScopedLocalRef<jbyteArray> result,
JniHelper::NewByteArray(env, serialized_knowledge_result.size()));
- env->SetByteArrayRegion(
- result.get(), 0, serialized_knowledge_result.size(),
- reinterpret_cast<const jbyte*>(serialized_knowledge_result.data()));
+ TC3_RETURN_NULL_IF_ERROR(JniHelper::SetByteArrayRegion(
+ env, result.get(), 0, serialized_knowledge_result.size(),
+ reinterpret_cast<const jbyte*>(serialized_knowledge_result.data())));
return result.release();
}
@@ -864,7 +887,7 @@
}
TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME, nativeGetLocalesWithOffset)
-(JNIEnv* env, jobject thiz, jint fd, jlong offset, jlong size) {
+(JNIEnv* env, jobject clazz, jint fd, jlong offset, jlong size) {
const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
new libtextclassifier3::ScopedMmap(fd, offset, size));
TC3_ASSIGN_OR_RETURN_NULL(ScopedLocalRef<jstring> value,
@@ -880,7 +903,7 @@
}
TC3_JNI_METHOD(jint, TC3_ANNOTATOR_CLASS_NAME, nativeGetVersionWithOffset)
-(JNIEnv* env, jobject thiz, jint fd, jlong offset, jlong size) {
+(JNIEnv* env, jobject clazz, jint fd, jlong offset, jlong size) {
const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
new libtextclassifier3::ScopedMmap(fd, offset, size));
return GetVersionFromMmap(env, mmap.get());
@@ -896,7 +919,7 @@
}
TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME, nativeGetNameWithOffset)
-(JNIEnv* env, jobject thiz, jint fd, jlong offset, jlong size) {
+(JNIEnv* env, jobject clazz, jint fd, jlong offset, jlong size) {
const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
new libtextclassifier3::ScopedMmap(fd, offset, size));
TC3_ASSIGN_OR_RETURN_NULL(ScopedLocalRef<jstring> value,
diff --git a/native/annotator/annotator_jni.h b/native/annotator/annotator_jni.h
index 39a9d9a..0abaf46 100644
--- a/native/annotator/annotator_jni.h
+++ b/native/annotator/annotator_jni.h
@@ -29,13 +29,13 @@
// SmartSelection.
TC3_JNI_METHOD(jlong, TC3_ANNOTATOR_CLASS_NAME, nativeNewAnnotator)
-(JNIEnv* env, jobject thiz, jint fd);
+(JNIEnv* env, jobject clazz, jint fd);
TC3_JNI_METHOD(jlong, TC3_ANNOTATOR_CLASS_NAME, nativeNewAnnotatorFromPath)
-(JNIEnv* env, jobject thiz, jstring path);
+(JNIEnv* env, jobject clazz, jstring path);
TC3_JNI_METHOD(jlong, TC3_ANNOTATOR_CLASS_NAME, nativeNewAnnotatorWithOffset)
-(JNIEnv* env, jobject thiz, jint fd, jlong offset, jlong size);
+(JNIEnv* env, jobject clazz, jint fd, jlong offset, jlong size);
TC3_JNI_METHOD(jboolean, TC3_ANNOTATOR_CLASS_NAME,
nativeInitializeKnowledgeEngine)
@@ -68,8 +68,7 @@
jint selection_end, jobject options, jobject app_context,
jstring device_locales);
-TC3_JNI_METHOD(jobjectArray, TC3_ANNOTATOR_CLASS_NAME,
- nativeAnnotateStructuredInput)
+TC3_JNI_METHOD(jobject, TC3_ANNOTATOR_CLASS_NAME, nativeAnnotateStructuredInput)
(JNIEnv* env, jobject thiz, jlong ptr, jobjectArray jinput_fragments,
jobject options);
@@ -91,19 +90,19 @@
(JNIEnv* env, jobject clazz, jint fd);
TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME, nativeGetLocalesWithOffset)
-(JNIEnv* env, jobject thiz, jint fd, jlong offset, jlong size);
+(JNIEnv* env, jobject clazz, jint fd, jlong offset, jlong size);
TC3_JNI_METHOD(jint, TC3_ANNOTATOR_CLASS_NAME, nativeGetVersion)
(JNIEnv* env, jobject clazz, jint fd);
TC3_JNI_METHOD(jint, TC3_ANNOTATOR_CLASS_NAME, nativeGetVersionWithOffset)
-(JNIEnv* env, jobject thiz, jint fd, jlong offset, jlong size);
+(JNIEnv* env, jobject clazz, jint fd, jlong offset, jlong size);
TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME, nativeGetName)
(JNIEnv* env, jobject clazz, jint fd);
TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME, nativeGetNameWithOffset)
-(JNIEnv* env, jobject thiz, jint fd, jlong offset, jlong size);
+(JNIEnv* env, jobject clazz, jint fd, jlong offset, jlong size);
#ifdef __cplusplus
}
@@ -114,13 +113,13 @@
// Given a utf8 string and a span expressed in Java BMP (basic multilingual
// plane) codepoints, converts it to a span expressed in utf8 codepoints.
libtextclassifier3::CodepointSpan ConvertIndicesBMPToUTF8(
- const std::string& utf8_str, libtextclassifier3::CodepointSpan bmp_indices);
+ const std::string& utf8_str, const std::pair<int, int>& bmp_indices);
// Given a utf8 string and a span expressed in utf8 codepoints, converts it to a
// span expressed in Java BMP (basic multilingual plane) codepoints.
-libtextclassifier3::CodepointSpan ConvertIndicesUTF8ToBMP(
+std::pair<int, int> ConvertIndicesUTF8ToBMP(
const std::string& utf8_str,
- libtextclassifier3::CodepointSpan utf8_indices);
+ const libtextclassifier3::CodepointSpan& utf8_indices);
} // namespace libtextclassifier3
diff --git a/native/annotator/annotator_jni_common.cc b/native/annotator/annotator_jni_common.cc
index de58b70..155e038 100644
--- a/native/annotator/annotator_jni_common.cc
+++ b/native/annotator/annotator_jni_common.cc
@@ -16,6 +16,7 @@
#include "annotator/annotator_jni_common.h"
+#include "annotator/knowledge/knowledge-engine-types.h"
#include "utils/java/jni-base.h"
#include "utils/java/jni-helper.h"
@@ -26,13 +27,14 @@
JNIEnv* env, const jobject& jobject) {
std::unordered_set<std::string> entity_types;
jobjectArray jentity_types = reinterpret_cast<jobjectArray>(jobject);
- const int size = env->GetArrayLength(jentity_types);
+ TC3_ASSIGN_OR_RETURN(const int size,
+ JniHelper::GetArrayLength(env, jentity_types));
for (int i = 0; i < size; ++i) {
TC3_ASSIGN_OR_RETURN(
ScopedLocalRef<jstring> jentity_type,
JniHelper::GetObjectArrayElement<jstring>(env, jentity_types, i));
TC3_ASSIGN_OR_RETURN(std::string entity_type,
- ToStlString(env, jentity_type.get()));
+ JStringToUtf8String(env, jentity_type.get()));
entity_types.insert(entity_type);
}
return entity_types;
@@ -117,17 +119,27 @@
JniHelper::CallFloatMethod(
env, joptions, get_user_location_accuracy_meters));
+ // .getUsePodNer()
+ TC3_ASSIGN_OR_RETURN(
+ jmethodID get_use_pod_ner,
+ JniHelper::GetMethodID(env, options_class.get(), "getUsePodNer", "()Z"));
+ TC3_ASSIGN_OR_RETURN(bool use_pod_ner, JniHelper::CallBooleanMethod(
+ env, joptions, get_use_pod_ner));
+
T options;
- TC3_ASSIGN_OR_RETURN(options.locales, ToStlString(env, locales.get()));
+ TC3_ASSIGN_OR_RETURN(options.locales,
+ JStringToUtf8String(env, locales.get()));
TC3_ASSIGN_OR_RETURN(options.reference_timezone,
- ToStlString(env, reference_timezone.get()));
+ JStringToUtf8String(env, reference_timezone.get()));
options.reference_time_ms_utc = reference_time;
- TC3_ASSIGN_OR_RETURN(options.detected_text_language_tags,
- ToStlString(env, detected_text_language_tags.get()));
+ TC3_ASSIGN_OR_RETURN(
+ options.detected_text_language_tags,
+ JStringToUtf8String(env, detected_text_language_tags.get()));
options.annotation_usecase =
static_cast<AnnotationUsecase>(annotation_usecase);
options.location_context = {user_location_lat, user_location_lng,
user_location_accuracy_meters};
+ options.use_pod_ner = use_pod_ner;
return options;
}
} // namespace
@@ -154,6 +166,16 @@
ScopedLocalRef<jstring> locales,
JniHelper::CallObjectMethod<jstring>(env, joptions, get_locales));
+ // .getDetectedTextLanguageTags()
+ TC3_ASSIGN_OR_RETURN(jmethodID get_detected_text_language_tags_method,
+ JniHelper::GetMethodID(env, options_class.get(),
+ "getDetectedTextLanguageTags",
+ "()Ljava/lang/String;"));
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jstring> detected_text_language_tags,
+ JniHelper::CallObjectMethod<jstring>(
+ env, joptions, get_detected_text_language_tags_method));
+
// .getAnnotationUsecase()
TC3_ASSIGN_OR_RETURN(jmethodID get_annotation_usecase,
JniHelper::GetMethodID(env, options_class.get(),
@@ -162,11 +184,49 @@
int32 annotation_usecase,
JniHelper::CallIntMethod(env, joptions, get_annotation_usecase));
+ // .getUserLocationLat()
+ TC3_ASSIGN_OR_RETURN(jmethodID get_user_location_lat,
+ JniHelper::GetMethodID(env, options_class.get(),
+ "getUserLocationLat", "()D"));
+ TC3_ASSIGN_OR_RETURN(
+ double user_location_lat,
+ JniHelper::CallDoubleMethod(env, joptions, get_user_location_lat));
+
+ // .getUserLocationLng()
+ TC3_ASSIGN_OR_RETURN(jmethodID get_user_location_lng,
+ JniHelper::GetMethodID(env, options_class.get(),
+ "getUserLocationLng", "()D"));
+ TC3_ASSIGN_OR_RETURN(
+ double user_location_lng,
+ JniHelper::CallDoubleMethod(env, joptions, get_user_location_lng));
+
+ // .getUserLocationAccuracyMeters()
+ TC3_ASSIGN_OR_RETURN(
+ jmethodID get_user_location_accuracy_meters,
+ JniHelper::GetMethodID(env, options_class.get(),
+ "getUserLocationAccuracyMeters", "()F"));
+ TC3_ASSIGN_OR_RETURN(float user_location_accuracy_meters,
+ JniHelper::CallFloatMethod(
+ env, joptions, get_user_location_accuracy_meters));
+
+ // .getUsePodNer()
+ TC3_ASSIGN_OR_RETURN(
+ jmethodID get_use_pod_ner,
+ JniHelper::GetMethodID(env, options_class.get(), "getUsePodNer", "()Z"));
+ TC3_ASSIGN_OR_RETURN(bool use_pod_ner, JniHelper::CallBooleanMethod(
+ env, joptions, get_use_pod_ner));
+
SelectionOptions options;
- TC3_ASSIGN_OR_RETURN(options.locales, ToStlString(env, locales.get()));
+ TC3_ASSIGN_OR_RETURN(options.locales,
+ JStringToUtf8String(env, locales.get()));
options.annotation_usecase =
static_cast<AnnotationUsecase>(annotation_usecase);
-
+ TC3_ASSIGN_OR_RETURN(
+ options.detected_text_language_tags,
+ JStringToUtf8String(env, detected_text_language_tags.get()));
+ options.location_context = {user_location_lat, user_location_lng,
+ user_location_accuracy_meters};
+ options.use_pod_ner = use_pod_ner;
return options;
}
@@ -197,8 +257,9 @@
JniHelper::CallObjectMethod<jstring>(
env, joptions, get_user_familiar_language_tags));
- TC3_ASSIGN_OR_RETURN(classifier_options.user_familiar_language_tags,
- ToStlString(env, user_familiar_language_tags.get()));
+ TC3_ASSIGN_OR_RETURN(
+ classifier_options.user_familiar_language_tags,
+ JStringToUtf8String(env, user_familiar_language_tags.get()));
return classifier_options;
}
@@ -252,6 +313,13 @@
bool has_personalization_permission,
JniHelper::CallBooleanMethod(env, joptions,
has_personalization_permission_method));
+ // .getAnnotateMode()
+ TC3_ASSIGN_OR_RETURN(jmethodID get_annotate_mode,
+ JniHelper::GetMethodID(env, options_class.get(),
+ "getAnnotateMode", "()I"));
+ TC3_ASSIGN_OR_RETURN(
+ int32 annotate_mode,
+ JniHelper::CallIntMethod(env, joptions, get_annotate_mode));
TC3_ASSIGN_OR_RETURN(
AnnotationOptions annotation_options,
@@ -266,6 +334,7 @@
has_location_permission;
annotation_options.permissions.has_personalization_permission =
has_personalization_permission;
+ annotation_options.annotate_mode = static_cast<AnnotateMode>(annotate_mode);
return annotation_options;
}
@@ -290,7 +359,7 @@
ScopedLocalRef<jstring> text,
JniHelper::CallObjectMethod<jstring>(env, jfragment, get_text));
- TC3_ASSIGN_OR_RETURN(fragment.text, ToStlString(env, text.get()));
+ TC3_ASSIGN_OR_RETURN(fragment.text, JStringToUtf8String(env, text.get()));
// .hasDatetimeOptions()
TC3_ASSIGN_OR_RETURN(jmethodID has_date_time_options_method,
@@ -323,7 +392,7 @@
env, jfragment, get_reference_timezone_method));
TC3_ASSIGN_OR_RETURN(std::string reference_timezone,
- ToStlString(env, jreference_timezone.get()));
+ JStringToUtf8String(env, jreference_timezone.get()));
fragment.datetime_options =
DatetimeOptions{.reference_time_ms_utc = reference_time,
diff --git a/native/annotator/annotator_jni_test.cc b/native/annotator/annotator_jni_test.cc
index 929fb59..a48f173 100644
--- a/native/annotator/annotator_jni_test.cc
+++ b/native/annotator/annotator_jni_test.cc
@@ -24,52 +24,52 @@
TEST(Annotator, ConvertIndicesBMPUTF8) {
// Test boundary cases.
- EXPECT_EQ(ConvertIndicesBMPToUTF8("hello", {0, 5}), std::make_pair(0, 5));
+ EXPECT_EQ(ConvertIndicesBMPToUTF8("hello", {0, 5}), CodepointSpan(0, 5));
EXPECT_EQ(ConvertIndicesUTF8ToBMP("hello", {0, 5}), std::make_pair(0, 5));
EXPECT_EQ(ConvertIndicesBMPToUTF8("hello world", {0, 5}),
- std::make_pair(0, 5));
+ CodepointSpan(0, 5));
EXPECT_EQ(ConvertIndicesUTF8ToBMP("hello world", {0, 5}),
std::make_pair(0, 5));
EXPECT_EQ(ConvertIndicesBMPToUTF8("😁ello world", {0, 6}),
- std::make_pair(0, 5));
+ CodepointSpan(0, 5));
EXPECT_EQ(ConvertIndicesUTF8ToBMP("😁ello world", {0, 5}),
std::make_pair(0, 6));
EXPECT_EQ(ConvertIndicesBMPToUTF8("hello world", {6, 11}),
- std::make_pair(6, 11));
+ CodepointSpan(6, 11));
EXPECT_EQ(ConvertIndicesUTF8ToBMP("hello world", {6, 11}),
std::make_pair(6, 11));
EXPECT_EQ(ConvertIndicesBMPToUTF8("hello worl😁", {6, 12}),
- std::make_pair(6, 11));
+ CodepointSpan(6, 11));
EXPECT_EQ(ConvertIndicesUTF8ToBMP("hello worl😁", {6, 11}),
std::make_pair(6, 12));
// Simple example where the longer character is before the selection.
// character 😁 is 0x1f601
EXPECT_EQ(ConvertIndicesBMPToUTF8("😁 Hello World.", {3, 8}),
- std::make_pair(2, 7));
+ CodepointSpan(2, 7));
EXPECT_EQ(ConvertIndicesUTF8ToBMP("😁 Hello World.", {2, 7}),
std::make_pair(3, 8));
// Longer character is before and in selection.
EXPECT_EQ(ConvertIndicesBMPToUTF8("😁 Hell😁 World.", {3, 9}),
- std::make_pair(2, 7));
+ CodepointSpan(2, 7));
EXPECT_EQ(ConvertIndicesUTF8ToBMP("😁 Hell😁 World.", {2, 7}),
std::make_pair(3, 9));
// Longer character is before and after selection.
EXPECT_EQ(ConvertIndicesBMPToUTF8("😁 Hello😁World.", {3, 8}),
- std::make_pair(2, 7));
+ CodepointSpan(2, 7));
EXPECT_EQ(ConvertIndicesUTF8ToBMP("😁 Hello😁World.", {2, 7}),
std::make_pair(3, 8));
// Longer character is before in after selection.
EXPECT_EQ(ConvertIndicesBMPToUTF8("😁 Hell😁😁World.", {3, 9}),
- std::make_pair(2, 7));
+ CodepointSpan(2, 7));
EXPECT_EQ(ConvertIndicesUTF8ToBMP("😁 Hell😁😁World.", {2, 7}),
std::make_pair(3, 9));
diff --git a/native/annotator/cached-features.cc b/native/annotator/cached-features.cc
index 480c044..1a14a42 100644
--- a/native/annotator/cached-features.cc
+++ b/native/annotator/cached-features.cc
@@ -88,10 +88,9 @@
click_pos -= extraction_span_.first;
AppendFeaturesInternal(
- /*intended_span=*/ExpandTokenSpan(SingleTokenSpan(click_pos),
- options_->context_size(),
- options_->context_size()),
- /*read_mask_span=*/{0, TokenSpanSize(extraction_span_)}, output_features);
+ /*intended_span=*/TokenSpan(click_pos).Expand(options_->context_size(),
+ options_->context_size()),
+ /*read_mask_span=*/{0, extraction_span_.Size()}, output_features);
}
void CachedFeatures::AppendBoundsSensitiveFeaturesForSpan(
@@ -118,16 +117,15 @@
/*intended_span=*/{selected_span.second -
config->num_tokens_inside_right(),
selected_span.second + config->num_tokens_after()},
- /*read_mask_span=*/{selected_span.first, TokenSpanSize(extraction_span_)},
- output_features);
+ /*read_mask_span=*/
+ {selected_span.first, extraction_span_.Size()}, output_features);
if (config->include_inside_bag()) {
AppendBagFeatures(selected_span, output_features);
}
if (config->include_inside_length()) {
- output_features->push_back(
- static_cast<float>(TokenSpanSize(selected_span)));
+ output_features->push_back(static_cast<float>(selected_span.Size()));
}
}
@@ -161,7 +159,7 @@
for (int i = bag_span.first; i < bag_span.second; ++i) {
for (int j = 0; j < NumFeaturesPerToken(); ++j) {
(*output_features)[offset + j] +=
- (*features_)[i * NumFeaturesPerToken() + j] / TokenSpanSize(bag_span);
+ (*features_)[i * NumFeaturesPerToken() + j] / bag_span.Size();
}
}
}
diff --git a/native/annotator/datetime/extractor.cc b/native/annotator/datetime/extractor.cc
index b8e1b7a..c42ddf0 100644
--- a/native/annotator/datetime/extractor.cc
+++ b/native/annotator/datetime/extractor.cc
@@ -16,6 +16,8 @@
#include "annotator/datetime/extractor.h"
+#include "annotator/model_generated.h"
+#include "annotator/types.h"
#include "utils/base/logging.h"
namespace libtextclassifier3 {
@@ -162,6 +164,18 @@
}
break;
}
+ case DatetimeGroupType_GROUP_ABSOLUTETIME: {
+ std::unordered_map<DatetimeComponent::ComponentType, int> values;
+ if (!ParseAbsoluteDateValues(group_text, &values)) {
+ TC3_LOG(ERROR) << "Couldn't extract Component values.";
+ return false;
+ }
+ for (const std::pair<const DatetimeComponent::ComponentType, int>&
+ date_time_pair : values) {
+ result->SetAbsoluteValue(date_time_pair.first, date_time_pair.second);
+ }
+ break;
+ }
case DatetimeGroupType_GROUP_DUMMY1:
case DatetimeGroupType_GROUP_DUMMY2:
break;
@@ -417,6 +431,26 @@
return false;
}
+bool DatetimeExtractor::ParseAbsoluteDateValues(
+ const UnicodeText& input,
+ std::unordered_map<DatetimeComponent::ComponentType, int>* values) const {
+ if (MapInput(input,
+ {
+ {DatetimeExtractorType_NOON,
+ {{DatetimeComponent::ComponentType::MERIDIEM, 1},
+ {DatetimeComponent::ComponentType::MINUTE, 0},
+ {DatetimeComponent::ComponentType::HOUR, 12}}},
+ {DatetimeExtractorType_MIDNIGHT,
+ {{DatetimeComponent::ComponentType::MERIDIEM, 0},
+ {DatetimeComponent::ComponentType::MINUTE, 0},
+ {DatetimeComponent::ComponentType::HOUR, 0}}},
+ },
+ values)) {
+ return true;
+ }
+ return false;
+}
+
bool DatetimeExtractor::ParseMeridiem(const UnicodeText& input,
int* parsed_meridiem) const {
return MapInput(input,
diff --git a/native/annotator/datetime/extractor.h b/native/annotator/datetime/extractor.h
index 0f92b2a..3f2b755 100644
--- a/native/annotator/datetime/extractor.h
+++ b/native/annotator/datetime/extractor.h
@@ -96,9 +96,19 @@
const UnicodeText& input,
DatetimeComponent::ComponentType* parsed_field_type) const;
bool ParseDayOfWeek(const UnicodeText& input, int* parsed_day_of_week) const;
+
bool ParseRelationAndConvertToRelativeCount(const UnicodeText& input,
int* relative_count) const;
+ // There are some special words which represent multiple date time components
+ // e.g. if the text says “by noon” it clearly indicates that the hour is 12,
+ // minute is 0 and meridiam is PM.
+ // The method handles such tokens and translates them into multiple date time
+ // components.
+ bool ParseAbsoluteDateValues(
+ const UnicodeText& input,
+ std::unordered_map<DatetimeComponent::ComponentType, int>* values) const;
+
const CompiledRule& rule_;
const UniLib::RegexMatcher& matcher_;
int locale_id_;
diff --git a/native/annotator/duration/duration.cc b/native/annotator/duration/duration.cc
index 07b9885..c59b8e0 100644
--- a/native/annotator/duration/duration.cc
+++ b/native/annotator/duration/duration.cc
@@ -22,6 +22,7 @@
#include "annotator/collections.h"
#include "annotator/types.h"
#include "utils/base/logging.h"
+#include "utils/base/macros.h"
#include "utils/strings/numbers.h"
#include "utils/utf8/unicodetext.h"
@@ -100,6 +101,24 @@
return result;
}
+// Get the dangling quantity unit e.g. for 2 hours 10, 10 would have the unit
+// "minute".
+DurationUnit GetDanglingQuantityUnit(const DurationUnit main_unit) {
+ switch (main_unit) {
+ case DurationUnit::HOUR:
+ return DurationUnit::MINUTE;
+ case DurationUnit::MINUTE:
+ return DurationUnit::SECOND;
+ case DurationUnit::UNKNOWN:
+ TC3_LOG(ERROR) << "Requesting parse of UNKNOWN duration duration_unit.";
+ TC3_FALLTHROUGH_INTENDED;
+ case DurationUnit::WEEK:
+ case DurationUnit::DAY:
+ case DurationUnit::SECOND:
+ // We only support dangling units for hours and minutes.
+ return DurationUnit::UNKNOWN;
+ }
+}
} // namespace internal
bool DurationAnnotator::ClassifyText(
@@ -201,26 +220,32 @@
const bool parse_ended_without_unit_for_last_mentioned_quantity =
has_quantity;
+ if (parse_ended_without_unit_for_last_mentioned_quantity) {
+ const DurationUnit main_unit = parsed_duration_atoms.rbegin()->unit;
+ if (parsed_duration.plus_half) {
+ // Process "and half" suffix.
+ end_index = quantity_end_index;
+ ParsedDurationAtom atom = ParsedDurationAtom::Half();
+ atom.unit = main_unit;
+ parsed_duration_atoms.push_back(atom);
+ } else if (options_->enable_dangling_quantity_interpretation()) {
+ // Process dangling quantity.
+ ParsedDurationAtom atom;
+ atom.value = parsed_duration.value;
+ atom.unit = GetDanglingQuantityUnit(main_unit);
+ if (atom.unit != DurationUnit::UNKNOWN) {
+ end_index = quantity_end_index;
+ parsed_duration_atoms.push_back(atom);
+ }
+ }
+ }
+
ClassificationResult classification{Collections::Duration(),
options_->score()};
classification.priority_score = options_->priority_score();
classification.duration_ms =
ParsedDurationAtomsToMillis(parsed_duration_atoms);
- // Process suffix expressions like "and half" that don't have the
- // duration_unit explicitly mentioned.
- if (parse_ended_without_unit_for_last_mentioned_quantity) {
- if (parsed_duration.plus_half) {
- end_index = quantity_end_index;
- ParsedDurationAtom atom = ParsedDurationAtom::Half();
- atom.unit = parsed_duration_atoms.rbegin()->unit;
- classification.duration_ms += ParsedDurationAtomsToMillis({atom});
- } else if (options_->enable_dangling_quantity_interpretation()) {
- end_index = quantity_end_index;
- // TODO(b/144752747) Add dangling quantity to duration_ms.
- }
- }
-
result->span = feature_processor_->StripBoundaryCodepoints(
context, {start_index, end_index});
result->classification.push_back(classification);
@@ -256,7 +281,7 @@
break;
}
- int64 value = atom.value;
+ double value = atom.value;
// This condition handles expressions like "an hour", where the quantity is
// not specified. In this case we assume quantity 1. Except for cases like
// "half hour".
@@ -287,8 +312,8 @@
return true;
}
- int32 parsed_value;
- if (ParseInt32(lowercase_token_value.c_str(), &parsed_value)) {
+ double parsed_value;
+ if (ParseDouble(lowercase_token_value.c_str(), &parsed_value)) {
value->value = parsed_value;
return true;
}
diff --git a/native/annotator/duration/duration.h b/native/annotator/duration/duration.h
index db4bdae..1a42ac3 100644
--- a/native/annotator/duration/duration.h
+++ b/native/annotator/duration/duration.h
@@ -98,7 +98,7 @@
internal::DurationUnit unit = internal::DurationUnit::UNKNOWN;
// Quantity of the duration unit.
- int value = 0;
+ double value = 0;
// True, if half an unit was specified (either in addition, or exclusively).
// E.g. "hour and a half".
diff --git a/native/annotator/duration/duration_test.cc b/native/annotator/duration/duration_test.cc
index a0985a2..f5e0510 100644
--- a/native/annotator/duration/duration_test.cc
+++ b/native/annotator/duration/duration_test.cc
@@ -435,21 +435,57 @@
3.5 * 60 * 1000)))))));
}
-TEST_F(DurationAnnotatorTest, CorrectlyAnnotatesSpanWithDanglingQuantity) {
+TEST_F(DurationAnnotatorTest, FindsDurationWithDanglingQuantity) {
const UnicodeText text = UTF8ToUnicodeText("20 minutes 10");
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(duration_annotator_.FindAll(
text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
- // TODO(b/144752747) Include test for duration_ms.
EXPECT_THAT(
result,
ElementsAre(
AllOf(Field(&AnnotatedSpan::span, CodepointSpan(0, 13)),
Field(&AnnotatedSpan::classification,
- ElementsAre(AllOf(Field(&ClassificationResult::collection,
- "duration")))))));
+ ElementsAre(AllOf(
+ Field(&ClassificationResult::collection, "duration"),
+ Field(&ClassificationResult::duration_ms,
+ 20 * 60 * 1000 + 10 * 1000)))))));
+}
+
+TEST_F(DurationAnnotatorTest, FindsDurationWithDanglingQuantityNotSupported) {
+ const UnicodeText text = UTF8ToUnicodeText("20 seconds 10");
+ std::vector<Token> tokens = Tokenize(text);
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(duration_annotator_.FindAll(
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+
+ EXPECT_THAT(
+ result,
+ ElementsAre(AllOf(
+ Field(&AnnotatedSpan::span, CodepointSpan(0, 10)),
+ Field(&AnnotatedSpan::classification,
+ ElementsAre(AllOf(
+ Field(&ClassificationResult::collection, "duration"),
+ Field(&ClassificationResult::duration_ms, 20 * 1000)))))));
+}
+
+TEST_F(DurationAnnotatorTest, FindsDurationWithDecimalQuantity) {
+ const UnicodeText text = UTF8ToUnicodeText("in 10.2 hours");
+ std::vector<Token> tokens = Tokenize(text);
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(duration_annotator_.FindAll(
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+
+ EXPECT_THAT(
+ result,
+ ElementsAre(
+ AllOf(Field(&AnnotatedSpan::span, CodepointSpan(3, 13)),
+ Field(&AnnotatedSpan::classification,
+ ElementsAre(AllOf(
+ Field(&ClassificationResult::collection, "duration"),
+ Field(&ClassificationResult::duration_ms,
+ 10 * 60 * 60 * 1000 + 12 * 60 * 1000)))))));
}
const DurationAnnotatorOptions* TestingJapaneseDurationAnnotatorOptions() {
@@ -472,7 +508,7 @@
options.half_expressions.push_back("半");
options.require_quantity = true;
- options.enable_dangling_quantity_interpretation = false;
+ options.enable_dangling_quantity_interpretation = true;
flatbuffers::FlatBufferBuilder builder;
builder.Finish(DurationAnnotatorOptions::Pack(builder, &options));
@@ -545,7 +581,7 @@
EXPECT_THAT(result, IsEmpty());
}
-TEST_F(JapaneseDurationAnnotatorTest, IgnoresDanglingQuantity) {
+TEST_F(JapaneseDurationAnnotatorTest, FindsDurationWithDanglingQuantity) {
const UnicodeText text = UTF8ToUnicodeText("2 分 10 の アラーム");
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
@@ -555,12 +591,12 @@
EXPECT_THAT(
result,
ElementsAre(
- AllOf(Field(&AnnotatedSpan::span, CodepointSpan(0, 3)),
+ AllOf(Field(&AnnotatedSpan::span, CodepointSpan(0, 6)),
Field(&AnnotatedSpan::classification,
ElementsAre(AllOf(
Field(&ClassificationResult::collection, "duration"),
Field(&ClassificationResult::duration_ms,
- 2 * 60 * 1000)))))));
+ 2 * 60 * 1000 + 10 * 1000)))))));
}
} // namespace
diff --git a/native/annotator/entity-data.fbs b/native/annotator/entity-data.fbs
index 4c02f6d..f82eb44 100755
--- a/native/annotator/entity-data.fbs
+++ b/native/annotator/entity-data.fbs
@@ -175,11 +175,24 @@
// Whole part of the amount (e.g. 123 from "CHF 123.45").
amount_whole_part:int;
- // Decimal part of the amount (e.g. 45 from "CHF 123.45").
+ // Decimal part of the amount (e.g. 45 from "CHF 123.45"). Will be
+ // deprecated, use nanos instead.
amount_decimal_part:int;
// Money amount (e.g. 123.45 from "CHF 123.45").
unnormalized_amount:string (shared);
+
+ // Number of nano (10^-9) units of the amount fractional part.
+ // The value must be between -999,999,999 and +999,999,999 inclusive.
+ // If `units` is positive, `nanos` must be positive or zero.
+ // If `units` is zero, `nanos` can be positive, zero, or negative.
+ // If `units` is negative, `nanos` must be negative or zero.
+ // For example $-1.75 is represented as `amount_whole_part`=-1 and
+ // `nanos`=-750,000,000.
+ nanos:int;
+
+ // Money quantity (e.g. k from "CHF 123.45k").
+ quantity:string (shared);
}
namespace libtextclassifier3.EntityData_.Translate_;
diff --git a/native/annotator/experimental/experimental-dummy.h b/native/annotator/experimental/experimental-dummy.h
index 389aae1..1c57c7e 100644
--- a/native/annotator/experimental/experimental-dummy.h
+++ b/native/annotator/experimental/experimental-dummy.h
@@ -39,7 +39,7 @@
bool Annotate(const UnicodeText& context,
std::vector<AnnotatedSpan>* candidates) const {
- return false;
+ return true;
}
AnnotatedSpan SuggestSelection(const UnicodeText& context,
@@ -47,9 +47,9 @@
return {click, {}};
}
- bool ClassifyText(const UnicodeText& context, CodepointSpan click,
- ClassificationResult* result) const {
- return false;
+ bool ClassifyText(const UnicodeText& context, CodepointSpan selection_indices,
+ std::vector<AnnotatedSpan>& candidates) const {
+ return true;
}
};
diff --git a/native/annotator/experimental/experimental.fbs b/native/annotator/experimental/experimental.fbs
index 6e15d04..5d69d17 100755
--- a/native/annotator/experimental/experimental.fbs
+++ b/native/annotator/experimental/experimental.fbs
@@ -14,6 +14,8 @@
// limitations under the License.
//
+include "utils/container/bit-vector.fbs";
+
namespace libtextclassifier3;
table ExperimentalModel {
}
diff --git a/native/annotator/feature-processor.cc b/native/annotator/feature-processor.cc
index 8d08574..3831c5f 100644
--- a/native/annotator/feature-processor.cc
+++ b/native/annotator/feature-processor.cc
@@ -67,8 +67,8 @@
extractor_options.extract_selection_mask_feature =
options->extract_selection_mask_feature();
if (options->regexp_feature() != nullptr) {
- for (const auto& regexp_feauture : *options->regexp_feature()) {
- extractor_options.regexp_features.push_back(regexp_feauture->str());
+ for (const auto& regexp_feature : *options->regexp_feature()) {
+ extractor_options.regexp_features.push_back(regexp_feature->str());
}
}
extractor_options.remap_digits = options->remap_digits();
@@ -82,7 +82,7 @@
return extractor_options;
}
-void SplitTokensOnSelectionBoundaries(CodepointSpan selection,
+void SplitTokensOnSelectionBoundaries(const CodepointSpan& selection,
std::vector<Token>* tokens) {
for (auto it = tokens->begin(); it != tokens->end(); ++it) {
const UnicodeText token_word =
@@ -137,7 +137,7 @@
} // namespace internal
void FeatureProcessor::StripTokensFromOtherLines(
- const std::string& context, CodepointSpan span,
+ const std::string& context, const CodepointSpan& span,
std::vector<Token>* tokens) const {
const UnicodeText context_unicode = UTF8ToUnicodeText(context,
/*do_copy=*/false);
@@ -145,7 +145,7 @@
}
void FeatureProcessor::StripTokensFromOtherLines(
- const UnicodeText& context_unicode, CodepointSpan span,
+ const UnicodeText& context_unicode, const CodepointSpan& span,
std::vector<Token>* tokens) const {
std::vector<UnicodeTextRange> lines =
SplitContext(context_unicode, options_->use_pipe_character_for_newline());
@@ -198,9 +198,9 @@
return tokenizer_.Tokenize(text_unicode);
}
-bool FeatureProcessor::LabelToSpan(
- const int label, const VectorSpan<Token>& tokens,
- std::pair<CodepointIndex, CodepointIndex>* span) const {
+bool FeatureProcessor::LabelToSpan(const int label,
+ const VectorSpan<Token>& tokens,
+ CodepointSpan* span) const {
if (tokens.size() != GetNumContextTokens()) {
return false;
}
@@ -221,7 +221,7 @@
if (result_begin_codepoint == kInvalidIndex ||
result_end_codepoint == kInvalidIndex) {
- *span = CodepointSpan({kInvalidIndex, kInvalidIndex});
+ *span = CodepointSpan::kInvalid;
} else {
const UnicodeText token_begin_unicode =
UTF8ToUnicodeText(result_begin_token.value, /*do_copy=*/false);
@@ -241,8 +241,8 @@
if (begin_ignored == (result_end_codepoint - result_begin_codepoint)) {
*span = {result_begin_codepoint, result_begin_codepoint};
} else {
- *span = CodepointSpan({result_begin_codepoint + begin_ignored,
- result_end_codepoint - end_ignored});
+ *span = CodepointSpan(result_begin_codepoint + begin_ignored,
+ result_end_codepoint - end_ignored);
}
}
return true;
@@ -258,9 +258,9 @@
}
}
-bool FeatureProcessor::SpanToLabel(
- const std::pair<CodepointIndex, CodepointIndex>& span,
- const std::vector<Token>& tokens, int* label) const {
+bool FeatureProcessor::SpanToLabel(const CodepointSpan& span,
+ const std::vector<Token>& tokens,
+ int* label) const {
if (tokens.size() != GetNumContextTokens()) {
return false;
}
@@ -323,8 +323,8 @@
return true;
}
-int FeatureProcessor::TokenSpanToLabel(const TokenSpan& span) const {
- auto it = selection_to_label_.find(span);
+int FeatureProcessor::TokenSpanToLabel(const TokenSpan& token_span) const {
+ auto it = selection_to_label_.find(token_span);
if (it != selection_to_label_.end()) {
return it->second;
} else {
@@ -333,10 +333,10 @@
}
TokenSpan CodepointSpanToTokenSpan(const std::vector<Token>& selectable_tokens,
- CodepointSpan codepoint_span,
+ const CodepointSpan& codepoint_span,
bool snap_boundaries_to_containing_tokens) {
- const int codepoint_start = std::get<0>(codepoint_span);
- const int codepoint_end = std::get<1>(codepoint_span);
+ const int codepoint_start = codepoint_span.first;
+ const int codepoint_end = codepoint_span.second;
TokenIndex start_token = kInvalidIndex;
TokenIndex end_token = kInvalidIndex;
@@ -360,7 +360,7 @@
}
CodepointSpan TokenSpanToCodepointSpan(
- const std::vector<Token>& selectable_tokens, TokenSpan token_span) {
+ const std::vector<Token>& selectable_tokens, const TokenSpan& token_span) {
return {selectable_tokens[token_span.first].start,
selectable_tokens[token_span.second - 1].end};
}
@@ -369,9 +369,9 @@
// Finds a single token that completely contains the given span.
int FindTokenThatContainsSpan(const std::vector<Token>& selectable_tokens,
- CodepointSpan codepoint_span) {
- const int codepoint_start = std::get<0>(codepoint_span);
- const int codepoint_end = std::get<1>(codepoint_span);
+ const CodepointSpan& codepoint_span) {
+ const int codepoint_start = codepoint_span.first;
+ const int codepoint_end = codepoint_span.second;
for (int i = 0; i < selectable_tokens.size(); ++i) {
if (codepoint_start >= selectable_tokens[i].start &&
@@ -386,12 +386,12 @@
namespace internal {
-int CenterTokenFromClick(CodepointSpan span,
+int CenterTokenFromClick(const CodepointSpan& span,
const std::vector<Token>& selectable_tokens) {
- int range_begin;
- int range_end;
- std::tie(range_begin, range_end) =
+ const TokenSpan token_span =
CodepointSpanToTokenSpan(selectable_tokens, span);
+ int range_begin = token_span.first;
+ int range_end = token_span.second;
// If no exact match was found, try finding a token that completely contains
// the click span. This is useful e.g. when Android builds the selection
@@ -414,11 +414,11 @@
}
int CenterTokenFromMiddleOfSelection(
- CodepointSpan span, const std::vector<Token>& selectable_tokens) {
- int range_begin;
- int range_end;
- std::tie(range_begin, range_end) =
+ const CodepointSpan& span, const std::vector<Token>& selectable_tokens) {
+ const TokenSpan token_span =
CodepointSpanToTokenSpan(selectable_tokens, span);
+ const int range_begin = token_span.first;
+ const int range_end = token_span.second;
// Center the clicked token in the selection range.
if (range_begin != kInvalidIndex && range_end != kInvalidIndex) {
@@ -430,7 +430,7 @@
} // namespace internal
-int FeatureProcessor::FindCenterToken(CodepointSpan span,
+int FeatureProcessor::FindCenterToken(const CodepointSpan& span,
const std::vector<Token>& tokens) const {
if (options_->center_token_selection_method() ==
FeatureProcessorOptions_::
@@ -464,7 +464,7 @@
const VectorSpan<Token> tokens,
std::vector<CodepointSpan>* selection_label_spans) const {
for (int i = 0; i < label_to_selection_.size(); ++i) {
- CodepointSpan span;
+ CodepointSpan span = CodepointSpan::kInvalid;
if (!LabelToSpan(i, tokens, &span)) {
TC3_LOG(ERROR) << "Could not convert label to span: " << i;
return false;
@@ -486,15 +486,6 @@
const UnicodeText::const_iterator& span_start,
const UnicodeText::const_iterator& span_end,
bool count_from_beginning) const {
- return CountIgnoredSpanBoundaryCodepoints(span_start, span_end,
- count_from_beginning,
- ignored_span_boundary_codepoints_);
-}
-
-int FeatureProcessor::CountIgnoredSpanBoundaryCodepoints(
- const UnicodeText::const_iterator& span_start,
- const UnicodeText::const_iterator& span_end, bool count_from_beginning,
- const std::unordered_set<int>& ignored_span_boundary_codepoints) const {
if (span_start == span_end) {
return 0;
}
@@ -517,8 +508,8 @@
// Move until we encounter a non-ignored character.
int num_ignored = 0;
- while (ignored_span_boundary_codepoints.find(*it) !=
- ignored_span_boundary_codepoints.end()) {
+ while (ignored_span_boundary_codepoints_.find(*it) !=
+ ignored_span_boundary_codepoints_.end()) {
++num_ignored;
if (it == it_last) {
@@ -571,37 +562,15 @@
}
CodepointSpan FeatureProcessor::StripBoundaryCodepoints(
- const std::string& context, CodepointSpan span) const {
- return StripBoundaryCodepoints(context, span,
- ignored_span_boundary_codepoints_,
- ignored_span_boundary_codepoints_);
-}
-
-CodepointSpan FeatureProcessor::StripBoundaryCodepoints(
- const std::string& context, CodepointSpan span,
- const std::unordered_set<int>& ignored_prefix_span_boundary_codepoints,
- const std::unordered_set<int>& ignored_suffix_span_boundary_codepoints)
- const {
+ const std::string& context, const CodepointSpan& span) const {
const UnicodeText context_unicode =
UTF8ToUnicodeText(context, /*do_copy=*/false);
- return StripBoundaryCodepoints(context_unicode, span,
- ignored_prefix_span_boundary_codepoints,
- ignored_suffix_span_boundary_codepoints);
+ return StripBoundaryCodepoints(context_unicode, span);
}
CodepointSpan FeatureProcessor::StripBoundaryCodepoints(
- const UnicodeText& context_unicode, CodepointSpan span) const {
- return StripBoundaryCodepoints(context_unicode, span,
- ignored_span_boundary_codepoints_,
- ignored_span_boundary_codepoints_);
-}
-
-CodepointSpan FeatureProcessor::StripBoundaryCodepoints(
- const UnicodeText& context_unicode, CodepointSpan span,
- const std::unordered_set<int>& ignored_prefix_span_boundary_codepoints,
- const std::unordered_set<int>& ignored_suffix_span_boundary_codepoints)
- const {
- if (context_unicode.empty() || !ValidNonEmptySpan(span)) {
+ const UnicodeText& context_unicode, const CodepointSpan& span) const {
+ if (context_unicode.empty() || !span.IsValid() || span.IsEmpty()) {
return span;
}
@@ -610,35 +579,21 @@
UnicodeText::const_iterator span_end = context_unicode.begin();
std::advance(span_end, span.second);
- return StripBoundaryCodepoints(span_begin, span_end, span,
- ignored_prefix_span_boundary_codepoints,
- ignored_suffix_span_boundary_codepoints);
+ return StripBoundaryCodepoints(span_begin, span_end, span);
}
CodepointSpan FeatureProcessor::StripBoundaryCodepoints(
const UnicodeText::const_iterator& span_begin,
- const UnicodeText::const_iterator& span_end, CodepointSpan span) const {
- return StripBoundaryCodepoints(span_begin, span_end, span,
- ignored_span_boundary_codepoints_,
- ignored_span_boundary_codepoints_);
-}
-
-CodepointSpan FeatureProcessor::StripBoundaryCodepoints(
- const UnicodeText::const_iterator& span_begin,
- const UnicodeText::const_iterator& span_end, CodepointSpan span,
- const std::unordered_set<int>& ignored_prefix_span_boundary_codepoints,
- const std::unordered_set<int>& ignored_suffix_span_boundary_codepoints)
- const {
- if (!ValidNonEmptySpan(span) || span_begin == span_end) {
+ const UnicodeText::const_iterator& span_end,
+ const CodepointSpan& span) const {
+ if (!span.IsValid() || span.IsEmpty() || span_begin == span_end) {
return span;
}
const int start_offset = CountIgnoredSpanBoundaryCodepoints(
- span_begin, span_end, /*count_from_beginning=*/true,
- ignored_prefix_span_boundary_codepoints);
+ span_begin, span_end, /*count_from_beginning=*/true);
const int end_offset = CountIgnoredSpanBoundaryCodepoints(
- span_begin, span_end, /*count_from_beginning=*/false,
- ignored_suffix_span_boundary_codepoints);
+ span_begin, span_end, /*count_from_beginning=*/false);
if (span.first + start_offset < span.second - end_offset) {
return {span.first + start_offset, span.second - end_offset};
@@ -670,21 +625,10 @@
const std::string& FeatureProcessor::StripBoundaryCodepoints(
const std::string& value, std::string* buffer) const {
- return StripBoundaryCodepoints(value, buffer,
- ignored_span_boundary_codepoints_,
- ignored_span_boundary_codepoints_);
-}
-
-const std::string& FeatureProcessor::StripBoundaryCodepoints(
- const std::string& value, std::string* buffer,
- const std::unordered_set<int>& ignored_prefix_span_boundary_codepoints,
- const std::unordered_set<int>& ignored_suffix_span_boundary_codepoints)
- const {
const UnicodeText value_unicode = UTF8ToUnicodeText(value, /*do_copy=*/false);
const CodepointSpan initial_span{0, value_unicode.size_codepoints()};
- const CodepointSpan stripped_span = StripBoundaryCodepoints(
- value_unicode, initial_span, ignored_prefix_span_boundary_codepoints,
- ignored_suffix_span_boundary_codepoints);
+ const CodepointSpan stripped_span =
+ StripBoundaryCodepoints(value_unicode, initial_span);
if (initial_span != stripped_span) {
const UnicodeText stripped_token_value =
@@ -735,7 +679,7 @@
}
void FeatureProcessor::RetokenizeAndFindClick(const std::string& context,
- CodepointSpan input_span,
+ const CodepointSpan& input_span,
bool only_use_line_with_click,
std::vector<Token>* tokens,
int* click_pos) const {
@@ -746,7 +690,7 @@
}
void FeatureProcessor::RetokenizeAndFindClick(
- const UnicodeText& context_unicode, CodepointSpan input_span,
+ const UnicodeText& context_unicode, const CodepointSpan& input_span,
bool only_use_line_with_click, std::vector<Token>* tokens,
int* click_pos) const {
TC3_CHECK(tokens != nullptr);
@@ -773,7 +717,7 @@
namespace internal {
-void StripOrPadTokens(TokenSpan relative_click_span, int context_size,
+void StripOrPadTokens(const TokenSpan& relative_click_span, int context_size,
std::vector<Token>* tokens, int* click_pos) {
int right_context_needed = relative_click_span.second + context_size;
if (*click_pos + right_context_needed + 1 >= tokens->size()) {
@@ -810,7 +754,7 @@
} // namespace internal
bool FeatureProcessor::HasEnoughSupportedCodepoints(
- const std::vector<Token>& tokens, TokenSpan token_span) const {
+ const std::vector<Token>& tokens, const TokenSpan& token_span) const {
if (options_->min_supported_codepoint_ratio() > 0) {
const float supported_codepoint_ratio =
SupportedCodepointsRatio(token_span, tokens);
@@ -824,13 +768,13 @@
}
bool FeatureProcessor::ExtractFeatures(
- const std::vector<Token>& tokens, TokenSpan token_span,
- CodepointSpan selection_span_for_feature,
+ const std::vector<Token>& tokens, const TokenSpan& token_span,
+ const CodepointSpan& selection_span_for_feature,
const EmbeddingExecutor* embedding_executor,
EmbeddingCache* embedding_cache, int feature_vector_size,
std::unique_ptr<CachedFeatures>* cached_features) const {
std::unique_ptr<std::vector<float>> features(new std::vector<float>());
- features->reserve(feature_vector_size * TokenSpanSize(token_span));
+ features->reserve(feature_vector_size * token_span.Size());
for (int i = token_span.first; i < token_span.second; ++i) {
if (!AppendTokenFeaturesWithCache(tokens[i], selection_span_for_feature,
embedding_executor, embedding_cache,
@@ -862,7 +806,7 @@
}
bool FeatureProcessor::AppendTokenFeaturesWithCache(
- const Token& token, CodepointSpan selection_span_for_feature,
+ const Token& token, const CodepointSpan& selection_span_for_feature,
const EmbeddingExecutor* embedding_executor,
EmbeddingCache* embedding_cache,
std::vector<float>* output_features) const {
diff --git a/native/annotator/feature-processor.h b/native/annotator/feature-processor.h
index 78dbbce..3b865b0 100644
--- a/native/annotator/feature-processor.h
+++ b/native/annotator/feature-processor.h
@@ -49,22 +49,23 @@
// Splits tokens that contain the selection boundary inside them.
// E.g. "foo{bar}@google.com" -> "foo", "bar", "@google.com"
-void SplitTokensOnSelectionBoundaries(CodepointSpan selection,
+void SplitTokensOnSelectionBoundaries(const CodepointSpan& selection,
std::vector<Token>* tokens);
// Returns the index of token that corresponds to the codepoint span.
-int CenterTokenFromClick(CodepointSpan span, const std::vector<Token>& tokens);
+int CenterTokenFromClick(const CodepointSpan& span,
+ const std::vector<Token>& tokens);
// Returns the index of token that corresponds to the middle of the codepoint
// span.
int CenterTokenFromMiddleOfSelection(
- CodepointSpan span, const std::vector<Token>& selectable_tokens);
+ const CodepointSpan& span, const std::vector<Token>& selectable_tokens);
// Strips the tokens from the tokens vector that are not used for feature
// extraction because they are out of scope, or pads them so that there is
// enough tokens in the required context_size for all inferences with a click
// in relative_click_span.
-void StripOrPadTokens(TokenSpan relative_click_span, int context_size,
+void StripOrPadTokens(const TokenSpan& relative_click_span, int context_size,
std::vector<Token>* tokens, int* click_pos);
} // namespace internal
@@ -74,12 +75,13 @@
// token to overlap with the codepoint range to be considered part of it.
// Otherwise it must be fully included in the range.
TokenSpan CodepointSpanToTokenSpan(
- const std::vector<Token>& selectable_tokens, CodepointSpan codepoint_span,
+ const std::vector<Token>& selectable_tokens,
+ const CodepointSpan& codepoint_span,
bool snap_boundaries_to_containing_tokens = false);
// Converts a token span to a codepoint span in the given list of tokens.
CodepointSpan TokenSpanToCodepointSpan(
- const std::vector<Token>& selectable_tokens, TokenSpan token_span);
+ const std::vector<Token>& selectable_tokens, const TokenSpan& token_span);
// Takes care of preparing features for the span prediction model.
class FeatureProcessor {
@@ -132,25 +134,26 @@
// Retokenizes the context and input span, and finds the click position.
// Depending on the options, might modify tokens (split them or remove them).
void RetokenizeAndFindClick(const std::string& context,
- CodepointSpan input_span,
+ const CodepointSpan& input_span,
bool only_use_line_with_click,
std::vector<Token>* tokens, int* click_pos) const;
// Same as above but takes UnicodeText.
void RetokenizeAndFindClick(const UnicodeText& context_unicode,
- CodepointSpan input_span,
+ const CodepointSpan& input_span,
bool only_use_line_with_click,
std::vector<Token>* tokens, int* click_pos) const;
// Returns true if the token span has enough supported codepoints (as defined
// in the model config) or not and model should not run.
bool HasEnoughSupportedCodepoints(const std::vector<Token>& tokens,
- TokenSpan token_span) const;
+ const TokenSpan& token_span) const;
// Extracts features as a CachedFeatures object that can be used for repeated
// inference over token spans in the given context.
- bool ExtractFeatures(const std::vector<Token>& tokens, TokenSpan token_span,
- CodepointSpan selection_span_for_feature,
+ bool ExtractFeatures(const std::vector<Token>& tokens,
+ const TokenSpan& token_span,
+ const CodepointSpan& selection_span_for_feature,
const EmbeddingExecutor* embedding_executor,
EmbeddingCache* embedding_cache, int feature_vector_size,
std::unique_ptr<CachedFeatures>* cached_features) const;
@@ -177,38 +180,17 @@
// start and end indices. If the span comprises entirely of boundary
// codepoints, the first index of span is returned for both indices.
CodepointSpan StripBoundaryCodepoints(const std::string& context,
- CodepointSpan span) const;
-
- // Same as previous, but also takes the ignored span boundary codepoints.
- CodepointSpan StripBoundaryCodepoints(
- const std::string& context, CodepointSpan span,
- const std::unordered_set<int>& ignored_prefix_span_boundary_codepoints,
- const std::unordered_set<int>& ignored_suffix_span_boundary_codepoints)
- const;
+ const CodepointSpan& span) const;
// Same as above but takes UnicodeText.
CodepointSpan StripBoundaryCodepoints(const UnicodeText& context_unicode,
- CodepointSpan span) const;
-
- // Same as the previous, but also takes the ignored span boundary codepoints.
- CodepointSpan StripBoundaryCodepoints(
- const UnicodeText& context_unicode, CodepointSpan span,
- const std::unordered_set<int>& ignored_prefix_span_boundary_codepoints,
- const std::unordered_set<int>& ignored_suffix_span_boundary_codepoints)
- const;
+ const CodepointSpan& span) const;
// Same as above but takes a pair of iterators for the span, for efficiency.
CodepointSpan StripBoundaryCodepoints(
const UnicodeText::const_iterator& span_begin,
- const UnicodeText::const_iterator& span_end, CodepointSpan span) const;
-
- // Same as previous, but also takes the ignored span boundary codepoints.
- CodepointSpan StripBoundaryCodepoints(
- const UnicodeText::const_iterator& span_begin,
- const UnicodeText::const_iterator& span_end, CodepointSpan span,
- const std::unordered_set<int>& ignored_prefix_span_boundary_codepoints,
- const std::unordered_set<int>& ignored_suffix_span_boundary_codepoints)
- const;
+ const UnicodeText::const_iterator& span_end,
+ const CodepointSpan& span) const;
// Same as above, but takes an optional buffer for saving the modified value.
// As an optimization, returns pointer to 'value' if nothing was stripped, or
@@ -216,13 +198,6 @@
const std::string& StripBoundaryCodepoints(const std::string& value,
std::string* buffer) const;
- // Same as previous, but also takes the ignored span boundary codepoints.
- const std::string& StripBoundaryCodepoints(
- const std::string& value, std::string* buffer,
- const std::unordered_set<int>& ignored_prefix_span_boundary_codepoints,
- const std::unordered_set<int>& ignored_suffix_span_boundary_codepoints)
- const;
-
protected:
// Returns the class id corresponding to the given string collection
// identifier. There is a catch-all class id that the function returns for
@@ -245,11 +220,11 @@
CodepointSpan* span) const;
// Converts a span to the corresponding label given output_tokens.
- bool SpanToLabel(const std::pair<CodepointIndex, CodepointIndex>& span,
+ bool SpanToLabel(const CodepointSpan& span,
const std::vector<Token>& output_tokens, int* label) const;
// Converts a token span to the corresponding label.
- int TokenSpanToLabel(const std::pair<TokenIndex, TokenIndex>& span) const;
+ int TokenSpanToLabel(const TokenSpan& token_span) const;
// Returns the ratio of supported codepoints to total number of codepoints in
// the given token span.
@@ -268,35 +243,30 @@
const UnicodeText::const_iterator& span_end,
bool count_from_beginning) const;
- // Same as previous, but also takes the ignored span boundary codepoints.
- int CountIgnoredSpanBoundaryCodepoints(
- const UnicodeText::const_iterator& span_start,
- const UnicodeText::const_iterator& span_end, bool count_from_beginning,
- const std::unordered_set<int>& ignored_span_boundary_codepoints) const;
-
// Finds the center token index in tokens vector, using the method defined
// in options_.
- int FindCenterToken(CodepointSpan span,
+ int FindCenterToken(const CodepointSpan& span,
const std::vector<Token>& tokens) const;
// Removes all tokens from tokens that are not on a line (defined by calling
// SplitContext on the context) to which span points.
- void StripTokensFromOtherLines(const std::string& context, CodepointSpan span,
+ void StripTokensFromOtherLines(const std::string& context,
+ const CodepointSpan& span,
std::vector<Token>* tokens) const;
// Same as above but takes UnicodeText.
void StripTokensFromOtherLines(const UnicodeText& context_unicode,
- CodepointSpan span,
+ const CodepointSpan& span,
std::vector<Token>* tokens) const;
// Extracts the features of a token and appends them to the output vector.
// Uses the embedding cache to to avoid re-extracting the re-embedding the
// sparse features for the same token.
- bool AppendTokenFeaturesWithCache(const Token& token,
- CodepointSpan selection_span_for_feature,
- const EmbeddingExecutor* embedding_executor,
- EmbeddingCache* embedding_cache,
- std::vector<float>* output_features) const;
+ bool AppendTokenFeaturesWithCache(
+ const Token& token, const CodepointSpan& selection_span_for_feature,
+ const EmbeddingExecutor* embedding_executor,
+ EmbeddingCache* embedding_cache,
+ std::vector<float>* output_features) const;
protected:
const TokenFeatureExtractor feature_extractor_;
diff --git a/native/annotator/feature-processor_test.cc b/native/annotator/feature-processor_test.cc
new file mode 100644
index 0000000..86f25e4
--- /dev/null
+++ b/native/annotator/feature-processor_test.cc
@@ -0,0 +1,1050 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/feature-processor.h"
+
+#include "annotator/model-executor.h"
+#include "utils/tensor-view.h"
+#include "utils/utf8/unicodetext.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+using testing::ElementsAreArray;
+using testing::FloatEq;
+using testing::Matcher;
+
+flatbuffers::DetachedBuffer PackFeatureProcessorOptions(
+ const FeatureProcessorOptionsT& options) {
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(CreateFeatureProcessorOptions(builder, &options));
+ return builder.Release();
+}
+
+template <typename T>
+std::vector<T> Subvector(const std::vector<T>& vector, int start, int end) {
+ return std::vector<T>(vector.begin() + start, vector.begin() + end);
+}
+
+Matcher<std::vector<float>> ElementsAreFloat(const std::vector<float>& values) {
+ std::vector<Matcher<float>> matchers;
+ for (const float value : values) {
+ matchers.push_back(FloatEq(value));
+ }
+ return ElementsAreArray(matchers);
+}
+
+class TestingFeatureProcessor : public FeatureProcessor {
+ public:
+ using FeatureProcessor::CountIgnoredSpanBoundaryCodepoints;
+ using FeatureProcessor::FeatureProcessor;
+ using FeatureProcessor::SpanToLabel;
+ using FeatureProcessor::StripTokensFromOtherLines;
+ using FeatureProcessor::supported_codepoint_ranges_;
+ using FeatureProcessor::SupportedCodepointsRatio;
+};
+
+// EmbeddingExecutor that always returns features based on
+class FakeEmbeddingExecutor : public EmbeddingExecutor {
+ public:
+ bool AddEmbedding(const TensorView<int>& sparse_features, float* dest,
+ int dest_size) const override {
+ TC3_CHECK_GE(dest_size, 4);
+ EXPECT_EQ(sparse_features.size(), 1);
+ dest[0] = sparse_features.data()[0];
+ dest[1] = sparse_features.data()[0];
+ dest[2] = -sparse_features.data()[0];
+ dest[3] = -sparse_features.data()[0];
+ return true;
+ }
+
+ private:
+ std::vector<float> storage_;
+};
+
+class AnnotatorFeatureProcessorTest : public ::testing::Test {
+ protected:
+ AnnotatorFeatureProcessorTest() : INIT_UNILIB_FOR_TESTING(unilib_) {}
+ UniLib unilib_;
+};
+
+TEST_F(AnnotatorFeatureProcessorTest, SplitTokensOnSelectionBoundariesMiddle) {
+ std::vector<Token> tokens{Token("Hělló", 0, 5),
+ Token("fěěbař@google.com", 6, 23),
+ Token("heře!", 24, 29)};
+
+ internal::SplitTokensOnSelectionBoundaries({9, 12}, &tokens);
+
+ // clang-format off
+ EXPECT_THAT(tokens, ElementsAreArray(
+ {Token("Hělló", 0, 5),
+ Token("fěě", 6, 9),
+ Token("bař", 9, 12),
+ Token("@google.com", 12, 23),
+ Token("heře!", 24, 29)}));
+ // clang-format on
+}
+
+TEST_F(AnnotatorFeatureProcessorTest, SplitTokensOnSelectionBoundariesBegin) {
+ std::vector<Token> tokens{Token("Hělló", 0, 5),
+ Token("fěěbař@google.com", 6, 23),
+ Token("heře!", 24, 29)};
+
+ internal::SplitTokensOnSelectionBoundaries({6, 12}, &tokens);
+
+ // clang-format off
+ EXPECT_THAT(tokens, ElementsAreArray(
+ {Token("Hělló", 0, 5),
+ Token("fěěbař", 6, 12),
+ Token("@google.com", 12, 23),
+ Token("heře!", 24, 29)}));
+ // clang-format on
+}
+
+TEST_F(AnnotatorFeatureProcessorTest, SplitTokensOnSelectionBoundariesEnd) {
+ std::vector<Token> tokens{Token("Hělló", 0, 5),
+ Token("fěěbař@google.com", 6, 23),
+ Token("heře!", 24, 29)};
+
+ internal::SplitTokensOnSelectionBoundaries({9, 23}, &tokens);
+
+ // clang-format off
+ EXPECT_THAT(tokens, ElementsAreArray(
+ {Token("Hělló", 0, 5),
+ Token("fěě", 6, 9),
+ Token("bař@google.com", 9, 23),
+ Token("heře!", 24, 29)}));
+ // clang-format on
+}
+
+TEST_F(AnnotatorFeatureProcessorTest, SplitTokensOnSelectionBoundariesWhole) {
+ std::vector<Token> tokens{Token("Hělló", 0, 5),
+ Token("fěěbař@google.com", 6, 23),
+ Token("heře!", 24, 29)};
+
+ internal::SplitTokensOnSelectionBoundaries({6, 23}, &tokens);
+
+ // clang-format off
+ EXPECT_THAT(tokens, ElementsAreArray(
+ {Token("Hělló", 0, 5),
+ Token("fěěbař@google.com", 6, 23),
+ Token("heře!", 24, 29)}));
+ // clang-format on
+}
+
+TEST_F(AnnotatorFeatureProcessorTest,
+ SplitTokensOnSelectionBoundariesCrossToken) {
+ std::vector<Token> tokens{Token("Hělló", 0, 5),
+ Token("fěěbař@google.com", 6, 23),
+ Token("heře!", 24, 29)};
+
+ internal::SplitTokensOnSelectionBoundaries({2, 9}, &tokens);
+
+ // clang-format off
+ EXPECT_THAT(tokens, ElementsAreArray(
+ {Token("Hě", 0, 2),
+ Token("lló", 2, 5),
+ Token("fěě", 6, 9),
+ Token("bař@google.com", 9, 23),
+ Token("heře!", 24, 29)}));
+ // clang-format on
+}
+
+TEST_F(AnnotatorFeatureProcessorTest, KeepLineWithClickFirst) {
+ FeatureProcessorOptionsT options;
+ options.only_use_line_with_click = true;
+ flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
+ &unilib_);
+
+ const std::string context = "Fiřst Lině\nSěcond Lině\nThiřd Lině";
+ const CodepointSpan span = {0, 5};
+ // clang-format off
+ std::vector<Token> tokens = {Token("Fiřst", 0, 5),
+ Token("Lině", 6, 10),
+ Token("Sěcond", 11, 17),
+ Token("Lině", 18, 22),
+ Token("Thiřd", 23, 28),
+ Token("Lině", 29, 33)};
+ // clang-format on
+
+ // Keeps the first line.
+ feature_processor.StripTokensFromOtherLines(context, span, &tokens);
+ EXPECT_THAT(tokens,
+ ElementsAreArray({Token("Fiřst", 0, 5), Token("Lině", 6, 10)}));
+}
+
+TEST_F(AnnotatorFeatureProcessorTest, KeepLineWithClickSecond) {
+ FeatureProcessorOptionsT options;
+ options.only_use_line_with_click = true;
+ flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
+ &unilib_);
+
+ const std::string context = "Fiřst Lině\nSěcond Lině\nThiřd Lině";
+ const CodepointSpan span = {18, 22};
+ // clang-format off
+ std::vector<Token> tokens = {Token("Fiřst", 0, 5),
+ Token("Lině", 6, 10),
+ Token("Sěcond", 11, 17),
+ Token("Lině", 18, 22),
+ Token("Thiřd", 23, 28),
+ Token("Lině", 29, 33)};
+ // clang-format on
+
+ // Keeps the first line.
+ feature_processor.StripTokensFromOtherLines(context, span, &tokens);
+ EXPECT_THAT(tokens, ElementsAreArray(
+ {Token("Sěcond", 11, 17), Token("Lině", 18, 22)}));
+}
+
+TEST_F(AnnotatorFeatureProcessorTest, KeepLineWithClickThird) {
+ FeatureProcessorOptionsT options;
+ options.only_use_line_with_click = true;
+ flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
+ &unilib_);
+
+ const std::string context = "Fiřst Lině\nSěcond Lině\nThiřd Lině";
+ const CodepointSpan span = {24, 33};
+ // clang-format off
+ std::vector<Token> tokens = {Token("Fiřst", 0, 5),
+ Token("Lině", 6, 10),
+ Token("Sěcond", 11, 17),
+ Token("Lině", 18, 22),
+ Token("Thiřd", 23, 28),
+ Token("Lině", 29, 33)};
+ // clang-format on
+
+ // Keeps the first line.
+ feature_processor.StripTokensFromOtherLines(context, span, &tokens);
+ EXPECT_THAT(tokens, ElementsAreArray(
+ {Token("Thiřd", 23, 28), Token("Lině", 29, 33)}));
+}
+
+TEST_F(AnnotatorFeatureProcessorTest, KeepLineWithClickSecondWithPipe) {
+ FeatureProcessorOptionsT options;
+ options.only_use_line_with_click = true;
+ flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
+ &unilib_);
+
+ const std::string context = "Fiřst Lině|Sěcond Lině\nThiřd Lině";
+ const CodepointSpan span = {18, 22};
+ // clang-format off
+ std::vector<Token> tokens = {Token("Fiřst", 0, 5),
+ Token("Lině", 6, 10),
+ Token("Sěcond", 11, 17),
+ Token("Lině", 18, 22),
+ Token("Thiřd", 23, 28),
+ Token("Lině", 29, 33)};
+ // clang-format on
+
+ // Keeps the first line.
+ feature_processor.StripTokensFromOtherLines(context, span, &tokens);
+ EXPECT_THAT(tokens, ElementsAreArray(
+ {Token("Sěcond", 11, 17), Token("Lině", 18, 22)}));
+}
+
+TEST_F(AnnotatorFeatureProcessorTest,
+ KeepLineWithClickAndDoNotUsePipeAsNewLineCharacter) {
+ FeatureProcessorOptionsT options;
+ options.only_use_line_with_click = true;
+ options.use_pipe_character_for_newline = false;
+ flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
+ &unilib_);
+
+ const std::string context = "Fiřst Lině|Sěcond Lině\nThiřd Lině";
+ const CodepointSpan span = {18, 22};
+ // clang-format off
+ std::vector<Token> tokens = {Token("Fiřst", 0, 5),
+ Token("Lině|Sěcond", 6, 17),
+ Token("Lině", 18, 22),
+ Token("Thiřd", 23, 28),
+ Token("Lině", 29, 33)};
+ // clang-format on
+
+ // Keeps the first line.
+ feature_processor.StripTokensFromOtherLines(context, span, &tokens);
+ EXPECT_THAT(tokens, ElementsAreArray({Token("Fiřst", 0, 5),
+ Token("Lině|Sěcond", 6, 17),
+ Token("Lině", 18, 22)}));
+}
+
+TEST_F(AnnotatorFeatureProcessorTest, ShouldSplitLinesOnPipe) {
+ FeatureProcessorOptionsT options;
+ options.use_pipe_character_for_newline = true;
+ flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
+ &unilib_);
+
+ const std::string context = "Fiřst Lině|Sěcond Lině\nThiřd Lině";
+ const UnicodeText context_unicode = UTF8ToUnicodeText(context,
+ /*do_copy=*/false);
+
+ const std::vector<UnicodeTextRange>& lines = feature_processor.SplitContext(
+ context_unicode, options.use_pipe_character_for_newline);
+ EXPECT_EQ(lines.size(), 3);
+ EXPECT_EQ(UnicodeText::UTF8Substring(lines[0].first, lines[0].second),
+ "Fiřst Lině");
+ EXPECT_EQ(UnicodeText::UTF8Substring(lines[1].first, lines[1].second),
+ "Sěcond Lině");
+ EXPECT_EQ(UnicodeText::UTF8Substring(lines[2].first, lines[2].second),
+ "Thiřd Lině");
+}
+
+TEST_F(AnnotatorFeatureProcessorTest, ShouldNotSplitLinesOnPipe) {
+ FeatureProcessorOptionsT options;
+ options.use_pipe_character_for_newline = false;
+ flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
+ &unilib_);
+
+ const std::string context = "Fiřst Lině|Sěcond Lině\nThiřd Lině";
+ const UnicodeText context_unicode = UTF8ToUnicodeText(context,
+ /*do_copy=*/false);
+
+ const std::vector<UnicodeTextRange>& lines = feature_processor.SplitContext(
+ context_unicode, options.use_pipe_character_for_newline);
+ EXPECT_EQ(lines.size(), 2);
+ EXPECT_EQ(UnicodeText::UTF8Substring(lines[0].first, lines[0].second),
+ "Fiřst Lině|Sěcond Lině");
+ EXPECT_EQ(UnicodeText::UTF8Substring(lines[1].first, lines[1].second),
+ "Thiřd Lině");
+}
+
+TEST_F(AnnotatorFeatureProcessorTest, KeepLineWithCrosslineClick) {
+ FeatureProcessorOptionsT options;
+ options.only_use_line_with_click = true;
+ flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
+ &unilib_);
+
+ const std::string context = "Fiřst Lině\nSěcond Lině\nThiřd Lině";
+ const CodepointSpan span = {5, 23};
+ // clang-format off
+ std::vector<Token> tokens = {Token("Fiřst", 0, 5),
+ Token("Lině", 6, 10),
+ Token("Sěcond", 18, 23),
+ Token("Lině", 19, 23),
+ Token("Thiřd", 23, 28),
+ Token("Lině", 29, 33)};
+ // clang-format on
+
+ // Keeps the first line.
+ feature_processor.StripTokensFromOtherLines(context, span, &tokens);
+ EXPECT_THAT(tokens, ElementsAreArray(
+ {Token("Fiřst", 0, 5), Token("Lině", 6, 10),
+ Token("Sěcond", 18, 23), Token("Lině", 19, 23),
+ Token("Thiřd", 23, 28), Token("Lině", 29, 33)}));
+}
+
+TEST_F(AnnotatorFeatureProcessorTest, SpanToLabel) {
+ FeatureProcessorOptionsT options;
+ options.context_size = 1;
+ options.max_selection_span = 1;
+ options.snap_label_span_boundaries_to_containing_tokens = false;
+
+ options.tokenization_codepoint_config.emplace_back(
+ new TokenizationCodepointRangeT());
+ auto& config = options.tokenization_codepoint_config.back();
+ config->start = 32;
+ config->end = 33;
+ config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR;
+
+ flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
+ &unilib_);
+ std::vector<Token> tokens = feature_processor.Tokenize("one, two, three");
+ ASSERT_EQ(3, tokens.size());
+ int label;
+ ASSERT_TRUE(feature_processor.SpanToLabel({5, 8}, tokens, &label));
+ EXPECT_EQ(kInvalidLabel, label);
+ ASSERT_TRUE(feature_processor.SpanToLabel({5, 9}, tokens, &label));
+ EXPECT_NE(kInvalidLabel, label);
+ TokenSpan token_span;
+ feature_processor.LabelToTokenSpan(label, &token_span);
+ EXPECT_EQ(0, token_span.first);
+ EXPECT_EQ(0, token_span.second);
+
+ // Reconfigure with snapping enabled.
+ options.snap_label_span_boundaries_to_containing_tokens = true;
+ flatbuffers::DetachedBuffer options2_fb =
+ PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor2(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options2_fb.data()),
+ &unilib_);
+ int label2;
+ ASSERT_TRUE(feature_processor2.SpanToLabel({5, 8}, tokens, &label2));
+ EXPECT_EQ(label, label2);
+ ASSERT_TRUE(feature_processor2.SpanToLabel({6, 9}, tokens, &label2));
+ EXPECT_EQ(label, label2);
+ ASSERT_TRUE(feature_processor2.SpanToLabel({5, 9}, tokens, &label2));
+ EXPECT_EQ(label, label2);
+
+ // Cross a token boundary.
+ ASSERT_TRUE(feature_processor2.SpanToLabel({4, 9}, tokens, &label2));
+ EXPECT_EQ(kInvalidLabel, label2);
+ ASSERT_TRUE(feature_processor2.SpanToLabel({5, 10}, tokens, &label2));
+ EXPECT_EQ(kInvalidLabel, label2);
+
+ // Multiple tokens.
+ options.context_size = 2;
+ options.max_selection_span = 2;
+ flatbuffers::DetachedBuffer options3_fb =
+ PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor3(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options3_fb.data()),
+ &unilib_);
+ tokens = feature_processor3.Tokenize("zero, one, two, three, four");
+ ASSERT_TRUE(feature_processor3.SpanToLabel({6, 15}, tokens, &label2));
+ EXPECT_NE(kInvalidLabel, label2);
+ feature_processor3.LabelToTokenSpan(label2, &token_span);
+ EXPECT_EQ(1, token_span.first);
+ EXPECT_EQ(0, token_span.second);
+
+ int label3;
+ ASSERT_TRUE(feature_processor3.SpanToLabel({6, 14}, tokens, &label3));
+ EXPECT_EQ(label2, label3);
+ ASSERT_TRUE(feature_processor3.SpanToLabel({6, 13}, tokens, &label3));
+ EXPECT_EQ(label2, label3);
+ ASSERT_TRUE(feature_processor3.SpanToLabel({7, 13}, tokens, &label3));
+ EXPECT_EQ(label2, label3);
+}
+
+TEST_F(AnnotatorFeatureProcessorTest, SpanToLabelIgnoresPunctuation) {
+ FeatureProcessorOptionsT options;
+ options.context_size = 1;
+ options.max_selection_span = 1;
+ options.snap_label_span_boundaries_to_containing_tokens = false;
+
+ options.tokenization_codepoint_config.emplace_back(
+ new TokenizationCodepointRangeT());
+ auto& config = options.tokenization_codepoint_config.back();
+ config->start = 32;
+ config->end = 33;
+ config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR;
+
+ flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
+ &unilib_);
+ std::vector<Token> tokens = feature_processor.Tokenize("one, two, three");
+ ASSERT_EQ(3, tokens.size());
+ int label;
+ ASSERT_TRUE(feature_processor.SpanToLabel({5, 8}, tokens, &label));
+ EXPECT_EQ(kInvalidLabel, label);
+ ASSERT_TRUE(feature_processor.SpanToLabel({5, 9}, tokens, &label));
+ EXPECT_NE(kInvalidLabel, label);
+ TokenSpan token_span;
+ feature_processor.LabelToTokenSpan(label, &token_span);
+ EXPECT_EQ(0, token_span.first);
+ EXPECT_EQ(0, token_span.second);
+
+ // Reconfigure with snapping enabled.
+ options.snap_label_span_boundaries_to_containing_tokens = true;
+ flatbuffers::DetachedBuffer options2_fb =
+ PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor2(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options2_fb.data()),
+ &unilib_);
+ int label2;
+ ASSERT_TRUE(feature_processor2.SpanToLabel({5, 8}, tokens, &label2));
+ EXPECT_EQ(label, label2);
+ ASSERT_TRUE(feature_processor2.SpanToLabel({6, 9}, tokens, &label2));
+ EXPECT_EQ(label, label2);
+ ASSERT_TRUE(feature_processor2.SpanToLabel({5, 9}, tokens, &label2));
+ EXPECT_EQ(label, label2);
+
+ // Cross a token boundary.
+ ASSERT_TRUE(feature_processor2.SpanToLabel({4, 9}, tokens, &label2));
+ EXPECT_EQ(kInvalidLabel, label2);
+ ASSERT_TRUE(feature_processor2.SpanToLabel({5, 10}, tokens, &label2));
+ EXPECT_EQ(kInvalidLabel, label2);
+
+ // Multiple tokens.
+ options.context_size = 2;
+ options.max_selection_span = 2;
+ flatbuffers::DetachedBuffer options3_fb =
+ PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor3(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options3_fb.data()),
+ &unilib_);
+ tokens = feature_processor3.Tokenize("zero, one, two, three, four");
+ ASSERT_TRUE(feature_processor3.SpanToLabel({6, 15}, tokens, &label2));
+ EXPECT_NE(kInvalidLabel, label2);
+ feature_processor3.LabelToTokenSpan(label2, &token_span);
+ EXPECT_EQ(1, token_span.first);
+ EXPECT_EQ(0, token_span.second);
+
+ int label3;
+ ASSERT_TRUE(feature_processor3.SpanToLabel({6, 14}, tokens, &label3));
+ EXPECT_EQ(label2, label3);
+ ASSERT_TRUE(feature_processor3.SpanToLabel({6, 13}, tokens, &label3));
+ EXPECT_EQ(label2, label3);
+ ASSERT_TRUE(feature_processor3.SpanToLabel({7, 13}, tokens, &label3));
+ EXPECT_EQ(label2, label3);
+}
+
+TEST_F(AnnotatorFeatureProcessorTest, CenterTokenFromClick) {
+ int token_index;
+
+ // Exactly aligned indices.
+ token_index = internal::CenterTokenFromClick(
+ {6, 11},
+ {Token("Hělló", 0, 5), Token("world", 6, 11), Token("heře!", 12, 17)});
+ EXPECT_EQ(token_index, 1);
+
+ // Click is contained in a token.
+ token_index = internal::CenterTokenFromClick(
+ {13, 17},
+ {Token("Hělló", 0, 5), Token("world", 6, 11), Token("heře!", 12, 17)});
+ EXPECT_EQ(token_index, 2);
+
+ // Click spans two tokens.
+ token_index = internal::CenterTokenFromClick(
+ {6, 17},
+ {Token("Hělló", 0, 5), Token("world", 6, 11), Token("heře!", 12, 17)});
+ EXPECT_EQ(token_index, kInvalidIndex);
+}
+
+TEST_F(AnnotatorFeatureProcessorTest, CenterTokenFromMiddleOfSelection) {
+ int token_index;
+
+ // Selection of length 3. Exactly aligned indices.
+ token_index = internal::CenterTokenFromMiddleOfSelection(
+ {7, 27},
+ {Token("Token1", 0, 6), Token("Token2", 7, 13), Token("Token3", 14, 20),
+ Token("Token4", 21, 27), Token("Token5", 28, 34)});
+ EXPECT_EQ(token_index, 2);
+
+ // Selection of length 1 token. Exactly aligned indices.
+ token_index = internal::CenterTokenFromMiddleOfSelection(
+ {21, 27},
+ {Token("Token1", 0, 6), Token("Token2", 7, 13), Token("Token3", 14, 20),
+ Token("Token4", 21, 27), Token("Token5", 28, 34)});
+ EXPECT_EQ(token_index, 3);
+
+ // Selection marks sub-token range, with no tokens in it.
+ token_index = internal::CenterTokenFromMiddleOfSelection(
+ {29, 33},
+ {Token("Token1", 0, 6), Token("Token2", 7, 13), Token("Token3", 14, 20),
+ Token("Token4", 21, 27), Token("Token5", 28, 34)});
+ EXPECT_EQ(token_index, kInvalidIndex);
+
+ // Selection of length 2. Sub-token indices.
+ token_index = internal::CenterTokenFromMiddleOfSelection(
+ {3, 25},
+ {Token("Token1", 0, 6), Token("Token2", 7, 13), Token("Token3", 14, 20),
+ Token("Token4", 21, 27), Token("Token5", 28, 34)});
+ EXPECT_EQ(token_index, 1);
+
+ // Selection of length 1. Sub-token indices.
+ token_index = internal::CenterTokenFromMiddleOfSelection(
+ {22, 34},
+ {Token("Token1", 0, 6), Token("Token2", 7, 13), Token("Token3", 14, 20),
+ Token("Token4", 21, 27), Token("Token5", 28, 34)});
+ EXPECT_EQ(token_index, 4);
+
+ // Some invalid ones.
+ token_index = internal::CenterTokenFromMiddleOfSelection({7, 27}, {});
+ EXPECT_EQ(token_index, -1);
+}
+
+TEST_F(AnnotatorFeatureProcessorTest, SupportedCodepointsRatio) {
+ FeatureProcessorOptionsT options;
+ options.context_size = 2;
+ options.max_selection_span = 2;
+ options.snap_label_span_boundaries_to_containing_tokens = false;
+ options.feature_version = 2;
+ options.embedding_size = 4;
+ options.bounds_sensitive_features.reset(
+ new FeatureProcessorOptions_::BoundsSensitiveFeaturesT());
+ options.bounds_sensitive_features->enabled = true;
+ options.bounds_sensitive_features->num_tokens_before = 5;
+ options.bounds_sensitive_features->num_tokens_inside_left = 3;
+ options.bounds_sensitive_features->num_tokens_inside_right = 3;
+ options.bounds_sensitive_features->num_tokens_after = 5;
+ options.bounds_sensitive_features->include_inside_bag = true;
+ options.bounds_sensitive_features->include_inside_length = true;
+
+ options.tokenization_codepoint_config.emplace_back(
+ new TokenizationCodepointRangeT());
+ auto& config = options.tokenization_codepoint_config.back();
+ config->start = 32;
+ config->end = 33;
+ config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR;
+
+ {
+ options.supported_codepoint_ranges.emplace_back(new CodepointRangeT());
+ auto& range = options.supported_codepoint_ranges.back();
+ range->start = 0;
+ range->end = 128;
+ }
+
+ {
+ options.supported_codepoint_ranges.emplace_back(new CodepointRangeT());
+ auto& range = options.supported_codepoint_ranges.back();
+ range->start = 10000;
+ range->end = 10001;
+ }
+
+ {
+ options.supported_codepoint_ranges.emplace_back(new CodepointRangeT());
+ auto& range = options.supported_codepoint_ranges.back();
+ range->start = 20000;
+ range->end = 30000;
+ }
+
+ flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
+ &unilib_);
+ EXPECT_THAT(feature_processor.SupportedCodepointsRatio(
+ {0, 3}, feature_processor.Tokenize("aaa bbb ccc")),
+ FloatEq(1.0));
+ EXPECT_THAT(feature_processor.SupportedCodepointsRatio(
+ {0, 3}, feature_processor.Tokenize("aaa bbb ěěě")),
+ FloatEq(2.0 / 3));
+ EXPECT_THAT(feature_processor.SupportedCodepointsRatio(
+ {0, 3}, feature_processor.Tokenize("ěěě řřř ěěě")),
+ FloatEq(0.0));
+ EXPECT_THAT(feature_processor.SupportedCodepointsRatio(
+ {0, 0}, feature_processor.Tokenize("")),
+ FloatEq(0.0));
+ EXPECT_FALSE(
+ IsCodepointInRanges(-1, feature_processor.supported_codepoint_ranges_));
+ EXPECT_TRUE(
+ IsCodepointInRanges(0, feature_processor.supported_codepoint_ranges_));
+ EXPECT_TRUE(
+ IsCodepointInRanges(10, feature_processor.supported_codepoint_ranges_));
+ EXPECT_TRUE(
+ IsCodepointInRanges(127, feature_processor.supported_codepoint_ranges_));
+ EXPECT_FALSE(
+ IsCodepointInRanges(128, feature_processor.supported_codepoint_ranges_));
+ EXPECT_FALSE(
+ IsCodepointInRanges(9999, feature_processor.supported_codepoint_ranges_));
+ EXPECT_TRUE(IsCodepointInRanges(
+ 10000, feature_processor.supported_codepoint_ranges_));
+ EXPECT_FALSE(IsCodepointInRanges(
+ 10001, feature_processor.supported_codepoint_ranges_));
+ EXPECT_TRUE(IsCodepointInRanges(
+ 25000, feature_processor.supported_codepoint_ranges_));
+
+ const std::vector<Token> tokens = {Token("ěěě", 0, 3), Token("řřř", 4, 7),
+ Token("eee", 8, 11)};
+
+ options.min_supported_codepoint_ratio = 0.0;
+ flatbuffers::DetachedBuffer options2_fb =
+ PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor2(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options2_fb.data()),
+ &unilib_);
+ EXPECT_TRUE(feature_processor2.HasEnoughSupportedCodepoints(
+ tokens, /*token_span=*/{0, 3}));
+
+ options.min_supported_codepoint_ratio = 0.2;
+ flatbuffers::DetachedBuffer options3_fb =
+ PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor3(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options3_fb.data()),
+ &unilib_);
+ EXPECT_TRUE(feature_processor3.HasEnoughSupportedCodepoints(
+ tokens, /*token_span=*/{0, 3}));
+
+ options.min_supported_codepoint_ratio = 0.5;
+ flatbuffers::DetachedBuffer options4_fb =
+ PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor4(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options4_fb.data()),
+ &unilib_);
+ EXPECT_FALSE(feature_processor4.HasEnoughSupportedCodepoints(
+ tokens, /*token_span=*/{0, 3}));
+}
+
+TEST_F(AnnotatorFeatureProcessorTest, InSpanFeature) {
+ FeatureProcessorOptionsT options;
+ options.context_size = 2;
+ options.max_selection_span = 2;
+ options.snap_label_span_boundaries_to_containing_tokens = false;
+ options.feature_version = 2;
+ options.embedding_size = 4;
+ options.extract_selection_mask_feature = true;
+
+ flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
+ &unilib_);
+
+ std::unique_ptr<CachedFeatures> cached_features;
+
+ FakeEmbeddingExecutor embedding_executor;
+
+ const std::vector<Token> tokens = {Token("aaa", 0, 3), Token("bbb", 4, 7),
+ Token("ccc", 8, 11), Token("ddd", 12, 15)};
+
+ EXPECT_TRUE(feature_processor.ExtractFeatures(
+ tokens, /*token_span=*/{0, 4},
+ /*selection_span_for_feature=*/{4, 11}, &embedding_executor,
+ /*embedding_cache=*/nullptr, /*feature_vector_size=*/5,
+ &cached_features));
+ std::vector<float> features;
+ cached_features->AppendClickContextFeaturesForClick(1, &features);
+ ASSERT_EQ(features.size(), 25);
+ EXPECT_THAT(features[4], FloatEq(0.0));
+ EXPECT_THAT(features[9], FloatEq(0.0));
+ EXPECT_THAT(features[14], FloatEq(1.0));
+ EXPECT_THAT(features[19], FloatEq(1.0));
+ EXPECT_THAT(features[24], FloatEq(0.0));
+}
+
+TEST_F(AnnotatorFeatureProcessorTest, EmbeddingCache) {
+ FeatureProcessorOptionsT options;
+ options.context_size = 2;
+ options.max_selection_span = 2;
+ options.snap_label_span_boundaries_to_containing_tokens = false;
+ options.feature_version = 2;
+ options.embedding_size = 4;
+ options.bounds_sensitive_features.reset(
+ new FeatureProcessorOptions_::BoundsSensitiveFeaturesT());
+ options.bounds_sensitive_features->enabled = true;
+ options.bounds_sensitive_features->num_tokens_before = 3;
+ options.bounds_sensitive_features->num_tokens_inside_left = 2;
+ options.bounds_sensitive_features->num_tokens_inside_right = 2;
+ options.bounds_sensitive_features->num_tokens_after = 3;
+
+ flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
+ &unilib_);
+
+ std::unique_ptr<CachedFeatures> cached_features;
+
+ FakeEmbeddingExecutor embedding_executor;
+
+ const std::vector<Token> tokens = {
+ Token("aaa", 0, 3), Token("bbb", 4, 7), Token("ccc", 8, 11),
+ Token("ddd", 12, 15), Token("eee", 16, 19), Token("fff", 20, 23)};
+
+ // We pre-populate the cache with dummy embeddings, to make sure they are
+ // used when populating the features vector.
+ const std::vector<float> cached_padding_features = {10.0, -10.0, 10.0, -10.0};
+ const std::vector<float> cached_features1 = {1.0, 2.0, 3.0, 4.0};
+ const std::vector<float> cached_features2 = {5.0, 6.0, 7.0, 8.0};
+ FeatureProcessor::EmbeddingCache embedding_cache = {
+ {{kInvalidIndex, kInvalidIndex}, cached_padding_features},
+ {{4, 7}, cached_features1},
+ {{12, 15}, cached_features2},
+ };
+
+ EXPECT_TRUE(feature_processor.ExtractFeatures(
+ tokens, /*token_span=*/{0, 6},
+ /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex},
+ &embedding_executor, &embedding_cache, /*feature_vector_size=*/4,
+ &cached_features));
+ std::vector<float> features;
+ cached_features->AppendBoundsSensitiveFeaturesForSpan({2, 4}, &features);
+ ASSERT_EQ(features.size(), 40);
+ // Check that the dummy embeddings were used.
+ EXPECT_THAT(Subvector(features, 0, 4),
+ ElementsAreFloat(cached_padding_features));
+ EXPECT_THAT(Subvector(features, 8, 12), ElementsAreFloat(cached_features1));
+ EXPECT_THAT(Subvector(features, 16, 20), ElementsAreFloat(cached_features2));
+ EXPECT_THAT(Subvector(features, 24, 28), ElementsAreFloat(cached_features2));
+ EXPECT_THAT(Subvector(features, 36, 40),
+ ElementsAreFloat(cached_padding_features));
+ // Check that the real embeddings were cached.
+ EXPECT_EQ(embedding_cache.size(), 7);
+ EXPECT_THAT(Subvector(features, 4, 8),
+ ElementsAreFloat(embedding_cache.at({0, 3})));
+ EXPECT_THAT(Subvector(features, 12, 16),
+ ElementsAreFloat(embedding_cache.at({8, 11})));
+ EXPECT_THAT(Subvector(features, 20, 24),
+ ElementsAreFloat(embedding_cache.at({8, 11})));
+ EXPECT_THAT(Subvector(features, 28, 32),
+ ElementsAreFloat(embedding_cache.at({16, 19})));
+ EXPECT_THAT(Subvector(features, 32, 36),
+ ElementsAreFloat(embedding_cache.at({20, 23})));
+}
+
+TEST_F(AnnotatorFeatureProcessorTest, StripUnusedTokensWithNoRelativeClick) {
+ std::vector<Token> tokens_orig{
+ Token("0", 0, 0), Token("1", 0, 0), Token("2", 0, 0), Token("3", 0, 0),
+ Token("4", 0, 0), Token("5", 0, 0), Token("6", 0, 0), Token("7", 0, 0),
+ Token("8", 0, 0), Token("9", 0, 0), Token("10", 0, 0), Token("11", 0, 0),
+ Token("12", 0, 0)};
+
+ std::vector<Token> tokens;
+ int click_index;
+
+ // Try to click first token and see if it gets padded from left.
+ tokens = tokens_orig;
+ click_index = 0;
+ internal::StripOrPadTokens({0, 0}, 2, &tokens, &click_index);
+ // clang-format off
+ EXPECT_EQ(tokens, std::vector<Token>({Token(),
+ Token(),
+ Token("0", 0, 0),
+ Token("1", 0, 0),
+ Token("2", 0, 0)}));
+ // clang-format on
+ EXPECT_EQ(click_index, 2);
+
+ // When we click the second token nothing should get padded.
+ tokens = tokens_orig;
+ click_index = 2;
+ internal::StripOrPadTokens({0, 0}, 2, &tokens, &click_index);
+ // clang-format off
+ EXPECT_EQ(tokens, std::vector<Token>({Token("0", 0, 0),
+ Token("1", 0, 0),
+ Token("2", 0, 0),
+ Token("3", 0, 0),
+ Token("4", 0, 0)}));
+ // clang-format on
+ EXPECT_EQ(click_index, 2);
+
+ // When we click the last token tokens should get padded from the right.
+ tokens = tokens_orig;
+ click_index = 12;
+ internal::StripOrPadTokens({0, 0}, 2, &tokens, &click_index);
+ // clang-format off
+ EXPECT_EQ(tokens, std::vector<Token>({Token("10", 0, 0),
+ Token("11", 0, 0),
+ Token("12", 0, 0),
+ Token(),
+ Token()}));
+ // clang-format on
+ EXPECT_EQ(click_index, 2);
+}
+
+TEST_F(AnnotatorFeatureProcessorTest, StripUnusedTokensWithRelativeClick) {
+ std::vector<Token> tokens_orig{
+ Token("0", 0, 0), Token("1", 0, 0), Token("2", 0, 0), Token("3", 0, 0),
+ Token("4", 0, 0), Token("5", 0, 0), Token("6", 0, 0), Token("7", 0, 0),
+ Token("8", 0, 0), Token("9", 0, 0), Token("10", 0, 0), Token("11", 0, 0),
+ Token("12", 0, 0)};
+
+ std::vector<Token> tokens;
+ int click_index;
+
+ // Try to click first token and see if it gets padded from left to maximum
+ // context_size.
+ tokens = tokens_orig;
+ click_index = 0;
+ internal::StripOrPadTokens({2, 3}, 2, &tokens, &click_index);
+ // clang-format off
+ EXPECT_EQ(tokens, std::vector<Token>({Token(),
+ Token(),
+ Token("0", 0, 0),
+ Token("1", 0, 0),
+ Token("2", 0, 0),
+ Token("3", 0, 0),
+ Token("4", 0, 0),
+ Token("5", 0, 0)}));
+ // clang-format on
+ EXPECT_EQ(click_index, 2);
+
+ // Clicking to the middle with enough context should not produce any padding.
+ tokens = tokens_orig;
+ click_index = 6;
+ internal::StripOrPadTokens({3, 1}, 2, &tokens, &click_index);
+ // clang-format off
+ EXPECT_EQ(tokens, std::vector<Token>({Token("1", 0, 0),
+ Token("2", 0, 0),
+ Token("3", 0, 0),
+ Token("4", 0, 0),
+ Token("5", 0, 0),
+ Token("6", 0, 0),
+ Token("7", 0, 0),
+ Token("8", 0, 0),
+ Token("9", 0, 0)}));
+ // clang-format on
+ EXPECT_EQ(click_index, 5);
+
+ // Clicking at the end should pad right to maximum context_size.
+ tokens = tokens_orig;
+ click_index = 11;
+ internal::StripOrPadTokens({3, 1}, 2, &tokens, &click_index);
+ // clang-format off
+ EXPECT_EQ(tokens, std::vector<Token>({Token("6", 0, 0),
+ Token("7", 0, 0),
+ Token("8", 0, 0),
+ Token("9", 0, 0),
+ Token("10", 0, 0),
+ Token("11", 0, 0),
+ Token("12", 0, 0),
+ Token(),
+ Token()}));
+ // clang-format on
+ EXPECT_EQ(click_index, 5);
+}
+
+TEST_F(AnnotatorFeatureProcessorTest, IgnoredSpanBoundaryCodepoints) {
+ FeatureProcessorOptionsT options;
+ options.ignored_span_boundary_codepoints.push_back('.');
+ options.ignored_span_boundary_codepoints.push_back(',');
+ options.ignored_span_boundary_codepoints.push_back('[');
+ options.ignored_span_boundary_codepoints.push_back(']');
+
+ flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
+ &unilib_);
+
+ const std::string text1_utf8 = "ěščř";
+ const UnicodeText text1 = UTF8ToUnicodeText(text1_utf8, /*do_copy=*/false);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text1.begin(), text1.end(),
+ /*count_from_beginning=*/true),
+ 0);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text1.begin(), text1.end(),
+ /*count_from_beginning=*/false),
+ 0);
+
+ const std::string text2_utf8 = ".,abčd";
+ const UnicodeText text2 = UTF8ToUnicodeText(text2_utf8, /*do_copy=*/false);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text2.begin(), text2.end(),
+ /*count_from_beginning=*/true),
+ 2);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text2.begin(), text2.end(),
+ /*count_from_beginning=*/false),
+ 0);
+
+ const std::string text3_utf8 = ".,abčd[]";
+ const UnicodeText text3 = UTF8ToUnicodeText(text3_utf8, /*do_copy=*/false);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text3.begin(), text3.end(),
+ /*count_from_beginning=*/true),
+ 2);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text3.begin(), text3.end(),
+ /*count_from_beginning=*/false),
+ 2);
+
+ const std::string text4_utf8 = "[abčd]";
+ const UnicodeText text4 = UTF8ToUnicodeText(text4_utf8, /*do_copy=*/false);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text4.begin(), text4.end(),
+ /*count_from_beginning=*/true),
+ 1);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text4.begin(), text4.end(),
+ /*count_from_beginning=*/false),
+ 1);
+
+ const std::string text5_utf8 = "";
+ const UnicodeText text5 = UTF8ToUnicodeText(text5_utf8, /*do_copy=*/false);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text5.begin(), text5.end(),
+ /*count_from_beginning=*/true),
+ 0);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text5.begin(), text5.end(),
+ /*count_from_beginning=*/false),
+ 0);
+
+ const std::string text6_utf8 = "012345ěščř";
+ const UnicodeText text6 = UTF8ToUnicodeText(text6_utf8, /*do_copy=*/false);
+ UnicodeText::const_iterator text6_begin = text6.begin();
+ std::advance(text6_begin, 6);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text6_begin, text6.end(),
+ /*count_from_beginning=*/true),
+ 0);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text6_begin, text6.end(),
+ /*count_from_beginning=*/false),
+ 0);
+
+ const std::string text7_utf8 = "012345.,ěščř";
+ const UnicodeText text7 = UTF8ToUnicodeText(text7_utf8, /*do_copy=*/false);
+ UnicodeText::const_iterator text7_begin = text7.begin();
+ std::advance(text7_begin, 6);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text7_begin, text7.end(),
+ /*count_from_beginning=*/true),
+ 2);
+ UnicodeText::const_iterator text7_end = text7.begin();
+ std::advance(text7_end, 8);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text7.begin(), text7_end,
+ /*count_from_beginning=*/false),
+ 2);
+
+ // Test not stripping.
+ EXPECT_EQ(feature_processor.StripBoundaryCodepoints(
+ "Hello [[[Wořld]] or not?", {0, 24}),
+ CodepointSpan(0, 24));
+ // Test basic stripping.
+ EXPECT_EQ(feature_processor.StripBoundaryCodepoints(
+ "Hello [[[Wořld]] or not?", {6, 16}),
+ CodepointSpan(9, 14));
+ // Test stripping when everything is stripped.
+ EXPECT_EQ(
+ feature_processor.StripBoundaryCodepoints("Hello [[[]] or not?", {6, 11}),
+ CodepointSpan(6, 6));
+ // Test stripping empty string.
+ EXPECT_EQ(feature_processor.StripBoundaryCodepoints("", {0, 0}),
+ CodepointSpan(0, 0));
+}
+
+TEST_F(AnnotatorFeatureProcessorTest, CodepointSpanToTokenSpan) {
+ const std::vector<Token> tokens{Token("Hělló", 0, 5),
+ Token("fěěbař@google.com", 6, 23),
+ Token("heře!", 24, 29)};
+
+ // Spans matching the tokens exactly.
+ EXPECT_EQ(TokenSpan(0, 1), CodepointSpanToTokenSpan(tokens, {0, 5}));
+ EXPECT_EQ(TokenSpan(1, 2), CodepointSpanToTokenSpan(tokens, {6, 23}));
+ EXPECT_EQ(TokenSpan(2, 3), CodepointSpanToTokenSpan(tokens, {24, 29}));
+ EXPECT_EQ(TokenSpan(0, 2), CodepointSpanToTokenSpan(tokens, {0, 23}));
+ EXPECT_EQ(TokenSpan(1, 3), CodepointSpanToTokenSpan(tokens, {6, 29}));
+ EXPECT_EQ(TokenSpan(0, 3), CodepointSpanToTokenSpan(tokens, {0, 29}));
+
+ // Snapping to containing tokens has no effect.
+ EXPECT_EQ(TokenSpan(0, 1), CodepointSpanToTokenSpan(tokens, {0, 5}, true));
+ EXPECT_EQ(TokenSpan(1, 2), CodepointSpanToTokenSpan(tokens, {6, 23}, true));
+ EXPECT_EQ(TokenSpan(2, 3), CodepointSpanToTokenSpan(tokens, {24, 29}, true));
+ EXPECT_EQ(TokenSpan(0, 2), CodepointSpanToTokenSpan(tokens, {0, 23}, true));
+ EXPECT_EQ(TokenSpan(1, 3), CodepointSpanToTokenSpan(tokens, {6, 29}, true));
+ EXPECT_EQ(TokenSpan(0, 3), CodepointSpanToTokenSpan(tokens, {0, 29}, true));
+
+ // Span boundaries inside tokens.
+ EXPECT_EQ(TokenSpan(1, 2), CodepointSpanToTokenSpan(tokens, {1, 28}));
+ EXPECT_EQ(TokenSpan(0, 3), CodepointSpanToTokenSpan(tokens, {1, 28}, true));
+
+ // Tokens adjacent to the span, but not overlapping.
+ EXPECT_EQ(TokenSpan(1, 2), CodepointSpanToTokenSpan(tokens, {5, 24}));
+ EXPECT_EQ(TokenSpan(1, 2), CodepointSpanToTokenSpan(tokens, {5, 24}, true));
+}
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/native/annotator/flatbuffer-utils.cc b/native/annotator/flatbuffer-utils.cc
new file mode 100644
index 0000000..d4cbe4a
--- /dev/null
+++ b/native/annotator/flatbuffer-utils.cc
@@ -0,0 +1,94 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/flatbuffer-utils.h"
+
+#include <memory>
+
+#include "utils/base/logging.h"
+#include "utils/flatbuffers/flatbuffers.h"
+#include "utils/flatbuffers/reflection.h"
+#include "flatbuffers/reflection.h"
+
+namespace libtextclassifier3 {
+
+bool SwapFieldNamesForOffsetsInPath(ModelT* model) {
+ if (model->regex_model == nullptr || model->entity_data_schema.empty()) {
+ // Nothing to do.
+ return true;
+ }
+ const reflection::Schema* schema =
+ LoadAndVerifyFlatbuffer<reflection::Schema>(
+ model->entity_data_schema.data(), model->entity_data_schema.size());
+
+ for (std::unique_ptr<RegexModel_::PatternT>& pattern :
+ model->regex_model->patterns) {
+ for (std::unique_ptr<CapturingGroupT>& group : pattern->capturing_group) {
+ if (group->entity_field_path == nullptr) {
+ continue;
+ }
+
+ if (!SwapFieldNamesForOffsetsInPath(schema,
+ group->entity_field_path.get())) {
+ return false;
+ }
+ }
+ }
+
+ return true;
+}
+
+std::string SwapFieldNamesForOffsetsInPathInSerializedModel(
+ const std::string& model) {
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(model.c_str());
+ TC3_CHECK(unpacked_model != nullptr);
+ TC3_CHECK(SwapFieldNamesForOffsetsInPath(unpacked_model.get()));
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+ return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize());
+}
+
+std::string CreateDatetimeSerializedEntityData(
+ const DatetimeParseResult& parse_result) {
+ EntityDataT entity_data;
+ entity_data.datetime.reset(new EntityData_::DatetimeT());
+ entity_data.datetime->time_ms_utc = parse_result.time_ms_utc;
+ entity_data.datetime->granularity =
+ static_cast<EntityData_::Datetime_::Granularity>(
+ parse_result.granularity);
+
+ for (const auto& c : parse_result.datetime_components) {
+ EntityData_::Datetime_::DatetimeComponentT datetime_component;
+ datetime_component.absolute_value = c.value;
+ datetime_component.relative_count = c.relative_count;
+ datetime_component.component_type =
+ static_cast<EntityData_::Datetime_::DatetimeComponent_::ComponentType>(
+ c.component_type);
+ datetime_component.relation_type =
+ EntityData_::Datetime_::DatetimeComponent_::RelationType_ABSOLUTE;
+ if (c.relative_qualifier !=
+ DatetimeComponent::RelativeQualifier::UNSPECIFIED) {
+ datetime_component.relation_type =
+ EntityData_::Datetime_::DatetimeComponent_::RelationType_RELATIVE;
+ }
+ entity_data.datetime->datetime_component.emplace_back(
+ new EntityData_::Datetime_::DatetimeComponentT(datetime_component));
+ }
+ return PackFlatbuffer<EntityData>(&entity_data);
+}
+
+} // namespace libtextclassifier3
diff --git a/native/annotator/flatbuffer-utils.h b/native/annotator/flatbuffer-utils.h
new file mode 100644
index 0000000..cd7d653
--- /dev/null
+++ b/native/annotator/flatbuffer-utils.h
@@ -0,0 +1,42 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Utility functions for working with FlatBuffers in the annotator model.
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_FLATBUFFER_UTILS_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_FLATBUFFER_UTILS_H_
+
+#include <string>
+
+#include "annotator/model_generated.h"
+#include "annotator/types.h"
+
+namespace libtextclassifier3 {
+
+// Resolves field lookups by name to the concrete field offsets in the regex
+// rules of the model.
+bool SwapFieldNamesForOffsetsInPath(ModelT* model);
+
+// Same as above but for a serialized model.
+std::string SwapFieldNamesForOffsetsInPathInSerializedModel(
+ const std::string& model);
+
+std::string CreateDatetimeSerializedEntityData(
+ const DatetimeParseResult& parse_result);
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_FLATBUFFER_UTILS_H_
diff --git a/native/annotator/grammar/grammar-annotator.cc b/native/annotator/grammar/grammar-annotator.cc
index baa3fac..756fcd5 100644
--- a/native/annotator/grammar/grammar-annotator.cc
+++ b/native/annotator/grammar/grammar-annotator.cc
@@ -49,8 +49,7 @@
public:
explicit GrammarAnnotatorCallbackDelegate(
const UniLib* unilib, const GrammarModel* model,
- const ReflectiveFlatbufferBuilder* entity_data_builder,
- const ModeFlag mode)
+ const MutableFlatbufferBuilder* entity_data_builder, const ModeFlag mode)
: unilib_(*unilib),
model_(model),
entity_data_builder_(entity_data_builder),
@@ -276,7 +275,7 @@
if (entity_data_builder_ == nullptr) {
return true;
}
- std::unique_ptr<ReflectiveFlatbuffer> entity_data =
+ std::unique_ptr<MutableFlatbuffer> entity_data =
entity_data_builder_->NewRoot();
if (interpretation->serialized_entity_data() != nullptr) {
entity_data->MergeFromSerializedFlatbuffer(
@@ -342,7 +341,7 @@
const UniLib& unilib_;
const GrammarModel* model_;
- const ReflectiveFlatbufferBuilder* entity_data_builder_;
+ const MutableFlatbufferBuilder* entity_data_builder_;
const ModeFlag mode_;
// All annotation/selection/classification rule match candidates.
@@ -352,7 +351,7 @@
GrammarAnnotator::GrammarAnnotator(
const UniLib* unilib, const GrammarModel* model,
- const ReflectiveFlatbufferBuilder* entity_data_builder)
+ const MutableFlatbufferBuilder* entity_data_builder)
: unilib_(*unilib),
model_(model),
lexer_(unilib, model->rules()),
diff --git a/native/annotator/grammar/grammar-annotator.h b/native/annotator/grammar/grammar-annotator.h
index 365bb44..b9ef62c 100644
--- a/native/annotator/grammar/grammar-annotator.h
+++ b/native/annotator/grammar/grammar-annotator.h
@@ -21,7 +21,7 @@
#include "annotator/model_generated.h"
#include "annotator/types.h"
-#include "utils/flatbuffers.h"
+#include "utils/flatbuffers/mutable.h"
#include "utils/grammar/lexer.h"
#include "utils/i18n/locale.h"
#include "utils/tokenizer.h"
@@ -39,7 +39,7 @@
explicit GrammarAnnotator(
const UniLib* unilib, const GrammarModel* model,
- const ReflectiveFlatbufferBuilder* entity_data_builder);
+ const MutableFlatbufferBuilder* entity_data_builder);
// Annotates a given text.
// Returns true if the text was successfully annotated.
@@ -63,7 +63,7 @@
const GrammarModel* model_;
const grammar::Lexer lexer_;
const Tokenizer tokenizer_;
- const ReflectiveFlatbufferBuilder* entity_data_builder_;
+ const MutableFlatbufferBuilder* entity_data_builder_;
// Pre-parsed locales of the rules.
const std::vector<std::vector<Locale>> rules_locales_;
diff --git a/native/annotator/grammar/utils.cc b/native/annotator/grammar/utils.cc
index 8b9363d..bb58190 100644
--- a/native/annotator/grammar/utils.cc
+++ b/native/annotator/grammar/utils.cc
@@ -53,13 +53,14 @@
int AddRuleClassificationResult(const std::string& collection,
const ModeFlag& enabled_modes,
- GrammarModelT* model) {
+ float priority_score, GrammarModelT* model) {
const int result_id = model->rule_classification_result.size();
model->rule_classification_result.emplace_back(new RuleClassificationResultT);
RuleClassificationResultT* result =
model->rule_classification_result.back().get();
result->collection_name = collection;
result->enabled_modes = enabled_modes;
+ result->priority_score = priority_score;
return result_id;
}
diff --git a/native/annotator/grammar/utils.h b/native/annotator/grammar/utils.h
index 4d870fd..21d383f 100644
--- a/native/annotator/grammar/utils.h
+++ b/native/annotator/grammar/utils.h
@@ -35,7 +35,7 @@
// Returns the ID associated with the created classification rule.
int AddRuleClassificationResult(const std::string& collection,
const ModeFlag& enabled_modes,
- GrammarModelT* model);
+ float priority_score, GrammarModelT* model);
} // namespace libtextclassifier3
diff --git a/native/annotator/knowledge/knowledge-engine-dummy.h b/native/annotator/knowledge/knowledge-engine-dummy.h
index e9f688a..2a53288 100644
--- a/native/annotator/knowledge/knowledge-engine-dummy.h
+++ b/native/annotator/knowledge/knowledge-engine-dummy.h
@@ -19,6 +19,7 @@
#include <string>
+#include "annotator/knowledge/knowledge-engine-types.h"
#include "annotator/model_generated.h"
#include "annotator/types.h"
#include "utils/base/status.h"
@@ -46,17 +47,17 @@
bool Chunk(const std::string& text, AnnotationUsecase annotation_usecase,
const Optional<LocationContext>& location_context,
- const Permissions& permissions,
- std::vector<AnnotatedSpan>* result) const {
+ const Permissions& permissions, const AnnotateMode annotate_mode,
+ Annotations* result) const {
return true;
}
- Status ChunkMultipleSpans(
- const std::vector<std::string>& text_fragments,
- AnnotationUsecase annotation_usecase,
- const Optional<LocationContext>& location_context,
- const Permissions& permissions,
- std::vector<std::vector<AnnotatedSpan>>* results) const {
+ Status ChunkMultipleSpans(const std::vector<std::string>& text_fragments,
+ AnnotationUsecase annotation_usecase,
+ const Optional<LocationContext>& location_context,
+ const Permissions& permissions,
+ const AnnotateMode annotate_mode,
+ Annotations* results) const {
return Status::OK;
}
diff --git a/native/annotator/knowledge/knowledge-engine-types.h b/native/annotator/knowledge/knowledge-engine-types.h
new file mode 100644
index 0000000..9508c7b
--- /dev/null
+++ b/native/annotator/knowledge/knowledge-engine-types.h
@@ -0,0 +1,26 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_KNOWLEDGE_KNOWLEDGE_ENGINE_TYPES_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_KNOWLEDGE_KNOWLEDGE_ENGINE_TYPES_H_
+
+namespace libtextclassifier3 {
+
+enum AnnotateMode { kEntityAnnotation, kEntityAndTopicalityAnnotation };
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_KNOWLEDGE_KNOWLEDGE_ENGINE_TYPES_H_
diff --git a/native/annotator/model.fbs b/native/annotator/model.fbs
index bdb7a17..a9906fc 100755
--- a/native/annotator/model.fbs
+++ b/native/annotator/model.fbs
@@ -14,16 +14,16 @@
// limitations under the License.
//
-include "annotator/entity-data.fbs";
include "annotator/experimental/experimental.fbs";
+include "annotator/entity-data.fbs";
include "annotator/grammar/dates/dates.fbs";
-include "utils/codepoint-range.fbs";
-include "utils/flatbuffers.fbs";
include "utils/grammar/rules.fbs";
include "utils/intents/intent-config.fbs";
-include "utils/normalization.fbs";
include "utils/resources.fbs";
include "utils/tokenizer.fbs";
+include "utils/codepoint-range.fbs";
+include "utils/flatbuffers/flatbuffers.fbs";
+include "utils/normalization.fbs";
include "utils/zlib/buffer.fbs";
file_identifier "TC2 ";
@@ -131,6 +131,8 @@
NINETY = 70,
HUNDRED = 71,
THOUSAND = 72,
+ NOON = 73,
+ MIDNIGHT = 74,
}
namespace libtextclassifier3;
@@ -154,6 +156,7 @@
GROUP_DUMMY1 = 12,
GROUP_DUMMY2 = 13,
+ GROUP_ABSOLUTETIME = 14,
}
// Options for the model that predicts text selection.
@@ -504,10 +507,21 @@
tokenizer_options:GrammarTokenizerOptions;
}
+namespace libtextclassifier3.MoneyParsingOptions_;
+table QuantitiesNameToExponentEntry {
+ key:string (key, shared);
+ value:int;
+}
+
namespace libtextclassifier3;
table MoneyParsingOptions {
// Separators (codepoints) marking decimal or thousand in the money amount.
separators:[int];
+
+ // Mapping between a quantity string (e.g. "million") and the power of 10
+ // it multiplies the amount with (e.g. 6 in case of "million").
+ // NOTE: The entries need to be sorted by key since we use LookupByKey.
+ quantities_name_to_exponent:[MoneyParsingOptions_.QuantitiesNameToExponentEntry];
}
namespace libtextclassifier3.ModelTriggeringOptions_;
@@ -659,6 +673,7 @@
grammar_model:GrammarModel;
conflict_resolution_options:Model_.ConflictResolutionOptions;
experimental_model:ExperimentalModel;
+ pod_ner_model:PodNerModel;
}
// Method for selecting the center token.
@@ -985,4 +1000,65 @@
backoff_options:TranslateAnnotatorOptions_.BackoffOptions;
}
+namespace libtextclassifier3.PodNerModel_.Label_;
+enum BoiseType : int {
+ NONE = 0,
+ BEGIN = 1,
+ O = 2,
+ // No label.
+
+ INTERMEDIATE = 3,
+ SINGLE = 4,
+ END = 5,
+}
+
+namespace libtextclassifier3.PodNerModel_.Label_;
+enum MentionType : int {
+ UNDEFINED = 0,
+ NAM = 1,
+ NOM = 2,
+}
+
+namespace libtextclassifier3.PodNerModel_;
+table Label {
+ boise_type:Label_.BoiseType;
+ mention_type:Label_.MentionType;
+ collection_id:int;
+ // points to the collections array above.
+}
+
+namespace libtextclassifier3;
+table PodNerModel {
+ tflite_model:[ubyte];
+ word_piece_vocab:[ubyte];
+ lowercase_input:bool = true;
+
+ // Index of mention_logits tensor in the output of the tflite model. Can
+ // be found in the textproto output after model is converted to tflite.
+ logits_index_in_output_tensor:int = 0;
+
+ // Whether to append a period at the end of an input that doesn't already
+ // end in punctuation.
+ append_final_period:bool = false;
+
+ // Priority score used for conflict resolution with the other models.
+ priority_score:float = 0;
+
+ // Maximum number of wordpieces supported by the model.
+ max_num_wordpieces:int = 128;
+
+ // In case of long text (number of wordpieces greater than the max) we use
+ // sliding window approach, this determines the number of overlapping
+ // wordpieces between two consecutive windows. This overlap enables context
+ // for each word NER annotates.
+ sliding_window_num_wordpieces_overlap:int = 20;
+
+ // Possible collections for labeled entities, e.g., "location", "person".
+ collections:[string];
+
+ // The possible labels the ner model can output. If empty the default labels
+ // will be used.
+ labels:[PodNerModel_.Label];
+}
+
root_type libtextclassifier3.Model;
diff --git a/native/annotator/pod_ner/pod-ner-dummy.h b/native/annotator/pod_ner/pod-ner-dummy.h
new file mode 100644
index 0000000..1246ade
--- /dev/null
+++ b/native/annotator/pod_ner/pod-ner-dummy.h
@@ -0,0 +1,56 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_POD_NER_POD_NER_DUMMY_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_POD_NER_POD_NER_DUMMY_H_
+
+#include <memory>
+
+#include "annotator/model_generated.h"
+#include "annotator/types.h"
+#include "utils/utf8/unicodetext.h"
+#include "utils/utf8/unilib.h"
+
+namespace libtextclassifier3 {
+
+// Dummy version of POD NER annotator. To be included in builds that do not
+// want POD NER support.
+class PodNerAnnotator {
+ public:
+ static std::unique_ptr<PodNerAnnotator> Create(const PodNerModel *model,
+ const UniLib &unilib) {
+ return nullptr;
+ }
+
+ bool Annotate(const UnicodeText &context,
+ std::vector<AnnotatedSpan> *results) const {
+ return true;
+ }
+
+ AnnotatedSpan SuggestSelection(const UnicodeText &context,
+ CodepointSpan click) const {
+ return {};
+ }
+
+ bool ClassifyText(const UnicodeText &context, CodepointSpan click,
+ ClassificationResult *result) const {
+ return true;
+ }
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_POD_NER_POD_NER_DUMMY_H_
diff --git a/native/annotator/pod_ner/pod-ner.h b/native/annotator/pod_ner/pod-ner.h
new file mode 100644
index 0000000..3594e6e
--- /dev/null
+++ b/native/annotator/pod_ner/pod-ner.h
@@ -0,0 +1,22 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_POD_NER_POD_NER_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_POD_NER_POD_NER_H_
+
+#include "annotator/pod_ner/pod-ner-dummy.h"
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_POD_NER_POD_NER_H_
diff --git a/native/annotator/strip-unpaired-brackets.cc b/native/annotator/strip-unpaired-brackets.cc
index b1067ad..df1fcce 100644
--- a/native/annotator/strip-unpaired-brackets.cc
+++ b/native/annotator/strip-unpaired-brackets.cc
@@ -66,7 +66,7 @@
// version.
CodepointSpan StripUnpairedBrackets(const UnicodeText& context_unicode,
CodepointSpan span, const UniLib& unilib) {
- if (context_unicode.empty() || !ValidNonEmptySpan(span)) {
+ if (context_unicode.empty() || !span.IsValid() || span.IsEmpty()) {
return span;
}
diff --git a/native/annotator/strip-unpaired-brackets_test.cc b/native/annotator/strip-unpaired-brackets_test.cc
index 32585ce..a7a3d29 100644
--- a/native/annotator/strip-unpaired-brackets_test.cc
+++ b/native/annotator/strip-unpaired-brackets_test.cc
@@ -30,36 +30,36 @@
TEST_F(StripUnpairedBracketsTest, StripUnpairedBrackets) {
// If the brackets match, nothing gets stripped.
EXPECT_EQ(StripUnpairedBrackets("call me (123) 456 today", {8, 17}, unilib_),
- std::make_pair(8, 17));
+ CodepointSpan(8, 17));
EXPECT_EQ(StripUnpairedBrackets("call me (123 456) today", {8, 17}, unilib_),
- std::make_pair(8, 17));
+ CodepointSpan(8, 17));
// If the brackets don't match, they get stripped.
EXPECT_EQ(StripUnpairedBrackets("call me (123 456 today", {8, 16}, unilib_),
- std::make_pair(9, 16));
+ CodepointSpan(9, 16));
EXPECT_EQ(StripUnpairedBrackets("call me )123 456 today", {8, 16}, unilib_),
- std::make_pair(9, 16));
+ CodepointSpan(9, 16));
EXPECT_EQ(StripUnpairedBrackets("call me 123 456) today", {8, 16}, unilib_),
- std::make_pair(8, 15));
+ CodepointSpan(8, 15));
EXPECT_EQ(StripUnpairedBrackets("call me 123 456( today", {8, 16}, unilib_),
- std::make_pair(8, 15));
+ CodepointSpan(8, 15));
// Strips brackets correctly from length-1 selections that consist of
// a bracket only.
EXPECT_EQ(StripUnpairedBrackets("call me at ) today", {11, 12}, unilib_),
- std::make_pair(12, 12));
+ CodepointSpan(12, 12));
EXPECT_EQ(StripUnpairedBrackets("call me at ( today", {11, 12}, unilib_),
- std::make_pair(12, 12));
+ CodepointSpan(12, 12));
// Handles invalid spans gracefully.
EXPECT_EQ(StripUnpairedBrackets("call me at today", {11, 11}, unilib_),
- std::make_pair(11, 11));
+ CodepointSpan(11, 11));
EXPECT_EQ(StripUnpairedBrackets("hello world", {0, 0}, unilib_),
- std::make_pair(0, 0));
+ CodepointSpan(0, 0));
EXPECT_EQ(StripUnpairedBrackets("hello world", {11, 11}, unilib_),
- std::make_pair(11, 11));
+ CodepointSpan(11, 11));
EXPECT_EQ(StripUnpairedBrackets("hello world", {-1, -1}, unilib_),
- std::make_pair(-1, -1));
+ CodepointSpan(-1, -1));
}
} // namespace
diff --git a/native/annotator/types-test-util.h b/native/annotator/types-test-util.h
index 1d018a1..55dd214 100644
--- a/native/annotator/types-test-util.h
+++ b/native/annotator/types-test-util.h
@@ -34,9 +34,11 @@
TC3_DECLARE_PRINT_OPERATOR(AnnotatedSpan)
TC3_DECLARE_PRINT_OPERATOR(ClassificationResult)
+TC3_DECLARE_PRINT_OPERATOR(CodepointSpan)
TC3_DECLARE_PRINT_OPERATOR(DatetimeParsedData)
TC3_DECLARE_PRINT_OPERATOR(DatetimeParseResultSpan)
TC3_DECLARE_PRINT_OPERATOR(Token)
+TC3_DECLARE_PRINT_OPERATOR(TokenSpan)
#undef TC3_DECLARE_PRINT_OPERATOR
diff --git a/native/annotator/types.cc b/native/annotator/types.cc
index be542d3..b1dde17 100644
--- a/native/annotator/types.cc
+++ b/native/annotator/types.cc
@@ -22,6 +22,21 @@
namespace libtextclassifier3 {
+const CodepointSpan CodepointSpan::kInvalid =
+ CodepointSpan(kInvalidIndex, kInvalidIndex);
+
+const TokenSpan TokenSpan::kInvalid = TokenSpan(kInvalidIndex, kInvalidIndex);
+
+logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
+ const CodepointSpan& span) {
+ return stream << "CodepointSpan(" << span.first << ", " << span.second << ")";
+}
+
+logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
+ const TokenSpan& span) {
+ return stream << "TokenSpan(" << span.first << ", " << span.second << ")";
+}
+
logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
const Token& token) {
if (!token.is_padding) {
diff --git a/native/annotator/types.h b/native/annotator/types.h
index 665d4b6..df6f676 100644
--- a/native/annotator/types.h
+++ b/native/annotator/types.h
@@ -30,9 +30,10 @@
#include <vector>
#include "annotator/entity-data_generated.h"
+#include "annotator/knowledge/knowledge-engine-types.h"
#include "utils/base/integral_types.h"
#include "utils/base/logging.h"
-#include "utils/flatbuffers.h"
+#include "utils/flatbuffers/flatbuffers.h"
#include "utils/optional.h"
#include "utils/variant.h"
@@ -56,17 +57,48 @@
// Marks a span in a sequence of codepoints. The first element is the index of
// the first codepoint of the span, and the second element is the index of the
// codepoint one past the end of the span.
-// TODO(b/71982294): Make it a struct.
-using CodepointSpan = std::pair<CodepointIndex, CodepointIndex>;
+struct CodepointSpan {
+ static const CodepointSpan kInvalid;
+
+
+ CodepointSpan(CodepointIndex start, CodepointIndex end)
+ : first(start), second(end) {}
+
+ CodepointSpan& operator=(const CodepointSpan& other) = default;
+
+ bool operator==(const CodepointSpan& other) const {
+ return this->first == other.first && this->second == other.second;
+ }
+
+ bool operator!=(const CodepointSpan& other) const {
+ return !(*this == other);
+ }
+
+ bool operator<(const CodepointSpan& other) const {
+ if (this->first != other.first) {
+ return this->first < other.first;
+ }
+ return this->second < other.second;
+ }
+
+ bool IsValid() const {
+ return this->first != kInvalidIndex && this->second != kInvalidIndex;
+ }
+
+ bool IsEmpty() const { return this->first == this->second; }
+
+ CodepointIndex first;
+ CodepointIndex second;
+};
+
+// Pretty-printing function for CodepointSpan.
+logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
+ const CodepointSpan& span);
inline bool SpansOverlap(const CodepointSpan& a, const CodepointSpan& b) {
return a.first < b.second && b.first < a.second;
}
-inline bool ValidNonEmptySpan(const CodepointSpan& span) {
- return span.first < span.second && span.first >= 0 && span.second >= 0;
-}
-
template <typename T>
bool DoesCandidateConflict(
const int considered_candidate, const std::vector<T>& candidates,
@@ -102,35 +134,61 @@
// Marks a span in a sequence of tokens. The first element is the index of the
// first token in the span, and the second element is the index of the token one
// past the end of the span.
-// TODO(b/71982294): Make it a struct.
-using TokenSpan = std::pair<TokenIndex, TokenIndex>;
+struct TokenSpan {
+ static const TokenSpan kInvalid;
-// Returns the size of the token span. Assumes that the span is valid.
-inline int TokenSpanSize(const TokenSpan& token_span) {
- return token_span.second - token_span.first;
-}
+ TokenSpan() : first(kInvalidIndex), second(kInvalidIndex) {}
-// Returns a token span consisting of one token.
-inline TokenSpan SingleTokenSpan(int token_index) {
- return {token_index, token_index + 1};
-}
+ TokenSpan(TokenIndex start, TokenIndex end) : first(start), second(end) {}
-// Returns an intersection of two token spans. Assumes that both spans are valid
-// and overlapping.
+ // Creates a token span consisting of one token.
+ explicit TokenSpan(int token_index)
+ : first(token_index), second(token_index + 1) {}
+
+ TokenSpan& operator=(const TokenSpan& other) = default;
+
+ bool operator==(const TokenSpan& other) const {
+ return this->first == other.first && this->second == other.second;
+ }
+
+ bool operator!=(const TokenSpan& other) const { return !(*this == other); }
+
+ bool operator<(const TokenSpan& other) const {
+ if (this->first != other.first) {
+ return this->first < other.first;
+ }
+ return this->second < other.second;
+ }
+
+ bool IsValid() const {
+ return this->first != kInvalidIndex && this->second != kInvalidIndex;
+ }
+
+ // Returns the size of the token span. Assumes that the span is valid.
+ int Size() const { return this->second - this->first; }
+
+ // Returns an expanded token span by adding a certain number of tokens on its
+ // left and on its right.
+ TokenSpan Expand(int num_tokens_left, int num_tokens_right) const {
+ return {this->first - num_tokens_left, this->second + num_tokens_right};
+ }
+
+ TokenIndex first;
+ TokenIndex second;
+};
+
+// Pretty-printing function for TokenSpan.
+logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
+ const TokenSpan& span);
+
+// Returns an intersection of two token spans. Assumes that both spans are
+// valid and overlapping.
inline TokenSpan IntersectTokenSpans(const TokenSpan& token_span1,
const TokenSpan& token_span2) {
return {std::max(token_span1.first, token_span2.first),
std::min(token_span1.second, token_span2.second)};
}
-// Returns and expanded token span by adding a certain number of tokens on its
-// left and on its right.
-inline TokenSpan ExpandTokenSpan(const TokenSpan& token_span,
- int num_tokens_left, int num_tokens_right) {
- return {token_span.first - num_tokens_left,
- token_span.second + num_tokens_right};
-}
-
// Token holds a token, its position in the original string and whether it was
// part of the input span.
struct Token {
@@ -169,7 +227,7 @@
is_padding == other.is_padding;
}
- bool IsContainedInSpan(CodepointSpan span) const {
+ bool IsContainedInSpan(const CodepointSpan& span) const {
return start >= span.first && end <= span.second;
}
};
@@ -178,6 +236,11 @@
logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
const Token& token);
+// Returns a TokenSpan that merges all of the given token spans.
+inline TokenSpan AllOf(const std::vector<Token>& tokens) {
+ return {0, static_cast<TokenIndex>(tokens.size())};
+}
+
enum DatetimeGranularity {
GRANULARITY_UNKNOWN = -1, // GRANULARITY_UNKNOWN is used as a proxy for this
// structure being uninitialized.
@@ -314,17 +377,18 @@
float priority_score;
DatetimeParseResultSpan()
- : target_classification_score(-1.0), priority_score(-1.0) {}
+ : span(CodepointSpan::kInvalid),
+ target_classification_score(-1.0),
+ priority_score(-1.0) {}
DatetimeParseResultSpan(const CodepointSpan& span,
const std::vector<DatetimeParseResult>& data,
const float target_classification_score,
- const float priority_score) {
- this->span = span;
- this->data = data;
- this->target_classification_score = target_classification_score;
- this->priority_score = priority_score;
- }
+ const float priority_score)
+ : span(span),
+ data(data),
+ target_classification_score(target_classification_score),
+ priority_score(priority_score) {}
bool operator==(const DatetimeParseResultSpan& other) const {
return span == other.span && data == other.data &&
@@ -456,6 +520,9 @@
// The location context passed along with each annotation.
Optional<LocationContext> location_context;
+ // If true, the POD NER annotator is used.
+ bool use_pod_ner = true;
+
bool operator==(const BaseOptions& other) const {
bool location_context_equality = this->location_context.has_value() ==
other.location_context.has_value();
@@ -525,11 +592,14 @@
// Defines the permissions for the annotators.
Permissions permissions;
+ AnnotateMode annotate_mode = AnnotateMode::kEntityAnnotation;
+
bool operator==(const AnnotationOptions& other) const {
return this->is_serialized_entity_data_enabled ==
other.is_serialized_entity_data_enabled &&
this->permissions == other.permissions &&
this->entity_types == other.entity_types &&
+ this->annotate_mode == other.annotate_mode &&
BaseOptions::operator==(other) && DatetimeOptions::operator==(other);
}
};
@@ -552,7 +622,7 @@
enum class Source { OTHER, KNOWLEDGE, DURATION, DATETIME, PERSON_NAME };
// Unicode codepoint indices in the input string.
- CodepointSpan span = {kInvalidIndex, kInvalidIndex};
+ CodepointSpan span = CodepointSpan::kInvalid;
// Classification result for the span.
std::vector<ClassificationResult> classification;
@@ -574,6 +644,27 @@
source(arg_source) {}
};
+// Represents Annotations that correspond to all input fragments.
+struct Annotations {
+ // List of annotations found in the corresponding input fragments. For these
+ // annotations, topicality score will not be set.
+ std::vector<std::vector<AnnotatedSpan>> annotated_spans;
+
+ // List of topicality results found across all input fragments.
+ std::vector<ClassificationResult> topicality_results;
+
+ Annotations() = default;
+
+ explicit Annotations(
+ std::vector<std::vector<AnnotatedSpan>> arg_annotated_spans)
+ : annotated_spans(std::move(arg_annotated_spans)) {}
+
+ Annotations(std::vector<std::vector<AnnotatedSpan>> arg_annotated_spans,
+ std::vector<ClassificationResult> arg_topicality_results)
+ : annotated_spans(std::move(arg_annotated_spans)),
+ topicality_results(std::move(arg_topicality_results)) {}
+};
+
struct InputFragment {
std::string text;
@@ -591,7 +682,7 @@
class VectorSpan {
public:
VectorSpan() : begin_(), end_() {}
- VectorSpan(const std::vector<T>& v) // NOLINT(runtime/explicit)
+ explicit VectorSpan(const std::vector<T>& v) // NOLINT(runtime/explicit)
: begin_(v.begin()), end_(v.end()) {}
VectorSpan(typename std::vector<T>::const_iterator begin,
typename std::vector<T>::const_iterator end)
diff --git a/native/annotator/zlib-utils.cc b/native/annotator/zlib-utils.cc
index c3c2cf1..2f3012e 100644
--- a/native/annotator/zlib-utils.cc
+++ b/native/annotator/zlib-utils.cc
@@ -20,7 +20,6 @@
#include "utils/base/logging.h"
#include "utils/intents/zlib-utils.h"
-#include "utils/resources.h"
#include "utils/zlib/zlib.h"
namespace libtextclassifier3 {
@@ -66,11 +65,6 @@
}
}
- // Compress resources.
- if (model->resources != nullptr) {
- CompressResources(model->resources.get());
- }
-
// Compress intent generator.
if (model->intent_options != nullptr) {
CompressIntentModel(model->intent_options.get());
@@ -126,10 +120,6 @@
}
}
- if (model->resources != nullptr) {
- DecompressResources(model->resources.get());
- }
-
if (model->intent_options != nullptr) {
DecompressIntentModel(model->intent_options.get());
}
diff --git a/native/annotator/zlib-utils_test.cc b/native/annotator/zlib-utils_test.cc
index df33ea1..7e4ef08 100644
--- a/native/annotator/zlib-utils_test.cc
+++ b/native/annotator/zlib-utils_test.cc
@@ -55,42 +55,9 @@
model.intent_options->generator.back()->lua_template_generator =
std::vector<uint8_t>(intent_generator2.begin(), intent_generator2.end());
- // NOTE: The resource strings contain some repetition, so that the compressed
- // version is smaller than the uncompressed one. Because the compression code
- // looks at that as well.
- model.resources.reset(new ResourcePoolT);
- model.resources->resource_entry.emplace_back(new ResourceEntryT);
- model.resources->resource_entry.back()->resource.emplace_back(new ResourceT);
- model.resources->resource_entry.back()->resource.back()->content =
- "rrrrrrrrrrrrr1.1";
- model.resources->resource_entry.back()->resource.emplace_back(new ResourceT);
- model.resources->resource_entry.back()->resource.back()->content =
- "rrrrrrrrrrrrr1.2";
- model.resources->resource_entry.emplace_back(new ResourceEntryT);
- model.resources->resource_entry.back()->resource.emplace_back(new ResourceT);
- model.resources->resource_entry.back()->resource.back()->content =
- "rrrrrrrrrrrrr2.1";
- model.resources->resource_entry.back()->resource.emplace_back(new ResourceT);
- model.resources->resource_entry.back()->resource.back()->content =
- "rrrrrrrrrrrrr2.2";
-
// Compress the model.
EXPECT_TRUE(CompressModel(&model));
- // Sanity check that uncompressed field is removed.
- EXPECT_TRUE(model.regex_model->patterns[0]->pattern.empty());
- EXPECT_TRUE(model.regex_model->patterns[1]->pattern.empty());
- EXPECT_TRUE(model.datetime_model->patterns[0]->regexes[0]->pattern.empty());
- EXPECT_TRUE(model.datetime_model->extractors[0]->pattern.empty());
- EXPECT_TRUE(
- model.intent_options->generator[0]->lua_template_generator.empty());
- EXPECT_TRUE(
- model.intent_options->generator[1]->lua_template_generator.empty());
- EXPECT_TRUE(model.resources->resource_entry[0]->resource[0]->content.empty());
- EXPECT_TRUE(model.resources->resource_entry[0]->resource[1]->content.empty());
- EXPECT_TRUE(model.resources->resource_entry[1]->resource[0]->content.empty());
- EXPECT_TRUE(model.resources->resource_entry[1]->resource[1]->content.empty());
-
// Pack and load the model.
flatbuffers::FlatBufferBuilder builder;
builder.Finish(Model::Pack(builder, &model));
@@ -139,14 +106,6 @@
EXPECT_EQ(
model.intent_options->generator[1]->lua_template_generator,
std::vector<uint8_t>(intent_generator2.begin(), intent_generator2.end()));
- EXPECT_EQ(model.resources->resource_entry[0]->resource[0]->content,
- "rrrrrrrrrrrrr1.1");
- EXPECT_EQ(model.resources->resource_entry[0]->resource[1]->content,
- "rrrrrrrrrrrrr1.2");
- EXPECT_EQ(model.resources->resource_entry[1]->resource[0]->content,
- "rrrrrrrrrrrrr2.1");
- EXPECT_EQ(model.resources->resource_entry[1]->resource[1]->content,
- "rrrrrrrrrrrrr2.2");
}
} // namespace libtextclassifier3
diff --git a/native/lang_id/common/flatbuffers/embedding-network-params-from-flatbuffer.cc b/native/lang_id/common/flatbuffers/embedding-network-params-from-flatbuffer.cc
index ee22420..d6daa3f 100644
--- a/native/lang_id/common/flatbuffers/embedding-network-params-from-flatbuffer.cc
+++ b/native/lang_id/common/flatbuffers/embedding-network-params-from-flatbuffer.cc
@@ -384,6 +384,7 @@
const flatbuffers::Vector<uint16_t> *scales = matrix->scales();
if (scales == nullptr) {
SAFTM_LOG(ERROR) << "nullptr scales";
+ return nullptr;
}
return scales->data();
}
diff --git a/native/lang_id/common/lite_base/integral-types.h b/native/lang_id/common/lite_base/integral-types.h
index 4c3038c..9b02296 100644
--- a/native/lang_id/common/lite_base/integral-types.h
+++ b/native/lang_id/common/lite_base/integral-types.h
@@ -19,11 +19,13 @@
#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_BASE_INTEGRAL_TYPES_H_
#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_BASE_INTEGRAL_TYPES_H_
+#include <cstdint>
+
namespace libtextclassifier3 {
namespace mobile {
typedef unsigned int uint32;
-typedef unsigned long long uint64;
+typedef uint64_t uint64;
#ifndef SWIG
typedef int int32;
@@ -37,11 +39,7 @@
typedef signed int char32;
#endif // SWIG
-#ifdef COMPILER_MSVC
-typedef __int64 int64;
-#else
-typedef long long int64; // NOLINT
-#endif // COMPILER_MSVC
+using int64 = int64_t;
// Some compile-time assertions that our new types have the intended size.
static_assert(sizeof(int) == 4, "Our typedefs depend on int being 32 bits");
diff --git a/native/lang_id/lang-id-wrapper.cc b/native/lang_id/lang-id-wrapper.cc
index 4246cce..baeb0f2 100644
--- a/native/lang_id/lang-id-wrapper.cc
+++ b/native/lang_id/lang-id-wrapper.cc
@@ -40,6 +40,13 @@
return langid_model;
}
+std::unique_ptr<libtextclassifier3::mobile::lang_id::LangId> LoadFromUnownedBuffer(
+ const char* buffer, int size) {
+ std::unique_ptr<libtextclassifier3::mobile::lang_id::LangId> langid_model =
+ libtextclassifier3::mobile::lang_id::GetLangIdFromFlatbufferBytes(buffer, size);
+ return langid_model;
+}
+
std::vector<std::pair<std::string, float>> GetPredictions(
const libtextclassifier3::mobile::lang_id::LangId* model, const std::string& text) {
return GetPredictions(model, text.data(), text.size());
diff --git a/native/lang_id/lang-id-wrapper.h b/native/lang_id/lang-id-wrapper.h
index 47e6f44..8e65af7 100644
--- a/native/lang_id/lang-id-wrapper.h
+++ b/native/lang_id/lang-id-wrapper.h
@@ -35,6 +35,11 @@
std::unique_ptr<libtextclassifier3::mobile::lang_id::LangId> LoadFromDescriptor(
const int fd);
+// Loads the LangId model from a buffer. The buffer needs to outlive the LangId
+// instance.
+std::unique_ptr<libtextclassifier3::mobile::lang_id::LangId> LoadFromUnownedBuffer(
+ const char* buffer, int size);
+
// Returns the LangId predictions (locale, confidence) from the given LangId
// model. The maximum number of predictions returned will be computed internally
// relatively to the noise threshold.
diff --git a/native/lang_id/lang-id_jni.cc b/native/lang_id/lang-id_jni.cc
index 30753dc..e4bb5d8 100644
--- a/native/lang_id/lang-id_jni.cc
+++ b/native/lang_id/lang-id_jni.cc
@@ -28,9 +28,9 @@
#include "lang_id/lang-id.h"
using libtextclassifier3::JniHelper;
+using libtextclassifier3::JStringToUtf8String;
using libtextclassifier3::ScopedLocalRef;
using libtextclassifier3::StatusOr;
-using libtextclassifier3::ToStlString;
using libtextclassifier3::mobile::lang_id::GetLangIdFromFlatbufferFile;
using libtextclassifier3::mobile::lang_id::GetLangIdFromFlatbufferFileDescriptor;
using libtextclassifier3::mobile::lang_id::LangId;
@@ -63,7 +63,7 @@
env, result_class.get(), result_class_constructor,
predicted_language.get(),
static_cast<jfloat>(lang_id_predictions[i].second)));
- env->SetObjectArrayElement(results.get(), i, result.get());
+ JniHelper::SetObjectArrayElement(env, results.get(), i, result.get());
}
return results;
}
@@ -74,7 +74,7 @@
} // namespace
TC3_JNI_METHOD(jlong, TC3_LANG_ID_CLASS_NAME, nativeNew)
-(JNIEnv* env, jobject thiz, jint fd) {
+(JNIEnv* env, jobject clazz, jint fd) {
std::unique_ptr<LangId> lang_id = GetLangIdFromFlatbufferFileDescriptor(fd);
if (!lang_id->is_valid()) {
return reinterpret_cast<jlong>(nullptr);
@@ -83,8 +83,9 @@
}
TC3_JNI_METHOD(jlong, TC3_LANG_ID_CLASS_NAME, nativeNewFromPath)
-(JNIEnv* env, jobject thiz, jstring path) {
- TC3_ASSIGN_OR_RETURN_0(const std::string path_str, ToStlString(env, path));
+(JNIEnv* env, jobject clazz, jstring path) {
+ TC3_ASSIGN_OR_RETURN_0(const std::string path_str,
+ JStringToUtf8String(env, path));
std::unique_ptr<LangId> lang_id = GetLangIdFromFlatbufferFile(path_str);
if (!lang_id->is_valid()) {
return reinterpret_cast<jlong>(nullptr);
@@ -93,13 +94,14 @@
}
TC3_JNI_METHOD(jobjectArray, TC3_LANG_ID_CLASS_NAME, nativeDetectLanguages)
-(JNIEnv* env, jobject clazz, jlong ptr, jstring text) {
+(JNIEnv* env, jobject thiz, jlong ptr, jstring text) {
LangId* model = reinterpret_cast<LangId*>(ptr);
if (!model) {
return nullptr;
}
- TC3_ASSIGN_OR_RETURN_NULL(const std::string text_str, ToStlString(env, text));
+ TC3_ASSIGN_OR_RETURN_NULL(const std::string text_str,
+ JStringToUtf8String(env, text));
const std::vector<std::pair<std::string, float>>& prediction_results =
libtextclassifier3::langid::GetPredictions(model, text_str);
@@ -111,7 +113,7 @@
}
TC3_JNI_METHOD(void, TC3_LANG_ID_CLASS_NAME, nativeClose)
-(JNIEnv* env, jobject clazz, jlong ptr) {
+(JNIEnv* env, jobject thiz, jlong ptr) {
if (!ptr) {
TC3_LOG(ERROR) << "Trying to close null LangId.";
return;
@@ -121,7 +123,7 @@
}
TC3_JNI_METHOD(jint, TC3_LANG_ID_CLASS_NAME, nativeGetVersion)
-(JNIEnv* env, jobject clazz, jlong ptr) {
+(JNIEnv* env, jobject thiz, jlong ptr) {
if (!ptr) {
return -1;
}
diff --git a/native/lang_id/lang-id_jni.h b/native/lang_id/lang-id_jni.h
index 219349c..5eb2b00 100644
--- a/native/lang_id/lang-id_jni.h
+++ b/native/lang_id/lang-id_jni.h
@@ -40,13 +40,13 @@
(JNIEnv* env, jobject clazz, jstring path);
TC3_JNI_METHOD(jobjectArray, TC3_LANG_ID_CLASS_NAME, nativeDetectLanguages)
-(JNIEnv* env, jobject clazz, jlong ptr, jstring text);
+(JNIEnv* env, jobject thiz, jlong ptr, jstring text);
TC3_JNI_METHOD(void, TC3_LANG_ID_CLASS_NAME, nativeClose)
-(JNIEnv* env, jobject clazz, jlong ptr);
+(JNIEnv* env, jobject thiz, jlong ptr);
TC3_JNI_METHOD(jint, TC3_LANG_ID_CLASS_NAME, nativeGetVersion)
-(JNIEnv* env, jobject clazz, jlong ptr);
+(JNIEnv* env, jobject thiz, jlong ptr);
TC3_JNI_METHOD(jint, TC3_LANG_ID_CLASS_NAME, nativeGetVersionFromFd)
(JNIEnv* env, jobject clazz, jint fd);
diff --git a/native/models/actions_suggestions.en.model b/native/models/actions_suggestions.en.model
index d4b0ced..74422f6 100755
--- a/native/models/actions_suggestions.en.model
+++ b/native/models/actions_suggestions.en.model
Binary files differ
diff --git a/native/models/actions_suggestions.universal.model b/native/models/actions_suggestions.universal.model
index 2ee546c..f74fed4 100755
--- a/native/models/actions_suggestions.universal.model
+++ b/native/models/actions_suggestions.universal.model
Binary files differ
diff --git a/native/models/textclassifier.ar.model b/native/models/textclassifier.ar.model
index 923d8af..d9710b9 100755
--- a/native/models/textclassifier.ar.model
+++ b/native/models/textclassifier.ar.model
Binary files differ
diff --git a/native/models/textclassifier.en.model b/native/models/textclassifier.en.model
index aec4302..f5fcc23 100755
--- a/native/models/textclassifier.en.model
+++ b/native/models/textclassifier.en.model
Binary files differ
diff --git a/native/models/textclassifier.es.model b/native/models/textclassifier.es.model
index 7ff3d73..33dddec 100755
--- a/native/models/textclassifier.es.model
+++ b/native/models/textclassifier.es.model
Binary files differ
diff --git a/native/models/textclassifier.fr.model b/native/models/textclassifier.fr.model
index cc5f488..45df2f1 100755
--- a/native/models/textclassifier.fr.model
+++ b/native/models/textclassifier.fr.model
Binary files differ
diff --git a/native/models/textclassifier.it.model b/native/models/textclassifier.it.model
index 5d40ef5..70bf151 100755
--- a/native/models/textclassifier.it.model
+++ b/native/models/textclassifier.it.model
Binary files differ
diff --git a/native/models/textclassifier.ja.model b/native/models/textclassifier.ja.model
index 9d65601..d28801a 100755
--- a/native/models/textclassifier.ja.model
+++ b/native/models/textclassifier.ja.model
Binary files differ
diff --git a/native/models/textclassifier.ko.model b/native/models/textclassifier.ko.model
index becba7a..c4bacdb 100755
--- a/native/models/textclassifier.ko.model
+++ b/native/models/textclassifier.ko.model
Binary files differ
diff --git a/native/models/textclassifier.nl.model b/native/models/textclassifier.nl.model
index bac8350..78e2f46 100755
--- a/native/models/textclassifier.nl.model
+++ b/native/models/textclassifier.nl.model
Binary files differ
diff --git a/native/models/textclassifier.pl.model b/native/models/textclassifier.pl.model
index 03b2825..6090f54 100755
--- a/native/models/textclassifier.pl.model
+++ b/native/models/textclassifier.pl.model
Binary files differ
diff --git a/native/models/textclassifier.pt.model b/native/models/textclassifier.pt.model
index 39f0b12..7ab45d8 100755
--- a/native/models/textclassifier.pt.model
+++ b/native/models/textclassifier.pt.model
Binary files differ
diff --git a/native/models/textclassifier.ru.model b/native/models/textclassifier.ru.model
index 6d08044..441d3fe 100755
--- a/native/models/textclassifier.ru.model
+++ b/native/models/textclassifier.ru.model
Binary files differ
diff --git a/native/models/textclassifier.th.model b/native/models/textclassifier.th.model
index 5e0f9dd..6a0a0d8 100755
--- a/native/models/textclassifier.th.model
+++ b/native/models/textclassifier.th.model
Binary files differ
diff --git a/native/models/textclassifier.tr.model b/native/models/textclassifier.tr.model
index 2dbc1d8..5752ba0 100755
--- a/native/models/textclassifier.tr.model
+++ b/native/models/textclassifier.tr.model
Binary files differ
diff --git a/native/models/textclassifier.universal.model b/native/models/textclassifier.universal.model
index 853e389..971c2af 100755
--- a/native/models/textclassifier.universal.model
+++ b/native/models/textclassifier.universal.model
Binary files differ
diff --git a/native/models/textclassifier.zh.model b/native/models/textclassifier.zh.model
index 8d989d7..ef9a536 100755
--- a/native/models/textclassifier.zh.model
+++ b/native/models/textclassifier.zh.model
Binary files differ
diff --git a/native/util/hash/hash.cc b/native/util/hash/hash.cc
deleted file mode 100644
index eaa85ae..0000000
--- a/native/util/hash/hash.cc
+++ /dev/null
@@ -1,79 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "util/hash/hash.h"
-
-#include "utils/base/macros.h"
-
-namespace libtextclassifier2 {
-
-namespace {
-// Lower-level versions of Get... that read directly from a character buffer
-// without any bounds checking.
-inline uint32 DecodeFixed32(const char *ptr) {
- return ((static_cast<uint32>(static_cast<unsigned char>(ptr[0]))) |
- (static_cast<uint32>(static_cast<unsigned char>(ptr[1])) << 8) |
- (static_cast<uint32>(static_cast<unsigned char>(ptr[2])) << 16) |
- (static_cast<uint32>(static_cast<unsigned char>(ptr[3])) << 24));
-}
-
-// 0xff is in case char is signed.
-static inline uint32 ByteAs32(char c) { return static_cast<uint32>(c) & 0xff; }
-} // namespace
-
-uint32 Hash32(const char *data, size_t n, uint32 seed) {
- // 'm' and 'r' are mixing constants generated offline.
- // They're not really 'magic', they just happen to work well.
- const uint32 m = 0x5bd1e995;
- const int r = 24;
-
- // Initialize the hash to a 'random' value
- uint32 h = static_cast<uint32>(seed ^ n);
-
- // Mix 4 bytes at a time into the hash
- while (n >= 4) {
- uint32 k = DecodeFixed32(data);
- k *= m;
- k ^= k >> r;
- k *= m;
- h *= m;
- h ^= k;
- data += 4;
- n -= 4;
- }
-
- // Handle the last few bytes of the input array
- switch (n) {
- case 3:
- h ^= ByteAs32(data[2]) << 16;
- TC3_FALLTHROUGH_INTENDED;
- case 2:
- h ^= ByteAs32(data[1]) << 8;
- TC3_FALLTHROUGH_INTENDED;
- case 1:
- h ^= ByteAs32(data[0]);
- h *= m;
- }
-
- // Do a few final mixes of the hash to ensure the last few
- // bytes are well-incorporated.
- h ^= h >> 13;
- h *= m;
- h ^= h >> 15;
- return h;
-}
-
-} // namespace libtextclassifier2
diff --git a/native/util/hash/hash.h b/native/util/hash/hash.h
deleted file mode 100644
index 9353e5f..0000000
--- a/native/util/hash/hash.h
+++ /dev/null
@@ -1,40 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_UTIL_HASH_HASH_H_
-#define LIBTEXTCLASSIFIER_UTIL_HASH_HASH_H_
-
-#include <string>
-
-#include "utils/base/integral_types.h"
-
-namespace libtextclassifier2 {
-
-using namespace libtextclassifier3;
-
-uint32 Hash32(const char *data, size_t n, uint32 seed);
-
-static inline uint32 Hash32WithDefaultSeed(const char *data, size_t n) {
- return Hash32(data, n, 0xBEEF);
-}
-
-static inline uint32 Hash32WithDefaultSeed(const std::string &input) {
- return Hash32WithDefaultSeed(input.data(), input.size());
-}
-
-} // namespace libtextclassifier2
-
-#endif // LIBTEXTCLASSIFIER_UTIL_HASH_HASH_H_
diff --git a/native/utils/base/arena.h b/native/utils/base/arena.h
index 28b6f6c..712deeb 100644
--- a/native/utils/base/arena.h
+++ b/native/utils/base/arena.h
@@ -204,7 +204,7 @@
// Allocates and initializes an object on the arena.
template <typename T, typename... Args>
- T* AllocAndInit(Args... args) {
+ T* AllocAndInit(Args&&... args) {
return new (reinterpret_cast<T*>(AllocAligned(sizeof(T), alignof(T))))
T(std::forward<Args>(args)...);
}
diff --git a/native/utils/base/logging.h b/native/utils/base/logging.h
index eae71b9..ce7cac8 100644
--- a/native/utils/base/logging.h
+++ b/native/utils/base/logging.h
@@ -24,7 +24,6 @@
#include "utils/base/logging_levels.h"
#include "utils/base/port.h"
-
namespace libtextclassifier3 {
namespace logging {
diff --git a/native/utils/base/status_macros.h b/native/utils/base/status_macros.h
index 40159fe..604b5b3 100644
--- a/native/utils/base/status_macros.h
+++ b/native/utils/base/status_macros.h
@@ -56,11 +56,20 @@
// TC3_RETURN_IF_ERROR(foo.Method(args...));
// return libtextclassifier3::Status();
// }
-#define TC3_RETURN_IF_ERROR(expr) \
+#define TC3_RETURN_IF_ERROR(expr) \
+ TC3_RETURN_IF_ERROR_INTERNAL(expr, std::move(adapter).status())
+
+#define TC3_RETURN_NULL_IF_ERROR(expr) \
+ TC3_RETURN_IF_ERROR_INTERNAL(expr, nullptr)
+
+#define TC3_RETURN_FALSE_IF_ERROR(expr) \
+ TC3_RETURN_IF_ERROR_INTERNAL(expr, false)
+
+#define TC3_RETURN_IF_ERROR_INTERNAL(expr, return_value) \
TC3_STATUS_MACROS_IMPL_ELSE_BLOCKER_ \
if (::libtextclassifier3::StatusAdapter adapter{expr}) { \
} else /* NOLINT */ \
- return std::move(adapter).status()
+ return return_value
// The GNU compiler emits a warning for code like:
//
diff --git a/native/utils/base/statusor.h b/native/utils/base/statusor.h
index dde9ecd..1bafcc7 100644
--- a/native/utils/base/statusor.h
+++ b/native/utils/base/statusor.h
@@ -34,7 +34,7 @@
inline StatusOr();
// Builds from a non-OK status. Crashes if an OK status is specified.
- inline StatusOr(const Status& status); // NOLINT
+ inline StatusOr(const Status& status); // NOLINT
// Builds from the specified value.
inline StatusOr(const T& value); // NOLINT
@@ -88,6 +88,8 @@
// Conversion assignment operator, T must be assignable from U
template <typename U>
inline StatusOr& operator=(const StatusOr<U>& other);
+ template <typename U>
+ inline StatusOr& operator=(StatusOr<U>&& other);
inline ~StatusOr();
@@ -136,6 +138,40 @@
friend class StatusOr;
private:
+ void Clear() {
+ if (ok()) {
+ value_.~T();
+ }
+ }
+
+ // Construct the value through placement new with the passed argument.
+ template <typename... Arg>
+ void MakeValue(Arg&&... arg) {
+ new (&value_) T(std::forward<Arg>(arg)...);
+ }
+
+ // Creates a valid instance of type T constructed with U and assigns it to
+ // value_. Handles how to properly assign to value_ if value_ was never
+ // actually initialized (if this is currently non-OK).
+ template <typename U>
+ void AssignValue(U&& value) {
+ if (ok()) {
+ value_ = std::forward<U>(value);
+ } else {
+ MakeValue(std::forward<U>(value));
+ status_ = Status::OK;
+ }
+ }
+
+ // Creates a status constructed with U and assigns it to status_. It also
+ // properly destroys value_ if this is OK and value_ represents a valid
+ // instance of T.
+ template <typename U>
+ void AssignStatus(U&& v) {
+ Clear();
+ status_ = static_cast<Status>(std::forward<U>(v));
+ }
+
Status status_;
// The members of unions do not require initialization and are not destructed
// unless specifically called. This allows us to construct instances of
@@ -212,35 +248,47 @@
template <typename T>
inline StatusOr<T>& StatusOr<T>::operator=(const StatusOr& other) {
- status_ = other.status_;
- if (status_.ok()) {
- value_ = other.value_;
+ if (other.ok()) {
+ AssignValue(other.value_);
+ } else {
+ AssignStatus(other.status_);
}
return *this;
}
template <typename T>
inline StatusOr<T>& StatusOr<T>::operator=(StatusOr&& other) {
- status_ = other.status_;
- if (status_.ok()) {
- value_ = std::move(other.value_);
+ if (other.ok()) {
+ AssignValue(std::move(other.value_));
+ } else {
+ AssignStatus(std::move(other.status_));
}
return *this;
}
template <typename T>
inline StatusOr<T>::~StatusOr() {
- if (ok()) {
- value_.~T();
- }
+ Clear();
}
template <typename T>
template <typename U>
inline StatusOr<T>& StatusOr<T>::operator=(const StatusOr<U>& other) {
- status_ = other.status_;
- if (status_.ok()) {
- value_ = other.value_;
+ if (other.ok()) {
+ AssignValue(other.value_);
+ } else {
+ AssignStatus(other.status_);
+ }
+ return *this;
+}
+
+template <typename T>
+template <typename U>
+inline StatusOr<T>& StatusOr<T>::operator=(StatusOr<U>&& other) {
+ if (other.ok()) {
+ AssignValue(std::move(other.value_));
+ } else {
+ AssignStatus(std::move(other.status_));
}
return *this;
}
@@ -259,7 +307,17 @@
#define TC3_ASSIGN_OR_RETURN_FALSE(lhs, rexpr) \
TC3_ASSIGN_OR_RETURN(lhs, rexpr, false)
-#define TC3_ASSIGN_OR_RETURN_0(lhs, rexpr) TC3_ASSIGN_OR_RETURN(lhs, rexpr, 0)
+#define TC3_ASSIGN_OR_RETURN_0(...) \
+ TC_STATUS_MACROS_IMPL_GET_VARIADIC_( \
+ (__VA_ARGS__, TC_STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_0_3_, \
+ TC_STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_0_2_)) \
+ (__VA_ARGS__)
+
+#define TC_STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_0_2_(lhs, rexpr) \
+ TC3_ASSIGN_OR_RETURN(lhs, rexpr, 0)
+#define TC_STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_0_3_(lhs, rexpr, \
+ log_expression) \
+ TC3_ASSIGN_OR_RETURN(lhs, rexpr, (log_expression, 0))
// =================================================================
// == Implementation details, do not rely on anything below here. ==
@@ -281,11 +339,11 @@
#define TC_STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_2_(lhs, rexpr) \
TC_STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_3_(lhs, rexpr, _)
-#define TC_STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_3_(lhs, rexpr, \
- error_expression) \
- TC_STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_( \
- TC_STATUS_MACROS_IMPL_CONCAT_(_status_or_value, __LINE__), lhs, rexpr, \
- error_expression)
+#define TC_STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_3_(lhs, rexpr, \
+ error_expression) \
+ TC_STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_( \
+ TC_STATUS_MACROS_IMPL_CONCAT_(_status_or_value, __COUNTER__), lhs, \
+ rexpr, error_expression)
#define TC_STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_(statusor, lhs, rexpr, \
error_expression) \
auto statusor = (rexpr); \
diff --git a/native/utils/base/statusor_test.cc b/native/utils/base/statusor_test.cc
index 23165b0..04ac8ee 100644
--- a/native/utils/base/statusor_test.cc
+++ b/native/utils/base/statusor_test.cc
@@ -103,6 +103,84 @@
EXPECT_FALSE(moved_error_status.ok());
}
+// Create a class that has validly defined copy and move operators, but will
+// cause a crash if assignment operators are invoked on an instance that was
+// never initialized.
+class Baz {
+ public:
+ Baz() : i_(new int), invalid_(false) {}
+ Baz(const Baz& other) {
+ i_ = new int;
+ *i_ = *other.i_;
+ invalid_ = false;
+ }
+ Baz(const Foo& other) { // NOLINT
+ i_ = new int;
+ *i_ = other.i();
+ invalid_ = false;
+ }
+ Baz(Baz&& other) {
+ // Copy other.i_ into this so that this holds it now. Mark other as invalid
+ // so that it doesn't destroy the int that this now owns when other is
+ // destroyed.
+ i_ = other.i_;
+ other.invalid_ = true;
+ invalid_ = false;
+ }
+ Baz& operator=(const Baz& rhs) {
+ // Copy rhs.i_ into tmp. Then swap this with tmp so that this no has the
+ // value that rhs had and tmp will destroy the value that this used to hold.
+ Baz tmp(rhs);
+ std::swap(i_, tmp.i_);
+ return *this;
+ }
+ Baz& operator=(Baz&& rhs) {
+ std::swap(i_, rhs.i_);
+ return *this;
+ }
+ ~Baz() {
+ if (!invalid_) delete i_;
+ }
+
+ private:
+ int* i_;
+ bool invalid_;
+};
+
+TEST(StatusOrTest, CopyAssignment) {
+ StatusOr<Baz> baz_or;
+ EXPECT_FALSE(baz_or.ok());
+ Baz b;
+ StatusOr<Baz> other(b);
+ baz_or = other;
+ EXPECT_TRUE(baz_or.ok());
+ EXPECT_TRUE(other.ok());
+}
+
+TEST(StatusOrTest, MoveAssignment) {
+ StatusOr<Baz> baz_or;
+ EXPECT_FALSE(baz_or.ok());
+ baz_or = StatusOr<Baz>(Baz());
+ EXPECT_TRUE(baz_or.ok());
+}
+
+TEST(StatusOrTest, CopyConversionAssignment) {
+ StatusOr<Baz> baz_or;
+ EXPECT_FALSE(baz_or.ok());
+ StatusOr<Foo> foo_or(Foo(12));
+ baz_or = foo_or;
+ EXPECT_TRUE(baz_or.ok());
+ EXPECT_TRUE(foo_or.ok());
+}
+
+TEST(StatusOrTest, MoveConversionAssignment) {
+ StatusOr<Baz> baz_or;
+ EXPECT_FALSE(baz_or.ok());
+ StatusOr<Foo> foo_or(Foo(12));
+ baz_or = std::move(foo_or);
+ EXPECT_TRUE(baz_or.ok());
+}
+
struct OkFn {
StatusOr<int> operator()() { return 42; }
};
diff --git a/native/utils/container/bit-vector.cc b/native/utils/container/bit-vector.cc
new file mode 100644
index 0000000..388e488
--- /dev/null
+++ b/native/utils/container/bit-vector.cc
@@ -0,0 +1,96 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/container/bit-vector.h"
+
+#include <math.h>
+
+#include <algorithm>
+#include <memory>
+#include <vector>
+
+#include "utils/base/logging.h"
+#include "utils/container/bit-vector_generated.h"
+
+namespace libtextclassifier3 {
+
+BitVector::BitVector(const BitVectorData* bit_vector_data)
+ : bit_vector_data_(bit_vector_data) {}
+
+bool BitVector::GetFromSparseData(int32 index) const {
+ return std::binary_search(
+ bit_vector_data_->sparse_data()->sorted_indices_32()->begin(),
+ bit_vector_data_->sparse_data()->sorted_indices_32()->end(), index);
+}
+
+bool BitVector::GetFromDenseData(int32 index) const {
+ if (index >= bit_vector_data_->dense_data()->size()) {
+ return false;
+ }
+ int32 byte_index = index / 8;
+ uint8 extracted_byte =
+ bit_vector_data_->dense_data()->data()->Get(byte_index);
+ uint8 bit_index = index % 8;
+ return extracted_byte & (1 << bit_index);
+}
+
+bool BitVector::Get(int32 index) const {
+ TC3_DCHECK(index >= 0);
+
+ if (bit_vector_data_ == nullptr) {
+ return false;
+ }
+ if (bit_vector_data_->dense_data() != nullptr) {
+ return GetFromDenseData(index);
+ }
+ return GetFromSparseData(index);
+}
+
+std::unique_ptr<BitVectorDataT> BitVector::CreateSparseBitVectorData(
+ const std::vector<int32>& indices) {
+ auto bit_vector_data = std::make_unique<BitVectorDataT>();
+ bit_vector_data->sparse_data =
+ std::make_unique<libtextclassifier3::SparseBitVectorDataT>();
+ bit_vector_data->sparse_data->sorted_indices_32 = indices;
+ return bit_vector_data;
+}
+
+std::unique_ptr<BitVectorDataT> BitVector::CreateDenseBitVectorData(
+ const std::vector<bool>& data) {
+ uint8_t temp = 0;
+ std::vector<uint8_t> result;
+ for (int i = 0; i < data.size(); i++) {
+ if (i != 0 && (i % 8) == 0) {
+ result.push_back(temp);
+ temp = 0;
+ }
+ if (data[i]) {
+ temp += (1 << (i % 8));
+ }
+ }
+ if ((data.size() % 8) != 0) {
+ result.push_back(temp);
+ }
+
+ auto bit_vector_data = std::make_unique<BitVectorDataT>();
+ bit_vector_data->dense_data =
+ std::make_unique<libtextclassifier3::DenseBitVectorDataT>();
+ bit_vector_data->dense_data->data = result;
+ bit_vector_data->dense_data->size = data.size();
+ return bit_vector_data;
+}
+
+} // namespace libtextclassifier3
diff --git a/native/utils/container/bit-vector.fbs b/native/utils/container/bit-vector.fbs
new file mode 100755
index 0000000..d117ee5
--- /dev/null
+++ b/native/utils/container/bit-vector.fbs
@@ -0,0 +1,40 @@
+//
+// Copyright (C) 2018 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+// The data that is used to represent a BitVector.
+namespace libtextclassifier3;
+table BitVectorData {
+ dense_data:DenseBitVectorData;
+ sparse_data:SparseBitVectorData;
+}
+
+// A dense representation of a bit vector.
+namespace libtextclassifier3;
+table DenseBitVectorData {
+ // The bits.
+ data:[ubyte];
+
+ // Number of bits in this bit vector.
+ size:int;
+}
+
+// A sparse representation of a bit vector.
+namespace libtextclassifier3;
+table SparseBitVectorData {
+ // A vector of sorted indices of elements that are 1.
+ sorted_indices_32:[int];
+}
+
diff --git a/native/utils/container/bit-vector.h b/native/utils/container/bit-vector.h
new file mode 100644
index 0000000..f6716d5
--- /dev/null
+++ b/native/utils/container/bit-vector.h
@@ -0,0 +1,57 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_CONTAINER_BIT_VECTOR_H_
+#define LIBTEXTCLASSIFIER_UTILS_CONTAINER_BIT_VECTOR_H_
+
+#include <set>
+#include <vector>
+
+#include "utils/base/integral_types.h"
+#include "utils/container/bit-vector_generated.h"
+
+namespace libtextclassifier3 {
+
+// A read-only bit vector. It does not own the data and it is like a view on
+// the given data. There are two internal representations, sparse and dense.
+// The dense one stores every bits. The sparse stores only the indices of
+// elements that are 1.
+class BitVector {
+ public:
+ explicit BitVector(const BitVectorData* bit_vector_data);
+
+ // Gets a particular bit. If the underlying data does not contain the
+ // value of the asked bit, false is returned.
+ const bool operator[](int index) const { return Get(index); }
+
+ // Creates a BitVectorDataT using the dense representation.
+ static std::unique_ptr<BitVectorDataT> CreateDenseBitVectorData(
+ const std::vector<bool>& data);
+
+ // Creates a BitVectorDataT using the sparse representation.
+ static std::unique_ptr<BitVectorDataT> CreateSparseBitVectorData(
+ const std::vector<int32>& indices);
+
+ private:
+ const BitVectorData* bit_vector_data_;
+
+ bool Get(int index) const;
+ bool GetFromSparseData(int index) const;
+ bool GetFromDenseData(int index) const;
+};
+
+} // namespace libtextclassifier3
+#endif // LIBTEXTCLASSIFIER_UTILS_CONTAINER_BIT_VECTOR_H_
diff --git a/native/utils/container/bit-vector_test.cc b/native/utils/container/bit-vector_test.cc
new file mode 100644
index 0000000..dfa67e8
--- /dev/null
+++ b/native/utils/container/bit-vector_test.cc
@@ -0,0 +1,85 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/container/bit-vector.h"
+
+#include <memory>
+
+#include "utils/container/bit-vector_generated.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+TEST(BitVectorTest, Dense) {
+ std::vector<bool> data = {false, true, true, true, false,
+ false, true, false, false, true};
+
+ std::unique_ptr<BitVectorDataT> mutable_bit_vector_data =
+ BitVector::CreateDenseBitVectorData(data);
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(BitVectorData::Pack(builder, mutable_bit_vector_data.get()));
+ const flatbuffers::DetachedBuffer buffer = builder.Release();
+ const BitVectorData* bit_vector_data =
+ flatbuffers::GetRoot<BitVectorData>(buffer.data());
+
+ BitVector bit_vector(bit_vector_data);
+ EXPECT_EQ(bit_vector[0], false);
+ EXPECT_EQ(bit_vector[1], true);
+ EXPECT_EQ(bit_vector[2], true);
+ EXPECT_EQ(bit_vector[3], true);
+ EXPECT_EQ(bit_vector[4], false);
+ EXPECT_EQ(bit_vector[5], false);
+ EXPECT_EQ(bit_vector[6], true);
+ EXPECT_EQ(bit_vector[7], false);
+ EXPECT_EQ(bit_vector[8], false);
+ EXPECT_EQ(bit_vector[9], true);
+ EXPECT_EQ(bit_vector[10], false);
+}
+
+TEST(BitVectorTest, Sparse) {
+ std::vector<int32> sorted_indices = {3, 7};
+
+ std::unique_ptr<BitVectorDataT> mutable_bit_vector_data =
+ BitVector::CreateSparseBitVectorData(sorted_indices);
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(BitVectorData::Pack(builder, mutable_bit_vector_data.get()));
+ const flatbuffers::DetachedBuffer buffer = builder.Release();
+ const BitVectorData* bit_vector_data =
+ flatbuffers::GetRoot<BitVectorData>(buffer.data());
+
+ BitVector bit_vector(bit_vector_data);
+ EXPECT_EQ(bit_vector[0], false);
+ EXPECT_EQ(bit_vector[1], false);
+ EXPECT_EQ(bit_vector[2], false);
+ EXPECT_EQ(bit_vector[3], true);
+ EXPECT_EQ(bit_vector[4], false);
+ EXPECT_EQ(bit_vector[5], false);
+ EXPECT_EQ(bit_vector[6], false);
+ EXPECT_EQ(bit_vector[7], true);
+ EXPECT_EQ(bit_vector[8], false);
+}
+
+TEST(BitVectorTest, Null) {
+ BitVector bit_vector(nullptr);
+
+ EXPECT_EQ(bit_vector[0], false);
+ EXPECT_EQ(bit_vector[1], false);
+}
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/native/utils/container/double-array-trie_test.cc b/native/utils/container/double-array-trie_test.cc
new file mode 100644
index 0000000..b639d53
--- /dev/null
+++ b/native/utils/container/double-array-trie_test.cc
@@ -0,0 +1,132 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/container/double-array-trie.h"
+
+#include <fstream>
+#include <string>
+#include <vector>
+
+#include "utils/test-data-test-utils.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+std::string GetTestConfigPath() {
+ return GetTestDataPath("utils/container/test_data/test_trie.bin");
+}
+
+TEST(DoubleArrayTest, Lookup) {
+ // Test trie that contains pieces "hell", "hello", "o", "there".
+ std::ifstream test_config_stream(GetTestConfigPath());
+ std::string config((std::istreambuf_iterator<char>(test_config_stream)),
+ (std::istreambuf_iterator<char>()));
+ DoubleArrayTrie trie(reinterpret_cast<const TrieNode*>(config.data()),
+ config.size() / sizeof(TrieNode));
+
+ {
+ std::vector<StringSet::Match> matches;
+ EXPECT_TRUE(trie.FindAllPrefixMatches("hello there", &matches));
+ EXPECT_EQ(matches.size(), 2);
+ EXPECT_EQ(matches[0].id, 0 /*hell*/);
+ EXPECT_EQ(matches[0].match_length, 4 /*hell*/);
+ EXPECT_EQ(matches[1].id, 1 /*hello*/);
+ EXPECT_EQ(matches[1].match_length, 5 /*hello*/);
+ }
+
+ {
+ std::vector<StringSet::Match> matches;
+ EXPECT_TRUE(trie.FindAllPrefixMatches("he", &matches));
+ EXPECT_THAT(matches, testing::IsEmpty());
+ }
+
+ {
+ std::vector<StringSet::Match> matches;
+ EXPECT_TRUE(trie.FindAllPrefixMatches("abcd", &matches));
+ EXPECT_THAT(matches, testing::IsEmpty());
+ }
+
+ {
+ std::vector<StringSet::Match> matches;
+ EXPECT_TRUE(trie.FindAllPrefixMatches("", &matches));
+ EXPECT_THAT(matches, testing::IsEmpty());
+ }
+
+ {
+ std::vector<StringSet::Match> matches;
+ EXPECT_TRUE(trie.FindAllPrefixMatches("hi there", &matches));
+ EXPECT_THAT(matches, testing::IsEmpty());
+ }
+
+ {
+ std::vector<StringSet::Match> matches;
+ EXPECT_TRUE(trie.FindAllPrefixMatches(StringPiece("\0", 1), &matches));
+ EXPECT_THAT(matches, testing::IsEmpty());
+ }
+
+ {
+ std::vector<StringSet::Match> matches;
+ EXPECT_TRUE(
+ trie.FindAllPrefixMatches(StringPiece("\xff, \xfe", 2), &matches));
+ EXPECT_THAT(matches, testing::IsEmpty());
+ }
+
+ {
+ StringSet::Match match;
+ EXPECT_TRUE(trie.LongestPrefixMatch("hella there", &match));
+ EXPECT_EQ(match.id, 0 /*hell*/);
+ }
+
+ {
+ StringSet::Match match;
+ EXPECT_TRUE(trie.LongestPrefixMatch("hello there", &match));
+ EXPECT_EQ(match.id, 1 /*hello*/);
+ }
+
+ {
+ StringSet::Match match;
+ EXPECT_TRUE(trie.LongestPrefixMatch("abcd", &match));
+ EXPECT_EQ(match.id, -1);
+ }
+
+ {
+ StringSet::Match match;
+ EXPECT_TRUE(trie.LongestPrefixMatch("", &match));
+ EXPECT_EQ(match.id, -1);
+ }
+
+ {
+ int value;
+ EXPECT_TRUE(trie.Find("hell", &value));
+ EXPECT_EQ(value, 0);
+ }
+
+ {
+ int value;
+ EXPECT_FALSE(trie.Find("hella", &value));
+ }
+
+ {
+ int value;
+ EXPECT_TRUE(trie.Find("hello", &value));
+ EXPECT_EQ(value, 1 /*hello*/);
+ }
+}
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/native/utils/container/test_data/test_trie.bin b/native/utils/container/test_data/test_trie.bin
new file mode 100644
index 0000000..ade1f29
--- /dev/null
+++ b/native/utils/container/test_data/test_trie.bin
Binary files differ
diff --git a/native/utils/flatbuffers.fbs b/native/utils/flatbuffers/flatbuffers.fbs
similarity index 100%
rename from native/utils/flatbuffers.fbs
rename to native/utils/flatbuffers/flatbuffers.fbs
diff --git a/native/utils/flatbuffers/flatbuffers.h b/native/utils/flatbuffers/flatbuffers.h
new file mode 100644
index 0000000..f76e12d
--- /dev/null
+++ b/native/utils/flatbuffers/flatbuffers.h
@@ -0,0 +1,116 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Utility functions for working with FlatBuffers.
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_FLATBUFFERS_FLATBUFFERS_H_
+#define LIBTEXTCLASSIFIER_UTILS_FLATBUFFERS_FLATBUFFERS_H_
+
+#include <string>
+
+#include "annotator/model_generated.h"
+#include "flatbuffers/flatbuffers.h"
+
+namespace libtextclassifier3 {
+
+// Loads and interprets the buffer as 'FlatbufferMessage' and verifies its
+// integrity.
+template <typename FlatbufferMessage>
+const FlatbufferMessage* LoadAndVerifyFlatbuffer(const void* buffer, int size) {
+ const FlatbufferMessage* message =
+ flatbuffers::GetRoot<FlatbufferMessage>(buffer);
+ if (message == nullptr) {
+ return nullptr;
+ }
+ flatbuffers::Verifier verifier(reinterpret_cast<const uint8_t*>(buffer),
+ size);
+ if (message->Verify(verifier)) {
+ return message;
+ } else {
+ return nullptr;
+ }
+}
+
+// Same as above but takes string.
+template <typename FlatbufferMessage>
+const FlatbufferMessage* LoadAndVerifyFlatbuffer(const std::string& buffer) {
+ return LoadAndVerifyFlatbuffer<FlatbufferMessage>(buffer.c_str(),
+ buffer.size());
+}
+
+// Loads and interprets the buffer as 'FlatbufferMessage', verifies its
+// integrity and returns its mutable version.
+template <typename FlatbufferMessage>
+std::unique_ptr<typename FlatbufferMessage::NativeTableType>
+LoadAndVerifyMutableFlatbuffer(const void* buffer, int size) {
+ const FlatbufferMessage* message =
+ LoadAndVerifyFlatbuffer<FlatbufferMessage>(buffer, size);
+ if (message == nullptr) {
+ return nullptr;
+ }
+ return std::unique_ptr<typename FlatbufferMessage::NativeTableType>(
+ message->UnPack());
+}
+
+// Same as above but takes string.
+template <typename FlatbufferMessage>
+std::unique_ptr<typename FlatbufferMessage::NativeTableType>
+LoadAndVerifyMutableFlatbuffer(const std::string& buffer) {
+ return LoadAndVerifyMutableFlatbuffer<FlatbufferMessage>(buffer.c_str(),
+ buffer.size());
+}
+
+template <typename FlatbufferMessage>
+const char* FlatbufferFileIdentifier() {
+ return nullptr;
+}
+
+template <>
+inline const char* FlatbufferFileIdentifier<Model>() {
+ return ModelIdentifier();
+}
+
+// Packs the mutable flatbuffer message to string.
+template <typename FlatbufferMessage>
+std::string PackFlatbuffer(
+ const typename FlatbufferMessage::NativeTableType* mutable_message) {
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(FlatbufferMessage::Pack(builder, mutable_message),
+ FlatbufferFileIdentifier<FlatbufferMessage>());
+ return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize());
+}
+
+// A convenience flatbuffer object with its underlying buffer.
+template <typename T, typename B = flatbuffers::DetachedBuffer>
+class OwnedFlatbuffer {
+ public:
+ explicit OwnedFlatbuffer(B&& buffer) : buffer_(std::move(buffer)) {}
+
+ // Cast as flatbuffer type.
+ const T* get() const { return flatbuffers::GetRoot<T>(buffer_.data()); }
+
+ const T* operator->() const {
+ return flatbuffers::GetRoot<T>(buffer_.data());
+ }
+
+ private:
+ B buffer_;
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_FLATBUFFERS_FLATBUFFERS_H_
diff --git a/native/utils/flatbuffers.cc b/native/utils/flatbuffers/mutable.cc
similarity index 75%
rename from native/utils/flatbuffers.cc
rename to native/utils/flatbuffers/mutable.cc
index cf4c97f..b54732f 100644
--- a/native/utils/flatbuffers.cc
+++ b/native/utils/flatbuffers/mutable.cc
@@ -14,10 +14,11 @@
* limitations under the License.
*/
-#include "utils/flatbuffers.h"
+#include "utils/flatbuffers/mutable.h"
#include <vector>
+#include "utils/flatbuffers/reflection.h"
#include "utils/strings/numbers.h"
#include "utils/variant.h"
#include "flatbuffers/reflection_generated.h"
@@ -25,56 +26,6 @@
namespace libtextclassifier3 {
namespace {
-// Gets the field information for a field name, returns nullptr if the
-// field was not defined.
-const reflection::Field* GetFieldOrNull(const reflection::Object* type,
- const StringPiece field_name) {
- TC3_CHECK(type != nullptr && type->fields() != nullptr);
- return type->fields()->LookupByKey(field_name.data());
-}
-
-const reflection::Field* GetFieldOrNull(const reflection::Object* type,
- const int field_offset) {
- if (type->fields() == nullptr) {
- return nullptr;
- }
- for (const reflection::Field* field : *type->fields()) {
- if (field->offset() == field_offset) {
- return field;
- }
- }
- return nullptr;
-}
-
-const reflection::Field* GetFieldOrNull(const reflection::Object* type,
- const StringPiece field_name,
- const int field_offset) {
- // Lookup by name might be faster as the fields are sorted by name in the
- // schema data, so try that first.
- if (!field_name.empty()) {
- return GetFieldOrNull(type, field_name.data());
- }
- return GetFieldOrNull(type, field_offset);
-}
-
-const reflection::Field* GetFieldOrNull(const reflection::Object* type,
- const FlatbufferField* field) {
- TC3_CHECK(type != nullptr && field != nullptr);
- if (field->field_name() == nullptr) {
- return GetFieldOrNull(type, field->field_offset());
- }
- return GetFieldOrNull(
- type,
- StringPiece(field->field_name()->data(), field->field_name()->size()),
- field->field_offset());
-}
-
-const reflection::Field* GetFieldOrNull(const reflection::Object* type,
- const FlatbufferFieldT* field) {
- TC3_CHECK(type != nullptr && field != nullptr);
- return GetFieldOrNull(type, field->field_name, field->field_offset);
-}
-
bool Parse(const std::string& str_value, float* value) {
double double_value;
if (!ParseDouble(str_value.data(), &double_value)) {
@@ -103,8 +54,7 @@
template <typename T>
bool ParseAndSetField(const reflection::Field* field,
- const std::string& str_value,
- ReflectiveFlatbuffer* buffer) {
+ const std::string& str_value, MutableFlatbuffer* buffer) {
T value;
if (!Parse(str_value, &value)) {
TC3_LOG(ERROR) << "Could not parse '" << str_value << "'";
@@ -120,44 +70,48 @@
} // namespace
-template <>
-const char* FlatbufferFileIdentifier<Model>() {
- return ModelIdentifier();
+MutableFlatbufferBuilder::MutableFlatbufferBuilder(
+ const reflection::Schema* schema, StringPiece root_type)
+ : schema_(schema), root_type_(TypeForName(schema, root_type)) {}
+
+std::unique_ptr<MutableFlatbuffer> MutableFlatbufferBuilder::NewRoot() const {
+ return NewTable(root_type_);
}
-std::unique_ptr<ReflectiveFlatbuffer> ReflectiveFlatbufferBuilder::NewRoot()
- const {
- if (!schema_->root_table()) {
- TC3_LOG(ERROR) << "No root table specified.";
+std::unique_ptr<MutableFlatbuffer> MutableFlatbufferBuilder::NewTable(
+ StringPiece table_name) const {
+ return NewTable(TypeForName(schema_, table_name));
+}
+
+std::unique_ptr<MutableFlatbuffer> MutableFlatbufferBuilder::NewTable(
+ const int type_id) const {
+ if (type_id < 0 || type_id >= schema_->objects()->size()) {
+ TC3_LOG(ERROR) << "Invalid type id: " << type_id;
return nullptr;
}
- return std::unique_ptr<ReflectiveFlatbuffer>(
- new ReflectiveFlatbuffer(schema_, schema_->root_table()));
+ return NewTable(schema_->objects()->Get(type_id));
}
-std::unique_ptr<ReflectiveFlatbuffer> ReflectiveFlatbufferBuilder::NewTable(
- StringPiece table_name) const {
- for (const reflection::Object* object : *schema_->objects()) {
- if (table_name.Equals(object->name()->str())) {
- return std::unique_ptr<ReflectiveFlatbuffer>(
- new ReflectiveFlatbuffer(schema_, object));
- }
+std::unique_ptr<MutableFlatbuffer> MutableFlatbufferBuilder::NewTable(
+ const reflection::Object* type) const {
+ if (type == nullptr) {
+ return nullptr;
}
- return nullptr;
+ return std::make_unique<MutableFlatbuffer>(schema_, type);
}
-const reflection::Field* ReflectiveFlatbuffer::GetFieldOrNull(
+const reflection::Field* MutableFlatbuffer::GetFieldOrNull(
const StringPiece field_name) const {
return libtextclassifier3::GetFieldOrNull(type_, field_name);
}
-const reflection::Field* ReflectiveFlatbuffer::GetFieldOrNull(
+const reflection::Field* MutableFlatbuffer::GetFieldOrNull(
const FlatbufferField* field) const {
return libtextclassifier3::GetFieldOrNull(type_, field);
}
-bool ReflectiveFlatbuffer::GetFieldWithParent(
- const FlatbufferFieldPath* field_path, ReflectiveFlatbuffer** parent,
+bool MutableFlatbuffer::GetFieldWithParent(
+ const FlatbufferFieldPath* field_path, MutableFlatbuffer** parent,
reflection::Field const** field) {
const auto* path = field_path->field();
if (path == nullptr || path->size() == 0) {
@@ -178,13 +132,13 @@
return true;
}
-const reflection::Field* ReflectiveFlatbuffer::GetFieldOrNull(
+const reflection::Field* MutableFlatbuffer::GetFieldOrNull(
const int field_offset) const {
return libtextclassifier3::GetFieldOrNull(type_, field_offset);
}
-bool ReflectiveFlatbuffer::ParseAndSet(const reflection::Field* field,
- const std::string& value) {
+bool MutableFlatbuffer::ParseAndSet(const reflection::Field* field,
+ const std::string& value) {
switch (field->type()->base_type() == reflection::Vector
? field->type()->element()
: field->type()->base_type()) {
@@ -204,9 +158,9 @@
}
}
-bool ReflectiveFlatbuffer::ParseAndSet(const FlatbufferFieldPath* path,
- const std::string& value) {
- ReflectiveFlatbuffer* parent;
+bool MutableFlatbuffer::ParseAndSet(const FlatbufferFieldPath* path,
+ const std::string& value) {
+ MutableFlatbuffer* parent;
const reflection::Field* field;
if (!GetFieldWithParent(path, &parent, &field)) {
return false;
@@ -214,7 +168,7 @@
return parent->ParseAndSet(field, value);
}
-ReflectiveFlatbuffer* ReflectiveFlatbuffer::Add(StringPiece field_name) {
+MutableFlatbuffer* MutableFlatbuffer::Add(StringPiece field_name) {
const reflection::Field* field = GetFieldOrNull(field_name);
if (field == nullptr) {
return nullptr;
@@ -227,16 +181,14 @@
return Add(field);
}
-ReflectiveFlatbuffer* ReflectiveFlatbuffer::Add(
- const reflection::Field* field) {
+MutableFlatbuffer* MutableFlatbuffer::Add(const reflection::Field* field) {
if (field == nullptr) {
return nullptr;
}
return Repeated(field)->Add();
}
-ReflectiveFlatbuffer* ReflectiveFlatbuffer::Mutable(
- const StringPiece field_name) {
+MutableFlatbuffer* MutableFlatbuffer::Mutable(const StringPiece field_name) {
if (const reflection::Field* field = GetFieldOrNull(field_name)) {
return Mutable(field);
}
@@ -244,8 +196,7 @@
return nullptr;
}
-ReflectiveFlatbuffer* ReflectiveFlatbuffer::Mutable(
- const reflection::Field* field) {
+MutableFlatbuffer* MutableFlatbuffer::Mutable(const reflection::Field* field) {
if (field->type()->base_type() != reflection::Obj) {
TC3_LOG(ERROR) << "Field is not of type Object.";
return nullptr;
@@ -258,12 +209,31 @@
/*hint=*/entry,
std::make_pair(
field,
- std::unique_ptr<ReflectiveFlatbuffer>(new ReflectiveFlatbuffer(
+ std::unique_ptr<MutableFlatbuffer>(new MutableFlatbuffer(
schema_, schema_->objects()->Get(field->type()->index())))));
return it->second.get();
}
-RepeatedField* ReflectiveFlatbuffer::Repeated(StringPiece field_name) {
+MutableFlatbuffer* MutableFlatbuffer::Mutable(const FlatbufferFieldPath* path) {
+ const auto* field_path = path->field();
+ if (field_path == nullptr || field_path->size() == 0) {
+ return this;
+ }
+ MutableFlatbuffer* object = this;
+ for (int i = 0; i < field_path->size(); i++) {
+ const reflection::Field* field = object->GetFieldOrNull(field_path->Get(i));
+ if (field == nullptr) {
+ return nullptr;
+ }
+ object = object->Mutable(field);
+ if (object == nullptr) {
+ return nullptr;
+ }
+ }
+ return object;
+}
+
+RepeatedField* MutableFlatbuffer::Repeated(StringPiece field_name) {
if (const reflection::Field* field = GetFieldOrNull(field_name)) {
return Repeated(field);
}
@@ -271,7 +241,7 @@
return nullptr;
}
-RepeatedField* ReflectiveFlatbuffer::Repeated(const reflection::Field* field) {
+RepeatedField* MutableFlatbuffer::Repeated(const reflection::Field* field) {
if (field->type()->base_type() != reflection::Vector) {
TC3_LOG(ERROR) << "Field is not of type Vector.";
return nullptr;
@@ -291,7 +261,7 @@
return it->second.get();
}
-flatbuffers::uoffset_t ReflectiveFlatbuffer::Serialize(
+flatbuffers::uoffset_t MutableFlatbuffer::Serialize(
flatbuffers::FlatBufferBuilder* builder) const {
// Build all children before we can start with this table.
std::vector<
@@ -380,7 +350,7 @@
return builder->EndTable(table_start);
}
-std::string ReflectiveFlatbuffer::Serialize() const {
+std::string MutableFlatbuffer::Serialize() const {
flatbuffers::FlatBufferBuilder builder;
builder.Finish(flatbuffers::Offset<void>(Serialize(&builder)));
return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
@@ -388,7 +358,7 @@
}
template <>
-bool ReflectiveFlatbuffer::AppendFromVector<std::string>(
+bool MutableFlatbuffer::AppendFromVector<std::string>(
const flatbuffers::Table* from, const reflection::Field* field) {
auto* from_vector = from->GetPointer<
const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>*>(
@@ -405,7 +375,7 @@
}
template <>
-bool ReflectiveFlatbuffer::AppendFromVector<ReflectiveFlatbuffer>(
+bool MutableFlatbuffer::AppendFromVector<MutableFlatbuffer>(
const flatbuffers::Table* from, const reflection::Field* field) {
auto* from_vector = from->GetPointer<const flatbuffers::Vector<
flatbuffers::Offset<const flatbuffers::Table>>*>(field->offset());
@@ -415,7 +385,7 @@
RepeatedField* to_repeated = Repeated(field);
for (const flatbuffers::Table* const from_element : *from_vector) {
- ReflectiveFlatbuffer* to_element = to_repeated->Add();
+ MutableFlatbuffer* to_element = to_repeated->Add();
if (to_element == nullptr) {
return false;
}
@@ -424,7 +394,7 @@
return true;
}
-bool ReflectiveFlatbuffer::MergeFrom(const flatbuffers::Table* from) {
+bool MutableFlatbuffer::MergeFrom(const flatbuffers::Table* from) {
// No fields to set.
if (type_->fields() == nullptr) {
return true;
@@ -479,7 +449,7 @@
->str());
break;
case reflection::Obj:
- if (ReflectiveFlatbuffer* nested_field = Mutable(field);
+ if (MutableFlatbuffer* nested_field = Mutable(field);
nested_field == nullptr ||
!nested_field->MergeFrom(
from->GetPointer<const flatbuffers::Table* const>(
@@ -511,7 +481,7 @@
AppendFromVector<std::string>(from, field);
break;
case reflection::Obj:
- AppendFromVector<ReflectiveFlatbuffer>(from, field);
+ AppendFromVector<MutableFlatbuffer>(from, field);
break;
case reflection::Double:
AppendFromVector<double>(from, field);
@@ -536,12 +506,12 @@
return true;
}
-bool ReflectiveFlatbuffer::MergeFromSerializedFlatbuffer(StringPiece from) {
+bool MutableFlatbuffer::MergeFromSerializedFlatbuffer(StringPiece from) {
return MergeFrom(flatbuffers::GetAnyRoot(
reinterpret_cast<const unsigned char*>(from.data())));
}
-void ReflectiveFlatbuffer::AsFlatMap(
+void MutableFlatbuffer::AsFlatMap(
const std::string& key_separator, const std::string& key_prefix,
std::map<std::string, Variant>* result) const {
// Add direct fields.
@@ -557,7 +527,7 @@
}
}
-std::string ReflectiveFlatbuffer::ToTextProto() const {
+std::string MutableFlatbuffer::ToTextProto() const {
std::string result;
std::string current_field_separator;
// Add direct fields.
@@ -584,47 +554,17 @@
return result;
}
-bool SwapFieldNamesForOffsetsInPath(const reflection::Schema* schema,
- FlatbufferFieldPathT* path) {
- if (schema == nullptr || !schema->root_table()) {
- TC3_LOG(ERROR) << "Empty schema provided.";
- return false;
- }
-
- reflection::Object const* type = schema->root_table();
- for (int i = 0; i < path->field.size(); i++) {
- const reflection::Field* field = GetFieldOrNull(type, path->field[i].get());
- if (field == nullptr) {
- TC3_LOG(ERROR) << "Could not find field: " << path->field[i]->field_name;
- return false;
- }
- path->field[i]->field_name.clear();
- path->field[i]->field_offset = field->offset();
-
- // Descend.
- if (i < path->field.size() - 1) {
- if (field->type()->base_type() != reflection::Obj) {
- TC3_LOG(ERROR) << "Field: " << field->name()->str()
- << " is not of type `Object`.";
- return false;
- }
- type = schema->objects()->Get(field->type()->index());
- }
- }
- return true;
-}
-
//
// Repeated field methods.
//
-ReflectiveFlatbuffer* RepeatedField::Add() {
+MutableFlatbuffer* RepeatedField::Add() {
if (is_primitive_) {
TC3_LOG(ERROR) << "Trying to add sub-message on a primitive-typed field.";
return nullptr;
}
- object_items_.emplace_back(new ReflectiveFlatbuffer(
+ object_items_.emplace_back(new MutableFlatbuffer(
schema_, schema_->objects()->Get(field_->type()->index())));
return object_items_.back().get();
}
diff --git a/native/utils/flatbuffers.h b/native/utils/flatbuffers/mutable.h
similarity index 60%
rename from native/utils/flatbuffers.h
rename to native/utils/flatbuffers/mutable.h
index aaf248e..436462a 100644
--- a/native/utils/flatbuffers.h
+++ b/native/utils/flatbuffers/mutable.h
@@ -14,10 +14,8 @@
* limitations under the License.
*/
-// Utility functions for working with FlatBuffers.
-
-#ifndef LIBTEXTCLASSIFIER_UTILS_FLATBUFFERS_H_
-#define LIBTEXTCLASSIFIER_UTILS_FLATBUFFERS_H_
+#ifndef LIBTEXTCLASSIFIER_UTILS_FLATBUFFERS_MUTABLE_H_
+#define LIBTEXTCLASSIFIER_UTILS_FLATBUFFERS_MUTABLE_H_
#include <memory>
#include <string>
@@ -25,7 +23,8 @@
#include "annotator/model_generated.h"
#include "utils/base/logging.h"
-#include "utils/flatbuffers_generated.h"
+#include "utils/flatbuffers/flatbuffers_generated.h"
+#include "utils/flatbuffers/reflection.h"
#include "utils/strings/stringpiece.h"
#include "utils/variant.h"
#include "flatbuffers/flatbuffers.h"
@@ -34,122 +33,35 @@
namespace libtextclassifier3 {
-class ReflectiveFlatBuffer;
+class MutableFlatbuffer;
class RepeatedField;
-// Loads and interprets the buffer as 'FlatbufferMessage' and verifies its
-// integrity.
-template <typename FlatbufferMessage>
-const FlatbufferMessage* LoadAndVerifyFlatbuffer(const void* buffer, int size) {
- const FlatbufferMessage* message =
- flatbuffers::GetRoot<FlatbufferMessage>(buffer);
- if (message == nullptr) {
- return nullptr;
- }
- flatbuffers::Verifier verifier(reinterpret_cast<const uint8_t*>(buffer),
- size);
- if (message->Verify(verifier)) {
- return message;
- } else {
- return nullptr;
- }
-}
-
-// Same as above but takes string.
-template <typename FlatbufferMessage>
-const FlatbufferMessage* LoadAndVerifyFlatbuffer(const std::string& buffer) {
- return LoadAndVerifyFlatbuffer<FlatbufferMessage>(buffer.c_str(),
- buffer.size());
-}
-
-// Loads and interprets the buffer as 'FlatbufferMessage', verifies its
-// integrity and returns its mutable version.
-template <typename FlatbufferMessage>
-std::unique_ptr<typename FlatbufferMessage::NativeTableType>
-LoadAndVerifyMutableFlatbuffer(const void* buffer, int size) {
- const FlatbufferMessage* message =
- LoadAndVerifyFlatbuffer<FlatbufferMessage>(buffer, size);
- if (message == nullptr) {
- return nullptr;
- }
- return std::unique_ptr<typename FlatbufferMessage::NativeTableType>(
- message->UnPack());
-}
-
-// Same as above but takes string.
-template <typename FlatbufferMessage>
-std::unique_ptr<typename FlatbufferMessage::NativeTableType>
-LoadAndVerifyMutableFlatbuffer(const std::string& buffer) {
- return LoadAndVerifyMutableFlatbuffer<FlatbufferMessage>(buffer.c_str(),
- buffer.size());
-}
-
-template <typename FlatbufferMessage>
-const char* FlatbufferFileIdentifier() {
- return nullptr;
-}
-
-template <>
-const char* FlatbufferFileIdentifier<Model>();
-
-// Packs the mutable flatbuffer message to string.
-template <typename FlatbufferMessage>
-std::string PackFlatbuffer(
- const typename FlatbufferMessage::NativeTableType* mutable_message) {
- flatbuffers::FlatBufferBuilder builder;
- builder.Finish(FlatbufferMessage::Pack(builder, mutable_message),
- FlatbufferFileIdentifier<FlatbufferMessage>());
- return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
- builder.GetSize());
-}
-
-class ReflectiveFlatbuffer;
-
// Checks whether a variant value type agrees with a field type.
template <typename T>
bool IsMatchingType(const reflection::BaseType type) {
switch (type) {
- case reflection::Bool:
- return std::is_same<T, bool>::value;
- case reflection::Byte:
- return std::is_same<T, int8>::value;
- case reflection::UByte:
- return std::is_same<T, uint8>::value;
- case reflection::Int:
- return std::is_same<T, int32>::value;
- case reflection::UInt:
- return std::is_same<T, uint32>::value;
- case reflection::Long:
- return std::is_same<T, int64>::value;
- case reflection::ULong:
- return std::is_same<T, uint64>::value;
- case reflection::Float:
- return std::is_same<T, float>::value;
- case reflection::Double:
- return std::is_same<T, double>::value;
case reflection::String:
return std::is_same<T, std::string>::value ||
std::is_same<T, StringPiece>::value ||
std::is_same<T, const char*>::value;
case reflection::Obj:
- return std::is_same<T, ReflectiveFlatbuffer>::value;
+ return std::is_same<T, MutableFlatbuffer>::value;
default:
- return false;
+ return type == flatbuffers_base_type<T>::value;
}
}
-// A flatbuffer that can be built using flatbuffer reflection data of the
-// schema.
-// Normally, field information is hard-coded in code generated from a flatbuffer
-// schema. Here we lookup the necessary information for building a flatbuffer
-// from the provided reflection meta data.
-// When serializing a flatbuffer, the library requires that the sub messages
-// are already serialized, therefore we explicitly keep the field values and
-// serialize the message in (reverse) topological dependency order.
-class ReflectiveFlatbuffer {
+// A mutable flatbuffer that can be built using flatbuffer reflection data of
+// the schema. Normally, field information is hard-coded in code generated from
+// a flatbuffer schema. Here we lookup the necessary information for building a
+// flatbuffer from the provided reflection meta data. When serializing a
+// flatbuffer, the library requires that the sub messages are already
+// serialized, therefore we explicitly keep the field values and serialize the
+// message in (reverse) topological dependency order.
+class MutableFlatbuffer {
public:
- ReflectiveFlatbuffer(const reflection::Schema* schema,
- const reflection::Object* type)
+ MutableFlatbuffer(const reflection::Schema* schema,
+ const reflection::Object* type)
: schema_(schema), type_(type) {}
// Gets the field information for a field name, returns nullptr if the
@@ -160,7 +72,7 @@
// Gets a nested field and the message it is defined on.
bool GetFieldWithParent(const FlatbufferFieldPath* field_path,
- ReflectiveFlatbuffer** parent,
+ MutableFlatbuffer** parent,
reflection::Field const** field);
// Sets a field to a specific value.
@@ -182,8 +94,13 @@
// Sets sub-message field (if not set yet), and returns a pointer to it.
// Returns nullptr if the field was not found, or the field type was not a
// table.
- ReflectiveFlatbuffer* Mutable(StringPiece field_name);
- ReflectiveFlatbuffer* Mutable(const reflection::Field* field);
+ MutableFlatbuffer* Mutable(StringPiece field_name);
+ MutableFlatbuffer* Mutable(const reflection::Field* field);
+
+ // Sets a sub-message field (if not set yet) specified by path, and returns a
+ // pointer to it. Returns nullptr if the field was not found, or the field
+ // type was not a table.
+ MutableFlatbuffer* Mutable(const FlatbufferFieldPath* path);
// Parses the value (according to the type) and sets a primitive field to the
// parsed value.
@@ -195,12 +112,12 @@
bool Add(StringPiece field_name, T value);
// Add a sub-message to the repeated field.
- ReflectiveFlatbuffer* Add(StringPiece field_name);
+ MutableFlatbuffer* Add(StringPiece field_name);
template <typename T>
bool Add(const reflection::Field* field, T value);
- ReflectiveFlatbuffer* Add(const reflection::Field* field);
+ MutableFlatbuffer* Add(const reflection::Field* field);
// Gets the reflective flatbuffer for a repeated field.
// Returns nullptr if the field was not found, or the field type was not a
@@ -236,6 +153,8 @@
return !fields_.empty() || !children_.empty() || !repeated_fields_.empty();
}
+ const reflection::Object* type() const { return type_; }
+
private:
// Helper function for merging given repeated field from given flatbuffer
// table. Appends the elements.
@@ -251,7 +170,7 @@
// Cached sub-messages.
std::unordered_map<const reflection::Field*,
- std::unique_ptr<ReflectiveFlatbuffer>>
+ std::unique_ptr<MutableFlatbuffer>>
children_;
// Cached repeated fields.
@@ -267,23 +186,34 @@
};
// A helper class to build flatbuffers based on schema reflection data.
-// Can be used to a `ReflectiveFlatbuffer` for the root message of the
+// Can be used to a `MutableFlatbuffer` for the root message of the
// schema, or any defined table via name.
-class ReflectiveFlatbufferBuilder {
+class MutableFlatbufferBuilder {
public:
- explicit ReflectiveFlatbufferBuilder(const reflection::Schema* schema)
- : schema_(schema) {}
+ explicit MutableFlatbufferBuilder(const reflection::Schema* schema)
+ : schema_(schema), root_type_(schema->root_table()) {}
+ explicit MutableFlatbufferBuilder(const reflection::Schema* schema,
+ StringPiece root_type);
// Starts a new root table message.
- std::unique_ptr<ReflectiveFlatbuffer> NewRoot() const;
+ std::unique_ptr<MutableFlatbuffer> NewRoot() const;
- // Starts a new table message. Returns nullptr if no table with given name is
+ // Creates a new table message. Returns nullptr if no table with given name is
// found in the schema.
- std::unique_ptr<ReflectiveFlatbuffer> NewTable(
+ std::unique_ptr<MutableFlatbuffer> NewTable(
const StringPiece table_name) const;
+ // Creates a new message for the given type id. Returns nullptr if the type is
+ // invalid.
+ std::unique_ptr<MutableFlatbuffer> NewTable(int type_id) const;
+
+ // Creates a new message for the given type.
+ std::unique_ptr<MutableFlatbuffer> NewTable(
+ const reflection::Object* type) const;
+
private:
const reflection::Schema* const schema_;
+ const reflection::Object* const root_type_;
};
// Encapsulates a repeated field.
@@ -299,7 +229,7 @@
template <typename T>
bool Add(const T value);
- ReflectiveFlatbuffer* Add();
+ MutableFlatbuffer* Add();
template <typename T>
T Get(int index) const {
@@ -307,7 +237,7 @@
}
template <>
- ReflectiveFlatbuffer* Get(int index) const {
+ MutableFlatbuffer* Get(int index) const {
if (is_primitive_) {
TC3_LOG(ERROR) << "Trying to get primitive value out of non-primitive "
"repeated field.";
@@ -338,11 +268,11 @@
bool is_primitive_;
std::vector<Variant> items_;
- std::vector<std::unique_ptr<ReflectiveFlatbuffer>> object_items_;
+ std::vector<std::unique_ptr<MutableFlatbuffer>> object_items_;
};
template <typename T>
-bool ReflectiveFlatbuffer::Set(StringPiece field_name, T value) {
+bool MutableFlatbuffer::Set(StringPiece field_name, T value) {
if (const reflection::Field* field = GetFieldOrNull(field_name)) {
if (field->type()->base_type() == reflection::BaseType::Vector ||
field->type()->base_type() == reflection::BaseType::Obj) {
@@ -357,7 +287,7 @@
}
template <typename T>
-bool ReflectiveFlatbuffer::Set(const reflection::Field* field, T value) {
+bool MutableFlatbuffer::Set(const reflection::Field* field, T value) {
if (field == nullptr) {
TC3_LOG(ERROR) << "Expected non-null field.";
return false;
@@ -374,8 +304,8 @@
}
template <typename T>
-bool ReflectiveFlatbuffer::Set(const FlatbufferFieldPath* path, T value) {
- ReflectiveFlatbuffer* parent;
+bool MutableFlatbuffer::Set(const FlatbufferFieldPath* path, T value) {
+ MutableFlatbuffer* parent;
const reflection::Field* field;
if (!GetFieldWithParent(path, &parent, &field)) {
return false;
@@ -384,7 +314,7 @@
}
template <typename T>
-bool ReflectiveFlatbuffer::Add(StringPiece field_name, T value) {
+bool MutableFlatbuffer::Add(StringPiece field_name, T value) {
const reflection::Field* field = GetFieldOrNull(field_name);
if (field == nullptr) {
return false;
@@ -398,7 +328,7 @@
}
template <typename T>
-bool ReflectiveFlatbuffer::Add(const reflection::Field* field, T value) {
+bool MutableFlatbuffer::Add(const reflection::Field* field, T value) {
if (field == nullptr) {
return false;
}
@@ -416,13 +346,9 @@
return true;
}
-// Resolves field lookups by name to the concrete field offsets.
-bool SwapFieldNamesForOffsetsInPath(const reflection::Schema* schema,
- FlatbufferFieldPathT* path);
-
template <typename T>
-bool ReflectiveFlatbuffer::AppendFromVector(const flatbuffers::Table* from,
- const reflection::Field* field) {
+bool MutableFlatbuffer::AppendFromVector(const flatbuffers::Table* from,
+ const reflection::Field* field) {
const flatbuffers::Vector<T>* from_vector =
from->GetPointer<const flatbuffers::Vector<T>*>(field->offset());
if (from_vector == nullptr) {
@@ -436,14 +362,6 @@
return true;
}
-inline logging::LoggingStringStream& operator<<(
- logging::LoggingStringStream& stream, flatbuffers::String* message) {
- if (message != nullptr) {
- stream.message.append(message->c_str(), message->size());
- }
- return stream;
-}
-
} // namespace libtextclassifier3
-#endif // LIBTEXTCLASSIFIER_UTILS_FLATBUFFERS_H_
+#endif // LIBTEXTCLASSIFIER_UTILS_FLATBUFFERS_MUTABLE_H_
diff --git a/native/utils/flatbuffers/reflection.cc b/native/utils/flatbuffers/reflection.cc
new file mode 100644
index 0000000..d569670
--- /dev/null
+++ b/native/utils/flatbuffers/reflection.cc
@@ -0,0 +1,109 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/flatbuffers/reflection.h"
+
+namespace libtextclassifier3 {
+
+const reflection::Field* GetFieldOrNull(const reflection::Object* type,
+ const StringPiece field_name) {
+ TC3_CHECK(type != nullptr && type->fields() != nullptr);
+ return type->fields()->LookupByKey(field_name.data());
+}
+
+const reflection::Field* GetFieldOrNull(const reflection::Object* type,
+ const int field_offset) {
+ if (type->fields() == nullptr) {
+ return nullptr;
+ }
+ for (const reflection::Field* field : *type->fields()) {
+ if (field->offset() == field_offset) {
+ return field;
+ }
+ }
+ return nullptr;
+}
+
+const reflection::Field* GetFieldOrNull(const reflection::Object* type,
+ const StringPiece field_name,
+ const int field_offset) {
+ // Lookup by name might be faster as the fields are sorted by name in the
+ // schema data, so try that first.
+ if (!field_name.empty()) {
+ return GetFieldOrNull(type, field_name.data());
+ }
+ return GetFieldOrNull(type, field_offset);
+}
+
+const reflection::Field* GetFieldOrNull(const reflection::Object* type,
+ const FlatbufferField* field) {
+ TC3_CHECK(type != nullptr && field != nullptr);
+ if (field->field_name() == nullptr) {
+ return GetFieldOrNull(type, field->field_offset());
+ }
+ return GetFieldOrNull(
+ type,
+ StringPiece(field->field_name()->data(), field->field_name()->size()),
+ field->field_offset());
+}
+
+const reflection::Field* GetFieldOrNull(const reflection::Object* type,
+ const FlatbufferFieldT* field) {
+ TC3_CHECK(type != nullptr && field != nullptr);
+ return GetFieldOrNull(type, field->field_name, field->field_offset);
+}
+
+const reflection::Object* TypeForName(const reflection::Schema* schema,
+ const StringPiece type_name) {
+ for (const reflection::Object* object : *schema->objects()) {
+ if (type_name.Equals(object->name()->str())) {
+ return object;
+ }
+ }
+ return nullptr;
+}
+
+bool SwapFieldNamesForOffsetsInPath(const reflection::Schema* schema,
+ FlatbufferFieldPathT* path) {
+ if (schema == nullptr || !schema->root_table()) {
+ TC3_LOG(ERROR) << "Empty schema provided.";
+ return false;
+ }
+
+ reflection::Object const* type = schema->root_table();
+ for (int i = 0; i < path->field.size(); i++) {
+ const reflection::Field* field = GetFieldOrNull(type, path->field[i].get());
+ if (field == nullptr) {
+ TC3_LOG(ERROR) << "Could not find field: " << path->field[i]->field_name;
+ return false;
+ }
+ path->field[i]->field_name.clear();
+ path->field[i]->field_offset = field->offset();
+
+ // Descend.
+ if (i < path->field.size() - 1) {
+ if (field->type()->base_type() != reflection::Obj) {
+ TC3_LOG(ERROR) << "Field: " << field->name()->str()
+ << " is not of type `Object`.";
+ return false;
+ }
+ type = schema->objects()->Get(field->type()->index());
+ }
+ }
+ return true;
+}
+
+} // namespace libtextclassifier3
diff --git a/native/utils/flatbuffers/reflection.h b/native/utils/flatbuffers/reflection.h
new file mode 100644
index 0000000..cad3c5a
--- /dev/null
+++ b/native/utils/flatbuffers/reflection.h
@@ -0,0 +1,119 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Utility functions for working with FlatBuffers.
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_FLATBUFFERS_REFLECTION_H_
+#define LIBTEXTCLASSIFIER_UTILS_FLATBUFFERS_REFLECTION_H_
+
+#include "utils/flatbuffers/flatbuffers_generated.h"
+#include "utils/strings/stringpiece.h"
+#include "flatbuffers/reflection_generated.h"
+
+namespace libtextclassifier3 {
+
+template <typename T>
+struct flatbuffers_base_type {
+ static const reflection::BaseType value;
+};
+
+template <typename T>
+inline const reflection::BaseType flatbuffers_base_type<T>::value =
+ reflection::None;
+
+template <>
+inline const reflection::BaseType flatbuffers_base_type<bool>::value =
+ reflection::Bool;
+
+template <>
+inline const reflection::BaseType flatbuffers_base_type<int8>::value =
+ reflection::Byte;
+
+template <>
+inline const reflection::BaseType flatbuffers_base_type<uint8>::value =
+ reflection::UByte;
+
+template <>
+inline const reflection::BaseType flatbuffers_base_type<int16>::value =
+ reflection::Short;
+
+template <>
+inline const reflection::BaseType flatbuffers_base_type<uint16>::value =
+ reflection::UShort;
+
+template <>
+inline const reflection::BaseType flatbuffers_base_type<int32>::value =
+ reflection::Int;
+
+template <>
+inline const reflection::BaseType flatbuffers_base_type<uint32>::value =
+ reflection::UInt;
+
+template <>
+inline const reflection::BaseType flatbuffers_base_type<int64>::value =
+ reflection::Long;
+
+template <>
+inline const reflection::BaseType flatbuffers_base_type<uint64>::value =
+ reflection::ULong;
+
+template <>
+inline const reflection::BaseType flatbuffers_base_type<float>::value =
+ reflection::Float;
+
+template <>
+inline const reflection::BaseType flatbuffers_base_type<double>::value =
+ reflection::Double;
+
+template <>
+inline const reflection::BaseType flatbuffers_base_type<StringPiece>::value =
+ reflection::String;
+
+// Gets the field information for a field name, returns nullptr if the
+// field was not defined.
+const reflection::Field* GetFieldOrNull(const reflection::Object* type,
+ const StringPiece field_name);
+
+// Gets the field information for a field offet, returns nullptr if no field was
+// defined with the given offset.
+const reflection::Field* GetFieldOrNull(const reflection::Object* type,
+ const int field_offset);
+
+// Gets a field by name or offset, returns nullptr if no field was found.
+const reflection::Field* GetFieldOrNull(const reflection::Object* type,
+ const StringPiece field_name,
+ const int field_offset);
+
+// Gets a field by a field spec, either by name or offset. Returns nullptr if no
+// such field was found.
+const reflection::Field* GetFieldOrNull(const reflection::Object* type,
+ const FlatbufferField* field);
+
+const reflection::Field* GetFieldOrNull(const reflection::Object* type,
+ const FlatbufferFieldT* field);
+
+// Gets the type information for the given type name or nullptr if not
+// specified.
+const reflection::Object* TypeForName(const reflection::Schema* schema,
+ const StringPiece type_name);
+
+// Resolves field lookups by name to the concrete field offsets.
+bool SwapFieldNamesForOffsetsInPath(const reflection::Schema* schema,
+ FlatbufferFieldPathT* path);
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_FLATBUFFERS_REFLECTION_H_
diff --git a/native/utils/grammar/match.h b/native/utils/grammar/match.h
index 97edac9..f96703d 100644
--- a/native/utils/grammar/match.h
+++ b/native/utils/grammar/match.h
@@ -77,7 +77,7 @@
int16 type = kUnknownType;
// The span in codepoints.
- CodepointSpan codepoint_span;
+ CodepointSpan codepoint_span = CodepointSpan::kInvalid;
// The begin codepoint offset used during matching.
// This is usually including any prefix whitespace.
diff --git a/native/utils/grammar/next/semantics/expression.fbs b/native/utils/grammar/next/semantics/expression.fbs
new file mode 100755
index 0000000..e20fd00
--- /dev/null
+++ b/native/utils/grammar/next/semantics/expression.fbs
@@ -0,0 +1,94 @@
+//
+// Copyright (C) 2018 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+include "utils/flatbuffers/flatbuffers.fbs";
+
+namespace libtextclassifier3.grammar.next.SemanticExpression_;
+union Expression {
+ ConstValueExpression,
+ ConstituentExpression,
+ ComposeExpression,
+ SpanTextExpression,
+ ParseNumberExpression,
+}
+
+// A semantic expression.
+namespace libtextclassifier3.grammar.next;
+table SemanticExpression {
+ expression:SemanticExpression_.Expression;
+}
+
+// A constant flatbuffer value.
+namespace libtextclassifier3.grammar.next;
+table ConstValueExpression {
+ // The base type of the value.
+ base_type:int;
+
+ // The id of the type of the value.
+ // The id is used for lookup in the semantic values type metadata.
+ type:int;
+
+ // The serialized value.
+ value:[ubyte];
+}
+
+// The value of a rule constituent.
+namespace libtextclassifier3.grammar.next;
+table ConstituentExpression {
+ // The id of the constituent.
+ id:ushort;
+}
+
+// The fields to set.
+namespace libtextclassifier3.grammar.next.ComposeExpression_;
+table Field {
+ // The field to set.
+ path:libtextclassifier3.FlatbufferFieldPath;
+
+ // The value.
+ value:SemanticExpression;
+
+ // Whether the field can be absent: If set to true, evaluation to null will
+ // not be treated as an error.
+ // A value of null represents a non-present value that can e.g. arise from
+ // optional parts of a rule that might not be present in a match.
+ optional:bool;
+}
+
+// A combination: Compose a result from arguments.
+// https://mitpress.mit.edu/sites/default/files/sicp/full-text/book/book-Z-H-4.html#%_toc_%_sec_1.1.1
+namespace libtextclassifier3.grammar.next;
+table ComposeExpression {
+ // The id of the type of the result.
+ type:int;
+
+ fields:[ComposeExpression_.Field];
+}
+
+// Lifts a span as a value.
+namespace libtextclassifier3.grammar.next;
+table SpanTextExpression {
+}
+
+// Parses a string as a number.
+namespace libtextclassifier3.grammar.next;
+table ParseNumberExpression {
+ // The base type of the value.
+ base_type:int;
+
+ value:SemanticExpression;
+}
+
diff --git a/native/utils/grammar/rules.fbs b/native/utils/grammar/rules.fbs
index 8052c11..fd681c5 100755
--- a/native/utils/grammar/rules.fbs
+++ b/native/utils/grammar/rules.fbs
@@ -14,8 +14,8 @@
// limitations under the License.
//
-include "utils/i18n/language-tag.fbs";
include "utils/zlib/buffer.fbs";
+include "utils/i18n/language-tag.fbs";
// The terminal rules map as sorted strings table.
// The sorted terminal strings table is represented as offsets into the
@@ -211,5 +211,9 @@
// If true, will compile the regexes only on first use.
lazy_regex_compilation:bool;
+ reserved_10:int16 (deprecated);
+
+ // The schema defining the semantic results.
+ semantic_values_schema:[ubyte];
}
diff --git a/native/utils/grammar/types.h b/native/utils/grammar/types.h
index a79532b..64a618d 100644
--- a/native/utils/grammar/types.h
+++ b/native/utils/grammar/types.h
@@ -38,6 +38,7 @@
kMapping = -3,
kExclusion = -4,
kRootRule = 1,
+ kSemanticExpression = 2,
};
// Special CallbackId indicating that there's no callback associated with a
diff --git a/native/utils/grammar/utils/ir.cc b/native/utils/grammar/utils/ir.cc
index ce074b8..fc5c28e 100644
--- a/native/utils/grammar/utils/ir.cc
+++ b/native/utils/grammar/utils/ir.cc
@@ -70,7 +70,7 @@
}
}
- return false;
+ return true;
}
Ir::LhsSet SortedLhsSet(const Ir::LhsSet& lhs_set) {
diff --git a/native/utils/grammar/utils/ir_test.cc b/native/utils/grammar/utils/ir_test.cc
index d2438dd..4d12e76 100644
--- a/native/utils/grammar/utils/ir_test.cc
+++ b/native/utils/grammar/utils/ir_test.cc
@@ -24,6 +24,7 @@
namespace libtextclassifier3::grammar {
namespace {
+using ::testing::ElementsAre;
using ::testing::Eq;
using ::testing::IsEmpty;
using ::testing::Ne;
@@ -234,5 +235,21 @@
EXPECT_THAT(rules.rules[1]->binary_rules, SizeIs(3));
}
+TEST(IrTest, DeduplicatesLhsSets) {
+ Ir ir;
+
+ const Nonterm test = ir.AddUnshareableNonterminal();
+ ir.Add(test, "test");
+
+ // Add a second rule for the same nonterminal.
+ ir.Add(test, "again");
+
+ RulesSetT rules;
+ ir.Serialize(/*include_debug_information=*/false, &rules);
+
+ EXPECT_THAT(rules.lhs_set, SizeIs(1));
+ EXPECT_THAT(rules.lhs_set.front()->lhs, ElementsAre(test));
+}
+
} // namespace
} // namespace libtextclassifier3::grammar
diff --git a/native/utils/grammar/utils/rules.cc b/native/utils/grammar/utils/rules.cc
index d6e4b76..d661a21 100644
--- a/native/utils/grammar/utils/rules.cc
+++ b/native/utils/grammar/utils/rules.cc
@@ -379,6 +379,14 @@
/*callback_param=*/value, max_whitespace_gap, case_sensitive, shard);
}
+void Rules::AddValueMapping(const int lhs, const std::vector<RhsElement>& rhs,
+ int64 value, const int8 max_whitespace_gap,
+ const bool case_sensitive, const int shard) {
+ Add(lhs, rhs,
+ /*callback=*/static_cast<CallbackId>(DefaultCallback::kMapping),
+ /*callback_param=*/value, max_whitespace_gap, case_sensitive, shard);
+}
+
void Rules::AddRegex(const std::string& lhs, const std::string& regex_pattern) {
AddRegex(AddNonterminal(lhs), regex_pattern);
}
@@ -408,6 +416,12 @@
// multiple rules or that have a filter callback on some rule.
for (int i = 0; i < nonterminals_.size(); i++) {
const NontermInfo& nonterminal = nonterminals_[i];
+
+ // Skip predefined nonterminals, they have already been assigned.
+ if (rules.GetNonterminalForName(nonterminal.name) != kUnassignedNonterm) {
+ continue;
+ }
+
bool unmergeable =
(nonterminal.from_annotation || nonterminal.rules.size() > 1 ||
!nonterminal.regex_rules.empty());
diff --git a/native/utils/grammar/utils/rules.h b/native/utils/grammar/utils/rules.h
index 5a2cbc2..0c8d7da 100644
--- a/native/utils/grammar/utils/rules.h
+++ b/native/utils/grammar/utils/rules.h
@@ -66,15 +66,24 @@
struct RhsElement {
RhsElement() {}
explicit RhsElement(const std::string& terminal, const bool is_optional)
- : is_terminal(true), terminal(terminal), is_optional(is_optional) {}
- explicit RhsElement(const int nonterminal, const bool is_optional)
+ : is_terminal(true),
+ terminal(terminal),
+ is_optional(is_optional),
+ is_constituent(false) {}
+ explicit RhsElement(const int nonterminal, const bool is_optional,
+ const bool is_constituent = true)
: is_terminal(false),
nonterminal(nonterminal),
- is_optional(is_optional) {}
+ is_optional(is_optional),
+ is_constituent(is_constituent) {}
bool is_terminal;
std::string terminal;
int nonterminal;
bool is_optional;
+
+ // Whether the element is a constituent of a rule - these are the explicit
+ // nonterminals, but not terminals or implicitly added anchors.
+ bool is_constituent;
};
// Represents the right-hand side, and possibly callback, of one rule.
@@ -139,6 +148,9 @@
const std::vector<std::string>& rhs, int64 value,
int8 max_whitespace_gap = -1,
bool case_sensitive = false, int shard = 0);
+ void AddValueMapping(int lhs, const std::vector<RhsElement>& rhs, int64 value,
+ int8 max_whitespace_gap = -1,
+ bool case_sensitive = false, int shard = 0);
// Adds a regex rule.
void AddRegex(const std::string& lhs, const std::string& regex_pattern);
diff --git a/native/utils/intents/intent-generator.cc b/native/utils/intents/intent-generator.cc
index 4cb3e40..7edef41 100644
--- a/native/utils/intents/intent-generator.cc
+++ b/native/utils/intents/intent-generator.cc
@@ -18,21 +18,11 @@
#include <vector>
-#include "actions/types.h"
-#include "annotator/types.h"
#include "utils/base/logging.h"
-#include "utils/base/statusor.h"
-#include "utils/hash/farmhash.h"
-#include "utils/java/jni-base.h"
+#include "utils/intents/jni-lua.h"
#include "utils/java/jni-helper.h"
-#include "utils/java/string_utils.h"
-#include "utils/lua-utils.h"
-#include "utils/strings/stringpiece.h"
-#include "utils/strings/substitute.h"
#include "utils/utf8/unicodetext.h"
-#include "utils/variant.h"
#include "utils/zlib/zlib.h"
-#include "flatbuffers/reflection_generated.h"
#ifdef __cplusplus
extern "C" {
@@ -47,696 +37,6 @@
namespace {
static constexpr const char* kReferenceTimeUsecKey = "reference_time_ms_utc";
-static constexpr const char* kHashKey = "hash";
-static constexpr const char* kUrlSchemaKey = "url_schema";
-static constexpr const char* kUrlHostKey = "url_host";
-static constexpr const char* kUrlEncodeKey = "urlencode";
-static constexpr const char* kPackageNameKey = "package_name";
-static constexpr const char* kDeviceLocaleKey = "device_locales";
-static constexpr const char* kFormatKey = "format";
-
-// An Android specific Lua environment with JNI backed callbacks.
-class JniLuaEnvironment : public LuaEnvironment {
- public:
- JniLuaEnvironment(const Resources& resources, const JniCache* jni_cache,
- const jobject context,
- const std::vector<Locale>& device_locales);
- // Environment setup.
- bool Initialize();
-
- // Runs an intent generator snippet.
- bool RunIntentGenerator(const std::string& generator_snippet,
- std::vector<RemoteActionTemplate>* remote_actions);
-
- protected:
- virtual void SetupExternalHook();
-
- int HandleExternalCallback();
- int HandleAndroidCallback();
- int HandleUserRestrictionsCallback();
- int HandleUrlEncode();
- int HandleUrlSchema();
- int HandleHash();
- int HandleFormat();
- int HandleAndroidStringResources();
- int HandleUrlHost();
-
- // Checks and retrieves string resources from the model.
- bool LookupModelStringResource() const;
-
- // Reads and create a RemoteAction result from Lua.
- RemoteActionTemplate ReadRemoteActionTemplateResult() const;
-
- // Reads the extras from the Lua result.
- std::map<std::string, Variant> ReadExtras() const;
-
- // Retrieves user manager if not previously done.
- bool RetrieveUserManager();
-
- // Retrieves system resources if not previously done.
- bool RetrieveSystemResources();
-
- // Parse the url string by using Uri.parse from Java.
- StatusOr<ScopedLocalRef<jobject>> ParseUri(StringPiece url) const;
-
- // Read remote action templates from lua generator.
- int ReadRemoteActionTemplates(std::vector<RemoteActionTemplate>* result);
-
- const Resources& resources_;
- JNIEnv* jenv_;
- const JniCache* jni_cache_;
- const jobject context_;
- std::vector<Locale> device_locales_;
-
- ScopedGlobalRef<jobject> usermanager_;
- // Whether we previously attempted to retrieve the UserManager before.
- bool usermanager_retrieved_;
-
- ScopedGlobalRef<jobject> system_resources_;
- // Whether we previously attempted to retrieve the system resources.
- bool system_resources_resources_retrieved_;
-
- // Cached JNI references for Java strings `string` and `android`.
- ScopedGlobalRef<jstring> string_;
- ScopedGlobalRef<jstring> android_;
-};
-
-JniLuaEnvironment::JniLuaEnvironment(const Resources& resources,
- const JniCache* jni_cache,
- const jobject context,
- const std::vector<Locale>& device_locales)
- : resources_(resources),
- jenv_(jni_cache ? jni_cache->GetEnv() : nullptr),
- jni_cache_(jni_cache),
- context_(context),
- device_locales_(device_locales),
- usermanager_(/*object=*/nullptr,
- /*jvm=*/(jni_cache ? jni_cache->jvm : nullptr)),
- usermanager_retrieved_(false),
- system_resources_(/*object=*/nullptr,
- /*jvm=*/(jni_cache ? jni_cache->jvm : nullptr)),
- system_resources_resources_retrieved_(false),
- string_(/*object=*/nullptr,
- /*jvm=*/(jni_cache ? jni_cache->jvm : nullptr)),
- android_(/*object=*/nullptr,
- /*jvm=*/(jni_cache ? jni_cache->jvm : nullptr)) {}
-
-bool JniLuaEnvironment::Initialize() {
- TC3_ASSIGN_OR_RETURN_FALSE(ScopedLocalRef<jstring> string_value,
- JniHelper::NewStringUTF(jenv_, "string"));
- string_ = MakeGlobalRef(string_value.get(), jenv_, jni_cache_->jvm);
- TC3_ASSIGN_OR_RETURN_FALSE(ScopedLocalRef<jstring> android_value,
- JniHelper::NewStringUTF(jenv_, "android"));
- android_ = MakeGlobalRef(android_value.get(), jenv_, jni_cache_->jvm);
- if (string_ == nullptr || android_ == nullptr) {
- TC3_LOG(ERROR) << "Could not allocate constant strings references.";
- return false;
- }
- return (RunProtected([this] {
- LoadDefaultLibraries();
- SetupExternalHook();
- lua_setglobal(state_, "external");
- return LUA_OK;
- }) == LUA_OK);
-}
-
-void JniLuaEnvironment::SetupExternalHook() {
- // This exposes an `external` object with the following fields:
- // * entity: the bundle with all information about a classification.
- // * android: callbacks into specific android provided methods.
- // * android.user_restrictions: callbacks to check user permissions.
- // * android.R: callbacks to retrieve string resources.
- PushLazyObject(&JniLuaEnvironment::HandleExternalCallback);
-
- // android
- PushLazyObject(&JniLuaEnvironment::HandleAndroidCallback);
- {
- // android.user_restrictions
- PushLazyObject(&JniLuaEnvironment::HandleUserRestrictionsCallback);
- lua_setfield(state_, /*idx=*/-2, "user_restrictions");
-
- // android.R
- // Callback to access android string resources.
- PushLazyObject(&JniLuaEnvironment::HandleAndroidStringResources);
- lua_setfield(state_, /*idx=*/-2, "R");
- }
- lua_setfield(state_, /*idx=*/-2, "android");
-}
-
-int JniLuaEnvironment::HandleExternalCallback() {
- const StringPiece key = ReadString(kIndexStackTop);
- if (key.Equals(kHashKey)) {
- PushFunction(&JniLuaEnvironment::HandleHash);
- return 1;
- } else if (key.Equals(kFormatKey)) {
- PushFunction(&JniLuaEnvironment::HandleFormat);
- return 1;
- } else {
- TC3_LOG(ERROR) << "Undefined external access " << key;
- lua_error(state_);
- return 0;
- }
-}
-
-int JniLuaEnvironment::HandleAndroidCallback() {
- const StringPiece key = ReadString(kIndexStackTop);
- if (key.Equals(kDeviceLocaleKey)) {
- // Provide the locale as table with the individual fields set.
- lua_newtable(state_);
- for (int i = 0; i < device_locales_.size(); i++) {
- // Adjust index to 1-based indexing for Lua.
- lua_pushinteger(state_, i + 1);
- lua_newtable(state_);
- PushString(device_locales_[i].Language());
- lua_setfield(state_, -2, "language");
- PushString(device_locales_[i].Region());
- lua_setfield(state_, -2, "region");
- PushString(device_locales_[i].Script());
- lua_setfield(state_, -2, "script");
- lua_settable(state_, /*idx=*/-3);
- }
- return 1;
- } else if (key.Equals(kPackageNameKey)) {
- if (context_ == nullptr) {
- TC3_LOG(ERROR) << "Context invalid.";
- lua_error(state_);
- return 0;
- }
-
- StatusOr<ScopedLocalRef<jstring>> status_or_package_name_str =
- JniHelper::CallObjectMethod<jstring>(
- jenv_, context_, jni_cache_->context_get_package_name);
-
- if (!status_or_package_name_str.ok()) {
- TC3_LOG(ERROR) << "Error calling Context.getPackageName";
- lua_error(state_);
- return 0;
- }
- StatusOr<std::string> status_or_package_name_std_str =
- ToStlString(jenv_, status_or_package_name_str.ValueOrDie().get());
- if (!status_or_package_name_std_str.ok()) {
- lua_error(state_);
- return 0;
- }
- PushString(status_or_package_name_std_str.ValueOrDie());
- return 1;
- } else if (key.Equals(kUrlEncodeKey)) {
- PushFunction(&JniLuaEnvironment::HandleUrlEncode);
- return 1;
- } else if (key.Equals(kUrlHostKey)) {
- PushFunction(&JniLuaEnvironment::HandleUrlHost);
- return 1;
- } else if (key.Equals(kUrlSchemaKey)) {
- PushFunction(&JniLuaEnvironment::HandleUrlSchema);
- return 1;
- } else {
- TC3_LOG(ERROR) << "Undefined android reference " << key;
- lua_error(state_);
- return 0;
- }
-}
-
-int JniLuaEnvironment::HandleUserRestrictionsCallback() {
- if (jni_cache_->usermanager_class == nullptr ||
- jni_cache_->usermanager_get_user_restrictions == nullptr) {
- // UserManager is only available for API level >= 17 and
- // getUserRestrictions only for API level >= 18, so we just return false
- // normally here.
- lua_pushboolean(state_, false);
- return 1;
- }
-
- // Get user manager if not previously retrieved.
- if (!RetrieveUserManager()) {
- TC3_LOG(ERROR) << "Error retrieving user manager.";
- lua_error(state_);
- return 0;
- }
-
- StatusOr<ScopedLocalRef<jobject>> status_or_bundle =
- JniHelper::CallObjectMethod(
- jenv_, usermanager_.get(),
- jni_cache_->usermanager_get_user_restrictions);
- if (!status_or_bundle.ok() || status_or_bundle.ValueOrDie() == nullptr) {
- TC3_LOG(ERROR) << "Error calling getUserRestrictions";
- lua_error(state_);
- return 0;
- }
-
- const StringPiece key_str = ReadString(kIndexStackTop);
- if (key_str.empty()) {
- TC3_LOG(ERROR) << "Expected string, got null.";
- lua_error(state_);
- return 0;
- }
-
- const StatusOr<ScopedLocalRef<jstring>> status_or_key =
- jni_cache_->ConvertToJavaString(key_str);
- if (!status_or_key.ok()) {
- lua_error(state_);
- return 0;
- }
- const StatusOr<bool> status_or_permission = JniHelper::CallBooleanMethod(
- jenv_, status_or_bundle.ValueOrDie().get(),
- jni_cache_->bundle_get_boolean, status_or_key.ValueOrDie().get());
- if (!status_or_permission.ok()) {
- TC3_LOG(ERROR) << "Error getting bundle value";
- lua_pushboolean(state_, false);
- } else {
- lua_pushboolean(state_, status_or_permission.ValueOrDie());
- }
- return 1;
-}
-
-int JniLuaEnvironment::HandleUrlEncode() {
- const StringPiece input = ReadString(/*index=*/1);
- if (input.empty()) {
- TC3_LOG(ERROR) << "Expected string, got null.";
- lua_error(state_);
- return 0;
- }
-
- // Call Java URL encoder.
- const StatusOr<ScopedLocalRef<jstring>> status_or_input_str =
- jni_cache_->ConvertToJavaString(input);
- if (!status_or_input_str.ok()) {
- lua_error(state_);
- return 0;
- }
- StatusOr<ScopedLocalRef<jstring>> status_or_encoded_str =
- JniHelper::CallStaticObjectMethod<jstring>(
- jenv_, jni_cache_->urlencoder_class.get(),
- jni_cache_->urlencoder_encode, status_or_input_str.ValueOrDie().get(),
- jni_cache_->string_utf8.get());
-
- if (!status_or_encoded_str.ok()) {
- TC3_LOG(ERROR) << "Error calling UrlEncoder.encode";
- lua_error(state_);
- return 0;
- }
- const StatusOr<std::string> status_or_encoded_std_str =
- ToStlString(jenv_, status_or_encoded_str.ValueOrDie().get());
- if (!status_or_encoded_std_str.ok()) {
- lua_error(state_);
- return 0;
- }
- PushString(status_or_encoded_std_str.ValueOrDie());
- return 1;
-}
-
-StatusOr<ScopedLocalRef<jobject>> JniLuaEnvironment::ParseUri(
- StringPiece url) const {
- if (url.empty()) {
- return {Status::UNKNOWN};
- }
-
- // Call to Java URI parser.
- TC3_ASSIGN_OR_RETURN(
- const StatusOr<ScopedLocalRef<jstring>> status_or_url_str,
- jni_cache_->ConvertToJavaString(url));
-
- // Try to parse uri and get scheme.
- TC3_ASSIGN_OR_RETURN(
- ScopedLocalRef<jobject> uri,
- JniHelper::CallStaticObjectMethod(jenv_, jni_cache_->uri_class.get(),
- jni_cache_->uri_parse,
- status_or_url_str.ValueOrDie().get()));
- if (uri == nullptr) {
- TC3_LOG(ERROR) << "Error calling Uri.parse";
- return {Status::UNKNOWN};
- }
- return uri;
-}
-
-int JniLuaEnvironment::HandleUrlSchema() {
- StringPiece url = ReadString(/*index=*/1);
-
- const StatusOr<ScopedLocalRef<jobject>> status_or_parsed_uri = ParseUri(url);
- if (!status_or_parsed_uri.ok()) {
- lua_error(state_);
- return 0;
- }
-
- const StatusOr<ScopedLocalRef<jstring>> status_or_scheme_str =
- JniHelper::CallObjectMethod<jstring>(
- jenv_, status_or_parsed_uri.ValueOrDie().get(),
- jni_cache_->uri_get_scheme);
- if (!status_or_scheme_str.ok()) {
- TC3_LOG(ERROR) << "Error calling Uri.getScheme";
- lua_error(state_);
- return 0;
- }
- if (status_or_scheme_str.ValueOrDie() == nullptr) {
- lua_pushnil(state_);
- } else {
- const StatusOr<std::string> status_or_scheme_std_str =
- ToStlString(jenv_, status_or_scheme_str.ValueOrDie().get());
- if (!status_or_scheme_std_str.ok()) {
- lua_error(state_);
- return 0;
- }
- PushString(status_or_scheme_std_str.ValueOrDie());
- }
- return 1;
-}
-
-int JniLuaEnvironment::HandleUrlHost() {
- const StringPiece url = ReadString(kIndexStackTop);
-
- const StatusOr<ScopedLocalRef<jobject>> status_or_parsed_uri = ParseUri(url);
- if (!status_or_parsed_uri.ok()) {
- lua_error(state_);
- return 0;
- }
-
- const StatusOr<ScopedLocalRef<jstring>> status_or_host_str =
- JniHelper::CallObjectMethod<jstring>(
- jenv_, status_or_parsed_uri.ValueOrDie().get(),
- jni_cache_->uri_get_host);
- if (!status_or_host_str.ok()) {
- TC3_LOG(ERROR) << "Error calling Uri.getHost";
- lua_error(state_);
- return 0;
- }
-
- if (status_or_host_str.ValueOrDie() == nullptr) {
- lua_pushnil(state_);
- } else {
- const StatusOr<std::string> status_or_host_std_str =
- ToStlString(jenv_, status_or_host_str.ValueOrDie().get());
- if (!status_or_host_std_str.ok()) {
- lua_error(state_);
- return 0;
- }
- PushString(status_or_host_std_str.ValueOrDie());
- }
- return 1;
-}
-
-int JniLuaEnvironment::HandleHash() {
- const StringPiece input = ReadString(kIndexStackTop);
- lua_pushinteger(state_, tc3farmhash::Hash32(input.data(), input.length()));
- return 1;
-}
-
-int JniLuaEnvironment::HandleFormat() {
- const int num_args = lua_gettop(state_);
- std::vector<StringPiece> args(num_args - 1);
- for (int i = 0; i < num_args - 1; i++) {
- args[i] = ReadString(/*index=*/i + 2);
- }
- PushString(strings::Substitute(ReadString(/*index=*/1), args));
- return 1;
-}
-
-bool JniLuaEnvironment::LookupModelStringResource() const {
- // Handle only lookup by name.
- if (lua_type(state_, kIndexStackTop) != LUA_TSTRING) {
- return false;
- }
-
- const StringPiece resource_name = ReadString(kIndexStackTop);
- std::string resource_content;
- if (!resources_.GetResourceContent(device_locales_, resource_name,
- &resource_content)) {
- // Resource cannot be provided by the model.
- return false;
- }
-
- PushString(resource_content);
- return true;
-}
-
-int JniLuaEnvironment::HandleAndroidStringResources() {
- // Check whether the requested resource can be served from the model data.
- if (LookupModelStringResource()) {
- return 1;
- }
-
- // Get system resources if not previously retrieved.
- if (!RetrieveSystemResources()) {
- TC3_LOG(ERROR) << "Error retrieving system resources.";
- lua_error(state_);
- return 0;
- }
-
- int resource_id;
- switch (lua_type(state_, kIndexStackTop)) {
- case LUA_TNUMBER:
- resource_id = Read<int>(/*index=*/kIndexStackTop);
- break;
- case LUA_TSTRING: {
- const StringPiece resource_name_str = ReadString(kIndexStackTop);
- if (resource_name_str.empty()) {
- TC3_LOG(ERROR) << "No resource name provided.";
- lua_error(state_);
- return 0;
- }
- const StatusOr<ScopedLocalRef<jstring>> status_or_resource_name =
- jni_cache_->ConvertToJavaString(resource_name_str);
- if (!status_or_resource_name.ok()) {
- TC3_LOG(ERROR) << "Invalid resource name.";
- lua_error(state_);
- return 0;
- }
- StatusOr<int> status_or_resource_id = JniHelper::CallIntMethod(
- jenv_, system_resources_.get(), jni_cache_->resources_get_identifier,
- status_or_resource_name.ValueOrDie().get(), string_.get(),
- android_.get());
- if (!status_or_resource_id.ok()) {
- TC3_LOG(ERROR) << "Error calling getIdentifier.";
- lua_error(state_);
- return 0;
- }
- resource_id = status_or_resource_id.ValueOrDie();
- break;
- }
- default:
- TC3_LOG(ERROR) << "Unexpected type for resource lookup.";
- lua_error(state_);
- return 0;
- }
- if (resource_id == 0) {
- TC3_LOG(ERROR) << "Resource not found.";
- lua_pushnil(state_);
- return 1;
- }
- StatusOr<ScopedLocalRef<jstring>> status_or_resource_str =
- JniHelper::CallObjectMethod<jstring>(jenv_, system_resources_.get(),
- jni_cache_->resources_get_string,
- resource_id);
- if (!status_or_resource_str.ok()) {
- TC3_LOG(ERROR) << "Error calling getString.";
- lua_error(state_);
- return 0;
- }
-
- if (status_or_resource_str.ValueOrDie() == nullptr) {
- lua_pushnil(state_);
- } else {
- StatusOr<std::string> status_or_resource_std_str =
- ToStlString(jenv_, status_or_resource_str.ValueOrDie().get());
- if (!status_or_resource_std_str.ok()) {
- lua_error(state_);
- return 0;
- }
- PushString(status_or_resource_std_str.ValueOrDie());
- }
- return 1;
-}
-
-bool JniLuaEnvironment::RetrieveSystemResources() {
- if (system_resources_resources_retrieved_) {
- return (system_resources_ != nullptr);
- }
- system_resources_resources_retrieved_ = true;
- TC3_ASSIGN_OR_RETURN_FALSE(ScopedLocalRef<jobject> system_resources_ref,
- JniHelper::CallStaticObjectMethod(
- jenv_, jni_cache_->resources_class.get(),
- jni_cache_->resources_get_system));
- system_resources_ =
- MakeGlobalRef(system_resources_ref.get(), jenv_, jni_cache_->jvm);
- return (system_resources_ != nullptr);
-}
-
-bool JniLuaEnvironment::RetrieveUserManager() {
- if (context_ == nullptr) {
- return false;
- }
- if (usermanager_retrieved_) {
- return (usermanager_ != nullptr);
- }
- usermanager_retrieved_ = true;
- TC3_ASSIGN_OR_RETURN_FALSE(const ScopedLocalRef<jstring> service,
- JniHelper::NewStringUTF(jenv_, "user"));
- TC3_ASSIGN_OR_RETURN_FALSE(
- const ScopedLocalRef<jobject> usermanager_ref,
- JniHelper::CallObjectMethod(jenv_, context_,
- jni_cache_->context_get_system_service,
- service.get()));
-
- usermanager_ = MakeGlobalRef(usermanager_ref.get(), jenv_, jni_cache_->jvm);
- return (usermanager_ != nullptr);
-}
-
-RemoteActionTemplate JniLuaEnvironment::ReadRemoteActionTemplateResult() const {
- RemoteActionTemplate result;
- // Read intent template.
- lua_pushnil(state_);
- while (Next(/*index=*/-2)) {
- const StringPiece key = ReadString(/*index=*/-2);
- if (key.Equals("title_without_entity")) {
- result.title_without_entity = Read<std::string>(/*index=*/kIndexStackTop);
- } else if (key.Equals("title_with_entity")) {
- result.title_with_entity = Read<std::string>(/*index=*/kIndexStackTop);
- } else if (key.Equals("description")) {
- result.description = Read<std::string>(/*index=*/kIndexStackTop);
- } else if (key.Equals("description_with_app_name")) {
- result.description_with_app_name =
- Read<std::string>(/*index=*/kIndexStackTop);
- } else if (key.Equals("action")) {
- result.action = Read<std::string>(/*index=*/kIndexStackTop);
- } else if (key.Equals("data")) {
- result.data = Read<std::string>(/*index=*/kIndexStackTop);
- } else if (key.Equals("type")) {
- result.type = Read<std::string>(/*index=*/kIndexStackTop);
- } else if (key.Equals("flags")) {
- result.flags = Read<int>(/*index=*/kIndexStackTop);
- } else if (key.Equals("package_name")) {
- result.package_name = Read<std::string>(/*index=*/kIndexStackTop);
- } else if (key.Equals("request_code")) {
- result.request_code = Read<int>(/*index=*/kIndexStackTop);
- } else if (key.Equals("category")) {
- result.category = ReadVector<std::string>(/*index=*/kIndexStackTop);
- } else if (key.Equals("extra")) {
- result.extra = ReadExtras();
- } else {
- TC3_LOG(INFO) << "Unknown entry: " << key;
- }
- lua_pop(state_, 1);
- }
- lua_pop(state_, 1);
- return result;
-}
-
-std::map<std::string, Variant> JniLuaEnvironment::ReadExtras() const {
- if (lua_type(state_, kIndexStackTop) != LUA_TTABLE) {
- TC3_LOG(ERROR) << "Expected extras table, got: "
- << lua_type(state_, kIndexStackTop);
- lua_pop(state_, 1);
- return {};
- }
- std::map<std::string, Variant> extras;
- lua_pushnil(state_);
- while (Next(/*index=*/-2)) {
- // Each entry is a table specifying name and value.
- // The value is specified via a type specific field as Lua doesn't allow
- // to easily distinguish between different number types.
- if (lua_type(state_, kIndexStackTop) != LUA_TTABLE) {
- TC3_LOG(ERROR) << "Expected a table for an extra, got: "
- << lua_type(state_, kIndexStackTop);
- lua_pop(state_, 1);
- return {};
- }
- std::string name;
- Variant value;
-
- lua_pushnil(state_);
- while (Next(/*index=*/-2)) {
- const StringPiece key = ReadString(/*index=*/-2);
- if (key.Equals("name")) {
- name = Read<std::string>(/*index=*/kIndexStackTop);
- } else if (key.Equals("int_value")) {
- value = Variant(Read<int>(/*index=*/kIndexStackTop));
- } else if (key.Equals("long_value")) {
- value = Variant(Read<int64>(/*index=*/kIndexStackTop));
- } else if (key.Equals("float_value")) {
- value = Variant(Read<float>(/*index=*/kIndexStackTop));
- } else if (key.Equals("bool_value")) {
- value = Variant(Read<bool>(/*index=*/kIndexStackTop));
- } else if (key.Equals("string_value")) {
- value = Variant(Read<std::string>(/*index=*/kIndexStackTop));
- } else if (key.Equals("string_array_value")) {
- value = Variant(ReadVector<std::string>(/*index=*/kIndexStackTop));
- } else if (key.Equals("float_array_value")) {
- value = Variant(ReadVector<float>(/*index=*/kIndexStackTop));
- } else if (key.Equals("int_array_value")) {
- value = Variant(ReadVector<int>(/*index=*/kIndexStackTop));
- } else if (key.Equals("named_variant_array_value")) {
- value = Variant(ReadExtras());
- } else {
- TC3_LOG(INFO) << "Unknown extra field: " << key;
- }
- lua_pop(state_, 1);
- }
- if (!name.empty()) {
- extras[name] = value;
- } else {
- TC3_LOG(ERROR) << "Unnamed extra entry. Skipping.";
- }
- lua_pop(state_, 1);
- }
- return extras;
-}
-
-int JniLuaEnvironment::ReadRemoteActionTemplates(
- std::vector<RemoteActionTemplate>* result) {
- // Read result.
- if (lua_type(state_, kIndexStackTop) != LUA_TTABLE) {
- TC3_LOG(ERROR) << "Unexpected result for snippet: "
- << lua_type(state_, kIndexStackTop);
- lua_error(state_);
- return LUA_ERRRUN;
- }
-
- // Read remote action templates array.
- lua_pushnil(state_);
- while (Next(/*index=*/-2)) {
- if (lua_type(state_, kIndexStackTop) != LUA_TTABLE) {
- TC3_LOG(ERROR) << "Expected intent table, got: "
- << lua_type(state_, kIndexStackTop);
- lua_pop(state_, 1);
- continue;
- }
- result->push_back(ReadRemoteActionTemplateResult());
- }
- lua_pop(state_, /*n=*/1);
- return LUA_OK;
-}
-
-bool JniLuaEnvironment::RunIntentGenerator(
- const std::string& generator_snippet,
- std::vector<RemoteActionTemplate>* remote_actions) {
- int status;
- status = luaL_loadbuffer(state_, generator_snippet.data(),
- generator_snippet.size(),
- /*name=*/nullptr);
- if (status != LUA_OK) {
- TC3_LOG(ERROR) << "Couldn't load generator snippet: " << status;
- return false;
- }
- status = lua_pcall(state_, /*nargs=*/0, /*nresults=*/1, /*errfunc=*/0);
- if (status != LUA_OK) {
- TC3_LOG(ERROR) << "Couldn't run generator snippet: " << status;
- return false;
- }
- if (RunProtected(
- [this, remote_actions] {
- return ReadRemoteActionTemplates(remote_actions);
- },
- /*num_args=*/1) != LUA_OK) {
- TC3_LOG(ERROR) << "Could not read results.";
- return false;
- }
- // Check that we correctly cleaned-up the state.
- const int stack_size = lua_gettop(state_);
- if (stack_size > 0) {
- TC3_LOG(ERROR) << "Unexpected stack size.";
- lua_settop(state_, 0);
- return false;
- }
- return true;
-}
// Lua environment for classfication result intent generation.
class AnnotatorJniEnvironment : public JniLuaEnvironment {
@@ -855,15 +155,15 @@
TC3_LOG(ERROR) << "No locales provided.";
return {};
}
- ScopedStringChars locales_str =
- GetScopedStringChars(jni_cache_->GetEnv(), device_locales);
- if (locales_str == nullptr) {
- TC3_LOG(ERROR) << "Cannot retrieve provided locales.";
+ StatusOr<std::string> status_or_locales_str =
+ JStringToUtf8String(jni_cache_->GetEnv(), device_locales);
+ if (!status_or_locales_str.ok()) {
+ TC3_LOG(ERROR)
+ << "JStringToUtf8String failed, cannot retrieve provided locales.";
return {};
}
std::vector<Locale> locales;
- if (!ParseLocales(reinterpret_cast<const char*>(locales_str.get()),
- &locales)) {
+ if (!ParseLocales(status_or_locales_str.ValueOrDie(), &locales)) {
TC3_LOG(ERROR) << "Cannot parse locales.";
return {};
}
diff --git a/native/utils/intents/intent-generator.h b/native/utils/intents/intent-generator.h
index 2a45191..a3c8898 100644
--- a/native/utils/intents/intent-generator.h
+++ b/native/utils/intents/intent-generator.h
@@ -18,6 +18,7 @@
#include "utils/resources.h"
#include "utils/resources_generated.h"
#include "utils/strings/stringpiece.h"
+#include "flatbuffers/reflection_generated.h"
namespace libtextclassifier3 {
diff --git a/native/utils/intents/jni-lua.cc b/native/utils/intents/jni-lua.cc
new file mode 100644
index 0000000..f151f4d
--- /dev/null
+++ b/native/utils/intents/jni-lua.cc
@@ -0,0 +1,670 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/intents/jni-lua.h"
+
+#include "utils/hash/farmhash.h"
+#include "utils/java/jni-helper.h"
+#include "utils/strings/substitute.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+#include "lauxlib.h"
+#include "lua.h"
+#ifdef __cplusplus
+}
+#endif
+
+namespace libtextclassifier3 {
+namespace {
+
+static constexpr const char* kHashKey = "hash";
+static constexpr const char* kUrlSchemaKey = "url_schema";
+static constexpr const char* kUrlHostKey = "url_host";
+static constexpr const char* kUrlEncodeKey = "urlencode";
+static constexpr const char* kPackageNameKey = "package_name";
+static constexpr const char* kDeviceLocaleKey = "device_locales";
+static constexpr const char* kFormatKey = "format";
+
+} // namespace
+
+JniLuaEnvironment::JniLuaEnvironment(const Resources& resources,
+ const JniCache* jni_cache,
+ const jobject context,
+ const std::vector<Locale>& device_locales)
+ : LuaEnvironment(),
+ resources_(resources),
+ jenv_(jni_cache ? jni_cache->GetEnv() : nullptr),
+ jni_cache_(jni_cache),
+ context_(context),
+ device_locales_(device_locales),
+ usermanager_(/*object=*/nullptr,
+ /*jvm=*/(jni_cache ? jni_cache->jvm : nullptr)),
+ usermanager_retrieved_(false),
+ system_resources_(/*object=*/nullptr,
+ /*jvm=*/(jni_cache ? jni_cache->jvm : nullptr)),
+ system_resources_resources_retrieved_(false),
+ string_(/*object=*/nullptr,
+ /*jvm=*/(jni_cache ? jni_cache->jvm : nullptr)),
+ android_(/*object=*/nullptr,
+ /*jvm=*/(jni_cache ? jni_cache->jvm : nullptr)) {}
+
+bool JniLuaEnvironment::PreallocateConstantJniStrings() {
+ TC3_ASSIGN_OR_RETURN_FALSE(ScopedLocalRef<jstring> string_value,
+ JniHelper::NewStringUTF(jenv_, "string"));
+ string_ = MakeGlobalRef(string_value.get(), jenv_, jni_cache_->jvm);
+ TC3_ASSIGN_OR_RETURN_FALSE(ScopedLocalRef<jstring> android_value,
+ JniHelper::NewStringUTF(jenv_, "android"));
+ android_ = MakeGlobalRef(android_value.get(), jenv_, jni_cache_->jvm);
+ if (string_ == nullptr || android_ == nullptr) {
+ TC3_LOG(ERROR) << "Could not allocate constant strings references.";
+ return false;
+ }
+ return true;
+}
+
+bool JniLuaEnvironment::Initialize() {
+ if (!PreallocateConstantJniStrings()) {
+ return false;
+ }
+ return (RunProtected([this] {
+ LoadDefaultLibraries();
+ SetupExternalHook();
+ lua_setglobal(state_, "external");
+ return LUA_OK;
+ }) == LUA_OK);
+}
+
+void JniLuaEnvironment::SetupExternalHook() {
+ // This exposes an `external` object with the following fields:
+ // * entity: the bundle with all information about a classification.
+ // * android: callbacks into specific android provided methods.
+ // * android.user_restrictions: callbacks to check user permissions.
+ // * android.R: callbacks to retrieve string resources.
+ PushLazyObject(&JniLuaEnvironment::HandleExternalCallback);
+
+ // android
+ PushLazyObject(&JniLuaEnvironment::HandleAndroidCallback);
+ {
+ // android.user_restrictions
+ PushLazyObject(&JniLuaEnvironment::HandleUserRestrictionsCallback);
+ lua_setfield(state_, /*idx=*/-2, "user_restrictions");
+
+ // android.R
+ // Callback to access android string resources.
+ PushLazyObject(&JniLuaEnvironment::HandleAndroidStringResources);
+ lua_setfield(state_, /*idx=*/-2, "R");
+ }
+ lua_setfield(state_, /*idx=*/-2, "android");
+}
+
+int JniLuaEnvironment::HandleExternalCallback() {
+ const StringPiece key = ReadString(kIndexStackTop);
+ if (key.Equals(kHashKey)) {
+ PushFunction(&JniLuaEnvironment::HandleHash);
+ return 1;
+ } else if (key.Equals(kFormatKey)) {
+ PushFunction(&JniLuaEnvironment::HandleFormat);
+ return 1;
+ } else {
+ TC3_LOG(ERROR) << "Undefined external access " << key;
+ lua_error(state_);
+ return 0;
+ }
+}
+
+int JniLuaEnvironment::HandleAndroidCallback() {
+ const StringPiece key = ReadString(kIndexStackTop);
+ if (key.Equals(kDeviceLocaleKey)) {
+ // Provide the locale as table with the individual fields set.
+ lua_newtable(state_);
+ for (int i = 0; i < device_locales_.size(); i++) {
+ // Adjust index to 1-based indexing for Lua.
+ lua_pushinteger(state_, i + 1);
+ lua_newtable(state_);
+ PushString(device_locales_[i].Language());
+ lua_setfield(state_, -2, "language");
+ PushString(device_locales_[i].Region());
+ lua_setfield(state_, -2, "region");
+ PushString(device_locales_[i].Script());
+ lua_setfield(state_, -2, "script");
+ lua_settable(state_, /*idx=*/-3);
+ }
+ return 1;
+ } else if (key.Equals(kPackageNameKey)) {
+ if (context_ == nullptr) {
+ TC3_LOG(ERROR) << "Context invalid.";
+ lua_error(state_);
+ return 0;
+ }
+
+ StatusOr<ScopedLocalRef<jstring>> status_or_package_name_str =
+ JniHelper::CallObjectMethod<jstring>(
+ jenv_, context_, jni_cache_->context_get_package_name);
+
+ if (!status_or_package_name_str.ok()) {
+ TC3_LOG(ERROR) << "Error calling Context.getPackageName";
+ lua_error(state_);
+ return 0;
+ }
+ StatusOr<std::string> status_or_package_name_std_str = JStringToUtf8String(
+ jenv_, status_or_package_name_str.ValueOrDie().get());
+ if (!status_or_package_name_std_str.ok()) {
+ lua_error(state_);
+ return 0;
+ }
+ PushString(status_or_package_name_std_str.ValueOrDie());
+ return 1;
+ } else if (key.Equals(kUrlEncodeKey)) {
+ PushFunction(&JniLuaEnvironment::HandleUrlEncode);
+ return 1;
+ } else if (key.Equals(kUrlHostKey)) {
+ PushFunction(&JniLuaEnvironment::HandleUrlHost);
+ return 1;
+ } else if (key.Equals(kUrlSchemaKey)) {
+ PushFunction(&JniLuaEnvironment::HandleUrlSchema);
+ return 1;
+ } else {
+ TC3_LOG(ERROR) << "Undefined android reference " << key;
+ lua_error(state_);
+ return 0;
+ }
+}
+
+int JniLuaEnvironment::HandleUserRestrictionsCallback() {
+ if (jni_cache_->usermanager_class == nullptr ||
+ jni_cache_->usermanager_get_user_restrictions == nullptr) {
+ // UserManager is only available for API level >= 17 and
+ // getUserRestrictions only for API level >= 18, so we just return false
+ // normally here.
+ lua_pushboolean(state_, false);
+ return 1;
+ }
+
+ // Get user manager if not previously retrieved.
+ if (!RetrieveUserManager()) {
+ TC3_LOG(ERROR) << "Error retrieving user manager.";
+ lua_error(state_);
+ return 0;
+ }
+
+ StatusOr<ScopedLocalRef<jobject>> status_or_bundle =
+ JniHelper::CallObjectMethod(
+ jenv_, usermanager_.get(),
+ jni_cache_->usermanager_get_user_restrictions);
+ if (!status_or_bundle.ok() || status_or_bundle.ValueOrDie() == nullptr) {
+ TC3_LOG(ERROR) << "Error calling getUserRestrictions";
+ lua_error(state_);
+ return 0;
+ }
+
+ const StringPiece key_str = ReadString(kIndexStackTop);
+ if (key_str.empty()) {
+ TC3_LOG(ERROR) << "Expected string, got null.";
+ lua_error(state_);
+ return 0;
+ }
+
+ const StatusOr<ScopedLocalRef<jstring>> status_or_key =
+ jni_cache_->ConvertToJavaString(key_str);
+ if (!status_or_key.ok()) {
+ lua_error(state_);
+ return 0;
+ }
+ const StatusOr<bool> status_or_permission = JniHelper::CallBooleanMethod(
+ jenv_, status_or_bundle.ValueOrDie().get(),
+ jni_cache_->bundle_get_boolean, status_or_key.ValueOrDie().get());
+ if (!status_or_permission.ok()) {
+ TC3_LOG(ERROR) << "Error getting bundle value";
+ lua_pushboolean(state_, false);
+ } else {
+ lua_pushboolean(state_, status_or_permission.ValueOrDie());
+ }
+ return 1;
+}
+
+int JniLuaEnvironment::HandleUrlEncode() {
+ const StringPiece input = ReadString(/*index=*/1);
+ if (input.empty()) {
+ TC3_LOG(ERROR) << "Expected string, got null.";
+ lua_error(state_);
+ return 0;
+ }
+
+ // Call Java URL encoder.
+ const StatusOr<ScopedLocalRef<jstring>> status_or_input_str =
+ jni_cache_->ConvertToJavaString(input);
+ if (!status_or_input_str.ok()) {
+ lua_error(state_);
+ return 0;
+ }
+ StatusOr<ScopedLocalRef<jstring>> status_or_encoded_str =
+ JniHelper::CallStaticObjectMethod<jstring>(
+ jenv_, jni_cache_->urlencoder_class.get(),
+ jni_cache_->urlencoder_encode, status_or_input_str.ValueOrDie().get(),
+ jni_cache_->string_utf8.get());
+
+ if (!status_or_encoded_str.ok()) {
+ TC3_LOG(ERROR) << "Error calling UrlEncoder.encode";
+ lua_error(state_);
+ return 0;
+ }
+ const StatusOr<std::string> status_or_encoded_std_str =
+ JStringToUtf8String(jenv_, status_or_encoded_str.ValueOrDie().get());
+ if (!status_or_encoded_std_str.ok()) {
+ lua_error(state_);
+ return 0;
+ }
+ PushString(status_or_encoded_std_str.ValueOrDie());
+ return 1;
+}
+
+StatusOr<ScopedLocalRef<jobject>> JniLuaEnvironment::ParseUri(
+ StringPiece url) const {
+ if (url.empty()) {
+ return {Status::UNKNOWN};
+ }
+
+ // Call to Java URI parser.
+ TC3_ASSIGN_OR_RETURN(
+ const StatusOr<ScopedLocalRef<jstring>> status_or_url_str,
+ jni_cache_->ConvertToJavaString(url));
+
+ // Try to parse uri and get scheme.
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jobject> uri,
+ JniHelper::CallStaticObjectMethod(jenv_, jni_cache_->uri_class.get(),
+ jni_cache_->uri_parse,
+ status_or_url_str.ValueOrDie().get()));
+ if (uri == nullptr) {
+ TC3_LOG(ERROR) << "Error calling Uri.parse";
+ return {Status::UNKNOWN};
+ }
+ return uri;
+}
+
+int JniLuaEnvironment::HandleUrlSchema() {
+ StringPiece url = ReadString(/*index=*/1);
+
+ const StatusOr<ScopedLocalRef<jobject>> status_or_parsed_uri = ParseUri(url);
+ if (!status_or_parsed_uri.ok()) {
+ lua_error(state_);
+ return 0;
+ }
+
+ const StatusOr<ScopedLocalRef<jstring>> status_or_scheme_str =
+ JniHelper::CallObjectMethod<jstring>(
+ jenv_, status_or_parsed_uri.ValueOrDie().get(),
+ jni_cache_->uri_get_scheme);
+ if (!status_or_scheme_str.ok()) {
+ TC3_LOG(ERROR) << "Error calling Uri.getScheme";
+ lua_error(state_);
+ return 0;
+ }
+ if (status_or_scheme_str.ValueOrDie() == nullptr) {
+ lua_pushnil(state_);
+ } else {
+ const StatusOr<std::string> status_or_scheme_std_str =
+ JStringToUtf8String(jenv_, status_or_scheme_str.ValueOrDie().get());
+ if (!status_or_scheme_std_str.ok()) {
+ lua_error(state_);
+ return 0;
+ }
+ PushString(status_or_scheme_std_str.ValueOrDie());
+ }
+ return 1;
+}
+
+int JniLuaEnvironment::HandleUrlHost() {
+ const StringPiece url = ReadString(kIndexStackTop);
+
+ const StatusOr<ScopedLocalRef<jobject>> status_or_parsed_uri = ParseUri(url);
+ if (!status_or_parsed_uri.ok()) {
+ lua_error(state_);
+ return 0;
+ }
+
+ const StatusOr<ScopedLocalRef<jstring>> status_or_host_str =
+ JniHelper::CallObjectMethod<jstring>(
+ jenv_, status_or_parsed_uri.ValueOrDie().get(),
+ jni_cache_->uri_get_host);
+ if (!status_or_host_str.ok()) {
+ TC3_LOG(ERROR) << "Error calling Uri.getHost";
+ lua_error(state_);
+ return 0;
+ }
+
+ if (status_or_host_str.ValueOrDie() == nullptr) {
+ lua_pushnil(state_);
+ } else {
+ const StatusOr<std::string> status_or_host_std_str =
+ JStringToUtf8String(jenv_, status_or_host_str.ValueOrDie().get());
+ if (!status_or_host_std_str.ok()) {
+ lua_error(state_);
+ return 0;
+ }
+ PushString(status_or_host_std_str.ValueOrDie());
+ }
+ return 1;
+}
+
+int JniLuaEnvironment::HandleHash() {
+ const StringPiece input = ReadString(kIndexStackTop);
+ lua_pushinteger(state_, tc3farmhash::Hash32(input.data(), input.length()));
+ return 1;
+}
+
+int JniLuaEnvironment::HandleFormat() {
+ const int num_args = lua_gettop(state_);
+ std::vector<StringPiece> args(num_args - 1);
+ for (int i = 0; i < num_args - 1; i++) {
+ args[i] = ReadString(/*index=*/i + 2);
+ }
+ PushString(strings::Substitute(ReadString(/*index=*/1), args));
+ return 1;
+}
+
+bool JniLuaEnvironment::LookupModelStringResource() const {
+ // Handle only lookup by name.
+ if (lua_type(state_, kIndexStackTop) != LUA_TSTRING) {
+ return false;
+ }
+
+ const StringPiece resource_name = ReadString(kIndexStackTop);
+ std::string resource_content;
+ if (!resources_.GetResourceContent(device_locales_, resource_name,
+ &resource_content)) {
+ // Resource cannot be provided by the model.
+ return false;
+ }
+
+ PushString(resource_content);
+ return true;
+}
+
+int JniLuaEnvironment::HandleAndroidStringResources() {
+ // Check whether the requested resource can be served from the model data.
+ if (LookupModelStringResource()) {
+ return 1;
+ }
+
+ // Get system resources if not previously retrieved.
+ if (!RetrieveSystemResources()) {
+ TC3_LOG(ERROR) << "Error retrieving system resources.";
+ lua_error(state_);
+ return 0;
+ }
+
+ int resource_id;
+ switch (lua_type(state_, kIndexStackTop)) {
+ case LUA_TNUMBER:
+ resource_id = Read<int>(/*index=*/kIndexStackTop);
+ break;
+ case LUA_TSTRING: {
+ const StringPiece resource_name_str = ReadString(kIndexStackTop);
+ if (resource_name_str.empty()) {
+ TC3_LOG(ERROR) << "No resource name provided.";
+ lua_error(state_);
+ return 0;
+ }
+ const StatusOr<ScopedLocalRef<jstring>> status_or_resource_name =
+ jni_cache_->ConvertToJavaString(resource_name_str);
+ if (!status_or_resource_name.ok()) {
+ TC3_LOG(ERROR) << "Invalid resource name.";
+ lua_error(state_);
+ return 0;
+ }
+ StatusOr<int> status_or_resource_id = JniHelper::CallIntMethod(
+ jenv_, system_resources_.get(), jni_cache_->resources_get_identifier,
+ status_or_resource_name.ValueOrDie().get(), string_.get(),
+ android_.get());
+ if (!status_or_resource_id.ok()) {
+ TC3_LOG(ERROR) << "Error calling getIdentifier.";
+ lua_error(state_);
+ return 0;
+ }
+ resource_id = status_or_resource_id.ValueOrDie();
+ break;
+ }
+ default:
+ TC3_LOG(ERROR) << "Unexpected type for resource lookup.";
+ lua_error(state_);
+ return 0;
+ }
+ if (resource_id == 0) {
+ TC3_LOG(ERROR) << "Resource not found.";
+ lua_pushnil(state_);
+ return 1;
+ }
+ StatusOr<ScopedLocalRef<jstring>> status_or_resource_str =
+ JniHelper::CallObjectMethod<jstring>(jenv_, system_resources_.get(),
+ jni_cache_->resources_get_string,
+ resource_id);
+ if (!status_or_resource_str.ok()) {
+ TC3_LOG(ERROR) << "Error calling getString.";
+ lua_error(state_);
+ return 0;
+ }
+
+ if (status_or_resource_str.ValueOrDie() == nullptr) {
+ lua_pushnil(state_);
+ } else {
+ StatusOr<std::string> status_or_resource_std_str =
+ JStringToUtf8String(jenv_, status_or_resource_str.ValueOrDie().get());
+ if (!status_or_resource_std_str.ok()) {
+ lua_error(state_);
+ return 0;
+ }
+ PushString(status_or_resource_std_str.ValueOrDie());
+ }
+ return 1;
+}
+
+bool JniLuaEnvironment::RetrieveSystemResources() {
+ if (system_resources_resources_retrieved_) {
+ return (system_resources_ != nullptr);
+ }
+ system_resources_resources_retrieved_ = true;
+ TC3_ASSIGN_OR_RETURN_FALSE(ScopedLocalRef<jobject> system_resources_ref,
+ JniHelper::CallStaticObjectMethod(
+ jenv_, jni_cache_->resources_class.get(),
+ jni_cache_->resources_get_system));
+ system_resources_ =
+ MakeGlobalRef(system_resources_ref.get(), jenv_, jni_cache_->jvm);
+ return (system_resources_ != nullptr);
+}
+
+bool JniLuaEnvironment::RetrieveUserManager() {
+ if (context_ == nullptr) {
+ return false;
+ }
+ if (usermanager_retrieved_) {
+ return (usermanager_ != nullptr);
+ }
+ usermanager_retrieved_ = true;
+ TC3_ASSIGN_OR_RETURN_FALSE(const ScopedLocalRef<jstring> service,
+ JniHelper::NewStringUTF(jenv_, "user"));
+ TC3_ASSIGN_OR_RETURN_FALSE(
+ const ScopedLocalRef<jobject> usermanager_ref,
+ JniHelper::CallObjectMethod(jenv_, context_,
+ jni_cache_->context_get_system_service,
+ service.get()));
+
+ usermanager_ = MakeGlobalRef(usermanager_ref.get(), jenv_, jni_cache_->jvm);
+ return (usermanager_ != nullptr);
+}
+
+RemoteActionTemplate JniLuaEnvironment::ReadRemoteActionTemplateResult() const {
+ RemoteActionTemplate result;
+ // Read intent template.
+ lua_pushnil(state_);
+ while (Next(/*index=*/-2)) {
+ const StringPiece key = ReadString(/*index=*/-2);
+ if (key.Equals("title_without_entity")) {
+ result.title_without_entity = Read<std::string>(/*index=*/kIndexStackTop);
+ } else if (key.Equals("title_with_entity")) {
+ result.title_with_entity = Read<std::string>(/*index=*/kIndexStackTop);
+ } else if (key.Equals("description")) {
+ result.description = Read<std::string>(/*index=*/kIndexStackTop);
+ } else if (key.Equals("description_with_app_name")) {
+ result.description_with_app_name =
+ Read<std::string>(/*index=*/kIndexStackTop);
+ } else if (key.Equals("action")) {
+ result.action = Read<std::string>(/*index=*/kIndexStackTop);
+ } else if (key.Equals("data")) {
+ result.data = Read<std::string>(/*index=*/kIndexStackTop);
+ } else if (key.Equals("type")) {
+ result.type = Read<std::string>(/*index=*/kIndexStackTop);
+ } else if (key.Equals("flags")) {
+ result.flags = Read<int>(/*index=*/kIndexStackTop);
+ } else if (key.Equals("package_name")) {
+ result.package_name = Read<std::string>(/*index=*/kIndexStackTop);
+ } else if (key.Equals("request_code")) {
+ result.request_code = Read<int>(/*index=*/kIndexStackTop);
+ } else if (key.Equals("category")) {
+ result.category = ReadVector<std::string>(/*index=*/kIndexStackTop);
+ } else if (key.Equals("extra")) {
+ result.extra = ReadExtras();
+ } else {
+ TC3_LOG(INFO) << "Unknown entry: " << key;
+ }
+ lua_pop(state_, 1);
+ }
+ lua_pop(state_, 1);
+ return result;
+}
+
+std::map<std::string, Variant> JniLuaEnvironment::ReadExtras() const {
+ if (lua_type(state_, kIndexStackTop) != LUA_TTABLE) {
+ TC3_LOG(ERROR) << "Expected extras table, got: "
+ << lua_type(state_, kIndexStackTop);
+ lua_pop(state_, 1);
+ return {};
+ }
+ std::map<std::string, Variant> extras;
+ lua_pushnil(state_);
+ while (Next(/*index=*/-2)) {
+ // Each entry is a table specifying name and value.
+ // The value is specified via a type specific field as Lua doesn't allow
+ // to easily distinguish between different number types.
+ if (lua_type(state_, kIndexStackTop) != LUA_TTABLE) {
+ TC3_LOG(ERROR) << "Expected a table for an extra, got: "
+ << lua_type(state_, kIndexStackTop);
+ lua_pop(state_, 1);
+ return {};
+ }
+ std::string name;
+ Variant value;
+
+ lua_pushnil(state_);
+ while (Next(/*index=*/-2)) {
+ const StringPiece key = ReadString(/*index=*/-2);
+ if (key.Equals("name")) {
+ name = Read<std::string>(/*index=*/kIndexStackTop);
+ } else if (key.Equals("int_value")) {
+ value = Variant(Read<int>(/*index=*/kIndexStackTop));
+ } else if (key.Equals("long_value")) {
+ value = Variant(Read<int64>(/*index=*/kIndexStackTop));
+ } else if (key.Equals("float_value")) {
+ value = Variant(Read<float>(/*index=*/kIndexStackTop));
+ } else if (key.Equals("bool_value")) {
+ value = Variant(Read<bool>(/*index=*/kIndexStackTop));
+ } else if (key.Equals("string_value")) {
+ value = Variant(Read<std::string>(/*index=*/kIndexStackTop));
+ } else if (key.Equals("string_array_value")) {
+ value = Variant(ReadVector<std::string>(/*index=*/kIndexStackTop));
+ } else if (key.Equals("float_array_value")) {
+ value = Variant(ReadVector<float>(/*index=*/kIndexStackTop));
+ } else if (key.Equals("int_array_value")) {
+ value = Variant(ReadVector<int>(/*index=*/kIndexStackTop));
+ } else if (key.Equals("named_variant_array_value")) {
+ value = Variant(ReadExtras());
+ } else {
+ TC3_LOG(INFO) << "Unknown extra field: " << key;
+ }
+ lua_pop(state_, 1);
+ }
+ if (!name.empty()) {
+ extras[name] = value;
+ } else {
+ TC3_LOG(ERROR) << "Unnamed extra entry. Skipping.";
+ }
+ lua_pop(state_, 1);
+ }
+ return extras;
+}
+
+int JniLuaEnvironment::ReadRemoteActionTemplates(
+ std::vector<RemoteActionTemplate>* result) {
+ // Read result.
+ if (lua_type(state_, kIndexStackTop) != LUA_TTABLE) {
+ TC3_LOG(ERROR) << "Unexpected result for snippet: "
+ << lua_type(state_, kIndexStackTop);
+ lua_error(state_);
+ return LUA_ERRRUN;
+ }
+
+ // Read remote action templates array.
+ lua_pushnil(state_);
+ while (Next(/*index=*/-2)) {
+ if (lua_type(state_, kIndexStackTop) != LUA_TTABLE) {
+ TC3_LOG(ERROR) << "Expected intent table, got: "
+ << lua_type(state_, kIndexStackTop);
+ lua_pop(state_, 1);
+ continue;
+ }
+ result->push_back(ReadRemoteActionTemplateResult());
+ }
+ lua_pop(state_, /*n=*/1);
+ return LUA_OK;
+}
+
+bool JniLuaEnvironment::RunIntentGenerator(
+ const std::string& generator_snippet,
+ std::vector<RemoteActionTemplate>* remote_actions) {
+ int status;
+ status = luaL_loadbuffer(state_, generator_snippet.data(),
+ generator_snippet.size(),
+ /*name=*/nullptr);
+ if (status != LUA_OK) {
+ TC3_LOG(ERROR) << "Couldn't load generator snippet: " << status;
+ return false;
+ }
+ status = lua_pcall(state_, /*nargs=*/0, /*nresults=*/1, /*errfunc=*/0);
+ if (status != LUA_OK) {
+ TC3_LOG(ERROR) << "Couldn't run generator snippet: " << status;
+ return false;
+ }
+ if (RunProtected(
+ [this, remote_actions] {
+ return ReadRemoteActionTemplates(remote_actions);
+ },
+ /*num_args=*/1) != LUA_OK) {
+ TC3_LOG(ERROR) << "Could not read results.";
+ return false;
+ }
+ // Check that we correctly cleaned-up the state.
+ const int stack_size = lua_gettop(state_);
+ if (stack_size > 0) {
+ TC3_LOG(ERROR) << "Unexpected stack size.";
+ lua_settop(state_, 0);
+ return false;
+ }
+ return true;
+}
+
+} // namespace libtextclassifier3
diff --git a/native/utils/intents/jni-lua.h b/native/utils/intents/jni-lua.h
new file mode 100644
index 0000000..ab7bc96
--- /dev/null
+++ b/native/utils/intents/jni-lua.h
@@ -0,0 +1,104 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_INTENTS_JNI_LUA_H_
+#define LIBTEXTCLASSIFIER_UTILS_INTENTS_JNI_LUA_H_
+
+#include <map>
+#include <vector>
+
+#include "utils/base/statusor.h"
+#include "utils/i18n/locale.h"
+#include "utils/intents/remote-action-template.h"
+#include "utils/java/jni-base.h"
+#include "utils/java/jni-cache.h"
+#include "utils/lua-utils.h"
+#include "utils/resources.h"
+#include "utils/strings/stringpiece.h"
+#include "utils/variant.h"
+
+namespace libtextclassifier3 {
+
+// An Android specific Lua environment with JNI backed callbacks.
+class JniLuaEnvironment : public LuaEnvironment {
+ public:
+ JniLuaEnvironment(const Resources& resources, const JniCache* jni_cache,
+ const jobject context,
+ const std::vector<Locale>& device_locales);
+ // Environment setup.
+ bool Initialize();
+
+ // Runs an intent generator snippet.
+ bool RunIntentGenerator(const std::string& generator_snippet,
+ std::vector<RemoteActionTemplate>* remote_actions);
+
+ protected:
+ virtual void SetupExternalHook();
+ bool PreallocateConstantJniStrings();
+
+ int HandleExternalCallback();
+ int HandleAndroidCallback();
+ int HandleUserRestrictionsCallback();
+ int HandleUrlEncode();
+ int HandleUrlSchema();
+ int HandleHash();
+ int HandleFormat();
+ int HandleAndroidStringResources();
+ int HandleUrlHost();
+
+ // Checks and retrieves string resources from the model.
+ bool LookupModelStringResource() const;
+
+ // Reads and create a RemoteAction result from Lua.
+ RemoteActionTemplate ReadRemoteActionTemplateResult() const;
+
+ // Reads the extras from the Lua result.
+ std::map<std::string, Variant> ReadExtras() const;
+
+ // Retrieves user manager if not previously done.
+ bool RetrieveUserManager();
+
+ // Retrieves system resources if not previously done.
+ bool RetrieveSystemResources();
+
+ // Parse the url string by using Uri.parse from Java.
+ StatusOr<ScopedLocalRef<jobject>> ParseUri(StringPiece url) const;
+
+ // Read remote action templates from lua generator.
+ int ReadRemoteActionTemplates(std::vector<RemoteActionTemplate>* result);
+
+ const Resources& resources_;
+ JNIEnv* jenv_;
+ const JniCache* jni_cache_;
+ const jobject context_;
+ std::vector<Locale> device_locales_;
+
+ ScopedGlobalRef<jobject> usermanager_;
+ // Whether we previously attempted to retrieve the UserManager before.
+ bool usermanager_retrieved_;
+
+ ScopedGlobalRef<jobject> system_resources_;
+ // Whether we previously attempted to retrieve the system resources.
+ bool system_resources_resources_retrieved_;
+
+ // Cached JNI references for Java strings `string` and `android`.
+ ScopedGlobalRef<jstring> string_;
+ ScopedGlobalRef<jstring> android_;
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_INTENTS_JNI_LUA_H_
diff --git a/native/utils/intents/jni.cc b/native/utils/intents/jni.cc
index 051d078..c95f03b 100644
--- a/native/utils/intents/jni.cc
+++ b/native/utils/intents/jni.cc
@@ -18,6 +18,7 @@
#include <memory>
+#include "utils/base/status_macros.h"
#include "utils/base/statusor.h"
#include "utils/java/jni-base.h"
#include "utils/java/jni-helper.h"
@@ -27,19 +28,19 @@
// The macros below are intended to reduce the boilerplate and avoid
// easily introduced copy/paste errors.
#define TC3_CHECK_JNI_PTR(PTR) TC3_CHECK((PTR) != nullptr)
-#define TC3_GET_CLASS(FIELD, NAME) \
- { \
- StatusOr<ScopedLocalRef<jclass>> status_or_clazz = \
- JniHelper::FindClass(env, NAME); \
- handler->FIELD = MakeGlobalRef(status_or_clazz.ValueOrDie().release(), \
- env, jni_cache->jvm); \
- TC3_CHECK_JNI_PTR(handler->FIELD) << "Error finding class: " << NAME; \
+#define TC3_GET_CLASS(FIELD, NAME) \
+ { \
+ TC3_ASSIGN_OR_RETURN(ScopedLocalRef<jclass> clazz, \
+ JniHelper::FindClass(env, NAME)); \
+ handler->FIELD = MakeGlobalRef(clazz.release(), env, jni_cache->jvm); \
+ TC3_CHECK_JNI_PTR(handler->FIELD) << "Error finding class: " << NAME; \
}
-#define TC3_GET_METHOD(CLASS, FIELD, NAME, SIGNATURE) \
- handler->FIELD = env->GetMethodID(handler->CLASS.get(), NAME, SIGNATURE); \
- TC3_CHECK(handler->FIELD) << "Error finding method: " << NAME;
+#define TC3_GET_METHOD(CLASS, FIELD, NAME, SIGNATURE) \
+ TC3_ASSIGN_OR_RETURN( \
+ handler->FIELD, \
+ JniHelper::GetMethodID(env, handler->CLASS.get(), NAME, SIGNATURE));
-std::unique_ptr<RemoteActionTemplatesHandler>
+StatusOr<std::unique_ptr<RemoteActionTemplatesHandler>>
RemoteActionTemplatesHandler::Create(
const std::shared_ptr<JniCache>& jni_cache) {
JNIEnv* env = jni_cache->GetEnv();
@@ -127,8 +128,8 @@
for (int k = 0; k < values.size(); k++) {
TC3_ASSIGN_OR_RETURN(ScopedLocalRef<jstring> value_str,
jni_cache_->ConvertToJavaString(values[k]));
- jni_cache_->GetEnv()->SetObjectArrayElement(result.get(), k,
- value_str.get());
+ TC3_RETURN_IF_ERROR(JniHelper::SetObjectArrayElement(
+ jni_cache_->GetEnv(), result.get(), k, value_str.get()));
}
return result;
}
@@ -144,9 +145,9 @@
ScopedLocalRef<jfloatArray> result,
JniHelper::NewFloatArray(jni_cache_->GetEnv(), values.size()));
- jni_cache_->GetEnv()->SetFloatArrayRegion(result.get(), /*start=*/0,
- /*len=*/values.size(),
- &(values[0]));
+ TC3_RETURN_IF_ERROR(JniHelper::SetFloatArrayRegion(
+ jni_cache_->GetEnv(), result.get(), /*start=*/0,
+ /*len=*/values.size(), &(values[0])));
return result;
}
@@ -160,8 +161,9 @@
ScopedLocalRef<jintArray> result,
JniHelper::NewIntArray(jni_cache_->GetEnv(), values.size()));
- jni_cache_->GetEnv()->SetIntArrayRegion(result.get(), /*start=*/0,
- /*len=*/values.size(), &(values[0]));
+ TC3_RETURN_IF_ERROR(JniHelper::SetIntArrayRegion(
+ jni_cache_->GetEnv(), result.get(), /*start=*/0,
+ /*len=*/values.size(), &(values[0])));
return result;
}
@@ -275,8 +277,8 @@
TC3_ASSIGN_OR_RETURN(
StatusOr<ScopedLocalRef<jobject>> named_extra,
AsNamedVariant(key_value_pair.first, key_value_pair.second));
- env->SetObjectArrayElement(result.get(), element_index,
- named_extra.ValueOrDie().get());
+ TC3_RETURN_IF_ERROR(JniHelper::SetObjectArrayElement(
+ env, result.get(), element_index, named_extra.ValueOrDie().get()));
element_index++;
}
return result;
@@ -335,7 +337,8 @@
type.ValueOrDie().get(), flags.ValueOrDie().get(),
category.ValueOrDie().get(), package.ValueOrDie().get(),
extra.ValueOrDie().get(), request_code.ValueOrDie().get()));
- env->SetObjectArrayElement(results.get(), i, result.get());
+ TC3_RETURN_IF_ERROR(
+ JniHelper::SetObjectArrayElement(env, results.get(), i, result.get()));
}
return results;
}
@@ -344,8 +347,8 @@
RemoteActionTemplatesHandler::EntityDataAsNamedVariantArray(
const reflection::Schema* entity_data_schema,
const std::string& serialized_entity_data) const {
- ReflectiveFlatbufferBuilder entity_data_builder(entity_data_schema);
- std::unique_ptr<ReflectiveFlatbuffer> buffer = entity_data_builder.NewRoot();
+ MutableFlatbufferBuilder entity_data_builder(entity_data_schema);
+ std::unique_ptr<MutableFlatbuffer> buffer = entity_data_builder.NewRoot();
buffer->MergeFromSerializedFlatbuffer(serialized_entity_data);
std::map<std::string, Variant> entity_data_map = buffer->AsFlatMap();
return AsNamedVariantArray(entity_data_map);
diff --git a/native/utils/intents/jni.h b/native/utils/intents/jni.h
index ada2631..895c63d 100644
--- a/native/utils/intents/jni.h
+++ b/native/utils/intents/jni.h
@@ -25,7 +25,8 @@
#include <vector>
#include "utils/base/statusor.h"
-#include "utils/flatbuffers.h"
+#include "utils/flatbuffers/flatbuffers.h"
+#include "utils/flatbuffers/mutable.h"
#include "utils/intents/remote-action-template.h"
#include "utils/java/jni-base.h"
#include "utils/java/jni-cache.h"
@@ -51,7 +52,7 @@
// A helper class to create RemoteActionTemplate object from model results.
class RemoteActionTemplatesHandler {
public:
- static std::unique_ptr<RemoteActionTemplatesHandler> Create(
+ static StatusOr<std::unique_ptr<RemoteActionTemplatesHandler>> Create(
const std::shared_ptr<JniCache>& jni_cache);
StatusOr<ScopedLocalRef<jstring>> AsUTF8String(
diff --git a/native/utils/java/jni-base.cc b/native/utils/java/jni-base.cc
index e0829b7..42de67e 100644
--- a/native/utils/java/jni-base.cc
+++ b/native/utils/java/jni-base.cc
@@ -17,7 +17,6 @@
#include "utils/java/jni-base.h"
#include "utils/base/status.h"
-#include "utils/java/string_utils.h"
namespace libtextclassifier3 {
@@ -35,12 +34,4 @@
return result;
}
-StatusOr<std::string> ToStlString(JNIEnv* env, const jstring& str) {
- std::string result;
- if (!JStringToUtf8String(env, str, &result)) {
- return {Status::UNKNOWN};
- }
- return result;
-}
-
} // namespace libtextclassifier3
diff --git a/native/utils/java/jni-base.h b/native/utils/java/jni-base.h
index c7b04e6..0bc46fa 100644
--- a/native/utils/java/jni-base.h
+++ b/native/utils/java/jni-base.h
@@ -67,8 +67,6 @@
// Returns true if there was an exception. Also it clears the exception.
bool JniExceptionCheckAndClear(JNIEnv* env);
-StatusOr<std::string> ToStlString(JNIEnv* env, const jstring& str);
-
// A deleter to be used with std::unique_ptr to delete JNI global references.
class GlobalRefDeleter {
public:
diff --git a/native/utils/java/jni-cache.cc b/native/utils/java/jni-cache.cc
index 0be769d..58d3369 100644
--- a/native/utils/java/jni-cache.cc
+++ b/native/utils/java/jni-cache.cc
@@ -17,6 +17,7 @@
#include "utils/java/jni-cache.h"
#include "utils/base/logging.h"
+#include "utils/base/status_macros.h"
#include "utils/java/jni-base.h"
#include "utils/java/jni-helper.h"
@@ -72,59 +73,61 @@
} \
}
-#define TC3_GET_METHOD(CLASS, FIELD, NAME, SIGNATURE) \
- result->CLASS##_##FIELD = \
- env->GetMethodID(result->CLASS##_class.get(), NAME, SIGNATURE); \
- TC3_CHECK_JNI_RESULT(result->CLASS##_##FIELD) \
- << "Error finding method: " << NAME;
+#define TC3_GET_METHOD(CLASS, FIELD, NAME, SIGNATURE) \
+ TC3_ASSIGN_OR_RETURN_NULL( \
+ result->CLASS##_##FIELD, \
+ JniHelper::GetMethodID(env, result->CLASS##_class.get(), NAME, \
+ SIGNATURE));
-#define TC3_GET_OPTIONAL_METHOD(CLASS, FIELD, NAME, SIGNATURE) \
- if (result->CLASS##_class != nullptr) { \
- result->CLASS##_##FIELD = \
- env->GetMethodID(result->CLASS##_class.get(), NAME, SIGNATURE); \
- env->ExceptionClear(); \
+#define TC3_GET_OPTIONAL_METHOD(CLASS, FIELD, NAME, SIGNATURE) \
+ TC3_GET_OPTIONAL_METHOD_INTERNAL(CLASS, FIELD, NAME, SIGNATURE, GetMethodID)
+
+#define TC3_GET_OPTIONAL_STATIC_METHOD(CLASS, FIELD, NAME, SIGNATURE) \
+ TC3_GET_OPTIONAL_METHOD_INTERNAL(CLASS, FIELD, NAME, SIGNATURE, \
+ GetStaticMethodID)
+
+#define TC3_GET_OPTIONAL_METHOD_INTERNAL(CLASS, FIELD, NAME, SIGNATURE, \
+ METHOD_NAME) \
+ if (result->CLASS##_class != nullptr) { \
+ if (StatusOr<jmethodID> status_or_method_id = JniHelper::METHOD_NAME( \
+ env, result->CLASS##_class.get(), NAME, SIGNATURE); \
+ status_or_method_id.ok()) { \
+ result->CLASS##_##FIELD = status_or_method_id.ValueOrDie(); \
+ } \
}
-#define TC3_GET_OPTIONAL_STATIC_METHOD(CLASS, FIELD, NAME, SIGNATURE) \
- if (result->CLASS##_class != nullptr) { \
- result->CLASS##_##FIELD = \
- env->GetStaticMethodID(result->CLASS##_class.get(), NAME, SIGNATURE); \
- env->ExceptionClear(); \
+#define TC3_GET_STATIC_METHOD(CLASS, FIELD, NAME, SIGNATURE) \
+ TC3_ASSIGN_OR_RETURN_NULL( \
+ result->CLASS##_##FIELD, \
+ JniHelper::GetStaticMethodID(env, result->CLASS##_class.get(), NAME, \
+ SIGNATURE));
+
+#define TC3_GET_STATIC_OBJECT_FIELD_OR_RETURN_NULL(CLASS, FIELD, NAME, \
+ SIGNATURE) \
+ { \
+ TC3_ASSIGN_OR_RETURN_NULL( \
+ const jfieldID CLASS##_##FIELD##_field, \
+ JniHelper::GetStaticFieldID(env, result->CLASS##_class.get(), NAME, \
+ SIGNATURE)); \
+ TC3_ASSIGN_OR_RETURN_NULL( \
+ ScopedLocalRef<jobject> static_object, \
+ JniHelper::GetStaticObjectField(env, result->CLASS##_class.get(), \
+ CLASS##_##FIELD##_field)); \
+ result->CLASS##_##FIELD = MakeGlobalRef(static_object.get(), env, jvm); \
+ if (result->CLASS##_##FIELD == nullptr) { \
+ TC3_LOG(ERROR) << "Error finding field: " << NAME; \
+ return nullptr; \
+ } \
}
-#define TC3_GET_STATIC_METHOD(CLASS, FIELD, NAME, SIGNATURE) \
- result->CLASS##_##FIELD = \
- env->GetStaticMethodID(result->CLASS##_class.get(), NAME, SIGNATURE); \
- TC3_CHECK_JNI_RESULT(result->CLASS##_##FIELD) \
- << "Error finding method: " << NAME;
-
-#define TC3_GET_STATIC_OBJECT_FIELD_OR_RETURN_NULL(CLASS, FIELD, NAME, \
- SIGNATURE) \
- { \
- const jfieldID CLASS##_##FIELD##_field = \
- env->GetStaticFieldID(result->CLASS##_class.get(), NAME, SIGNATURE); \
- TC3_CHECK_JNI_RESULT(CLASS##_##FIELD##_field) \
- << "Error finding field id: " << NAME; \
- TC3_ASSIGN_OR_RETURN_NULL( \
- ScopedLocalRef<jobject> static_object, \
- JniHelper::GetStaticObjectField(env, result->CLASS##_class.get(), \
- CLASS##_##FIELD##_field)); \
- result->CLASS##_##FIELD = MakeGlobalRef(static_object.get(), env, jvm); \
- if (result->CLASS##_##FIELD == nullptr) { \
- TC3_LOG(ERROR) << "Error finding field: " << NAME; \
- return nullptr; \
- } \
- }
-
-#define TC3_GET_STATIC_INT_FIELD(CLASS, FIELD, NAME) \
- const jfieldID CLASS##_##FIELD##_field = \
- env->GetStaticFieldID(result->CLASS##_class.get(), NAME, "I"); \
- TC3_CHECK_JNI_RESULT(CLASS##_##FIELD##_field) \
- << "Error finding field id: " << NAME; \
- result->CLASS##_##FIELD = env->GetStaticIntField( \
- result->CLASS##_class.get(), CLASS##_##FIELD##_field); \
- TC3_CHECK_JNI_RESULT(result->CLASS##_##FIELD) \
- << "Error finding field: " << NAME;
+#define TC3_GET_STATIC_INT_FIELD(CLASS, FIELD, NAME) \
+ TC3_ASSIGN_OR_RETURN_NULL(const jfieldID CLASS##_##FIELD##_field, \
+ JniHelper::GetStaticFieldID( \
+ env, result->CLASS##_class.get(), NAME, "I")); \
+ TC3_ASSIGN_OR_RETURN_NULL( \
+ result->CLASS##_##FIELD, \
+ JniHelper::GetStaticIntField(env, result->CLASS##_class.get(), \
+ CLASS##_##FIELD##_field));
std::unique_ptr<JniCache> JniCache::Create(JNIEnv* env) {
if (env == nullptr) {
@@ -290,8 +293,9 @@
TC3_ASSIGN_OR_RETURN(ScopedLocalRef<jbyteArray> text_java_utf8,
JniHelper::NewByteArray(jenv, utf8_text_size_bytes));
- jenv->SetByteArrayRegion(text_java_utf8.get(), 0, utf8_text_size_bytes,
- reinterpret_cast<const jbyte*>(utf8_text));
+ TC3_RETURN_IF_ERROR(JniHelper::SetByteArrayRegion(
+ jenv, text_java_utf8.get(), 0, utf8_text_size_bytes,
+ reinterpret_cast<const jbyte*>(utf8_text)));
// Create the string with a UTF-8 charset.
TC3_ASSIGN_OR_RETURN(ScopedLocalRef<jstring> result,
diff --git a/native/utils/java/jni-helper.cc b/native/utils/java/jni-helper.cc
index d1677e4..c7d8012 100644
--- a/native/utils/java/jni-helper.cc
+++ b/native/utils/java/jni-helper.cc
@@ -16,6 +16,8 @@
#include "utils/java/jni-helper.h"
+#include "utils/base/status_macros.h"
+
namespace libtextclassifier3 {
StatusOr<ScopedLocalRef<jclass>> JniHelper::FindClass(JNIEnv* env,
@@ -27,10 +29,46 @@
return result;
}
+StatusOr<ScopedLocalRef<jclass>> JniHelper::GetObjectClass(JNIEnv* env,
+ jobject object) {
+ TC3_ENSURE_LOCAL_CAPACITY_OR_RETURN;
+ ScopedLocalRef<jclass> result(env->GetObjectClass(object), env);
+ TC3_NO_EXCEPTION_OR_RETURN;
+ TC3_NOT_NULL_OR_RETURN;
+ return result;
+}
+
StatusOr<jmethodID> JniHelper::GetMethodID(JNIEnv* env, jclass clazz,
const char* method_name,
- const char* return_type) {
- jmethodID result = env->GetMethodID(clazz, method_name, return_type);
+ const char* signature) {
+ jmethodID result = env->GetMethodID(clazz, method_name, signature);
+ TC3_NO_EXCEPTION_OR_RETURN;
+ TC3_NOT_NULL_OR_RETURN;
+ return result;
+}
+
+StatusOr<jmethodID> JniHelper::GetStaticMethodID(JNIEnv* env, jclass clazz,
+ const char* method_name,
+ const char* signature) {
+ jmethodID result = env->GetStaticMethodID(clazz, method_name, signature);
+ TC3_NO_EXCEPTION_OR_RETURN;
+ TC3_NOT_NULL_OR_RETURN;
+ return result;
+}
+
+StatusOr<jfieldID> JniHelper::GetFieldID(JNIEnv* env, jclass clazz,
+ const char* field_name,
+ const char* signature) {
+ jfieldID result = env->GetFieldID(clazz, field_name, signature);
+ TC3_NO_EXCEPTION_OR_RETURN;
+ TC3_NOT_NULL_OR_RETURN;
+ return result;
+}
+
+StatusOr<jfieldID> JniHelper::GetStaticFieldID(JNIEnv* env, jclass clazz,
+ const char* field_name,
+ const char* signature) {
+ jfieldID result = env->GetStaticFieldID(clazz, field_name, signature);
TC3_NO_EXCEPTION_OR_RETURN;
TC3_NOT_NULL_OR_RETURN;
return result;
@@ -46,6 +84,14 @@
return result;
}
+StatusOr<jint> JniHelper::GetStaticIntField(JNIEnv* env, jclass class_name,
+ jfieldID field_id) {
+ TC3_ENSURE_LOCAL_CAPACITY_OR_RETURN;
+ jint result = env->GetStaticIntField(class_name, field_id);
+ TC3_NO_EXCEPTION_OR_RETURN;
+ return result;
+}
+
StatusOr<ScopedLocalRef<jbyteArray>> JniHelper::NewByteArray(JNIEnv* env,
jsize length) {
TC3_ENSURE_LOCAL_CAPACITY_OR_RETURN;
@@ -147,6 +193,46 @@
return Status::OK;
}
+StatusOr<jsize> JniHelper::GetArrayLength(JNIEnv* env, jarray array) {
+ TC3_ENSURE_LOCAL_CAPACITY_OR_RETURN;
+ jsize result = env->GetArrayLength(array);
+ TC3_NO_EXCEPTION_OR_RETURN;
+ return result;
+}
+
+Status JniHelper::GetByteArrayRegion(JNIEnv* env, jbyteArray array, jsize start,
+ jsize len, jbyte* buf) {
+ TC3_ENSURE_LOCAL_CAPACITY_OR_RETURN;
+ env->GetByteArrayRegion(array, start, len, buf);
+ TC3_NO_EXCEPTION_OR_RETURN;
+ return Status::OK;
+}
+
+Status JniHelper::SetByteArrayRegion(JNIEnv* env, jbyteArray array, jsize start,
+ jsize len, const jbyte* buf) {
+ TC3_ENSURE_LOCAL_CAPACITY_OR_RETURN;
+ env->SetByteArrayRegion(array, start, len, buf);
+ TC3_NO_EXCEPTION_OR_RETURN;
+ return Status::OK;
+}
+
+Status JniHelper::SetIntArrayRegion(JNIEnv* env, jintArray array, jsize start,
+ jsize len, const jint* buf) {
+ TC3_ENSURE_LOCAL_CAPACITY_OR_RETURN;
+ env->SetIntArrayRegion(array, start, len, buf);
+ TC3_NO_EXCEPTION_OR_RETURN;
+ return Status::OK;
+}
+
+Status JniHelper::SetFloatArrayRegion(JNIEnv* env, jfloatArray array,
+ jsize start, jsize len,
+ const jfloat* buf) {
+ TC3_ENSURE_LOCAL_CAPACITY_OR_RETURN;
+ env->SetFloatArrayRegion(array, start, len, buf);
+ TC3_NO_EXCEPTION_OR_RETURN;
+ return Status::OK;
+}
+
StatusOr<ScopedLocalRef<jobjectArray>> JniHelper::NewObjectArray(
JNIEnv* env, jsize length, jclass element_class, jobject initial_element) {
TC3_ENSURE_LOCAL_CAPACITY_OR_RETURN;
@@ -157,14 +243,6 @@
return result;
}
-StatusOr<jsize> JniHelper::GetArrayLength(JNIEnv* env,
- jarray jinput_fragments) {
- TC3_ENSURE_LOCAL_CAPACITY_OR_RETURN;
- jsize result = env->GetArrayLength(jinput_fragments);
- TC3_NO_EXCEPTION_OR_RETURN;
- return result;
-}
-
StatusOr<ScopedLocalRef<jstring>> JniHelper::NewStringUTF(JNIEnv* env,
const char* bytes) {
TC3_ENSURE_LOCAL_CAPACITY_OR_RETURN;
@@ -174,4 +252,37 @@
return result;
}
+StatusOr<std::string> JByteArrayToString(JNIEnv* env, jbyteArray array) {
+ std::string result;
+ TC3_ASSIGN_OR_RETURN(const int array_length,
+ JniHelper::GetArrayLength(env, array));
+ result.resize(array_length);
+ TC3_RETURN_IF_ERROR(JniHelper::GetByteArrayRegion(
+ env, array, 0, array_length,
+ reinterpret_cast<jbyte*>(const_cast<char*>(result.data()))));
+ return result;
+}
+
+StatusOr<std::string> JStringToUtf8String(JNIEnv* env, jstring jstr) {
+ if (jstr == nullptr) {
+ return "";
+ }
+
+ TC3_ASSIGN_OR_RETURN(ScopedLocalRef<jclass> string_class,
+ JniHelper::FindClass(env, "java/lang/String"));
+ TC3_ASSIGN_OR_RETURN(
+ jmethodID get_bytes_id,
+ JniHelper::GetMethodID(env, string_class.get(), "getBytes",
+ "(Ljava/lang/String;)[B"));
+
+ TC3_ASSIGN_OR_RETURN(ScopedLocalRef<jstring> encoding,
+ JniHelper::NewStringUTF(env, "UTF-8"));
+
+ TC3_ASSIGN_OR_RETURN(ScopedLocalRef<jbyteArray> array,
+ JniHelper::CallObjectMethod<jbyteArray>(
+ env, jstr, get_bytes_id, encoding.get()));
+
+ return JByteArrayToString(env, array.get());
+}
+
} // namespace libtextclassifier3
diff --git a/native/utils/java/jni-helper.h b/native/utils/java/jni-helper.h
index 55d4696..952fe95 100644
--- a/native/utils/java/jni-helper.h
+++ b/native/utils/java/jni-helper.h
@@ -74,16 +74,31 @@
static StatusOr<ScopedLocalRef<jclass>> FindClass(JNIEnv* env,
const char* class_name);
+ static StatusOr<ScopedLocalRef<jclass>> GetObjectClass(JNIEnv* env,
+ jobject object);
+
template <typename T = jobject>
static StatusOr<ScopedLocalRef<T>> GetObjectArrayElement(JNIEnv* env,
jobjectArray array,
jsize index);
static StatusOr<jmethodID> GetMethodID(JNIEnv* env, jclass clazz,
const char* method_name,
- const char* return_type);
+ const char* signature);
+ static StatusOr<jmethodID> GetStaticMethodID(JNIEnv* env, jclass clazz,
+ const char* method_name,
+ const char* signature);
+
+ static StatusOr<jfieldID> GetFieldID(JNIEnv* env, jclass clazz,
+ const char* field_name,
+ const char* signature);
+ static StatusOr<jfieldID> GetStaticFieldID(JNIEnv* env, jclass clazz,
+ const char* field_name,
+ const char* signature);
static StatusOr<ScopedLocalRef<jobject>> GetStaticObjectField(
JNIEnv* env, jclass class_name, jfieldID field_id);
+ static StatusOr<jint> GetStaticIntField(JNIEnv* env, jclass class_name,
+ jfieldID field_id);
// New* methods.
TC3_DEFINE_VARIADIC_SCOPED_LOCAL_REF_ENV_METHOD(NewObject, jobject, jclass,
@@ -100,11 +115,23 @@
static StatusOr<ScopedLocalRef<jfloatArray>> NewFloatArray(JNIEnv* env,
jsize length);
- static StatusOr<jsize> GetArrayLength(JNIEnv* env, jarray jinput_fragments);
+ static StatusOr<jsize> GetArrayLength(JNIEnv* env, jarray array);
static Status SetObjectArrayElement(JNIEnv* env, jobjectArray array,
jsize index, jobject val);
+ static Status GetByteArrayRegion(JNIEnv* env, jbyteArray array, jsize start,
+ jsize len, jbyte* buf);
+
+ static Status SetByteArrayRegion(JNIEnv* env, jbyteArray array, jsize start,
+ jsize len, const jbyte* buf);
+
+ static Status SetIntArrayRegion(JNIEnv* env, jintArray array, jsize start,
+ jsize len, const jint* buf);
+
+ static Status SetFloatArrayRegion(JNIEnv* env, jfloatArray array, jsize start,
+ jsize len, const jfloat* buf);
+
// Call* methods.
TC3_DEFINE_VARIADIC_SCOPED_LOCAL_REF_ENV_METHOD(CallObjectMethod, jobject,
jobject, TC3_JNI_NO_CHECK);
@@ -153,6 +180,12 @@
return result;
}
+// Converts Java byte[] object to std::string.
+StatusOr<std::string> JByteArrayToString(JNIEnv* env, jbyteArray array);
+
+// Converts Java String object to UTF8-encoded std::string.
+StatusOr<std::string> JStringToUtf8String(JNIEnv* env, jstring jstr);
+
} // namespace libtextclassifier3
#endif // LIBTEXTCLASSIFIER_UTILS_JAVA_JNI_HELPER_H_
diff --git a/native/utils/java/string_utils.cc b/native/utils/java/string_utils.cc
deleted file mode 100644
index ca518a0..0000000
--- a/native/utils/java/string_utils.cc
+++ /dev/null
@@ -1,75 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/java/string_utils.h"
-
-#include "utils/base/logging.h"
-
-namespace libtextclassifier3 {
-
-bool JByteArrayToString(JNIEnv* env, const jbyteArray& array,
- std::string* result) {
- jbyte* const array_bytes = env->GetByteArrayElements(array, JNI_FALSE);
- if (array_bytes == nullptr) {
- return false;
- }
-
- const int array_length = env->GetArrayLength(array);
- *result = std::string(reinterpret_cast<char*>(array_bytes), array_length);
-
- env->ReleaseByteArrayElements(array, array_bytes, JNI_ABORT);
-
- return true;
-}
-
-bool JStringToUtf8String(JNIEnv* env, const jstring& jstr,
- std::string* result) {
- if (jstr == nullptr) {
- *result = std::string();
- return true;
- }
-
- jclass string_class = env->FindClass("java/lang/String");
- if (!string_class) {
- TC3_LOG(ERROR) << "Can't find String class";
- return false;
- }
-
- jmethodID get_bytes_id =
- env->GetMethodID(string_class, "getBytes", "(Ljava/lang/String;)[B");
-
- jstring encoding = env->NewStringUTF("UTF-8");
-
- jbyteArray array = reinterpret_cast<jbyteArray>(
- env->CallObjectMethod(jstr, get_bytes_id, encoding));
-
- JByteArrayToString(env, array, result);
-
- // Release the array.
- env->DeleteLocalRef(array);
- env->DeleteLocalRef(string_class);
- env->DeleteLocalRef(encoding);
-
- return true;
-}
-
-ScopedStringChars GetScopedStringChars(JNIEnv* env, jstring string,
- jboolean* is_copy) {
- return ScopedStringChars(env->GetStringUTFChars(string, is_copy),
- StringCharsReleaser(env, string));
-}
-
-} // namespace libtextclassifier3
diff --git a/native/utils/java/string_utils.h b/native/utils/java/string_utils.h
deleted file mode 100644
index 172a938..0000000
--- a/native/utils/java/string_utils.h
+++ /dev/null
@@ -1,76 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_UTILS_JAVA_STRING_UTILS_H_
-#define LIBTEXTCLASSIFIER_UTILS_JAVA_STRING_UTILS_H_
-
-#include <jni.h>
-#include <memory>
-#include <string>
-
-#include "utils/base/logging.h"
-
-namespace libtextclassifier3 {
-
-bool JByteArrayToString(JNIEnv* env, const jbyteArray& array,
- std::string* result);
-bool JStringToUtf8String(JNIEnv* env, const jstring& jstr, std::string* result);
-
-// A deleter to be used with std::unique_ptr to release Java string chars.
-class StringCharsReleaser {
- public:
- StringCharsReleaser() : env_(nullptr) {}
-
- StringCharsReleaser(JNIEnv* env, jstring jstr) : env_(env), jstr_(jstr) {}
-
- StringCharsReleaser(const StringCharsReleaser& orig) = default;
-
- // Copy assignment to allow move semantics in StringCharsReleaser.
- StringCharsReleaser& operator=(const StringCharsReleaser& rhs) {
- // As the releaser and its state are thread-local, it's enough to only
- // ensure the envs are consistent but do nothing.
- TC3_CHECK_EQ(env_, rhs.env_);
- return *this;
- }
-
- // The delete operator.
- void operator()(const char* chars) const {
- if (env_ != nullptr) {
- env_->ReleaseStringUTFChars(jstr_, chars);
- }
- }
-
- private:
- // The env_ stashed to use for deletion. Thread-local, don't share!
- JNIEnv* const env_;
-
- // The referenced jstring.
- jstring jstr_;
-};
-
-// A smart pointer that releases string chars when it goes out of scope.
-// of scope.
-// Note that this class is not thread-safe since it caches JNIEnv in
-// the deleter. Do not use the same jobject across different threads.
-using ScopedStringChars = std::unique_ptr<const char, StringCharsReleaser>;
-
-// Returns a scoped pointer to the array of Unicode characters of a string.
-ScopedStringChars GetScopedStringChars(JNIEnv* env, jstring string,
- jboolean* is_copy = nullptr);
-
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_UTILS_JAVA_STRING_UTILS_H_
diff --git a/native/utils/lua-utils.cc b/native/utils/lua-utils.cc
index fe3d12d..37d5e0d 100644
--- a/native/utils/lua-utils.cc
+++ b/native/utils/lua-utils.cc
@@ -221,7 +221,7 @@
}
int LuaEnvironment::ReadFlatbuffer(const int index,
- ReflectiveFlatbuffer* buffer) const {
+ MutableFlatbuffer* buffer) const {
if (buffer == nullptr) {
TC3_LOG(ERROR) << "Called ReadFlatbuffer with null buffer: " << index;
lua_error(state_);
@@ -322,8 +322,8 @@
buffer->Repeated(field));
break;
case reflection::Obj:
- ReadRepeatedField<ReflectiveFlatbuffer>(/*index=*/kIndexStackTop,
- buffer->Repeated(field));
+ ReadRepeatedField<MutableFlatbuffer>(/*index=*/kIndexStackTop,
+ buffer->Repeated(field));
break;
default:
TC3_LOG(ERROR) << "Unsupported repeated field type: "
@@ -542,7 +542,7 @@
classification.serialized_entity_data =
Read<std::string>(/*index=*/kIndexStackTop);
} else if (key.Equals(kEntityKey)) {
- auto buffer = ReflectiveFlatbufferBuilder(entity_data_schema).NewRoot();
+ auto buffer = MutableFlatbufferBuilder(entity_data_schema).NewRoot();
ReadFlatbuffer(/*index=*/kIndexStackTop, buffer.get());
classification.serialized_entity_data = buffer->Serialize();
} else {
@@ -610,7 +610,7 @@
ReadAnnotations(actions_entity_data_schema, &action.annotations);
} else if (key.Equals(kEntityKey)) {
auto buffer =
- ReflectiveFlatbufferBuilder(actions_entity_data_schema).NewRoot();
+ MutableFlatbufferBuilder(actions_entity_data_schema).NewRoot();
ReadFlatbuffer(/*index=*/kIndexStackTop, buffer.get());
action.serialized_entity_data = buffer->Serialize();
} else {
diff --git a/native/utils/lua-utils.h b/native/utils/lua-utils.h
index b01471a..98c451c 100644
--- a/native/utils/lua-utils.h
+++ b/native/utils/lua-utils.h
@@ -21,7 +21,7 @@
#include "actions/types.h"
#include "annotator/types.h"
-#include "utils/flatbuffers.h"
+#include "utils/flatbuffers/mutable.h"
#include "utils/strings/stringpiece.h"
#include "utils/variant.h"
#include "flatbuffers/reflection_generated.h"
@@ -65,7 +65,7 @@
class LuaEnvironment {
public:
virtual ~LuaEnvironment();
- LuaEnvironment();
+ explicit LuaEnvironment();
// Compile a lua snippet into binary bytecode.
// NOTE: The compiled bytecode might not be compatible across Lua versions
@@ -213,7 +213,7 @@
}
// Reads a flatbuffer from the stack.
- int ReadFlatbuffer(int index, ReflectiveFlatbuffer* buffer) const;
+ int ReadFlatbuffer(int index, MutableFlatbuffer* buffer) const;
// Pushes an iterator.
template <typename ItemCallback, typename KeyCallback>
@@ -513,8 +513,8 @@
}
template <>
- void ReadRepeatedField<ReflectiveFlatbuffer>(const int index,
- RepeatedField* result) const {
+ void ReadRepeatedField<MutableFlatbuffer>(const int index,
+ RepeatedField* result) const {
lua_pushnil(state_);
while (Next(index - 1)) {
ReadFlatbuffer(index, result->Add());
diff --git a/native/utils/lua-utils_test.cc b/native/utils/lua-utils_test.cc
index 8c9f8de..b4f6181 100644
--- a/native/utils/lua-utils_test.cc
+++ b/native/utils/lua-utils_test.cc
@@ -18,7 +18,8 @@
#include <string>
-#include "utils/flatbuffers.h"
+#include "utils/flatbuffers/flatbuffers.h"
+#include "utils/flatbuffers/mutable.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
@@ -124,7 +125,7 @@
const std::string serialized_flatbuffer_schema_;
const reflection::Schema* schema_;
- ReflectiveFlatbufferBuilder flatbuffer_builder_;
+ MutableFlatbufferBuilder flatbuffer_builder_;
};
TEST_F(LuaUtilsTest, HandlesVectors) {
@@ -184,7 +185,7 @@
)lua");
// Read the flatbuffer.
- std::unique_ptr<ReflectiveFlatbuffer> buffer = flatbuffer_builder_.NewRoot();
+ std::unique_ptr<MutableFlatbuffer> buffer = flatbuffer_builder_.NewRoot();
ReadFlatbuffer(/*index=*/-1, buffer.get());
const std::string serialized_buffer = buffer->Serialize();
@@ -235,7 +236,7 @@
TEST_F(LuaUtilsTest, HandlesSimpleFlatbufferFields) {
// Create test flatbuffer.
- std::unique_ptr<ReflectiveFlatbuffer> buffer = flatbuffer_builder_.NewRoot();
+ std::unique_ptr<MutableFlatbuffer> buffer = flatbuffer_builder_.NewRoot();
buffer->Set("float_field", 42.f);
const std::string serialized_buffer = buffer->Serialize();
PushFlatbuffer(schema_, flatbuffers::GetRoot<flatbuffers::Table>(
@@ -252,7 +253,7 @@
TEST_F(LuaUtilsTest, HandlesRepeatedFlatbufferFields) {
// Create test flatbuffer.
- std::unique_ptr<ReflectiveFlatbuffer> buffer = flatbuffer_builder_.NewRoot();
+ std::unique_ptr<MutableFlatbuffer> buffer = flatbuffer_builder_.NewRoot();
RepeatedField* repeated_field = buffer->Repeated("repeated_string_field");
repeated_field->Add("this");
repeated_field->Add("is");
@@ -274,11 +275,11 @@
TEST_F(LuaUtilsTest, HandlesRepeatedNestedFlatbufferFields) {
// Create test flatbuffer.
- std::unique_ptr<ReflectiveFlatbuffer> buffer = flatbuffer_builder_.NewRoot();
+ std::unique_ptr<MutableFlatbuffer> buffer = flatbuffer_builder_.NewRoot();
RepeatedField* repeated_field = buffer->Repeated("repeated_nested_field");
repeated_field->Add()->Set("string_field", "hello");
repeated_field->Add()->Set("string_field", "my");
- ReflectiveFlatbuffer* nested = repeated_field->Add();
+ MutableFlatbuffer* nested = repeated_field->Add();
nested->Set("string_field", "old");
RepeatedField* nested_repeated = nested->Repeated("repeated_string_field");
nested_repeated->Add("friend");
@@ -308,14 +309,14 @@
TEST_F(LuaUtilsTest, CorrectlyReadsTwoFlatbuffersSimultaneously) {
// The first flatbuffer.
- std::unique_ptr<ReflectiveFlatbuffer> buffer = flatbuffer_builder_.NewRoot();
+ std::unique_ptr<MutableFlatbuffer> buffer = flatbuffer_builder_.NewRoot();
buffer->Set("string_field", "first");
const std::string serialized_buffer = buffer->Serialize();
PushFlatbuffer(schema_, flatbuffers::GetRoot<flatbuffers::Table>(
serialized_buffer.data()));
lua_setglobal(state_, "arg");
// The second flatbuffer.
- std::unique_ptr<ReflectiveFlatbuffer> buffer2 = flatbuffer_builder_.NewRoot();
+ std::unique_ptr<MutableFlatbuffer> buffer2 = flatbuffer_builder_.NewRoot();
buffer2->Set("string_field", "second");
const std::string serialized_buffer2 = buffer2->Serialize();
PushFlatbuffer(schema_, flatbuffers::GetRoot<flatbuffers::Table>(
diff --git a/native/utils/optional.h b/native/utils/optional.h
index 15d2619..572350d 100644
--- a/native/utils/optional.h
+++ b/native/utils/optional.h
@@ -62,7 +62,7 @@
return value_;
}
- T const& value_or(T&& default_value) {
+ T const& value_or(T&& default_value) const& {
return (init_ ? value_ : default_value);
}
diff --git a/native/utils/resources.cc b/native/utils/resources.cc
index 2ae2def..24b3a6f 100644
--- a/native/utils/resources.cc
+++ b/native/utils/resources.cc
@@ -18,7 +18,6 @@
#include "utils/base/logging.h"
#include "utils/zlib/buffer_generated.h"
-#include "utils/zlib/zlib.h"
namespace libtextclassifier3 {
namespace {
@@ -128,121 +127,8 @@
if (resource->content() != nullptr) {
*result = resource->content()->str();
return true;
- } else if (resource->compressed_content() != nullptr) {
- std::unique_ptr<ZlibDecompressor> decompressor = ZlibDecompressor::Instance(
- resources_->compression_dictionary()->data(),
- resources_->compression_dictionary()->size());
- if (decompressor != nullptr &&
- decompressor->MaybeDecompress(resource->compressed_content(), result)) {
- return true;
- }
}
return false;
}
-bool CompressResources(ResourcePoolT* resources,
- const bool build_compression_dictionary,
- const int dictionary_sample_every) {
- std::vector<unsigned char> dictionary;
- if (build_compression_dictionary) {
- {
- // Build up a compression dictionary.
- std::unique_ptr<ZlibCompressor> compressor = ZlibCompressor::Instance();
- int i = 0;
- for (auto& entry : resources->resource_entry) {
- for (auto& resource : entry->resource) {
- if (resource->content.empty()) {
- continue;
- }
- i++;
-
- // Use a sample of the entries to build up a custom compression
- // dictionary. Using all entries will generally not give a benefit
- // for small data sizes, so we subsample here.
- if (i % dictionary_sample_every != 0) {
- continue;
- }
- CompressedBufferT compressed_content;
- compressor->Compress(resource->content, &compressed_content);
- }
- }
- compressor->GetDictionary(&dictionary);
- resources->compression_dictionary.assign(
- dictionary.data(), dictionary.data() + dictionary.size());
- }
- }
-
- for (auto& entry : resources->resource_entry) {
- for (auto& resource : entry->resource) {
- if (resource->content.empty()) {
- continue;
- }
- // Try compressing the data.
- std::unique_ptr<ZlibCompressor> compressor =
- build_compression_dictionary
- ? ZlibCompressor::Instance(dictionary.data(), dictionary.size())
- : ZlibCompressor::Instance();
- if (!compressor) {
- TC3_LOG(ERROR) << "Cannot create zlib compressor.";
- return false;
- }
-
- CompressedBufferT compressed_content;
- compressor->Compress(resource->content, &compressed_content);
-
- // Only keep compressed version if smaller.
- if (compressed_content.uncompressed_size >
- compressed_content.buffer.size()) {
- resource->content.clear();
- resource->compressed_content.reset(new CompressedBufferT);
- *resource->compressed_content = compressed_content;
- }
- }
- }
- return true;
-}
-
-std::string CompressSerializedResources(const std::string& resources,
- const int dictionary_sample_every) {
- std::unique_ptr<ResourcePoolT> unpacked_resources(
- flatbuffers::GetRoot<ResourcePool>(resources.data())->UnPack());
- TC3_CHECK(unpacked_resources != nullptr);
- TC3_CHECK(
- CompressResources(unpacked_resources.get(), dictionary_sample_every));
- flatbuffers::FlatBufferBuilder builder;
- builder.Finish(ResourcePool::Pack(builder, unpacked_resources.get()));
- return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
- builder.GetSize());
-}
-
-bool DecompressResources(ResourcePoolT* resources,
- const bool build_compression_dictionary) {
- std::vector<unsigned char> dictionary;
-
- for (auto& entry : resources->resource_entry) {
- for (auto& resource : entry->resource) {
- if (resource->compressed_content == nullptr) {
- continue;
- }
-
- std::unique_ptr<ZlibDecompressor> zlib_decompressor =
- build_compression_dictionary
- ? ZlibDecompressor::Instance(dictionary.data(), dictionary.size())
- : ZlibDecompressor::Instance();
- if (!zlib_decompressor) {
- TC3_LOG(ERROR) << "Cannot initialize decompressor.";
- return false;
- }
-
- if (!zlib_decompressor->MaybeDecompress(
- resource->compressed_content.get(), &resource->content)) {
- TC3_LOG(ERROR) << "Cannot decompress resource.";
- return false;
- }
- resource->compressed_content.reset(nullptr);
- }
- }
- return true;
-}
-
} // namespace libtextclassifier3
diff --git a/native/utils/resources.fbs b/native/utils/resources.fbs
index aae57cf..0a05718 100755
--- a/native/utils/resources.fbs
+++ b/native/utils/resources.fbs
@@ -14,14 +14,13 @@
// limitations under the License.
//
-include "utils/i18n/language-tag.fbs";
include "utils/zlib/buffer.fbs";
+include "utils/i18n/language-tag.fbs";
namespace libtextclassifier3;
table Resource {
locale:[int];
content:string (shared);
- compressed_content:CompressedBuffer;
}
namespace libtextclassifier3;
@@ -34,6 +33,5 @@
table ResourcePool {
locale:[LanguageTag];
resource_entry:[ResourceEntry];
- compression_dictionary:[ubyte];
}
diff --git a/native/utils/resources.h b/native/utils/resources.h
index 96f9683..ca601fe 100644
--- a/native/utils/resources.h
+++ b/native/utils/resources.h
@@ -63,18 +63,6 @@
const ResourcePool* resources_;
};
-// Compresses resources in place.
-bool CompressResources(ResourcePoolT* resources,
- const bool build_compression_dictionary = false,
- const int dictionary_sample_every = 1);
-std::string CompressSerializedResources(
- const std::string& resources,
- const bool build_compression_dictionary = false,
- const int dictionary_sample_every = 1);
-
-bool DecompressResources(ResourcePoolT* resources,
- const bool build_compression_dictionary = false);
-
} // namespace libtextclassifier3
#endif // LIBTEXTCLASSIFIER_UTILS_RESOURCES_H_
diff --git a/native/utils/resources_test.cc b/native/utils/resources_test.cc
index c385f39..6e3d0a1 100644
--- a/native/utils/resources_test.cc
+++ b/native/utils/resources_test.cc
@@ -15,6 +15,7 @@
*/
#include "utils/resources.h"
+
#include "utils/i18n/locale.h"
#include "utils/resources_generated.h"
#include "gmock/gmock.h"
@@ -23,8 +24,7 @@
namespace libtextclassifier3 {
namespace {
-class ResourcesTest
- : public testing::TestWithParam<testing::tuple<bool, bool>> {
+class ResourcesTest : public testing::Test {
protected:
ResourcesTest() {}
@@ -57,7 +57,7 @@
test_resources.locale.back()->language = "zh";
test_resources.locale.emplace_back(new LanguageTagT);
test_resources.locale.back()->language = "fr";
- test_resources.locale.back()->language = "fr-CA";
+ test_resources.locale.back()->region = "CA";
if (add_default_language) {
test_resources.locale.emplace_back(new LanguageTagT); // default
}
@@ -115,12 +115,6 @@
test_resources.resource_entry.back()->resource.back()->content = "龍";
test_resources.resource_entry.back()->resource.back()->locale.push_back(7);
- if (compress()) {
- EXPECT_TRUE(CompressResources(
- &test_resources,
- /*build_compression_dictionary=*/build_dictionary()));
- }
-
flatbuffers::FlatBufferBuilder builder;
builder.Finish(ResourcePool::Pack(builder, &test_resources));
@@ -128,16 +122,9 @@
reinterpret_cast<const char*>(builder.GetBufferPointer()),
builder.GetSize());
}
-
- bool compress() const { return testing::get<0>(GetParam()); }
-
- bool build_dictionary() const { return testing::get<1>(GetParam()); }
};
-INSTANTIATE_TEST_SUITE_P(Compression, ResourcesTest,
- testing::Combine(testing::Bool(), testing::Bool()));
-
-TEST_P(ResourcesTest, CorrectlyHandlesExactMatch) {
+TEST_F(ResourcesTest, CorrectlyHandlesExactMatch) {
std::string test_resources = BuildTestResources();
Resources resources(
flatbuffers::GetRoot<ResourcePool>(test_resources.data()));
@@ -162,7 +149,7 @@
EXPECT_EQ("localiser", content);
}
-TEST_P(ResourcesTest, CorrectlyHandlesTie) {
+TEST_F(ResourcesTest, CorrectlyHandlesTie) {
std::string test_resources = BuildTestResources();
Resources resources(
flatbuffers::GetRoot<ResourcePool>(test_resources.data()));
@@ -173,7 +160,7 @@
EXPECT_EQ("localize", content);
}
-TEST_P(ResourcesTest, RequiresLanguageMatch) {
+TEST_F(ResourcesTest, RequiresLanguageMatch) {
{
std::string test_resources =
BuildTestResources(/*add_default_language=*/false);
@@ -196,7 +183,7 @@
}
}
-TEST_P(ResourcesTest, HandlesFallback) {
+TEST_F(ResourcesTest, HandlesFallback) {
std::string test_resources = BuildTestResources();
Resources resources(
flatbuffers::GetRoot<ResourcePool>(test_resources.data()));
@@ -217,7 +204,7 @@
EXPECT_EQ("localize", content);
}
-TEST_P(ResourcesTest, HandlesFallbackMultipleLocales) {
+TEST_F(ResourcesTest, HandlesFallbackMultipleLocales) {
std::string test_resources = BuildTestResources();
Resources resources(
flatbuffers::GetRoot<ResourcePool>(test_resources.data()));
@@ -252,7 +239,7 @@
EXPECT_EQ("localize", content);
}
-TEST_P(ResourcesTest, PreferGenericCallback) {
+TEST_F(ResourcesTest, PreferGenericCallback) {
std::string test_resources = BuildTestResources();
Resources resources(
flatbuffers::GetRoot<ResourcePool>(test_resources.data()));
@@ -271,7 +258,7 @@
EXPECT_EQ("龍", content); // Falls back to zh, not zh-Hans-CN.
}
-TEST_P(ResourcesTest, PreferGenericWhenGeneric) {
+TEST_F(ResourcesTest, PreferGenericWhenGeneric) {
std::string test_resources = BuildTestResources();
Resources resources(
flatbuffers::GetRoot<ResourcePool>(test_resources.data()));
diff --git a/native/utils/sentencepiece/normalizer.cc b/native/utils/sentencepiece/normalizer.cc
index 4cee507..d2b0c06 100644
--- a/native/utils/sentencepiece/normalizer.cc
+++ b/native/utils/sentencepiece/normalizer.cc
@@ -124,8 +124,8 @@
}
const bool no_match = match.match_length <= 0;
if (no_match) {
- const int char_length = ValidUTF8CharLength(input.data(), input.size());
- if (char_length <= 0) {
+ int char_length;
+ if (!IsValidChar(input.data(), input.size(), &char_length)) {
// Found a malformed utf8.
// The rune is set to be 0xFFFD (REPLACEMENT CHARACTER),
// which is a valid Unicode of three bytes in utf8,
diff --git a/native/utils/sentencepiece/normalizer_test.cc b/native/utils/sentencepiece/normalizer_test.cc
new file mode 100644
index 0000000..57debe3
--- /dev/null
+++ b/native/utils/sentencepiece/normalizer_test.cc
@@ -0,0 +1,199 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/sentencepiece/normalizer.h"
+
+#include <fstream>
+#include <string>
+
+#include "utils/container/double-array-trie.h"
+#include "utils/sentencepiece/test_utils.h"
+#include "utils/strings/stringpiece.h"
+#include "utils/test-data-test-utils.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+std::string GetTestConfigPath() {
+ return GetTestDataPath("utils/sentencepiece/test_data/nmt_nfkc_charsmap.bin");
+}
+
+TEST(NormalizerTest, NormalizesAsReferenceNormalizer) {
+ std::ifstream test_config_stream(GetTestConfigPath());
+ std::string config((std::istreambuf_iterator<char>(test_config_stream)),
+ (std::istreambuf_iterator<char>()));
+ SentencePieceNormalizer normalizer =
+ NormalizerFromSpec(config, /*add_dummy_prefix=*/true,
+ /*remove_extra_whitespaces=*/true,
+ /*escape_whitespaces=*/true);
+ {
+ std::string normalized;
+ EXPECT_TRUE(normalizer.Normalize("hello there", &normalized));
+ EXPECT_EQ(normalized, "▁hello▁there");
+ }
+
+ // Redundant whitespace.
+ {
+ std::string normalized;
+ EXPECT_TRUE(normalizer.Normalize("when is the world cup?", &normalized));
+ EXPECT_EQ(normalized, "▁when▁is▁the▁world▁cup?");
+ }
+
+ // Different whitespace.
+ {
+ std::string normalized;
+ EXPECT_TRUE(normalizer.Normalize("general\tkenobi", &normalized));
+ EXPECT_EQ(normalized, "▁general▁kenobi");
+ }
+
+ // NFKC char to multi-char normalization.
+ {
+ std::string normalized;
+ EXPECT_TRUE(normalizer.Normalize("㍿", &normalized));
+ EXPECT_EQ(normalized, "▁株式会社");
+ }
+
+ // Half width katakana, character composition happens.
+ {
+ std::string normalized;
+ EXPECT_TRUE(normalizer.Normalize(" グーグル ", &normalized));
+ EXPECT_EQ(normalized, "▁グーグル");
+ }
+
+ // NFKC char to char normalization.
+ {
+ std::string normalized;
+ EXPECT_TRUE(normalizer.Normalize("①②③", &normalized));
+ EXPECT_EQ(normalized, "▁123");
+ }
+}
+
+TEST(NormalizerTest, NoDummyPrefix) {
+ std::ifstream test_config_stream(GetTestConfigPath());
+ std::string config((std::istreambuf_iterator<char>(test_config_stream)),
+ (std::istreambuf_iterator<char>()));
+ SentencePieceNormalizer normalizer =
+ NormalizerFromSpec(config, /*add_dummy_prefix=*/false,
+ /*remove_extra_whitespaces=*/true,
+ /*escape_whitespaces=*/true);
+
+ // NFKC char to char normalization.
+ {
+ std::string normalized;
+ EXPECT_TRUE(normalizer.Normalize("hello there", &normalized));
+ EXPECT_EQ(normalized, "hello▁there");
+ }
+
+ // Redundant whitespace.
+ {
+ std::string normalized;
+ EXPECT_TRUE(normalizer.Normalize("when is the world cup?", &normalized));
+ EXPECT_EQ(normalized, "when▁is▁the▁world▁cup?");
+ }
+
+ // Different whitespace.
+ {
+ std::string normalized;
+ EXPECT_TRUE(normalizer.Normalize("general\tkenobi", &normalized));
+ EXPECT_EQ(normalized, "general▁kenobi");
+ }
+
+ // NFKC char to multi-char normalization.
+ {
+ std::string normalized;
+ EXPECT_TRUE(normalizer.Normalize("㍿", &normalized));
+ EXPECT_EQ(normalized, "株式会社");
+ }
+
+ // Half width katakana, character composition happens.
+ {
+ std::string normalized;
+ EXPECT_TRUE(normalizer.Normalize(" グーグル ", &normalized));
+ EXPECT_EQ(normalized, "グーグル");
+ }
+
+ // NFKC char to char normalization.
+ {
+ std::string normalized;
+ EXPECT_TRUE(normalizer.Normalize("①②③", &normalized));
+ EXPECT_EQ(normalized, "123");
+ }
+}
+
+TEST(NormalizerTest, NoRemoveExtraWhitespace) {
+ std::ifstream test_config_stream(GetTestConfigPath());
+ std::string config((std::istreambuf_iterator<char>(test_config_stream)),
+ (std::istreambuf_iterator<char>()));
+ SentencePieceNormalizer normalizer =
+ NormalizerFromSpec(config, /*add_dummy_prefix=*/false,
+ /*remove_extra_whitespaces=*/false,
+ /*escape_whitespaces=*/true);
+
+ {
+ std::string normalized;
+ EXPECT_TRUE(normalizer.Normalize("hello there", &normalized));
+ EXPECT_EQ(normalized, "hello▁there");
+ }
+
+ // Redundant whitespace.
+ {
+ std::string normalized;
+ EXPECT_TRUE(normalizer.Normalize("when is the world cup?", &normalized));
+ EXPECT_EQ(normalized, "when▁is▁▁the▁▁world▁cup?");
+ }
+
+ // Different whitespace.
+ {
+ std::string normalized;
+ EXPECT_TRUE(normalizer.Normalize("general\tkenobi", &normalized));
+ EXPECT_EQ(normalized, "general▁kenobi");
+ }
+}
+
+TEST(NormalizerTest, NoEscapeWhitespaces) {
+ std::ifstream test_config_stream(GetTestConfigPath());
+ std::string config((std::istreambuf_iterator<char>(test_config_stream)),
+ (std::istreambuf_iterator<char>()));
+ SentencePieceNormalizer normalizer =
+ NormalizerFromSpec(config, /*add_dummy_prefix=*/false,
+ /*remove_extra_whitespaces=*/false,
+ /*escape_whitespaces=*/false);
+
+ {
+ std::string normalized;
+ EXPECT_TRUE(normalizer.Normalize("hello there", &normalized));
+ EXPECT_EQ(normalized, "hello there");
+ }
+
+ // Redundant whitespace.
+ {
+ std::string normalized;
+ EXPECT_TRUE(normalizer.Normalize("when is the world cup?", &normalized));
+ EXPECT_EQ(normalized, "when is the world cup?");
+ }
+
+ // Different whitespace.
+ {
+ std::string normalized;
+ EXPECT_TRUE(normalizer.Normalize("general\tkenobi", &normalized));
+ EXPECT_EQ(normalized, "general kenobi");
+ }
+}
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/native/utils/sentencepiece/test_data/nmt_nfkc_charsmap.bin b/native/utils/sentencepiece/test_data/nmt_nfkc_charsmap.bin
new file mode 100644
index 0000000..74da62d
--- /dev/null
+++ b/native/utils/sentencepiece/test_data/nmt_nfkc_charsmap.bin
Binary files differ
diff --git a/native/utils/sentencepiece/test_utils.cc b/native/utils/sentencepiece/test_utils.cc
new file mode 100644
index 0000000..f277a14
--- /dev/null
+++ b/native/utils/sentencepiece/test_utils.cc
@@ -0,0 +1,42 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/sentencepiece/test_utils.h"
+
+#include <memory>
+
+#include "utils/base/integral_types.h"
+#include "utils/container/double-array-trie.h"
+#include "utils/strings/stringpiece.h"
+
+namespace libtextclassifier3 {
+
+SentencePieceNormalizer NormalizerFromSpec(StringPiece spec,
+ bool add_dummy_prefix,
+ bool remove_extra_whitespaces,
+ bool escape_whitespaces) {
+ const uint32 trie_blob_size = reinterpret_cast<const uint32*>(spec.data())[0];
+ spec.RemovePrefix(sizeof(trie_blob_size));
+ const TrieNode* trie_blob = reinterpret_cast<const TrieNode*>(spec.data());
+ spec.RemovePrefix(trie_blob_size);
+ const int num_nodes = trie_blob_size / sizeof(TrieNode);
+ return SentencePieceNormalizer(
+ DoubleArrayTrie(trie_blob, num_nodes),
+ /*charsmap_normalized=*/StringPiece(spec.data(), spec.size()),
+ add_dummy_prefix, remove_extra_whitespaces, escape_whitespaces);
+}
+
+} // namespace libtextclassifier3
diff --git a/native/utils/sentencepiece/test_utils.h b/native/utils/sentencepiece/test_utils.h
new file mode 100644
index 0000000..0c833da
--- /dev/null
+++ b/native/utils/sentencepiece/test_utils.h
@@ -0,0 +1,35 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_TEST_UTILS_H_
+#define LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_TEST_UTILS_H_
+
+#include <string>
+#include <vector>
+
+#include "utils/sentencepiece/normalizer.h"
+#include "utils/strings/stringpiece.h"
+
+namespace libtextclassifier3 {
+
+SentencePieceNormalizer NormalizerFromSpec(StringPiece spec,
+ bool add_dummy_prefix,
+ bool remove_extra_whitespaces,
+ bool escape_whitespaces);
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_TEST_UTILS_H_
diff --git a/native/utils/strings/utf8.cc b/native/utils/strings/utf8.cc
index 932e2a5..b3ed0af 100644
--- a/native/utils/strings/utf8.cc
+++ b/native/utils/strings/utf8.cc
@@ -19,10 +19,11 @@
#include "utils/base/logging.h"
namespace libtextclassifier3 {
+
bool IsValidUTF8(const char *src, int size) {
+ int char_length;
for (int i = 0; i < size;) {
- const int char_length = ValidUTF8CharLength(src + i, size - i);
- if (char_length <= 0) {
+ if (!IsValidChar(src + i, size - i, &char_length)) {
return false;
}
i += char_length;
@@ -30,27 +31,6 @@
return true;
}
-int ValidUTF8CharLength(const char *src, int size) {
- // Unexpected trail byte.
- if (IsTrailByte(src[0])) {
- return -1;
- }
-
- const int num_codepoint_bytes = GetNumBytesForUTF8Char(&src[0]);
- if (num_codepoint_bytes <= 0 || num_codepoint_bytes > size) {
- return -1;
- }
-
- // Check that remaining bytes in the codepoint are trailing bytes.
- for (int k = 1; k < num_codepoint_bytes; k++) {
- if (!IsTrailByte(src[k])) {
- return -1;
- }
- }
-
- return num_codepoint_bytes;
-}
-
int SafeTruncateLength(const char *str, int truncate_at) {
// Always want to truncate at the start of a character, so if
// it's in a middle, back up toward the start
@@ -88,6 +68,47 @@
((byte3 & 0x3F) << 6) | (byte4 & 0x3F);
}
+bool IsValidChar(const char *str, int size, int *num_bytes) {
+ // Unexpected trail byte.
+ if (IsTrailByte(str[0])) {
+ return false;
+ }
+
+ *num_bytes = GetNumBytesForUTF8Char(str);
+ if (*num_bytes <= 0 || *num_bytes > size) {
+ return false;
+ }
+
+ // Check that remaining bytes in the codepoint are trailing bytes.
+ for (int k = 1; k < *num_bytes; k++) {
+ if (!IsTrailByte(str[k])) {
+ return false;
+ }
+ }
+
+ // Exclude overlong encodings.
+ // Check that the codepoint is encoded with the minimum number of required
+ // bytes. An ascii value could be encoded in 4, 3 or 2 bytes but requires
+ // only 1. There is a unique valid encoding for each code point.
+ // This ensures that string comparisons and searches are well-defined.
+ // See: https://en.wikipedia.org/wiki/UTF-8
+ const char32 codepoint = ValidCharToRune(str);
+ switch (*num_bytes) {
+ case 1:
+ return true;
+ case 2:
+ // Everything below 128 can be encoded in one byte.
+ return (codepoint >= (1 << 7 /* num. payload bits in one byte */));
+ case 3:
+ return (codepoint >= (1 << 11 /* num. payload bits in two utf8 bytes */));
+ case 4:
+ return (codepoint >=
+ (1 << 16 /* num. payload bits in three utf8 bytes */)) &&
+ (codepoint < 0x10FFFF /* maximum rune value */);
+ }
+ return false;
+}
+
int ValidRuneToChar(const char32 rune, char *dest) {
// Convert to unsigned for range check.
uint32 c;
diff --git a/native/utils/strings/utf8.h b/native/utils/strings/utf8.h
index e871731..370cf23 100644
--- a/native/utils/strings/utf8.h
+++ b/native/utils/strings/utf8.h
@@ -41,10 +41,6 @@
// Returns true iff src points to a well-formed UTF-8 string.
bool IsValidUTF8(const char *src, int size);
-// Returns byte length of the first valid codepoint in the string, otherwise -1
-// if pointing to an ill-formed UTF-8 character.
-int ValidUTF8CharLength(const char *src, int size);
-
// Helper to ensure that strings are not truncated in the middle of
// multi-byte UTF-8 characters.
// Given a string, and a position at which to truncate, returns the
@@ -55,6 +51,10 @@
// Gets a unicode codepoint from a valid utf8 encoding.
char32 ValidCharToRune(const char *str);
+// Checks whether a utf8 encoding is a valid codepoint and returns the number of
+// bytes of the codepoint.
+bool IsValidChar(const char *str, int size, int *num_bytes);
+
// Converts a valid codepoint to utf8.
// Returns the length of the encoding.
int ValidRuneToChar(const char32 rune, char *dest);
diff --git a/native/utils/strings/utf8_test.cc b/native/utils/strings/utf8_test.cc
index 28d971b..5b4b748 100644
--- a/native/utils/strings/utf8_test.cc
+++ b/native/utils/strings/utf8_test.cc
@@ -34,25 +34,18 @@
EXPECT_TRUE(IsValidUTF8("\u304A\u00B0\u106B", 8));
EXPECT_TRUE(IsValidUTF8("this is a test😋😋😋", 26));
EXPECT_TRUE(IsValidUTF8("\xf0\x9f\x98\x8b", 4));
+ // Example with first byte payload of zero.
+ EXPECT_TRUE(IsValidUTF8("\xf0\x90\x80\x80", 4));
// Too short (string is too short).
EXPECT_FALSE(IsValidUTF8("\xf0\x9f", 2));
// Too long (too many trailing bytes).
EXPECT_FALSE(IsValidUTF8("\xf0\x9f\x98\x8b\x8b", 5));
// Too short (too few trailing bytes).
EXPECT_FALSE(IsValidUTF8("\xf0\x9f\x98\x61\x61", 5));
-}
-
-TEST(Utf8Test, ValidUTF8CharLength) {
- EXPECT_EQ(ValidUTF8CharLength("1234😋hello", 13), 1);
- EXPECT_EQ(ValidUTF8CharLength("\u304A\u00B0\u106B", 8), 3);
- EXPECT_EQ(ValidUTF8CharLength("this is a test😋😋😋", 26), 1);
- EXPECT_EQ(ValidUTF8CharLength("\xf0\x9f\x98\x8b", 4), 4);
- // Too short (string is too short).
- EXPECT_EQ(ValidUTF8CharLength("\xf0\x9f", 2), -1);
- // Too long (too many trailing bytes). First character is valid.
- EXPECT_EQ(ValidUTF8CharLength("\xf0\x9f\x98\x8b\x8b", 5), 4);
- // Too short (too few trailing bytes).
- EXPECT_EQ(ValidUTF8CharLength("\xf0\x9f\x98\x61\x61", 5), -1);
+ // Invalid continuation byte (can be encoded in less bytes).
+ EXPECT_FALSE(IsValidUTF8("\xc0\x81", 2));
+ // Invalid continuation byte (can be encoded in less bytes).
+ EXPECT_FALSE(IsValidUTF8("\xf0\x8a\x85\x8f", 4));
}
TEST(Utf8Test, CorrectlyTruncatesStrings) {
diff --git a/native/utils/test-data-test-utils.h b/native/utils/test-data-test-utils.h
new file mode 100644
index 0000000..8bafbeb
--- /dev/null
+++ b/native/utils/test-data-test-utils.h
@@ -0,0 +1,35 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Utilities for accessing test data.
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_TEST_DATA_TEST_UTILS_H_
+#define LIBTEXTCLASSIFIER_UTILS_TEST_DATA_TEST_UTILS_H_
+
+#include "gtest/gtest.h"
+#include "android-base/file.h"
+
+namespace libtextclassifier3 {
+
+// Get the file path to the test data.
+inline std::string GetTestDataPath(const std::string& relative_path) {
+ return android::base::GetExecutableDirectory() + "/" +
+ relative_path;
+}
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_TEST_DATA_TEST_UTILS_H_
diff --git a/native/utils/test-utils.cc b/native/utils/test-utils.cc
index 8996a4a..9e3216a 100644
--- a/native/utils/test-utils.cc
+++ b/native/utils/test-utils.cc
@@ -31,7 +31,8 @@
}
std::vector<Token> TokenizeOnDelimiters(
- const std::string& text, const std::unordered_set<char32>& delimiters) {
+ const std::string& text, const std::unordered_set<char32>& delimiters,
+ bool create_tokens_for_non_space_delimiters) {
const UnicodeText unicode_text = UTF8ToUnicodeText(text, /*do_copy=*/false);
std::vector<Token> result;
@@ -48,6 +49,10 @@
result.push_back(Token{UnicodeText::UTF8Substring(token_start_it, it),
token_start_codepoint, codepoint_idx});
}
+ if (create_tokens_for_non_space_delimiters && *it != ' ') {
+ result.push_back(
+ Token{std::string(1, *it), codepoint_idx, codepoint_idx + 1});
+ }
token_start_codepoint = codepoint_idx + 1;
token_start_it = it;
diff --git a/native/utils/test-utils.h b/native/utils/test-utils.h
index 0e75190..184e60a 100644
--- a/native/utils/test-utils.h
+++ b/native/utils/test-utils.h
@@ -30,8 +30,13 @@
// Returns a list of Tokens for a given input string, by tokenizing on the
// given set of delimiter codepoints.
+// If create_tokens_for_non_space_delimiters is true, create tokens for
+// delimiters which are not white spaces. For example "This, is" -> {"This",
+// ",", "is"}.
+
std::vector<Token> TokenizeOnDelimiters(
- const std::string& text, const std::unordered_set<char32>& delimiters);
+ const std::string& text, const std::unordered_set<char32>& delimiters,
+ bool create_tokens_for_non_space_delimiters = false);
} // namespace libtextclassifier3
diff --git a/native/utils/test-utils_test.cc b/native/utils/test-utils_test.cc
index bdaa285..88a3ec1 100644
--- a/native/utils/test-utils_test.cc
+++ b/native/utils/test-utils_test.cc
@@ -96,5 +96,49 @@
EXPECT_EQ(tokens[5].end, 35);
}
+TEST(TestUtilTest, TokenizeOnDelimitersKeepNoSpace) {
+ std::vector<Token> tokens = TokenizeOnDelimiters(
+ "This might be čomplíčateď?!: Oder?", {' ', '?', '!'},
+ /* create_tokens_for_non_space_delimiters =*/true);
+
+ EXPECT_EQ(tokens.size(), 9);
+
+ EXPECT_EQ(tokens[0].value, "This");
+ EXPECT_EQ(tokens[0].start, 0);
+ EXPECT_EQ(tokens[0].end, 4);
+
+ EXPECT_EQ(tokens[1].value, "might");
+ EXPECT_EQ(tokens[1].start, 7);
+ EXPECT_EQ(tokens[1].end, 12);
+
+ EXPECT_EQ(tokens[2].value, "be");
+ EXPECT_EQ(tokens[2].start, 13);
+ EXPECT_EQ(tokens[2].end, 15);
+
+ EXPECT_EQ(tokens[3].value, "čomplíčateď");
+ EXPECT_EQ(tokens[3].start, 16);
+ EXPECT_EQ(tokens[3].end, 27);
+
+ EXPECT_EQ(tokens[4].value, "?");
+ EXPECT_EQ(tokens[4].start, 27);
+ EXPECT_EQ(tokens[4].end, 28);
+
+ EXPECT_EQ(tokens[5].value, "!");
+ EXPECT_EQ(tokens[5].start, 28);
+ EXPECT_EQ(tokens[5].end, 29);
+
+ EXPECT_EQ(tokens[6].value, ":");
+ EXPECT_EQ(tokens[6].start, 29);
+ EXPECT_EQ(tokens[6].end, 30);
+
+ EXPECT_EQ(tokens[7].value, "Oder");
+ EXPECT_EQ(tokens[7].start, 31);
+ EXPECT_EQ(tokens[7].end, 35);
+
+ EXPECT_EQ(tokens[8].value, "?");
+ EXPECT_EQ(tokens[8].start, 35);
+ EXPECT_EQ(tokens[8].end, 36);
+}
+
} // namespace
} // namespace libtextclassifier3
diff --git a/native/utils/tflite-model-executor.cc b/native/utils/tflite-model-executor.cc
index 55faea5..31ed414 100644
--- a/native/utils/tflite-model-executor.cc
+++ b/native/utils/tflite-model-executor.cc
@@ -80,14 +80,14 @@
resolver->AddBuiltin(tflite::BuiltinOperator_CONV_2D,
tflite::ops::builtin::Register_CONV_2D(),
/*min_version=*/1,
- /*max_version=*/3);
+ /*max_version=*/5);
resolver->AddBuiltin(::tflite::BuiltinOperator_EQUAL,
::tflite::ops::builtin::Register_EQUAL());
resolver->AddBuiltin(tflite::BuiltinOperator_FULLY_CONNECTED,
tflite::ops::builtin::Register_FULLY_CONNECTED(),
/*min_version=*/1,
- /*max_version=*/4);
+ /*max_version=*/9);
resolver->AddBuiltin(::tflite::BuiltinOperator_GREATER_EQUAL,
::tflite::ops::builtin::Register_GREATER_EQUAL());
resolver->AddBuiltin(tflite::BuiltinOperator_L2_NORMALIZATION,
@@ -186,7 +186,12 @@
namespace libtextclassifier3 {
-inline std::unique_ptr<tflite::OpResolver> BuildOpResolver() {
+std::unique_ptr<tflite::OpResolver> BuildOpResolver() {
+ return BuildOpResolver([](tflite::MutableOpResolver* mutable_resolver) {});
+}
+
+std::unique_ptr<tflite::OpResolver> BuildOpResolver(
+ const std::function<void(tflite::MutableOpResolver*)>& customize_fn) {
#ifdef TC3_USE_SELECTIVE_REGISTRATION
std::unique_ptr<tflite::MutableOpResolver> resolver(
new tflite::MutableOpResolver);
@@ -203,6 +208,7 @@
resolver->AddCustom("TokenEncoder",
tflite::ops::custom::Register_TOKEN_ENCODER());
#endif // TC3_WITH_ACTIONS_OPS
+ customize_fn(resolver.get());
return std::unique_ptr<tflite::OpResolver>(std::move(resolver));
}
diff --git a/native/utils/tflite-model-executor.h b/native/utils/tflite-model-executor.h
index a4432ff..faa1295 100644
--- a/native/utils/tflite-model-executor.h
+++ b/native/utils/tflite-model-executor.h
@@ -28,12 +28,21 @@
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/model.h"
+#include "tensorflow/lite/mutable_op_resolver.h"
#include "tensorflow/lite/op_resolver.h"
#include "tensorflow/lite/string_util.h"
namespace libtextclassifier3 {
+// Creates a TF.Lite Op resolver in default configuration, with ops for
+// Annotator and Actions models.
std::unique_ptr<tflite::OpResolver> BuildOpResolver();
+
+// Like above, but allows passage of a function that can register additional
+// ops.
+std::unique_ptr<tflite::OpResolver> BuildOpResolver(
+ const std::function<void(tflite::MutableOpResolver*)>& customize_fn);
+
std::unique_ptr<const tflite::FlatBufferModel> TfLiteModelFromModelSpec(
const tflite::Model*);
std::unique_ptr<const tflite::FlatBufferModel> TfLiteModelFromBuffer(
diff --git a/native/utils/tokenizer.cc b/native/utils/tokenizer.cc
index bd47592..da66ff6 100644
--- a/native/utils/tokenizer.cc
+++ b/native/utils/tokenizer.cc
@@ -234,15 +234,12 @@
if (!break_iterator) {
return false;
}
- int last_break_index = 0;
- int break_index = 0;
int last_unicode_index = 0;
int unicode_index = 0;
auto token_begin_it = context_unicode.begin();
- while ((break_index = break_iterator->Next()) !=
+ while ((unicode_index = break_iterator->Next()) !=
UniLib::BreakIterator::kDone) {
- const int token_length = break_index - last_break_index;
- unicode_index = last_unicode_index + token_length;
+ const int token_length = unicode_index - last_unicode_index;
auto token_end_it = token_begin_it;
std::advance(token_end_it, token_length);
@@ -264,7 +261,6 @@
/*is_padding=*/false, is_whitespace));
}
- last_break_index = break_index;
last_unicode_index = unicode_index;
token_begin_it = token_end_it;
}
diff --git a/native/utils/utf8/unilib-common.cc b/native/utils/utf8/unilib-common.cc
index de52086..30149af 100644
--- a/native/utils/utf8/unilib-common.cc
+++ b/native/utils/utf8/unilib-common.cc
@@ -61,6 +61,12 @@
0x1F501, 0x1F502, 0x1F503, 0x1F504, 0x1F5D8, 0x1F5DE};
constexpr int kNumWhitespaces = ARRAYSIZE(kWhitespaces);
+// https://en.wikipedia.org/wiki/Bidirectional_text
+constexpr char32 kBidirectional[] = {0x061C, 0x200E, 0x200F, 0x202A,
+ 0x202B, 0x202C, 0x202D, 0x202E,
+ 0x2066, 0x2067, 0x2068, 0x2069};
+constexpr int kNumBidirectional = ARRAYSIZE(kBidirectional);
+
// grep -E "Nd" UnicodeData.txt | sed -re "s/([0-9A-Z]+);.*/0x\1, /"
// As the name suggests, these ranges are always 10 codepoints long, so we just
// store the end of the range.
@@ -502,6 +508,10 @@
return GetMatchIndex(kWhitespaces, kNumWhitespaces, codepoint) >= 0;
}
+bool IsBidirectional(char32 codepoint) {
+ return GetMatchIndex(kBidirectional, kNumBidirectional, codepoint) >= 0;
+}
+
bool IsDigit(char32 codepoint) {
return GetOverlappingRangeIndex(kDecimalDigitRangesEnd,
kNumDecimalDigitRangesEnd,
diff --git a/native/utils/utf8/unilib-common.h b/native/utils/utf8/unilib-common.h
index 4f03de7..eeffe9c 100644
--- a/native/utils/utf8/unilib-common.h
+++ b/native/utils/utf8/unilib-common.h
@@ -25,6 +25,7 @@
bool IsOpeningBracket(char32 codepoint);
bool IsClosingBracket(char32 codepoint);
bool IsWhitespace(char32 codepoint);
+bool IsBidirectional(char32 codepoint);
bool IsDigit(char32 codepoint);
bool IsLower(char32 codepoint);
bool IsUpper(char32 codepoint);
diff --git a/native/utils/utf8/unilib-javaicu.cc b/native/utils/utf8/unilib-javaicu.cc
index de6b5ed..a0f4cb7 100644
--- a/native/utils/utf8/unilib-javaicu.cc
+++ b/native/utils/utf8/unilib-javaicu.cc
@@ -25,7 +25,6 @@
#include "utils/base/logging.h"
#include "utils/base/statusor.h"
#include "utils/java/jni-base.h"
-#include "utils/java/string_utils.h"
#include "utils/utf8/unicodetext.h"
#include "utils/utf8/unilib-common.h"
@@ -420,14 +419,14 @@
return UTF8ToUnicodeText("", /*do_copy=*/false);
}
- std::string result;
- if (!JStringToUtf8String(jenv, status_or_java_result.ValueOrDie().get(),
- &result)) {
+ StatusOr<std::string> status_or_result =
+ JStringToUtf8String(jenv, status_or_java_result.ValueOrDie().get());
+ if (!status_or_result.ok()) {
*status = kError;
return UTF8ToUnicodeText("", /*do_copy=*/false);
}
*status = kNoError;
- return UTF8ToUnicodeText(result, /*do_copy=*/true);
+ return UTF8ToUnicodeText(status_or_result.ValueOrDie(), /*do_copy=*/true);
} else {
*status = kError;
return UTF8ToUnicodeText("", /*do_copy=*/false);
@@ -455,14 +454,14 @@
return UTF8ToUnicodeText("", /*do_copy=*/false);
}
- std::string result;
- if (!JStringToUtf8String(jenv, status_or_java_result.ValueOrDie().get(),
- &result)) {
+ StatusOr<std::string> status_or_result =
+ JStringToUtf8String(jenv, status_or_java_result.ValueOrDie().get());
+ if (!status_or_result.ok()) {
*status = kError;
return UTF8ToUnicodeText("", /*do_copy=*/false);
}
*status = kNoError;
- return UTF8ToUnicodeText(result, /*do_copy=*/true);
+ return UTF8ToUnicodeText(status_or_result.ValueOrDie(), /*do_copy=*/true);
} else {
*status = kError;
return UTF8ToUnicodeText("", /*do_copy=*/false);
diff --git a/native/utils/utf8/unilib-javaicu.h b/native/utils/utf8/unilib-javaicu.h
index d208730..4845de0 100644
--- a/native/utils/utf8/unilib-javaicu.h
+++ b/native/utils/utf8/unilib-javaicu.h
@@ -31,7 +31,6 @@
#include "utils/java/jni-base.h"
#include "utils/java/jni-cache.h"
#include "utils/java/jni-helper.h"
-#include "utils/java/string_utils.h"
#include "utils/utf8/unicodetext.h"
namespace libtextclassifier3 {
@@ -115,9 +114,13 @@
// Returns the matched text (the 0th capturing group).
std::string Text() const {
- ScopedStringChars text_str =
- GetScopedStringChars(jni_cache_->GetEnv(), text_.get());
- return text_str.get();
+ StatusOr<std::string> status_or_result =
+ JStringToUtf8String(jni_cache_->GetEnv(), text_.get());
+ if (!status_or_result.ok()) {
+ TC3_LOG(ERROR) << "JStringToUtf8String failed.";
+ return "";
+ }
+ return status_or_result.ValueOrDie();
}
private:
diff --git a/native/utils/zlib/zlib.cc b/native/utils/zlib/zlib.cc
index 4cb7760..6c8c5fd 100644
--- a/native/utils/zlib/zlib.cc
+++ b/native/utils/zlib/zlib.cc
@@ -16,22 +16,20 @@
#include "utils/zlib/zlib.h"
-#include "utils/flatbuffers.h"
+#include "utils/base/logging.h"
+#include "utils/flatbuffers/flatbuffers.h"
namespace libtextclassifier3 {
-std::unique_ptr<ZlibDecompressor> ZlibDecompressor::Instance(
- const unsigned char* dictionary, const unsigned int dictionary_size) {
- std::unique_ptr<ZlibDecompressor> result(
- new ZlibDecompressor(dictionary, dictionary_size));
+std::unique_ptr<ZlibDecompressor> ZlibDecompressor::Instance() {
+ std::unique_ptr<ZlibDecompressor> result(new ZlibDecompressor());
if (!result->initialized_) {
result.reset();
}
return result;
}
-ZlibDecompressor::ZlibDecompressor(const unsigned char* dictionary,
- const unsigned int dictionary_size) {
+ZlibDecompressor::ZlibDecompressor() {
memset(&stream_, 0, sizeof(stream_));
stream_.zalloc = Z_NULL;
stream_.zfree = Z_NULL;
@@ -40,11 +38,6 @@
TC3_LOG(ERROR) << "Could not initialize decompressor.";
return;
}
- if (dictionary != nullptr &&
- inflateSetDictionary(&stream_, dictionary, dictionary_size) != Z_OK) {
- TC3_LOG(ERROR) << "Could not set dictionary.";
- return;
- }
initialized_ = true;
}
@@ -61,7 +54,8 @@
return false;
}
out->resize(uncompressed_size);
- stream_.next_in = reinterpret_cast<const Bytef*>(buffer);
+ stream_.next_in =
+ const_cast<z_const Bytef*>(reinterpret_cast<const Bytef*>(buffer));
stream_.avail_in = buffer_size;
stream_.next_out = reinterpret_cast<Bytef*>(const_cast<char*>(out->c_str()));
stream_.avail_out = uncompressed_size;
@@ -110,19 +104,15 @@
return MaybeDecompress(compressed_buffer, out);
}
-std::unique_ptr<ZlibCompressor> ZlibCompressor::Instance(
- const unsigned char* dictionary, const unsigned int dictionary_size) {
- std::unique_ptr<ZlibCompressor> result(
- new ZlibCompressor(dictionary, dictionary_size));
+std::unique_ptr<ZlibCompressor> ZlibCompressor::Instance() {
+ std::unique_ptr<ZlibCompressor> result(new ZlibCompressor());
if (!result->initialized_) {
result.reset();
}
return result;
}
-ZlibCompressor::ZlibCompressor(const unsigned char* dictionary,
- const unsigned int dictionary_size,
- const int level, const int tmp_buffer_size) {
+ZlibCompressor::ZlibCompressor(const int level, const int tmp_buffer_size) {
memset(&stream_, 0, sizeof(stream_));
stream_.zalloc = Z_NULL;
stream_.zfree = Z_NULL;
@@ -133,11 +123,6 @@
TC3_LOG(ERROR) << "Could not initialize compressor.";
return;
}
- if (dictionary != nullptr &&
- deflateSetDictionary(&stream_, dictionary, dictionary_size) != Z_OK) {
- TC3_LOG(ERROR) << "Could not set dictionary.";
- return;
- }
initialized_ = true;
}
@@ -147,8 +132,8 @@
CompressedBufferT* out) {
out->uncompressed_size = uncompressed_content.size();
out->buffer.clear();
- stream_.next_in =
- reinterpret_cast<const Bytef*>(uncompressed_content.c_str());
+ stream_.next_in = const_cast<z_const Bytef*>(
+ reinterpret_cast<const Bytef*>(uncompressed_content.c_str()));
stream_.avail_in = uncompressed_content.size();
stream_.next_out = buffer_.get();
stream_.avail_out = buffer_size_;
@@ -177,14 +162,4 @@
} while (status == Z_OK);
}
-bool ZlibCompressor::GetDictionary(std::vector<unsigned char>* dictionary) {
- // Retrieve first the size of the dictionary.
- unsigned int size;
- if (deflateGetDictionary(&stream_, /*dictionary=*/Z_NULL, &size) != Z_OK) {
- return false;
- }
- dictionary->resize(size);
- return deflateGetDictionary(&stream_, dictionary->data(), &size) == Z_OK;
-}
-
} // namespace libtextclassifier3
diff --git a/native/utils/zlib/zlib.h b/native/utils/zlib/zlib.h
index f773c27..1f4d18a 100644
--- a/native/utils/zlib/zlib.h
+++ b/native/utils/zlib/zlib.h
@@ -29,9 +29,7 @@
class ZlibDecompressor {
public:
- static std::unique_ptr<ZlibDecompressor> Instance(
- const unsigned char* dictionary = nullptr,
- unsigned int dictionary_size = 0);
+ static std::unique_ptr<ZlibDecompressor> Instance();
~ZlibDecompressor();
bool Decompress(const uint8* buffer, const int buffer_size,
@@ -48,28 +46,21 @@
const CompressedBuffer* compressed_buffer, std::string* out);
private:
- ZlibDecompressor(const unsigned char* dictionary,
- const unsigned int dictionary_size);
+ explicit ZlibDecompressor();
z_stream stream_;
bool initialized_;
};
class ZlibCompressor {
public:
- static std::unique_ptr<ZlibCompressor> Instance(
- const unsigned char* dictionary = nullptr,
- unsigned int dictionary_size = 0);
+ static std::unique_ptr<ZlibCompressor> Instance();
~ZlibCompressor();
void Compress(const std::string& uncompressed_content,
CompressedBufferT* out);
- bool GetDictionary(std::vector<unsigned char>* dictionary);
-
private:
- explicit ZlibCompressor(const unsigned char* dictionary = nullptr,
- const unsigned int dictionary_size = 0,
- const int level = Z_BEST_COMPRESSION,
+ explicit ZlibCompressor(const int level = Z_BEST_COMPRESSION,
// Tmp. buffer size was set based on the current set
// of patterns to be compressed.
const int tmp_buffer_size = 64 * 1024);
diff --git a/native/utils/zlib/zlib_regex.cc b/native/utils/zlib/zlib_regex.cc
index 73b6d30..4822d6f 100644
--- a/native/utils/zlib/zlib_regex.cc
+++ b/native/utils/zlib/zlib_regex.cc
@@ -19,7 +19,7 @@
#include <memory>
#include "utils/base/logging.h"
-#include "utils/flatbuffers.h"
+#include "utils/flatbuffers/flatbuffers.h"
#include "utils/utf8/unicodetext.h"
namespace libtextclassifier3 {
diff --git a/notification/src/com/android/textclassifier/notification/SmartSuggestionsHelper.java b/notification/src/com/android/textclassifier/notification/SmartSuggestionsHelper.java
index 0a2cce7..0fee3b3 100644
--- a/notification/src/com/android/textclassifier/notification/SmartSuggestionsHelper.java
+++ b/notification/src/com/android/textclassifier/notification/SmartSuggestionsHelper.java
@@ -35,11 +35,9 @@
import android.util.Pair;
import android.view.textclassifier.ConversationAction;
import android.view.textclassifier.ConversationActions;
-import android.view.textclassifier.TextClassification;
import android.view.textclassifier.TextClassificationContext;
import android.view.textclassifier.TextClassificationManager;
import android.view.textclassifier.TextClassifier;
-
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
@@ -172,10 +170,7 @@
} else {
SmartSuggestionsLogSession session =
new SmartSuggestionsLogSession(
- resultId,
- repliesScore,
- textClassifier,
- textClassificationContext);
+ resultId, repliesScore, textClassifier, textClassificationContext);
session.onSuggestionsGenerated(conversationActions);
// Store the session if we expect more logging from it, destroy it otherwise.
diff --git a/notification/tests/src/com/android/textclassifier/notification/SmartSuggestionsHelperTest.java b/notification/tests/src/com/android/textclassifier/notification/SmartSuggestionsHelperTest.java
index 9d0a720..84cf4fb 100644
--- a/notification/tests/src/com/android/textclassifier/notification/SmartSuggestionsHelperTest.java
+++ b/notification/tests/src/com/android/textclassifier/notification/SmartSuggestionsHelperTest.java
@@ -42,6 +42,7 @@
import androidx.test.ext.junit.runners.AndroidJUnit4;
import androidx.test.filters.LargeTest;
import com.google.common.collect.ImmutableList;
+import java.time.Instant;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
@@ -426,8 +427,8 @@
long expectedReferenceTime) {
assertThat(subject.getText().toString()).isEqualTo(expectedMessage);
assertThat(subject.getAuthor()).isEqualTo(expectedAuthor);
- assertThat(subject.getReferenceTime().toInstant().toEpochMilli())
- .isEqualTo(expectedReferenceTime);
+ assertThat(subject.getReferenceTime().toInstant())
+ .isEqualTo(Instant.ofEpochMilli(expectedReferenceTime));
}
private static void assertAdjustmentWithSmartReply(SmartSuggestions smartSuggestions) {