Import libtextclassifier
Test: atest TextClassifierServiceTest
Test: atest TextClassifierTest
Test: m libtextclassifier_tests
Change-Id: I0e1aee3ea6b32b1d06d8ef638091fc2ac3d57763
diff --git a/PREUPLOAD.cfg b/PREUPLOAD.cfg
index f3db20e..f6fa187 100644
--- a/PREUPLOAD.cfg
+++ b/PREUPLOAD.cfg
@@ -1,2 +1,2 @@
[Hook Scripts]
-checkstyle_hook = ${REPO_ROOT}/prebuilts/checkstyle/checkstyle.py --sha ${PREUPLOAD_COMMIT}
+checkstyle_hook = ${REPO_ROOT}/prebuilts/checkstyle/checkstyle.py --sha ${PREUPLOAD_COMMIT} --file_whitelist java/
diff --git a/java/tests/unittests/src/com/android/textclassifier/intent/LegacyIntentClassificationFactoryTest.java b/java/tests/unittests/src/com/android/textclassifier/intent/LegacyIntentClassificationFactoryTest.java
index 5e959a1..6ae9e5a 100644
--- a/java/tests/unittests/src/com/android/textclassifier/intent/LegacyIntentClassificationFactoryTest.java
+++ b/java/tests/unittests/src/com/android/textclassifier/intent/LegacyIntentClassificationFactoryTest.java
@@ -65,8 +65,10 @@
null,
null,
null,
- 0,
- 0);
+ null,
+ 0L,
+ 0L,
+ 0d);
List<LabeledIntent> intents =
mLegacyIntentClassificationFactory.create(
@@ -102,8 +104,10 @@
null,
null,
null,
- 0,
- 0);
+ null,
+ 0L,
+ 0L,
+ 0d);
List<LabeledIntent> intents =
mLegacyIntentClassificationFactory.create(
diff --git a/java/tests/unittests/src/com/android/textclassifier/intent/TemplateClassificationIntentFactoryTest.java b/java/tests/unittests/src/com/android/textclassifier/intent/TemplateClassificationIntentFactoryTest.java
index 5965b5f..32840d0 100644
--- a/java/tests/unittests/src/com/android/textclassifier/intent/TemplateClassificationIntentFactoryTest.java
+++ b/java/tests/unittests/src/com/android/textclassifier/intent/TemplateClassificationIntentFactoryTest.java
@@ -81,9 +81,11 @@
null,
null,
null,
+ null,
createRemoteActionTemplates(),
- 0,
- 0);
+ 0L,
+ 0L,
+ 0d);
List<LabeledIntent> intents =
mTemplateClassificationIntentFactory.create(
@@ -122,9 +124,11 @@
null,
null,
null,
+ null,
createRemoteActionTemplates(),
- 0,
- 0);
+ 0L,
+ 0L,
+ 0d);
List<LabeledIntent> intents =
mTemplateClassificationIntentFactory.create(
@@ -160,8 +164,10 @@
null,
null,
null,
- 0,
- 0);
+ null,
+ 0L,
+ 0L,
+ 0d);
mTemplateClassificationIntentFactory.create(
InstrumentationRegistry.getContext(),
@@ -197,9 +203,11 @@
null,
null,
null,
+ null,
new RemoteActionTemplate[0],
- 0,
- 0);
+ 0L,
+ 0L,
+ 0d);
mTemplateClassificationIntentFactory.create(
InstrumentationRegistry.getContext(),
diff --git a/jni/com/google/android/textclassifier/ActionsSuggestionsModel.java b/jni/com/google/android/textclassifier/ActionsSuggestionsModel.java
index 9132b1f..84b5c3d 100644
--- a/jni/com/google/android/textclassifier/ActionsSuggestionsModel.java
+++ b/jni/com/google/android/textclassifier/ActionsSuggestionsModel.java
@@ -246,12 +246,21 @@
private static native long nativeNewActionsModelFromPath(
String path, byte[] preconditionsOverwrite);
+ private static native long nativeNewActionsModelWithOffset(
+ int fd, long offset, long size, byte[] preconditionsOverwrite);
+
private static native String nativeGetLocales(int fd);
+ private static native String nativeGetLocalesWithOffset(int fd, long offset, long size);
+
private static native int nativeGetVersion(int fd);
+ private static native int nativeGetVersionWithOffset(int fd, long offset, long size);
+
private static native String nativeGetName(int fd);
+ private static native String nativeGetNameWithOffset(int fd, long offset, long size);
+
private native ActionSuggestion[] nativeSuggestActions(
long context,
Conversation conversation,
diff --git a/jni/com/google/android/textclassifier/AnnotatorModel.java b/jni/com/google/android/textclassifier/AnnotatorModel.java
index 5f99f74..2e32cc0 100644
--- a/jni/com/google/android/textclassifier/AnnotatorModel.java
+++ b/jni/com/google/android/textclassifier/AnnotatorModel.java
@@ -240,6 +240,7 @@
private final byte[] serializedKnowledgeResult;
private final String contactName;
private final String contactGivenName;
+ private final String contactFamilyName;
private final String contactNickname;
private final String contactEmailAddress;
private final String contactPhoneNumber;
@@ -251,6 +252,7 @@
private final RemoteActionTemplate[] remoteActionTemplates;
private final long durationMs;
private final long numericValue;
+ private final double numericDoubleValue;
public ClassificationResult(
String collection,
@@ -259,6 +261,7 @@
byte[] serializedKnowledgeResult,
String contactName,
String contactGivenName,
+ String contactFamilyName,
String contactNickname,
String contactEmailAddress,
String contactPhoneNumber,
@@ -269,13 +272,15 @@
byte[] serializedEntityData,
RemoteActionTemplate[] remoteActionTemplates,
long durationMs,
- long numericValue) {
+ long numericValue,
+ double numericDoubleValue) {
this.collection = collection;
this.score = score;
this.datetimeResult = datetimeResult;
this.serializedKnowledgeResult = serializedKnowledgeResult;
this.contactName = contactName;
this.contactGivenName = contactGivenName;
+ this.contactFamilyName = contactFamilyName;
this.contactNickname = contactNickname;
this.contactEmailAddress = contactEmailAddress;
this.contactPhoneNumber = contactPhoneNumber;
@@ -287,6 +292,7 @@
this.remoteActionTemplates = remoteActionTemplates;
this.durationMs = durationMs;
this.numericValue = numericValue;
+ this.numericDoubleValue = numericDoubleValue;
}
/** Returns the classified entity type. */
@@ -315,6 +321,10 @@
return contactGivenName;
}
+ public String getContactFamilyName() {
+ return contactFamilyName;
+ }
+
public String getContactNickname() {
return contactNickname;
}
@@ -358,6 +368,10 @@
public long getNumericValue() {
return numericValue;
}
+
+ public double getNumericDoubleValue() {
+ return numericDoubleValue;
+ }
}
/** Represents a result of Annotate call. */
@@ -556,12 +570,20 @@
private static native long nativeNewAnnotatorFromPath(String path);
+ private static native long nativeNewAnnotatorWithOffset(int fd, long offset, long size);
+
private static native String nativeGetLocales(int fd);
+ private static native String nativeGetLocalesWithOffset(int fd, long offset, long size);
+
private static native int nativeGetVersion(int fd);
+ private static native int nativeGetVersionWithOffset(int fd, long offset, long size);
+
private static native String nativeGetName(int fd);
+ private static native String nativeGetNameWithOffset(int fd, long offset, long size);
+
private native long nativeGetNativeModelPtr(long context);
private native boolean nativeInitializeKnowledgeEngine(long context, byte[] serializedConfig);
diff --git a/jni/com/google/android/textclassifier/LangIdModel.java b/jni/com/google/android/textclassifier/LangIdModel.java
index d3e166f..e3f7a79 100644
--- a/jni/com/google/android/textclassifier/LangIdModel.java
+++ b/jni/com/google/android/textclassifier/LangIdModel.java
@@ -103,6 +103,11 @@
return nativeGetVersionFromFd(fd);
}
+ // Visible for testing.
+ float getLangIdNoiseThreshold() {
+ return nativeGetLangIdNoiseThreshold(modelPtr);
+ }
+
private static native long nativeNew(int fd);
private static native long nativeNewFromPath(String path);
@@ -116,4 +121,6 @@
private static native int nativeGetVersionFromFd(int fd);
private native float nativeGetLangIdThreshold(long nativePtr);
+
+ private native float nativeGetLangIdNoiseThreshold(long nativePtr);
}
diff --git a/native/actions/actions-entity-data.fbs b/native/actions/actions-entity-data.fbs
index 4ed68bb..21584b6 100755
--- a/native/actions/actions-entity-data.fbs
+++ b/native/actions/actions-entity-data.fbs
@@ -18,7 +18,7 @@
namespace libtextclassifier3;
table ActionsEntityData {
// Extracted text.
- text:string;
+ text:string (shared);
}
root_type libtextclassifier3.ActionsEntityData;
diff --git a/native/actions/actions-suggestions.cc b/native/actions/actions-suggestions.cc
index 29a4424..e651d19 100644
--- a/native/actions/actions-suggestions.cc
+++ b/native/actions/actions-suggestions.cc
@@ -24,6 +24,7 @@
#include "utils/base/logging.h"
#include "utils/flatbuffers.h"
#include "utils/lua-utils.h"
+#include "utils/optional.h"
#include "utils/regex-match.h"
#include "utils/strings/split.h"
#include "utils/strings/stringpiece.h"
@@ -1178,14 +1179,13 @@
bool ActionsSuggestions::FillAnnotationFromMatchGroup(
const UniLib::RegexMatcher* matcher,
const RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroup* group,
- const int message_index, ActionSuggestionAnnotation* annotation) const {
+ const std::string& group_match_text, const int message_index,
+ ActionSuggestionAnnotation* annotation) const {
if (group->annotation_name() != nullptr ||
group->annotation_type() != nullptr) {
int status = UniLib::RegexMatcher::kNoError;
const CodepointSpan span = {matcher->Start(group->group_id(), &status),
matcher->End(group->group_id(), &status)};
- std::string text =
- matcher->Group(group->group_id(), &status).ToUTF8String();
if (status != UniLib::RegexMatcher::kNoError) {
TC3_LOG(ERROR) << "Could not extract span from rule capturing group.";
return false;
@@ -1197,7 +1197,7 @@
}
annotation->span.span = span;
annotation->span.message_index = message_index;
- annotation->span.text = text;
+ annotation->span.text = group_match_text;
if (group->annotation_name() != nullptr) {
annotation->name = group->annotation_name()->str();
}
@@ -1244,12 +1244,18 @@
if (rule_action->capturing_group() != nullptr) {
for (const RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroup*
group : *rule_action->capturing_group()) {
+ Optional<std::string> group_match_text =
+ GetCapturingGroupText(matcher.get(), group->group_id());
+ if (!group_match_text.has_value()) {
+ // The group was not part of the match, ignore and continue.
+ continue;
+ }
+
if (group->entity_field() != nullptr) {
TC3_CHECK(entity_data != nullptr);
sets_entity_data = true;
- if (!SetFieldFromCapturingGroup(
- group->group_id(), group->entity_field(), matcher.get(),
- entity_data.get())) {
+ if (!entity_data->ParseAndSet(group->entity_field(),
+ group_match_text.value())) {
TC3_LOG(ERROR)
<< "Could not set entity data from rule capturing group.";
return false;
@@ -1259,27 +1265,17 @@
// Create a text annotation for the group span.
ActionSuggestionAnnotation annotation;
if (FillAnnotationFromMatchGroup(matcher.get(), group,
+ group_match_text.value(),
message_index, &annotation)) {
annotations.push_back(annotation);
}
// Create text reply.
if (group->text_reply() != nullptr) {
- int status = UniLib::RegexMatcher::kNoError;
- const std::string group_text =
- matcher->Group(group->group_id(), &status).ToUTF8String();
- if (status != UniLib::RegexMatcher::kNoError) {
- TC3_LOG(ERROR) << "Could get text from capturing group.";
- return false;
- }
- if (group_text.empty()) {
- // The group was not part of the match, ignore and continue.
- continue;
- }
actions->push_back(SuggestionFromSpec(
group->text_reply(),
/*default_type=*/model_->smart_reply_action_type()->str(),
- /*default_response_text=*/group_text));
+ /*default_response_text=*/group_match_text.value()));
}
}
}
diff --git a/native/actions/actions-suggestions.h b/native/actions/actions-suggestions.h
index 2dde133..60f7204 100644
--- a/native/actions/actions-suggestions.h
+++ b/native/actions/actions-suggestions.h
@@ -258,7 +258,8 @@
bool FillAnnotationFromMatchGroup(
const UniLib::RegexMatcher* matcher,
const RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroup* group,
- const int message_index, ActionSuggestionAnnotation* annotation) const;
+ const std::string& group_match_text, const int message_index,
+ ActionSuggestionAnnotation* annotation) const;
std::unique_ptr<libtextclassifier3::ScopedMmap> mmap_;
diff --git a/native/actions/actions-suggestions_test.cc b/native/actions/actions-suggestions_test.cc
index e0cfbaa..0dc627b 100644
--- a/native/actions/actions-suggestions_test.cc
+++ b/native/actions/actions-suggestions_test.cc
@@ -35,7 +35,6 @@
namespace libtextclassifier3 {
namespace {
-using testing::_;
constexpr char kModelFileName[] = "actions_suggestions_test.model";
constexpr char kHashGramModelFileName[] =
@@ -167,7 +166,7 @@
flight_annotation2.span = {35, 39};
flight_annotation2.classification = {ClassificationResult("flight", 3.0)};
AnnotatedSpan email_annotation;
- email_annotation.span = {55, 68};
+ email_annotation.span = {43, 56};
email_annotation.classification = {ClassificationResult("email", 2.0)};
const ActionsSuggestionsResponse& response =
@@ -208,7 +207,7 @@
flight_annotation2.span = {35, 39};
flight_annotation2.classification = {ClassificationResult("flight", 3.0)};
AnnotatedSpan email_annotation;
- email_annotation.span = {55, 68};
+ email_annotation.span = {43, 56};
email_annotation.classification = {ClassificationResult("email", 2.0)};
const ActionsSuggestionsResponse& response =
diff --git a/native/actions/actions_jni.cc b/native/actions/actions_jni.cc
index 20891fa..8284921 100644
--- a/native/actions/actions_jni.cc
+++ b/native/actions/actions_jni.cc
@@ -166,7 +166,8 @@
std::vector<RemoteActionTemplate> remote_action_templates;
if (context->intent_generator()->GenerateIntents(
device_locales, action_result[i], conversation, app_context,
- actions_entity_data_schema, annotations_entity_data_schema,
+ /*annotations_entity_data_schema=*/nullptr,
+ /*actions_entity_data_schema=*/nullptr,
&remote_action_templates)) {
remote_action_templates_result =
context->template_handler()->RemoteActionTemplatesToJObjectArray(
@@ -355,6 +356,31 @@
#endif // TC3_UNILIB_JAVAICU
}
+TC3_JNI_METHOD(jlong, TC3_ACTIONS_CLASS_NAME, nativeNewActionsModelWithOffset)
+(JNIEnv* env, jobject thiz, jint fd, jlong offset, jlong size,
+ jbyteArray serialized_preconditions) {
+ std::shared_ptr<libtextclassifier3::JniCache> jni_cache =
+ libtextclassifier3::JniCache::Create(env);
+ std::string preconditions;
+ if (serialized_preconditions != nullptr &&
+ !libtextclassifier3::JByteArrayToString(env, serialized_preconditions,
+ &preconditions)) {
+ TC3_LOG(ERROR) << "Could not convert serialized preconditions.";
+ return 0;
+ }
+#ifdef TC3_UNILIB_JAVAICU
+ return reinterpret_cast<jlong>(ActionsSuggestionsJniContext::Create(
+ jni_cache,
+ ActionsSuggestions::FromFileDescriptor(
+ fd, offset, size, std::unique_ptr<UniLib>(new UniLib(jni_cache)),
+ preconditions)));
+#else
+ return reinterpret_cast<jlong>(ActionsSuggestionsJniContext::Create(
+ jni_cache, ActionsSuggestions::FromFileDescriptor(
+ fd, offset, size, /*unilib=*/nullptr, preconditions)));
+#endif // TC3_UNILIB_JAVAICU
+}
+
TC3_JNI_METHOD(jobjectArray, TC3_ACTIONS_CLASS_NAME, nativeSuggestActions)
(JNIEnv* env, jobject clazz, jlong ptr, jobject jconversation, jobject joptions,
jlong annotatorPtr, jobject app_context, jstring device_locales,
@@ -393,6 +419,13 @@
return libtextclassifier3::GetLocalesFromMmap(env, mmap.get());
}
+TC3_JNI_METHOD(jstring, TC3_ACTIONS_CLASS_NAME, nativeGetLocalesWithOffset)
+(JNIEnv* env, jobject clazz, jint fd, jlong offset, jlong size) {
+ const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
+ new libtextclassifier3::ScopedMmap(fd, offset, size));
+ return libtextclassifier3::GetLocalesFromMmap(env, mmap.get());
+}
+
TC3_JNI_METHOD(jstring, TC3_ACTIONS_CLASS_NAME, nativeGetName)
(JNIEnv* env, jobject clazz, jint fd) {
const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
@@ -400,9 +433,23 @@
return libtextclassifier3::GetNameFromMmap(env, mmap.get());
}
+TC3_JNI_METHOD(jstring, TC3_ACTIONS_CLASS_NAME, nativeGetNameWithOffset)
+(JNIEnv* env, jobject clazz, jint fd, jlong offset, jlong size) {
+ const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
+ new libtextclassifier3::ScopedMmap(fd, offset, size));
+ return libtextclassifier3::GetNameFromMmap(env, mmap.get());
+}
+
TC3_JNI_METHOD(jint, TC3_ACTIONS_CLASS_NAME, nativeGetVersion)
(JNIEnv* env, jobject clazz, jint fd) {
const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
new libtextclassifier3::ScopedMmap(fd));
return libtextclassifier3::GetVersionFromMmap(env, mmap.get());
}
+
+TC3_JNI_METHOD(jint, TC3_ACTIONS_CLASS_NAME, nativeGetVersionWithOffset)
+(JNIEnv* env, jobject clazz, jint fd, jlong offset, jlong size) {
+ const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
+ new libtextclassifier3::ScopedMmap(fd, offset, size));
+ return libtextclassifier3::GetVersionFromMmap(env, mmap.get());
+}
diff --git a/native/actions/actions_jni.h b/native/actions/actions_jni.h
index fe2b998..276e361 100644
--- a/native/actions/actions_jni.h
+++ b/native/actions/actions_jni.h
@@ -37,6 +37,10 @@
TC3_JNI_METHOD(jlong, TC3_ACTIONS_CLASS_NAME, nativeNewActionsModelFromPath)
(JNIEnv* env, jobject thiz, jstring path, jbyteArray serialized_preconditions);
+TC3_JNI_METHOD(jlong, TC3_ACTIONS_CLASS_NAME, nativeNewActionsModelWithOffset)
+(JNIEnv* env, jobject thiz, jint fd, jlong offset, jlong size,
+ jbyteArray serialized_preconditions);
+
TC3_JNI_METHOD(jobjectArray, TC3_ACTIONS_CLASS_NAME, nativeSuggestActions)
(JNIEnv* env, jobject thiz, jlong ptr, jobject jconversation, jobject joptions,
jlong annotatorPtr, jobject app_context, jstring device_locales,
@@ -48,12 +52,21 @@
TC3_JNI_METHOD(jstring, TC3_ACTIONS_CLASS_NAME, nativeGetLocales)
(JNIEnv* env, jobject clazz, jint fd);
+TC3_JNI_METHOD(jstring, TC3_ACTIONS_CLASS_NAME, nativeGetLocalesWithOffset)
+(JNIEnv* env, jobject clazz, jint fd, jlong offset, jlong size);
+
TC3_JNI_METHOD(jstring, TC3_ACTIONS_CLASS_NAME, nativeGetName)
(JNIEnv* env, jobject clazz, jint fd);
+TC3_JNI_METHOD(jstring, TC3_ACTIONS_CLASS_NAME, nativeGetNameWithOffset)
+(JNIEnv* env, jobject clazz, jint fd, jlong offset, jlong size);
+
TC3_JNI_METHOD(jint, TC3_ACTIONS_CLASS_NAME, nativeGetVersion)
(JNIEnv* env, jobject clazz, jint fd);
+TC3_JNI_METHOD(jint, TC3_ACTIONS_CLASS_NAME, nativeGetVersionWithOffset)
+(JNIEnv* env, jobject clazz, jint fd, jlong offset, jlong size);
+
#ifdef __cplusplus
}
#endif
diff --git a/native/actions/actions_model.fbs b/native/actions/actions_model.fbs
index 42c7d88..d939d7c 100755
--- a/native/actions/actions_model.fbs
+++ b/native/actions/actions_model.fbs
@@ -255,16 +255,16 @@
namespace libtextclassifier3;
table ActionSuggestionSpec {
// Type of the action suggestion.
- type:string;
+ type:string (shared);
// Text of a smart reply action.
- response_text:string;
+ response_text:string (shared);
// Score.
score:float;
// Serialized entity information.
- serialized_entity_data:string;
+ serialized_entity_data:string (shared);
// Priority score used for internal conflict resolution.
priority_score:float = 0;
@@ -274,7 +274,7 @@
namespace libtextclassifier3;
table ActionTypeOptions {
// The name of the predicted action.
- name:string;
+ name:string (shared);
// Triggering behaviour.
// Whether the action class is considered in the model output or not.
@@ -290,7 +290,7 @@
namespace libtextclassifier3.AnnotationActionsSpec_;
table AnnotationMapping {
// The annotation collection.
- annotation_collection:string;
+ annotation_collection:string (shared);
// The action name to use.
action:ActionSuggestionSpec;
@@ -353,7 +353,7 @@
// * input: (optionally deduplicated) action suggestions, via the `actions`
// global
// * output: indices of the actions to keep in the provided order.
- lua_ranking_script:string;
+ lua_ranking_script:string (shared);
compressed_lua_ranking_script:CompressedBuffer;
@@ -376,9 +376,9 @@
// If set, the capturing group will be used to create a text annotation
// with the given name and type.
- annotation_type:string;
+ annotation_type:string (shared);
- annotation_name:string;
+ annotation_name:string (shared);
// If set, the capturing group text will be used to create a text
// reply.
@@ -398,13 +398,13 @@
namespace libtextclassifier3.RulesModel_;
table Rule {
// The regular expression pattern.
- pattern:string;
+ pattern:string (shared);
compressed_pattern:CompressedBuffer;
actions:[Rule_.RuleActionSpec];
// Patterns for post-checking the outputs.
- output_pattern:string;
+ output_pattern:string (shared);
compressed_output_pattern:CompressedBuffer;
}
@@ -421,18 +421,18 @@
namespace libtextclassifier3;
table ActionsModel {
// Comma-separated list of locales supported by the model as BCP 47 tags.
- locales:string;
+ locales:string (shared);
// Version of the actions model.
version:int;
// A name for the model that can be used e.g. for logging.
- name:string;
+ name:string (shared);
tflite_model_spec:TensorflowLiteModelSpec;
// Output classes.
- smart_reply_action_type:string;
+ smart_reply_action_type:string (shared);
action_type:[ActionTypeOptions];
@@ -464,7 +464,7 @@
ranking_options:RankingOptions;
// Lua based actions.
- lua_actions_script:string;
+ lua_actions_script:string (shared);
compressed_lua_actions_script:CompressedBuffer;
diff --git a/native/actions/test_data/actions_suggestions_test.default.model b/native/actions/test_data/actions_suggestions_test.default.model
deleted file mode 100644
index 60f10e6..0000000
--- a/native/actions/test_data/actions_suggestions_test.default.model
+++ /dev/null
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.hashgram.model b/native/actions/test_data/actions_suggestions_test.hashgram.model
deleted file mode 100644
index cdc6bdc..0000000
--- a/native/actions/test_data/actions_suggestions_test.hashgram.model
+++ /dev/null
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.model b/native/actions/test_data/actions_suggestions_test.model
deleted file mode 100644
index 6cec2b7..0000000
--- a/native/actions/test_data/actions_suggestions_test.model
+++ /dev/null
Binary files differ
diff --git a/native/annotator/annotator.cc b/native/annotator/annotator.cc
index 53c8d8a..867eea0 100644
--- a/native/annotator/annotator.cc
+++ b/native/annotator/annotator.cc
@@ -29,6 +29,7 @@
#include "utils/base/logging.h"
#include "utils/checksum.h"
#include "utils/math/softmax.h"
+#include "utils/optional.h"
#include "utils/regex-match.h"
#include "utils/utf8/unicodetext.h"
#include "utils/zlib/zlib_regex.h"
@@ -499,8 +500,7 @@
bool Annotator::InitializeKnowledgeEngine(
const std::string& serialized_config) {
- std::unique_ptr<KnowledgeEngine> knowledge_engine(
- new KnowledgeEngine(unilib_));
+ std::unique_ptr<KnowledgeEngine> knowledge_engine(new KnowledgeEngine());
if (!knowledge_engine->Initialize(serialized_config)) {
TC3_LOG(ERROR) << "Failed to initialize the knowledge engine.";
return false;
@@ -617,15 +617,20 @@
classification[0].collection == Collections::Other();
}
-float GetPriorityScore(
- const std::vector<ClassificationResult>& classification) {
+} // namespace
+
+float Annotator::GetPriorityScore(
+ const std::vector<ClassificationResult>& classification) const {
if (!classification.empty() && !ClassifiedAsOther(classification)) {
return classification[0].priority_score;
} else {
- return -1.0;
+ if (model_->triggering_options() != nullptr) {
+ return model_->triggering_options()->other_collection_priority_score();
+ } else {
+ return -1000.0;
+ }
}
}
-} // namespace
bool Annotator::VerifyRegexMatchCandidate(
const std::string& context, const VerificationOptions* verification_options,
@@ -735,7 +740,8 @@
return original_click_indices;
}
if (knowledge_engine_ != nullptr &&
- !knowledge_engine_->Chunk(context, &candidates)) {
+ !knowledge_engine_->Chunk(context, options.annotation_usecase,
+ &candidates)) {
TC3_LOG(ERROR) << "Knowledge suggest selection failed.";
return original_click_indices;
}
@@ -779,7 +785,7 @@
}
std::sort(candidate_indices.begin(), candidate_indices.end(),
- [&candidates](int a, int b) {
+ [this, &candidates](int a, int b) {
return GetPriorityScore(candidates[a].classification) >
GetPriorityScore(candidates[b].classification);
});
@@ -1421,6 +1427,23 @@
static_cast<EntityData_::Datetime_::Granularity>(
parse_result.granularity);
+ for (const auto& c : parse_result.datetime_components) {
+ EntityData_::Datetime_::DatetimeComponentT datetime_component;
+ datetime_component.absolute_value = c.value;
+ datetime_component.relative_count = c.relative_count;
+ datetime_component.component_type =
+ static_cast<EntityData_::Datetime_::DatetimeComponent_::ComponentType>(
+ c.component_type);
+ datetime_component.relation_type =
+ EntityData_::Datetime_::DatetimeComponent_::RelationType_ABSOLUTE;
+ if (c.relative_qualifier !=
+ DatetimeComponent::RelativeQualifier::UNSPECIFIED) {
+ datetime_component.relation_type =
+ EntityData_::Datetime_::DatetimeComponent_::RelationType_RELATIVE;
+ }
+ entity_data.datetime->datetime_component.emplace_back(
+ new EntityData_::Datetime_::DatetimeComponentT(datetime_component));
+ }
flatbuffers::FlatBufferBuilder builder;
FinishEntityDataBuffer(builder, EntityData::Pack(builder, &entity_data));
return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
@@ -1515,11 +1538,14 @@
// TODO(b/126579108): Propagate error status.
ClassificationResult knowledge_result;
if (knowledge_engine_ && knowledge_engine_->ClassifyText(
- context, selection_indices, &knowledge_result)) {
+ context, selection_indices,
+ options.annotation_usecase, &knowledge_result)) {
candidates.push_back({selection_indices, {knowledge_result}});
candidates.back().source = AnnotatedSpan::Source::KNOWLEDGE;
}
+ AddContactMetadataToKnowledgeClassificationResults(&candidates);
+
// Try the contact engine.
// TODO(b/126579108): Propagate error status.
ClassificationResult contact_result;
@@ -1652,7 +1678,9 @@
if (!selection_feature_processor_->GetOptions()->only_use_line_with_click()) {
lines.push_back({context_unicode.begin(), context_unicode.end()});
} else {
- lines = selection_feature_processor_->SplitContext(context_unicode);
+ lines = selection_feature_processor_->SplitContext(
+ context_unicode, selection_feature_processor_->GetOptions()
+ ->use_pipe_character_for_newline());
}
const float min_annotate_confidence =
@@ -1769,6 +1797,19 @@
annotated_spans->end());
}
+void Annotator::AddContactMetadataToKnowledgeClassificationResults(
+ std::vector<AnnotatedSpan>* candidates) const {
+ if (candidates == nullptr || contact_engine_ == nullptr) {
+ return;
+ }
+ for (auto& candidate : *candidates) {
+ for (auto& classification_result : candidate.classification) {
+ contact_engine_->AddContactMetadataToKnowledgeClassificationResult(
+ &classification_result);
+ }
+ }
+}
+
std::vector<AnnotatedSpan> Annotator::Annotate(
const std::string& context, const AnnotationOptions& options) const {
std::vector<AnnotatedSpan> candidates;
@@ -1824,16 +1865,28 @@
options.locales, ModeFlag_ANNOTATION,
options.annotation_usecase,
options.is_serialized_entity_data_enabled, &candidates)) {
- TC3_LOG(ERROR) << "Couldn't run RegexChunk.";
+ TC3_LOG(ERROR) << "Couldn't run DatetimeChunk.";
return {};
}
- // Annotate with the knowledge engine.
- if (knowledge_engine_ && !knowledge_engine_->Chunk(context, &candidates)) {
+ // Annotate with the knowledge engine into a temporary vector.
+ std::vector<AnnotatedSpan> knowledge_candidates;
+ if (knowledge_engine_ &&
+ !knowledge_engine_->Chunk(context, options.annotation_usecase,
+ &knowledge_candidates)) {
TC3_LOG(ERROR) << "Couldn't run knowledge engine Chunk.";
return {};
}
+ AddContactMetadataToKnowledgeClassificationResults(&knowledge_candidates);
+
+ // Move the knowledge candidates to the full candidate list, and erase
+ // knowledge_candidates.
+ candidates.insert(candidates.end(),
+ std::make_move_iterator(knowledge_candidates.begin()),
+ std::make_move_iterator(knowledge_candidates.end()));
+ knowledge_candidates.clear();
+
// Annotate with the contact engine.
if (contact_engine_ &&
!contact_engine_->Chunk(context_unicode, tokens, &candidates)) {
@@ -1970,6 +2023,9 @@
if (group->entity_field_path() != nullptr) {
return true;
}
+ if (group->serialized_entity_data() != nullptr) {
+ return true;
+ }
}
}
return false;
@@ -1991,7 +2047,6 @@
// Set static entity data.
if (pattern->serialized_entity_data() != nullptr) {
- TC3_CHECK(entity_data != nullptr);
entity_data->MergeFromSerializedFlatbuffer(
StringPiece(pattern->serialized_entity_data()->c_str(),
pattern->serialized_entity_data()->size()));
@@ -2001,17 +2056,31 @@
if (pattern->capturing_group() != nullptr) {
const int num_groups = pattern->capturing_group()->size();
for (int i = 0; i < num_groups; i++) {
- const FlatbufferFieldPath* field_path =
- pattern->capturing_group()->Get(i)->entity_field_path();
- if (field_path == nullptr) {
+ const RegexModel_::Pattern_::CapturingGroup* group =
+ pattern->capturing_group()->Get(i);
+
+ // Check whether the group matched.
+ Optional<std::string> group_match_text =
+ GetCapturingGroupText(matcher, /*group_id=*/i);
+ if (!group_match_text.has_value()) {
continue;
}
- TC3_CHECK(entity_data != nullptr);
- if (!SetFieldFromCapturingGroup(/*group_id=*/i, field_path, matcher,
- entity_data.get())) {
- TC3_LOG(ERROR)
- << "Could not set entity data from rule capturing group.";
- return false;
+
+ // Set static entity data from capturing group match.
+ if (group->serialized_entity_data() != nullptr) {
+ entity_data->MergeFromSerializedFlatbuffer(
+ StringPiece(group->serialized_entity_data()->c_str(),
+ group->serialized_entity_data()->size()));
+ }
+
+ // Set entity field from capturing group text.
+ if (group->entity_field_path() != nullptr) {
+ if (!entity_data->ParseAndSet(group->entity_field_path(),
+ group_match_text.value())) {
+ TC3_LOG(ERROR)
+ << "Could not set entity data from rule capturing group.";
+ return false;
+ }
}
}
}
diff --git a/native/annotator/annotator.h b/native/annotator/annotator.h
index 0b1c9f9..dabd894 100644
--- a/native/annotator/annotator.h
+++ b/native/annotator/annotator.h
@@ -137,7 +137,7 @@
};
// Holds TFLite interpreters for selection and classification models.
-// NOTE: his class is not thread-safe, thus should NOT be re-used across
+// NOTE: This class is not thread-safe, thus should NOT be re-used across
// threads.
class InterpreterManager {
public:
@@ -453,6 +453,15 @@
const RegexModel_::Pattern* pattern, UniLib::RegexMatcher* matcher,
std::string* serialized_entity_data) const;
+ // For knowledge candidates which have a ContactPointer, fill in the
+ // appropriate contact metadata, if possible.
+ void AddContactMetadataToKnowledgeClassificationResults(
+ std::vector<AnnotatedSpan>* candidates) const;
+
+ // Gets priority score from the list of classification results.
+ float GetPriorityScore(
+ const std::vector<ClassificationResult>& classification) const;
+
// Verifies a regex match and returns true if verification was successful.
bool VerifyRegexMatchCandidate(
const std::string& context,
diff --git a/native/annotator/annotator_jni.cc b/native/annotator/annotator_jni.cc
index 9118f30..e5b7833 100644
--- a/native/annotator/annotator_jni.cc
+++ b/native/annotator/annotator_jni.cc
@@ -155,6 +155,12 @@
env->NewStringUTF(classification_result.contact_given_name.c_str());
}
+ jstring contact_family_name = nullptr;
+ if (!classification_result.contact_family_name.empty()) {
+ contact_family_name =
+ env->NewStringUTF(classification_result.contact_family_name.c_str());
+ }
+
jstring contact_nickname = nullptr;
if (!classification_result.contact_nickname.empty()) {
contact_nickname =
@@ -228,10 +234,11 @@
result_class, result_class_constructor, row_string,
static_cast<jfloat>(classification_result.score), row_datetime_parse,
serialized_knowledge_result, contact_name, contact_given_name,
- contact_nickname, contact_email_address, contact_phone_number, contact_id,
- app_name, app_package_name, extras, serialized_entity_data,
- remote_action_templates_result, classification_result.duration_ms,
- classification_result.numeric_value);
+ contact_family_name, contact_nickname, contact_email_address,
+ contact_phone_number, contact_id, app_name, app_package_name, extras,
+ serialized_entity_data, remote_action_templates_result,
+ classification_result.duration_ms, classification_result.numeric_value,
+ classification_result.numeric_double_value);
}
jobjectArray ClassificationResultsWithIntentsToJObjectArray(
@@ -262,9 +269,9 @@
"(Ljava/lang/String;FL" TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR
"$DatetimeResult;[BLjava/lang/String;Ljava/lang/String;Ljava/lang/String;"
"Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;"
- "Ljava/lang/String;[L" TC3_PACKAGE_PATH TC3_NAMED_VARIANT_CLASS_NAME_STR
- ";[B[L" TC3_PACKAGE_PATH TC3_REMOTE_ACTION_TEMPLATE_CLASS_NAME_STR
- ";JJ)V");
+ "Ljava/lang/String;Ljava/lang/String;[L" TC3_PACKAGE_PATH
+ "" TC3_NAMED_VARIANT_CLASS_NAME_STR ";[B[L" TC3_PACKAGE_PATH
+ "" TC3_REMOTE_ACTION_TEMPLATE_CLASS_NAME_STR ";JJD)V");
const jmethodID datetime_parse_class_constructor =
env->GetMethodID(datetime_parse_class.get(), "<init>", "(JI)V");
@@ -435,12 +442,10 @@
#endif
}
-TC3_JNI_METHOD(jlong, TC3_ANNOTATOR_CLASS_NAME,
- nativeNewAnnotatorFromAssetFileDescriptor)
-(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size) {
+TC3_JNI_METHOD(jlong, TC3_ANNOTATOR_CLASS_NAME, nativeNewAnnotatorWithOffset)
+(JNIEnv* env, jobject thiz, jint fd, jlong offset, jlong size) {
std::shared_ptr<libtextclassifier3::JniCache> jni_cache(
libtextclassifier3::JniCache::Create(env));
- const jint fd = libtextclassifier3::GetFdFromAssetFileDescriptor(env, afd);
#ifdef TC3_USE_JAVAICU
return reinterpret_cast<jlong>(AnnotatorJniContext::Create(
jni_cache,
@@ -652,10 +657,8 @@
return GetLocalesFromMmap(env, mmap.get());
}
-TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME,
- nativeGetLocalesFromAssetFileDescriptor)
-(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size) {
- const jint fd = libtextclassifier3::GetFdFromAssetFileDescriptor(env, afd);
+TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME, nativeGetLocalesWithOffset)
+(JNIEnv* env, jobject thiz, jint fd, jlong offset, jlong size) {
const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
new libtextclassifier3::ScopedMmap(fd, offset, size));
return GetLocalesFromMmap(env, mmap.get());
@@ -668,10 +671,8 @@
return GetVersionFromMmap(env, mmap.get());
}
-TC3_JNI_METHOD(jint, TC3_ANNOTATOR_CLASS_NAME,
- nativeGetVersionFromAssetFileDescriptor)
-(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size) {
- const jint fd = libtextclassifier3::GetFdFromAssetFileDescriptor(env, afd);
+TC3_JNI_METHOD(jint, TC3_ANNOTATOR_CLASS_NAME, nativeGetVersionWithOffset)
+(JNIEnv* env, jobject thiz, jint fd, jlong offset, jlong size) {
const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
new libtextclassifier3::ScopedMmap(fd, offset, size));
return GetVersionFromMmap(env, mmap.get());
@@ -684,10 +685,8 @@
return GetNameFromMmap(env, mmap.get());
}
-TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME,
- nativeGetNameFromAssetFileDescriptor)
-(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size) {
- const jint fd = libtextclassifier3::GetFdFromAssetFileDescriptor(env, afd);
+TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME, nativeGetNameWithOffset)
+(JNIEnv* env, jobject thiz, jint fd, jlong offset, jlong size) {
const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
new libtextclassifier3::ScopedMmap(fd, offset, size));
return GetNameFromMmap(env, mmap.get());
diff --git a/native/annotator/annotator_jni.h b/native/annotator/annotator_jni.h
index bca1dcd..0789e76 100644
--- a/native/annotator/annotator_jni.h
+++ b/native/annotator/annotator_jni.h
@@ -34,9 +34,8 @@
TC3_JNI_METHOD(jlong, TC3_ANNOTATOR_CLASS_NAME, nativeNewAnnotatorFromPath)
(JNIEnv* env, jobject thiz, jstring path);
-TC3_JNI_METHOD(jlong, TC3_ANNOTATOR_CLASS_NAME,
- nativeNewAnnotatorFromAssetFileDescriptor)
-(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size);
+TC3_JNI_METHOD(jlong, TC3_ANNOTATOR_CLASS_NAME, nativeNewAnnotatorWithOffset)
+(JNIEnv* env, jobject thiz, jint fd, jlong offset, jlong size);
TC3_JNI_METHOD(jboolean, TC3_ANNOTATOR_CLASS_NAME,
nativeInitializeKnowledgeEngine)
@@ -79,23 +78,20 @@
TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME, nativeGetLocales)
(JNIEnv* env, jobject clazz, jint fd);
-TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME,
- nativeGetLocalesFromAssetFileDescriptor)
-(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size);
+TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME, nativeGetLocalesWithOffset)
+(JNIEnv* env, jobject thiz, jint fd, jlong offset, jlong size);
TC3_JNI_METHOD(jint, TC3_ANNOTATOR_CLASS_NAME, nativeGetVersion)
(JNIEnv* env, jobject clazz, jint fd);
-TC3_JNI_METHOD(jint, TC3_ANNOTATOR_CLASS_NAME,
- nativeGetVersionFromAssetFileDescriptor)
-(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size);
+TC3_JNI_METHOD(jint, TC3_ANNOTATOR_CLASS_NAME, nativeGetVersionWithOffset)
+(JNIEnv* env, jobject thiz, jint fd, jlong offset, jlong size);
TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME, nativeGetName)
(JNIEnv* env, jobject clazz, jint fd);
-TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME,
- nativeGetNameFromAssetFileDescriptor)
-(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size);
+TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME, nativeGetNameWithOffset)
+(JNIEnv* env, jobject thiz, jint fd, jlong offset, jlong size);
#ifdef __cplusplus
}
diff --git a/native/annotator/collections.h b/native/annotator/collections.h
index a23623e..92fa984 100644
--- a/native/annotator/collections.h
+++ b/native/annotator/collections.h
@@ -104,6 +104,11 @@
*[]() { return new std::string("payment_card"); }();
return value;
}
+ static const std::string& Percentage() {
+ static const std::string& value =
+ *[]() { return new std::string("percentage"); }();
+ return value;
+ }
static const std::string& Phone() {
static const std::string& value =
*[]() { return new std::string("phone"); }();
diff --git a/native/annotator/contact/contact-engine-dummy.h b/native/annotator/contact/contact-engine-dummy.h
index c7a389d..436ac94 100644
--- a/native/annotator/contact/contact-engine-dummy.h
+++ b/native/annotator/contact/contact-engine-dummy.h
@@ -49,6 +49,9 @@
std::vector<AnnotatedSpan>* result) const {
return true;
}
+
+ void AddContactMetadataToKnowledgeClassificationResult(
+ ClassificationResult* classification_result) const {}
};
} // namespace libtextclassifier3
diff --git a/native/annotator/datetime/extractor.cc b/native/annotator/datetime/extractor.cc
index b9d0c30..0d3c202 100644
--- a/native/annotator/datetime/extractor.cc
+++ b/native/annotator/datetime/extractor.cc
@@ -20,15 +20,21 @@
namespace libtextclassifier3 {
-bool DatetimeExtractor::Extract(DateParseData* result,
+bool DatetimeExtractor::Extract(DatetimeParsedData* result,
CodepointSpan* result_span) const {
- result->field_set_mask = 0;
*result_span = {kInvalidIndex, kInvalidIndex};
if (rule_.regex->groups() == nullptr) {
return false;
}
+ // In the current implementation of extractor, the assumption is that there
+ // can only be one relative field.
+ DatetimeComponent::ComponentType component_type;
+ DatetimeComponent::RelativeQualifier relative_qualifier =
+ DatetimeComponent::RelativeQualifier::UNSPECIFIED;
+ int relative_count = 0;
+
for (int group_id = 0; group_id < rule_.regex->groups()->size(); group_id++) {
UnicodeText group_text;
const int group_type = rule_.regex->groups()->Get(group_id);
@@ -44,85 +50,115 @@
if (group_text.empty()) {
continue;
}
+
switch (group_type) {
case DatetimeGroupType_GROUP_YEAR: {
- if (!ParseYear(group_text, &(result->year))) {
+ int year;
+ if (!ParseYear(group_text, &(year))) {
TC3_LOG(ERROR) << "Couldn't extract YEAR.";
return false;
}
- result->field_set_mask |= DateParseData::YEAR_FIELD;
+ result->SetAbsoluteValue(DatetimeComponent::ComponentType::YEAR, year);
break;
}
case DatetimeGroupType_GROUP_MONTH: {
- if (!ParseMonth(group_text, &(result->month))) {
+ int month;
+ if (!ParseMonth(group_text, &(month))) {
TC3_LOG(ERROR) << "Couldn't extract MONTH.";
return false;
}
- result->field_set_mask |= DateParseData::MONTH_FIELD;
+ result->SetAbsoluteValue(DatetimeComponent::ComponentType::MONTH,
+ month);
break;
}
case DatetimeGroupType_GROUP_DAY: {
- if (!ParseDigits(group_text, &(result->day_of_month))) {
+ int day_of_month;
+ if (!ParseDigits(group_text, &(day_of_month))) {
TC3_LOG(ERROR) << "Couldn't extract DAY.";
return false;
}
- result->field_set_mask |= DateParseData::DAY_FIELD;
+ result->SetAbsoluteValue(DatetimeComponent::ComponentType::DAY_OF_MONTH,
+ day_of_month);
break;
}
case DatetimeGroupType_GROUP_HOUR: {
- if (!ParseDigits(group_text, &(result->hour))) {
+ int hour;
+ if (!ParseDigits(group_text, &(hour))) {
TC3_LOG(ERROR) << "Couldn't extract HOUR.";
return false;
}
- result->field_set_mask |= DateParseData::HOUR_FIELD;
+ result->SetAbsoluteValue(DatetimeComponent::ComponentType::HOUR, hour);
break;
}
case DatetimeGroupType_GROUP_MINUTE: {
- if (!ParseDigits(group_text, &(result->minute))) {
+ int minute;
+ if (!ParseDigits(group_text, &(minute))) {
TC3_LOG(ERROR) << "Couldn't extract MINUTE.";
return false;
}
- result->field_set_mask |= DateParseData::MINUTE_FIELD;
+ result->SetAbsoluteValue(DatetimeComponent::ComponentType::MINUTE,
+ minute);
break;
}
case DatetimeGroupType_GROUP_SECOND: {
- if (!ParseDigits(group_text, &(result->second))) {
+ int second;
+ if (!ParseDigits(group_text, &(second))) {
TC3_LOG(ERROR) << "Couldn't extract SECOND.";
return false;
}
- result->field_set_mask |= DateParseData::SECOND_FIELD;
+ result->SetAbsoluteValue(DatetimeComponent::ComponentType::SECOND,
+ second);
break;
}
case DatetimeGroupType_GROUP_AMPM: {
- if (!ParseAMPM(group_text, &(result->ampm))) {
+ int meridiem;
+ if (!ParseMeridiem(group_text, &(meridiem))) {
TC3_LOG(ERROR) << "Couldn't extract AMPM.";
return false;
}
- result->field_set_mask |= DateParseData::AMPM_FIELD;
+ result->SetAbsoluteValue(DatetimeComponent::ComponentType::MERIDIEM,
+ meridiem);
break;
}
case DatetimeGroupType_GROUP_RELATIONDISTANCE: {
- if (!ParseRelationDistance(group_text, &(result->relation_distance))) {
+ relative_count = 0;
+ if (!ParseRelationDistance(group_text, &(relative_count))) {
TC3_LOG(ERROR) << "Couldn't extract RELATION_DISTANCE_FIELD.";
return false;
}
- result->field_set_mask |= DateParseData::RELATION_DISTANCE_FIELD;
break;
}
case DatetimeGroupType_GROUP_RELATION: {
- if (!ParseRelation(group_text, &(result->relation))) {
+ if (!ParseRelativeValue(group_text, &relative_qualifier)) {
TC3_LOG(ERROR) << "Couldn't extract RELATION_FIELD.";
return false;
}
- result->field_set_mask |= DateParseData::RELATION_FIELD;
+ ParseRelationAndConvertToRelativeCount(group_text, &relative_count);
+ if (relative_qualifier ==
+ DatetimeComponent::RelativeQualifier::TOMORROW ||
+ relative_qualifier == DatetimeComponent::RelativeQualifier::NOW ||
+ relative_qualifier ==
+ DatetimeComponent::RelativeQualifier::YESTERDAY) {
+ if (!ParseFieldType(group_text, &component_type)) {
+ TC3_LOG(ERROR) << "Couldn't extract RELATION_TYPE_FIELD.";
+ return false;
+ }
+ }
break;
}
case DatetimeGroupType_GROUP_RELATIONTYPE: {
- if (!ParseRelationType(group_text, &(result->relation_type))) {
+ if (!ParseFieldType(group_text, &component_type)) {
TC3_LOG(ERROR) << "Couldn't extract RELATION_TYPE_FIELD.";
return false;
}
- result->field_set_mask |= DateParseData::RELATION_TYPE_FIELD;
+ if (component_type == DatetimeComponent::ComponentType::DAY_OF_WEEK) {
+ int day_of_week;
+ if (!ParseDayOfWeek(group_text, &day_of_week)) {
+ TC3_LOG(ERROR) << "Couldn't extract RELATION_TYPE_FIELD.";
+ return false;
+ }
+ result->SetAbsoluteValue(component_type, day_of_week);
+ }
break;
}
case DatetimeGroupType_GROUP_DUMMY1:
@@ -138,6 +174,11 @@
}
}
+ if (relative_qualifier != DatetimeComponent::RelativeQualifier::UNSPECIFIED) {
+ result->SetRelativeValue(component_type, relative_qualifier);
+ result->SetRelativeCount(component_type, relative_count);
+ }
+
if (result_span->first == kInvalidIndex ||
result_span->second == kInvalidIndex) {
*result_span = {kInvalidIndex, kInvalidIndex};
@@ -280,7 +321,6 @@
if (!matcher) {
return false;
}
-
int status;
while (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) {
int span_start = matcher->Start(&status);
@@ -336,6 +376,7 @@
return false;
}
+ // Logic to decide if XX will be 20XX or 19XX
if (*parsed_year < 100) {
if (*parsed_year < 50) {
*parsed_year += 2000;
@@ -375,14 +416,14 @@
return false;
}
-bool DatetimeExtractor::ParseAMPM(const UnicodeText& input,
- DateParseData::AMPM* parsed_ampm) const {
+bool DatetimeExtractor::ParseMeridiem(const UnicodeText& input,
+ int* parsed_meridiem) const {
return MapInput(input,
{
- {DatetimeExtractorType_AM, DateParseData::AMPM::AM},
- {DatetimeExtractorType_PM, DateParseData::AMPM::PM},
+ {DatetimeExtractorType_AM, 0 /* AM */},
+ {DatetimeExtractorType_PM, 1 /* PM */},
},
- parsed_ampm);
+ parsed_meridiem);
}
bool DatetimeExtractor::ParseRelationDistance(const UnicodeText& input,
@@ -396,49 +437,99 @@
return false;
}
-bool DatetimeExtractor::ParseRelation(
- const UnicodeText& input, DateParseData::Relation* parsed_relation) const {
- return MapInput(
- input,
- {
- {DatetimeExtractorType_NOW, DateParseData::Relation::NOW},
- {DatetimeExtractorType_YESTERDAY, DateParseData::Relation::YESTERDAY},
- {DatetimeExtractorType_TOMORROW, DateParseData::Relation::TOMORROW},
- {DatetimeExtractorType_NEXT, DateParseData::Relation::NEXT},
- {DatetimeExtractorType_NEXT_OR_SAME,
- DateParseData::Relation::NEXT_OR_SAME},
- {DatetimeExtractorType_LAST, DateParseData::Relation::LAST},
- {DatetimeExtractorType_PAST, DateParseData::Relation::PAST},
- {DatetimeExtractorType_FUTURE, DateParseData::Relation::FUTURE},
- },
- parsed_relation);
+bool DatetimeExtractor::ParseRelativeValue(
+ const UnicodeText& input,
+ DatetimeComponent::RelativeQualifier* parsed_relative_value) const {
+ return MapInput(input,
+ {
+ {DatetimeExtractorType_NOW,
+ DatetimeComponent::RelativeQualifier::NOW},
+ {DatetimeExtractorType_YESTERDAY,
+ DatetimeComponent::RelativeQualifier::YESTERDAY},
+ {DatetimeExtractorType_TOMORROW,
+ DatetimeComponent::RelativeQualifier::TOMORROW},
+ {DatetimeExtractorType_NEXT,
+ DatetimeComponent::RelativeQualifier::NEXT},
+ {DatetimeExtractorType_NEXT_OR_SAME,
+ DatetimeComponent::RelativeQualifier::THIS},
+ {DatetimeExtractorType_LAST,
+ DatetimeComponent::RelativeQualifier::LAST},
+ {DatetimeExtractorType_PAST,
+ DatetimeComponent::RelativeQualifier::PAST},
+ {DatetimeExtractorType_FUTURE,
+ DatetimeComponent::RelativeQualifier::FUTURE},
+ },
+ parsed_relative_value);
}
-bool DatetimeExtractor::ParseRelationType(
+bool DatetimeExtractor::ParseRelationAndConvertToRelativeCount(
+ const UnicodeText& input, int* relative_count) const {
+ return MapInput(input,
+ {
+ {DatetimeExtractorType_NOW, 0},
+ {DatetimeExtractorType_YESTERDAY, -1},
+ {DatetimeExtractorType_TOMORROW, 1},
+ {DatetimeExtractorType_NEXT, 1},
+ {DatetimeExtractorType_NEXT_OR_SAME, 1},
+ {DatetimeExtractorType_LAST, -1},
+ },
+ relative_count);
+}
+
+bool DatetimeExtractor::ParseDayOfWeek(const UnicodeText& input,
+ int* parsed_day_of_week) const {
+ return MapInput(input,
+ {
+ {DatetimeExtractorType_SUNDAY, kSunday},
+ {DatetimeExtractorType_MONDAY, kMonday},
+ {DatetimeExtractorType_TUESDAY, kTuesday},
+ {DatetimeExtractorType_WEDNESDAY, kWednesday},
+ {DatetimeExtractorType_THURSDAY, kThursday},
+ {DatetimeExtractorType_FRIDAY, kFriday},
+ {DatetimeExtractorType_SATURDAY, kSaturday},
+ },
+ parsed_day_of_week);
+}
+
+bool DatetimeExtractor::ParseFieldType(
const UnicodeText& input,
- DateParseData::RelationType* parsed_relation_type) const {
+ DatetimeComponent::ComponentType* parsed_field_type) const {
return MapInput(
input,
{
- {DatetimeExtractorType_MONDAY, DateParseData::RelationType::MONDAY},
- {DatetimeExtractorType_TUESDAY, DateParseData::RelationType::TUESDAY},
+ {DatetimeExtractorType_MONDAY,
+ DatetimeComponent::ComponentType::DAY_OF_WEEK},
+ {DatetimeExtractorType_TUESDAY,
+ DatetimeComponent::ComponentType::DAY_OF_WEEK},
{DatetimeExtractorType_WEDNESDAY,
- DateParseData::RelationType::WEDNESDAY},
+ DatetimeComponent::ComponentType::DAY_OF_WEEK},
{DatetimeExtractorType_THURSDAY,
- DateParseData::RelationType::THURSDAY},
- {DatetimeExtractorType_FRIDAY, DateParseData::RelationType::FRIDAY},
+ DatetimeComponent::ComponentType::DAY_OF_WEEK},
+ {DatetimeExtractorType_FRIDAY,
+ DatetimeComponent::ComponentType::DAY_OF_WEEK},
{DatetimeExtractorType_SATURDAY,
- DateParseData::RelationType::SATURDAY},
- {DatetimeExtractorType_SUNDAY, DateParseData::RelationType::SUNDAY},
- {DatetimeExtractorType_SECONDS, DateParseData::RelationType::SECOND},
- {DatetimeExtractorType_MINUTES, DateParseData::RelationType::MINUTE},
- {DatetimeExtractorType_HOURS, DateParseData::RelationType::HOUR},
- {DatetimeExtractorType_DAY, DateParseData::RelationType::DAY},
- {DatetimeExtractorType_WEEK, DateParseData::RelationType::WEEK},
- {DatetimeExtractorType_MONTH, DateParseData::RelationType::MONTH},
- {DatetimeExtractorType_YEAR, DateParseData::RelationType::YEAR},
+ DatetimeComponent::ComponentType::DAY_OF_WEEK},
+ {DatetimeExtractorType_SUNDAY,
+ DatetimeComponent::ComponentType::DAY_OF_WEEK},
+ {DatetimeExtractorType_SECONDS,
+ DatetimeComponent::ComponentType::SECOND},
+ {DatetimeExtractorType_MINUTES,
+ DatetimeComponent::ComponentType::MINUTE},
+ {DatetimeExtractorType_NOW,
+ DatetimeComponent::ComponentType::DAY_OF_MONTH},
+ {DatetimeExtractorType_HOURS, DatetimeComponent::ComponentType::HOUR},
+ {DatetimeExtractorType_DAY,
+ DatetimeComponent::ComponentType::DAY_OF_MONTH},
+ {DatetimeExtractorType_TOMORROW,
+ DatetimeComponent::ComponentType::DAY_OF_MONTH},
+ {DatetimeExtractorType_YESTERDAY,
+ DatetimeComponent::ComponentType::DAY_OF_MONTH},
+ {DatetimeExtractorType_WEEK, DatetimeComponent::ComponentType::WEEK},
+ {DatetimeExtractorType_MONTH,
+ DatetimeComponent::ComponentType::MONTH},
+ {DatetimeExtractorType_YEAR, DatetimeComponent::ComponentType::YEAR},
},
- parsed_relation_type);
+ parsed_field_type);
}
} // namespace libtextclassifier3
diff --git a/native/annotator/datetime/extractor.h b/native/annotator/datetime/extractor.h
index 95e7f7c..097dd95 100644
--- a/native/annotator/datetime/extractor.h
+++ b/native/annotator/datetime/extractor.h
@@ -58,7 +58,7 @@
unilib_(unilib),
rules_(extractor_rules),
type_and_locale_to_rule_(type_and_locale_to_extractor_rule) {}
- bool Extract(DateParseData* result, CodepointSpan* result_span) const;
+ bool Extract(DatetimeParsedData* result, CodepointSpan* result_span) const;
private:
bool RuleIdForType(DatetimeExtractorType type, int* rule_id) const;
@@ -86,19 +86,18 @@
bool ParseWrittenNumber(const UnicodeText& input, int* parsed_number) const;
bool ParseYear(const UnicodeText& input, int* parsed_year) const;
bool ParseMonth(const UnicodeText& input, int* parsed_month) const;
- bool ParseAMPM(const UnicodeText& input,
- DateParseData::AMPM* parsed_ampm) const;
- bool ParseRelation(const UnicodeText& input,
- DateParseData::Relation* parsed_relation) const;
+ bool ParseMeridiem(const UnicodeText& input, int* parsed_meridiem) const;
+ bool ParseRelativeValue(
+ const UnicodeText& input,
+ DatetimeComponent::RelativeQualifier* parsed_relative_value) const;
bool ParseRelationDistance(const UnicodeText& input,
int* parsed_distance) const;
- bool ParseTimeUnit(const UnicodeText& input,
- DateParseData::TimeUnit* parsed_time_unit) const;
- bool ParseRelationType(
+ bool ParseFieldType(
const UnicodeText& input,
- DateParseData::RelationType* parsed_relation_type) const;
- bool ParseWeekday(const UnicodeText& input,
- DateParseData::RelationType* parsed_weekday) const;
+ DatetimeComponent::ComponentType* parsed_field_type) const;
+ bool ParseDayOfWeek(const UnicodeText& input, int* parsed_day_of_week) const;
+ bool ParseRelationAndConvertToRelativeCount(const UnicodeText& input,
+ int* relative_count) const;
const CompiledRule& rule_;
const UniLib::RegexMatcher& matcher_;
diff --git a/native/annotator/datetime/parser.cc b/native/annotator/datetime/parser.cc
index 6d844f4..0f222bd 100644
--- a/native/annotator/datetime/parser.cc
+++ b/native/annotator/datetime/parser.cc
@@ -346,36 +346,47 @@
}
void DatetimeParser::FillInterpretations(
- const DateParseData& parse,
- std::vector<DateParseData>* interpretations) const {
+ const DatetimeParsedData& parse,
+ std::vector<DatetimeParsedData>* interpretations) const {
DatetimeGranularity granularity = calendarlib_.GetGranularity(parse);
- DateParseData modified_parse(parse);
+ DatetimeParsedData modified_parse(parse);
// If the relation field is not set, but relation_type field *is*, assume
// the relation field is NEXT_OR_SAME. This is necessary to handle e.g.
// "monday 3pm" (otherwise only "this monday 3pm" would work).
- if (!(modified_parse.field_set_mask &
- DateParseData::Fields::RELATION_FIELD) &&
- (modified_parse.field_set_mask &
- DateParseData::Fields::RELATION_TYPE_FIELD)) {
- modified_parse.relation = DateParseData::Relation::NEXT_OR_SAME;
- modified_parse.field_set_mask |= DateParseData::Fields::RELATION_FIELD;
+ if (parse.HasFieldType(DatetimeComponent::ComponentType::DAY_OF_WEEK)) {
+ DatetimeComponent::RelativeQualifier relative_value;
+ if (parse.GetRelativeValue(DatetimeComponent::ComponentType::DAY_OF_WEEK,
+ &relative_value)) {
+ if (relative_value == DatetimeComponent::RelativeQualifier::UNSPECIFIED) {
+ modified_parse.SetRelativeValue(
+ DatetimeComponent::ComponentType::DAY_OF_WEEK,
+ DatetimeComponent::RelativeQualifier::THIS);
+ }
+ }
}
// Multiple interpretations of ambiguous datetime expressions are generated
// here.
if (granularity > DatetimeGranularity::GRANULARITY_DAY &&
- (modified_parse.field_set_mask & DateParseData::Fields::HOUR_FIELD) &&
- modified_parse.hour <= 12 &&
- !(modified_parse.field_set_mask & DateParseData::Fields::AMPM_FIELD)) {
- // If it's not clear if the time is AM or PM, generate all variants.
- interpretations->push_back(modified_parse);
- interpretations->back().field_set_mask |= DateParseData::Fields::AMPM_FIELD;
- interpretations->back().ampm = DateParseData::AMPM::AM;
-
- interpretations->push_back(modified_parse);
- interpretations->back().field_set_mask |= DateParseData::Fields::AMPM_FIELD;
- interpretations->back().ampm = DateParseData::AMPM::PM;
+ modified_parse.HasFieldType(DatetimeComponent::ComponentType::HOUR) &&
+ !modified_parse.HasRelativeValue(
+ DatetimeComponent::ComponentType::HOUR) &&
+ !modified_parse.HasFieldType(
+ DatetimeComponent::ComponentType::MERIDIEM)) {
+ int hour_value;
+ modified_parse.GetFieldValue(DatetimeComponent::ComponentType::HOUR,
+ &hour_value);
+ if (hour_value <= 12) {
+ modified_parse.SetAbsoluteValue(
+ DatetimeComponent::ComponentType::MERIDIEM, 0);
+ interpretations->push_back(modified_parse);
+ modified_parse.SetAbsoluteValue(
+ DatetimeComponent::ComponentType::MERIDIEM, 1);
+ interpretations->push_back(modified_parse);
+ } else {
+ interpretations->push_back(modified_parse);
+ }
} else {
// Otherwise just generate 1 variant.
interpretations->push_back(modified_parse);
@@ -394,15 +405,14 @@
int locale_id,
std::vector<DatetimeParseResult>* results,
CodepointSpan* result_span) const {
- DateParseData parse;
+ DatetimeParsedData parse;
DatetimeExtractor extractor(rule, matcher, locale_id, unilib_,
extractor_rules_,
type_and_locale_to_extractor_rule_);
if (!extractor.Extract(&parse, result_span)) {
return false;
}
-
- std::vector<DateParseData> interpretations;
+ std::vector<DatetimeParsedData> interpretations;
if (generate_alternative_interpretations_when_ambiguous_) {
FillInterpretations(parse, &interpretations);
} else {
@@ -410,13 +420,29 @@
}
results->reserve(results->size() + interpretations.size());
- for (const DateParseData& interpretation : interpretations) {
+ for (const DatetimeParsedData& interpretation : interpretations) {
+ std::vector<DatetimeComponent> date_components;
+ interpretation.GetDatetimeComponents(&date_components);
DatetimeParseResult result;
+ // TODO(hassan): Text classifier only provides ambiguity limited to “AM/PM”
+ // which is encoded in the pair of DatetimeParseResult; both
+ // corresponding to the same date, but one corresponding to
+ // “AM” and the other one corresponding to “PM”.
+ // Remove multiple DatetimeParseResult per datetime span,
+ // once the ambiguities/DatetimeComponents are added in the
+ // response. For Details see b/130355975
if (!calendarlib_.InterpretParseData(
interpretation, reference_time_ms_utc, reference_timezone,
reference_locale, &(result.time_ms_utc), &(result.granularity))) {
return false;
}
+
+ // Sort the date time units by component type.
+ std::sort(date_components.begin(), date_components.end(),
+ [](DatetimeComponent a, DatetimeComponent b) {
+ return a.component_type > b.component_type;
+ });
+ result.datetime_components.swap(date_components);
results->push_back(result);
}
return true;
diff --git a/native/annotator/datetime/parser.h b/native/annotator/datetime/parser.h
index 3f0c143..4e995bd 100644
--- a/native/annotator/datetime/parser.h
+++ b/native/annotator/datetime/parser.h
@@ -28,6 +28,7 @@
#include "annotator/types.h"
#include "utils/base/integral_types.h"
#include "utils/calendar/calendar.h"
+#include "utils/utf8/unicodetext.h"
#include "utils/utf8/unilib.h"
#include "utils/zlib/zlib.h"
@@ -91,8 +92,9 @@
bool anchor_start_end,
std::vector<DatetimeParseResultSpan>* result) const;
- void FillInterpretations(const DateParseData& parse,
- std::vector<DateParseData>* interpretations) const;
+ void FillInterpretations(
+ const DatetimeParsedData& parse,
+ std::vector<DatetimeParsedData>* interpretations) const;
// Converts the current match in 'matcher' into DatetimeParseResult.
bool ExtractDatetime(const CompiledRule& rule,
diff --git a/native/annotator/datetime/parser_test.cc b/native/annotator/datetime/parser_test.cc
index 8196fa7..35c725f 100644
--- a/native/annotator/datetime/parser_test.cc
+++ b/native/annotator/datetime/parser_test.cc
@@ -29,10 +29,48 @@
#include "annotator/types-test-util.h"
#include "utils/testing/annotator.h"
+using std::vector;
using testing::ElementsAreArray;
namespace libtextclassifier3 {
namespace {
+// Builder class to construct the DatetimeComponents and make the test readable.
+class DatetimeComponentsBuilder {
+ public:
+ DatetimeComponentsBuilder Add(DatetimeComponent::ComponentType type,
+ int value) {
+ DatetimeComponent component;
+ component.component_type = type;
+ component.value = value;
+ return AddComponent(component);
+ }
+
+ DatetimeComponentsBuilder Add(
+ DatetimeComponent::ComponentType type, int value,
+ DatetimeComponent::RelativeQualifier relative_qualifier,
+ int relative_count) {
+ DatetimeComponent component;
+ component.component_type = type;
+ component.value = value;
+ component.relative_qualifier = relative_qualifier;
+ component.relative_count = relative_count;
+ return AddComponent(component);
+ }
+
+ std::vector<DatetimeComponent> Build() {
+ std::vector<DatetimeComponent> result(datetime_components_);
+ datetime_components_.clear();
+ return result;
+ }
+
+ private:
+ DatetimeComponentsBuilder AddComponent(
+ const DatetimeComponent& datetime_component) {
+ datetime_components_.push_back(datetime_component);
+ return *this;
+ }
+ std::vector<DatetimeComponent> datetime_components_;
+};
std::string GetModelPath() {
return TC3_TEST_DATA_DIR;
@@ -76,8 +114,9 @@
}
bool ParsesCorrectly(const std::string& marked_text,
- const std::vector<int64>& expected_ms_utcs,
+ const vector<int64>& expected_ms_utcs,
DatetimeGranularity expected_granularity,
+ vector<vector<DatetimeComponent>> datetime_components,
bool anchor_start_end = false,
const std::string& timezone = "Europe/Zurich",
const std::string& locales = "en-US",
@@ -124,7 +163,6 @@
filtered_results.push_back(result);
}
}
-
std::vector<DatetimeParseResultSpan> expected{
{{expected_start_index, expected_end_index},
{},
@@ -132,7 +170,8 @@
/*priority_score=*/0.1}};
expected[0].data.resize(expected_ms_utcs.size());
for (int i = 0; i < expected_ms_utcs.size(); i++) {
- expected[0].data[i] = {expected_ms_utcs[i], expected_granularity};
+ expected[0].data[i] = {expected_ms_utcs[i], expected_granularity,
+ datetime_components[i]};
}
const bool matches =
@@ -151,28 +190,34 @@
bool ParsesCorrectly(const std::string& marked_text,
const int64 expected_ms_utc,
DatetimeGranularity expected_granularity,
+ vector<vector<DatetimeComponent>> datetime_components,
bool anchor_start_end = false,
const std::string& timezone = "Europe/Zurich",
const std::string& locales = "en-US",
AnnotationUsecase annotation_usecase =
AnnotationUsecase_ANNOTATION_USECASE_SMART) {
- return ParsesCorrectly(marked_text, std::vector<int64>{expected_ms_utc},
- expected_granularity, anchor_start_end, timezone,
- locales, annotation_usecase);
+ return ParsesCorrectly(marked_text, vector<int64>{expected_ms_utc},
+ expected_granularity, datetime_components,
+ anchor_start_end, timezone, locales,
+ annotation_usecase);
}
- bool ParsesCorrectlyGerman(const std::string& marked_text,
- const std::vector<int64>& expected_ms_utcs,
- DatetimeGranularity expected_granularity) {
+ bool ParsesCorrectlyGerman(
+ const std::string& marked_text, const vector<int64>& expected_ms_utcs,
+ DatetimeGranularity expected_granularity,
+ vector<vector<DatetimeComponent>> datetime_components) {
return ParsesCorrectly(marked_text, expected_ms_utcs, expected_granularity,
+ datetime_components,
/*anchor_start_end=*/false,
/*timezone=*/"Europe/Zurich", /*locales=*/"de");
}
- bool ParsesCorrectlyGerman(const std::string& marked_text,
- const int64 expected_ms_utc,
- DatetimeGranularity expected_granularity) {
+ bool ParsesCorrectlyGerman(
+ const std::string& marked_text, const int64 expected_ms_utc,
+ DatetimeGranularity expected_granularity,
+ vector<vector<DatetimeComponent>> datetime_components) {
return ParsesCorrectly(marked_text, expected_ms_utc, expected_granularity,
+ datetime_components,
/*anchor_start_end=*/false,
/*timezone=*/"Europe/Zurich", /*locales=*/"de");
}
@@ -186,52 +231,275 @@
// Test with just a few cases to make debugging of general failures easier.
TEST_F(ParserTest, ParseShort) {
- EXPECT_TRUE(
- ParsesCorrectly("{January 1, 1988}", 567990000000, GRANULARITY_DAY));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{January 1, 1988}", 567990000000, GRANULARITY_DAY,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 1988)
+ .Build()}));
}
TEST_F(ParserTest, Parse) {
+ EXPECT_TRUE(ParsesCorrectly(
+ "{January 1, 1988}", 567990000000, GRANULARITY_DAY,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 1988)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{january 31 2018}", 1517353200000, GRANULARITY_DAY,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 31)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2018)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "lorem {1 january 2018} ipsum", 1514761200000, GRANULARITY_DAY,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2018)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{09/Mar/2004 22:02:40}", 1078866160000, GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::SECOND, 40)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 02)
+ .Add(DatetimeComponent::ComponentType::HOUR, 22)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 9)
+ .Add(DatetimeComponent::ComponentType::MONTH, 3)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2004)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{Dec 2, 2010 2:39:58 AM}", 1291253998000, GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 58)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 39)
+ .Add(DatetimeComponent::ComponentType::HOUR, 2)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 2)
+ .Add(DatetimeComponent::ComponentType::MONTH, 12)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2010)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{Jun 09 2011 15:28:14}", 1307626094000, GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::SECOND, 14)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 28)
+ .Add(DatetimeComponent::ComponentType::HOUR, 15)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 9)
+ .Add(DatetimeComponent::ComponentType::MONTH, 6)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2011)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{Mar 16 08:12:04}", {6419524000, 6462724000}, GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 4)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 12)
+ .Add(DatetimeComponent::ComponentType::HOUR, 8)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 16)
+ .Add(DatetimeComponent::ComponentType::MONTH, 3)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 4)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 12)
+ .Add(DatetimeComponent::ComponentType::HOUR, 8)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 16)
+ .Add(DatetimeComponent::ComponentType::MONTH, 3)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{2010-06-26 02:31:29}", {1277512289000, 1277555489000},
+ GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 29)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 31)
+ .Add(DatetimeComponent::ComponentType::HOUR, 2)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 26)
+ .Add(DatetimeComponent::ComponentType::MONTH, 6)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2010)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 29)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 31)
+ .Add(DatetimeComponent::ComponentType::HOUR, 2)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 26)
+ .Add(DatetimeComponent::ComponentType::MONTH, 6)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2010)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{2006/01/22 04:11:05}", {1137899465000, 1137942665000},
+ GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 5)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 11)
+ .Add(DatetimeComponent::ComponentType::HOUR, 4)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 22)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2006)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 5)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 11)
+ .Add(DatetimeComponent::ComponentType::HOUR, 4)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 22)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2006)
+ .Build()}));
EXPECT_TRUE(
- ParsesCorrectly("{January 1, 1988}", 567990000000, GRANULARITY_DAY));
- EXPECT_TRUE(
- ParsesCorrectly("{january 31 2018}", 1517353200000, GRANULARITY_DAY));
- EXPECT_TRUE(ParsesCorrectly("lorem {1 january 2018} ipsum", 1514761200000,
- GRANULARITY_DAY));
- EXPECT_TRUE(ParsesCorrectly("{09/Mar/2004 22:02:40}", 1078866160000,
- GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectly("{Dec 2, 2010 2:39:58 AM}", 1291253998000,
- GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectly("{Jun 09 2011 15:28:14}", 1307626094000,
- GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectly("{Mar 16 08:12:04}", {6419524000, 6462724000},
- GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectly("{2010-06-26 02:31:29}",
- {1277512289000, 1277555489000},
- GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectly("{2006/01/22 04:11:05}",
- {1137899465000, 1137942665000},
- GRANULARITY_SECOND));
- EXPECT_TRUE(
- ParsesCorrectly("{11:42:35}", {38555000, 81755000}, GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectly("{23/Apr 11:42:35}", {9715355000, 9758555000},
- GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectly("{23/Apr/2015 11:42:35}",
- {1429782155000, 1429825355000},
- GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectly("{23-Apr-2015 11:42:35}",
- {1429782155000, 1429825355000},
- GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectly("{23 Apr 2015 11:42:35}",
- {1429782155000, 1429825355000},
- GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectly("{04/23/15 11:42:35}",
- {1429782155000, 1429825355000},
- GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectly("{04/23/2015 11:42:35}",
- {1429782155000, 1429825355000},
- GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectly("{9/28/2011 2:23:15 PM}", 1317212595000,
- GRANULARITY_SECOND));
+ ParsesCorrectly("{11:42:35}", {38555000, 81755000}, GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{23/Apr 11:42:35}", {9715355000, 9758555000}, GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{23/Apr/2015 11:42:35}", {1429782155000, 1429825355000},
+ GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{23-Apr-2015 11:42:35}", {1429782155000, 1429825355000},
+ GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{23 Apr 2015 11:42:35}", {1429782155000, 1429825355000},
+ GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{04/23/15 11:42:35}", {1429782155000, 1429825355000}, GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{04/23/2015 11:42:35}", {1429782155000, 1429825355000},
+ GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{9/28/2011 2:23:15 PM}", 1317212595000, GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 15)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 23)
+ .Add(DatetimeComponent::ComponentType::HOUR, 2)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 28)
+ .Add(DatetimeComponent::ComponentType::MONTH, 9)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2011)
+ .Build()}));
EXPECT_TRUE(ParsesCorrectly(
"Are sentiments apartments decisively the especially alteration. "
"Thrown shy denote ten ladies though ask saw. Or by to he going "
@@ -243,40 +511,179 @@
"think order event music. Incommode so intention defective at "
"convinced. Led income months itself and houses you. After nor "
"you leave might share court balls. ",
- {1271651775000, 1271694975000}, GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectly("{january 1 2018 at 4:30}",
- {1514777400000, 1514820600000},
- GRANULARITY_MINUTE));
- EXPECT_TRUE(ParsesCorrectly("{january 1 2018 at 4:30 am}", 1514777400000,
- GRANULARITY_MINUTE));
- EXPECT_TRUE(ParsesCorrectly("{january 1 2018 at 4pm}", 1514818800000,
- GRANULARITY_HOUR));
+ {1271651775000, 1271694975000}, GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 15)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 36)
+ .Add(DatetimeComponent::ComponentType::HOUR, 6)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 19)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2010)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 15)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 36)
+ .Add(DatetimeComponent::ComponentType::HOUR, 6)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 19)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2010)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{january 1 2018 at 4:30}", {1514777400000, 1514820600000},
+ GRANULARITY_MINUTE,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 30)
+ .Add(DatetimeComponent::ComponentType::HOUR, 4)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2018)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 30)
+ .Add(DatetimeComponent::ComponentType::HOUR, 4)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2018)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{january 1 2018 at 4:30 am}", 1514777400000, GRANULARITY_MINUTE,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 30)
+ .Add(DatetimeComponent::ComponentType::HOUR, 4)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2018)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{january 1 2018 at 4pm}", 1514818800000, GRANULARITY_HOUR,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::HOUR, 4)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2018)
+ .Build()}));
- EXPECT_TRUE(ParsesCorrectly("{today at 0:00}", {-3600000, 39600000},
- GRANULARITY_MINUTE));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{today at 0:00}", {-3600000, 39600000}, GRANULARITY_MINUTE,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 0)
+ .Add(DatetimeComponent::ComponentType::HOUR, 0)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
+ DatetimeComponent::RelativeQualifier::NOW, 0)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 0)
+ .Add(DatetimeComponent::ComponentType::HOUR, 0)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
+ DatetimeComponent::RelativeQualifier::NOW, 0)
+ .Build()}));
EXPECT_TRUE(ParsesCorrectly(
"{today at 0:00}", {-57600000, -14400000}, GRANULARITY_MINUTE,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 0)
+ .Add(DatetimeComponent::ComponentType::HOUR, 0)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
+ DatetimeComponent::RelativeQualifier::NOW, 0)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 0)
+ .Add(DatetimeComponent::ComponentType::HOUR, 0)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
+ DatetimeComponent::RelativeQualifier::NOW, 0)
+ .Build()},
/*anchor_start_end=*/false, "America/Los_Angeles"));
- EXPECT_TRUE(ParsesCorrectly("{tomorrow at 4:00}", {97200000, 140400000},
- GRANULARITY_MINUTE));
- EXPECT_TRUE(ParsesCorrectly("{tomorrow at 4am}", 97200000, GRANULARITY_HOUR));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{tomorrow at 4:00}", {97200000, 140400000}, GRANULARITY_MINUTE,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 0)
+ .Add(DatetimeComponent::ComponentType::HOUR, 4)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
+ DatetimeComponent::RelativeQualifier::TOMORROW, 1)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 0)
+ .Add(DatetimeComponent::ComponentType::HOUR, 4)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
+ DatetimeComponent::RelativeQualifier::TOMORROW, 1)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{tomorrow at 4am}", 97200000, GRANULARITY_HOUR,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::HOUR, 4)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
+ DatetimeComponent::RelativeQualifier::TOMORROW, 1)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{wednesday at 4am}", 529200000, GRANULARITY_HOUR,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::HOUR, 4)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_WEEK, 4,
+ DatetimeComponent::RelativeQualifier::THIS, 0)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "last seen {today at 9:01 PM}", 72060000, GRANULARITY_MINUTE,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 1)
+ .Add(DatetimeComponent::ComponentType::HOUR, 9)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
+ DatetimeComponent::RelativeQualifier::NOW, 0)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "set an alarm for {7am tomorrow}", 108000000, GRANULARITY_HOUR,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::HOUR, 7)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
+ DatetimeComponent::RelativeQualifier::TOMORROW, 1)
+ .Build()}));
EXPECT_TRUE(
- ParsesCorrectly("{wednesday at 4am}", 529200000, GRANULARITY_HOUR));
- EXPECT_TRUE(ParsesCorrectly("last seen {today at 9:01 PM}", 72060000,
- GRANULARITY_MINUTE));
- EXPECT_TRUE(ParsesCorrectly("set an alarm for {7am tomorrow}", 108000000,
- GRANULARITY_HOUR));
- EXPECT_TRUE(
- ParsesCorrectly("set an alarm for {7 a.m}", 21600000, GRANULARITY_HOUR));
+ ParsesCorrectly("set an alarm for {7 a.m}", 21600000, GRANULARITY_HOUR,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::HOUR, 7)
+ .Build()}));
}
TEST_F(ParserTest, ParseWithAnchor) {
- EXPECT_TRUE(ParsesCorrectly("{January 1, 1988}", 567990000000,
- GRANULARITY_DAY, /*anchor_start_end=*/false));
- EXPECT_TRUE(ParsesCorrectly("{January 1, 1988}", 567990000000,
- GRANULARITY_DAY, /*anchor_start_end=*/true));
- EXPECT_TRUE(ParsesCorrectly("lorem {1 january 2018} ipsum", 1514761200000,
- GRANULARITY_DAY, /*anchor_start_end=*/false));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{January 1, 1988}", 567990000000, GRANULARITY_DAY,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 1988)
+ .Build()},
+ /*anchor_start_end=*/false));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{January 1, 1988}", 567990000000, GRANULARITY_DAY,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 1988)
+ .Build()},
+ /*anchor_start_end=*/true));
+ EXPECT_TRUE(ParsesCorrectly(
+ "lorem {1 january 2018} ipsum", 1514761200000, GRANULARITY_DAY,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2018)
+ .Build()},
+ /*anchor_start_end=*/false));
EXPECT_TRUE(HasNoResult("lorem 1 january 2018 ipsum",
/*anchor_start_end=*/true));
}
@@ -284,29 +691,50 @@
TEST_F(ParserTest, ParseWithRawUsecase) {
// Annotated for RAW usecase.
EXPECT_TRUE(ParsesCorrectly(
- "{tomorrow}", 82800000, GRANULARITY_DAY, /*anchor_start_end=*/false,
+ "{tomorrow}", 82800000, GRANULARITY_DAY,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
+ DatetimeComponent::RelativeQualifier::TOMORROW, 1)
+ .Build()},
+ /*anchor_start_end=*/false,
/*timezone=*/"Europe/Zurich", /*locales=*/"en-US",
/*annotation_usecase=*/AnnotationUsecase_ANNOTATION_USECASE_RAW));
EXPECT_TRUE(ParsesCorrectly(
"call me {in two hours}", 7200000, GRANULARITY_HOUR,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::HOUR, 0,
+ DatetimeComponent::RelativeQualifier::FUTURE, 2)
+ .Build()},
/*anchor_start_end=*/false,
/*timezone=*/"Europe/Zurich", /*locales=*/"en-US",
/*annotation_usecase=*/AnnotationUsecase_ANNOTATION_USECASE_RAW));
EXPECT_TRUE(ParsesCorrectly(
"call me {next month}", 2674800000, GRANULARITY_MONTH,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MONTH, 0,
+ DatetimeComponent::RelativeQualifier::NEXT, 1)
+ .Build()},
/*anchor_start_end=*/false,
/*timezone=*/"Europe/Zurich", /*locales=*/"en-US",
/*annotation_usecase=*/AnnotationUsecase_ANNOTATION_USECASE_RAW));
EXPECT_TRUE(ParsesCorrectly(
"what's the time {now}", -3600000, GRANULARITY_DAY,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
+ DatetimeComponent::RelativeQualifier::NOW, 0)
+ .Build()},
/*anchor_start_end=*/false,
/*timezone=*/"Europe/Zurich", /*locales=*/"en-US",
/*annotation_usecase=*/AnnotationUsecase_ANNOTATION_USECASE_RAW));
EXPECT_TRUE(ParsesCorrectly(
"call me on {Saturday}", 169200000, GRANULARITY_DAY,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::DAY_OF_WEEK, 7,
+ DatetimeComponent::RelativeQualifier::THIS, 0)
+ .Build()},
/*anchor_start_end=*/false,
/*timezone=*/"Europe/Zurich", /*locales=*/"en-US",
/*annotation_usecase=*/AnnotationUsecase_ANNOTATION_USECASE_RAW));
@@ -319,105 +747,496 @@
}
TEST_F(ParserTest, ParsesNoonAndMidnightCorrectly) {
- EXPECT_TRUE(ParsesCorrectly("{January 1, 1988 12:30am}", 567991800000,
- GRANULARITY_MINUTE));
- EXPECT_TRUE(ParsesCorrectly("{January 1, 1988 12:30pm}", 568035000000,
- GRANULARITY_MINUTE));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{January 1, 1988 12:30am}", 567991800000, GRANULARITY_MINUTE,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 30)
+ .Add(DatetimeComponent::ComponentType::HOUR, 12)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 1988)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{January 1, 1988 12:30pm}", 568035000000, GRANULARITY_MINUTE,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 30)
+ .Add(DatetimeComponent::ComponentType::HOUR, 12)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 1988)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{tomorrow at 12:00 am}", 82800000, GRANULARITY_MINUTE,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 0)
+ .Add(DatetimeComponent::ComponentType::HOUR, 12)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
+ DatetimeComponent::RelativeQualifier::TOMORROW, 1)
+ .Build()}));
}
TEST_F(ParserTest, ParseGerman) {
- EXPECT_TRUE(
- ParsesCorrectlyGerman("{Januar 1 2018}", 1514761200000, GRANULARITY_DAY));
- EXPECT_TRUE(
- ParsesCorrectlyGerman("{1 2 2018}", 1517439600000, GRANULARITY_DAY));
- EXPECT_TRUE(ParsesCorrectlyGerman("lorem {1 Januar 2018} ipsum",
- 1514761200000, GRANULARITY_DAY));
- EXPECT_TRUE(ParsesCorrectlyGerman("{19/Apr/2010:06:36:15}",
- {1271651775000, 1271694975000},
- GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectlyGerman("{09/März/2004 22:02:40}", 1078866160000,
- GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectlyGerman("{Dez 2, 2010 2:39:58}",
- {1291253998000, 1291297198000},
- GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectlyGerman("{Juni 09 2011 15:28:14}", 1307626094000,
- GRANULARITY_SECOND));
EXPECT_TRUE(ParsesCorrectlyGerman(
- "{März 16 08:12:04}", {6419524000, 6462724000}, GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectlyGerman("{2010-06-26 02:31:29}",
- {1277512289000, 1277555489000},
- GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectlyGerman("{2006/01/22 04:11:05}",
- {1137899465000, 1137942665000},
- GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectlyGerman("{11:42:35}", {38555000, 81755000},
- GRANULARITY_SECOND));
+ "{Januar 1 2018}", 1514761200000, GRANULARITY_DAY,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2018)
+ .Build()}));
EXPECT_TRUE(ParsesCorrectlyGerman(
- "{23/Apr 11:42:35}", {9715355000, 9758555000}, GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectlyGerman("{23/Apr/2015:11:42:35}",
- {1429782155000, 1429825355000},
- GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectlyGerman("{23/Apr/2015 11:42:35}",
- {1429782155000, 1429825355000},
- GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectlyGerman("{23-Apr-2015 11:42:35}",
- {1429782155000, 1429825355000},
- GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectlyGerman("{23 Apr 2015 11:42:35}",
- {1429782155000, 1429825355000},
- GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectlyGerman("{04/23/15 11:42:35}",
- {1429782155000, 1429825355000},
- GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectlyGerman("{04/23/2015 11:42:35}",
- {1429782155000, 1429825355000},
- GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectlyGerman("{19/apr/2010:06:36:15}",
- {1271651775000, 1271694975000},
- GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectlyGerman("{januar 1 2018 um 4:30}",
- {1514777400000, 1514820600000},
- GRANULARITY_MINUTE));
- EXPECT_TRUE(ParsesCorrectlyGerman("{januar 1 2018 um 4:30 nachm}",
- 1514820600000, GRANULARITY_MINUTE));
- EXPECT_TRUE(ParsesCorrectlyGerman("{januar 1 2018 um 4 nachm}", 1514818800000,
- GRANULARITY_HOUR));
- EXPECT_TRUE(
- ParsesCorrectlyGerman("{14.03.2017}", 1489446000000, GRANULARITY_DAY));
- EXPECT_TRUE(ParsesCorrectlyGerman("{morgen 0:00}", {82800000, 126000000},
- GRANULARITY_MINUTE));
- EXPECT_TRUE(ParsesCorrectlyGerman("{morgen um 4:00}", {97200000, 140400000},
- GRANULARITY_MINUTE));
- EXPECT_TRUE(
- ParsesCorrectlyGerman("{morgen um 4 vorm}", 97200000, GRANULARITY_HOUR));
+ "{1 2 2018}", 1517439600000, GRANULARITY_DAY,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::MONTH, 2)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2018)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "lorem {1 Januar 2018} ipsum", 1514761200000, GRANULARITY_DAY,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2018)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "{19/Apr/2010:06:36:15}", {1271651775000, 1271694975000},
+ GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 15)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 36)
+ .Add(DatetimeComponent::ComponentType::HOUR, 6)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 19)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2010)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 15)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 36)
+ .Add(DatetimeComponent::ComponentType::HOUR, 6)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 19)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2010)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "{09/März/2004 22:02:40}", 1078866160000, GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::SECOND, 40)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 02)
+ .Add(DatetimeComponent::ComponentType::HOUR, 22)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 9)
+ .Add(DatetimeComponent::ComponentType::MONTH, 3)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2004)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "{Dez 2, 2010 2:39:58}", {1291253998000, 1291297198000},
+ GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 58)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 39)
+ .Add(DatetimeComponent::ComponentType::HOUR, 2)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 2)
+ .Add(DatetimeComponent::ComponentType::MONTH, 12)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2010)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 58)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 39)
+ .Add(DatetimeComponent::ComponentType::HOUR, 2)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 2)
+ .Add(DatetimeComponent::ComponentType::MONTH, 12)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2010)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "{Juni 09 2011 15:28:14}", 1307626094000, GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::SECOND, 14)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 28)
+ .Add(DatetimeComponent::ComponentType::HOUR, 15)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 9)
+ .Add(DatetimeComponent::ComponentType::MONTH, 6)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2011)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "{März 16 08:12:04}", {6419524000, 6462724000}, GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 4)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 12)
+ .Add(DatetimeComponent::ComponentType::HOUR, 8)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 16)
+ .Add(DatetimeComponent::ComponentType::MONTH, 3)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 4)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 12)
+ .Add(DatetimeComponent::ComponentType::HOUR, 8)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 16)
+ .Add(DatetimeComponent::ComponentType::MONTH, 3)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "{2010-06-26 02:31:29}", {1277512289000, 1277555489000},
+ GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 29)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 31)
+ .Add(DatetimeComponent::ComponentType::HOUR, 2)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 26)
+ .Add(DatetimeComponent::ComponentType::MONTH, 6)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2010)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 29)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 31)
+ .Add(DatetimeComponent::ComponentType::HOUR, 2)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 26)
+ .Add(DatetimeComponent::ComponentType::MONTH, 6)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2010)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "{2006/01/22 04:11:05}", {1137899465000, 1137942665000},
+ GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 5)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 11)
+ .Add(DatetimeComponent::ComponentType::HOUR, 4)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 22)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2006)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 5)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 11)
+ .Add(DatetimeComponent::ComponentType::HOUR, 4)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 22)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2006)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "{11:42:35}", {38555000, 81755000}, GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "{23/Apr 11:42:35}", {9715355000, 9758555000}, GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "{23/Apr/2015:11:42:35}", {1429782155000, 1429825355000},
+ GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "{23/Apr/2015 11:42:35}", {1429782155000, 1429825355000},
+ GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "{23-Apr-2015 11:42:35}", {1429782155000, 1429825355000},
+ GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "{23 Apr 2015 11:42:35}", {1429782155000, 1429825355000},
+ GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "{04/23/15 11:42:35}", {1429782155000, 1429825355000}, GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "{04/23/2015 11:42:35}", {1429782155000, 1429825355000},
+ GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "{19/apr/2010:06:36:15}", {1271651775000, 1271694975000},
+ GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 15)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 36)
+ .Add(DatetimeComponent::ComponentType::HOUR, 6)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 19)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2010)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 15)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 36)
+ .Add(DatetimeComponent::ComponentType::HOUR, 6)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 19)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2010)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "{januar 1 2018 um 4:30}", {1514777400000, 1514820600000},
+ GRANULARITY_MINUTE,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 30)
+ .Add(DatetimeComponent::ComponentType::HOUR, 4)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2018)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 30)
+ .Add(DatetimeComponent::ComponentType::HOUR, 4)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2018)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "{januar 1 2018 um 4:30 nachm}", 1514820600000, GRANULARITY_MINUTE,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 30)
+ .Add(DatetimeComponent::ComponentType::HOUR, 4)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2018)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "{januar 1 2018 um 4 nachm}", 1514818800000, GRANULARITY_HOUR,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::HOUR, 4)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2018)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "{14.03.2017}", 1489446000000, GRANULARITY_DAY,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 14)
+ .Add(DatetimeComponent::ComponentType::MONTH, 3)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2017)
+ .Build()}));
+
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "{morgen 0:00}", {82800000, 126000000}, GRANULARITY_MINUTE,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 0)
+ .Add(DatetimeComponent::ComponentType::HOUR, 0)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
+ DatetimeComponent::RelativeQualifier::TOMORROW, 1)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 0)
+ .Add(DatetimeComponent::ComponentType::HOUR, 0)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
+ DatetimeComponent::RelativeQualifier::TOMORROW, 1)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "{morgen um 4:00}", {97200000, 140400000}, GRANULARITY_MINUTE,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 0)
+ .Add(DatetimeComponent::ComponentType::HOUR, 4)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
+ DatetimeComponent::RelativeQualifier::TOMORROW, 1)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 0)
+ .Add(DatetimeComponent::ComponentType::HOUR, 4)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
+ DatetimeComponent::RelativeQualifier::TOMORROW, 1)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "{morgen um 4 vorm}", 97200000, GRANULARITY_HOUR,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::HOUR, 4)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
+ DatetimeComponent::RelativeQualifier::TOMORROW, 1)
+ .Build()}));
}
TEST_F(ParserTest, ParseNonUs) {
+ auto first_may_2015 =
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::MONTH, 5)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build();
+
EXPECT_TRUE(ParsesCorrectly("{1/5/15}", 1430431200000, GRANULARITY_DAY,
+ {first_may_2015},
/*anchor_start_end=*/false,
/*timezone=*/"Europe/Zurich",
/*locales=*/"en-GB"));
EXPECT_TRUE(ParsesCorrectly("{1/5/15}", 1430431200000, GRANULARITY_DAY,
+ {first_may_2015},
/*anchor_start_end=*/false,
/*timezone=*/"Europe/Zurich", /*locales=*/"en"));
}
TEST_F(ParserTest, ParseUs) {
+ auto five_january_2015 =
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 5)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build();
+
EXPECT_TRUE(ParsesCorrectly("{1/5/15}", 1420412400000, GRANULARITY_DAY,
+ {five_january_2015},
/*anchor_start_end=*/false,
/*timezone=*/"Europe/Zurich",
/*locales=*/"en-US"));
EXPECT_TRUE(ParsesCorrectly("{1/5/15}", 1420412400000, GRANULARITY_DAY,
+ {five_january_2015},
/*anchor_start_end=*/false,
/*timezone=*/"Europe/Zurich",
/*locales=*/"es-US"));
}
TEST_F(ParserTest, ParseUnknownLanguage) {
- EXPECT_TRUE(ParsesCorrectly("bylo to {31. 12. 2015} v 6 hodin", 1451516400000,
- GRANULARITY_DAY,
- /*anchor_start_end=*/false,
- /*timezone=*/"Europe/Zurich", /*locales=*/"xx"));
+ EXPECT_TRUE(ParsesCorrectly(
+ "bylo to {31. 12. 2015} v 6 hodin", 1451516400000, GRANULARITY_DAY,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 31)
+ .Add(DatetimeComponent::ComponentType::MONTH, 12)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build()},
+ /*anchor_start_end=*/false,
+ /*timezone=*/"Europe/Zurich", /*locales=*/"xx"));
}
TEST_F(ParserTest, WhenAlternativesEnabledGeneratesAlternatives) {
@@ -426,12 +1245,49 @@
true;
});
- EXPECT_TRUE(ParsesCorrectly("{january 1 2018 at 4:30}",
- {1514777400000, 1514820600000},
- GRANULARITY_MINUTE));
- EXPECT_TRUE(ParsesCorrectly("{monday 3pm}", 396000000, GRANULARITY_HOUR));
- EXPECT_TRUE(ParsesCorrectly("{monday 3:00}", {352800000, 396000000},
- GRANULARITY_MINUTE));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{january 1 2018 at 4:30}", {1514777400000, 1514820600000},
+ GRANULARITY_MINUTE,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 30)
+ .Add(DatetimeComponent::ComponentType::HOUR, 4)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2018)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 30)
+ .Add(DatetimeComponent::ComponentType::HOUR, 4)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2018)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{monday 3pm}", 396000000, GRANULARITY_HOUR,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::HOUR, 3)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_WEEK, 2,
+ DatetimeComponent::RelativeQualifier::THIS, 0)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{monday 3:00}", {352800000, 396000000}, GRANULARITY_MINUTE,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 0)
+ .Add(DatetimeComponent::ComponentType::HOUR, 3)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_WEEK, 2,
+ DatetimeComponent::RelativeQualifier::THIS, 0)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 0)
+ .Add(DatetimeComponent::ComponentType::HOUR, 3)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_WEEK, 2,
+ DatetimeComponent::RelativeQualifier::THIS, 0)
+ .Build()}));
}
TEST_F(ParserTest, WhenAlternativesDisabledDoesNotGenerateAlternatives) {
@@ -440,8 +1296,15 @@
false;
});
- EXPECT_TRUE(ParsesCorrectly("{january 1 2018 at 4:30}", 1514777400000,
- GRANULARITY_MINUTE));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{january 1 2018 at 4:30}", 1514777400000, GRANULARITY_MINUTE,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MINUTE, 30)
+ .Add(DatetimeComponent::ComponentType::HOUR, 4)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2018)
+ .Build()}));
}
class ParserLocaleTest : public testing::Test {
diff --git a/native/annotator/duration/duration.cc b/native/annotator/duration/duration.cc
index d442dc6..3529691 100644
--- a/native/annotator/duration/duration.cc
+++ b/native/annotator/duration/duration.cc
@@ -76,6 +76,20 @@
return result;
}
+std::unordered_set<int32> BuildInt32Set(
+ const flatbuffers::Vector<int32>* ints) {
+ std::unordered_set<int32> result;
+ if (ints == nullptr) {
+ return result;
+ }
+
+ for (const int32 int_value : *ints) {
+ result.insert(int_value);
+ }
+
+ return result;
+}
+
} // namespace internal
bool DurationAnnotator::ClassifyText(
@@ -152,7 +166,8 @@
start_index = token.start;
}
end_index = token.end;
- } else if (ParseDurationUnitToken(token, &parsed_duration.unit)) {
+ } else if (ParseDurationUnitToken(token, &parsed_duration.unit) ||
+ ParseQuantityDurationUnitToken(token, &parsed_duration)) {
if (start_index == kInvalidIndex) {
start_index = token.start;
}
@@ -223,7 +238,7 @@
break;
}
- int value = atom.value;
+ int64 value = atom.value;
// This condition handles expressions like "an hour", where the quantity is
// not specified. In this case we assume quantity 1. Except for cases like
// "half hour".
@@ -275,6 +290,31 @@
return true;
}
+bool DurationAnnotator::ParseQuantityDurationUnitToken(
+ const Token& token, ParsedDurationAtom* value) const {
+ if (token.value.empty()) {
+ return false;
+ }
+
+ Token sub_token;
+ bool has_quantity = false;
+ for (const char c : token.value) {
+ if (sub_token_separator_codepoints_.find(c) !=
+ sub_token_separator_codepoints_.end()) {
+ if (has_quantity || !ParseQuantityToken(sub_token, value)) {
+ return false;
+ }
+ has_quantity = true;
+
+ sub_token = Token();
+ } else {
+ sub_token.value += c;
+ }
+ }
+
+ return ParseDurationUnitToken(sub_token, &(value->unit));
+}
+
bool DurationAnnotator::ParseFillerToken(const Token& token) const {
std::string token_value_buffer;
const std::string& token_value = feature_processor_->StripBoundaryCodepoints(
diff --git a/native/annotator/duration/duration.h b/native/annotator/duration/duration.h
index 4311afc..2242259 100644
--- a/native/annotator/duration/duration.h
+++ b/native/annotator/duration/duration.h
@@ -50,7 +50,11 @@
// Creates a set of strings from a flatbuffer string vector.
std::unordered_set<std::string> BuildStringSet(
- const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>*);
+ const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>*
+ strings);
+
+// Creates a set of ints from a flatbuffer int vector.
+std::unordered_set<int32> BuildInt32Set(const flatbuffers::Vector<int32>* ints);
} // namespace internal
@@ -66,7 +70,9 @@
filler_expressions_(
internal::BuildStringSet(options->filler_expressions())),
half_expressions_(
- internal::BuildStringSet(options->half_expressions())) {}
+ internal::BuildStringSet(options->half_expressions())),
+ sub_token_separator_codepoints_(internal::BuildInt32Set(
+ options->sub_token_separator_codepoints())) {}
// Classifies given text, and if it is a duration, it passes the result in
// 'classification_result' and returns true, otherwise returns false.
@@ -110,6 +116,8 @@
bool ParseQuantityToken(const Token& token, ParsedDurationAtom* value) const;
bool ParseDurationUnitToken(const Token& token,
internal::DurationUnit* duration_unit) const;
+ bool ParseQuantityDurationUnitToken(const Token& token,
+ ParsedDurationAtom* value) const;
bool ParseFillerToken(const Token& token) const;
int64 ParsedDurationAtomsToMillis(
@@ -121,6 +129,7 @@
token_value_to_duration_unit_;
const std::unordered_set<std::string> filler_expressions_;
const std::unordered_set<std::string> half_expressions_;
+ const std::unordered_set<int32> sub_token_separator_codepoints_;
};
} // namespace libtextclassifier3
diff --git a/native/annotator/duration/duration_test.cc b/native/annotator/duration/duration_test.cc
index 78548fe..3fc25e6 100644
--- a/native/annotator/duration/duration_test.cc
+++ b/native/annotator/duration/duration_test.cc
@@ -63,6 +63,8 @@
options.half_expressions.push_back("half");
+ options.sub_token_separator_codepoints.push_back('-');
+
flatbuffers::FlatBufferBuilder builder;
builder.Finish(DurationAnnotatorOptions::Pack(builder, &options));
return new flatbuffers::DetachedBuffer(builder.Release());
@@ -71,7 +73,7 @@
return flatbuffers::GetRoot<DurationAnnotatorOptions>(options_data->data());
}
-FeatureProcessor BuildFeatureProcessor(const UniLib* unilib) {
+std::unique_ptr<FeatureProcessor> BuildFeatureProcessor(const UniLib* unilib) {
static const flatbuffers::DetachedBuffer* options_data = []() {
FeatureProcessorOptionsT options;
options.context_size = 1;
@@ -94,7 +96,8 @@
const FeatureProcessorOptions* feature_processor_options =
flatbuffers::GetRoot<FeatureProcessorOptions>(options_data->data());
- return FeatureProcessor(feature_processor_options, unilib);
+ return std::unique_ptr<FeatureProcessor>(
+ new FeatureProcessor(feature_processor_options, unilib));
}
class DurationAnnotatorTest : public ::testing::Test {
@@ -103,14 +106,14 @@
: INIT_UNILIB_FOR_TESTING(unilib_),
feature_processor_(BuildFeatureProcessor(&unilib_)),
duration_annotator_(TestingDurationAnnotatorOptions(),
- &feature_processor_) {}
+ feature_processor_.get()) {}
std::vector<Token> Tokenize(const UnicodeText& text) {
- return feature_processor_.Tokenize(text);
+ return feature_processor_->Tokenize(text);
}
UniLib unilib_;
- FeatureProcessor feature_processor_;
+ std::unique_ptr<FeatureProcessor> feature_processor_;
DurationAnnotator duration_annotator_;
};
@@ -316,5 +319,36 @@
10 * 60 * 1000 + 2 * 1000)))))));
}
+TEST_F(DurationAnnotatorTest, FindsCorrectlyWithCombinedQuantityUnitToken) {
+ const UnicodeText text = UTF8ToUnicodeText("Show 5-minute timer.");
+ std::vector<Token> tokens = Tokenize(text);
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(duration_annotator_.FindAll(
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+
+ EXPECT_THAT(
+ result,
+ ElementsAre(
+ AllOf(Field(&AnnotatedSpan::span, CodepointSpan(5, 13)),
+ Field(&AnnotatedSpan::classification,
+ ElementsAre(AllOf(
+ Field(&ClassificationResult::collection, "duration"),
+ Field(&ClassificationResult::duration_ms,
+ 5 * 60 * 1000)))))));
+}
+
+TEST_F(DurationAnnotatorTest,
+ DoesNotIntOverflowWithDurationThatHasMoreThanInt32Millis) {
+ ClassificationResult classification;
+ EXPECT_TRUE(duration_annotator_.ClassifyText(
+ UTF8ToUnicodeText("1400 hours"), {0, 10},
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification));
+
+ EXPECT_THAT(classification,
+ AllOf(Field(&ClassificationResult::collection, "duration"),
+ Field(&ClassificationResult::duration_ms,
+ 1400L * 60L * 60L * 1000L)));
+}
+
} // namespace
} // namespace libtextclassifier3
diff --git a/native/annotator/entity-data.fbs b/native/annotator/entity-data.fbs
index 2143e28..6da3dd5 100755
--- a/native/annotator/entity-data.fbs
+++ b/native/annotator/entity-data.fbs
@@ -26,26 +26,103 @@
GRANULARITY_SECOND = 6,
}
+namespace libtextclassifier3.EntityData_.Datetime_.DatetimeComponent_;
+enum ComponentType : int {
+ UNSPECIFIED = 0,
+ YEAR = 1,
+ MONTH = 2,
+ WEEK = 3,
+ DAY_OF_WEEK = 4,
+ DAY_OF_MONTH = 5,
+ HOUR = 6,
+ MINUTE = 7,
+ SECOND = 8,
+ MERIDIEM = 9,
+ ZONE_OFFSET = 10,
+ DST_OFFSET = 11,
+}
+
+// Enum to identify if the datetime component are relative or absolute.
+namespace libtextclassifier3.EntityData_.Datetime_.DatetimeComponent_;
+enum RelationType : int {
+ RELATION_UNSPECIFIED = 0,
+
+ // Absolute represents the datetime component that need no further
+ // calculation e.g. in a datetime span "21-03-2019" components
+ // year=2019, month=3 and day=21 is explicitly mentioned in the span
+ ABSOLUTE = 1,
+
+ // Identify datetime component where datetime expressions are relative.
+ // e.g. "three days ago", "2 days after March 1st", "next monday",
+ // "last Mondays".
+ RELATIVE = 2,
+}
+
+namespace libtextclassifier3.EntityData_.Datetime_;
+table DatetimeComponent {
+ component_type:DatetimeComponent_.ComponentType = UNSPECIFIED;
+ absolute_value:int;
+ relative_count:int;
+ relation_type:DatetimeComponent_.RelationType = RELATION_UNSPECIFIED;
+}
+
namespace libtextclassifier3.EntityData_;
table Datetime {
time_ms_utc:long;
granularity:Datetime_.Granularity = GRANULARITY_UNKNOWN;
+ datetime_component:[Datetime_.DatetimeComponent];
}
namespace libtextclassifier3.EntityData_;
table Contact {
- name:string;
- given_name:string;
- nickname:string;
- email_address:string;
- phone_number:string;
- contact_id:string;
+ name:string (shared);
+ given_name:string (shared);
+ nickname:string (shared);
+ email_address:string (shared);
+ phone_number:string (shared);
+ contact_id:string (shared);
}
namespace libtextclassifier3.EntityData_;
table App {
- name:string;
- package_name:string;
+ name:string (shared);
+ package_name:string (shared);
+}
+
+// The issuer/network of the payment card.
+namespace libtextclassifier3.EntityData_.PaymentCard_;
+enum CardNetwork : int {
+ UNKNOWN_CARD_NETWORK = 0,
+ AMEX = 1,
+ DINERS_CLUB = 2,
+ DISCOVER = 3,
+ INTER_PAYMENT = 4,
+ JCB = 5,
+ MAESTRO = 6,
+ MASTERCARD = 7,
+ MIR = 8,
+ TROY = 9,
+ UNIONPAY = 10,
+ VISA = 11,
+}
+
+// Details about a payment card.
+namespace libtextclassifier3.EntityData_;
+table PaymentCard {
+ card_network:PaymentCard_.CardNetwork;
+
+ // The card number.
+ card_number:string (shared);
+}
+
+// Details about a flight number.
+namespace libtextclassifier3.EntityData_;
+table Flight {
+ // The IATA or ICAO airline code of the flight number.
+ airline_code:string (shared);
+
+ // The flight number.
+ flight_number:string (shared);
}
// Represents an entity annotated in text.
@@ -58,12 +135,14 @@
end:int;
// The entity type, as in the TextClassifier APIs.
- type:string;
+ type:string (shared);
datetime:EntityData_.Datetime;
reserved_5:int (deprecated);
contact:EntityData_.Contact;
app:EntityData_.App;
+ payment_card:EntityData_.PaymentCard;
+ flight:EntityData_.Flight;
}
root_type libtextclassifier3.EntityData;
diff --git a/native/annotator/feature-processor.cc b/native/annotator/feature-processor.cc
index c0f5c82..1d3b8f5 100644
--- a/native/annotator/feature-processor.cc
+++ b/native/annotator/feature-processor.cc
@@ -147,7 +147,8 @@
void FeatureProcessor::StripTokensFromOtherLines(
const UnicodeText& context_unicode, CodepointSpan span,
std::vector<Token>* tokens) const {
- std::vector<UnicodeTextRange> lines = SplitContext(context_unicode);
+ std::vector<UnicodeTextRange> lines =
+ SplitContext(context_unicode, options_->use_pipe_character_for_newline());
auto span_start = context_unicode.begin();
if (span.first > 0) {
@@ -485,6 +486,15 @@
const UnicodeText::const_iterator& span_start,
const UnicodeText::const_iterator& span_end,
bool count_from_beginning) const {
+ return CountIgnoredSpanBoundaryCodepoints(span_start, span_end,
+ count_from_beginning,
+ ignored_span_boundary_codepoints_);
+}
+
+int FeatureProcessor::CountIgnoredSpanBoundaryCodepoints(
+ const UnicodeText::const_iterator& span_start,
+ const UnicodeText::const_iterator& span_end, bool count_from_beginning,
+ const std::unordered_set<int>& ignored_span_boundary_codepoints) const {
if (span_start == span_end) {
return 0;
}
@@ -507,8 +517,8 @@
// Move until we encounter a non-ignored character.
int num_ignored = 0;
- while (ignored_span_boundary_codepoints_.find(*it) !=
- ignored_span_boundary_codepoints_.end()) {
+ while (ignored_span_boundary_codepoints.find(*it) !=
+ ignored_span_boundary_codepoints.end()) {
++num_ignored;
if (it == it_last) {
@@ -549,22 +559,48 @@
} // namespace
std::vector<UnicodeTextRange> FeatureProcessor::SplitContext(
- const UnicodeText& context_unicode) const {
+ const UnicodeText& context_unicode,
+ const bool use_pipe_character_for_newline) const {
std::vector<UnicodeTextRange> lines;
- const std::set<char32> codepoints{{'\n', '|'}};
+ std::set<char32> codepoints{'\n'};
+ if (use_pipe_character_for_newline) {
+ codepoints.insert('|');
+ }
FindSubstrings(context_unicode, codepoints, &lines);
return lines;
}
CodepointSpan FeatureProcessor::StripBoundaryCodepoints(
const std::string& context, CodepointSpan span) const {
+ return StripBoundaryCodepoints(context, span,
+ ignored_span_boundary_codepoints_,
+ ignored_span_boundary_codepoints_);
+}
+
+CodepointSpan FeatureProcessor::StripBoundaryCodepoints(
+ const std::string& context, CodepointSpan span,
+ const std::unordered_set<int>& ignored_prefix_span_boundary_codepoints,
+ const std::unordered_set<int>& ignored_suffix_span_boundary_codepoints)
+ const {
const UnicodeText context_unicode =
UTF8ToUnicodeText(context, /*do_copy=*/false);
- return StripBoundaryCodepoints(context_unicode, span);
+ return StripBoundaryCodepoints(context_unicode, span,
+ ignored_prefix_span_boundary_codepoints,
+ ignored_suffix_span_boundary_codepoints);
}
CodepointSpan FeatureProcessor::StripBoundaryCodepoints(
const UnicodeText& context_unicode, CodepointSpan span) const {
+ return StripBoundaryCodepoints(context_unicode, span,
+ ignored_span_boundary_codepoints_,
+ ignored_span_boundary_codepoints_);
+}
+
+CodepointSpan FeatureProcessor::StripBoundaryCodepoints(
+ const UnicodeText& context_unicode, CodepointSpan span,
+ const std::unordered_set<int>& ignored_prefix_span_boundary_codepoints,
+ const std::unordered_set<int>& ignored_suffix_span_boundary_codepoints)
+ const {
if (context_unicode.empty() || !ValidNonEmptySpan(span)) {
return span;
}
@@ -574,20 +610,35 @@
UnicodeText::const_iterator span_end = context_unicode.begin();
std::advance(span_end, span.second);
- return StripBoundaryCodepoints(span_begin, span_end, span);
+ return StripBoundaryCodepoints(span_begin, span_end, span,
+ ignored_prefix_span_boundary_codepoints,
+ ignored_suffix_span_boundary_codepoints);
}
CodepointSpan FeatureProcessor::StripBoundaryCodepoints(
const UnicodeText::const_iterator& span_begin,
const UnicodeText::const_iterator& span_end, CodepointSpan span) const {
+ return StripBoundaryCodepoints(span_begin, span_end, span,
+ ignored_span_boundary_codepoints_,
+ ignored_span_boundary_codepoints_);
+}
+
+CodepointSpan FeatureProcessor::StripBoundaryCodepoints(
+ const UnicodeText::const_iterator& span_begin,
+ const UnicodeText::const_iterator& span_end, CodepointSpan span,
+ const std::unordered_set<int>& ignored_prefix_span_boundary_codepoints,
+ const std::unordered_set<int>& ignored_suffix_span_boundary_codepoints)
+ const {
if (!ValidNonEmptySpan(span) || span_begin == span_end) {
return span;
}
const int start_offset = CountIgnoredSpanBoundaryCodepoints(
- span_begin, span_end, /*count_from_beginning=*/true);
+ span_begin, span_end, /*count_from_beginning=*/true,
+ ignored_prefix_span_boundary_codepoints);
const int end_offset = CountIgnoredSpanBoundaryCodepoints(
- span_begin, span_end, /*count_from_beginning=*/false);
+ span_begin, span_end, /*count_from_beginning=*/false,
+ ignored_suffix_span_boundary_codepoints);
if (span.first + start_offset < span.second - end_offset) {
return {span.first + start_offset, span.second - end_offset};
@@ -615,10 +666,21 @@
const std::string& FeatureProcessor::StripBoundaryCodepoints(
const std::string& value, std::string* buffer) const {
+ return StripBoundaryCodepoints(value, buffer,
+ ignored_span_boundary_codepoints_,
+ ignored_span_boundary_codepoints_);
+}
+
+const std::string& FeatureProcessor::StripBoundaryCodepoints(
+ const std::string& value, std::string* buffer,
+ const std::unordered_set<int>& ignored_prefix_span_boundary_codepoints,
+ const std::unordered_set<int>& ignored_suffix_span_boundary_codepoints)
+ const {
const UnicodeText value_unicode = UTF8ToUnicodeText(value, /*do_copy=*/false);
const CodepointSpan initial_span{0, value_unicode.size_codepoints()};
- const CodepointSpan stripped_span =
- StripBoundaryCodepoints(value_unicode, initial_span);
+ const CodepointSpan stripped_span = StripBoundaryCodepoints(
+ value_unicode, initial_span, ignored_prefix_span_boundary_codepoints,
+ ignored_suffix_span_boundary_codepoints);
if (initial_span != stripped_span) {
const UnicodeText stripped_token_value =
diff --git a/native/annotator/feature-processor.h b/native/annotator/feature-processor.h
index 4a753b0..2245b66 100644
--- a/native/annotator/feature-processor.h
+++ b/native/annotator/feature-processor.h
@@ -169,7 +169,8 @@
// Splits context to several segments.
std::vector<UnicodeTextRange> SplitContext(
- const UnicodeText& context_unicode) const;
+ const UnicodeText& context_unicode,
+ const bool use_pipe_character_for_newline) const;
// Strips boundary codepoints from the span in context and returns the new
// start and end indices. If the span comprises entirely of boundary
@@ -177,21 +178,50 @@
CodepointSpan StripBoundaryCodepoints(const std::string& context,
CodepointSpan span) const;
+ // Same as previous, but also takes the ignored span boundary codepoints.
+ CodepointSpan StripBoundaryCodepoints(
+ const std::string& context, CodepointSpan span,
+ const std::unordered_set<int>& ignored_prefix_span_boundary_codepoints,
+ const std::unordered_set<int>& ignored_suffix_span_boundary_codepoints)
+ const;
+
// Same as above but takes UnicodeText.
CodepointSpan StripBoundaryCodepoints(const UnicodeText& context_unicode,
CodepointSpan span) const;
+ // Same as the previous, but also takes the ignored span boundary codepoints.
+ CodepointSpan StripBoundaryCodepoints(
+ const UnicodeText& context_unicode, CodepointSpan span,
+ const std::unordered_set<int>& ignored_prefix_span_boundary_codepoints,
+ const std::unordered_set<int>& ignored_suffix_span_boundary_codepoints)
+ const;
+
// Same as above but takes a pair of iterators for the span, for efficiency.
CodepointSpan StripBoundaryCodepoints(
const UnicodeText::const_iterator& span_begin,
const UnicodeText::const_iterator& span_end, CodepointSpan span) const;
+ // Same as previous, but also takes the ignored span boundary codepoints.
+ CodepointSpan StripBoundaryCodepoints(
+ const UnicodeText::const_iterator& span_begin,
+ const UnicodeText::const_iterator& span_end, CodepointSpan span,
+ const std::unordered_set<int>& ignored_prefix_span_boundary_codepoints,
+ const std::unordered_set<int>& ignored_suffix_span_boundary_codepoints)
+ const;
+
// Same as above, but takes an optional buffer for saving the modified value.
// As an optimization, returns pointer to 'value' if nothing was stripped, or
// pointer to 'buffer' if something was stripped.
const std::string& StripBoundaryCodepoints(const std::string& value,
std::string* buffer) const;
+ // Same as previous, but also takes the ignored span boundary codepoints.
+ const std::string& StripBoundaryCodepoints(
+ const std::string& value, std::string* buffer,
+ const std::unordered_set<int>& ignored_prefix_span_boundary_codepoints,
+ const std::unordered_set<int>& ignored_suffix_span_boundary_codepoints)
+ const;
+
protected:
// Returns the class id corresponding to the given string collection
// identifier. There is a catch-all class id that the function returns for
@@ -237,6 +267,12 @@
const UnicodeText::const_iterator& span_end,
bool count_from_beginning) const;
+ // Same as previous, but also takes the ignored span boundary codepoints.
+ int CountIgnoredSpanBoundaryCodepoints(
+ const UnicodeText::const_iterator& span_start,
+ const UnicodeText::const_iterator& span_end, bool count_from_beginning,
+ const std::unordered_set<int>& ignored_span_boundary_codepoints) const;
+
// Finds the center token index in tokens vector, using the method defined
// in options_.
int FindCenterToken(CodepointSpan span,
@@ -271,7 +307,7 @@
private:
// Set of codepoints that will be stripped from beginning and end of
// predicted spans.
- std::set<int32> ignored_span_boundary_codepoints_;
+ std::unordered_set<int32> ignored_span_boundary_codepoints_;
const FeatureProcessorOptions* const options_;
diff --git a/native/annotator/feature-processor_test.cc b/native/annotator/feature-processor_test.cc
deleted file mode 100644
index 5337776..0000000
--- a/native/annotator/feature-processor_test.cc
+++ /dev/null
@@ -1,975 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "annotator/feature-processor.h"
-
-#include "annotator/model-executor.h"
-#include "utils/tensor-view.h"
-
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-
-namespace libtextclassifier3 {
-namespace {
-
-using testing::ElementsAreArray;
-using testing::FloatEq;
-using testing::Matcher;
-
-flatbuffers::DetachedBuffer PackFeatureProcessorOptions(
- const FeatureProcessorOptionsT& options) {
- flatbuffers::FlatBufferBuilder builder;
- builder.Finish(CreateFeatureProcessorOptions(builder, &options));
- return builder.Release();
-}
-
-template <typename T>
-std::vector<T> Subvector(const std::vector<T>& vector, int start, int end) {
- return std::vector<T>(vector.begin() + start, vector.begin() + end);
-}
-
-Matcher<std::vector<float>> ElementsAreFloat(const std::vector<float>& values) {
- std::vector<Matcher<float>> matchers;
- for (const float value : values) {
- matchers.push_back(FloatEq(value));
- }
- return ElementsAreArray(matchers);
-}
-
-class TestingFeatureProcessor : public FeatureProcessor {
- public:
- using FeatureProcessor::CountIgnoredSpanBoundaryCodepoints;
- using FeatureProcessor::FeatureProcessor;
- using FeatureProcessor::SpanToLabel;
- using FeatureProcessor::StripTokensFromOtherLines;
- using FeatureProcessor::supported_codepoint_ranges_;
- using FeatureProcessor::SupportedCodepointsRatio;
-};
-
-// EmbeddingExecutor that always returns features based on
-class FakeEmbeddingExecutor : public EmbeddingExecutor {
- public:
- bool AddEmbedding(const TensorView<int>& sparse_features, float* dest,
- int dest_size) const override {
- TC3_CHECK_GE(dest_size, 4);
- EXPECT_EQ(sparse_features.size(), 1);
- dest[0] = sparse_features.data()[0];
- dest[1] = sparse_features.data()[0];
- dest[2] = -sparse_features.data()[0];
- dest[3] = -sparse_features.data()[0];
- return true;
- }
-
- private:
- std::vector<float> storage_;
-};
-
-class FeatureProcessorTest : public ::testing::Test {
- protected:
- FeatureProcessorTest() : INIT_UNILIB_FOR_TESTING(unilib_) {}
- UniLib unilib_;
-};
-
-TEST_F(FeatureProcessorTest, SplitTokensOnSelectionBoundariesMiddle) {
- std::vector<Token> tokens{Token("Hělló", 0, 5),
- Token("fěěbař@google.com", 6, 23),
- Token("heře!", 24, 29)};
-
- internal::SplitTokensOnSelectionBoundaries({9, 12}, &tokens);
-
- // clang-format off
- EXPECT_THAT(tokens, ElementsAreArray(
- {Token("Hělló", 0, 5),
- Token("fěě", 6, 9),
- Token("bař", 9, 12),
- Token("@google.com", 12, 23),
- Token("heře!", 24, 29)}));
- // clang-format on
-}
-
-TEST_F(FeatureProcessorTest, SplitTokensOnSelectionBoundariesBegin) {
- std::vector<Token> tokens{Token("Hělló", 0, 5),
- Token("fěěbař@google.com", 6, 23),
- Token("heře!", 24, 29)};
-
- internal::SplitTokensOnSelectionBoundaries({6, 12}, &tokens);
-
- // clang-format off
- EXPECT_THAT(tokens, ElementsAreArray(
- {Token("Hělló", 0, 5),
- Token("fěěbař", 6, 12),
- Token("@google.com", 12, 23),
- Token("heře!", 24, 29)}));
- // clang-format on
-}
-
-TEST_F(FeatureProcessorTest, SplitTokensOnSelectionBoundariesEnd) {
- std::vector<Token> tokens{Token("Hělló", 0, 5),
- Token("fěěbař@google.com", 6, 23),
- Token("heře!", 24, 29)};
-
- internal::SplitTokensOnSelectionBoundaries({9, 23}, &tokens);
-
- // clang-format off
- EXPECT_THAT(tokens, ElementsAreArray(
- {Token("Hělló", 0, 5),
- Token("fěě", 6, 9),
- Token("bař@google.com", 9, 23),
- Token("heře!", 24, 29)}));
- // clang-format on
-}
-
-TEST_F(FeatureProcessorTest, SplitTokensOnSelectionBoundariesWhole) {
- std::vector<Token> tokens{Token("Hělló", 0, 5),
- Token("fěěbař@google.com", 6, 23),
- Token("heře!", 24, 29)};
-
- internal::SplitTokensOnSelectionBoundaries({6, 23}, &tokens);
-
- // clang-format off
- EXPECT_THAT(tokens, ElementsAreArray(
- {Token("Hělló", 0, 5),
- Token("fěěbař@google.com", 6, 23),
- Token("heře!", 24, 29)}));
- // clang-format on
-}
-
-TEST_F(FeatureProcessorTest, SplitTokensOnSelectionBoundariesCrossToken) {
- std::vector<Token> tokens{Token("Hělló", 0, 5),
- Token("fěěbař@google.com", 6, 23),
- Token("heře!", 24, 29)};
-
- internal::SplitTokensOnSelectionBoundaries({2, 9}, &tokens);
-
- // clang-format off
- EXPECT_THAT(tokens, ElementsAreArray(
- {Token("Hě", 0, 2),
- Token("lló", 2, 5),
- Token("fěě", 6, 9),
- Token("bař@google.com", 9, 23),
- Token("heře!", 24, 29)}));
- // clang-format on
-}
-
-TEST_F(FeatureProcessorTest, KeepLineWithClickFirst) {
- FeatureProcessorOptionsT options;
- options.only_use_line_with_click = true;
- flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
- TestingFeatureProcessor feature_processor(
- flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
- &unilib_);
-
- const std::string context = "Fiřst Lině\nSěcond Lině\nThiřd Lině";
- const CodepointSpan span = {0, 5};
- // clang-format off
- std::vector<Token> tokens = {Token("Fiřst", 0, 5),
- Token("Lině", 6, 10),
- Token("Sěcond", 11, 17),
- Token("Lině", 18, 22),
- Token("Thiřd", 23, 28),
- Token("Lině", 29, 33)};
- // clang-format on
-
- // Keeps the first line.
- feature_processor.StripTokensFromOtherLines(context, span, &tokens);
- EXPECT_THAT(tokens,
- ElementsAreArray({Token("Fiřst", 0, 5), Token("Lině", 6, 10)}));
-}
-
-TEST_F(FeatureProcessorTest, KeepLineWithClickSecond) {
- FeatureProcessorOptionsT options;
- options.only_use_line_with_click = true;
- flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
- TestingFeatureProcessor feature_processor(
- flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
- &unilib_);
-
- const std::string context = "Fiřst Lině\nSěcond Lině\nThiřd Lině";
- const CodepointSpan span = {18, 22};
- // clang-format off
- std::vector<Token> tokens = {Token("Fiřst", 0, 5),
- Token("Lině", 6, 10),
- Token("Sěcond", 11, 17),
- Token("Lině", 18, 22),
- Token("Thiřd", 23, 28),
- Token("Lině", 29, 33)};
- // clang-format on
-
- // Keeps the first line.
- feature_processor.StripTokensFromOtherLines(context, span, &tokens);
- EXPECT_THAT(tokens, ElementsAreArray(
- {Token("Sěcond", 11, 17), Token("Lině", 18, 22)}));
-}
-
-TEST_F(FeatureProcessorTest, KeepLineWithClickThird) {
- FeatureProcessorOptionsT options;
- options.only_use_line_with_click = true;
- flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
- TestingFeatureProcessor feature_processor(
- flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
- &unilib_);
-
- const std::string context = "Fiřst Lině\nSěcond Lině\nThiřd Lině";
- const CodepointSpan span = {24, 33};
- // clang-format off
- std::vector<Token> tokens = {Token("Fiřst", 0, 5),
- Token("Lině", 6, 10),
- Token("Sěcond", 11, 17),
- Token("Lině", 18, 22),
- Token("Thiřd", 23, 28),
- Token("Lině", 29, 33)};
- // clang-format on
-
- // Keeps the first line.
- feature_processor.StripTokensFromOtherLines(context, span, &tokens);
- EXPECT_THAT(tokens, ElementsAreArray(
- {Token("Thiřd", 23, 28), Token("Lině", 29, 33)}));
-}
-
-TEST_F(FeatureProcessorTest, KeepLineWithClickSecondWithPipe) {
- FeatureProcessorOptionsT options;
- options.only_use_line_with_click = true;
- flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
- TestingFeatureProcessor feature_processor(
- flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
- &unilib_);
-
- const std::string context = "Fiřst Lině|Sěcond Lině\nThiřd Lině";
- const CodepointSpan span = {18, 22};
- // clang-format off
- std::vector<Token> tokens = {Token("Fiřst", 0, 5),
- Token("Lině", 6, 10),
- Token("Sěcond", 11, 17),
- Token("Lině", 18, 22),
- Token("Thiřd", 23, 28),
- Token("Lině", 29, 33)};
- // clang-format on
-
- // Keeps the first line.
- feature_processor.StripTokensFromOtherLines(context, span, &tokens);
- EXPECT_THAT(tokens, ElementsAreArray(
- {Token("Sěcond", 11, 17), Token("Lině", 18, 22)}));
-}
-
-TEST_F(FeatureProcessorTest, KeepLineWithCrosslineClick) {
- FeatureProcessorOptionsT options;
- options.only_use_line_with_click = true;
- flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
- TestingFeatureProcessor feature_processor(
- flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
- &unilib_);
-
- const std::string context = "Fiřst Lině\nSěcond Lině\nThiřd Lině";
- const CodepointSpan span = {5, 23};
- // clang-format off
- std::vector<Token> tokens = {Token("Fiřst", 0, 5),
- Token("Lině", 6, 10),
- Token("Sěcond", 18, 23),
- Token("Lině", 19, 23),
- Token("Thiřd", 23, 28),
- Token("Lině", 29, 33)};
- // clang-format on
-
- // Keeps the first line.
- feature_processor.StripTokensFromOtherLines(context, span, &tokens);
- EXPECT_THAT(tokens, ElementsAreArray(
- {Token("Fiřst", 0, 5), Token("Lině", 6, 10),
- Token("Sěcond", 18, 23), Token("Lině", 19, 23),
- Token("Thiřd", 23, 28), Token("Lině", 29, 33)}));
-}
-
-TEST_F(FeatureProcessorTest, SpanToLabel) {
- FeatureProcessorOptionsT options;
- options.context_size = 1;
- options.max_selection_span = 1;
- options.snap_label_span_boundaries_to_containing_tokens = false;
-
- options.tokenization_codepoint_config.emplace_back(
- new TokenizationCodepointRangeT());
- auto& config = options.tokenization_codepoint_config.back();
- config->start = 32;
- config->end = 33;
- config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR;
-
- flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
- TestingFeatureProcessor feature_processor(
- flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
- &unilib_);
- std::vector<Token> tokens = feature_processor.Tokenize("one, two, three");
- ASSERT_EQ(3, tokens.size());
- int label;
- ASSERT_TRUE(feature_processor.SpanToLabel({5, 8}, tokens, &label));
- EXPECT_EQ(kInvalidLabel, label);
- ASSERT_TRUE(feature_processor.SpanToLabel({5, 9}, tokens, &label));
- EXPECT_NE(kInvalidLabel, label);
- TokenSpan token_span;
- feature_processor.LabelToTokenSpan(label, &token_span);
- EXPECT_EQ(0, token_span.first);
- EXPECT_EQ(0, token_span.second);
-
- // Reconfigure with snapping enabled.
- options.snap_label_span_boundaries_to_containing_tokens = true;
- flatbuffers::DetachedBuffer options2_fb =
- PackFeatureProcessorOptions(options);
- TestingFeatureProcessor feature_processor2(
- flatbuffers::GetRoot<FeatureProcessorOptions>(options2_fb.data()),
- &unilib_);
- int label2;
- ASSERT_TRUE(feature_processor2.SpanToLabel({5, 8}, tokens, &label2));
- EXPECT_EQ(label, label2);
- ASSERT_TRUE(feature_processor2.SpanToLabel({6, 9}, tokens, &label2));
- EXPECT_EQ(label, label2);
- ASSERT_TRUE(feature_processor2.SpanToLabel({5, 9}, tokens, &label2));
- EXPECT_EQ(label, label2);
-
- // Cross a token boundary.
- ASSERT_TRUE(feature_processor2.SpanToLabel({4, 9}, tokens, &label2));
- EXPECT_EQ(kInvalidLabel, label2);
- ASSERT_TRUE(feature_processor2.SpanToLabel({5, 10}, tokens, &label2));
- EXPECT_EQ(kInvalidLabel, label2);
-
- // Multiple tokens.
- options.context_size = 2;
- options.max_selection_span = 2;
- flatbuffers::DetachedBuffer options3_fb =
- PackFeatureProcessorOptions(options);
- TestingFeatureProcessor feature_processor3(
- flatbuffers::GetRoot<FeatureProcessorOptions>(options3_fb.data()),
- &unilib_);
- tokens = feature_processor3.Tokenize("zero, one, two, three, four");
- ASSERT_TRUE(feature_processor3.SpanToLabel({6, 15}, tokens, &label2));
- EXPECT_NE(kInvalidLabel, label2);
- feature_processor3.LabelToTokenSpan(label2, &token_span);
- EXPECT_EQ(1, token_span.first);
- EXPECT_EQ(0, token_span.second);
-
- int label3;
- ASSERT_TRUE(feature_processor3.SpanToLabel({6, 14}, tokens, &label3));
- EXPECT_EQ(label2, label3);
- ASSERT_TRUE(feature_processor3.SpanToLabel({6, 13}, tokens, &label3));
- EXPECT_EQ(label2, label3);
- ASSERT_TRUE(feature_processor3.SpanToLabel({7, 13}, tokens, &label3));
- EXPECT_EQ(label2, label3);
-}
-
-TEST_F(FeatureProcessorTest, SpanToLabelIgnoresPunctuation) {
- FeatureProcessorOptionsT options;
- options.context_size = 1;
- options.max_selection_span = 1;
- options.snap_label_span_boundaries_to_containing_tokens = false;
-
- options.tokenization_codepoint_config.emplace_back(
- new TokenizationCodepointRangeT());
- auto& config = options.tokenization_codepoint_config.back();
- config->start = 32;
- config->end = 33;
- config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR;
-
- flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
- TestingFeatureProcessor feature_processor(
- flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
- &unilib_);
- std::vector<Token> tokens = feature_processor.Tokenize("one, two, three");
- ASSERT_EQ(3, tokens.size());
- int label;
- ASSERT_TRUE(feature_processor.SpanToLabel({5, 8}, tokens, &label));
- EXPECT_EQ(kInvalidLabel, label);
- ASSERT_TRUE(feature_processor.SpanToLabel({5, 9}, tokens, &label));
- EXPECT_NE(kInvalidLabel, label);
- TokenSpan token_span;
- feature_processor.LabelToTokenSpan(label, &token_span);
- EXPECT_EQ(0, token_span.first);
- EXPECT_EQ(0, token_span.second);
-
- // Reconfigure with snapping enabled.
- options.snap_label_span_boundaries_to_containing_tokens = true;
- flatbuffers::DetachedBuffer options2_fb =
- PackFeatureProcessorOptions(options);
- TestingFeatureProcessor feature_processor2(
- flatbuffers::GetRoot<FeatureProcessorOptions>(options2_fb.data()),
- &unilib_);
- int label2;
- ASSERT_TRUE(feature_processor2.SpanToLabel({5, 8}, tokens, &label2));
- EXPECT_EQ(label, label2);
- ASSERT_TRUE(feature_processor2.SpanToLabel({6, 9}, tokens, &label2));
- EXPECT_EQ(label, label2);
- ASSERT_TRUE(feature_processor2.SpanToLabel({5, 9}, tokens, &label2));
- EXPECT_EQ(label, label2);
-
- // Cross a token boundary.
- ASSERT_TRUE(feature_processor2.SpanToLabel({4, 9}, tokens, &label2));
- EXPECT_EQ(kInvalidLabel, label2);
- ASSERT_TRUE(feature_processor2.SpanToLabel({5, 10}, tokens, &label2));
- EXPECT_EQ(kInvalidLabel, label2);
-
- // Multiple tokens.
- options.context_size = 2;
- options.max_selection_span = 2;
- flatbuffers::DetachedBuffer options3_fb =
- PackFeatureProcessorOptions(options);
- TestingFeatureProcessor feature_processor3(
- flatbuffers::GetRoot<FeatureProcessorOptions>(options3_fb.data()),
- &unilib_);
- tokens = feature_processor3.Tokenize("zero, one, two, three, four");
- ASSERT_TRUE(feature_processor3.SpanToLabel({6, 15}, tokens, &label2));
- EXPECT_NE(kInvalidLabel, label2);
- feature_processor3.LabelToTokenSpan(label2, &token_span);
- EXPECT_EQ(1, token_span.first);
- EXPECT_EQ(0, token_span.second);
-
- int label3;
- ASSERT_TRUE(feature_processor3.SpanToLabel({6, 14}, tokens, &label3));
- EXPECT_EQ(label2, label3);
- ASSERT_TRUE(feature_processor3.SpanToLabel({6, 13}, tokens, &label3));
- EXPECT_EQ(label2, label3);
- ASSERT_TRUE(feature_processor3.SpanToLabel({7, 13}, tokens, &label3));
- EXPECT_EQ(label2, label3);
-}
-
-TEST_F(FeatureProcessorTest, CenterTokenFromClick) {
- int token_index;
-
- // Exactly aligned indices.
- token_index = internal::CenterTokenFromClick(
- {6, 11},
- {Token("Hělló", 0, 5), Token("world", 6, 11), Token("heře!", 12, 17)});
- EXPECT_EQ(token_index, 1);
-
- // Click is contained in a token.
- token_index = internal::CenterTokenFromClick(
- {13, 17},
- {Token("Hělló", 0, 5), Token("world", 6, 11), Token("heře!", 12, 17)});
- EXPECT_EQ(token_index, 2);
-
- // Click spans two tokens.
- token_index = internal::CenterTokenFromClick(
- {6, 17},
- {Token("Hělló", 0, 5), Token("world", 6, 11), Token("heře!", 12, 17)});
- EXPECT_EQ(token_index, kInvalidIndex);
-}
-
-TEST_F(FeatureProcessorTest, CenterTokenFromMiddleOfSelection) {
- int token_index;
-
- // Selection of length 3. Exactly aligned indices.
- token_index = internal::CenterTokenFromMiddleOfSelection(
- {7, 27},
- {Token("Token1", 0, 6), Token("Token2", 7, 13), Token("Token3", 14, 20),
- Token("Token4", 21, 27), Token("Token5", 28, 34)});
- EXPECT_EQ(token_index, 2);
-
- // Selection of length 1 token. Exactly aligned indices.
- token_index = internal::CenterTokenFromMiddleOfSelection(
- {21, 27},
- {Token("Token1", 0, 6), Token("Token2", 7, 13), Token("Token3", 14, 20),
- Token("Token4", 21, 27), Token("Token5", 28, 34)});
- EXPECT_EQ(token_index, 3);
-
- // Selection marks sub-token range, with no tokens in it.
- token_index = internal::CenterTokenFromMiddleOfSelection(
- {29, 33},
- {Token("Token1", 0, 6), Token("Token2", 7, 13), Token("Token3", 14, 20),
- Token("Token4", 21, 27), Token("Token5", 28, 34)});
- EXPECT_EQ(token_index, kInvalidIndex);
-
- // Selection of length 2. Sub-token indices.
- token_index = internal::CenterTokenFromMiddleOfSelection(
- {3, 25},
- {Token("Token1", 0, 6), Token("Token2", 7, 13), Token("Token3", 14, 20),
- Token("Token4", 21, 27), Token("Token5", 28, 34)});
- EXPECT_EQ(token_index, 1);
-
- // Selection of length 1. Sub-token indices.
- token_index = internal::CenterTokenFromMiddleOfSelection(
- {22, 34},
- {Token("Token1", 0, 6), Token("Token2", 7, 13), Token("Token3", 14, 20),
- Token("Token4", 21, 27), Token("Token5", 28, 34)});
- EXPECT_EQ(token_index, 4);
-
- // Some invalid ones.
- token_index = internal::CenterTokenFromMiddleOfSelection({7, 27}, {});
- EXPECT_EQ(token_index, -1);
-}
-
-TEST_F(FeatureProcessorTest, SupportedCodepointsRatio) {
- FeatureProcessorOptionsT options;
- options.context_size = 2;
- options.max_selection_span = 2;
- options.snap_label_span_boundaries_to_containing_tokens = false;
- options.feature_version = 2;
- options.embedding_size = 4;
- options.bounds_sensitive_features.reset(
- new FeatureProcessorOptions_::BoundsSensitiveFeaturesT());
- options.bounds_sensitive_features->enabled = true;
- options.bounds_sensitive_features->num_tokens_before = 5;
- options.bounds_sensitive_features->num_tokens_inside_left = 3;
- options.bounds_sensitive_features->num_tokens_inside_right = 3;
- options.bounds_sensitive_features->num_tokens_after = 5;
- options.bounds_sensitive_features->include_inside_bag = true;
- options.bounds_sensitive_features->include_inside_length = true;
-
- options.tokenization_codepoint_config.emplace_back(
- new TokenizationCodepointRangeT());
- auto& config = options.tokenization_codepoint_config.back();
- config->start = 32;
- config->end = 33;
- config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR;
-
- {
- options.supported_codepoint_ranges.emplace_back(new CodepointRangeT());
- auto& range = options.supported_codepoint_ranges.back();
- range->start = 0;
- range->end = 128;
- }
-
- {
- options.supported_codepoint_ranges.emplace_back(new CodepointRangeT());
- auto& range = options.supported_codepoint_ranges.back();
- range->start = 10000;
- range->end = 10001;
- }
-
- {
- options.supported_codepoint_ranges.emplace_back(new CodepointRangeT());
- auto& range = options.supported_codepoint_ranges.back();
- range->start = 20000;
- range->end = 30000;
- }
-
- flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
- TestingFeatureProcessor feature_processor(
- flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
- &unilib_);
- EXPECT_THAT(feature_processor.SupportedCodepointsRatio(
- {0, 3}, feature_processor.Tokenize("aaa bbb ccc")),
- FloatEq(1.0));
- EXPECT_THAT(feature_processor.SupportedCodepointsRatio(
- {0, 3}, feature_processor.Tokenize("aaa bbb ěěě")),
- FloatEq(2.0 / 3));
- EXPECT_THAT(feature_processor.SupportedCodepointsRatio(
- {0, 3}, feature_processor.Tokenize("ěěě řřř ěěě")),
- FloatEq(0.0));
- EXPECT_FALSE(
- IsCodepointInRanges(-1, feature_processor.supported_codepoint_ranges_));
- EXPECT_TRUE(
- IsCodepointInRanges(0, feature_processor.supported_codepoint_ranges_));
- EXPECT_TRUE(
- IsCodepointInRanges(10, feature_processor.supported_codepoint_ranges_));
- EXPECT_TRUE(
- IsCodepointInRanges(127, feature_processor.supported_codepoint_ranges_));
- EXPECT_FALSE(
- IsCodepointInRanges(128, feature_processor.supported_codepoint_ranges_));
- EXPECT_FALSE(
- IsCodepointInRanges(9999, feature_processor.supported_codepoint_ranges_));
- EXPECT_TRUE(IsCodepointInRanges(
- 10000, feature_processor.supported_codepoint_ranges_));
- EXPECT_FALSE(IsCodepointInRanges(
- 10001, feature_processor.supported_codepoint_ranges_));
- EXPECT_TRUE(IsCodepointInRanges(
- 25000, feature_processor.supported_codepoint_ranges_));
-
- const std::vector<Token> tokens = {Token("ěěě", 0, 3), Token("řřř", 4, 7),
- Token("eee", 8, 11)};
-
- options.min_supported_codepoint_ratio = 0.0;
- flatbuffers::DetachedBuffer options2_fb =
- PackFeatureProcessorOptions(options);
- TestingFeatureProcessor feature_processor2(
- flatbuffers::GetRoot<FeatureProcessorOptions>(options2_fb.data()),
- &unilib_);
- EXPECT_TRUE(feature_processor2.HasEnoughSupportedCodepoints(
- tokens, /*token_span=*/{0, 3}));
-
- options.min_supported_codepoint_ratio = 0.2;
- flatbuffers::DetachedBuffer options3_fb =
- PackFeatureProcessorOptions(options);
- TestingFeatureProcessor feature_processor3(
- flatbuffers::GetRoot<FeatureProcessorOptions>(options3_fb.data()),
- &unilib_);
- EXPECT_TRUE(feature_processor3.HasEnoughSupportedCodepoints(
- tokens, /*token_span=*/{0, 3}));
-
- options.min_supported_codepoint_ratio = 0.5;
- flatbuffers::DetachedBuffer options4_fb =
- PackFeatureProcessorOptions(options);
- TestingFeatureProcessor feature_processor4(
- flatbuffers::GetRoot<FeatureProcessorOptions>(options4_fb.data()),
- &unilib_);
- EXPECT_FALSE(feature_processor4.HasEnoughSupportedCodepoints(
- tokens, /*token_span=*/{0, 3}));
-}
-
-TEST_F(FeatureProcessorTest, InSpanFeature) {
- FeatureProcessorOptionsT options;
- options.context_size = 2;
- options.max_selection_span = 2;
- options.snap_label_span_boundaries_to_containing_tokens = false;
- options.feature_version = 2;
- options.embedding_size = 4;
- options.extract_selection_mask_feature = true;
-
- flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
- TestingFeatureProcessor feature_processor(
- flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
- &unilib_);
-
- std::unique_ptr<CachedFeatures> cached_features;
-
- FakeEmbeddingExecutor embedding_executor;
-
- const std::vector<Token> tokens = {Token("aaa", 0, 3), Token("bbb", 4, 7),
- Token("ccc", 8, 11), Token("ddd", 12, 15)};
-
- EXPECT_TRUE(feature_processor.ExtractFeatures(
- tokens, /*token_span=*/{0, 4},
- /*selection_span_for_feature=*/{4, 11}, &embedding_executor,
- /*embedding_cache=*/nullptr, /*feature_vector_size=*/5,
- &cached_features));
- std::vector<float> features;
- cached_features->AppendClickContextFeaturesForClick(1, &features);
- ASSERT_EQ(features.size(), 25);
- EXPECT_THAT(features[4], FloatEq(0.0));
- EXPECT_THAT(features[9], FloatEq(0.0));
- EXPECT_THAT(features[14], FloatEq(1.0));
- EXPECT_THAT(features[19], FloatEq(1.0));
- EXPECT_THAT(features[24], FloatEq(0.0));
-}
-
-TEST_F(FeatureProcessorTest, EmbeddingCache) {
- FeatureProcessorOptionsT options;
- options.context_size = 2;
- options.max_selection_span = 2;
- options.snap_label_span_boundaries_to_containing_tokens = false;
- options.feature_version = 2;
- options.embedding_size = 4;
- options.bounds_sensitive_features.reset(
- new FeatureProcessorOptions_::BoundsSensitiveFeaturesT());
- options.bounds_sensitive_features->enabled = true;
- options.bounds_sensitive_features->num_tokens_before = 3;
- options.bounds_sensitive_features->num_tokens_inside_left = 2;
- options.bounds_sensitive_features->num_tokens_inside_right = 2;
- options.bounds_sensitive_features->num_tokens_after = 3;
-
- flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
- TestingFeatureProcessor feature_processor(
- flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
- &unilib_);
-
- std::unique_ptr<CachedFeatures> cached_features;
-
- FakeEmbeddingExecutor embedding_executor;
-
- const std::vector<Token> tokens = {
- Token("aaa", 0, 3), Token("bbb", 4, 7), Token("ccc", 8, 11),
- Token("ddd", 12, 15), Token("eee", 16, 19), Token("fff", 20, 23)};
-
- // We pre-populate the cache with dummy embeddings, to make sure they are
- // used when populating the features vector.
- const std::vector<float> cached_padding_features = {10.0, -10.0, 10.0, -10.0};
- const std::vector<float> cached_features1 = {1.0, 2.0, 3.0, 4.0};
- const std::vector<float> cached_features2 = {5.0, 6.0, 7.0, 8.0};
- FeatureProcessor::EmbeddingCache embedding_cache = {
- {{kInvalidIndex, kInvalidIndex}, cached_padding_features},
- {{4, 7}, cached_features1},
- {{12, 15}, cached_features2},
- };
-
- EXPECT_TRUE(feature_processor.ExtractFeatures(
- tokens, /*token_span=*/{0, 6},
- /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex},
- &embedding_executor, &embedding_cache, /*feature_vector_size=*/4,
- &cached_features));
- std::vector<float> features;
- cached_features->AppendBoundsSensitiveFeaturesForSpan({2, 4}, &features);
- ASSERT_EQ(features.size(), 40);
- // Check that the dummy embeddings were used.
- EXPECT_THAT(Subvector(features, 0, 4),
- ElementsAreFloat(cached_padding_features));
- EXPECT_THAT(Subvector(features, 8, 12), ElementsAreFloat(cached_features1));
- EXPECT_THAT(Subvector(features, 16, 20), ElementsAreFloat(cached_features2));
- EXPECT_THAT(Subvector(features, 24, 28), ElementsAreFloat(cached_features2));
- EXPECT_THAT(Subvector(features, 36, 40),
- ElementsAreFloat(cached_padding_features));
- // Check that the real embeddings were cached.
- EXPECT_EQ(embedding_cache.size(), 7);
- EXPECT_THAT(Subvector(features, 4, 8),
- ElementsAreFloat(embedding_cache.at({0, 3})));
- EXPECT_THAT(Subvector(features, 12, 16),
- ElementsAreFloat(embedding_cache.at({8, 11})));
- EXPECT_THAT(Subvector(features, 20, 24),
- ElementsAreFloat(embedding_cache.at({8, 11})));
- EXPECT_THAT(Subvector(features, 28, 32),
- ElementsAreFloat(embedding_cache.at({16, 19})));
- EXPECT_THAT(Subvector(features, 32, 36),
- ElementsAreFloat(embedding_cache.at({20, 23})));
-}
-
-TEST_F(FeatureProcessorTest, StripUnusedTokensWithNoRelativeClick) {
- std::vector<Token> tokens_orig{
- Token("0", 0, 0), Token("1", 0, 0), Token("2", 0, 0), Token("3", 0, 0),
- Token("4", 0, 0), Token("5", 0, 0), Token("6", 0, 0), Token("7", 0, 0),
- Token("8", 0, 0), Token("9", 0, 0), Token("10", 0, 0), Token("11", 0, 0),
- Token("12", 0, 0)};
-
- std::vector<Token> tokens;
- int click_index;
-
- // Try to click first token and see if it gets padded from left.
- tokens = tokens_orig;
- click_index = 0;
- internal::StripOrPadTokens({0, 0}, 2, &tokens, &click_index);
- // clang-format off
- EXPECT_EQ(tokens, std::vector<Token>({Token(),
- Token(),
- Token("0", 0, 0),
- Token("1", 0, 0),
- Token("2", 0, 0)}));
- // clang-format on
- EXPECT_EQ(click_index, 2);
-
- // When we click the second token nothing should get padded.
- tokens = tokens_orig;
- click_index = 2;
- internal::StripOrPadTokens({0, 0}, 2, &tokens, &click_index);
- // clang-format off
- EXPECT_EQ(tokens, std::vector<Token>({Token("0", 0, 0),
- Token("1", 0, 0),
- Token("2", 0, 0),
- Token("3", 0, 0),
- Token("4", 0, 0)}));
- // clang-format on
- EXPECT_EQ(click_index, 2);
-
- // When we click the last token tokens should get padded from the right.
- tokens = tokens_orig;
- click_index = 12;
- internal::StripOrPadTokens({0, 0}, 2, &tokens, &click_index);
- // clang-format off
- EXPECT_EQ(tokens, std::vector<Token>({Token("10", 0, 0),
- Token("11", 0, 0),
- Token("12", 0, 0),
- Token(),
- Token()}));
- // clang-format on
- EXPECT_EQ(click_index, 2);
-}
-
-TEST_F(FeatureProcessorTest, StripUnusedTokensWithRelativeClick) {
- std::vector<Token> tokens_orig{
- Token("0", 0, 0), Token("1", 0, 0), Token("2", 0, 0), Token("3", 0, 0),
- Token("4", 0, 0), Token("5", 0, 0), Token("6", 0, 0), Token("7", 0, 0),
- Token("8", 0, 0), Token("9", 0, 0), Token("10", 0, 0), Token("11", 0, 0),
- Token("12", 0, 0)};
-
- std::vector<Token> tokens;
- int click_index;
-
- // Try to click first token and see if it gets padded from left to maximum
- // context_size.
- tokens = tokens_orig;
- click_index = 0;
- internal::StripOrPadTokens({2, 3}, 2, &tokens, &click_index);
- // clang-format off
- EXPECT_EQ(tokens, std::vector<Token>({Token(),
- Token(),
- Token("0", 0, 0),
- Token("1", 0, 0),
- Token("2", 0, 0),
- Token("3", 0, 0),
- Token("4", 0, 0),
- Token("5", 0, 0)}));
- // clang-format on
- EXPECT_EQ(click_index, 2);
-
- // Clicking to the middle with enough context should not produce any padding.
- tokens = tokens_orig;
- click_index = 6;
- internal::StripOrPadTokens({3, 1}, 2, &tokens, &click_index);
- // clang-format off
- EXPECT_EQ(tokens, std::vector<Token>({Token("1", 0, 0),
- Token("2", 0, 0),
- Token("3", 0, 0),
- Token("4", 0, 0),
- Token("5", 0, 0),
- Token("6", 0, 0),
- Token("7", 0, 0),
- Token("8", 0, 0),
- Token("9", 0, 0)}));
- // clang-format on
- EXPECT_EQ(click_index, 5);
-
- // Clicking at the end should pad right to maximum context_size.
- tokens = tokens_orig;
- click_index = 11;
- internal::StripOrPadTokens({3, 1}, 2, &tokens, &click_index);
- // clang-format off
- EXPECT_EQ(tokens, std::vector<Token>({Token("6", 0, 0),
- Token("7", 0, 0),
- Token("8", 0, 0),
- Token("9", 0, 0),
- Token("10", 0, 0),
- Token("11", 0, 0),
- Token("12", 0, 0),
- Token(),
- Token()}));
- // clang-format on
- EXPECT_EQ(click_index, 5);
-}
-
-TEST_F(FeatureProcessorTest, IgnoredSpanBoundaryCodepoints) {
- FeatureProcessorOptionsT options;
- options.ignored_span_boundary_codepoints.push_back('.');
- options.ignored_span_boundary_codepoints.push_back(',');
- options.ignored_span_boundary_codepoints.push_back('[');
- options.ignored_span_boundary_codepoints.push_back(']');
-
- flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
- TestingFeatureProcessor feature_processor(
- flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
- &unilib_);
-
- const std::string text1_utf8 = "ěščř";
- const UnicodeText text1 = UTF8ToUnicodeText(text1_utf8, /*do_copy=*/false);
- EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
- text1.begin(), text1.end(),
- /*count_from_beginning=*/true),
- 0);
- EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
- text1.begin(), text1.end(),
- /*count_from_beginning=*/false),
- 0);
-
- const std::string text2_utf8 = ".,abčd";
- const UnicodeText text2 = UTF8ToUnicodeText(text2_utf8, /*do_copy=*/false);
- EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
- text2.begin(), text2.end(),
- /*count_from_beginning=*/true),
- 2);
- EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
- text2.begin(), text2.end(),
- /*count_from_beginning=*/false),
- 0);
-
- const std::string text3_utf8 = ".,abčd[]";
- const UnicodeText text3 = UTF8ToUnicodeText(text3_utf8, /*do_copy=*/false);
- EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
- text3.begin(), text3.end(),
- /*count_from_beginning=*/true),
- 2);
- EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
- text3.begin(), text3.end(),
- /*count_from_beginning=*/false),
- 2);
-
- const std::string text4_utf8 = "[abčd]";
- const UnicodeText text4 = UTF8ToUnicodeText(text4_utf8, /*do_copy=*/false);
- EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
- text4.begin(), text4.end(),
- /*count_from_beginning=*/true),
- 1);
- EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
- text4.begin(), text4.end(),
- /*count_from_beginning=*/false),
- 1);
-
- const std::string text5_utf8 = "";
- const UnicodeText text5 = UTF8ToUnicodeText(text5_utf8, /*do_copy=*/false);
- EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
- text5.begin(), text5.end(),
- /*count_from_beginning=*/true),
- 0);
- EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
- text5.begin(), text5.end(),
- /*count_from_beginning=*/false),
- 0);
-
- const std::string text6_utf8 = "012345ěščř";
- const UnicodeText text6 = UTF8ToUnicodeText(text6_utf8, /*do_copy=*/false);
- UnicodeText::const_iterator text6_begin = text6.begin();
- std::advance(text6_begin, 6);
- EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
- text6_begin, text6.end(),
- /*count_from_beginning=*/true),
- 0);
- EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
- text6_begin, text6.end(),
- /*count_from_beginning=*/false),
- 0);
-
- const std::string text7_utf8 = "012345.,ěščř";
- const UnicodeText text7 = UTF8ToUnicodeText(text7_utf8, /*do_copy=*/false);
- UnicodeText::const_iterator text7_begin = text7.begin();
- std::advance(text7_begin, 6);
- EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
- text7_begin, text7.end(),
- /*count_from_beginning=*/true),
- 2);
- UnicodeText::const_iterator text7_end = text7.begin();
- std::advance(text7_end, 8);
- EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
- text7.begin(), text7_end,
- /*count_from_beginning=*/false),
- 2);
-
- // Test not stripping.
- EXPECT_EQ(feature_processor.StripBoundaryCodepoints(
- "Hello [[[Wořld]] or not?", {0, 24}),
- std::make_pair(0, 24));
- // Test basic stripping.
- EXPECT_EQ(feature_processor.StripBoundaryCodepoints(
- "Hello [[[Wořld]] or not?", {6, 16}),
- std::make_pair(9, 14));
- // Test stripping when everything is stripped.
- EXPECT_EQ(
- feature_processor.StripBoundaryCodepoints("Hello [[[]] or not?", {6, 11}),
- std::make_pair(6, 6));
- // Test stripping empty string.
- EXPECT_EQ(feature_processor.StripBoundaryCodepoints("", {0, 0}),
- std::make_pair(0, 0));
-}
-
-TEST_F(FeatureProcessorTest, CodepointSpanToTokenSpan) {
- const std::vector<Token> tokens{Token("Hělló", 0, 5),
- Token("fěěbař@google.com", 6, 23),
- Token("heře!", 24, 29)};
-
- // Spans matching the tokens exactly.
- EXPECT_EQ(TokenSpan(0, 1), CodepointSpanToTokenSpan(tokens, {0, 5}));
- EXPECT_EQ(TokenSpan(1, 2), CodepointSpanToTokenSpan(tokens, {6, 23}));
- EXPECT_EQ(TokenSpan(2, 3), CodepointSpanToTokenSpan(tokens, {24, 29}));
- EXPECT_EQ(TokenSpan(0, 2), CodepointSpanToTokenSpan(tokens, {0, 23}));
- EXPECT_EQ(TokenSpan(1, 3), CodepointSpanToTokenSpan(tokens, {6, 29}));
- EXPECT_EQ(TokenSpan(0, 3), CodepointSpanToTokenSpan(tokens, {0, 29}));
-
- // Snapping to containing tokens has no effect.
- EXPECT_EQ(TokenSpan(0, 1), CodepointSpanToTokenSpan(tokens, {0, 5}, true));
- EXPECT_EQ(TokenSpan(1, 2), CodepointSpanToTokenSpan(tokens, {6, 23}, true));
- EXPECT_EQ(TokenSpan(2, 3), CodepointSpanToTokenSpan(tokens, {24, 29}, true));
- EXPECT_EQ(TokenSpan(0, 2), CodepointSpanToTokenSpan(tokens, {0, 23}, true));
- EXPECT_EQ(TokenSpan(1, 3), CodepointSpanToTokenSpan(tokens, {6, 29}, true));
- EXPECT_EQ(TokenSpan(0, 3), CodepointSpanToTokenSpan(tokens, {0, 29}, true));
-
- // Span boundaries inside tokens.
- EXPECT_EQ(TokenSpan(1, 2), CodepointSpanToTokenSpan(tokens, {1, 28}));
- EXPECT_EQ(TokenSpan(0, 3), CodepointSpanToTokenSpan(tokens, {1, 28}, true));
-
- // Tokens adjacent to the span, but not overlapping.
- EXPECT_EQ(TokenSpan(1, 2), CodepointSpanToTokenSpan(tokens, {5, 24}));
- EXPECT_EQ(TokenSpan(1, 2), CodepointSpanToTokenSpan(tokens, {5, 24}, true));
-}
-
-} // namespace
-} // namespace libtextclassifier3
diff --git a/native/annotator/knowledge/knowledge-engine-dummy.h b/native/annotator/knowledge/knowledge-engine-dummy.h
index 96d77c5..1787353 100644
--- a/native/annotator/knowledge/knowledge-engine-dummy.h
+++ b/native/annotator/knowledge/knowledge-engine-dummy.h
@@ -19,24 +19,23 @@
#include <string>
+#include "annotator/model_generated.h"
#include "annotator/types.h"
-#include "utils/utf8/unilib.h"
namespace libtextclassifier3 {
// A dummy implementation of the knowledge engine.
class KnowledgeEngine {
public:
- explicit KnowledgeEngine(const UniLib* unilib) {}
-
bool Initialize(const std::string& serialized_config) { return true; }
bool ClassifyText(const std::string& context, CodepointSpan selection_indices,
+ AnnotationUsecase annotation_usecase,
ClassificationResult* classification_result) const {
return false;
}
- bool Chunk(const std::string& context,
+ bool Chunk(const std::string& context, AnnotationUsecase annotation_usecase,
std::vector<AnnotatedSpan>* result) const {
return true;
}
diff --git a/native/annotator/model.fbs b/native/annotator/model.fbs
index 9d18779..181a8aa 100755
--- a/native/annotator/model.fbs
+++ b/native/annotator/model.fbs
@@ -205,16 +205,20 @@
// If set, the text of the capturing group will be used to set a field in
// the classfication result entity data.
entity_field_path:FlatbufferFieldPath;
+
+ // If set, the serialized entity data will be merged with the
+ // classification result entity data.
+ serialized_entity_data:string (shared);
}
// List of regular expression matchers to check.
namespace libtextclassifier3.RegexModel_;
table Pattern {
// The name of the collection of a match.
- collection_name:string;
+ collection_name:string (shared);
// The pattern to check.
- pattern:string;
+ pattern:string (shared);
// The modes for which to apply the patterns.
enabled_modes:ModeFlag = ALL;
@@ -238,7 +242,7 @@
capturing_group:[Pattern_.CapturingGroup];
// Serialized entity data to set for a match.
- serialized_entity_data:string;
+ serialized_entity_data:string (shared);
}
namespace libtextclassifier3;
@@ -263,7 +267,7 @@
// List of regex patterns.
namespace libtextclassifier3.DatetimeModelPattern_;
table Regex {
- pattern:string;
+ pattern:string (shared);
// The ith entry specifies the type of the ith capturing group.
// This is used to decide how the matched content has to be parsed.
@@ -297,7 +301,7 @@
namespace libtextclassifier3;
table DatetimeModelExtractor {
extractor:DatetimeExtractorType;
- pattern:string;
+ pattern:string (shared);
locales:[int];
compressed_pattern:CompressedBuffer;
}
@@ -329,7 +333,7 @@
namespace libtextclassifier3.DatetimeModelLibrary_;
table Item {
- key:string;
+ key:string (shared);
value:DatetimeModel;
}
@@ -350,12 +354,15 @@
// Comma-separated list of locales (BCP 47 tags) that dictionary
// classification supports.
- dictionary_locales:string;
+ dictionary_locales:string (shared);
// Comma-separated list of locales (BCP 47 tags) that the model supports, that
// are used to prevent triggering on input in unsupported languages. If
// empty, the model will trigger on all inputs.
- locales:string;
+ locales:string (shared);
+
+ // Priority score assigned to the "other" class from ML model.
+ other_collection_priority_score:float = -1000;
}
// Options controlling the output of the classifier.
@@ -392,12 +399,12 @@
namespace libtextclassifier3;
table Model {
// Comma-separated list of locales supported by the model as BCP 47 tags.
- locales:string;
+ locales:string (shared);
version:int;
// A name for the model that can be used for e.g. logging.
- name:string;
+ name:string (shared);
selection_feature_options:FeatureProcessorOptions;
classification_feature_options:FeatureProcessorOptions;
@@ -448,7 +455,7 @@
// Comma-separated list of locales (BCP 47 tags) that the model supports, that
// are used to prevent triggering on input in unsupported languages. If
// empty, the model will trigger on all inputs.
- triggering_locales:string;
+ triggering_locales:string (shared);
embedding_pruning_mask:Model_.EmbeddingPruningMask;
}
@@ -621,6 +628,10 @@
// If true, tokens will be also split when the codepoint's script_id changes
// as defined in TokenizationCodepointRange.
tokenize_on_script_change:bool = false;
+
+ // If true, the pipe character '|' will be used as a newline character when
+ // splitting lines.
+ use_pipe_character_for_newline:bool = true;
}
namespace libtextclassifier3;
@@ -628,10 +639,10 @@
// If true, number annotations will be produced.
enabled:bool = false;
- // Score to assign to the annotated numbers from the annotator.
+ // Score to assign to the annotated numbers and percentages in the annotator.
score:float = 1;
- // Priority score used for conflict resolution with the other models.
+ // Number priority score used for conflict resolution with the other models.
priority_score:float = 0;
// The modes in which to enable number annotations.
@@ -646,6 +657,24 @@
// A list of codepoints that can form a suffix of a valid number.
allowed_suffix_codepoints:[int];
+
+ // List of codepoints that will be stripped from beginning of predicted spans.
+ ignored_prefix_span_boundary_codepoints:[int];
+
+ // List of codepoints that will be stripped from end of predicted spans.
+ ignored_suffix_span_boundary_codepoints:[int];
+
+ // If true, percent annotations will be produced.
+ enable_percentage:bool = false;
+
+ // Zero separated and ordered list of suffixes that mark a percent.
+ percentage_pieces_string:string (shared);
+
+ // List of suffixes offsets in the percent_pieces_string string.
+ percentage_pieces_offsets:[int];
+
+ // Priority score for the percentage annotation.
+ percentage_priority_score:float = 1;
}
// DurationAnnotator is so far tailored for English only.
@@ -684,6 +713,10 @@
// List of expressions that mean half of a unit of duration (e.g. "half an
// hour").
half_expressions:[string];
+
+ // Set of condepoints that can split the Annotator tokens to sub-tokens for
+ // sub-token matching.
+ sub_token_separator_codepoints:[int];
}
root_type libtextclassifier3.Model;
diff --git a/native/annotator/number/number.cc b/native/annotator/number/number.cc
index bc3a2fe..7af63fa 100644
--- a/native/annotator/number/number.cc
+++ b/native/annotator/number/number.cc
@@ -28,21 +28,68 @@
const UnicodeText& context, CodepointSpan selection_indices,
AnnotationUsecase annotation_usecase,
ClassificationResult* classification_result) const {
- int64 parsed_value;
+ if (!options_->enabled() || ((1 << annotation_usecase) &
+ options_->enabled_annotation_usecases()) == 0) {
+ return false;
+ }
+
+ int64 parsed_int_value;
+ double parsed_double_value;
int num_prefix_codepoints;
int num_suffix_codepoints;
- if (ParseNumber(UnicodeText::Substring(context, selection_indices.first,
- selection_indices.second),
- &parsed_value, &num_prefix_codepoints,
- &num_suffix_codepoints)) {
- ClassificationResult classification{Collections::Number(), 1.0};
+ const UnicodeText substring_selected = UnicodeText::Substring(
+ context, selection_indices.first, selection_indices.second);
+ if (ParseNumber(substring_selected, &parsed_int_value, &parsed_double_value,
+ &num_prefix_codepoints, &num_suffix_codepoints)) {
TC3_CHECK(classification_result != nullptr);
- classification_result->collection = Collections::Number();
classification_result->score = options_->score();
classification_result->priority_score = options_->priority_score();
- classification_result->numeric_value = parsed_value;
- return true;
+ classification_result->numeric_value = parsed_int_value;
+ classification_result->numeric_double_value = parsed_double_value;
+
+ if (num_suffix_codepoints == 0) {
+ classification_result->collection = Collections::Number();
+ return true;
+ }
+
+ // If the selection ends in %, the parseNumber returns true with
+ // num_suffix_codepoints = 1 => percent
+ if (options_->enable_percentage() &&
+ GetPercentSuffixLength(
+ context, context.size_codepoints(),
+ selection_indices.second - num_suffix_codepoints) ==
+ num_suffix_codepoints) {
+ classification_result->collection = Collections::Percentage();
+ classification_result->priority_score =
+ options_->percentage_priority_score();
+ return true;
+ }
+ } else if (options_->enable_percentage()) {
+ // If the substring selected is a percent matching the form: 5 percent,
+ // 5 pct or 5 pc => percent.
+ std::vector<AnnotatedSpan> results;
+ FindAll(substring_selected, annotation_usecase, &results);
+ for (auto& result : results) {
+ if (result.classification.empty() ||
+ result.classification[0].collection != Collections::Percentage()) {
+ continue;
+ }
+ if (result.span.first == 0 &&
+ result.span.second == substring_selected.size_codepoints()) {
+ TC3_CHECK(classification_result != nullptr);
+ classification_result->collection = Collections::Percentage();
+ classification_result->score = options_->score();
+ classification_result->priority_score =
+ options_->percentage_priority_score();
+ classification_result->numeric_value =
+ result.classification[0].numeric_value;
+ classification_result->numeric_double_value =
+ result.classification[0].numeric_double_value;
+ return true;
+ }
+ }
}
+
return false;
}
@@ -58,14 +105,16 @@
for (const Token& token : tokens) {
const UnicodeText token_text =
UTF8ToUnicodeText(token.value, /*do_copy=*/false);
- int64 parsed_value;
+ int64 parsed_int_value;
+ double parsed_double_value;
int num_prefix_codepoints;
int num_suffix_codepoints;
- if (ParseNumber(token_text, &parsed_value, &num_prefix_codepoints,
- &num_suffix_codepoints)) {
+ if (ParseNumber(token_text, &parsed_int_value, &parsed_double_value,
+ &num_prefix_codepoints, &num_suffix_codepoints)) {
ClassificationResult classification{Collections::Number(),
options_->score()};
- classification.numeric_value = parsed_value;
+ classification.numeric_value = parsed_int_value;
+ classification.numeric_double_value = parsed_double_value;
classification.priority_score = options_->priority_score();
AnnotatedSpan annotated_span;
@@ -77,27 +126,45 @@
}
}
+ if (options_->enable_percentage()) {
+ FindPercentages(context, result);
+ }
+
return true;
}
-std::unordered_set<int> NumberAnnotator::FlatbuffersVectorToSet(
- const flatbuffers::Vector<int32_t>* codepoints) {
- if (codepoints == nullptr) {
- return std::unordered_set<int>{};
+std::unordered_set<int> NumberAnnotator::FlatbuffersIntVectorToSet(
+ const flatbuffers::Vector<int32_t>* ints) {
+ if (ints == nullptr) {
+ return {};
}
+ return {ints->begin(), ints->end()};
+}
- std::unordered_set<int> result;
- for (const int codepoint : *codepoints) {
- result.insert(codepoint);
+std::vector<uint32> NumberAnnotator::FlatbuffersIntVectorToStdVector(
+ const flatbuffers::Vector<int32_t>* ints) {
+ if (ints == nullptr) {
+ return {};
}
- return result;
+ return {ints->begin(), ints->end()};
}
namespace {
+bool ParseNextNumericCodepoint(int32 codepoint, int64* current_value) {
+ if (*current_value > INT64_MAX / 10) {
+ return false;
+ }
+
+ // NOTE: This currently just works with ASCII numbers.
+ *current_value = *current_value * 10 + codepoint - '0';
+ return true;
+}
+
UnicodeText::const_iterator ConsumeAndParseNumber(
const UnicodeText::const_iterator& it_begin,
- const UnicodeText::const_iterator& it_end, int64* result) {
- *result = 0;
+ const UnicodeText::const_iterator& it_end, int64* int_result,
+ double* double_result) {
+ *int_result = 0;
// See if there's a sign in the beginning of the number.
int sign = 1;
@@ -112,31 +179,68 @@
}
}
+ enum class State {
+ PARSING_WHOLE_PART = 1,
+ PARSING_FLOATING_PART = 2,
+ PARSING_DONE = 3,
+ };
+ State state = State::PARSING_WHOLE_PART;
+ int64 decimal_result = 0;
+ int64 decimal_result_denominator = 1;
+ int number_digits = 0;
while (it != it_end) {
- if (*it >= '0' && *it <= '9') {
- // When overflow is imminent we'll fail to parse the number.
- if (*result > INT64_MAX / 10) {
- return it_begin;
- }
- *result *= 10;
- *result += *it - '0';
- } else {
- *result *= sign;
- return it;
+ switch (state) {
+ case State::PARSING_WHOLE_PART:
+ if (*it >= '0' && *it <= '9') {
+ if (!ParseNextNumericCodepoint(*it, int_result)) {
+ return it_begin;
+ }
+ } else if (*it == '.' || *it == ',') {
+ state = State::PARSING_FLOATING_PART;
+ } else {
+ state = State::PARSING_DONE;
+ }
+ break;
+ case State::PARSING_FLOATING_PART:
+ if (*it >= '0' && *it <= '9') {
+ if (!ParseNextNumericCodepoint(*it, &decimal_result)) {
+ state = State::PARSING_DONE;
+ break;
+ }
+ decimal_result_denominator *= 10;
+ } else {
+ state = State::PARSING_DONE;
+ }
+ break;
+ case State::PARSING_DONE:
+ break;
}
+ if (state == State::PARSING_DONE) {
+ break;
+ }
+ ++number_digits;
++it;
}
- *result *= sign;
- return it_end;
+ if (number_digits == 0) {
+ return it_begin;
+ }
+
+ *int_result *= sign;
+ *double_result =
+ *int_result + decimal_result * 1.0 / decimal_result_denominator;
+
+ return it;
}
} // namespace
-bool NumberAnnotator::ParseNumber(const UnicodeText& text, int64* result,
+bool NumberAnnotator::ParseNumber(const UnicodeText& text, int64* int_result,
+ double* double_result,
int* num_prefix_codepoints,
int* num_suffix_codepoints) const {
- TC3_CHECK(result != nullptr && num_prefix_codepoints != nullptr &&
+ TC3_CHECK(int_result != nullptr && double_result != nullptr &&
+ num_prefix_codepoints != nullptr &&
num_suffix_codepoints != nullptr);
auto it = text.begin();
auto it_end = text.end();
@@ -144,14 +248,24 @@
// Strip boundary codepoints from both ends.
const CodepointSpan original_span{0, text.size_codepoints()};
const CodepointSpan stripped_span =
- feature_processor_->StripBoundaryCodepoints(text, original_span);
+ feature_processor_->StripBoundaryCodepoints(
+ text, original_span, ignored_prefix_span_boundary_codepoints_,
+ ignored_suffix_span_boundary_codepoints_);
+
const int num_stripped_end = (original_span.second - stripped_span.second);
std::advance(it, stripped_span.first);
std::advance(it_end, -num_stripped_end);
// Consume prefix codepoints.
*num_prefix_codepoints = stripped_span.first;
- while (it != text.end()) {
+ bool valid_prefix = true;
+ // Makes valid_prefix=false for cases like: "#10" where it points to '1'. In
+ // this case the adjacent prefix is not an allowed one.
+ if (it != text.begin() && allowed_prefix_codepoints_.find(*std::prev(it)) ==
+ allowed_prefix_codepoints_.end()) {
+ valid_prefix = false;
+ }
+ while (it != it_end) {
if (allowed_prefix_codepoints_.find(*it) ==
allowed_prefix_codepoints_.end()) {
break;
@@ -162,7 +276,7 @@
}
auto it_start = it;
- it = ConsumeAndParseNumber(it, text.end(), result);
+ it = ConsumeAndParseNumber(it, it_end, int_result, double_result);
if (it == it_start) {
return false;
}
@@ -181,7 +295,56 @@
++(*num_suffix_codepoints);
}
*num_suffix_codepoints += num_stripped_end;
- return valid_suffix;
+
+ // Makes valid_suffix=false for cases like: "10@", when it == it_end and
+ // points to '@'. This adjacent character is not an allowed suffix.
+ if (it == it_end && it != text.end() &&
+ allowed_suffix_codepoints_.find(*it) ==
+ allowed_suffix_codepoints_.end()) {
+ valid_suffix = false;
+ }
+
+ return valid_suffix && valid_prefix;
+}
+
+int NumberAnnotator::GetPercentSuffixLength(const UnicodeText& context,
+ int context_size_codepoints,
+ int index_codepoints) const {
+ auto context_it = context.begin();
+ std::advance(context_it, index_codepoints);
+ const StringPiece suffix_context(
+ context_it.utf8_data(),
+ std::distance(context_it.utf8_data(), context.end().utf8_data()));
+ TrieMatch match;
+ percentage_suffixes_trie_.LongestPrefixMatch(suffix_context, &match);
+
+ if (match.match_length == -1) {
+ return match.match_length;
+ } else {
+ return UTF8ToUnicodeText(context_it.utf8_data(), match.match_length,
+ /*do_copy=*/false)
+ .size_codepoints();
+ }
+}
+
+void NumberAnnotator::FindPercentages(
+ const UnicodeText& context, std::vector<AnnotatedSpan>* result) const {
+ int context_size_codepoints = context.size_codepoints();
+ for (auto& res : *result) {
+ if (res.classification.empty() ||
+ res.classification[0].collection != Collections::Number()) {
+ continue;
+ }
+
+ const int match_length = GetPercentSuffixLength(
+ context, context_size_codepoints, res.span.second);
+ if (match_length > 0) {
+ res.classification[0].collection = Collections::Percentage();
+ res.classification[0].priority_score =
+ options_->percentage_priority_score();
+ res.span = {res.span.first, res.span.second + match_length};
+ }
+ }
}
} // namespace libtextclassifier3
diff --git a/native/annotator/number/number.h b/native/annotator/number/number.h
index 488f5ea..3debd09 100644
--- a/native/annotator/number/number.h
+++ b/native/annotator/number/number.h
@@ -24,6 +24,8 @@
#include "annotator/feature-processor.h"
#include "annotator/model_generated.h"
#include "annotator/types.h"
+#include "utils/base/logging.h"
+#include "utils/sentencepiece/sorted_strings_table.h"
#include "utils/utf8/unicodetext.h"
namespace libtextclassifier3 {
@@ -32,7 +34,8 @@
//
// Only supports values in range [-999 999 999, 999 999 999] (inclusive).
//
-// TODO(zilka): Add support for non-ASCII digits.
+// TODO(b/138639937): Add support for non-ASCII digits and multiple languages
+// percent.
// TODO(zilka): Add support for written-out numbers.
class NumberAnnotator {
public:
@@ -41,9 +44,24 @@
: options_(options),
feature_processor_(feature_processor),
allowed_prefix_codepoints_(
- FlatbuffersVectorToSet(options->allowed_prefix_codepoints())),
+ FlatbuffersIntVectorToSet(options->allowed_prefix_codepoints())),
allowed_suffix_codepoints_(
- FlatbuffersVectorToSet(options->allowed_suffix_codepoints())) {}
+ FlatbuffersIntVectorToSet(options->allowed_suffix_codepoints())),
+ ignored_prefix_span_boundary_codepoints_(FlatbuffersIntVectorToSet(
+ options->ignored_prefix_span_boundary_codepoints())),
+ ignored_suffix_span_boundary_codepoints_(FlatbuffersIntVectorToSet(
+ options->ignored_suffix_span_boundary_codepoints())),
+ percentage_pieces_string_(
+ (options->percentage_pieces_string() == nullptr)
+ ? StringPiece()
+ : StringPiece(options->percentage_pieces_string()->data(),
+ options->percentage_pieces_string()->size())),
+ percentage_pieces_offsets_(FlatbuffersIntVectorToStdVector(
+ options->percentage_pieces_offsets())),
+ percentage_suffixes_trie_(
+ SortedStringsTable(/*num_pieces=*/percentage_pieces_offsets_.size(),
+ /*offsets=*/percentage_pieces_offsets_.data(),
+ /*pieces=*/percentage_pieces_string_)) {}
// Classifies given text, and if it is a number, it passes the result in
// 'classification_result' and returns true, otherwise returns false.
@@ -57,20 +75,39 @@
std::vector<AnnotatedSpan>* result) const;
private:
- static std::unordered_set<int> FlatbuffersVectorToSet(
- const flatbuffers::Vector<int32_t>* codepoints);
+ static std::unordered_set<int> FlatbuffersIntVectorToSet(
+ const flatbuffers::Vector<int32_t>* ints);
+
+ static std::vector<uint32> FlatbuffersIntVectorToStdVector(
+ const flatbuffers::Vector<int32_t>* ints);
// Parses the text to an int64 value and returns true if succeeded, otherwise
// false. Also returns the number of prefix/suffix codepoints that were
// stripped from the number.
- bool ParseNumber(const UnicodeText& text, int64* result,
- int* num_prefix_codepoints,
+ bool ParseNumber(const UnicodeText& text, int64* int_result,
+ double* double_result, int* num_prefix_codepoints,
int* num_suffix_codepoints) const;
+ // Get the length of the percent suffix at the specified index in the context.
+ int GetPercentSuffixLength(const UnicodeText& context,
+ int context_size_codepoints,
+ int index_codepoints) const;
+
+ // Checks if the annotated numbers from the context represent percentages.
+ // If yes, replaces the collection type and the annotation boundary in the
+ // result.
+ void FindPercentages(const UnicodeText& context,
+ std::vector<AnnotatedSpan>* result) const;
+
const NumberAnnotatorOptions* options_;
const FeatureProcessor* feature_processor_;
const std::unordered_set<int> allowed_prefix_codepoints_;
const std::unordered_set<int> allowed_suffix_codepoints_;
+ const std::unordered_set<int> ignored_prefix_span_boundary_codepoints_;
+ const std::unordered_set<int> ignored_suffix_span_boundary_codepoints_;
+ const StringPiece percentage_pieces_string_;
+ const std::vector<uint32> percentage_pieces_offsets_;
+ const SortedStringsTable percentage_suffixes_trie_;
};
} // namespace libtextclassifier3
diff --git a/native/annotator/number/number_test.cc b/native/annotator/number/number_test.cc
deleted file mode 100644
index d3b2e8c..0000000
--- a/native/annotator/number/number_test.cc
+++ /dev/null
@@ -1,258 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "annotator/number/number.h"
-
-#include <string>
-#include <vector>
-
-#include "annotator/collections.h"
-#include "annotator/model_generated.h"
-#include "annotator/types-test-util.h"
-#include "annotator/types.h"
-#include "utils/test-utils.h"
-#include "utils/utf8/unicodetext.h"
-#include "utils/utf8/unilib.h"
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-
-namespace libtextclassifier3 {
-namespace {
-
-using testing::AllOf;
-using testing::ElementsAre;
-using testing::Field;
-
-const NumberAnnotatorOptions* TestingNumberAnnotatorOptions() {
- static const flatbuffers::DetachedBuffer* options_data = []() {
- NumberAnnotatorOptionsT options;
- options.enabled = true;
- options.allowed_prefix_codepoints.push_back('$');
- options.allowed_suffix_codepoints.push_back('%');
-
- flatbuffers::FlatBufferBuilder builder;
- builder.Finish(NumberAnnotatorOptions::Pack(builder, &options));
- return new flatbuffers::DetachedBuffer(builder.Release());
- }();
-
- return flatbuffers::GetRoot<NumberAnnotatorOptions>(options_data->data());
-}
-
-FeatureProcessor BuildFeatureProcessor(const UniLib* unilib) {
- static const flatbuffers::DetachedBuffer* options_data = []() {
- FeatureProcessorOptionsT options;
- options.context_size = 1;
- options.max_selection_span = 1;
- options.snap_label_span_boundaries_to_containing_tokens = false;
- options.ignored_span_boundary_codepoints.push_back(',');
-
- options.tokenization_codepoint_config.emplace_back(
- new TokenizationCodepointRangeT());
- auto& config = options.tokenization_codepoint_config.back();
- config->start = 32;
- config->end = 33;
- config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR;
-
- flatbuffers::FlatBufferBuilder builder;
- builder.Finish(FeatureProcessorOptions::Pack(builder, &options));
- return new flatbuffers::DetachedBuffer(builder.Release());
- }();
-
- const FeatureProcessorOptions* feature_processor_options =
- flatbuffers::GetRoot<FeatureProcessorOptions>(options_data->data());
-
- return FeatureProcessor(feature_processor_options, unilib);
-}
-
-class NumberAnnotatorTest : public ::testing::Test {
- protected:
- NumberAnnotatorTest()
- : INIT_UNILIB_FOR_TESTING(unilib_),
- feature_processor_(BuildFeatureProcessor(&unilib_)),
- number_annotator_(TestingNumberAnnotatorOptions(),
- &feature_processor_) {}
-
- UniLib unilib_;
- FeatureProcessor feature_processor_;
- NumberAnnotator number_annotator_;
-};
-
-TEST_F(NumberAnnotatorTest, ClassifiesAndParsesNumberCorrectly) {
- ClassificationResult classification_result;
- EXPECT_TRUE(number_annotator_.ClassifyText(
- UTF8ToUnicodeText("... 12345 ..."), {4, 9},
- AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
-
- EXPECT_EQ(classification_result.collection, "number");
- EXPECT_EQ(classification_result.numeric_value, 12345);
-}
-
-TEST_F(NumberAnnotatorTest, ClassifiesNonNumberCorrectly) {
- ClassificationResult classification_result;
- EXPECT_FALSE(number_annotator_.ClassifyText(
- UTF8ToUnicodeText("... 123a45 ..."), {4, 10},
- AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
-}
-
-TEST_F(NumberAnnotatorTest, FindsAllNumbersInText) {
- std::vector<AnnotatedSpan> result;
- EXPECT_TRUE(number_annotator_.FindAll(
- UTF8ToUnicodeText("... 12345 ... 9 is my number and I paid $99 and "
- "sometimes 27% but not 68# nor #68"),
- AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
-
- ASSERT_EQ(result.size(), 4);
- ASSERT_EQ(result[0].classification.size(), 1);
- EXPECT_EQ(result[0].classification[0].collection, "number");
- EXPECT_EQ(result[0].classification[0].numeric_value, 12345);
- ASSERT_EQ(result[1].classification.size(), 1);
- EXPECT_EQ(result[1].classification[0].collection, "number");
- EXPECT_EQ(result[1].classification[0].numeric_value, 9);
- ASSERT_EQ(result[2].classification.size(), 1);
- EXPECT_EQ(result[2].classification[0].collection, "number");
- EXPECT_EQ(result[2].classification[0].numeric_value, 99);
- ASSERT_EQ(result[3].classification.size(), 1);
- EXPECT_EQ(result[3].classification[0].collection, "number");
- EXPECT_EQ(result[3].classification[0].numeric_value, 27);
-}
-
-TEST_F(NumberAnnotatorTest, FindsNumberWithPunctuation) {
- std::vector<AnnotatedSpan> result;
- EXPECT_TRUE(number_annotator_.FindAll(
- UTF8ToUnicodeText("Come at 9, ok?"),
- AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
-
- EXPECT_THAT(
- result,
- ElementsAre(
- AllOf(Field(&AnnotatedSpan::span, CodepointSpan(8, 9)),
- Field(&AnnotatedSpan::classification,
- ElementsAre(AllOf(
- Field(&ClassificationResult::collection, "number"),
- Field(&ClassificationResult::numeric_value, 9)))))));
-}
-
-TEST_F(NumberAnnotatorTest, HandlesNumbersAtBeginning) {
- std::vector<AnnotatedSpan> result;
- EXPECT_TRUE(number_annotator_.FindAll(
- UTF8ToUnicodeText("-5"), AnnotationUsecase_ANNOTATION_USECASE_RAW,
- &result));
-
- EXPECT_THAT(
- result,
- ElementsAre(
- AllOf(Field(&AnnotatedSpan::span, CodepointSpan(0, 2)),
- Field(&AnnotatedSpan::classification,
- ElementsAre(AllOf(
- Field(&ClassificationResult::collection, "number"),
- Field(&ClassificationResult::numeric_value, -5)))))));
-}
-
-TEST_F(NumberAnnotatorTest, WhenLowestSupportedNumberParsesIt) {
- ClassificationResult classification_result;
- EXPECT_TRUE(number_annotator_.ClassifyText(
- UTF8ToUnicodeText("-999999999999999999"), {0, 19},
- AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
-
- EXPECT_THAT(
- classification_result,
- AllOf(Field(&ClassificationResult::collection, "number"),
- Field(&ClassificationResult::numeric_value, -999999999999999999L)));
-}
-
-TEST_F(NumberAnnotatorTest, WhenLargestSupportedNumberParsesIt) {
- ClassificationResult classification_result;
- EXPECT_TRUE(number_annotator_.ClassifyText(
- UTF8ToUnicodeText("999999999999999999"), {0, 18},
- AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
-
- EXPECT_THAT(
- classification_result,
- AllOf(Field(&ClassificationResult::collection, "number"),
- Field(&ClassificationResult::numeric_value, 999999999999999999L)));
-}
-
-TEST_F(NumberAnnotatorTest, WhenFirstLowestNonSupportedNumberDoesNotParseIt) {
- ClassificationResult classification_result;
- EXPECT_FALSE(number_annotator_.ClassifyText(
- UTF8ToUnicodeText("-10000000000000000000"), {0, 21},
- AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
-}
-
-TEST_F(NumberAnnotatorTest, WhenFirstLargestNonSupportedNumberDoesNotParseIt) {
- ClassificationResult classification_result;
- EXPECT_FALSE(number_annotator_.ClassifyText(
- UTF8ToUnicodeText("10000000000000000000"), {0, 20},
- AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
-}
-
-TEST_F(NumberAnnotatorTest, WhenLargeNumberDoesNotParseIt) {
- ClassificationResult classification_result;
- EXPECT_FALSE(number_annotator_.ClassifyText(
- UTF8ToUnicodeText("1234567890123456789012345678901234567890"), {0, 40},
- AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
-}
-
-TEST_F(NumberAnnotatorTest, WhenMultipleMinusSignsDoesNotParseIt) {
- ClassificationResult classification_result;
- EXPECT_FALSE(number_annotator_.ClassifyText(
- UTF8ToUnicodeText("--10"), {0, 4},
- AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
-}
-
-TEST_F(NumberAnnotatorTest, WhenMinusSignSuffixDoesNotParseIt) {
- ClassificationResult classification_result;
- EXPECT_FALSE(number_annotator_.ClassifyText(
- UTF8ToUnicodeText("10-"), {0, 3},
- AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
-}
-
-TEST_F(NumberAnnotatorTest, WhenMinusInTheMiddleDoesNotParseIt) {
- ClassificationResult classification_result;
- EXPECT_FALSE(number_annotator_.ClassifyText(
- UTF8ToUnicodeText("2016-2017"), {0, 9},
- AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
-}
-
-TEST_F(NumberAnnotatorTest, WhenSuffixWithoutNumberDoesNotParseIt) {
- std::vector<AnnotatedSpan> result;
- EXPECT_TRUE(number_annotator_.FindAll(
- UTF8ToUnicodeText("... % ..."), AnnotationUsecase_ANNOTATION_USECASE_RAW,
- &result));
-
- ASSERT_EQ(result.size(), 0);
-}
-
-TEST_F(NumberAnnotatorTest, WhenPrefixWithoutNumberDoesNotParseIt) {
- std::vector<AnnotatedSpan> result;
- EXPECT_TRUE(number_annotator_.FindAll(
- UTF8ToUnicodeText("... $ ..."), AnnotationUsecase_ANNOTATION_USECASE_RAW,
- &result));
-
- ASSERT_EQ(result.size(), 0);
-}
-
-TEST_F(NumberAnnotatorTest, WhenPrefixAndSuffixWithoutNumberDoesNotParseIt) {
- std::vector<AnnotatedSpan> result;
- EXPECT_TRUE(number_annotator_.FindAll(
- UTF8ToUnicodeText("... $% ..."), AnnotationUsecase_ANNOTATION_USECASE_RAW,
- &result));
-
- ASSERT_EQ(result.size(), 0);
-}
-
-} // namespace
-} // namespace libtextclassifier3
diff --git a/native/annotator/test_data/test_model.fb b/native/annotator/test_data/test_model.fb
index 0f2ec16..ce5f72f 100644
--- a/native/annotator/test_data/test_model.fb
+++ b/native/annotator/test_data/test_model.fb
Binary files differ
diff --git a/native/annotator/test_data/wrong_embeddings.fb b/native/annotator/test_data/wrong_embeddings.fb
index 5439623..efefa3c 100644
--- a/native/annotator/test_data/wrong_embeddings.fb
+++ b/native/annotator/test_data/wrong_embeddings.fb
Binary files differ
diff --git a/native/annotator/types-test-util.h b/native/annotator/types-test-util.h
index c0b0980..1d018a1 100644
--- a/native/annotator/types-test-util.h
+++ b/native/annotator/types-test-util.h
@@ -34,7 +34,7 @@
TC3_DECLARE_PRINT_OPERATOR(AnnotatedSpan)
TC3_DECLARE_PRINT_OPERATOR(ClassificationResult)
-TC3_DECLARE_PRINT_OPERATOR(DateParseData)
+TC3_DECLARE_PRINT_OPERATOR(DatetimeParsedData)
TC3_DECLARE_PRINT_OPERATOR(DatetimeParseResultSpan)
TC3_DECLARE_PRINT_OPERATOR(Token)
diff --git a/native/annotator/types.cc b/native/annotator/types.cc
index ee150c8..c31097d 100644
--- a/native/annotator/types.cc
+++ b/native/annotator/types.cc
@@ -28,6 +28,24 @@
}
}
+bool DatetimeComponent::ShouldRoundToGranularity() const {
+ // Don't round to the granularity for relative expressions that specify the
+ // distance. So that, e.g. "in 2 hours" when it's 8:35:03 will result in
+ // 10:35:03.
+ if (relative_qualifier == RelativeQualifier::UNSPECIFIED) {
+ return false;
+ }
+ if (relative_qualifier == RelativeQualifier::NEXT ||
+ relative_qualifier == RelativeQualifier::TOMORROW ||
+ relative_qualifier == RelativeQualifier::YESTERDAY ||
+ relative_qualifier == RelativeQualifier::LAST ||
+ relative_qualifier == RelativeQualifier::THIS ||
+ relative_qualifier == RelativeQualifier::NOW) {
+ return true;
+ }
+ return false;
+}
+
namespace {
std::string FormatMillis(int64 time_ms_utc) {
long time_seconds = time_ms_utc / 1000; // NOLINT
@@ -82,25 +100,170 @@
}
logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
- const DateParseData& data) {
- // TODO(zilka): Add human-readable form of field_set_mask and the enum fields.
- stream = stream << "DateParseData {\n";
- stream = stream << " field_set_mask: " << data.field_set_mask << "\n";
- stream = stream << " year: " << data.year << "\n";
- stream = stream << " month: " << data.month << "\n";
- stream = stream << " day_of_month: " << data.day_of_month << "\n";
- stream = stream << " hour: " << data.hour << "\n";
- stream = stream << " minute: " << data.minute << "\n";
- stream = stream << " second: " << data.second << "\n";
- stream = stream << " ampm: " << static_cast<int>(data.ampm) << "\n";
- stream = stream << " zone_offset: " << data.zone_offset << "\n";
- stream = stream << " dst_offset: " << data.dst_offset << "\n";
- stream = stream << " relation: " << static_cast<int>(data.relation) << "\n";
- stream = stream << " relation_type: " << static_cast<int>(data.relation_type)
- << "\n";
- stream = stream << " relation_distance: " << data.relation_distance << "\n";
+ const DatetimeParsedData& data) {
+ std::vector<DatetimeComponent> date_time_components;
+ data.GetDatetimeComponents(&date_time_components);
+ stream = stream << "DatetimeParsedData { \n";
+ for (const DatetimeComponent& c : date_time_components) {
+ stream = stream << " DatetimeComponent { \n";
+ stream = stream << " Component Type:" << static_cast<int>(c.component_type)
+ << "\n";
+ stream = stream << " Value:" << c.value << "\n";
+ stream = stream << " Relative Qualifier:"
+ << static_cast<int>(c.relative_qualifier) << "\n";
+ stream = stream << " Relative Count:" << c.relative_count << "\n";
+ stream = stream << " } \n";
+ }
stream = stream << "}";
return stream;
}
+void DatetimeParsedData::SetAbsoluteValue(
+ const DatetimeComponent::ComponentType& field_type, int value) {
+ GetOrCreateDatetimeComponent(field_type).value = value;
+}
+
+void DatetimeParsedData::SetRelativeValue(
+ const DatetimeComponent::ComponentType& field_type,
+ const DatetimeComponent::RelativeQualifier& relative_value) {
+ GetOrCreateDatetimeComponent(field_type).relative_qualifier = relative_value;
+}
+
+void DatetimeParsedData::SetRelativeCount(
+ const DatetimeComponent::ComponentType& field_type, int relative_count) {
+ GetOrCreateDatetimeComponent(field_type).relative_count = relative_count;
+}
+
+bool DatetimeParsedData::HasFieldType(
+ const DatetimeComponent::ComponentType& field_type) const {
+ if (date_time_components_.find(field_type) == date_time_components_.end()) {
+ return false;
+ }
+ return true;
+}
+
+bool DatetimeParsedData::GetFieldValue(
+ const DatetimeComponent::ComponentType& field_type,
+ int* field_value) const {
+ if (HasFieldType(field_type)) {
+ *field_value = date_time_components_.at(field_type).value;
+ return true;
+ }
+ return false;
+}
+
+bool DatetimeParsedData::GetRelativeValue(
+ const DatetimeComponent::ComponentType& field_type,
+ DatetimeComponent::RelativeQualifier* relative_value) const {
+ if (HasFieldType(field_type)) {
+ *relative_value = date_time_components_.at(field_type).relative_qualifier;
+ return true;
+ }
+ return false;
+}
+
+bool DatetimeParsedData::HasRelativeValue(
+ const DatetimeComponent::ComponentType& field_type) const {
+ if (HasFieldType(field_type)) {
+ return date_time_components_.at(field_type).relative_qualifier !=
+ DatetimeComponent::RelativeQualifier::UNSPECIFIED;
+ }
+ return false;
+}
+
+bool DatetimeParsedData::HasAbsoluteValue(
+ const DatetimeComponent::ComponentType& field_type) const {
+ return HasFieldType(field_type) && !HasRelativeValue(field_type);
+}
+
+void DatetimeParsedData::GetRelativeDatetimeComponents(
+ std::vector<DatetimeComponent>* date_time_components) const {
+ for (auto it = date_time_components_.begin();
+ it != date_time_components_.end(); it++) {
+ if (it->second.relative_qualifier !=
+ DatetimeComponent::RelativeQualifier::UNSPECIFIED) {
+ date_time_components->push_back(it->second);
+ }
+ }
+}
+
+void DatetimeParsedData::GetDatetimeComponents(
+ std::vector<DatetimeComponent>* date_time_components) const {
+ for (auto it = date_time_components_.begin();
+ it != date_time_components_.end(); it++) {
+ date_time_components->push_back(it->second);
+ }
+}
+
+DatetimeGranularity DatetimeParsedData::GetFinestGranularity() const {
+ DatetimeGranularity granularity = DatetimeGranularity::GRANULARITY_UNKNOWN;
+ for (auto it = date_time_components_.begin();
+ it != date_time_components_.end(); it++) {
+ switch (it->first) {
+ case DatetimeComponent::ComponentType::YEAR:
+ if (granularity < DatetimeGranularity::GRANULARITY_YEAR) {
+ granularity = DatetimeGranularity::GRANULARITY_YEAR;
+ }
+ break;
+
+ case DatetimeComponent::ComponentType::MONTH:
+ if (granularity < DatetimeGranularity::GRANULARITY_MONTH) {
+ granularity = DatetimeGranularity::GRANULARITY_MONTH;
+ }
+ break;
+
+ case DatetimeComponent::ComponentType::WEEK:
+ if (granularity < DatetimeGranularity::GRANULARITY_WEEK) {
+ granularity = DatetimeGranularity::GRANULARITY_WEEK;
+ }
+ break;
+
+ case DatetimeComponent::ComponentType::DAY_OF_WEEK:
+ case DatetimeComponent::ComponentType::DAY_OF_MONTH:
+ if (granularity < DatetimeGranularity::GRANULARITY_DAY) {
+ granularity = DatetimeGranularity::GRANULARITY_DAY;
+ }
+ break;
+
+ case DatetimeComponent::ComponentType::HOUR:
+ if (granularity < DatetimeGranularity::GRANULARITY_HOUR) {
+ granularity = DatetimeGranularity::GRANULARITY_HOUR;
+ }
+ break;
+
+ case DatetimeComponent::ComponentType::MINUTE:
+ if (granularity < DatetimeGranularity::GRANULARITY_MINUTE) {
+ granularity = DatetimeGranularity::GRANULARITY_MINUTE;
+ }
+ break;
+
+ case DatetimeComponent::ComponentType::SECOND:
+ if (granularity < DatetimeGranularity::GRANULARITY_SECOND) {
+ granularity = DatetimeGranularity::GRANULARITY_SECOND;
+ }
+ break;
+
+ case DatetimeComponent::ComponentType::MERIDIEM:
+ case DatetimeComponent::ComponentType::ZONE_OFFSET:
+ case DatetimeComponent::ComponentType::DST_OFFSET:
+ default:
+ break;
+ }
+ }
+ return granularity;
+}
+
+DatetimeComponent& DatetimeParsedData::GetOrCreateDatetimeComponent(
+ const DatetimeComponent::ComponentType& component_type) {
+ auto result =
+ date_time_components_
+ .insert(
+ {component_type,
+ DatetimeComponent(
+ component_type,
+ DatetimeComponent::RelativeQualifier::UNSPECIFIED, 0, 0)})
+ .first;
+ return result->second;
+}
+
} // namespace libtextclassifier3
diff --git a/native/annotator/types.h b/native/annotator/types.h
index 48fefe4..ac24e24 100644
--- a/native/annotator/types.h
+++ b/native/annotator/types.h
@@ -18,6 +18,7 @@
#define LIBTEXTCLASSIFIER_ANNOTATOR_TYPES_H_
#include <time.h>
+
#include <algorithm>
#include <cmath>
#include <functional>
@@ -36,6 +37,13 @@
namespace libtextclassifier3 {
constexpr int kInvalidIndex = -1;
+constexpr int kSunday = 1;
+constexpr int kMonday = 2;
+constexpr int kTuesday = 3;
+constexpr int kWednesday = 4;
+constexpr int kThursday = 5;
+constexpr int kFriday = 6;
+constexpr int kSaturday = 7;
// Index for a 0-based array of tokens.
using TokenIndex = int;
@@ -165,6 +173,83 @@
GRANULARITY_SECOND = 6
};
+// This struct represents a unit of date and time expression.
+// Examples include:
+// - In {March 21, 2019} datetime components are month: {March},
+// day of month: {21} and year: {2019}.
+// - {8:00 am} contains hour: {8}, minutes: {0} and am/pm: {am}
+struct DatetimeComponent {
+ enum class ComponentType {
+ UNSPECIFIED = 0,
+ // Year of the date seen in the text match.
+ YEAR = 1,
+ // Month of the year starting with January = 1.
+ MONTH = 2,
+ // Week (7 days).
+ WEEK = 3,
+ // Day of week, start of the week is Sunday & its value is 1.
+ DAY_OF_WEEK = 4,
+ // Day of the month starting with 1.
+ DAY_OF_MONTH = 5,
+ // Hour of the day with a range of 0-23,
+ // values less than 12 need the AMPM field below or heuristics
+ // to definitively determine the time.
+ HOUR = 6,
+ // Minute of the hour with a range of 0-59.
+ MINUTE = 7,
+ // Seconds of the minute with a range of 0-59.
+ SECOND = 8,
+ // Meridiem field where 0 == AM, 1 == PM.
+ MERIDIEM = 9,
+ // Number of hours offset from UTC this date time is in.
+ ZONE_OFFSET = 10,
+ // Number of hours offest for DST.
+ DST_OFFSET = 11,
+ };
+
+ // TODO(hassan): Remove RelativeQualifier as in the presence of relative
+ // count RelativeQualifier is redundant.
+ // Enum to represent the relative DateTimeComponent e.g. "next Monday",
+ // "the following day", "tomorrow".
+ enum class RelativeQualifier {
+ UNSPECIFIED = 0,
+ NEXT = 1,
+ THIS = 2,
+ LAST = 3,
+ NOW = 4,
+ TOMORROW = 5,
+ YESTERDAY = 6,
+ PAST = 7,
+ FUTURE = 8
+ };
+
+ bool operator==(const DatetimeComponent& other) const {
+ return component_type == other.component_type &&
+ relative_qualifier == other.relative_qualifier &&
+ relative_count == other.relative_count && value == other.value;
+ }
+
+ bool ShouldRoundToGranularity() const;
+
+ ComponentType component_type = ComponentType::UNSPECIFIED;
+ RelativeQualifier relative_qualifier = RelativeQualifier::UNSPECIFIED;
+
+ // Represents the absolute value of DateTime components.
+ int value = 0;
+ // The number of units of change present in the relative DateTimeComponent.
+ int relative_count = 0;
+
+ DatetimeComponent() = default;
+
+ explicit DatetimeComponent(ComponentType arg_component_type,
+ RelativeQualifier arg_relative_qualifier,
+ int arg_value, int arg_relative_count)
+ : component_type(arg_component_type),
+ relative_qualifier(arg_relative_qualifier),
+ value(arg_value),
+ relative_count(arg_relative_count) {}
+};
+
struct DatetimeParseResult {
// The absolute time in milliseconds since the epoch in UTC.
int64 time_ms_utc;
@@ -172,16 +257,24 @@
// The precision of the estimate then in to calculating the milliseconds
DatetimeGranularity granularity;
+ // List of parsed DateTimeComponent.
+ std::vector<DatetimeComponent> datetime_components;
+
DatetimeParseResult() : time_ms_utc(0), granularity(GRANULARITY_UNKNOWN) {}
DatetimeParseResult(int64 arg_time_ms_utc,
- DatetimeGranularity arg_granularity)
- : time_ms_utc(arg_time_ms_utc), granularity(arg_granularity) {}
+ DatetimeGranularity arg_granularity,
+ std::vector<DatetimeComponent> arg_datetime__components)
+ : time_ms_utc(arg_time_ms_utc),
+ granularity(arg_granularity),
+ datetime_components(arg_datetime__components) {}
bool IsSet() const { return granularity != GRANULARITY_UNKNOWN; }
bool operator==(const DatetimeParseResult& other) const {
- return granularity == other.granularity && time_ms_utc == other.time_ms_utc;
+ return granularity == other.granularity &&
+ time_ms_utc == other.time_ms_utc &&
+ datetime_components == other.datetime_components;
}
};
@@ -193,6 +286,19 @@
float target_classification_score;
float priority_score;
+ DatetimeParseResultSpan()
+ : target_classification_score(-1.0), priority_score(-1.0) {}
+
+ DatetimeParseResultSpan(const CodepointSpan& span,
+ const std::vector<DatetimeParseResult>& data,
+ const float target_classification_score,
+ const float priority_score) {
+ this->span = span;
+ this->data = data;
+ this->target_classification_score = target_classification_score;
+ this->priority_score = priority_score;
+ }
+
bool operator==(const DatetimeParseResultSpan& other) const {
return span == other.span && data == other.data &&
std::abs(target_classification_score -
@@ -206,15 +312,36 @@
logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
const DatetimeParseResultSpan& value);
+// This struct contains information intended to uniquely identify a device
+// contact. Instances are created by the Knowledge Engine, and dereferenced by
+// the Contact Engine.
+struct ContactPointer {
+ std::string focus_contact_id;
+ std::string device_id;
+ std::string device_contact_id;
+ std::string contact_name;
+ std::string contact_name_hash;
+
+ bool operator==(const ContactPointer& other) const {
+ return focus_contact_id == other.focus_contact_id &&
+ device_id == other.device_id &&
+ device_contact_id == other.device_contact_id &&
+ contact_name == other.contact_name &&
+ contact_name_hash == other.contact_name_hash;
+ }
+};
+
struct ClassificationResult {
std::string collection;
float score;
DatetimeParseResult datetime_parse_result;
std::string serialized_knowledge_result;
- std::string contact_name, contact_given_name, contact_nickname,
- contact_email_address, contact_phone_number, contact_id;
+ ContactPointer contact_pointer;
+ std::string contact_name, contact_given_name, contact_family_name,
+ contact_nickname, contact_email_address, contact_phone_number, contact_id;
std::string app_name, app_package_name;
int64 numeric_value;
+ double numeric_double_value;
// Length of the parsed duration in milliseconds.
int64 duration_ms;
@@ -230,18 +357,54 @@
serialized_entity_data.size());
}
- explicit ClassificationResult() : score(-1.0f), priority_score(-1.0) {}
+ explicit ClassificationResult()
+ : score(-1.0f),
+ numeric_value(0),
+ numeric_double_value(0.),
+ duration_ms(0),
+ priority_score(-1.0) {}
ClassificationResult(const std::string& arg_collection, float arg_score)
: collection(arg_collection),
score(arg_score),
+ numeric_value(0),
+ numeric_double_value(0.),
+ duration_ms(0),
priority_score(arg_score) {}
ClassificationResult(const std::string& arg_collection, float arg_score,
float arg_priority_score)
: collection(arg_collection),
score(arg_score),
+ numeric_value(0),
+ numeric_double_value(0.),
+ duration_ms(0),
priority_score(arg_priority_score) {}
+
+ bool operator!=(const ClassificationResult& other) const {
+ return !(*this == other);
+ }
+
+ bool operator==(const ClassificationResult& other) const {
+ return collection == other.collection &&
+ fabs(score - other.score) < 0.001 &&
+ datetime_parse_result == other.datetime_parse_result &&
+ serialized_knowledge_result == other.serialized_knowledge_result &&
+ contact_pointer == other.contact_pointer &&
+ contact_name == other.contact_name &&
+ contact_given_name == other.contact_given_name &&
+ contact_family_name == other.contact_family_name &&
+ contact_nickname == other.contact_nickname &&
+ contact_email_address == other.contact_email_address &&
+ contact_phone_number == other.contact_phone_number &&
+ contact_id == other.contact_id &&
+ app_package_name == other.app_package_name &&
+ fabs(priority_score - other.priority_score) < 0.001 &&
+ numeric_value == other.numeric_value &&
+ fabs(numeric_double_value - other.numeric_double_value) < 0.001 &&
+ duration_ms == other.duration_ms &&
+ serialized_entity_data == other.serialized_entity_data;
+ }
};
// Pretty-printing function for ClassificationResult.
@@ -302,121 +465,74 @@
typename std::vector<T>::const_iterator end_;
};
-struct DateParseData {
- enum class Relation {
- UNSPECIFIED = 0,
- NEXT = 1,
- NEXT_OR_SAME = 2,
- LAST = 3,
- NOW = 4,
- TOMORROW = 5,
- YESTERDAY = 6,
- PAST = 7,
- FUTURE = 8
- };
+// Class to provide representation of date and time expressions
+class DatetimeParsedData {
+ public:
+ // Function to set the absolute value of DateTimeComponent for the given
+ // FieldType, if the field is not present it will create the field and set
+ // the value.
+ void SetAbsoluteValue(const DatetimeComponent::ComponentType& field_type,
+ int value);
- enum class RelationType {
- UNSPECIFIED = 0,
- SUNDAY = 1,
- MONDAY = 2,
- TUESDAY = 3,
- WEDNESDAY = 4,
- THURSDAY = 5,
- FRIDAY = 6,
- SATURDAY = 7,
- DAY = 8,
- WEEK = 9,
- MONTH = 10,
- YEAR = 11,
- HOUR = 12,
- MINUTE = 13,
- SECOND = 14,
- };
+ // Function to set the relative value of DateTimeComponent, if the field is
+ // not present the function will create the field and set the relative value.
+ void SetRelativeValue(
+ const DatetimeComponent::ComponentType& field_type,
+ const DatetimeComponent::RelativeQualifier& relative_value);
- enum Fields {
- YEAR_FIELD = 1 << 0,
- MONTH_FIELD = 1 << 1,
- DAY_FIELD = 1 << 2,
- HOUR_FIELD = 1 << 3,
- MINUTE_FIELD = 1 << 4,
- SECOND_FIELD = 1 << 5,
- AMPM_FIELD = 1 << 6,
- ZONE_OFFSET_FIELD = 1 << 7,
- DST_OFFSET_FIELD = 1 << 8,
- RELATION_FIELD = 1 << 9,
- RELATION_TYPE_FIELD = 1 << 10,
- RELATION_DISTANCE_FIELD = 1 << 11
- };
+ // Function to set the relative count of DateTimeComponent, if the field is
+ // not present the function will create the field and set the count.
+ void SetRelativeCount(const DatetimeComponent::ComponentType& field_type,
+ int relative_count);
- enum class AMPM { AM = 0, PM = 1 };
+ // Function to populate the absolute value of the FieldType and return true.
+ // In case of no FieldType function will return false.
+ bool GetFieldValue(const DatetimeComponent::ComponentType& field_type,
+ int* field_value) const;
- enum class TimeUnit {
- DAYS = 1,
- WEEKS = 2,
- MONTHS = 3,
- HOURS = 4,
- MINUTES = 5,
- SECONDS = 6,
- YEARS = 7
- };
+ // Function to populate the relative value of the FieldType and return true.
+ // In case of no relative value function will return false.
+ bool GetRelativeValue(
+ const DatetimeComponent::ComponentType& field_type,
+ DatetimeComponent::RelativeQualifier* relative_value) const;
- // Bit mask of fields which have been set on the struct
- int field_set_mask = 0;
+ // Returns relative DateTimeComponent from the parsed DateTime span.
+ void GetRelativeDatetimeComponents(
+ std::vector<DatetimeComponent>* date_time_components) const;
- // Fields describing absolute date fields.
- // Year of the date seen in the text match.
- int year = 0;
- // Month of the year starting with January = 1.
- int month = 0;
- // Day of the month starting with 1.
- int day_of_month = 0;
- // Hour of the day with a range of 0-23,
- // values less than 12 need the AMPM field below or heuristics
- // to definitively determine the time.
- int hour = 0;
- // Hour of the day with a range of 0-59.
- int minute = 0;
- // Hour of the day with a range of 0-59.
- int second = 0;
- // 0 == AM, 1 == PM
- AMPM ampm = AMPM::AM;
- // Number of hours offset from UTC this date time is in.
- int zone_offset = 0;
- // Number of hours offest for DST
- int dst_offset = 0;
+ // Returns DateTimeComponent from the parsed DateTime span.
+ void GetDatetimeComponents(
+ std::vector<DatetimeComponent>* date_time_components) const;
- // The permutation from now that was made to find the date time.
- Relation relation = Relation::UNSPECIFIED;
- // The unit of measure of the change to the date time.
- RelationType relation_type = RelationType::UNSPECIFIED;
- // The number of units of change that were made.
- int relation_distance = 0;
+ // Represent the granularity of the Parsed DateTime span. The function will
+ // return “GRANULARITY_UNKNOWN” if no datetime field is set.
+ DatetimeGranularity GetFinestGranularity() const;
- DateParseData() = default;
+ // Utility function to check if DateTimeParsedData has FieldType initialized.
+ bool HasFieldType(const DatetimeComponent::ComponentType& field_type) const;
- DateParseData(int field_set_mask, int year, int month, int day_of_month,
- int hour, int minute, int second, AMPM ampm, int zone_offset,
- int dst_offset, Relation relation, RelationType relation_type,
- int relation_distance) {
- this->field_set_mask = field_set_mask;
- this->year = year;
- this->month = month;
- this->day_of_month = day_of_month;
- this->hour = hour;
- this->minute = minute;
- this->second = second;
- this->ampm = ampm;
- this->zone_offset = zone_offset;
- this->dst_offset = dst_offset;
- this->relation = relation;
- this->relation_type = relation_type;
- this->relation_distance = relation_distance;
- }
+ // Function to check if DateTimeParsedData has relative DateTimeComponent for
+ // given FieldType.
+ bool HasRelativeValue(
+ const DatetimeComponent::ComponentType& field_type) const;
+
+ // Function to check if DateTimeParsedData has absolute value
+ // DateTimeComponent for given FieldType.
+ bool HasAbsoluteValue(
+ const DatetimeComponent::ComponentType& field_type) const;
+
+ private:
+ DatetimeComponent& GetOrCreateDatetimeComponent(
+
+ const DatetimeComponent::ComponentType& component_type);
+
+ std::map<DatetimeComponent::ComponentType, DatetimeComponent>
+ date_time_components_;
};
-// Pretty-printing function for DateParseData.
+// Pretty-printing function for DateTimeParsedData.
logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
- const DateParseData& data);
+ const DatetimeParsedData& data);
} // namespace libtextclassifier3
diff --git a/native/lang_id/common/embedding-feature-extractor.cc b/native/lang_id/common/embedding-feature-extractor.cc
index 6235f89..a2e3cdf 100644
--- a/native/lang_id/common/embedding-feature-extractor.cc
+++ b/native/lang_id/common/embedding-feature-extractor.cc
@@ -35,10 +35,10 @@
bool GenericEmbeddingFeatureExtractor::Setup(TaskContext *context) {
// Don't use version to determine how to get feature FML.
- const string features = context->Get(GetParamName("features"), "");
- const string embedding_names =
+ const std::string features = context->Get(GetParamName("features"), "");
+ const std::string embedding_names =
context->Get(GetParamName("embedding_names"), "");
- const string embedding_dims =
+ const std::string embedding_dims =
context->Get(GetParamName("embedding_dims"), "");
// NOTE: unfortunately, LiteStrSplit returns a vector of StringPieces pointing
diff --git a/native/lang_id/common/embedding-feature-extractor.h b/native/lang_id/common/embedding-feature-extractor.h
index f51b6e5..ba4f858 100644
--- a/native/lang_id/common/embedding-feature-extractor.h
+++ b/native/lang_id/common/embedding-feature-extractor.h
@@ -46,7 +46,7 @@
//
// |arg_prefix| is a string prefix for the relevant TaskContext parameters, to
// avoid name clashes. See GetParamName().
- explicit GenericEmbeddingFeatureExtractor(const string &arg_prefix)
+ explicit GenericEmbeddingFeatureExtractor(const std::string &arg_prefix)
: arg_prefix_(arg_prefix) {}
virtual ~GenericEmbeddingFeatureExtractor() {}
@@ -65,11 +65,13 @@
// Returns number of embedding spaces.
int NumEmbeddings() const { return embedding_dims_.size(); }
- const std::vector<string> &embedding_fml() const { return embedding_fml_; }
+ const std::vector<std::string> &embedding_fml() const {
+ return embedding_fml_;
+ }
// Get parameter name by concatenating the prefix and the original name.
- string GetParamName(const string ¶m_name) const {
- string full_name = arg_prefix_;
+ std::string GetParamName(const std::string ¶m_name) const {
+ std::string full_name = arg_prefix_;
full_name.push_back('_');
full_name.append(param_name);
return full_name;
@@ -77,13 +79,13 @@
private:
// Prefix for TaskContext parameters.
- const string arg_prefix_;
+ const std::string arg_prefix_;
// Embedding space names for parameter sharing.
- std::vector<string> embedding_names_;
+ std::vector<std::string> embedding_names_;
// FML strings for each feature extractor.
- std::vector<string> embedding_fml_;
+ std::vector<std::string> embedding_fml_;
// Size of each of the embedding spaces (maximum predicate id).
std::vector<int> embedding_sizes_;
@@ -106,7 +108,7 @@
//
// |arg_prefix| is a string prefix for the relevant TaskContext parameters, to
// avoid name clashes. See GetParamName().
- explicit EmbeddingFeatureExtractor(const string &arg_prefix)
+ explicit EmbeddingFeatureExtractor(const std::string &arg_prefix)
: GenericEmbeddingFeatureExtractor(arg_prefix) {}
// Sets up all predicate maps, feature extractors, and flags.
diff --git a/native/lang_id/common/embedding-feature-interface.h b/native/lang_id/common/embedding-feature-interface.h
index 87576c6..75d0c98 100644
--- a/native/lang_id/common/embedding-feature-interface.h
+++ b/native/lang_id/common/embedding-feature-interface.h
@@ -36,7 +36,7 @@
//
// |arg_prefix| is a string prefix for the TaskContext parameters, passed to
// |the underlying EmbeddingFeatureExtractor.
- explicit EmbeddingFeatureInterface(const string &arg_prefix)
+ explicit EmbeddingFeatureInterface(const std::string &arg_prefix)
: feature_extractor_(arg_prefix) {}
// Sets up feature extractors and flags for processing (inference).
diff --git a/native/lang_id/common/embedding-network-params.cc b/native/lang_id/common/embedding-network-params.cc
index be7c80e..8b48fce 100644
--- a/native/lang_id/common/embedding-network-params.cc
+++ b/native/lang_id/common/embedding-network-params.cc
@@ -16,11 +16,13 @@
#include "lang_id/common/embedding-network-params.h"
+#include <string>
+
#include "lang_id/common/lite_base/logging.h"
namespace libtextclassifier3 {
-QuantizationType ParseQuantizationType(const string &s) {
+QuantizationType ParseQuantizationType(const std::string &s) {
if (s == "NONE") {
return QuantizationType::NONE;
}
diff --git a/native/lang_id/common/embedding-network-params.h b/native/lang_id/common/embedding-network-params.h
index f43c653..6ad147c 100755
--- a/native/lang_id/common/embedding-network-params.h
+++ b/native/lang_id/common/embedding-network-params.h
@@ -44,7 +44,7 @@
};
// Converts "UINT8" -> QuantizationType::UINT8, and so on.
-QuantizationType ParseQuantizationType(const string &s);
+QuantizationType ParseQuantizationType(const std::string &s);
// API for accessing parameters for a feed-forward neural network with
// embeddings.
@@ -303,7 +303,7 @@
virtual bool is_precomputed() const = 0;
protected:
- void CheckIndex(int index, int size, const string &description) const {
+ void CheckIndex(int index, int size, const std::string &description) const {
SAFTM_CHECK_GE(index, 0)
<< "Out-of-range index for " << description << ": " << index;
SAFTM_CHECK_LT(index, size)
diff --git a/native/lang_id/common/fel/feature-descriptors.cc b/native/lang_id/common/fel/feature-descriptors.cc
index bf03dd5..1293399 100644
--- a/native/lang_id/common/fel/feature-descriptors.cc
+++ b/native/lang_id/common/fel/feature-descriptors.cc
@@ -16,12 +16,15 @@
#include "lang_id/common/fel/feature-descriptors.h"
+#include <string>
+
#include "lang_id/common/lite_strings/str-cat.h"
namespace libtextclassifier3 {
namespace mobile {
-void ToFELFunction(const FeatureFunctionDescriptor &function, string *output) {
+void ToFELFunction(const FeatureFunctionDescriptor &function,
+ std::string *output) {
LiteStrAppend(output, function.type());
if (function.argument() != 0 || function.parameter_size() > 0) {
LiteStrAppend(output, "(");
@@ -40,7 +43,7 @@
}
}
-void ToFEL(const FeatureFunctionDescriptor &function, string *output) {
+void ToFEL(const FeatureFunctionDescriptor &function, std::string *output) {
ToFELFunction(function, output);
if (function.feature_size() == 1) {
LiteStrAppend(output, ".");
@@ -55,21 +58,21 @@
}
}
-void ToFEL(const FeatureExtractorDescriptor &extractor, string *output) {
+void ToFEL(const FeatureExtractorDescriptor &extractor, std::string *output) {
for (int i = 0; i < extractor.feature_size(); ++i) {
ToFEL(extractor.feature(i), output);
LiteStrAppend(output, "\n");
}
}
-string FeatureFunctionDescriptor::DebugString() const {
- string str;
+std::string FeatureFunctionDescriptor::DebugString() const {
+ std::string str;
ToFEL(*this, &str);
return str;
}
-string FeatureExtractorDescriptor::DebugString() const {
- string str;
+std::string FeatureExtractorDescriptor::DebugString() const {
+ std::string str;
ToFEL(*this, &str);
return str;
}
diff --git a/native/lang_id/common/fel/feature-descriptors.h b/native/lang_id/common/fel/feature-descriptors.h
index a9408c9..3bdc2fa 100644
--- a/native/lang_id/common/fel/feature-descriptors.h
+++ b/native/lang_id/common/fel/feature-descriptors.h
@@ -33,15 +33,15 @@
public:
Parameter() {}
- void set_name(const string &name) { name_ = name; }
- const string &name() const { return name_; }
+ void set_name(const std::string &name) { name_ = name; }
+ const std::string &name() const { return name_; }
- void set_value(const string &value) { value_ = value; }
- const string &value() const { return value_; }
+ void set_value(const std::string &value) { value_ = value; }
+ const std::string &value() const { return value_; }
private:
- string name_;
- string value_;
+ std::string name_;
+ std::string value_;
};
// Descriptor for a feature function. Used to store the results of parsing one
@@ -52,14 +52,14 @@
// Accessors for the feature function type. The function type is the string
// that the feature extractor code is registered under.
- void set_type(const string &type) { type_ = type; }
- const string &type() const { return type_; }
+ void set_type(const std::string &type) { type_ = type; }
+ const std::string &type() const { return type_; }
// Accessors for the feature function name. The function name (if available)
// is used for some log messages. Otherwise, a more precise, but also more
// verbose name based on the feature specification is used.
- void set_name(const string &name) { name_ = name; }
- const string &name() const { return name_; }
+ void set_name(const std::string &name) { name_ = name; }
+ const std::string &name() const { return name_; }
// Accessors for the default (name-less) parameter.
void set_argument(int32 argument) { argument_ = argument; }
@@ -95,14 +95,14 @@
}
// Returns human-readable representation of this FeatureFunctionDescriptor.
- string DebugString() const;
+ std::string DebugString() const;
private:
// See comments for set_type().
- string type_;
+ std::string type_;
// See comments for set_name().
- string name_;
+ std::string name_;
// See comments for set_argument().
int32 argument_ = 0;
@@ -135,7 +135,7 @@
}
// Returns human-readable representation of this FeatureExtractorDescriptor.
- string DebugString() const;
+ std::string DebugString() const;
private:
std::vector<std::unique_ptr<FeatureFunctionDescriptor>> features_;
@@ -145,13 +145,14 @@
// Appends to |*output| the FEL representation of the top-level feature from
// |function|, without diving into the nested features.
-void ToFELFunction(const FeatureFunctionDescriptor &function, string *output);
+void ToFELFunction(const FeatureFunctionDescriptor &function,
+ std::string *output);
// Appends to |*output| the FEL representation of |function|.
-void ToFEL(const FeatureFunctionDescriptor &function, string *output);
+void ToFEL(const FeatureFunctionDescriptor &function, std::string *output);
// Appends to |*output| the FEL representation of |extractor|.
-void ToFEL(const FeatureExtractorDescriptor &extractor, string *output);
+void ToFEL(const FeatureExtractorDescriptor &extractor, std::string *output);
} // namespace mobile
} // namespace nlp_saft
diff --git a/native/lang_id/common/fel/feature-extractor.cc b/native/lang_id/common/fel/feature-extractor.cc
index c256257..ab8a1a6 100644
--- a/native/lang_id/common/fel/feature-extractor.cc
+++ b/native/lang_id/common/fel/feature-extractor.cc
@@ -16,6 +16,8 @@
#include "lang_id/common/fel/feature-extractor.h"
+#include <string>
+
#include "lang_id/common/fel/feature-types.h"
#include "lang_id/common/fel/fel-parser.h"
#include "lang_id/common/lite_base/logging.h"
@@ -30,7 +32,7 @@
GenericFeatureExtractor::~GenericFeatureExtractor() {}
-bool GenericFeatureExtractor::Parse(const string &source) {
+bool GenericFeatureExtractor::Parse(const std::string &source) {
// Parse feature specification into descriptor.
FELParser parser;
@@ -61,8 +63,8 @@
return true;
}
-string GenericFeatureFunction::GetParameter(const string &name,
- const string &default_value) const {
+std::string GenericFeatureFunction::GetParameter(
+ const std::string &name, const std::string &default_value) const {
// Find named parameter in feature descriptor.
for (int i = 0; i < descriptor_->parameter_size(); ++i) {
if (name == descriptor_->parameter(i).name()) {
@@ -76,9 +78,9 @@
GenericFeatureFunction::~GenericFeatureFunction() { delete feature_type_; }
-int GenericFeatureFunction::GetIntParameter(const string &name,
+int GenericFeatureFunction::GetIntParameter(const std::string &name,
int default_value) const {
- string value_str = GetParameter(name, "");
+ std::string value_str = GetParameter(name, "");
if (value_str.empty()) {
// Parameter not specified, use default value for it.
return default_value;
@@ -92,9 +94,9 @@
return value;
}
-bool GenericFeatureFunction::GetBoolParameter(const string &name,
+bool GenericFeatureFunction::GetBoolParameter(const std::string &name,
bool default_value) const {
- string value = GetParameter(name, "");
+ std::string value = GetParameter(name, "");
if (value.empty()) return default_value;
if (value == "true") return true;
if (value == "false") return false;
@@ -121,8 +123,8 @@
return nullptr;
}
-string GenericFeatureFunction::name() const {
- string output;
+std::string GenericFeatureFunction::name() const {
+ std::string output;
if (descriptor_->name().empty()) {
if (!prefix_.empty()) {
output.append(prefix_);
diff --git a/native/lang_id/common/fel/feature-extractor.h b/native/lang_id/common/fel/feature-extractor.h
index 8763852..c09e1eb 100644
--- a/native/lang_id/common/fel/feature-extractor.h
+++ b/native/lang_id/common/fel/feature-extractor.h
@@ -126,7 +126,7 @@
// Initializes the feature extractor from the FEL specification |source|.
//
// Returns true on success, false otherwise (e.g., FEL syntax error).
- SAFTM_MUST_USE_RESULT bool Parse(const string &source);
+ SAFTM_MUST_USE_RESULT bool Parse(const std::string &source);
// Returns the feature extractor descriptor.
const FeatureExtractorDescriptor &descriptor() const { return descriptor_; }
@@ -207,30 +207,31 @@
// Returns value of parameter |name| from the feature function descriptor.
// If the parameter is not present, returns the indicated |default_value|.
- string GetParameter(const string &name, const string &default_value) const;
+ std::string GetParameter(const std::string &name,
+ const std::string &default_value) const;
// Returns value of int parameter |name| from feature function descriptor.
// If the parameter is not present, or its value can't be parsed as an int,
// returns |default_value|.
- int GetIntParameter(const string &name, int default_value) const;
+ int GetIntParameter(const std::string &name, int default_value) const;
// Returns value of bool parameter |name| from feature function descriptor.
// If the parameter is not present, or its value is not "true" or "false",
// returns |default_value|. NOTE: this method is case sensitive, it doesn't
// do any lower-casing.
- bool GetBoolParameter(const string &name, bool default_value) const;
+ bool GetBoolParameter(const std::string &name, bool default_value) const;
// Returns the FEL function description for the feature function, i.e. the
// name and parameters without the nested features.
- string FunctionName() const {
- string output;
+ std::string FunctionName() const {
+ std::string output;
ToFELFunction(*descriptor_, &output);
return output;
}
// Returns the prefix for nested feature functions. This is the prefix of this
// feature function concatenated with the feature function name.
- string SubPrefix() const {
+ std::string SubPrefix() const {
return prefix_.empty() ? FunctionName() : prefix_ + "." + FunctionName();
}
@@ -250,7 +251,7 @@
// the descriptor for the feature function. If the name is empty or the
// feature function is a variable the name is the FEL representation of the
// feature, including the prefix.
- string name() const;
+ std::string name() const;
// Returns the argument from the feature function descriptor. It defaults to
// 0 if the argument has not been specified.
@@ -259,8 +260,8 @@
}
// Returns/sets/clears function name prefix.
- const string &prefix() const { return prefix_; }
- void set_prefix(const string &prefix) { prefix_ = prefix; }
+ const std::string &prefix() const { return prefix_; }
+ void set_prefix(const std::string &prefix) { prefix_ = prefix; }
protected:
// Returns the feature type for single-type feature functions.
@@ -291,7 +292,7 @@
FeatureType *feature_type_ = nullptr;
// Prefix used for sub-feature types of this function.
- string prefix_;
+ std::string prefix_;
};
// Feature function that can extract features from an object. Templated on
@@ -340,7 +341,7 @@
// the relevant cc_library was not linked-in).
static Self *Instantiate(const GenericFeatureExtractor *extractor,
const FeatureFunctionDescriptor *fd,
- const string &prefix) {
+ const std::string &prefix) {
Self *f = Self::Create(fd->type());
if (f != nullptr) {
f->set_extractor(extractor);
@@ -439,7 +440,7 @@
SAFTM_MUST_USE_RESULT static bool CreateNested(
const GenericFeatureExtractor *extractor,
const FeatureFunctionDescriptor *fd, std::vector<NES *> *functions,
- const string &prefix) {
+ const std::string &prefix) {
for (int i = 0; i < fd->feature_size(); ++i) {
const FeatureFunctionDescriptor &sub = fd->feature(i);
NES *f = NES::Instantiate(extractor, &sub, prefix);
diff --git a/native/lang_id/common/fel/feature-types.h b/native/lang_id/common/fel/feature-types.h
index 18cf69a..ae422af 100644
--- a/native/lang_id/common/fel/feature-types.h
+++ b/native/lang_id/common/fel/feature-types.h
@@ -44,21 +44,21 @@
class FeatureType {
public:
// Initializes a feature type.
- explicit FeatureType(const string &name)
- : name_(name), base_(0),
- is_continuous_(name.find("continuous") != string::npos) {
- }
+ explicit FeatureType(const std::string &name)
+ : name_(name),
+ base_(0),
+ is_continuous_(name.find("continuous") != std::string::npos) {}
virtual ~FeatureType() {}
// Converts a feature value to a name.
- virtual string GetFeatureValueName(FeatureValue value) const = 0;
+ virtual std::string GetFeatureValueName(FeatureValue value) const = 0;
// Returns the size of the feature values domain.
virtual int64 GetDomainSize() const = 0;
// Returns the feature type name.
- const string &name() const { return name_; }
+ const std::string &name() const { return name_; }
Predicate base() const { return base_; }
void set_base(Predicate base) { base_ = base; }
@@ -68,7 +68,7 @@
private:
// Feature type name.
- string name_;
+ std::string name_;
// "Base" feature value: i.e. a "slot" in a global ordering of features.
Predicate base_;
@@ -91,8 +91,8 @@
// };
class EnumFeatureType : public FeatureType {
public:
- EnumFeatureType(const string &name,
- const std::map<FeatureValue, string> &value_names)
+ EnumFeatureType(const std::string &name,
+ const std::map<FeatureValue, std::string> &value_names)
: FeatureType(name), value_names_(value_names) {
for (const auto &pair : value_names) {
SAFTM_CHECK_GE(pair.first, 0)
@@ -102,7 +102,7 @@
}
// Returns the feature name for a given feature value.
- string GetFeatureValueName(FeatureValue value) const override {
+ std::string GetFeatureValueName(FeatureValue value) const override {
auto it = value_names_.find(value);
if (it == value_names_.end()) {
SAFTM_LOG(ERROR) << "Invalid feature value " << value << " for "
@@ -121,17 +121,18 @@
FeatureValue domain_size_ = 0;
// Names of feature values.
- std::map<FeatureValue, string> value_names_;
+ std::map<FeatureValue, std::string> value_names_;
};
// Feature type for binary features.
class BinaryFeatureType : public FeatureType {
public:
- BinaryFeatureType(const string &name, const string &off, const string &on)
+ BinaryFeatureType(const std::string &name, const std::string &off,
+ const std::string &on)
: FeatureType(name), off_(off), on_(on) {}
// Returns the feature name for a given feature value.
- string GetFeatureValueName(FeatureValue value) const override {
+ std::string GetFeatureValueName(FeatureValue value) const override {
if (value == 0) return off_;
if (value == 1) return on_;
return "";
@@ -142,19 +143,19 @@
private:
// Feature value names for on and off.
- string off_;
- string on_;
+ std::string off_;
+ std::string on_;
};
// Feature type for numeric features.
class NumericFeatureType : public FeatureType {
public:
// Initializes numeric feature.
- NumericFeatureType(const string &name, FeatureValue size)
+ NumericFeatureType(const std::string &name, FeatureValue size)
: FeatureType(name), size_(size) {}
// Returns numeric feature value.
- string GetFeatureValueName(FeatureValue value) const override {
+ std::string GetFeatureValueName(FeatureValue value) const override {
if (value < 0) return "";
return LiteStrCat(value);
}
@@ -170,14 +171,14 @@
// Feature type for byte features, including an "outside" value.
class ByteFeatureType : public NumericFeatureType {
public:
- explicit ByteFeatureType(const string &name)
+ explicit ByteFeatureType(const std::string &name)
: NumericFeatureType(name, 257) {}
- string GetFeatureValueName(FeatureValue value) const override {
+ std::string GetFeatureValueName(FeatureValue value) const override {
if (value == 256) {
return "<NULL>";
}
- string result;
+ std::string result;
result += static_cast<char>(value);
return result;
}
diff --git a/native/lang_id/common/fel/fel-parser.cc b/native/lang_id/common/fel/fel-parser.cc
index 4346fb7..2682941 100644
--- a/native/lang_id/common/fel/fel-parser.cc
+++ b/native/lang_id/common/fel/fel-parser.cc
@@ -17,6 +17,7 @@
#include "lang_id/common/fel/fel-parser.h"
#include <ctype.h>
+
#include <string>
#include "lang_id/common/lite_base/logging.h"
@@ -46,7 +47,7 @@
}
} // namespace
-bool FELParser::Initialize(const string &source) {
+bool FELParser::Initialize(const std::string &source) {
// Initialize parser state.
source_ = source;
current_ = source_.begin();
@@ -57,9 +58,9 @@
return NextItem();
}
-void FELParser::ReportError(const string &error_message) {
+void FELParser::ReportError(const std::string &error_message) {
const int position = item_start_ - line_start_ + 1;
- const string line(line_start_, current_);
+ const std::string line(line_start_, current_);
SAFTM_LOG(ERROR) << "Error in feature model, line " << item_line_number_
<< ", position " << position << ": " << error_message
@@ -104,7 +105,7 @@
// Parse number.
if (IsValidCharAtStartOfNumber(CurrentChar())) {
- string::iterator start = current_;
+ std::string::iterator start = current_;
Next();
while (!eos() && IsValidCharInsideNumber(CurrentChar())) Next();
item_text_.assign(start, current_);
@@ -115,7 +116,7 @@
// Parse string.
if (CurrentChar() == '"') {
Next();
- string::iterator start = current_;
+ std::string::iterator start = current_;
while (CurrentChar() != '"') {
if (eos()) {
ReportError("Unterminated string");
@@ -131,7 +132,7 @@
// Parse identifier name.
if (IsValidCharAtStartOfIdentifier(CurrentChar())) {
- string::iterator start = current_;
+ std::string::iterator start = current_;
while (!eos() && IsValidCharInsideIdentifier(CurrentChar())) {
Next();
}
@@ -146,7 +147,7 @@
return true;
}
-bool FELParser::Parse(const string &source,
+bool FELParser::Parse(const std::string &source,
FeatureExtractorDescriptor *result) {
// Initialize parser.
if (!Initialize(source)) {
@@ -159,7 +160,7 @@
ReportError("Feature type name expected");
return false;
}
- string name = item_text_;
+ std::string name = item_text_;
if (!NextItem()) {
return false;
}
@@ -204,7 +205,7 @@
ReportError("Feature name expected");
return false;
}
- string name = item_text_;
+ std::string name = item_text_;
if (!NextItem()) return false;
// Set feature name.
@@ -219,7 +220,7 @@
ReportError("Feature type name expected");
return false;
}
- string type = item_text_;
+ std::string type = item_text_;
if (!NextItem()) return false;
// Parse sub-feature.
@@ -234,7 +235,7 @@
ReportError("Feature type name expected");
return false;
}
- string type = item_text_;
+ std::string type = item_text_;
if (!NextItem()) return false;
// Parse sub-feature.
@@ -259,7 +260,7 @@
// Set default argument for feature.
result->set_argument(argument);
} else if (item_type_ == NAME) {
- string name = item_text_;
+ std::string name = item_text_;
if (!NextItem()) return false;
if (item_type_ != '=') {
ReportError("= expected");
@@ -270,7 +271,7 @@
ReportError("Parameter value expected");
return false;
}
- string value = item_text_;
+ std::string value = item_text_;
if (!NextItem()) return false;
// Add parameter to feature.
diff --git a/native/lang_id/common/fel/fel-parser.h b/native/lang_id/common/fel/fel-parser.h
index eacb442..d2c454c 100644
--- a/native/lang_id/common/fel/fel-parser.h
+++ b/native/lang_id/common/fel/fel-parser.h
@@ -53,15 +53,15 @@
public:
// Parses fml specification into feature extractor descriptor.
// Returns true on success, false on error (e.g., syntax errors).
- bool Parse(const string &source, FeatureExtractorDescriptor *result);
+ bool Parse(const std::string &source, FeatureExtractorDescriptor *result);
private:
// Initializes the parser with the source text.
// Returns true on success, false on syntax error.
- bool Initialize(const string &source);
+ bool Initialize(const std::string &source);
// Outputs an error message, with context info.
- void ReportError(const string &error_message);
+ void ReportError(const std::string &error_message);
// Moves to the next input character.
void Next();
@@ -104,19 +104,19 @@
};
// Source text.
- string source_;
+ std::string source_;
// Current input position.
- string::iterator current_;
+ std::string::iterator current_;
// Line number for current input position.
int line_number_;
// Start position for current item.
- string::iterator item_start_;
+ std::string::iterator item_start_;
// Start position for current line.
- string::iterator line_start_;
+ std::string::iterator line_start_;
// Line number for current item.
int item_line_number_;
@@ -126,7 +126,7 @@
int item_type_;
// Text for current item.
- string item_text_;
+ std::string item_text_;
};
} // namespace mobile
diff --git a/native/lang_id/common/fel/task-context.cc b/native/lang_id/common/fel/task-context.cc
index f8b0701..5e1d7f6 100644
--- a/native/lang_id/common/fel/task-context.cc
+++ b/native/lang_id/common/fel/task-context.cc
@@ -23,7 +23,7 @@
namespace libtextclassifier3 {
namespace mobile {
-string TaskContext::GetInputPath(const string &name) const {
+std::string TaskContext::GetInputPath(const std::string &name) const {
auto it = inputs_.find(name);
if (it != inputs_.end()) {
return it->second;
@@ -31,11 +31,13 @@
return "";
}
-void TaskContext::SetInputPath(const string &name, const string &path) {
+void TaskContext::SetInputPath(const std::string &name,
+ const std::string &path) {
inputs_[name] = path;
}
-string TaskContext::Get(const string &name, const char *defval) const {
+std::string TaskContext::Get(const std::string &name,
+ const char *defval) const {
auto it = parameters_.find(name);
if (it != parameters_.end()) {
return it->second;
@@ -43,8 +45,8 @@
return defval;
}
-int TaskContext::Get(const string &name, int defval) const {
- const string s = Get(name, "");
+int TaskContext::Get(const std::string &name, int defval) const {
+ const std::string s = Get(name, "");
int value = defval;
if (LiteAtoi(s, &value)) {
return value;
@@ -52,8 +54,8 @@
return defval;
}
-float TaskContext::Get(const string &name, float defval) const {
- const string s = Get(name, "");
+float TaskContext::Get(const std::string &name, float defval) const {
+ const std::string s = Get(name, "");
float value = defval;
if (LiteAtof(s, &value)) {
return value;
@@ -61,12 +63,13 @@
return defval;
}
-bool TaskContext::Get(const string &name, bool defval) const {
- string value = Get(name, "");
+bool TaskContext::Get(const std::string &name, bool defval) const {
+ std::string value = Get(name, "");
return value.empty() ? defval : value == "true";
}
-void TaskContext::SetParameter(const string &name, const string &value) {
+void TaskContext::SetParameter(const std::string &name,
+ const std::string &value) {
parameters_[name] = value;
}
diff --git a/native/lang_id/common/fel/task-context.h b/native/lang_id/common/fel/task-context.h
index ddc8cfe..b6bcd92 100644
--- a/native/lang_id/common/fel/task-context.h
+++ b/native/lang_id/common/fel/task-context.h
@@ -43,27 +43,27 @@
// Returns path for the input named |name|. Returns empty string ("") if
// there is no input with that name. Note: this can be a standard file path,
// or a path in a more special file system.
- string GetInputPath(const string &name) const;
+ std::string GetInputPath(const std::string &name) const;
// Sets path for input |name|. Previous path, if any, is overwritten.
- void SetInputPath(const string &name, const string &path);
+ void SetInputPath(const std::string &name, const std::string &path);
// Returns parameter value. If the parameter is not specified in this
// context, the default value is returned.
- string Get(const string &name, const char *defval) const;
- int Get(const string &name, int defval) const;
- float Get(const string &name, float defval) const;
- bool Get(const string &name, bool defval) const;
+ std::string Get(const std::string &name, const char *defval) const;
+ int Get(const std::string &name, int defval) const;
+ float Get(const std::string &name, float defval) const;
+ bool Get(const std::string &name, bool defval) const;
// Sets value of parameter |name| to |value|.
- void SetParameter(const string &name, const string &value);
+ void SetParameter(const std::string &name, const std::string &value);
private:
// Maps input name -> path.
- std::map<string, string> inputs_;
+ std::map<std::string, std::string> inputs_;
// Maps parameter name -> value.
- std::map<string, string> parameters_;
+ std::map<std::string, std::string> parameters_;
};
} // namespace mobile
diff --git a/native/lang_id/common/fel/workspace.cc b/native/lang_id/common/fel/workspace.cc
index 8cab281..af41e29 100644
--- a/native/lang_id/common/fel/workspace.cc
+++ b/native/lang_id/common/fel/workspace.cc
@@ -29,12 +29,12 @@
return counter++;
}
-string WorkspaceRegistry::DebugString() const {
- string str;
+std::string WorkspaceRegistry::DebugString() const {
+ std::string str;
for (auto &it : workspace_names_) {
- const string &type_name = workspace_types_.at(it.first);
+ const std::string &type_name = workspace_types_.at(it.first);
for (size_t index = 0; index < it.second.size(); ++index) {
- const string &workspace_name = it.second[index];
+ const std::string &workspace_name = it.second[index];
str.append("\n ");
str.append(type_name);
str.append(" :: ");
@@ -52,7 +52,7 @@
VectorIntWorkspace::VectorIntWorkspace(const std::vector<int> &elements)
: elements_(elements) {}
-string VectorIntWorkspace::TypeName() { return "Vector"; }
+std::string VectorIntWorkspace::TypeName() { return "Vector"; }
} // namespace mobile
} // namespace nlp_saft
diff --git a/native/lang_id/common/fel/workspace.h b/native/lang_id/common/fel/workspace.h
index 09095e4..f13d802 100644
--- a/native/lang_id/common/fel/workspace.h
+++ b/native/lang_id/common/fel/workspace.h
@@ -22,6 +22,7 @@
#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_WORKSPACE_H_
#include <stddef.h>
+
#include <string>
#include <unordered_map>
#include <utility>
@@ -69,11 +70,11 @@
// Returns the index of a named workspace, adding it to the registry first
// if necessary.
template <class W>
- int Request(const string &name) {
+ int Request(const std::string &name) {
const int id = TypeId<W>::type_id;
max_workspace_id_ = std::max(id, max_workspace_id_);
workspace_types_[id] = W::TypeName();
- std::vector<string> &names = workspace_names_[id];
+ std::vector<std::string> &names = workspace_names_[id];
for (int i = 0; i < names.size(); ++i) {
if (names[i] == name) return i;
}
@@ -86,20 +87,20 @@
return max_workspace_id_;
}
- const std::unordered_map<int, std::vector<string> > &WorkspaceNames()
+ const std::unordered_map<int, std::vector<std::string> > &WorkspaceNames()
const {
return workspace_names_;
}
// Returns a string describing the registered workspaces.
- string DebugString() const;
+ std::string DebugString() const;
private:
// Workspace type names, indexed as workspace_types_[typeid].
- std::unordered_map<int, string> workspace_types_;
+ std::unordered_map<int, std::string> workspace_types_;
// Workspace names, indexed as workspace_names_[typeid][workspace].
- std::unordered_map<int, std::vector<string> > workspace_names_;
+ std::unordered_map<int, std::vector<std::string> > workspace_names_;
// The maximum workspace id that has been registered.
int max_workspace_id_ = 0;
@@ -182,7 +183,7 @@
VectorIntWorkspace(int size, int value);
// Returns the name of this type of workspace.
- static string TypeName();
+ static std::string TypeName();
// Returns the i'th element.
int element(int i) const { return elements_[i]; }
diff --git a/native/lang_id/common/file/file-utils.cc b/native/lang_id/common/file/file-utils.cc
index 108c7d5..1ee229f 100644
--- a/native/lang_id/common/file/file-utils.cc
+++ b/native/lang_id/common/file/file-utils.cc
@@ -21,6 +21,8 @@
#include <sys/stat.h>
#include <sys/types.h>
+#include <string>
+
#include "lang_id/common/lite_base/logging.h"
#include "lang_id/common/lite_strings/stringpiece.h"
@@ -29,7 +31,7 @@
namespace file_utils {
-bool GetFileContent(const string &filename, string *content) {
+bool GetFileContent(const std::string &filename, std::string *content) {
ScopedMmap scoped_mmap(filename);
const MmapHandle &handle = scoped_mmap.handle();
if (!handle.ok()) {
@@ -41,7 +43,7 @@
return true;
}
-bool FileExists(const string &filename) {
+bool FileExists(const std::string &filename) {
struct stat s = {0};
if (!stat(filename.c_str(), &s)) {
return s.st_mode & S_IFREG;
@@ -50,7 +52,7 @@
}
}
-bool DirectoryExists(const string &dirpath) {
+bool DirectoryExists(const std::string &dirpath) {
struct stat s = {0};
if (!stat(dirpath.c_str(), &s)) {
return s.st_mode & S_IFDIR;
diff --git a/native/lang_id/common/file/file-utils.h b/native/lang_id/common/file/file-utils.h
index 6377d7a..e8b0fef 100644
--- a/native/lang_id/common/file/file-utils.h
+++ b/native/lang_id/common/file/file-utils.h
@@ -18,6 +18,7 @@
#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_FILE_FILE_UTILS_H_
#include <stddef.h>
+
#include <string>
#include "lang_id/common/file/mmap.h"
@@ -30,7 +31,7 @@
// Reads the entire content of a file into a string. Returns true on success,
// false on error.
-bool GetFileContent(const string &filename, string *content);
+bool GetFileContent(const std::string &filename, std::string *content);
// Parses a proto from its serialized representation in memory. That
// representation starts at address |data| and should contain exactly
@@ -57,8 +58,8 @@
//
// Note: when we compile for Android, the proto parsing methods need to know the
// type of the message they are parsing. We use template polymorphism for that.
-template<class Proto>
-bool ReadProtoFromFile(const string &filename, Proto *proto) {
+template <class Proto>
+bool ReadProtoFromFile(const std::string &filename, Proto *proto) {
ScopedMmap scoped_mmap(filename);
const MmapHandle &handle = scoped_mmap.handle();
if (!handle.ok()) {
@@ -69,11 +70,11 @@
// Returns true if filename is the name of an existing file, and false
// otherwise.
-bool FileExists(const string &filename);
+bool FileExists(const std::string &filename);
// Returns true if dirpath is the path to an existing directory, and false
// otherwise.
-bool DirectoryExists(const string &dirpath);
+bool DirectoryExists(const std::string &dirpath);
} // namespace file_utils
diff --git a/native/lang_id/common/file/mmap.cc b/native/lang_id/common/file/mmap.cc
index 89efa99..3dcdd3b 100644
--- a/native/lang_id/common/file/mmap.cc
+++ b/native/lang_id/common/file/mmap.cc
@@ -20,9 +20,14 @@
#include <fcntl.h>
#include <stdint.h>
#include <string.h>
+#ifdef _WIN32
+#include <winbase.h>
+#include <windows.h>
+#else
#include <sys/mman.h>
-#include <sys/stat.h>
#include <unistd.h>
+#endif
+#include <sys/stat.h>
#include "lang_id/common/lite_base/logging.h"
#include "lang_id/common/lite_base/macros.h"
@@ -31,21 +36,127 @@
namespace mobile {
namespace {
-inline string GetLastSystemError() {
- return string(strerror(errno));
+inline MmapHandle GetErrorMmapHandle() { return MmapHandle(nullptr, 0); }
+} // anonymous namespace
+
+#ifdef _WIN32
+
+namespace {
+inline std::string GetLastSystemError() {
+ LPTSTR message_buffer;
+ DWORD error_code = GetLastError();
+ FormatMessage(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM |
+ FORMAT_MESSAGE_IGNORE_INSERTS,
+ NULL, error_code, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT),
+ (LPTSTR)&message_buffer, 0, NULL);
+ std::string result(message_buffer);
+ LocalFree(message_buffer);
+ return result;
}
-inline MmapHandle GetErrorMmapHandle() {
- return MmapHandle(nullptr, 0);
+// Class for automatically closing a Win32 HANDLE on exit from a scope.
+class Win32HandleCloser {
+ public:
+ explicit Win32HandleCloser(HANDLE handle) : handle_(handle) {}
+ ~Win32HandleCloser() {
+ bool result = CloseHandle(handle_);
+ if (!result) {
+ const DWORD last_error = GetLastError();
+ SAFTM_LOG(ERROR) << "Error closing handle: " << last_error << ": "
+ << GetLastSystemError();
+ }
+ }
+
+ private:
+ const HANDLE handle_;
+
+ SAFTM_DISALLOW_COPY_AND_ASSIGN(Win32HandleCloser);
+};
+} // namespace
+
+MmapHandle MmapFile(const std::string &filename) {
+ HANDLE handle =
+ CreateFile(filename.c_str(), // File to open.
+ GENERIC_READ, // Open for reading.
+ FILE_SHARE_READ, // Share for reading.
+ NULL, // Default security.
+ OPEN_EXISTING, // Existing file only.
+ FILE_ATTRIBUTE_NORMAL | FILE_FLAG_OVERLAPPED, // Normal file.
+ NULL); // No attr. template.
+ if (handle == INVALID_HANDLE_VALUE) {
+ const std::string last_error = GetLastSystemError();
+ SAFTM_LOG(ERROR) << "Error opening " << filename << ": " << last_error;
+ return GetErrorMmapHandle();
+ }
+
+ // Make sure we close handle no matter how we exit this function.
+ Win32HandleCloser handle_closer(handle);
+
+ return MmapFile(handle);
}
+MmapHandle MmapFile(HANDLE file_handle) {
+ // Get the file size.
+ DWORD file_size_high = 0;
+ DWORD file_size_low = GetFileSize(file_handle, &file_size_high);
+ if (file_size_low == INVALID_FILE_SIZE && GetLastError() != NO_ERROR) {
+ const std::string last_error = GetLastSystemError();
+ SAFTM_LOG(ERROR) << "Unable to stat fd: " << last_error;
+ return GetErrorMmapHandle();
+ }
+ size_t file_size_in_bytes = (static_cast<size_t>(file_size_high) << 32) +
+ static_cast<size_t>(file_size_low);
+
+ // Create a file mapping object that refers to the file.
+ HANDLE file_mapping_object =
+ CreateFileMappingA(file_handle, nullptr, PAGE_READONLY, 0, 0, nullptr);
+ if (file_mapping_object == NULL) {
+ const std::string last_error = GetLastSystemError();
+ SAFTM_LOG(ERROR) << "Error while mmapping: " << last_error;
+ return GetErrorMmapHandle();
+ }
+ Win32HandleCloser handle_closer(file_mapping_object);
+
+ // Map the file mapping object into memory.
+ void *mmap_addr =
+ MapViewOfFile(file_mapping_object, FILE_MAP_READ, 0, 0, // File offset.
+ 0 // Number of bytes to map; 0 means map the whole file.
+ );
+ if (mmap_addr == nullptr) {
+ const std::string last_error = GetLastSystemError();
+ SAFTM_LOG(ERROR) << "Error while mmapping: " << last_error;
+ return GetErrorMmapHandle();
+ }
+
+ return MmapHandle(mmap_addr, file_size_in_bytes);
+}
+
+bool Unmap(MmapHandle mmap_handle) {
+ if (!mmap_handle.ok()) {
+ // Unmapping something that hasn't been mapped is trivially successful.
+ return true;
+ }
+ bool succeeded = UnmapViewOfFile(mmap_handle.start());
+ if (!succeeded) {
+ const std::string last_error = GetLastSystemError();
+ SAFTM_LOG(ERROR) << "Error during Unmap / UnmapViewOfFile: " << last_error;
+ return false;
+ }
+ return true;
+}
+
+#else
+
+namespace {
+inline std::string GetLastSystemError() { return std::string(strerror(errno)); }
+
class FileCloser {
public:
explicit FileCloser(int fd) : fd_(fd) {}
~FileCloser() {
int result = close(fd_);
if (result != 0) {
- const string last_error = GetLastSystemError();
+ const std::string last_error = GetLastSystemError();
SAFTM_LOG(ERROR) << "Error closing file descriptor: " << last_error;
}
}
@@ -56,11 +167,11 @@
};
} // namespace
-MmapHandle MmapFile(const string &filename) {
+MmapHandle MmapFile(const std::string &filename) {
int fd = open(filename.c_str(), O_RDONLY);
if (fd < 0) {
- const string last_error = GetLastSystemError();
+ const std::string last_error = GetLastSystemError();
SAFTM_LOG(ERROR) << "Error opening " << filename << ": " << last_error;
return GetErrorMmapHandle();
}
@@ -77,7 +188,7 @@
// Get file stats to obtain file size.
struct stat sb;
if (fstat(fd, &sb) != 0) {
- const string last_error = GetLastSystemError();
+ const std::string last_error = GetLastSystemError();
SAFTM_LOG(ERROR) << "Unable to stat fd: " << last_error;
return GetErrorMmapHandle();
}
@@ -108,7 +219,7 @@
// file_size_in_bytes (2nd argument) means we map all bytes from the file.
0);
if (mmap_addr == MAP_FAILED) {
- const string last_error = GetLastSystemError();
+ const std::string last_error = GetLastSystemError();
SAFTM_LOG(ERROR) << "Error while mmapping: " << last_error;
return GetErrorMmapHandle();
}
@@ -122,12 +233,14 @@
return true;
}
if (munmap(mmap_handle.start(), mmap_handle.num_bytes()) != 0) {
- const string last_error = GetLastSystemError();
+ const std::string last_error = GetLastSystemError();
SAFTM_LOG(ERROR) << "Error during Unmap / munmap: " << last_error;
return false;
}
return true;
}
+#endif
+
} // namespace mobile
} // namespace nlp_saft
diff --git a/native/lang_id/common/file/mmap.h b/native/lang_id/common/file/mmap.h
index 6131803..f785465 100644
--- a/native/lang_id/common/file/mmap.h
+++ b/native/lang_id/common/file/mmap.h
@@ -23,6 +23,11 @@
#include "lang_id/common/lite_strings/stringpiece.h"
+#ifdef _WIN32
+#define WIN32_LEAN_AND_MEAN
+#include <windows.h>
+#endif
+
namespace libtextclassifier3 {
namespace mobile {
@@ -83,10 +88,16 @@
// Note: one can read *and* write the num_bytes bytes from start, but those
// writes are not propagated to the underlying file, nor to other processes that
// may have mmapped that file (all changes are local to current process).
-MmapHandle MmapFile(const string &filename);
+MmapHandle MmapFile(const std::string &filename);
-// Like MmapFile(const string &filename), but uses a file descriptor.
-MmapHandle MmapFile(int fd);
+#ifdef _WIN32
+using FileDescriptorOrHandle = HANDLE;
+#else
+using FileDescriptorOrHandle = int;
+#endif
+
+// Like MmapFile(const std::string &filename), but uses a file descriptor.
+MmapHandle MmapFile(FileDescriptorOrHandle fd);
// Unmaps a file mapped using MmapFile. Returns true on success, false
// otherwise.
@@ -96,11 +107,10 @@
// destruction.
class ScopedMmap {
public:
- explicit ScopedMmap(const string &filename)
+ explicit ScopedMmap(const std::string &filename)
: handle_(MmapFile(filename)) {}
- explicit ScopedMmap(int fd)
- : handle_(MmapFile(fd)) {}
+ explicit ScopedMmap(FileDescriptorOrHandle fd) : handle_(MmapFile(fd)) {}
~ScopedMmap() {
if (handle_.ok()) {
diff --git a/native/lang_id/common/flatbuffers/model-utils.cc b/native/lang_id/common/flatbuffers/model-utils.cc
index 2c57aa2..66f7f38 100644
--- a/native/lang_id/common/flatbuffers/model-utils.cc
+++ b/native/lang_id/common/flatbuffers/model-utils.cc
@@ -18,6 +18,8 @@
#include <string.h>
+#include <string>
+
#include "lang_id/common/lite_base/logging.h"
#include "lang_id/common/math/checksum.h"
@@ -45,7 +47,7 @@
<< " vs " << expected_crc32;
return true;
}
- SAFTM_LOG(INFO) << "Successfully checked CRC32 " << actual_crc32;
+ SAFTM_DLOG(INFO) << "Successfully checked CRC32 " << actual_crc32;
return false;
}
} // namespace
@@ -71,7 +73,7 @@
return model;
}
-const ModelInput *GetInputByName(const Model *model, const string &name) {
+const ModelInput *GetInputByName(const Model *model, const std::string &name) {
if (model == nullptr) {
SAFTM_LOG(ERROR) << "GetInputByName called with model == nullptr";
return nullptr;
@@ -129,7 +131,7 @@
SAFTM_LOG(ERROR) << "null parameter name";
return false;
}
- const string name = p->name()->str();
+ const std::string name = p->name()->str();
if (name.empty()) {
SAFTM_LOG(ERROR) << "empty parameter name";
return false;
diff --git a/native/lang_id/common/flatbuffers/model-utils.h b/native/lang_id/common/flatbuffers/model-utils.h
index 5427f70..197e1e3 100644
--- a/native/lang_id/common/flatbuffers/model-utils.h
+++ b/native/lang_id/common/flatbuffers/model-utils.h
@@ -46,7 +46,7 @@
// Returns the |model| input with specified |name|. Returns nullptr if no such
// input exists. If |model| contains multiple inputs with that |name|, returns
// the first one (model builders should avoid building such models).
-const ModelInput *GetInputByName(const Model *model, const string &name);
+const ModelInput *GetInputByName(const Model *model, const std::string &name);
// Returns a StringPiece pointing to the bytes for the content of |input|. In
// case of errors, returns StringPiece(nullptr, 0).
diff --git a/native/lang_id/common/lite_base/compact-logging-raw.cc b/native/lang_id/common/lite_base/compact-logging-raw.cc
index 53dfc8e..27c6446 100644
--- a/native/lang_id/common/lite_base/compact-logging-raw.cc
+++ b/native/lang_id/common/lite_base/compact-logging-raw.cc
@@ -17,6 +17,7 @@
#include "lang_id/common/lite_base/compact-logging-raw.h"
#include <stdio.h>
+
#include <string>
// NOTE: this file contains two implementations: one for Android, one for all
@@ -48,8 +49,8 @@
}
} // namespace
-void LowLevelLogging(LogSeverity severity, const string &tag,
- const string &message) {
+void LowLevelLogging(LogSeverity severity, const std::string &tag,
+ const std::string &message) {
const int android_log_level = GetAndroidLogLevel(severity);
#if !defined(SAFTM_DEBUG_LOGGING)
if (android_log_level != ANDROID_LOG_ERROR &&
@@ -89,8 +90,8 @@
}
} // namespace
-void LowLevelLogging(LogSeverity severity, const string &tag,
- const string &message) {
+void LowLevelLogging(LogSeverity severity, const std::string &tag,
+ const std::string &message) {
fprintf(stderr, "[%s] %s : %s\n", LogSeverityToString(severity), tag.c_str(),
message.c_str());
fflush(stderr);
diff --git a/native/lang_id/common/lite_base/compact-logging-raw.h b/native/lang_id/common/lite_base/compact-logging-raw.h
index f67287c..d77a990 100644
--- a/native/lang_id/common/lite_base/compact-logging-raw.h
+++ b/native/lang_id/common/lite_base/compact-logging-raw.h
@@ -28,8 +28,8 @@
// Low-level logging primitive. Logs a message, with the indicated log
// severity. From android/log.h: "the tag normally corresponds to the component
// that emits the log message, and should be reasonably small".
-void LowLevelLogging(LogSeverity severity, const string &tag,
- const string &message);
+void LowLevelLogging(LogSeverity severity, const std::string &tag,
+ const std::string &message);
} // namespace internal_logging
} // namespace mobile
diff --git a/native/lang_id/common/lite_base/compact-logging.h b/native/lang_id/common/lite_base/compact-logging.h
index eccb7d1..29450b1 100644
--- a/native/lang_id/common/lite_base/compact-logging.h
+++ b/native/lang_id/common/lite_base/compact-logging.h
@@ -36,7 +36,7 @@
// Needed for invocation in SAFTM_CHECK macro.
explicit operator bool() const { return true; }
- string message;
+ std::string message;
};
template <typename T>
@@ -53,7 +53,7 @@
}
inline LoggingStringStream &operator<<(LoggingStringStream &stream,
- const string &message) {
+ const std::string &message) {
stream.message.append(message);
return stream;
}
diff --git a/native/lang_id/common/lite_strings/numbers.cc b/native/lang_id/common/lite_strings/numbers.cc
index e0c66f3..f933f04 100644
--- a/native/lang_id/common/lite_strings/numbers.cc
+++ b/native/lang_id/common/lite_strings/numbers.cc
@@ -18,6 +18,7 @@
#include <ctype.h>
#include <stdlib.h>
+
#include <climits>
namespace libtextclassifier3 {
diff --git a/native/lang_id/common/lite_strings/numbers.h b/native/lang_id/common/lite_strings/numbers.h
index 4b3c93c..f832a96 100644
--- a/native/lang_id/common/lite_strings/numbers.h
+++ b/native/lang_id/common/lite_strings/numbers.h
@@ -40,14 +40,14 @@
// underflows.
bool LiteAtoi(const char *c_str, int *value);
-inline bool LiteAtoi(const string &s, int *value) {
+inline bool LiteAtoi(const std::string &s, int *value) {
return LiteAtoi(s.c_str(), value);
}
inline bool LiteAtoi(StringPiece sp, int *value) {
// Unfortunately, we can't directly call LiteAtoi(sp.data()): LiteAtoi(const
// char *) needs a zero-terminated string.
- const string temp(sp.data(), sp.size());
+ const std::string temp(sp.data(), sp.size());
return LiteAtoi(temp.c_str(), value);
}
@@ -57,14 +57,14 @@
// TODO(salcianu): fix that.
bool LiteAtof(const char *c_str, float *value);
-inline bool LiteAtof(const string &s, float *value) {
+inline bool LiteAtof(const std::string &s, float *value) {
return LiteAtof(s.c_str(), value);
}
inline bool LiteAtof(StringPiece sp, float *value) {
// Unfortunately, we can't directly call LiteAtoi(sp.data()): LiteAtoi(const
// char *) needs a zero-terminated string.
- const string temp(sp.data(), sp.size());
+ const std::string temp(sp.data(), sp.size());
return LiteAtof(temp.c_str(), value);
}
diff --git a/native/lang_id/common/lite_strings/str-cat.h b/native/lang_id/common/lite_strings/str-cat.h
index f0c1682..25cec4d 100644
--- a/native/lang_id/common/lite_strings/str-cat.h
+++ b/native/lang_id/common/lite_strings/str-cat.h
@@ -41,10 +41,10 @@
// string that contains the representation of v. For examples, see
// str-cat_test.cc.
template <typename T>
-inline string LiteStrCat(T v) {
+inline std::string LiteStrCat(T v) {
#ifdef COMPILER_MSVC
std::stringstream stream;
- stream << input;
+ stream << v;
return stream.str();
#else
return std::to_string(v);
@@ -52,42 +52,42 @@
}
template <>
-inline string LiteStrCat(const char *v) {
- return string(v);
+inline std::string LiteStrCat(const char *v) {
+ return std::string(v);
}
-// TODO(salcianu): use a reference type (const string &). For some reason, I
-// couldn't get that to work on a first try.
+// TODO(salcianu): use a reference type (const std::string &). For some reason,
+// I couldn't get that to work on a first try.
template <>
-inline string LiteStrCat(string v) {
+inline std::string LiteStrCat(std::string v) {
return v;
}
template <>
-inline string LiteStrCat(char v) {
- return string(1, v);
+inline std::string LiteStrCat(char v) {
+ return std::string(1, v);
}
// Less efficient but more compact version of absl::LiteStrAppend().
template <typename T>
-inline void LiteStrAppend(string *dest, T v) {
+inline void LiteStrAppend(std::string *dest, T v) {
dest->append(LiteStrCat(v)); // NOLINT
}
template <typename T1, typename T2>
-inline void LiteStrAppend(string *dest, T1 v1, T2 v2) {
+inline void LiteStrAppend(std::string *dest, T1 v1, T2 v2) {
dest->append(LiteStrCat(v1)); // NOLINT
dest->append(LiteStrCat(v2)); // NOLINT
}
template <typename T1, typename T2, typename T3>
-inline void LiteStrAppend(string *dest, T1 v1, T2 v2, T3 v3) {
+inline void LiteStrAppend(std::string *dest, T1 v1, T2 v2, T3 v3) {
LiteStrAppend(dest, v1, v2);
dest->append(LiteStrCat(v3)); // NOLINT
}
template <typename T1, typename T2, typename T3, typename T4>
-inline void LiteStrAppend(string *dest, T1 v1, T2 v2, T3 v3, T4 v4) {
+inline void LiteStrAppend(std::string *dest, T1 v1, T2 v2, T3 v3, T4 v4) {
LiteStrAppend(dest, v1, v2, v3);
dest->append(LiteStrCat(v4)); // NOLINT
}
diff --git a/native/lang_id/common/lite_strings/stringpiece.h b/native/lang_id/common/lite_strings/stringpiece.h
index 59a2176..6565053 100644
--- a/native/lang_id/common/lite_strings/stringpiece.h
+++ b/native/lang_id/common/lite_strings/stringpiece.h
@@ -49,10 +49,10 @@
// Intentionally no "explicit" keyword: in function calls, we want strings to
// be converted to StringPiece implicitly.
- StringPiece(const string &s) // NOLINT
+ StringPiece(const std::string &s) // NOLINT
: StringPiece(s.data(), s.size()) {}
- StringPiece(const string &s, int offset, int len)
+ StringPiece(const std::string &s, int offset, int len)
: StringPiece(s.data() + offset, len) {}
char operator[](size_t i) const { return start_[i]; }
@@ -67,9 +67,9 @@
bool empty() const { return size() == 0; }
template <typename A>
- explicit operator basic_string<char, std::char_traits<char>, A>() const {
+ explicit operator std::basic_string<char, std::char_traits<char>, A>() const {
if (!data()) return {};
- return basic_string<char, std::char_traits<char>, A>(data(), size());
+ return std::basic_string<char, std::char_traits<char>, A>(data(), size());
}
private:
diff --git a/native/lang_id/common/math/algorithm.h b/native/lang_id/common/math/algorithm.h
index a963807..5c8596b 100644
--- a/native/lang_id/common/math/algorithm.h
+++ b/native/lang_id/common/math/algorithm.h
@@ -20,6 +20,7 @@
#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_MATH_ALGORITHM_H_
#include <algorithm>
+#include <queue>
#include <vector>
namespace libtextclassifier3 {
@@ -43,6 +44,104 @@
std::min_element(elements.begin(), elements.end()));
}
+// Returns indices of greatest k elements from |v|.
+//
+// The order between elements is indicated by |smaller|, which should be an
+// object like std::less<T>, std::greater<T>, etc. If smaller(a, b) is true,
+// that means that "a is smaller than b". Intuitively, |smaller| is a
+// generalization of operator<. Formally, it is a strict weak ordering, see
+// https://en.cppreference.com/w/cpp/named_req/Compare
+//
+// Calling this function with std::less<T>() returns the indices of the larger k
+// elements; calling it with std::greater<T>() returns the indices of the
+// smallest k elements. This is similar to e.g., std::priority_queue: using the
+// default std::less gives you a max-heap, while using std::greater results in a
+// min-heap.
+//
+// Returned indices are sorted in decreasing order of the corresponding elements
+// (e.g., first element of the returned array is the index of the largest
+// element). In case of ties (e.g., equal elements) we select the one with the
+// smallest index. E.g., getting the indices of the top-2 elements from [3, 2,
+// 1, 3, 0, 3] returns [0, 3] (the indices of the first and the second 3).
+//
+// Corner cases: If k <= 0, this function returns an empty vector. If |v| has
+// only n < k elements, this function returns all n indices [0, 1, 2, ..., n -
+// 1], sorted according to the comp order of the indicated elements.
+//
+// Assuming each comparison is O(1), this function uses O(k) auxiliary space,
+// and runs in O(n * log k) time. Note: it is possible to use std::nth_element
+// and obtain an O(n + k * log k) time algorithm, but that uses O(n) auxiliary
+// space. In our case, k << n, e.g., we may want to select the top-3 most
+// likely classes from a set of 100 classes, so the time complexity difference
+// should not matter in practice.
+template <typename T, typename Smaller>
+std::vector<int> GetTopKIndices(int k, const std::vector<T> &v,
+ Smaller smaller) {
+ if (k <= 0) {
+ return std::vector<int>();
+ }
+
+ if (k > v.size()) {
+ k = v.size();
+ }
+
+ // An order between indices. Intuitively, rev_vcomp(i1, i2) iff v[i2] is
+ // smaller than v[i1]. No typo: this inversion is necessary for Invariant B
+ // below. "vcomp" stands for "value comparator" (we compare the values
+ // indicates by the two indices) and "rev_" stands for the reverse order.
+ const auto rev_vcomp = [&v, &smaller](int i1, int i2) -> bool {
+ if (smaller(v[i2], v[i1])) return true;
+ if (smaller(v[i1], v[i2])) return false;
+
+ // Break ties in favor of earlier elements.
+ return i1 < i2;
+ };
+
+ // Indices of the top-k elements seen so far.
+ std::vector<int> heap(k);
+
+ // First, we fill |heap| with the first k indices.
+ for (int i = 0; i < k; ++i) {
+ heap[i] = i;
+ }
+ std::make_heap(heap.begin(), heap.end(), rev_vcomp);
+
+ // Next, we explore the rest of the vector v. Loop invariants:
+ //
+ // Invariant A: |heap| contains the indices of the top-k elements from v[0:i].
+ //
+ // Invariant B: heap[0] is the index of the smallest element from all elements
+ // indicated by the indices from |heap|.
+ //
+ // Invariant C: |heap| is a max heap, according to order rev_vcomp.
+ for (int i = k; i < v.size(); ++i) {
+ // We have to update |heap| iff v[i] is larger than the smallest of the
+ // top-k seen so far. This test is easy to do, due to Invariant B above.
+ if (smaller(v[heap[0]], v[i])) {
+ // Next lines replace heap[0] with i and re-"heapify" heap[0:k-1].
+ heap.push_back(i);
+ std::pop_heap(heap.begin(), heap.end(), rev_vcomp);
+ heap.pop_back();
+ }
+ }
+
+ // Arrange indices from |heap| in decreasing order of corresponding elements.
+ //
+ // More info: in iteration #0, we extract the largest heap element (according
+ // to rev_vcomp, i.e., the index of the smallest of the top-k elements) and
+ // place it at the end of heap, i.e., in heap[k-1]. In iteration #1, we
+ // extract the second largest and place it in heap[k-2], etc.
+ for (int i = 0; i < k; ++i) {
+ std::pop_heap(heap.begin(), heap.end() - i, rev_vcomp);
+ }
+ return heap;
+}
+
+template <typename T>
+std::vector<int> GetTopKIndices(int k, const std::vector<T> &elements) {
+ return GetTopKIndices(k, elements, std::less<T>());
+}
+
} // namespace mobile
} // namespace nlp_saft
diff --git a/native/lang_id/common/math/hash.h b/native/lang_id/common/math/hash.h
index 08c32be..a1c24d5 100644
--- a/native/lang_id/common/math/hash.h
+++ b/native/lang_id/common/math/hash.h
@@ -51,7 +51,7 @@
return Hash32(data, n, 0xBEEF);
}
-static inline uint32 Hash32WithDefaultSeed(const string &input) {
+static inline uint32 Hash32WithDefaultSeed(const std::string &input) {
return Hash32WithDefaultSeed(input.data(), input.size());
}
diff --git a/native/lang_id/common/registry.h b/native/lang_id/common/registry.h
index d2c5271..632f917 100644
--- a/native/lang_id/common/registry.h
+++ b/native/lang_id/common/registry.h
@@ -178,14 +178,14 @@
return (cell == nullptr) ? nullptr : cell->value();
}
- T *Lookup(const string &key) const { return Lookup(key.c_str()); }
+ T *Lookup(const std::string &key) const { return Lookup(key.c_str()); }
// Returns name of this ComponentRegistry.
const char *name() const { return name_; }
// Fills *names with names of all components registered in this
// ComponentRegistry. Previous content of *names is cleared out.
- void GetComponentNames(std::vector<string> *names) {
+ void GetComponentNames(std::vector<std::string> *names) {
names->clear();
for (const Cell *c = head_; c!= nullptr; c = c->next()) {
names->emplace_back(c->key());
@@ -247,7 +247,7 @@
// case of errors (e.g., unknown component).
//
// Passes ownership of the returned pointer to the caller.
- static T *Create(const string &name) { // NOLINT
+ static T *Create(const std::string &name) { // NOLINT
auto *factory = registry()->Lookup(name);
if (factory == nullptr) {
SAFTM_LOG(ERROR) << "Unknown RegisterableClass " << name;
diff --git a/native/lang_id/common/utf8.h b/native/lang_id/common/utf8.h
index 2365429..6103bdd 100644
--- a/native/lang_id/common/utf8.h
+++ b/native/lang_id/common/utf8.h
@@ -65,7 +65,7 @@
// Preconditions: data != nullptr.
const char *GetSafeEndOfUtf8String(const char *data, size_t size);
-static inline const char *GetSafeEndOfUtf8String(const string &text) {
+static inline const char *GetSafeEndOfUtf8String(const std::string &text) {
return GetSafeEndOfUtf8String(text.data(), text.size());
}
diff --git a/native/lang_id/custom-tokenizer.cc b/native/lang_id/custom-tokenizer.cc
index f77ad53..46a64b2 100644
--- a/native/lang_id/custom-tokenizer.cc
+++ b/native/lang_id/custom-tokenizer.cc
@@ -44,7 +44,7 @@
// we append the original UTF8 character.
inline SAFTM_ATTRIBUTE_ALWAYS_INLINE void AppendLowerCase(const char *curr,
int num_bytes,
- string *word) {
+ std::string *word) {
if (num_bytes == 1) {
// Optimize the ASCII case.
word->push_back(tolower(*curr));
@@ -126,7 +126,7 @@
// If control reaches this point, we are at beginning of a non-empty token.
sentence->emplace_back();
- string *word = &(sentence->back());
+ std::string *word = &(sentence->back());
// Add special token-start character.
word->push_back('^');
diff --git a/native/lang_id/fb_model/lang-id-from-fb.cc b/native/lang_id/fb_model/lang-id-from-fb.cc
index f8e39d7..b2163eb 100644
--- a/native/lang_id/fb_model/lang-id-from-fb.cc
+++ b/native/lang_id/fb_model/lang-id-from-fb.cc
@@ -16,13 +16,16 @@
#include "lang_id/fb_model/lang-id-from-fb.h"
+#include <string>
+
#include "lang_id/fb_model/model-provider-from-fb.h"
namespace libtextclassifier3 {
namespace mobile {
namespace lang_id {
-std::unique_ptr<LangId> GetLangIdFromFlatbufferFile(const string &filename) {
+std::unique_ptr<LangId> GetLangIdFromFlatbufferFile(
+ const std::string &filename) {
std::unique_ptr<ModelProvider> model_provider(
new ModelProviderFromFlatbuffer(filename));
@@ -31,7 +34,8 @@
new LangId(std::move(model_provider)));
}
-std::unique_ptr<LangId> GetLangIdFromFlatbufferFileDescriptor(int fd) {
+std::unique_ptr<LangId> GetLangIdFromFlatbufferFileDescriptor(
+ FileDescriptorOrHandle fd) {
std::unique_ptr<ModelProvider> model_provider(
new ModelProviderFromFlatbuffer(fd));
diff --git a/native/lang_id/fb_model/lang-id-from-fb.h b/native/lang_id/fb_model/lang-id-from-fb.h
index 51bcffe..061247b 100644
--- a/native/lang_id/fb_model/lang-id-from-fb.h
+++ b/native/lang_id/fb_model/lang-id-from-fb.h
@@ -22,6 +22,7 @@
#include <memory>
#include <string>
+#include "lang_id/common/file/mmap.h"
#include "lang_id/lang-id.h"
namespace libtextclassifier3 {
@@ -30,11 +31,13 @@
// Returns a LangId built using the SAFT model in flatbuffer format from
// |filename|.
-std::unique_ptr<LangId> GetLangIdFromFlatbufferFile(const string &filename);
+std::unique_ptr<LangId> GetLangIdFromFlatbufferFile(
+ const std::string &filename);
// Returns a LangId built using the SAFT model in flatbuffer format from
// given file descriptor.
-std::unique_ptr<LangId> GetLangIdFromFlatbufferFileDescriptor(int fd);
+std::unique_ptr<LangId> GetLangIdFromFlatbufferFileDescriptor(
+ FileDescriptorOrHandle fd);
// Returns a LangId built using the SAFT model in flatbuffer format from
// the |num_bytes| bytes that start at address |data|.
@@ -50,7 +53,7 @@
//
// IMPORTANT: |bytes| must be alive during the lifetime of the returned LangId.
inline std::unique_ptr<LangId> GetLangIdFromFlatbufferBytes(
- const string &bytes) {
+ const std::string &bytes) {
return GetLangIdFromFlatbufferBytes(bytes.data(), bytes.size());
}
diff --git a/native/lang_id/fb_model/model-provider-from-fb.cc b/native/lang_id/fb_model/model-provider-from-fb.cc
index 3357963..c81b116 100644
--- a/native/lang_id/fb_model/model-provider-from-fb.cc
+++ b/native/lang_id/fb_model/model-provider-from-fb.cc
@@ -16,7 +16,10 @@
#include "lang_id/fb_model/model-provider-from-fb.h"
+#include <string>
+
#include "lang_id/common/file/file-utils.h"
+#include "lang_id/common/file/mmap.h"
#include "lang_id/common/flatbuffers/embedding-network-params-from-flatbuffer.h"
#include "lang_id/common/flatbuffers/model-utils.h"
#include "lang_id/common/lite_strings/str-split.h"
@@ -25,7 +28,8 @@
namespace mobile {
namespace lang_id {
-ModelProviderFromFlatbuffer::ModelProviderFromFlatbuffer(const string &filename)
+ModelProviderFromFlatbuffer::ModelProviderFromFlatbuffer(
+ const std::string &filename)
// Using mmap as a fast way to read the model bytes. As the file is
// unmapped only when the field scoped_mmap_ is destructed, the model bytes
@@ -34,7 +38,8 @@
Initialize(scoped_mmap_->handle().to_stringpiece());
}
-ModelProviderFromFlatbuffer::ModelProviderFromFlatbuffer(int fd)
+ModelProviderFromFlatbuffer::ModelProviderFromFlatbuffer(
+ FileDescriptorOrHandle fd)
// Using mmap as a fast way to read the model bytes. As the file is
// unmapped only when the field scoped_mmap_ is destructed, the model bytes
@@ -60,7 +65,8 @@
}
// Init languages_.
- const string known_languages_str = context_.Get("supported_languages", "");
+ const std::string known_languages_str =
+ context_.Get("supported_languages", "");
for (StringPiece sp : LiteStrSplit(known_languages_str, ',')) {
languages_.emplace_back(sp);
}
@@ -80,7 +86,7 @@
}
bool ModelProviderFromFlatbuffer::InitNetworkParams() {
- const string kInputName = "language-identifier-network";
+ const std::string kInputName = "language-identifier-network";
StringPiece bytes =
saft_fbs::GetInputBytes(saft_fbs::GetInputByName(model_, kInputName));
if ((bytes.data() == nullptr) || bytes.empty()) {
diff --git a/native/lang_id/fb_model/model-provider-from-fb.h b/native/lang_id/fb_model/model-provider-from-fb.h
index d25c903..c3def49 100644
--- a/native/lang_id/fb_model/model-provider-from-fb.h
+++ b/native/lang_id/fb_model/model-provider-from-fb.h
@@ -37,11 +37,11 @@
public:
// Constructs a model provider based on a flatbuffer-format SAFT model from
// |filename|.
- explicit ModelProviderFromFlatbuffer(const string &filename);
+ explicit ModelProviderFromFlatbuffer(const std::string &filename);
// Constructs a model provider based on a flatbuffer-format SAFT model from
// file descriptor |fd|.
- explicit ModelProviderFromFlatbuffer(int fd);
+ explicit ModelProviderFromFlatbuffer(FileDescriptorOrHandle fd);
// Constructs a model provider from a flatbuffer-format SAFT model the bytes
// of which are already in RAM (size bytes starting from address data).
@@ -71,9 +71,7 @@
return nn_params_.get();
}
- std::vector<string> GetLanguages() const override {
- return languages_;
- }
+ std::vector<std::string> GetLanguages() const override { return languages_; }
private:
// Initializes the fields of this class based on the flatbuffer from
@@ -104,7 +102,7 @@
// List of supported languages, see GetLanguages(). We expect this list to be
// specified by the ModelParameter named "supported_languages" from model_.
- std::vector<string> languages_;
+ std::vector<std::string> languages_;
// EmbeddingNetworkParams, see GetNnParams(). Set based on the ModelInput
// named "language-identifier-network" from model_.
diff --git a/native/lang_id/features/char-ngram-feature.cc b/native/lang_id/features/char-ngram-feature.cc
index 83d7588..31faf2f 100644
--- a/native/lang_id/features/char-ngram-feature.cc
+++ b/native/lang_id/features/char-ngram-feature.cc
@@ -16,6 +16,7 @@
#include "lang_id/features/char-ngram-feature.h"
+#include <string>
#include <utility>
#include <vector>
@@ -68,7 +69,7 @@
int total_count = 0;
- for (const string &word : sentence) {
+ for (const std::string &word : sentence) {
const char *const word_end = word.data() + word.size();
// Set ngram_start at the start of the current token (word).
diff --git a/native/lang_id/features/relevant-script-feature.cc b/native/lang_id/features/relevant-script-feature.cc
index 0fde87b..e88b328 100644
--- a/native/lang_id/features/relevant-script-feature.cc
+++ b/native/lang_id/features/relevant-script-feature.cc
@@ -30,7 +30,7 @@
namespace lang_id {
bool RelevantScriptFeature::Setup(TaskContext *context) {
- string script_detector_name = GetParameter(
+ std::string script_detector_name = GetParameter(
"script_detector_name", /* default_value = */ "tiny-script-detector");
// We don't use absl::WrapUnique, nor the rest of absl, see http://b/71873194
@@ -60,7 +60,7 @@
// counts[s] is the number of characters with script s.
std::vector<int> counts(num_supported_scripts_);
int total_count = 0;
- for (const string &word : sentence) {
+ for (const std::string &word : sentence) {
const char *const word_end = word.data() + word.size();
const char *curr = word.data();
diff --git a/native/lang_id/lang-id.cc b/native/lang_id/lang-id.cc
index c892329..ef82456 100644
--- a/native/lang_id/lang-id.cc
+++ b/native/lang_id/lang-id.cc
@@ -18,7 +18,6 @@
#include <stdio.h>
-#include <algorithm>
#include <memory>
#include <string>
#include <unordered_map>
@@ -36,7 +35,13 @@
#include "lang_id/common/math/softmax.h"
#include "lang_id/custom-tokenizer.h"
#include "lang_id/features/light-sentence-features.h"
+// The two features/ headers below are needed only for RegisterClass().
+#include "lang_id/features/char-ngram-feature.h"
+#include "lang_id/features/relevant-script-feature.h"
#include "lang_id/light-sentence.h"
+// The two script/ headers below are needed only for RegisterClass().
+#include "lang_id/script/approx-script.h"
+#include "lang_id/script/tiny-script-detector.h"
namespace libtextclassifier3 {
namespace mobile {
@@ -91,21 +96,15 @@
valid_ = true;
}
- string FindLanguage(StringPiece text) const {
- // NOTE: it would be wasteful to implement this method in terms of
- // FindLanguages(). We just need the most likely language and its
- // probability; no need to compute (and allocate) a vector of pairs for all
- // languages, nor to compute probabilities for all non-top languages.
- if (!is_valid()) {
+ std::string FindLanguage(StringPiece text) const {
+ LangIdResult lang_id_result;
+ FindLanguages(text, &lang_id_result, /* max_results = */ 1);
+ if (lang_id_result.predictions.empty()) {
return LangId::kUnknownLanguageCode;
}
- std::vector<float> scores;
- ComputeScores(text, &scores);
-
- int prediction_id = GetArgMax(scores);
- const string language = GetLanguageForSoftmaxLabel(prediction_id);
- float probability = ComputeSoftmaxProbability(scores, prediction_id);
+ const std::string &language = lang_id_result.predictions[0].first;
+ const float probability = lang_id_result.predictions[0].second;
SAFTM_DLOG(INFO) << "Predicted " << language
<< " with prob: " << probability << " for \"" << text
<< "\"";
@@ -124,39 +123,50 @@
return language;
}
- void FindLanguages(StringPiece text, LangIdResult *result) const {
+ void FindLanguages(StringPiece text, LangIdResult *result,
+ int max_results) const {
if (result == nullptr) return;
+ if (max_results <= 0) {
+ max_results = languages_.size();
+ }
result->predictions.clear();
- if (!is_valid()) {
+ if (!is_valid() || (max_results == 0)) {
result->predictions.emplace_back(LangId::kUnknownLanguageCode, 1);
return;
}
+ // Tokenize the input text (this also does some pre-processing, like
+ // removing ASCII digits, punctuation, etc).
+ LightSentence sentence;
+ tokenizer_.Tokenize(text, &sentence);
+
+ // Extract features from the tokenized text.
+ std::vector<FeatureVector> features =
+ lang_id_brain_interface_.GetFeaturesNoCaching(&sentence);
+
+ // Run feed-forward neural network to compute scores (softmax logits).
std::vector<float> scores;
- ComputeScores(text, &scores);
+ network_->ComputeFinalScores(features, &scores);
- // Compute and sort softmax in descending order by probability and convert
- // IDs to language code strings. When probabilities are equal, we sort by
- // language code string in ascending order.
- std::vector<float> softmax = ComputeSoftmax(scores);
-
- for (int i = 0; i < softmax.size(); ++i) {
- result->predictions.emplace_back(GetLanguageForSoftmaxLabel(i),
- softmax[i]);
+ if (max_results == 1) {
+ // Optimization for the case when the user wants only the top result.
+ // Computing argmax is faster than the general top-k code.
+ int prediction_id = GetArgMax(scores);
+ const std::string language = GetLanguageForSoftmaxLabel(prediction_id);
+ float probability = ComputeSoftmaxProbability(scores, prediction_id);
+ result->predictions.emplace_back(language, probability);
+ } else {
+ // Compute and sort softmax in descending order by probability and convert
+ // IDs to language code strings. When probabilities are equal, we sort by
+ // language code string in ascending order.
+ const std::vector<float> softmax = ComputeSoftmax(scores);
+ const std::vector<int> indices = GetTopKIndices(max_results, softmax);
+ for (const int index : indices) {
+ result->predictions.emplace_back(GetLanguageForSoftmaxLabel(index),
+ softmax[index]);
+ }
}
-
- // Sort the resulting language predictions by probability in descending
- // order.
- std::sort(result->predictions.begin(), result->predictions.end(),
- [](const std::pair<string, float> &a,
- const std::pair<string, float> &b) {
- if (a.second == b.second) {
- return a.first.compare(b.first) < 0;
- } else {
- return a.second > b.second;
- }
- });
}
bool is_valid() const { return valid_; }
@@ -165,10 +175,32 @@
// Returns a property stored in the model file.
template <typename T, typename R>
- R GetProperty(const string &property, T default_value) const {
+ R GetProperty(const std::string &property, T default_value) const {
return model_provider_->GetTaskContext()->Get(property, default_value);
}
+ // Perform any necessary static initialization.
+ // This function is thread-safe.
+ // It's also safe to call this function multiple times.
+ //
+ // We explicitly call RegisterClass() rather than relying on alwayslink=1 in
+ // the BUILD file, because the build process for some users of this code
+ // doesn't support any equivalent to alwayslink=1 (in particular the
+ // Firebase C++ SDK build uses a Kokoro-based CMake build). While it might
+ // be possible to add such support, avoiding the need for an equivalent to
+ // alwayslink=1 is preferable because it avoids unnecessarily bloating code
+ // size in apps that link against this code but don't use it.
+ static void RegisterClasses() {
+ static bool initialized = []() -> bool {
+ libtextclassifier3::mobile::ApproxScriptDetector::RegisterClass();
+ libtextclassifier3::mobile::lang_id::ContinuousBagOfNgramsFunction::RegisterClass();
+ libtextclassifier3::mobile::lang_id::TinyScriptDetector::RegisterClass();
+ libtextclassifier3::mobile::lang_id::RelevantScriptFeature::RegisterClass();
+ return true;
+ }();
+ (void)initialized; // Variable used only for initializer's side effects.
+ }
+
private:
bool Setup(TaskContext *context) {
tokenizer_.Setup(context);
@@ -178,7 +210,7 @@
// Parse task parameter "per_lang_reliability_thresholds", fill
// per_lang_thresholds_.
- const string thresholds_str =
+ const std::string thresholds_str =
context->Get("per_lang_reliability_thresholds", "");
std::vector<StringPiece> tokens = LiteStrSplit(thresholds_str, ',');
for (const auto &token : tokens) {
@@ -186,7 +218,7 @@
std::vector<StringPiece> parts = LiteStrSplit(token, '=');
float threshold = 0.0f;
if ((parts.size() == 2) && LiteAtof(parts[1], &threshold)) {
- per_lang_thresholds_[string(parts[0])] = threshold;
+ per_lang_thresholds_[std::string(parts[0])] = threshold;
} else {
SAFTM_LOG(ERROR) << "Broken token: \"" << token << "\"";
}
@@ -199,25 +231,9 @@
return lang_id_brain_interface_.InitForProcessing(context);
}
- // Extracts features for |text|, runs them through the feed-forward neural
- // network, and computes the output scores (activations from the last layer).
- // These scores can be used to compute the softmax probabilities for our
- // labels (in this case, the languages).
- void ComputeScores(StringPiece text, std::vector<float> *scores) const {
- // Create a Sentence storing the input text.
- LightSentence sentence;
- tokenizer_.Tokenize(text, &sentence);
-
- std::vector<FeatureVector> features =
- lang_id_brain_interface_.GetFeaturesNoCaching(&sentence);
-
- // Run feed-forward neural network to compute scores.
- network_->ComputeFinalScores(features, scores);
- }
-
// Returns language code for a softmax label. See comments for languages_
// field. If label is out of range, returns LangId::kUnknownLanguageCode.
- string GetLanguageForSoftmaxLabel(int label) const {
+ std::string GetLanguageForSoftmaxLabel(int label) const {
if ((label >= 0) && (label < languages_.size())) {
return languages_[label];
} else {
@@ -244,11 +260,11 @@
// reported. Otherwise, we report LangId::kUnknownLanguageCode.
float default_threshold_ = kDefaultConfidenceThreshold;
- std::unordered_map<string, float> per_lang_thresholds_;
+ std::unordered_map<std::string, float> per_lang_thresholds_;
// Recognized languages: softmax label i means languages_[i] (something like
// "en", "fr", "ru", etc).
- std::vector<string> languages_;
+ std::vector<std::string> languages_;
// Version of the model used by this LangIdImpl object. Zero means that the
// model version could not be determined.
@@ -258,27 +274,29 @@
const char LangId::kUnknownLanguageCode[] = "und";
LangId::LangId(std::unique_ptr<ModelProvider> model_provider)
- : pimpl_(new LangIdImpl(std::move(model_provider))) {}
+ : pimpl_(new LangIdImpl(std::move(model_provider))) {
+ LangIdImpl::RegisterClasses();
+}
LangId::~LangId() = default;
-string LangId::FindLanguage(const char *data, size_t num_bytes) const {
+std::string LangId::FindLanguage(const char *data, size_t num_bytes) const {
StringPiece text(data, num_bytes);
return pimpl_->FindLanguage(text);
}
void LangId::FindLanguages(const char *data, size_t num_bytes,
- LangIdResult *result) const {
+ LangIdResult *result, int max_results) const {
SAFTM_DCHECK(result) << "LangIdResult must not be null.";
StringPiece text(data, num_bytes);
- pimpl_->FindLanguages(text, result);
+ pimpl_->FindLanguages(text, result, max_results);
}
bool LangId::is_valid() const { return pimpl_->is_valid(); }
int LangId::GetModelVersion() const { return pimpl_->GetModelVersion(); }
-float LangId::GetFloatProperty(const string &property,
+float LangId::GetFloatProperty(const std::string &property,
float default_value) const {
return pimpl_->GetProperty<float, float>(property, default_value);
}
diff --git a/native/lang_id/lang-id.h b/native/lang_id/lang-id.h
index 94af0c3..18c6e77 100644
--- a/native/lang_id/lang-id.h
+++ b/native/lang_id/lang-id.h
@@ -45,7 +45,7 @@
//
// If the model cannot make a prediction, this array contains a single result:
// a language code LangId::kUnknownLanguageCode with probability 1.
- std::vector<std::pair<string, float>> predictions;
+ std::vector<std::pair<std::string, float>> predictions;
};
// Class for detecting the language of a document.
@@ -69,21 +69,27 @@
virtual ~LangId();
- // Computes the an n-best list of language codes and probabilities
- // corresponding to the most likely languages the given input text is written
- // in. The list is sorted in descending order by language probability.
+ // Computes the n-best list of language codes and probabilities corresponding
+ // to the most likely languages the given input text is written in. That list
+ // includes the most likely |max_results| languages and is sorted in
+ // descending order by language probability.
//
// The input text consists of the |num_bytes| bytes that starts at |data|.
//
+ // If max_results <= 0, we report probabilities for all languages known by
+ // this LangId object (as always, in decreasing order of their probabilities).
+ //
// Note: If this LangId object is not valid (see is_valid()) or if this LangId
// object can't make a prediction, this method sets the LangIdResult to
// contain a single entry with kUnknownLanguageCode with probability 1.
- void FindLanguages(const char *data, size_t num_bytes,
- LangIdResult *result) const;
+ //
+ void FindLanguages(const char *data, size_t num_bytes, LangIdResult *result,
+ int max_results = 0) const;
// Convenience version of FindLanguages(const char *, size_t, LangIdResult *).
- void FindLanguages(const string &text, LangIdResult *result) const {
- FindLanguages(text.data(), text.size(), result);
+ void FindLanguages(const std::string &text, LangIdResult *result,
+ int max_results = 0) const {
+ FindLanguages(text.data(), text.size(), result, max_results);
}
// Returns language code for the most likely language for a piece of text.
@@ -101,10 +107,10 @@
// object can't make a prediction, then this method returns
// LangId::kUnknownLanguageCode.
//
- string FindLanguage(const char *data, size_t num_bytes) const;
+ std::string FindLanguage(const char *data, size_t num_bytes) const;
// Convenience version of FindLanguage(const char *, size_t).
- string FindLanguage(const string &text) const {
+ std::string FindLanguage(const std::string &text) const {
return FindLanguage(text.data(), text.size());
}
@@ -120,7 +126,8 @@
int GetModelVersion() const;
// Returns a typed property stored in the model file.
- float GetFloatProperty(const string &property, float default_value) const;
+ float GetFloatProperty(const std::string &property,
+ float default_value) const;
private:
// Pimpl ("pointer to implementation") pattern, to hide all internals from our
diff --git a/native/lang_id/lang-id_jni.cc b/native/lang_id/lang-id_jni.cc
index 6696298..02b388f 100644
--- a/native/lang_id/lang-id_jni.cc
+++ b/native/lang_id/lang-id_jni.cc
@@ -34,7 +34,8 @@
namespace {
jobjectArray LangIdResultToJObjectArray(JNIEnv* env,
- const LangIdResult& lang_id_result) {
+ const LangIdResult& lang_id_result,
+ const float significant_threshold) {
const ScopedLocalRef<jclass> result_class(
env->FindClass(TC3_PACKAGE_PATH TC3_LANG_ID_CLASS_NAME_STR
"$LanguageResult"),
@@ -44,10 +45,14 @@
return nullptr;
}
- // clang-format off
- const std::vector<std::pair<std::string, float>>& predictions =
- lang_id_result.predictions;
- // clang-format on
+ std::vector<std::pair<std::string, float>> predictions;
+ std::copy_if(lang_id_result.predictions.begin(),
+ lang_id_result.predictions.end(),
+ std::back_inserter(predictions),
+ [significant_threshold](std::pair<std::string, float> pair) {
+ return pair.second >= significant_threshold;
+ });
+
const jmethodID result_class_constructor =
env->GetMethodID(result_class.get(), "<init>", "(Ljava/lang/String;F)V");
const jobjectArray results =
@@ -61,6 +66,10 @@
}
return results;
}
+
+float GetNoiseThreshold(const LangId& model) {
+ return model.GetFloatProperty("text_classifier_langid_noise_threshold", -1.0);
+}
} // namespace
TC3_JNI_METHOD(jlong, TC3_LANG_ID_CLASS_NAME, nativeNew)
@@ -90,10 +99,16 @@
}
const std::string text_str = ToStlString(env, text);
+ const float noise_threshold = GetNoiseThreshold(*model);
+ // Speed up the things by specifying the max results we want. For example, if
+ // the noise threshold is 0.1, we don't need more than 10 results.
+ const int max_results =
+ noise_threshold < 0.01
+ ? -1 // -1 means FindLanguages returns all predictions
+ : static_cast<int>(1 / noise_threshold) + 1;
LangIdResult result;
- model->FindLanguages(text_str, &result);
-
- return LangIdResultToJObjectArray(env, result);
+ model->FindLanguages(text_str, &result, max_results);
+ return LangIdResultToJObjectArray(env, result, noise_threshold);
}
TC3_JNI_METHOD(void, TC3_LANG_ID_CLASS_NAME, nativeClose)
@@ -132,3 +147,12 @@
LangId* model = reinterpret_cast<LangId*>(ptr);
return model->GetFloatProperty("text_classifier_langid_threshold", -1.0);
}
+
+TC3_JNI_METHOD(jfloat, TC3_LANG_ID_CLASS_NAME, nativeGetLangIdNoiseThreshold)
+(JNIEnv* env, jobject thizz, jlong ptr) {
+ if (!ptr) {
+ return -1.0;
+ }
+ LangId* model = reinterpret_cast<LangId*>(ptr);
+ return GetNoiseThreshold(*model);
+}
diff --git a/native/lang_id/lang-id_jni.h b/native/lang_id/lang-id_jni.h
index cd67a4c..b765ad4 100644
--- a/native/lang_id/lang-id_jni.h
+++ b/native/lang_id/lang-id_jni.h
@@ -54,6 +54,9 @@
TC3_JNI_METHOD(jfloat, TC3_LANG_ID_CLASS_NAME, nativeGetLangIdThreshold)
(JNIEnv* env, jobject thizz, jlong ptr);
+TC3_JNI_METHOD(jfloat, TC3_LANG_ID_CLASS_NAME, nativeGetLangIdNoiseThreshold)
+(JNIEnv* env, jobject thizz, jlong ptr);
+
#ifdef __cplusplus
}
#endif
diff --git a/native/lang_id/light-sentence.h b/native/lang_id/light-sentence.h
index 2937549..2aee2ea 100644
--- a/native/lang_id/light-sentence.h
+++ b/native/lang_id/light-sentence.h
@@ -27,7 +27,7 @@
// Very simplified alternative to heavy sentence.proto, for the purpose of
// LangId. It turns out that in this case, all we need is a vector of strings,
// which uses a lot less code size than a Sentence proto.
-using LightSentence = std::vector<string>;
+using LightSentence = std::vector<std::string>;
} // namespace lang_id
} // namespace mobile
diff --git a/native/lang_id/model-provider.h b/native/lang_id/model-provider.h
index a076871..bf250ed 100644
--- a/native/lang_id/model-provider.h
+++ b/native/lang_id/model-provider.h
@@ -51,7 +51,7 @@
// returned vector should be a BCP-47 language code (e.g., "en", "ro", etc).
// Language at index i from the returned vector corresponds to softmax label
// i.
- virtual std::vector<string> GetLanguages() const = 0;
+ virtual std::vector<std::string> GetLanguages() const = 0;
protected:
bool valid_ = false;
diff --git a/native/models/actions_suggestions.en.model b/native/models/actions_suggestions.en.model
index 6cec2b7..90d66ba 100644
--- a/native/models/actions_suggestions.en.model
+++ b/native/models/actions_suggestions.en.model
Binary files differ
diff --git a/native/models/actions_suggestions.universal.model b/native/models/actions_suggestions.universal.model
index 60f10e6..74f9ee5 100644
--- a/native/models/actions_suggestions.universal.model
+++ b/native/models/actions_suggestions.universal.model
Binary files differ
diff --git a/native/models/lang_id.model b/native/models/lang_id.model
index 49b4b07..92f0103 100644
--- a/native/models/lang_id.model
+++ b/native/models/lang_id.model
Binary files differ
diff --git a/native/utils/base/arena.cc b/native/utils/base/arena.cc
new file mode 100644
index 0000000..fcaed8e
--- /dev/null
+++ b/native/utils/base/arena.cc
@@ -0,0 +1,513 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// This approach to arenas overcomes many of the limitations described
+// in the "Specialized allocators" section of
+// http://www.pdos.lcs.mit.edu/~dm/c++-new.html
+//
+// A somewhat similar approach to Gladiator, but for heap-detection, was
+// suggested by Ron van der Wal and Scott Meyers at
+// http://www.aristeia.com/BookErrata/M27Comments_frames.html
+
+#include "utils/base/arena.h"
+
+#include "utils/base/logging.h"
+#include "utils/base/macros.h"
+
+namespace libtextclassifier3 {
+
+static void *aligned_malloc(size_t size, int minimum_alignment) {
+ void *ptr = nullptr;
+ // posix_memalign requires that the requested alignment be at least
+ // sizeof(void*). In this case, fall back on malloc which should return memory
+ // aligned to at least the size of a pointer.
+ const int required_alignment = sizeof(void*);
+ if (minimum_alignment < required_alignment)
+ return malloc(size);
+ if (posix_memalign(&ptr, static_cast<size_t>(minimum_alignment), size) != 0)
+ return nullptr;
+ else
+ return ptr;
+}
+
+// The value here doesn't matter until page_aligned_ is supported.
+static const int kPageSize = 8192; // should be getpagesize()
+
+// We used to only keep track of how much space has been allocated in
+// debug mode. Now we track this for optimized builds, as well. If you
+// want to play with the old scheme to see if this helps performance,
+// change this TC3_ARENASET() macro to a NOP. However, NOTE: some
+// applications of arenas depend on this space information (exported
+// via bytes_allocated()).
+#define TC3_ARENASET(x) (x)
+
+namespace {
+
+#ifdef __cpp_aligned_new
+
+char* AllocateBytes(size_t size) {
+ return static_cast<char*>(::operator new(size));
+}
+
+// REQUIRES: alignment > __STDCPP_DEFAULT_NEW_ALIGNMENT__
+//
+// For alignments <=__STDCPP_DEFAULT_NEW_ALIGNMENT__, AllocateBytes() will
+// provide the correct alignment.
+char* AllocateAlignedBytes(size_t size, size_t alignment) {
+ TC3_CHECK_GT(alignment, __STDCPP_DEFAULT_NEW_ALIGNMENT__);
+ return static_cast<char*>(::operator new(size, std::align_val_t(alignment)));
+}
+
+void DeallocateBytes(void* ptr, size_t size, size_t alignment) {
+ if (alignment > __STDCPP_DEFAULT_NEW_ALIGNMENT__) {
+#ifdef __cpp_sized_deallocation
+ ::operator delete(ptr, size, std::align_val_t(alignment));
+#else // !__cpp_sized_deallocation
+ ::operator delete(ptr, std::align_val_t(alignment));
+#endif // !__cpp_sized_deallocation
+ } else {
+#ifdef __cpp_sized_deallocation
+ ::operator delete(ptr, size);
+#else // !__cpp_sized_deallocation
+ ::operator delete(ptr);
+#endif // !__cpp_sized_deallocation
+ }
+}
+
+#else // !__cpp_aligned_new
+
+char* AllocateBytes(size_t size) {
+ return static_cast<char*>(malloc(size));
+}
+
+char* AllocateAlignedBytes(size_t size, size_t alignment) {
+ return static_cast<char*>(aligned_malloc(size, alignment));
+}
+
+void DeallocateBytes(void* ptr, size_t size, size_t alignment) {
+ free(ptr);
+}
+
+#endif // !__cpp_aligned_new
+
+} // namespace
+
+const int BaseArena::kDefaultAlignment;
+
+// ----------------------------------------------------------------------
+// BaseArena::BaseArena()
+// BaseArena::~BaseArena()
+// Destroying the arena automatically calls Reset()
+// ----------------------------------------------------------------------
+
+BaseArena::BaseArena(char* first, const size_t orig_block_size,
+ bool align_to_page)
+ : remaining_(0),
+ block_size_(orig_block_size),
+ freestart_(nullptr), // set for real in Reset()
+ last_alloc_(nullptr),
+ overflow_blocks_(nullptr),
+ first_block_externally_owned_(first != nullptr),
+ page_aligned_(align_to_page),
+ blocks_alloced_(1) {
+ // Trivial check that aligned objects can actually be allocated.
+ TC3_CHECK_GT(block_size_, kDefaultAlignment)
+ << "orig_block_size = " << orig_block_size;
+ if (page_aligned_) {
+ // kPageSize must be power of 2, so make sure of this.
+ TC3_CHECK(kPageSize > 0 && 0 == (kPageSize & (kPageSize - 1)))
+ << "kPageSize[ " << kPageSize << "] is not "
+ << "correctly initialized: not a power of 2.";
+ }
+
+ if (first) {
+ TC3_CHECK(!page_aligned_ ||
+ (reinterpret_cast<uintptr_t>(first) & (kPageSize - 1)) == 0);
+ first_blocks_[0].mem = first;
+ first_blocks_[0].size = orig_block_size;
+ } else {
+ if (page_aligned_) {
+ // Make sure the blocksize is page multiple, as we need to end on a page
+ // boundary.
+ TC3_CHECK_EQ(block_size_ & (kPageSize - 1), 0) << "block_size is not a"
+ << "multiple of kPageSize";
+ first_blocks_[0].mem = AllocateAlignedBytes(block_size_, kPageSize);
+ first_blocks_[0].alignment = kPageSize;
+ TC3_CHECK(nullptr != first_blocks_[0].mem);
+ } else {
+ first_blocks_[0].mem = AllocateBytes(block_size_);
+ first_blocks_[0].alignment = 0;
+ }
+ first_blocks_[0].size = block_size_;
+ }
+
+ Reset();
+}
+
+BaseArena::~BaseArena() {
+ FreeBlocks();
+ assert(overflow_blocks_ == nullptr); // FreeBlocks() should do that
+#ifdef ADDRESS_SANITIZER
+ if (first_block_externally_owned_) {
+ ASAN_UNPOISON_MEMORY_REGION(first_blocks_[0].mem, first_blocks_[0].size);
+ }
+#endif
+ // The first X blocks stay allocated always by default. Delete them now.
+ for (int i = first_block_externally_owned_ ? 1 : 0;
+ i < blocks_alloced_; ++i) {
+ DeallocateBytes(first_blocks_[i].mem, first_blocks_[i].size,
+ first_blocks_[i].alignment);
+ }
+}
+
+// ----------------------------------------------------------------------
+// BaseArena::block_count()
+// Only reason this is in .cc file is because it involves STL.
+// ----------------------------------------------------------------------
+
+int BaseArena::block_count() const {
+ return (blocks_alloced_ +
+ (overflow_blocks_ ? static_cast<int>(overflow_blocks_->size()) : 0));
+}
+
+// Returns true iff it advances freestart_ to the first position
+// satisfying alignment without exhausting the current block.
+bool BaseArena::SatisfyAlignment(size_t alignment) {
+ const size_t overage =
+ reinterpret_cast<size_t>(freestart_) & (alignment - 1);
+ if (overage > 0) {
+ const size_t waste = alignment - overage;
+ if (waste >= remaining_) {
+ return false;
+ }
+ freestart_ += waste;
+ remaining_ -= waste;
+ }
+ TC3_DCHECK_EQ(0, reinterpret_cast<size_t>(freestart_) & (alignment - 1));
+ return true;
+}
+
+// ----------------------------------------------------------------------
+// BaseArena::Reset()
+// Clears all the memory an arena is using.
+// ----------------------------------------------------------------------
+
+void BaseArena::Reset() {
+ FreeBlocks();
+ freestart_ = first_blocks_[0].mem;
+ remaining_ = first_blocks_[0].size;
+ last_alloc_ = nullptr;
+#ifdef ADDRESS_SANITIZER
+ ASAN_POISON_MEMORY_REGION(freestart_, remaining_);
+#endif
+
+ TC3_ARENASET(status_.bytes_allocated_ = block_size_);
+
+ // There is no guarantee the first block is properly aligned, so
+ // enforce that now.
+ TC3_CHECK(SatisfyAlignment(kDefaultAlignment));
+
+ freestart_when_empty_ = freestart_;
+}
+
+// ----------------------------------------------------------------------
+// BaseArena::MakeNewBlock()
+// Our sbrk() equivalent. We always make blocks of the same size
+// (though GetMemory() can also make a new block for really big
+// data.
+// ----------------------------------------------------------------------
+
+void BaseArena::MakeNewBlock(const uint32 alignment) {
+ AllocatedBlock *block = AllocNewBlock(block_size_, alignment);
+ freestart_ = block->mem;
+ remaining_ = block->size;
+ TC3_CHECK(SatisfyAlignment(alignment));
+}
+
+// The following simple numeric routines also exist in util/math/mathutil.h
+// but we don't want to depend on that library.
+
+// Euclid's algorithm for Greatest Common Denominator.
+static uint32 GCD(uint32 x, uint32 y) {
+ while (y != 0) {
+ uint32 r = x % y;
+ x = y;
+ y = r;
+ }
+ return x;
+}
+
+static uint32 LeastCommonMultiple(uint32 a, uint32 b) {
+ if (a > b) {
+ return (a / GCD(a, b)) * b;
+ } else if (a < b) {
+ return (b / GCD(b, a)) * a;
+ } else {
+ return a;
+ }
+}
+
+// -------------------------------------------------------------
+// BaseArena::AllocNewBlock()
+// Adds and returns an AllocatedBlock.
+// The returned AllocatedBlock* is valid until the next call
+// to AllocNewBlock or Reset. (i.e. anything that might
+// affect overflow_blocks_).
+// -------------------------------------------------------------
+
+BaseArena::AllocatedBlock* BaseArena::AllocNewBlock(const size_t block_size,
+ const uint32 alignment) {
+ AllocatedBlock *block;
+ // Find the next block.
+ if (blocks_alloced_ < TC3_ARRAYSIZE(first_blocks_)) {
+ // Use one of the pre-allocated blocks
+ block = &first_blocks_[blocks_alloced_++];
+ } else { // oops, out of space, move to the vector
+ if (overflow_blocks_ == nullptr)
+ overflow_blocks_ = new std::vector<AllocatedBlock>;
+ // Adds another block to the vector.
+ overflow_blocks_->resize(overflow_blocks_->size()+1);
+ // block points to the last block of the vector.
+ block = &overflow_blocks_->back();
+ }
+
+ // NOTE(tucker): this utility is made slightly more complex by
+ // not disallowing the case where alignment > block_size.
+ // Can we, without breaking existing code?
+
+ // If page_aligned_, then alignment must be a multiple of page size.
+ // Otherwise, must be a multiple of kDefaultAlignment, unless
+ // requested alignment is 1, in which case we don't care at all.
+ const uint32 adjusted_alignment =
+ page_aligned_ ? LeastCommonMultiple(kPageSize, alignment)
+ : (alignment > 1 ? LeastCommonMultiple(alignment, kDefaultAlignment) : 1);
+ TC3_CHECK_LE(adjusted_alignment, 1 << 20)
+ << "Alignment on boundaries greater than 1MB not supported.";
+
+ // If block_size > alignment we force block_size to be a multiple
+ // of alignment; if block_size < alignment we make no adjustment, unless
+ // page_aligned_ is true, in which case it must be a multiple of
+ // kPageSize because SetProtect() will assume that.
+ size_t adjusted_block_size = block_size;
+#ifdef __STDCPP_DEFAULT_NEW_ALIGNMENT__
+ if (adjusted_alignment > __STDCPP_DEFAULT_NEW_ALIGNMENT__) {
+#else
+ if (adjusted_alignment > 1) {
+#endif
+ if (adjusted_block_size > adjusted_alignment) {
+ const uint32 excess = adjusted_block_size % adjusted_alignment;
+ adjusted_block_size += (excess > 0 ? adjusted_alignment - excess : 0);
+ }
+ if (page_aligned_) {
+ size_t num_pages = ((adjusted_block_size - 1)/kPageSize) + 1;
+ adjusted_block_size = num_pages * kPageSize;
+ }
+ block->mem = AllocateAlignedBytes(adjusted_block_size, adjusted_alignment);
+ } else {
+ block->mem = AllocateBytes(adjusted_block_size);
+ }
+ block->size = adjusted_block_size;
+ block->alignment = adjusted_alignment;
+ TC3_CHECK(nullptr != block->mem)
+ << "block_size=" << block_size
+ << " adjusted_block_size=" << adjusted_block_size
+ << " alignment=" << alignment
+ << " adjusted_alignment=" << adjusted_alignment;
+
+ TC3_ARENASET(status_.bytes_allocated_ += adjusted_block_size);
+
+#ifdef ADDRESS_SANITIZER
+ ASAN_POISON_MEMORY_REGION(block->mem, block->size);
+#endif
+ return block;
+}
+
+// ----------------------------------------------------------------------
+// BaseArena::IndexToBlock()
+// Index encoding is as follows:
+// For blocks in the first_blocks_ array, we use index of the block in
+// the array.
+// For blocks in the overflow_blocks_ vector, we use the index of the
+// block in iverflow_blocks_, plus the size of the first_blocks_ array.
+// ----------------------------------------------------------------------
+
+const BaseArena::AllocatedBlock *BaseArena::IndexToBlock(int index) const {
+ if (index < TC3_ARRAYSIZE(first_blocks_)) {
+ return &first_blocks_[index];
+ }
+ TC3_CHECK(overflow_blocks_ != nullptr);
+ int index_in_overflow_blocks = index - TC3_ARRAYSIZE(first_blocks_);
+ TC3_CHECK_GE(index_in_overflow_blocks, 0);
+ TC3_CHECK_LT(static_cast<size_t>(index_in_overflow_blocks),
+ overflow_blocks_->size());
+ return &(*overflow_blocks_)[index_in_overflow_blocks];
+}
+
+// ----------------------------------------------------------------------
+// BaseArena::GetMemoryFallback()
+// We take memory out of our pool, aligned on the byte boundary
+// requested. If we don't have space in our current pool, we
+// allocate a new block (wasting the remaining space in the
+// current block) and give you that. If your memory needs are
+// too big for a single block, we make a special your-memory-only
+// allocation -- this is equivalent to not using the arena at all.
+// ----------------------------------------------------------------------
+
+void* BaseArena::GetMemoryFallback(const size_t size, const int alignment) {
+ if (0 == size) {
+ return nullptr; // stl/stl_alloc.h says this is okay
+ }
+
+ // alignment must be a positive power of 2.
+ TC3_CHECK(alignment > 0 && 0 == (alignment & (alignment - 1)));
+
+ // If the object is more than a quarter of the block size, allocate
+ // it separately to avoid wasting too much space in leftover bytes.
+ if (block_size_ == 0 || size > block_size_/4) {
+ // Use a block separate from all other allocations; in particular
+ // we don't update last_alloc_ so you can't reclaim space on this block.
+ AllocatedBlock* b = AllocNewBlock(size, alignment);
+#ifdef ADDRESS_SANITIZER
+ ASAN_UNPOISON_MEMORY_REGION(b->mem, b->size);
+#endif
+ return b->mem;
+ }
+
+ // Enforce alignment on freestart_ then check for adequate space,
+ // which may require starting a new block.
+ if (!SatisfyAlignment(alignment) || size > remaining_) {
+ MakeNewBlock(alignment);
+ }
+ TC3_CHECK_LE(size, remaining_);
+
+ remaining_ -= size;
+ last_alloc_ = freestart_;
+ freestart_ += size;
+
+#ifdef ADDRESS_SANITIZER
+ ASAN_UNPOISON_MEMORY_REGION(last_alloc_, size);
+#endif
+ return reinterpret_cast<void*>(last_alloc_);
+}
+
+// ----------------------------------------------------------------------
+// BaseArena::ReturnMemoryFallback()
+// BaseArena::FreeBlocks()
+// Unlike GetMemory(), which does actual work, ReturnMemory() is a
+// no-op: we don't "free" memory until Reset() is called. We do
+// update some stats, though. Note we do no checking that the
+// pointer you pass in was actually allocated by us, or that it
+// was allocated for the size you say, so be careful here!
+// FreeBlocks() does the work for Reset(), actually freeing all
+// memory allocated in one fell swoop.
+// ----------------------------------------------------------------------
+
+void BaseArena::FreeBlocks() {
+ for ( int i = 1; i < blocks_alloced_; ++i ) { // keep first block alloced
+ DeallocateBytes(first_blocks_[i].mem, first_blocks_[i].size,
+ first_blocks_[i].alignment);
+ first_blocks_[i].mem = nullptr;
+ first_blocks_[i].size = 0;
+ }
+ blocks_alloced_ = 1;
+ if (overflow_blocks_ != nullptr) {
+ std::vector<AllocatedBlock>::iterator it;
+ for (it = overflow_blocks_->begin(); it != overflow_blocks_->end(); ++it) {
+ DeallocateBytes(it->mem, it->size, it->alignment);
+ }
+ delete overflow_blocks_; // These should be used very rarely
+ overflow_blocks_ = nullptr;
+ }
+}
+
+// ----------------------------------------------------------------------
+// BaseArena::AdjustLastAlloc()
+// If you realize you didn't want your last alloc to be for
+// the size you asked, after all, you can fix it by calling
+// this. We'll grow or shrink the last-alloc region if we
+// can (we can always shrink, but we might not be able to
+// grow if you want to grow too big.
+// RETURNS true if we successfully modified the last-alloc
+// region, false if the pointer you passed in wasn't actually
+// the last alloc or if you tried to grow bigger than we could.
+// ----------------------------------------------------------------------
+
+bool BaseArena::AdjustLastAlloc(void *last_alloc, const size_t newsize) {
+ // It's only legal to call this on the last thing you alloced.
+ if (last_alloc == nullptr || last_alloc != last_alloc_) return false;
+ // last_alloc_ should never point into a "big" block, w/ size >= block_size_
+ assert(freestart_ >= last_alloc_ && freestart_ <= last_alloc_ + block_size_);
+ assert(remaining_ >= 0); // should be: it's a size_t!
+ if (newsize > (freestart_ - last_alloc_) + remaining_)
+ return false; // not enough room, even after we get back last_alloc_ space
+ const char* old_freestart = freestart_; // where last alloc used to end
+ freestart_ = last_alloc_ + newsize; // where last alloc ends now
+ remaining_ -= (freestart_ - old_freestart); // how much new space we've taken
+
+#ifdef ADDRESS_SANITIZER
+ ASAN_UNPOISON_MEMORY_REGION(last_alloc_, newsize);
+ ASAN_POISON_MEMORY_REGION(freestart_, remaining_);
+#endif
+ return true;
+}
+
+// ----------------------------------------------------------------------
+// UnsafeArena::Realloc()
+// SafeArena::Realloc()
+// If you decide you want to grow -- or shrink -- a memory region,
+// we'll do it for you here. Typically this will involve copying
+// the existing memory to somewhere else on the arena that has
+// more space reserved. But if you're reallocing the last-allocated
+// block, we may be able to accommodate you just by updating a
+// pointer. In any case, we return a pointer to the new memory
+// location, which may be the same as the pointer you passed in.
+// Here's an example of how you might use Realloc():
+//
+// compr_buf = arena->Alloc(uncompr_size); // get too-much space
+// int compr_size;
+// zlib.Compress(uncompr_buf, uncompr_size, compr_buf, &compr_size);
+// compr_buf = arena->Realloc(compr_buf, uncompr_size, compr_size);
+// ----------------------------------------------------------------------
+
+char* UnsafeArena::Realloc(char* original, size_t oldsize, size_t newsize) {
+ assert(oldsize >= 0 && newsize >= 0);
+ // if original happens to be the last allocation we can avoid fragmentation.
+ if (AdjustLastAlloc(original, newsize)) {
+ return original;
+ }
+
+ char* resized = original;
+ if (newsize > oldsize) {
+ resized = Alloc(newsize);
+ memcpy(resized, original, oldsize);
+ } else {
+ // no need to do anything; we're ain't reclaiming any memory!
+ }
+
+#ifdef ADDRESS_SANITIZER
+ // Alloc already returns unpoisoned memory, but handling both cases here
+ // allows us to poison the old memory without worrying about whether or not it
+ // overlaps with the new memory. Thus, we must poison the old memory first.
+ ASAN_POISON_MEMORY_REGION(original, oldsize);
+ ASAN_UNPOISON_MEMORY_REGION(resized, newsize);
+#endif
+ return resized;
+}
+
+// Avoid weak vtables by defining a dummy key method.
+void UnsafeArena::UnusedKeyMethod() {}
+
+} // namespace libtextclassifier3
diff --git a/native/utils/base/arena.h b/native/utils/base/arena.h
new file mode 100644
index 0000000..aec1950
--- /dev/null
+++ b/native/utils/base/arena.h
@@ -0,0 +1,283 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Sometimes it is necessary to allocate a large number of small
+// objects. Doing this the usual way (malloc, new) is slow,
+// especially for multithreaded programs. A BaseArena provides a
+// mark/release method of memory management: it asks for a large chunk
+// from the operating system and doles it out bit by bit as required.
+// Then you free all the memory at once by calling BaseArena::Reset().
+//
+//
+// --Example Uses Of UnsafeArena
+// This is the simplest way. Just create an arena, and whenever you
+// need a block of memory to put something in, call BaseArena::Alloc(). eg
+// s = arena.Alloc(100);
+// snprintf(s, 100, "%s:%d", host, port);
+// arena.Shrink(strlen(s)+1); // optional; see below for use
+//
+// You'll probably use the convenience routines more often:
+// s = arena.Strdup(host); // a copy of host lives in the arena
+// s = arena.Strndup(host, 100); // we guarantee to NUL-terminate!
+// s = arena.Memdup(protobuf, sizeof(protobuf);
+//
+// If you go the Alloc() route, you'll probably allocate too-much-space.
+// You can reclaim the extra space by calling Shrink() before the next
+// Alloc() (or Strdup(), or whatever), with the #bytes you actually used.
+// If you use this method, memory management is easy: just call Alloc()
+// and friends a lot, and call Reset() when you're done with the data.
+//
+// FOR STRINGS: --Uses UnsafeArena
+// This is a special case of STL (below), but is simpler. Use an
+// astring, which acts like a string but allocates from the passed-in
+// arena:
+// astring s(arena); // or "sastring" to use a SafeArena
+// s.assign(host);
+// astring s2(host, hostlen, arena);
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_BASE_ARENA_H_
+#define LIBTEXTCLASSIFIER_UTILS_BASE_ARENA_H_
+
+#include <assert.h>
+#include <string.h>
+#include <vector>
+#ifdef ADDRESS_SANITIZER
+#include <sanitizer/asan_interface.h>
+#endif
+
+#include "utils/base/integral_types.h"
+#include "utils/base/logging.h"
+
+namespace libtextclassifier3 {
+
+// This class is "thread-compatible": different threads can access the
+// arena at the same time without locking, as long as they use only
+// const methods.
+class BaseArena {
+ protected: // You can't make an arena directly; only a subclass of one
+ BaseArena(char* first_block, const size_t block_size, bool align_to_page);
+
+ public:
+ virtual ~BaseArena();
+
+ virtual void Reset();
+
+ // they're "slow" only 'cause they're virtual (subclasses define "fast" ones)
+ virtual char* SlowAlloc(size_t size) = 0;
+ virtual void SlowFree(void* memory, size_t size) = 0;
+ virtual char* SlowRealloc(char* memory, size_t old_size, size_t new_size) = 0;
+
+ class Status {
+ private:
+ friend class BaseArena;
+ size_t bytes_allocated_;
+ public:
+ Status() : bytes_allocated_(0) { }
+ size_t bytes_allocated() const {
+ return bytes_allocated_;
+ }
+ };
+
+ // Accessors and stats counters
+ // This accessor isn't so useful here, but is included so we can be
+ // type-compatible with ArenaAllocator (in arena_allocator.h). That is,
+ // we define arena() because ArenaAllocator does, and that way you
+ // can template on either of these and know it's safe to call arena().
+ virtual BaseArena* arena() { return this; }
+ size_t block_size() const { return block_size_; }
+ int block_count() const;
+ bool is_empty() const {
+ // must check block count in case we allocated a block larger than blksize
+ return freestart_ == freestart_when_empty_ && 1 == block_count();
+ }
+
+ // The alignment that ArenaAllocator uses except for 1-byte objects.
+ static const int kDefaultAlignment = 8;
+
+ protected:
+ bool SatisfyAlignment(const size_t alignment);
+ void MakeNewBlock(const uint32 alignment);
+ void* GetMemoryFallback(const size_t size, const int align);
+ void* GetMemory(const size_t size, const int align) {
+ assert(remaining_ <= block_size_); // an invariant
+ if ( size > 0 && size <= remaining_ && align == 1 ) { // common case
+ last_alloc_ = freestart_;
+ freestart_ += size;
+ remaining_ -= size;
+#ifdef ADDRESS_SANITIZER
+ ASAN_UNPOISON_MEMORY_REGION(last_alloc_, size);
+#endif
+ return reinterpret_cast<void*>(last_alloc_);
+ }
+ return GetMemoryFallback(size, align);
+ }
+
+ // This doesn't actually free any memory except for the last piece allocated
+ void ReturnMemory(void* memory, const size_t size) {
+ if (memory == last_alloc_ &&
+ size == static_cast<size_t>(freestart_ - last_alloc_)) {
+ remaining_ += size;
+ freestart_ = last_alloc_;
+ }
+#ifdef ADDRESS_SANITIZER
+ ASAN_POISON_MEMORY_REGION(memory, size);
+#endif
+ }
+
+ // This is used by Realloc() -- usually we Realloc just by copying to a
+ // bigger space, but for the last alloc we can realloc by growing the region.
+ bool AdjustLastAlloc(void* last_alloc, const size_t newsize);
+
+ Status status_;
+ size_t remaining_;
+
+ private:
+ struct AllocatedBlock {
+ char* mem;
+ size_t size;
+ size_t alignment;
+ };
+
+ // Allocate new new block of at least block_size, with the specified
+ // alignment.
+ // The returned AllocatedBlock* is valid until the next call to AllocNewBlock
+ // or Reset (i.e. anything that might affect overflow_blocks_).
+ AllocatedBlock* AllocNewBlock(const size_t block_size,
+ const uint32 alignment);
+
+ const AllocatedBlock* IndexToBlock(int index) const;
+
+ const size_t block_size_;
+ char* freestart_; // beginning of the free space in most recent block
+ char* freestart_when_empty_; // beginning of the free space when we're empty
+ char* last_alloc_; // used to make sure ReturnBytes() is safe
+ // if the first_blocks_ aren't enough, expand into overflow_blocks_.
+ std::vector<AllocatedBlock>* overflow_blocks_;
+ // STL vector isn't as efficient as it could be, so we use an array at first
+ const bool first_block_externally_owned_; // true if they pass in 1st block
+ const bool page_aligned_; // when true, all blocks need to be page aligned
+ int8_t blocks_alloced_; // how many of the first_blocks_ have been allocated
+ AllocatedBlock first_blocks_[16]; // the length of this array is arbitrary
+
+ void FreeBlocks(); // Frees all except first block
+
+ BaseArena(const BaseArena&) = delete;
+ BaseArena& operator=(const BaseArena&) = delete;
+};
+
+class UnsafeArena : public BaseArena {
+ public:
+ // Allocates a thread-compatible arena with the specified block size.
+ explicit UnsafeArena(const size_t block_size)
+ : BaseArena(nullptr, block_size, false) { }
+ UnsafeArena(const size_t block_size, bool align)
+ : BaseArena(nullptr, block_size, align) { }
+
+ // Allocates a thread-compatible arena with the specified block
+ // size. "first_block" must have size "block_size". Memory is
+ // allocated from "first_block" until it is exhausted; after that
+ // memory is allocated by allocating new blocks from the heap.
+ UnsafeArena(char* first_block, const size_t block_size)
+ : BaseArena(first_block, block_size, false) { }
+ UnsafeArena(char* first_block, const size_t block_size, bool align)
+ : BaseArena(first_block, block_size, align) { }
+
+ char* Alloc(const size_t size) {
+ return reinterpret_cast<char*>(GetMemory(size, 1));
+ }
+ void* AllocAligned(const size_t size, const int align) {
+ return GetMemory(size, align);
+ }
+ char* Calloc(const size_t size) {
+ void* return_value = Alloc(size);
+ memset(return_value, 0, size);
+ return reinterpret_cast<char*>(return_value);
+ }
+
+ void* CallocAligned(const size_t size, const int align) {
+ void* return_value = AllocAligned(size, align);
+ memset(return_value, 0, size);
+ return return_value;
+ }
+
+ // Free does nothing except for the last piece allocated.
+ void Free(void* memory, size_t size) {
+ ReturnMemory(memory, size);
+ }
+ char* SlowAlloc(size_t size) override { // "slow" 'cause it's virtual
+ return Alloc(size);
+ }
+ void SlowFree(void* memory,
+ size_t size) override { // "slow" 'cause it's virt
+ Free(memory, size);
+ }
+ char* SlowRealloc(char* memory, size_t old_size, size_t new_size) override {
+ return Realloc(memory, old_size, new_size);
+ }
+
+ char* Memdup(const char* s, size_t bytes) {
+ char* newstr = Alloc(bytes);
+ memcpy(newstr, s, bytes);
+ return newstr;
+ }
+ char* MemdupPlusNUL(const char* s, size_t bytes) { // like "string(s, len)"
+ char* newstr = Alloc(bytes+1);
+ memcpy(newstr, s, bytes);
+ newstr[bytes] = '\0';
+ return newstr;
+ }
+ char* Strdup(const char* s) {
+ return Memdup(s, strlen(s) + 1);
+ }
+ // Unlike libc's strncpy, I always NUL-terminate. libc's semantics are dumb.
+ // This will allocate at most n+1 bytes (+1 is for the nul terminator).
+ char* Strndup(const char* s, size_t n) {
+ // Use memchr so we don't walk past n.
+ // We can't use the one in //strings since this is the base library,
+ // so we have to reinterpret_cast from the libc void*.
+ const char* eos = reinterpret_cast<const char*>(memchr(s, '\0', n));
+ // if no null terminator found, use full n
+ const size_t bytes = (eos == nullptr) ? n : eos - s;
+ return MemdupPlusNUL(s, bytes);
+ }
+
+ // You can realloc a previously-allocated string either bigger or smaller.
+ // We can be more efficient if you realloc a string right after you allocate
+ // it (eg allocate way-too-much space, fill it, realloc to just-big-enough)
+ char* Realloc(char* original, size_t oldsize, size_t newsize);
+ // If you know the new size is smaller (or equal), you don't need to know
+ // oldsize. We don't check that newsize is smaller, so you'd better be sure!
+ char* Shrink(char* s, size_t newsize) {
+ AdjustLastAlloc(s, newsize); // reclaim space if we can
+ return s; // never need to move if we go smaller
+ }
+
+ // We make a copy so you can keep track of status at a given point in time
+ Status status() const { return status_; }
+
+ // Number of bytes remaining before the arena has to allocate another block.
+ size_t bytes_until_next_allocation() const { return remaining_; }
+
+ private:
+ UnsafeArena(const UnsafeArena&) = delete;
+ UnsafeArena& operator=(const UnsafeArena&) = delete;
+
+ virtual void UnusedKeyMethod(); // Dummy key method to avoid weak vtable.
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_BASE_ARENA_H_
diff --git a/native/utils/base/logging.cc b/native/utils/base/logging.cc
index d7ddeb8..ddd1170 100644
--- a/native/utils/base/logging.cc
+++ b/native/utils/base/logging.cc
@@ -17,8 +17,8 @@
#include "utils/base/logging.h"
#include <stdlib.h>
+
#include <exception>
-#include <iostream>
#include "utils/base/logging_raw.h"
diff --git a/native/utils/base/logging.h b/native/utils/base/logging.h
index 1267f5e..2983e1f 100644
--- a/native/utils/base/logging.h
+++ b/native/utils/base/logging.h
@@ -20,6 +20,7 @@
#include <cassert>
#include <string>
+#include "utils/base/integral_types.h"
#include "utils/base/logging_levels.h"
#include "utils/base/port.h"
@@ -44,20 +45,19 @@
return stream;
}
+template <typename T>
+inline LoggingStringStream &operator<<(LoggingStringStream &stream,
+ T *const entry) {
+ stream.message.append(std::to_string(reinterpret_cast<const uint64>(entry)));
+ return stream;
+}
+
inline LoggingStringStream &operator<<(LoggingStringStream &stream,
const char *message) {
stream.message.append(message);
return stream;
}
-#if defined(HAS_GLOBAL_STRING)
-inline LoggingStringStream &operator<<(LoggingStringStream &stream,
- const ::string &message) {
- stream.message.append(message);
- return stream;
-}
-#endif
-
inline LoggingStringStream &operator<<(LoggingStringStream &stream,
const std::string &message) {
stream.message.append(message);
diff --git a/native/utils/base/macros.h b/native/utils/base/macros.h
index 6739c0b..3517225 100644
--- a/native/utils/base/macros.h
+++ b/native/utils/base/macros.h
@@ -21,6 +21,9 @@
namespace libtextclassifier3 {
+#define TC3_ARRAYSIZE(a) \
+ ((sizeof(a) / sizeof(*(a))) / (size_t)(!(sizeof(a) % sizeof(*(a)))))
+
#if LANG_CXX11
#define TC3_DISALLOW_COPY_AND_ASSIGN(TypeName) \
TypeName(const TypeName &) = delete; \
diff --git a/native/utils/base/unaligned_access.h b/native/utils/base/unaligned_access.h
new file mode 100644
index 0000000..d6907db
--- /dev/null
+++ b/native/utils/base/unaligned_access.h
@@ -0,0 +1,301 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_BASE_UNALIGNED_ACCESS_H_
+#define LIBTEXTCLASSIFIER_UTILS_BASE_UNALIGNED_ACCESS_H_
+
+#include <string.h>
+#include <cstdint>
+
+#include "third_party/absl/base/attributes.h"
+#include "third_party/absl/base/integral_types.h"
+
+// unaligned APIs
+
+// Portable handling of unaligned loads, stores, and copies.
+// On some platforms, like ARM, the copy functions can be more efficient
+// then a load and a store.
+//
+// It is possible to implement all of these these using constant-length memcpy
+// calls, which is portable and will usually be inlined into simple loads and
+// stores if the architecture supports it. However, such inlining usually
+// happens in a pass that's quite late in compilation, which means the resulting
+// loads and stores cannot participate in many other optimizations, leading to
+// overall worse code.
+
+// The unaligned API is C++ only. The declarations use C++ features
+// (namespaces, inline) which are absent or incompatible in C.
+#if defined(__cplusplus)
+
+#if defined(ADDRESS_SANITIZER) || defined(THREAD_SANITIZER) ||\
+ defined(MEMORY_SANITIZER)
+// Consider we have an unaligned load/store of 4 bytes from address 0x...05.
+// AddressSanitizer will treat it as a 3-byte access to the range 05:07 and
+// will miss a bug if 08 is the first unaddressable byte.
+// ThreadSanitizer will also treat this as a 3-byte access to 05:07 and will
+// miss a race between this access and some other accesses to 08.
+// MemorySanitizer will correctly propagate the shadow on unaligned stores
+// and correctly report bugs on unaligned loads, but it may not properly
+// update and report the origin of the uninitialized memory.
+// For all three tools, replacing an unaligned access with a tool-specific
+// callback solves the problem.
+
+// Make sure uint16_t/uint32_t/uint64_t are defined.
+#include <stdint.h>
+
+extern "C" {
+uint16_t __sanitizer_unaligned_load16(const void *p);
+uint32_t __sanitizer_unaligned_load32(const void *p);
+uint64_t __sanitizer_unaligned_load64(const void *p);
+void __sanitizer_unaligned_store16(void *p, uint16_t v);
+void __sanitizer_unaligned_store32(void *p, uint32_t v);
+void __sanitizer_unaligned_store64(void *p, uint64_t v);
+} // extern "C"
+
+namespace libtextclassifier3 {
+
+inline uint16_t UnalignedLoad16(const void *p) {
+ return __sanitizer_unaligned_load16(p);
+}
+
+inline uint32_t UnalignedLoad32(const void *p) {
+ return __sanitizer_unaligned_load32(p);
+}
+
+inline uint64 UnalignedLoad64(const void *p) {
+ return __sanitizer_unaligned_load64(p);
+}
+
+inline void UnalignedStore16(void *p, uint16_t v) {
+ __sanitizer_unaligned_store16(p, v);
+}
+
+inline void UnalignedStore32(void *p, uint32_t v) {
+ __sanitizer_unaligned_store32(p, v);
+}
+
+inline void UnalignedStore64(void *p, uint64 v) {
+ __sanitizer_unaligned_store64(p, v);
+}
+
+} // namespace libtextclassifier3
+
+#define TC3_INTERNAL_UNALIGNED_LOAD16(_p) \
+ (::libtextclassifier3::UnalignedLoad16(_p))
+#define TC3_INTERNAL_UNALIGNED_LOAD32(_p) \
+ (::libtextclassifier3::UnalignedLoad32(_p))
+#define TC3_UNALIGNED_LOAD64(_p) \
+ (::libtextclassifier3::UnalignedLoad64(_p))
+
+#define TC3_UNALIGNED_STORE16(_p, _val) \
+ (::libtextclassifier3::UnalignedStore16(_p, _val))
+#define TC3_UNALIGNED_STORE32(_p, _val) \
+ (::libtextclassifier3::UnalignedStore32(_p, _val))
+#define TC3_UNALIGNED_STORE64(_p, _val) \
+ (::libtextclassifier3::UnalignedStore64(_p, _val))
+
+#elif defined(UNDEFINED_BEHAVIOR_SANITIZER)
+
+namespace libtextclassifier3 {
+
+inline uint16_t UnalignedLoad16(const void *p) {
+ uint16_t t;
+ memcpy(&t, p, sizeof t);
+ return t;
+}
+
+inline uint32_t UnalignedLoad32(const void *p) {
+ uint32_t t;
+ memcpy(&t, p, sizeof t);
+ return t;
+}
+
+inline uint64 UnalignedLoad64(const void *p) {
+ uint64 t;
+ memcpy(&t, p, sizeof t);
+ return t;
+}
+
+inline void UnalignedStore16(void *p, uint16_t v) { memcpy(p, &v, sizeof v); }
+
+inline void UnalignedStore32(void *p, uint32_t v) { memcpy(p, &v, sizeof v); }
+
+inline void UnalignedStore64(void *p, uint64 v) { memcpy(p, &v, sizeof v); }
+
+} // namespace libtextclassifier3
+
+#define TC3_UNALIGNED_LOAD16(_p) (::libtextclassifier3::UnalignedLoad16(_p))
+#define TC3_UNALIGNED_LOAD32(_p) (::libtextclassifier3::UnalignedLoad32(_p))
+#define TC3_UNALIGNED_LOAD64(_p) (::libtextclassifier3::UnalignedLoad64(_p))
+
+#define TC3_UNALIGNED_STORE16(_p, _val) \
+ (::libtextclassifier3::UnalignedStore16(_p, _val))
+#define TC3_UNALIGNED_STORE32(_p, _val) \
+ (::libtextclassifier3::UnalignedStore32(_p, _val))
+#define TC3_UNALIGNED_STORE64(_p, _val) \
+ (::libtextclassifier3::UnalignedStore64(_p, _val))
+
+#elif defined(__x86_64__) || defined(_M_X64) || defined(__i386) || \
+ defined(_M_IX86) || defined(__ppc__) || defined(__PPC__) || \
+ defined(__ppc64__) || defined(__PPC64__)
+
+// x86 and x86-64 can perform unaligned loads/stores directly;
+// modern PowerPC hardware can also do unaligned integer loads and stores;
+// but note: the FPU still sends unaligned loads and stores to a trap handler!
+
+#define TC3_UNALIGNED_LOAD16(_p) \
+ (*reinterpret_cast<const uint16_t *>(_p))
+#define TC3_UNALIGNED_LOAD32(_p) \
+ (*reinterpret_cast<const uint32_t *>(_p))
+#define TC3_UNALIGNED_LOAD64(_p) \
+ (*reinterpret_cast<const uint64 *>(_p))
+
+#define TC3_UNALIGNED_STORE16(_p, _val) \
+ (*reinterpret_cast<uint16_t *>(_p) = (_val))
+#define TC3_UNALIGNED_STORE32(_p, _val) \
+ (*reinterpret_cast<uint32_t *>(_p) = (_val))
+#define TC3_UNALIGNED_STORE64(_p, _val) \
+ (*reinterpret_cast<uint64 *>(_p) = (_val))
+
+#elif defined(__arm__) && \
+ !defined(__ARM_ARCH_5__) && \
+ !defined(__ARM_ARCH_5T__) && \
+ !defined(__ARM_ARCH_5TE__) && \
+ !defined(__ARM_ARCH_5TEJ__) && \
+ !defined(__ARM_ARCH_6__) && \
+ !defined(__ARM_ARCH_6J__) && \
+ !defined(__ARM_ARCH_6K__) && \
+ !defined(__ARM_ARCH_6Z__) && \
+ !defined(__ARM_ARCH_6ZK__) && \
+ !defined(__ARM_ARCH_6T2__)
+
+
+// ARMv7 and newer support native unaligned accesses, but only of 16-bit
+// and 32-bit values (not 64-bit); older versions either raise a fatal signal,
+// do an unaligned read and rotate the words around a bit, or do the reads very
+// slowly (trip through kernel mode). There's no simple #define that says just
+// "ARMv7 or higher", so we have to filter away all ARMv5 and ARMv6
+// sub-architectures. Newer gcc (>= 4.6) set an __ARM_FEATURE_ALIGNED #define,
+// so in time, maybe we can move on to that.
+//
+// This is a mess, but there's not much we can do about it.
+//
+// To further complicate matters, only LDR instructions (single reads) are
+// allowed to be unaligned, not LDRD (two reads) or LDM (many reads). Unless we
+// explicitly tell the compiler that these accesses can be unaligned, it can and
+// will combine accesses. On armcc, the way to signal this is done by accessing
+// through the type (uint32_t __packed *), but GCC has no such attribute
+// (it ignores __attribute__((packed)) on individual variables). However,
+// we can tell it that a _struct_ is unaligned, which has the same effect,
+// so we do that.
+
+namespace libtextclassifier3 {
+
+struct Unaligned16Struct {
+ uint16_t value;
+ uint8_t dummy; // To make the size non-power-of-two.
+} ABSL_ATTRIBUTE_PACKED;
+
+struct Unaligned32Struct {
+ uint32_t value;
+ uint8_t dummy; // To make the size non-power-of-two.
+} ABSL_ATTRIBUTE_PACKED;
+
+} // namespace libtextclassifier3
+
+#define TC3_UNALIGNED_LOAD16(_p) \
+ ((reinterpret_cast<const ::libtextclassifier3::Unaligned16Struct *>(_p)) \
+ ->value)
+#define TC3_UNALIGNED_LOAD32(_p) \
+ ((reinterpret_cast<const ::libtextclassifier3::Unaligned32Struct *>(_p)) \
+ ->value)
+
+#define TC3_UNALIGNED_STORE16(_p, _val) \
+ ((reinterpret_cast< ::libtextclassifier3::Unaligned16Struct *>(_p)) \
+ ->value = (_val))
+#define TC3_UNALIGNED_STORE32(_p, _val) \
+ ((reinterpret_cast< ::libtextclassifier3::Unaligned32Struct *>(_p)) \
+ ->value = (_val))
+
+namespace libtextclassifier3 {
+
+inline uint64 UnalignedLoad64(const void *p) {
+ uint64 t;
+ memcpy(&t, p, sizeof t);
+ return t;
+}
+
+inline void UnalignedStore64(void *p, uint64 v) { memcpy(p, &v, sizeof v); }
+
+} // namespace libtextclassifier3
+
+#define TC3_UNALIGNED_LOAD64(_p) (::libtextclassifier3::UnalignedLoad64(_p))
+#define TC3_UNALIGNED_STORE64(_p, _val) \
+ (::libtextclassifier3::UnalignedStore64(_p, _val))
+
+#else
+
+// TC3_NEED_ALIGNED_LOADS is defined when the underlying platform
+// doesn't support unaligned access.
+#define TC3_NEED_ALIGNED_LOADS
+
+// These functions are provided for architectures that don't support
+// unaligned loads and stores.
+
+namespace libtextclassifier3 {
+
+inline uint16_t UnalignedLoad16(const void *p) {
+ uint16_t t;
+ memcpy(&t, p, sizeof t);
+ return t;
+}
+
+inline uint32_t UnalignedLoad32(const void *p) {
+ uint32_t t;
+ memcpy(&t, p, sizeof t);
+ return t;
+}
+
+inline uint64 UnalignedLoad64(const void *p) {
+ uint64 t;
+ memcpy(&t, p, sizeof t);
+ return t;
+}
+
+inline void UnalignedStore16(void *p, uint16_t v) { memcpy(p, &v, sizeof v); }
+
+inline void UnalignedStore32(void *p, uint32_t v) { memcpy(p, &v, sizeof v); }
+
+inline void UnalignedStore64(void *p, uint64 v) { memcpy(p, &v, sizeof v); }
+
+} // namespace libtextclassifier3
+
+#define TC3_UNALIGNED_LOAD16(_p) (::libtextclassifier3::UnalignedLoad16(_p))
+#define TC3_UNALIGNED_LOAD32(_p) (::libtextclassifier3::UnalignedLoad32(_p))
+#define TC3_UNALIGNED_LOAD64(_p) (::libtextclassifier3::UnalignedLoad64(_p))
+
+#define TC3_UNALIGNED_STORE16(_p, _val) \
+ (::libtextclassifier3::UnalignedStore16(_p, _val))
+#define TC3_UNALIGNED_STORE32(_p, _val) \
+ (::libtextclassifier3::UnalignedStore32(_p, _val))
+#define TC3_UNALIGNED_STORE64(_p, _val) \
+ (::libtextclassifier3::UnalignedStore64(_p, _val))
+
+#endif
+
+#endif // defined(__cplusplus), end of unaligned API
+
+#endif // LIBTEXTCLASSIFIER_UTILS_BASE_UNALIGNED_ACCESS_H_
diff --git a/native/utils/calendar/CalendarJavaIcuLocalTest.java b/native/utils/calendar/CalendarJavaIcuLocalTest.java
deleted file mode 100644
index 9beb36e..0000000
--- a/native/utils/calendar/CalendarJavaIcuLocalTest.java
+++ /dev/null
@@ -1,41 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.google.android.textclassifier.utils.calendar;
-
-import static com.google.common.truth.Truth.assertThat;
-
-import com.google.thirdparty.robolectric.GoogleRobolectricTestRunner;
-import org.junit.Before;
-import org.junit.Test;
-import org.junit.runner.RunWith;
-
-/** This is just a launcher of the tests because we need a valid JNIEnv in the C++ test. */
-@RunWith(GoogleRobolectricTestRunner.class)
-public class CalendarJavaIcuLocalTest {
-
- @Before
- public void setUp() throws Exception {
- System.loadLibrary("calendar-javaicu_test-lib");
- }
-
- private native boolean testsMain();
-
- @Test
- public void testNative() {
- assertThat(testsMain()).isTrue();
- }
-}
diff --git a/native/utils/calendar/CalendarJavaIcuTest.java b/native/utils/calendar/CalendarJavaIcuTest.java
deleted file mode 100644
index ab1f00a..0000000
--- a/native/utils/calendar/CalendarJavaIcuTest.java
+++ /dev/null
@@ -1,41 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.google.android.textclassifier.utils.calendar;
-
-import static com.google.common.truth.Truth.assertThat;
-
-import org.junit.Before;
-import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.junit.runners.JUnit4;
-
-/** This is just a launcher of the tests because we need a valid JNIEnv in the C++ test. */
-@RunWith(JUnit4.class)
-public class CalendarJavaIcuTest {
-
- @Before
- public void setUp() throws Exception {
- System.loadLibrary("calendar-javaicu_test-lib");
- }
-
- private native boolean testsMain();
-
- @Test
- public void testNative() {
- assertThat(testsMain()).isTrue();
- }
-}
diff --git a/native/utils/calendar/calendar-common.h b/native/utils/calendar/calendar-common.h
index 5c91e22..f47b367 100644
--- a/native/utils/calendar/calendar-common.h
+++ b/native/utils/calendar/calendar-common.h
@@ -37,19 +37,19 @@
template <class TCalendar>
class CalendarLibTempl {
public:
- bool InterpretParseData(const DateParseData& parse_data,
+ bool InterpretParseData(const DatetimeParsedData& parse_data,
int64 reference_time_ms_utc,
const std::string& reference_timezone,
const std::string& reference_locale,
TCalendar* calendar,
DatetimeGranularity* granularity) const;
- DatetimeGranularity GetGranularity(const DateParseData& data) const;
+ DatetimeGranularity GetGranularity(const DatetimeParsedData& data) const;
private:
// Adjusts the calendar's time instant according to a relative date reference
// in the parsed data.
- bool ApplyRelationField(const DateParseData& parse_data,
+ bool ApplyRelationField(const DatetimeParsedData& parse_data,
TCalendar* calendar) const;
// Round the time instant's precision down to the given granularity.
@@ -64,13 +64,13 @@
// Wednesday at least 4 weeks from now.
// If allow_today is true, the same day of the week may be kept
// if it already matches the relation type.
- bool AdjustByRelation(DateParseData::RelationType relation_type, int distance,
+ bool AdjustByRelation(DatetimeComponent date_time_component, int distance,
bool allow_today, TCalendar* calendar) const;
};
template <class TCalendar>
bool CalendarLibTempl<TCalendar>::InterpretParseData(
- const DateParseData& parse_data, int64 reference_time_ms_utc,
+ const DatetimeParsedData& parse_data, int64 reference_time_ms_utc,
const std::string& reference_timezone, const std::string& reference_locale,
TCalendar* calendar, DatetimeGranularity* granularity) const {
TC3_CALENDAR_CHECK(calendar->Initialize(reference_timezone, reference_locale,
@@ -81,23 +81,26 @@
// Apply each of the parsed fields in order of increasing granularity.
static const int64 kMillisInHour = 1000 * 60 * 60;
- if (parse_data.field_set_mask & DateParseData::Fields::ZONE_OFFSET_FIELD) {
- TC3_CALENDAR_CHECK(
- calendar->SetZoneOffset(parse_data.zone_offset * kMillisInHour))
+ if (parse_data.HasFieldType(DatetimeComponent::ComponentType::ZONE_OFFSET)) {
+ int zone_offset;
+ parse_data.GetFieldValue(DatetimeComponent::ComponentType::ZONE_OFFSET,
+ &zone_offset);
+ TC3_CALENDAR_CHECK(calendar->SetZoneOffset(zone_offset * kMillisInHour))
}
- if (parse_data.field_set_mask & DateParseData::Fields::DST_OFFSET_FIELD) {
- TC3_CALENDAR_CHECK(
- calendar->SetDstOffset(parse_data.dst_offset * kMillisInHour))
- }
- if (parse_data.field_set_mask & DateParseData::Fields::RELATION_FIELD) {
- TC3_CALENDAR_CHECK(ApplyRelationField(parse_data, calendar));
- // Don't round to the granularity for relative expressions that specify the
- // distance. So that, e.g. "in 2 hours" when it's 8:35:03 will result in
- // 10:35:03.
- if (parse_data.field_set_mask &
- DateParseData::Fields::RELATION_DISTANCE_FIELD) {
- should_round_to_granularity = false;
+
+ if (parse_data.HasFieldType(DatetimeComponent::ComponentType::DST_OFFSET)) {
+ int dst_offset;
+ if (parse_data.GetFieldValue(DatetimeComponent::ComponentType::DST_OFFSET,
+ &dst_offset)) {
+ TC3_CALENDAR_CHECK(calendar->SetDstOffset(dst_offset * kMillisInHour))
}
+ }
+ std::vector<DatetimeComponent> relative_components;
+ parse_data.GetRelativeDatetimeComponents(&relative_components);
+ if (!relative_components.empty()) {
+ TC3_CALENDAR_CHECK(ApplyRelationField(parse_data, calendar));
+ const DatetimeComponent& relative_component = relative_components.back();
+ should_round_to_granularity = relative_component.ShouldRoundToGranularity();
} else {
// By default, the parsed time is interpreted to be on the reference day.
// But a parsed date should have time 0:00:00 unless specified.
@@ -106,35 +109,56 @@
TC3_CALENDAR_CHECK(calendar->SetSecond(0))
TC3_CALENDAR_CHECK(calendar->SetMillisecond(0))
}
- if (parse_data.field_set_mask & DateParseData::Fields::YEAR_FIELD) {
- TC3_CALENDAR_CHECK(calendar->SetYear(parse_data.year))
+ if (parse_data.HasAbsoluteValue(DatetimeComponent::ComponentType::YEAR)) {
+ int year;
+ parse_data.GetFieldValue(DatetimeComponent::ComponentType::YEAR, &year);
+ TC3_CALENDAR_CHECK(calendar->SetYear(year))
}
- if (parse_data.field_set_mask & DateParseData::Fields::MONTH_FIELD) {
+ if (parse_data.HasAbsoluteValue(DatetimeComponent::ComponentType::MONTH)) {
+ int month;
+ parse_data.GetFieldValue(DatetimeComponent::ComponentType::MONTH, &month);
// ICU has months starting at 0, Java and Datetime parser at 1, so we
// need to subtract 1.
- TC3_CALENDAR_CHECK(calendar->SetMonth(parse_data.month - 1))
- }
- if (parse_data.field_set_mask & DateParseData::Fields::DAY_FIELD) {
- TC3_CALENDAR_CHECK(calendar->SetDayOfMonth(parse_data.day_of_month))
- }
- if (parse_data.field_set_mask & DateParseData::Fields::HOUR_FIELD) {
- if (parse_data.field_set_mask & DateParseData::Fields::AMPM_FIELD &&
- parse_data.ampm == DateParseData::AMPM::PM && parse_data.hour < 12) {
- TC3_CALENDAR_CHECK(calendar->SetHourOfDay(parse_data.hour + 12))
- } else if (parse_data.ampm == DateParseData::AMPM::AM &&
- parse_data.hour == 12) {
- // Do nothing. 12am == 0.
- } else {
- TC3_CALENDAR_CHECK(calendar->SetHourOfDay(parse_data.hour))
- }
- }
- if (parse_data.field_set_mask & DateParseData::Fields::MINUTE_FIELD) {
- TC3_CALENDAR_CHECK(calendar->SetMinute(parse_data.minute))
- }
- if (parse_data.field_set_mask & DateParseData::Fields::SECOND_FIELD) {
- TC3_CALENDAR_CHECK(calendar->SetSecond(parse_data.second))
+ TC3_CALENDAR_CHECK(calendar->SetMonth(month - 1))
}
+ if (parse_data.HasAbsoluteValue(
+ DatetimeComponent::ComponentType::DAY_OF_MONTH)) {
+ int day_of_month;
+ parse_data.GetFieldValue(DatetimeComponent::ComponentType::DAY_OF_MONTH,
+ &day_of_month);
+ TC3_CALENDAR_CHECK(calendar->SetDayOfMonth(day_of_month))
+ }
+ if (parse_data.HasAbsoluteValue(DatetimeComponent::ComponentType::HOUR)) {
+ int hour;
+ parse_data.GetFieldValue(DatetimeComponent::ComponentType::HOUR, &hour);
+ if (parse_data.HasFieldType(DatetimeComponent::ComponentType::MERIDIEM)) {
+ int merdiem;
+ parse_data.GetFieldValue(DatetimeComponent::ComponentType::MERIDIEM,
+ &merdiem);
+ if (merdiem == 1 && hour < 12) {
+ TC3_CALENDAR_CHECK(calendar->SetHourOfDay(hour + 12))
+ } else if (merdiem == 0 && hour == 12) {
+ // Set hour of the day's value to zero (12am == 0:00 in 24 hour format).
+ // Please see issue b/139923083.
+ TC3_CALENDAR_CHECK(calendar->SetHourOfDay(0));
+ } else {
+ TC3_CALENDAR_CHECK(calendar->SetHourOfDay(hour))
+ }
+ } else {
+ TC3_CALENDAR_CHECK(calendar->SetHourOfDay(hour))
+ }
+ }
+ if (parse_data.HasAbsoluteValue(DatetimeComponent::ComponentType::MINUTE)) {
+ int minute;
+ parse_data.GetFieldValue(DatetimeComponent::ComponentType::MINUTE, &minute);
+ TC3_CALENDAR_CHECK(calendar->SetMinute(minute))
+ }
+ if (parse_data.HasAbsoluteValue(DatetimeComponent::ComponentType::SECOND)) {
+ int second;
+ parse_data.GetFieldValue(DatetimeComponent::ComponentType::SECOND, &second);
+ TC3_CALENDAR_CHECK(calendar->SetSecond(second))
+ }
if (should_round_to_granularity) {
TC3_CALENDAR_CHECK(RoundToGranularity(*granularity, calendar))
}
@@ -143,58 +167,55 @@
template <class TCalendar>
bool CalendarLibTempl<TCalendar>::ApplyRelationField(
- const DateParseData& parse_data, TCalendar* calendar) const {
- constexpr int relation_type_mask = DateParseData::Fields::RELATION_TYPE_FIELD;
- constexpr int relation_distance_mask =
- DateParseData::Fields::RELATION_DISTANCE_FIELD;
- switch (parse_data.relation) {
- case DateParseData::Relation::UNSPECIFIED:
+ const DatetimeParsedData& parse_data, TCalendar* calendar) const {
+ std::vector<DatetimeComponent> relative_date_time_components;
+ parse_data.GetRelativeDatetimeComponents(&relative_date_time_components);
+ if (relative_date_time_components.empty()) {
+ // There is no relative field set in the parsed data.
+ return false;
+ }
+ // Current only one relative date time component is possible.
+ DatetimeComponent relative_date_time_component =
+ relative_date_time_components.back();
+
+ switch (relative_date_time_component.relative_qualifier) {
+ case DatetimeComponent::RelativeQualifier::UNSPECIFIED:
TC3_LOG(ERROR) << "UNSPECIFIED RelationType.";
return false;
- case DateParseData::Relation::NEXT:
- if (parse_data.field_set_mask & relation_type_mask) {
- TC3_CALENDAR_CHECK(AdjustByRelation(parse_data.relation_type,
- /*distance=*/1,
- /*allow_today=*/false, calendar));
- }
+ case DatetimeComponent::RelativeQualifier::NEXT:
+ TC3_CALENDAR_CHECK(AdjustByRelation(relative_date_time_component,
+ /*distance=*/1,
+ /*allow_today=*/false, calendar));
return true;
- case DateParseData::Relation::NEXT_OR_SAME:
- if (parse_data.field_set_mask & relation_type_mask) {
- TC3_CALENDAR_CHECK(AdjustByRelation(parse_data.relation_type,
- /*distance=*/1,
- /*allow_today=*/true, calendar))
- }
+ case DatetimeComponent::RelativeQualifier::THIS:
+ TC3_CALENDAR_CHECK(AdjustByRelation(relative_date_time_component,
+ /*distance=*/1,
+ /*allow_today=*/true, calendar))
return true;
- case DateParseData::Relation::LAST:
- if (parse_data.field_set_mask & relation_type_mask) {
- TC3_CALENDAR_CHECK(AdjustByRelation(parse_data.relation_type,
- /*distance=*/-1,
- /*allow_today=*/false, calendar))
- }
+ case DatetimeComponent::RelativeQualifier::LAST:
+ TC3_CALENDAR_CHECK(AdjustByRelation(relative_date_time_component,
+ /*distance=*/-1,
+ /*allow_today=*/false, calendar))
return true;
- case DateParseData::Relation::NOW:
+ case DatetimeComponent::RelativeQualifier::NOW:
return true; // NOOP
- case DateParseData::Relation::TOMORROW:
+ case DatetimeComponent::RelativeQualifier::TOMORROW:
TC3_CALENDAR_CHECK(calendar->AddDayOfMonth(1));
return true;
- case DateParseData::Relation::YESTERDAY:
+ case DatetimeComponent::RelativeQualifier::YESTERDAY:
TC3_CALENDAR_CHECK(calendar->AddDayOfMonth(-1));
return true;
- case DateParseData::Relation::PAST:
- if ((parse_data.field_set_mask & relation_type_mask) &&
- (parse_data.field_set_mask & relation_distance_mask)) {
- TC3_CALENDAR_CHECK(AdjustByRelation(parse_data.relation_type,
- -parse_data.relation_distance,
- /*allow_today=*/false, calendar))
- }
+ case DatetimeComponent::RelativeQualifier::PAST:
+ TC3_CALENDAR_CHECK(
+ AdjustByRelation(relative_date_time_component,
+ -relative_date_time_component.relative_count,
+ /*allow_today=*/false, calendar))
return true;
- case DateParseData::Relation::FUTURE:
- if ((parse_data.field_set_mask & relation_type_mask) &&
- (parse_data.field_set_mask & relation_distance_mask)) {
- TC3_CALENDAR_CHECK(AdjustByRelation(parse_data.relation_type,
- parse_data.relation_distance,
- /*allow_today=*/false, calendar))
- }
+ case DatetimeComponent::RelativeQualifier::FUTURE:
+ TC3_CALENDAR_CHECK(
+ AdjustByRelation(relative_date_time_component,
+ relative_date_time_component.relative_count,
+ /*allow_today=*/false, calendar))
return true;
}
return false;
@@ -242,17 +263,11 @@
template <class TCalendar>
bool CalendarLibTempl<TCalendar>::AdjustByRelation(
- DateParseData::RelationType relation_type, int distance, bool allow_today,
+ DatetimeComponent date_time_component, int distance, bool allow_today,
TCalendar* calendar) const {
const int distance_sign = distance < 0 ? -1 : 1;
- switch (relation_type) {
- case DateParseData::RelationType::MONDAY:
- case DateParseData::RelationType::TUESDAY:
- case DateParseData::RelationType::WEDNESDAY:
- case DateParseData::RelationType::THURSDAY:
- case DateParseData::RelationType::FRIDAY:
- case DateParseData::RelationType::SATURDAY:
- case DateParseData::RelationType::SUNDAY:
+ switch (date_time_component.component_type) {
+ case DatetimeComponent::ComponentType::DAY_OF_WEEK:
if (!allow_today) {
// If we're not including the same day as the reference, skip it.
TC3_CALENDAR_CHECK(calendar->AddDayOfMonth(distance_sign))
@@ -261,40 +276,40 @@
while (distance != 0) {
int day_of_week;
TC3_CALENDAR_CHECK(calendar->GetDayOfWeek(&day_of_week))
- if (day_of_week == static_cast<int>(relation_type)) {
+ if (day_of_week == (date_time_component.value)) {
distance += -distance_sign;
if (distance == 0) break;
}
TC3_CALENDAR_CHECK(calendar->AddDayOfMonth(distance_sign))
}
return true;
- case DateParseData::RelationType::SECOND:
+ case DatetimeComponent::ComponentType::SECOND:
TC3_CALENDAR_CHECK(calendar->AddSecond(distance));
return true;
- case DateParseData::RelationType::MINUTE:
+ case DatetimeComponent::ComponentType::MINUTE:
TC3_CALENDAR_CHECK(calendar->AddMinute(distance));
return true;
- case DateParseData::RelationType::HOUR:
+ case DatetimeComponent::ComponentType::HOUR:
TC3_CALENDAR_CHECK(calendar->AddHourOfDay(distance));
return true;
- case DateParseData::RelationType::DAY:
+ case DatetimeComponent::ComponentType::DAY_OF_MONTH:
TC3_CALENDAR_CHECK(calendar->AddDayOfMonth(distance));
return true;
- case DateParseData::RelationType::WEEK:
+ case DatetimeComponent::ComponentType::WEEK:
TC3_CALENDAR_CHECK(calendar->AddDayOfMonth(7 * distance))
TC3_CALENDAR_CHECK(calendar->SetDayOfWeek(1))
return true;
- case DateParseData::RelationType::MONTH:
+ case DatetimeComponent::ComponentType::MONTH:
TC3_CALENDAR_CHECK(calendar->AddMonth(distance))
TC3_CALENDAR_CHECK(calendar->SetDayOfMonth(1))
return true;
- case DateParseData::RelationType::YEAR:
+ case DatetimeComponent::ComponentType::YEAR:
TC3_CALENDAR_CHECK(calendar->AddYear(distance))
TC3_CALENDAR_CHECK(calendar->SetDayOfYear(1))
return true;
default:
TC3_LOG(ERROR) << "Unknown relation type: "
- << static_cast<int>(relation_type);
+ << static_cast<int>(date_time_component.component_type);
return false;
}
return false;
@@ -302,55 +317,8 @@
template <class TCalendar>
DatetimeGranularity CalendarLibTempl<TCalendar>::GetGranularity(
- const DateParseData& data) const {
- DatetimeGranularity granularity = DatetimeGranularity::GRANULARITY_YEAR;
- if ((data.field_set_mask & DateParseData::YEAR_FIELD) ||
- (data.field_set_mask & DateParseData::RELATION_TYPE_FIELD &&
- (data.relation_type == DateParseData::RelationType::YEAR))) {
- granularity = DatetimeGranularity::GRANULARITY_YEAR;
- }
- if ((data.field_set_mask & DateParseData::MONTH_FIELD) ||
- (data.field_set_mask & DateParseData::RELATION_TYPE_FIELD &&
- (data.relation_type == DateParseData::RelationType::MONTH))) {
- granularity = DatetimeGranularity::GRANULARITY_MONTH;
- }
- if (data.field_set_mask & DateParseData::RELATION_TYPE_FIELD &&
- (data.relation_type == DateParseData::RelationType::WEEK)) {
- granularity = DatetimeGranularity::GRANULARITY_WEEK;
- }
- if (data.field_set_mask & DateParseData::DAY_FIELD ||
- (data.field_set_mask & DateParseData::RELATION_FIELD &&
- (data.relation == DateParseData::Relation::NOW ||
- data.relation == DateParseData::Relation::TOMORROW ||
- data.relation == DateParseData::Relation::YESTERDAY)) ||
- (data.field_set_mask & DateParseData::RELATION_TYPE_FIELD &&
- (data.relation_type == DateParseData::RelationType::MONDAY ||
- data.relation_type == DateParseData::RelationType::TUESDAY ||
- data.relation_type == DateParseData::RelationType::WEDNESDAY ||
- data.relation_type == DateParseData::RelationType::THURSDAY ||
- data.relation_type == DateParseData::RelationType::FRIDAY ||
- data.relation_type == DateParseData::RelationType::SATURDAY ||
- data.relation_type == DateParseData::RelationType::SUNDAY ||
- data.relation_type == DateParseData::RelationType::DAY))) {
- granularity = DatetimeGranularity::GRANULARITY_DAY;
- }
- if (data.field_set_mask & DateParseData::HOUR_FIELD ||
- (data.field_set_mask & DateParseData::RELATION_TYPE_FIELD &&
- (data.relation_type == DateParseData::RelationType::HOUR))) {
- granularity = DatetimeGranularity::GRANULARITY_HOUR;
- }
- if (data.field_set_mask & DateParseData::MINUTE_FIELD ||
- (data.field_set_mask & DateParseData::RELATION_TYPE_FIELD &&
- data.relation_type == DateParseData::RelationType::MINUTE)) {
- granularity = DatetimeGranularity::GRANULARITY_MINUTE;
- }
- if (data.field_set_mask & DateParseData::SECOND_FIELD ||
- (data.field_set_mask & DateParseData::RELATION_TYPE_FIELD &&
- (data.relation_type == DateParseData::RelationType::SECOND))) {
- granularity = DatetimeGranularity::GRANULARITY_SECOND;
- }
-
- return granularity;
+ const DatetimeParsedData& data) const {
+ return data.GetFinestGranularity();
}
}; // namespace calendar
diff --git a/native/utils/calendar/calendar-javaicu.cc b/native/utils/calendar/calendar-javaicu.cc
index ac09979..59af9d4 100644
--- a/native/utils/calendar/calendar-javaicu.cc
+++ b/native/utils/calendar/calendar-javaicu.cc
@@ -67,20 +67,13 @@
}
// We'll assume the day indices match later on, so verify it here.
- if (jni_cache_->calendar_sunday !=
- static_cast<int>(DateParseData::RelationType::SUNDAY) ||
- jni_cache_->calendar_monday !=
- static_cast<int>(DateParseData::RelationType::MONDAY) ||
- jni_cache_->calendar_tuesday !=
- static_cast<int>(DateParseData::RelationType::TUESDAY) ||
- jni_cache_->calendar_wednesday !=
- static_cast<int>(DateParseData::RelationType::WEDNESDAY) ||
- jni_cache_->calendar_thursday !=
- static_cast<int>(DateParseData::RelationType::THURSDAY) ||
- jni_cache_->calendar_friday !=
- static_cast<int>(DateParseData::RelationType::FRIDAY) ||
- jni_cache_->calendar_saturday !=
- static_cast<int>(DateParseData::RelationType::SATURDAY)) {
+ if (jni_cache_->calendar_sunday != kSunday ||
+ jni_cache_->calendar_monday != kMonday ||
+ jni_cache_->calendar_tuesday != kTuesday ||
+ jni_cache_->calendar_wednesday != kWednesday ||
+ jni_cache_->calendar_thursday != kThursday ||
+ jni_cache_->calendar_friday != kFriday ||
+ jni_cache_->calendar_saturday != kSaturday) {
TC3_LOG(ERROR) << "day of the week indices mismatch";
return false;
}
diff --git a/native/utils/calendar/calendar-javaicu.h b/native/utils/calendar/calendar-javaicu.h
index 02673cc..035530e 100644
--- a/native/utils/calendar/calendar-javaicu.h
+++ b/native/utils/calendar/calendar-javaicu.h
@@ -67,7 +67,7 @@
explicit CalendarLib(const std::shared_ptr<JniCache>& jni_cache);
// Returns false (dummy version).
- bool InterpretParseData(const DateParseData& parse_data,
+ bool InterpretParseData(const DatetimeParsedData& parse_data,
int64 reference_time_ms_utc,
const std::string& reference_timezone,
const std::string& reference_locale,
@@ -82,7 +82,7 @@
return calendar.GetTimeInMillis(interpreted_time_ms_utc);
}
- DatetimeGranularity GetGranularity(const DateParseData& data) const {
+ DatetimeGranularity GetGranularity(const DatetimeParsedData& data) const {
return impl_.GetGranularity(data);
}
diff --git a/native/utils/calendar/calendar_test-include.cc b/native/utils/calendar/calendar_test-include.cc
index 70520a2..a145fc2 100644
--- a/native/utils/calendar/calendar_test-include.cc
+++ b/native/utils/calendar/calendar_test-include.cc
@@ -19,29 +19,24 @@
namespace libtextclassifier3 {
namespace test_internal {
+static constexpr int kWednesday = 4;
+
TEST_F(CalendarTest, Interface) {
int64 time;
DatetimeGranularity granularity;
std::string timezone;
- bool result = calendarlib_.InterpretParseData(
- DateParseData{/*field_set_mask=*/0, /*year=*/0, /*month=*/0,
- /*day_of_month=*/0, /*hour=*/0, /*minute=*/0, /*second=*/0,
- /*ampm=*/static_cast<DateParseData::AMPM>(0),
- /*zone_offset=*/0, /*dst_offset=*/0,
- static_cast<DateParseData::Relation>(0),
- static_cast<DateParseData::RelationType>(0),
- /*relation_distance=*/0},
- 0L, "Zurich", "en-CH", &time, &granularity);
+ DatetimeParsedData data;
+ bool result = calendarlib_.InterpretParseData(data, 0L, "Zurich", "en-CH",
+ &time, &granularity);
TC3_LOG(INFO) << result;
}
TEST_F(CalendarTest, SetsZeroTimeWhenNotRelative) {
int64 time;
DatetimeGranularity granularity;
- DateParseData data;
+ DatetimeParsedData data;
+ data.SetAbsoluteValue(DatetimeComponent::ComponentType::YEAR, 2018);
- data.year = 2018;
- data.field_set_mask = DateParseData::YEAR_FIELD;
ASSERT_TRUE(calendarlib_.InterpretParseData(
data,
/*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
@@ -58,50 +53,44 @@
TEST_F(CalendarTest, RoundingToGranularityBasic) {
int64 time;
DatetimeGranularity granularity;
- DateParseData data;
+ DatetimeParsedData data;
- data.year = 2018;
- data.field_set_mask = DateParseData::YEAR_FIELD;
+ data.SetAbsoluteValue(DatetimeComponent::ComponentType::YEAR, 2018);
ASSERT_TRUE(calendarlib_.InterpretParseData(
data,
/*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
/*reference_locale=*/"en-CH", &time, &granularity));
EXPECT_EQ(time, 1514761200000L /* Jan 01 2018 00:00:00 */);
- data.month = 4;
- data.field_set_mask |= DateParseData::MONTH_FIELD;
+ data.SetAbsoluteValue(DatetimeComponent::ComponentType::MONTH, 4);
ASSERT_TRUE(calendarlib_.InterpretParseData(
data,
/*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
/*reference_locale=*/"en-CH", &time, &granularity));
EXPECT_EQ(time, 1522533600000L /* Apr 01 2018 00:00:00 */);
- data.day_of_month = 25;
- data.field_set_mask |= DateParseData::DAY_FIELD;
+ data.SetAbsoluteValue(DatetimeComponent::ComponentType::DAY_OF_MONTH, 25);
ASSERT_TRUE(calendarlib_.InterpretParseData(
data,
/*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
/*reference_locale=*/"en-CH", &time, &granularity));
EXPECT_EQ(time, 1524607200000L /* Apr 25 2018 00:00:00 */);
- data.hour = 9;
- data.field_set_mask |= DateParseData::HOUR_FIELD;
+ data.SetAbsoluteValue(DatetimeComponent::ComponentType::HOUR, 9);
ASSERT_TRUE(calendarlib_.InterpretParseData(
data,
/*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
/*reference_locale=*/"en-CH", &time, &granularity));
EXPECT_EQ(time, 1524639600000L /* Apr 25 2018 09:00:00 */);
- data.minute = 33;
- data.field_set_mask |= DateParseData::MINUTE_FIELD;
+ data.SetAbsoluteValue(DatetimeComponent::ComponentType::MINUTE, 33);
ASSERT_TRUE(calendarlib_.InterpretParseData(
data,
/*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
/*reference_locale=*/"en-CH", &time, &granularity));
EXPECT_EQ(time, 1524641580000 /* Apr 25 2018 09:33:00 */);
- data.second = 59;
- data.field_set_mask |= DateParseData::SECOND_FIELD;
+ data.SetAbsoluteValue(DatetimeComponent::ComponentType::SECOND, 59);
ASSERT_TRUE(calendarlib_.InterpretParseData(
data,
/*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
@@ -113,11 +102,10 @@
int64 time;
DatetimeGranularity granularity;
// Prepare data structure that means: "next week"
- DateParseData data;
- data.field_set_mask =
- DateParseData::RELATION_FIELD | DateParseData::RELATION_TYPE_FIELD;
- data.relation = DateParseData::Relation::NEXT;
- data.relation_type = DateParseData::RelationType::WEEK;
+ DatetimeParsedData data;
+ data.SetRelativeValue(DatetimeComponent::ComponentType::WEEK,
+ DatetimeComponent::RelativeQualifier::NEXT);
+ data.SetRelativeCount(DatetimeComponent::ComponentType::WEEK, 1);
ASSERT_TRUE(calendarlib_.InterpretParseData(
data,
@@ -133,28 +121,20 @@
}
TEST_F(CalendarTest, RelativeTime) {
- const int field_mask = DateParseData::RELATION_FIELD |
- DateParseData::RELATION_TYPE_FIELD |
- DateParseData::RELATION_DISTANCE_FIELD;
const int64 ref_time = 1524648839000L; /* 25 April 2018 09:33:59 */
int64 time;
DatetimeGranularity granularity;
// Two Weds from now.
- const DateParseData future_wed_parse = {
- field_mask,
- /*year=*/0,
- /*month=*/0,
- /*day_of_month=*/0,
- /*hour=*/0,
- /*minute=*/0,
- /*second=*/0,
- static_cast<DateParseData::AMPM>(0),
- /*zone_offset=*/0,
- /*dst_offset=*/0,
- DateParseData::Relation::FUTURE,
- DateParseData::RelationType::WEDNESDAY,
- /*relation_distance=*/2};
+ DatetimeParsedData future_wed_parse;
+ future_wed_parse.SetRelativeValue(
+ DatetimeComponent::ComponentType::DAY_OF_WEEK,
+ DatetimeComponent::RelativeQualifier::FUTURE);
+ future_wed_parse.SetRelativeCount(
+ DatetimeComponent::ComponentType::DAY_OF_WEEK, 2);
+ future_wed_parse.SetAbsoluteValue(
+ DatetimeComponent::ComponentType::DAY_OF_WEEK, kWednesday);
+
ASSERT_TRUE(calendarlib_.InterpretParseData(
future_wed_parse, ref_time, /*reference_timezone=*/"Europe/Zurich",
/*reference_locale=*/"en-US", &time, &granularity));
@@ -162,79 +142,59 @@
EXPECT_EQ(granularity, GRANULARITY_DAY);
// Next Wed.
- const DateParseData next_wed_parse = {field_mask,
- /*year=*/0,
- /*month=*/0,
- /*day_of_month=*/0,
- /*hour=*/0,
- /*minute=*/0,
- /*second=*/0,
- static_cast<DateParseData::AMPM>(0),
- /*zone_offset=*/0,
- /*dst_offset=*/0,
- DateParseData::Relation::NEXT,
- DateParseData::RelationType::WEDNESDAY,
- /*relation_distance=*/0};
+ DatetimeParsedData next_wed_parse;
+ next_wed_parse.SetAbsoluteValue(DatetimeComponent::ComponentType::DAY_OF_WEEK,
+ kWednesday);
+ next_wed_parse.SetRelativeValue(DatetimeComponent::ComponentType::DAY_OF_WEEK,
+ DatetimeComponent::RelativeQualifier::NEXT);
+ next_wed_parse.SetRelativeCount(DatetimeComponent::ComponentType::DAY_OF_WEEK,
+ 1);
+
ASSERT_TRUE(calendarlib_.InterpretParseData(
next_wed_parse, ref_time, /*reference_timezone=*/"Europe/Zurich",
/*reference_locale=*/"en-US", &time, &granularity));
- EXPECT_EQ(time, 1525253639000L /* Wed May 02 2018 11:33:59 */);
+ EXPECT_EQ(time, 1525212000000L /* Wed May 02 2018 00:00:00 */);
EXPECT_EQ(granularity, GRANULARITY_DAY);
// Same Wed.
- const DateParseData same_wed_parse = {field_mask,
- /*year=*/0,
- /*month=*/0,
- /*day_of_month=*/0,
- /*hour=*/0,
- /*minute=*/0,
- /*second=*/0,
- static_cast<DateParseData::AMPM>(0),
- /*zone_offset=*/0,
- /*dst_offset=*/0,
- DateParseData::Relation::NEXT_OR_SAME,
- DateParseData::RelationType::WEDNESDAY,
- /*relation_distance=*/0};
+ DatetimeParsedData same_wed_parse;
+ same_wed_parse.SetRelativeValue(DatetimeComponent::ComponentType::DAY_OF_WEEK,
+ DatetimeComponent::RelativeQualifier::THIS);
+ same_wed_parse.SetAbsoluteValue(DatetimeComponent::ComponentType::DAY_OF_WEEK,
+ kWednesday);
+ same_wed_parse.SetRelativeCount(DatetimeComponent::ComponentType::DAY_OF_WEEK,
+ 1);
+
ASSERT_TRUE(calendarlib_.InterpretParseData(
same_wed_parse, ref_time, /*reference_timezone=*/"Europe/Zurich",
/*reference_locale=*/"en-US", &time, &granularity));
- EXPECT_EQ(time, 1524648839000L /* Wed Apr 25 2018 11:33:59 */);
+ EXPECT_EQ(time, 1524607200000L /* Wed Apr 25 2018 00:00:00 */);
EXPECT_EQ(granularity, GRANULARITY_DAY);
// Previous Wed.
- const DateParseData last_wed_parse = {field_mask,
- /*year=*/0,
- /*month=*/0,
- /*day_of_month=*/0,
- /*hour=*/0,
- /*minute=*/0,
- /*second=*/0,
- static_cast<DateParseData::AMPM>(0),
- /*zone_offset=*/0,
- /*dst_offset=*/0,
- DateParseData::Relation::LAST,
- DateParseData::RelationType::WEDNESDAY,
- /*relation_distance=*/0};
+ DatetimeParsedData last_wed_parse;
+ last_wed_parse.SetRelativeValue(DatetimeComponent::ComponentType::DAY_OF_WEEK,
+ DatetimeComponent::RelativeQualifier::LAST);
+ last_wed_parse.SetAbsoluteValue(DatetimeComponent::ComponentType::DAY_OF_WEEK,
+ kWednesday);
+ last_wed_parse.SetRelativeCount(DatetimeComponent::ComponentType::DAY_OF_WEEK,
+ 1);
+
ASSERT_TRUE(calendarlib_.InterpretParseData(
last_wed_parse, ref_time, /*reference_timezone=*/"Europe/Zurich",
/*reference_locale=*/"en-US", &time, &granularity));
- EXPECT_EQ(time, 1524044039000L /* Wed Apr 18 2018 11:33:59 */);
+ EXPECT_EQ(time, 1524002400000L /* Wed Apr 18 2018 00:00:00 */);
EXPECT_EQ(granularity, GRANULARITY_DAY);
// Two Weds ago.
- const DateParseData past_wed_parse = {field_mask,
- /*year=*/0,
- /*month=*/0,
- /*day_of_month=*/0,
- /*hour=*/0,
- /*minute=*/0,
- /*second=*/0,
- static_cast<DateParseData::AMPM>(0),
- /*zone_offset=*/0,
- /*dst_offset=*/0,
- DateParseData::Relation::PAST,
- DateParseData::RelationType::WEDNESDAY,
- /*relation_distance=*/2};
+ DatetimeParsedData past_wed_parse;
+ past_wed_parse.SetRelativeValue(DatetimeComponent::ComponentType::DAY_OF_WEEK,
+ DatetimeComponent::RelativeQualifier::PAST);
+ past_wed_parse.SetAbsoluteValue(DatetimeComponent::ComponentType::DAY_OF_WEEK,
+ kWednesday);
+ past_wed_parse.SetRelativeCount(DatetimeComponent::ComponentType::DAY_OF_WEEK,
+ 2);
+
ASSERT_TRUE(calendarlib_.InterpretParseData(
past_wed_parse, ref_time, /*reference_timezone=*/"Europe/Zurich",
/*reference_locale=*/"en-US", &time, &granularity));
@@ -242,20 +202,12 @@
EXPECT_EQ(granularity, GRANULARITY_DAY);
// In 3 hours.
- const DateParseData in_3_hours_parse = {
- field_mask,
- /*year=*/0,
- /*month=*/0,
- /*day_of_month=*/0,
- /*hour=*/0,
- /*minute=*/0,
- /*second=*/0,
- /*ampm=*/static_cast<DateParseData::AMPM>(0),
- /*zone_offset=*/0,
- /*dst_offset=*/0,
- DateParseData::Relation::FUTURE,
- DateParseData::RelationType::HOUR,
- /*relation_distance=*/3};
+ DatetimeParsedData in_3_hours_parse;
+ in_3_hours_parse.SetRelativeValue(
+ DatetimeComponent::ComponentType::HOUR,
+ DatetimeComponent::RelativeQualifier::FUTURE);
+ in_3_hours_parse.SetRelativeCount(DatetimeComponent::ComponentType::HOUR, 3);
+
ASSERT_TRUE(calendarlib_.InterpretParseData(
in_3_hours_parse, ref_time, /*reference_timezone=*/"Europe/Zurich",
/*reference_locale=*/"en-US", &time, &granularity));
@@ -263,20 +215,13 @@
EXPECT_EQ(granularity, GRANULARITY_HOUR);
// In 5 minutes.
- const DateParseData in_5_minutes_parse = {
- field_mask,
- /*year=*/0,
- /*month=*/0,
- /*day_of_month=*/0,
- /*hour=*/0,
- /*minute=*/0,
- /*second=*/0,
- /*ampm=*/static_cast<DateParseData::AMPM>(0),
- /*zone_offset=*/0,
- /*dst_offset=*/0,
- DateParseData::Relation::FUTURE,
- DateParseData::RelationType::MINUTE,
- /*relation_distance=*/5};
+ DatetimeParsedData in_5_minutes_parse;
+ in_5_minutes_parse.SetRelativeValue(
+ DatetimeComponent::ComponentType::MINUTE,
+ DatetimeComponent::RelativeQualifier::FUTURE);
+ in_5_minutes_parse.SetRelativeCount(DatetimeComponent::ComponentType::MINUTE,
+ 5);
+
ASSERT_TRUE(calendarlib_.InterpretParseData(
in_5_minutes_parse, ref_time, /*reference_timezone=*/"Europe/Zurich",
/*reference_locale=*/"en-US", &time, &granularity));
@@ -284,20 +229,13 @@
EXPECT_EQ(granularity, GRANULARITY_MINUTE);
// In 10 seconds.
- const DateParseData in_10_seconds_parse = {
- field_mask,
- /*year=*/0,
- /*month=*/0,
- /*day_of_month=*/0,
- /*hour=*/0,
- /*minute=*/0,
- /*second=*/0,
- /*ampm=*/static_cast<DateParseData::AMPM>(0),
- /*zone_offset=*/0,
- /*dst_offset=*/0,
- DateParseData::Relation::FUTURE,
- DateParseData::RelationType::SECOND,
- /*relation_distance=*/10};
+ DatetimeParsedData in_10_seconds_parse;
+ in_10_seconds_parse.SetRelativeValue(
+ DatetimeComponent::ComponentType::SECOND,
+ DatetimeComponent::RelativeQualifier::FUTURE);
+ in_10_seconds_parse.SetRelativeCount(DatetimeComponent::ComponentType::SECOND,
+ 10);
+
ASSERT_TRUE(calendarlib_.InterpretParseData(
in_10_seconds_parse, ref_time, /*reference_timezone=*/"Europe/Zurich",
/*reference_locale=*/"en-US", &time, &granularity));
diff --git a/native/utils/calendar/calendar_test-include.h b/native/utils/calendar/calendar_test-include.h
index 169a4ed..58ad6e0 100644
--- a/native/utils/calendar/calendar_test-include.h
+++ b/native/utils/calendar/calendar_test-include.h
@@ -20,9 +20,14 @@
#ifndef LIBTEXTCLASSIFIER_UTILS_CALENDAR_CALENDAR_TEST_INCLUDE_H_
#define LIBTEXTCLASSIFIER_UTILS_CALENDAR_CALENDAR_TEST_INCLUDE_H_
+#include "gtest/gtest.h"
+
#if defined TC3_CALENDAR_ICU
#include "utils/calendar/calendar-icu.h"
#define TC3_TESTING_CREATE_CALENDARLIB_INSTANCE(VAR) VAR()
+#elif defined TC3_CALENDAR_APPLE
+#include "utils/calendar/calendar-apple.h"
+#define TC3_TESTING_CREATE_CALENDARLIB_INSTANCE(VAR) VAR()
#elif defined TC3_CALENDAR_JAVAICU
#include <jni.h>
extern JNIEnv* g_jenv;
@@ -32,9 +37,6 @@
#else
#error Unsupported calendar implementation.
#endif
-#include "utils/base/logging.h"
-
-#include "gtest/gtest.h"
// This can get overridden in the javaicu version which needs to pass an JNIEnv*
// argument to the constructor.
diff --git a/native/utils/calendar/calendar_test.cc b/native/utils/calendar/calendar_test.cc
new file mode 100644
index 0000000..54ed2a0
--- /dev/null
+++ b/native/utils/calendar/calendar_test.cc
@@ -0,0 +1,20 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "gtest/gtest.h"
+
+// The actual code of the test is in the following include:
+#include "utils/calendar/calendar_test-include.h"
diff --git a/native/utils/flatbuffers.cc b/native/utils/flatbuffers.cc
index a4dbabd..005041d 100644
--- a/native/utils/flatbuffers.cc
+++ b/native/utils/flatbuffers.cc
@@ -22,36 +22,31 @@
namespace libtextclassifier3 {
namespace {
-bool CreateRepeatedField(
- const reflection::Schema* schema, const reflection::Type* type,
- std::unique_ptr<ReflectiveFlatbuffer::RepeatedField>* repeated_field) {
+bool CreateRepeatedField(const reflection::Schema* schema,
+ const reflection::Type* type,
+ std::unique_ptr<RepeatedField>* repeated_field) {
switch (type->element()) {
case reflection::Bool:
- repeated_field->reset(new ReflectiveFlatbuffer::TypedRepeatedField<bool>);
+ repeated_field->reset(new TypedRepeatedField<bool>);
return true;
case reflection::Int:
- repeated_field->reset(new ReflectiveFlatbuffer::TypedRepeatedField<int>);
+ repeated_field->reset(new TypedRepeatedField<int>);
return true;
case reflection::Long:
- repeated_field->reset(
- new ReflectiveFlatbuffer::TypedRepeatedField<int64>);
+ repeated_field->reset(new TypedRepeatedField<int64>);
return true;
case reflection::Float:
- repeated_field->reset(
- new ReflectiveFlatbuffer::TypedRepeatedField<float>);
+ repeated_field->reset(new TypedRepeatedField<float>);
return true;
case reflection::Double:
- repeated_field->reset(
- new ReflectiveFlatbuffer::TypedRepeatedField<double>);
+ repeated_field->reset(new TypedRepeatedField<double>);
return true;
case reflection::String:
- repeated_field->reset(
- new ReflectiveFlatbuffer::TypedRepeatedField<std::string>);
+ repeated_field->reset(new TypedRepeatedField<std::string>);
return true;
case reflection::Obj:
repeated_field->reset(
- new ReflectiveFlatbuffer::TypedRepeatedField<ReflectiveFlatbuffer>(
- schema, type));
+ new TypedRepeatedField<ReflectiveFlatbuffer>(schema, type));
return true;
default:
TC3_LOG(ERROR) << "Unsupported type: " << type->element();
@@ -237,8 +232,7 @@
return it->second.get();
}
-ReflectiveFlatbuffer::RepeatedField* ReflectiveFlatbuffer::Repeated(
- StringPiece field_name) {
+RepeatedField* ReflectiveFlatbuffer::Repeated(StringPiece field_name) {
if (const reflection::Field* field = GetFieldOrNull(field_name)) {
return Repeated(field);
}
@@ -246,8 +240,7 @@
return nullptr;
}
-ReflectiveFlatbuffer::RepeatedField* ReflectiveFlatbuffer::Repeated(
- const reflection::Field* field) {
+RepeatedField* ReflectiveFlatbuffer::Repeated(const reflection::Field* field) {
if (field->type()->base_type() != reflection::Vector) {
TC3_LOG(ERROR) << "Field is not of type Vector.";
return nullptr;
diff --git a/native/utils/flatbuffers.fbs b/native/utils/flatbuffers.fbs
index 584b885..155e8f8 100755
--- a/native/utils/flatbuffers.fbs
+++ b/native/utils/flatbuffers.fbs
@@ -18,7 +18,7 @@
namespace libtextclassifier3;
table FlatbufferField {
// Name of the field.
- field_name:string;
+ field_name:string (shared);
// Offset of the field
field_offset:int;
diff --git a/native/utils/flatbuffers.h b/native/utils/flatbuffers.h
index 76b095f..17668ff 100644
--- a/native/utils/flatbuffers.h
+++ b/native/utils/flatbuffers.h
@@ -31,6 +31,11 @@
namespace libtextclassifier3 {
+class ReflectiveFlatBuffer;
+class RepeatedField;
+template <typename T>
+class TypedRepeatedField;
+
// Loads and interprets the buffer as 'FlatbufferMessage' and verifies its
// integrity.
template <typename FlatbufferMessage>
@@ -111,81 +116,6 @@
const reflection::Object* type)
: schema_(schema), type_(type) {}
- // Encapsulates a repeated field.
- // Serves as a common base class for repeated fields.
- class RepeatedField {
- public:
- virtual ~RepeatedField() {}
-
- virtual flatbuffers::uoffset_t Serialize(
- flatbuffers::FlatBufferBuilder* builder) const = 0;
- };
-
- // Represents a repeated field of particular type.
- template <typename T>
- class TypedRepeatedField : public RepeatedField {
- public:
- void Add(const T value) { items_.push_back(value); }
-
- flatbuffers::uoffset_t Serialize(
- flatbuffers::FlatBufferBuilder* builder) const override {
- return builder->CreateVector(items_).o;
- }
-
- private:
- std::vector<T> items_;
- };
-
- // Specialization for strings.
- template <>
- class TypedRepeatedField<std::string> : public RepeatedField {
- public:
- void Add(const std::string& value) { items_.push_back(value); }
-
- flatbuffers::uoffset_t Serialize(
- flatbuffers::FlatBufferBuilder* builder) const override {
- std::vector<flatbuffers::Offset<flatbuffers::String>> offsets(
- items_.size());
- for (int i = 0; i < items_.size(); i++) {
- offsets[i] = builder->CreateString(items_[i]);
- }
- return builder->CreateVector(offsets).o;
- }
-
- private:
- std::vector<std::string> items_;
- };
-
- // Specialization for repeated sub-messages.
- template <>
- class TypedRepeatedField<ReflectiveFlatbuffer> : public RepeatedField {
- public:
- TypedRepeatedField<ReflectiveFlatbuffer>(
- const reflection::Schema* const schema,
- const reflection::Type* const type)
- : schema_(schema), type_(type) {}
-
- ReflectiveFlatbuffer* Add() {
- items_.emplace_back(new ReflectiveFlatbuffer(
- schema_, schema_->objects()->Get(type_->index())));
- return items_.back().get();
- }
-
- flatbuffers::uoffset_t Serialize(
- flatbuffers::FlatBufferBuilder* builder) const override {
- std::vector<flatbuffers::Offset<void>> offsets(items_.size());
- for (int i = 0; i < items_.size(); i++) {
- offsets[i] = items_[i]->Serialize(builder);
- }
- return builder->CreateVector(offsets).o;
- }
-
- private:
- const reflection::Schema* const schema_;
- const reflection::Type* const type_;
- std::vector<std::unique_ptr<ReflectiveFlatbuffer>> items_;
- };
-
// Gets the field information for a field name, returns nullptr if the
// field was not defined.
const reflection::Field* GetFieldOrNull(const StringPiece field_name) const;
@@ -332,6 +262,81 @@
const reflection::Schema* const schema_;
};
+// Encapsulates a repeated field.
+// Serves as a common base class for repeated fields.
+class RepeatedField {
+ public:
+ virtual ~RepeatedField() {}
+
+ virtual flatbuffers::uoffset_t Serialize(
+ flatbuffers::FlatBufferBuilder* builder) const = 0;
+};
+
+// Represents a repeated field of particular type.
+template <typename T>
+class TypedRepeatedField : public RepeatedField {
+ public:
+ void Add(const T value) { items_.push_back(value); }
+
+ flatbuffers::uoffset_t Serialize(
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ return builder->CreateVector(items_).o;
+ }
+
+ private:
+ std::vector<T> items_;
+};
+
+// Specialization for strings.
+template <>
+class TypedRepeatedField<std::string> : public RepeatedField {
+ public:
+ void Add(const std::string& value) { items_.push_back(value); }
+
+ flatbuffers::uoffset_t Serialize(
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ std::vector<flatbuffers::Offset<flatbuffers::String>> offsets(
+ items_.size());
+ for (int i = 0; i < items_.size(); i++) {
+ offsets[i] = builder->CreateString(items_[i]);
+ }
+ return builder->CreateVector(offsets).o;
+ }
+
+ private:
+ std::vector<std::string> items_;
+};
+
+// Specialization for repeated sub-messages.
+template <>
+class TypedRepeatedField<ReflectiveFlatbuffer> : public RepeatedField {
+ public:
+ TypedRepeatedField<ReflectiveFlatbuffer>(
+ const reflection::Schema* const schema,
+ const reflection::Type* const type)
+ : schema_(schema), type_(type) {}
+
+ ReflectiveFlatbuffer* Add() {
+ items_.emplace_back(new ReflectiveFlatbuffer(
+ schema_, schema_->objects()->Get(type_->index())));
+ return items_.back().get();
+ }
+
+ flatbuffers::uoffset_t Serialize(
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ std::vector<flatbuffers::Offset<void>> offsets(items_.size());
+ for (int i = 0; i < items_.size(); i++) {
+ offsets[i] = items_[i]->Serialize(builder);
+ }
+ return builder->CreateVector(offsets).o;
+ }
+
+ private:
+ const reflection::Schema* const schema_;
+ const reflection::Type* const type_;
+ std::vector<std::unique_ptr<ReflectiveFlatbuffer>> items_;
+};
+
} // namespace libtextclassifier3
#endif // LIBTEXTCLASSIFIER_UTILS_FLATBUFFERS_H_
diff --git a/native/utils/intents/IntentGeneratorTest.java b/native/utils/intents/IntentGeneratorTest.java
deleted file mode 100644
index f43ecc0..0000000
--- a/native/utils/intents/IntentGeneratorTest.java
+++ /dev/null
@@ -1,43 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.google.android.textclassifier.utils.intents;
-
-import static com.google.common.truth.Truth.assertThat;
-
-import android.content.Context;
-import androidx.test.InstrumentationRegistry;
-import org.junit.Before;
-import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.junit.runners.JUnit4;
-
-/** This is just a launcher of the tests because we need a valid JNIEnv in the C++ test. */
-@RunWith(JUnit4.class)
-public final class IntentGeneratorTest {
-
- @Before
- public void setUp() throws Exception {
- System.loadLibrary("intent-generator-test-lib");
- }
-
- private native boolean testsMain(Context context);
-
- @Test
- public void testNative() {
- assertThat(testsMain(InstrumentationRegistry.getContext())).isTrue();
- }
-}
diff --git a/native/utils/intents/intent-config.fbs b/native/utils/intents/intent-config.fbs
index 09ebbb4..4b1a288 100755
--- a/native/utils/intents/intent-config.fbs
+++ b/native/utils/intents/intent-config.fbs
@@ -83,7 +83,7 @@
table AndroidIntentFactoryEntityOptions {
// The entity type as defined by one of the TextClassifier ENTITY_TYPE
// constants. (e.g. "address", "phone", etc.)
- entity_type:string;
+ entity_type:string (shared);
// List of generators for all the different types of intents that should
// be made available for the entity type.
@@ -107,26 +107,26 @@
// restrictions, this must /not/ use wildcards. To e.g. match all English
// locales, use only "en" and not "en_*". Reference the java.util.Locale
// constructor for details.
- language_tag:string;
+ language_tag:string (shared);
// Title shown for the action (see RemoteAction.getTitle).
- title:string;
+ title:string (shared);
// Description shown for the action (see
// RemoteAction.getContentDescription).
- description:string;
+ description:string (shared);
}
// An extra to set on a simple intent generator Intent.
namespace libtextclassifier3;
table AndroidSimpleIntentGeneratorExtra {
// The name of the extra to set.
- name:string;
+ name:string (shared);
// The type of the extra to set.
type:AndroidSimpleIntentGeneratorExtraType;
- string_:string;
+ string_:string (shared);
bool_:bool;
int32_:int;
@@ -137,7 +137,7 @@
table AndroidSimpleIntentGeneratorCondition {
type:AndroidSimpleIntentGeneratorConditionType;
- string_:string;
+ string_:string (shared);
int32_:int;
int64_:long;
@@ -154,13 +154,13 @@
namespace libtextclassifier3;
table AndroidSimpleIntentGeneratorOptions {
// The action to set on the Intent (see Intent.setAction). Supports variables.
- action:string;
+ action:string (shared);
// The data to set on the Intent (see Intent.setData). Supports variables.
- data:string;
+ data:string (shared);
// The type to set on the Intent (see Intent.setType). Supports variables.
- type:string;
+ type:string (shared);
// The list of all the extras to add to the Intent.
extra:[AndroidSimpleIntentGeneratorExtra];
@@ -179,7 +179,7 @@
table IntentGenerator {
// The type of the intent generator, e.g. the entity type as defined by
// on the TextClassifier ENTITY_TYPE constants e.g. "address", "phone", etc.
- type:string;
+ type:string (shared);
// The template generator lua code, either as text source or precompiled
// bytecode.
diff --git a/native/utils/java/jni-base.cc b/native/utils/java/jni-base.cc
index 4483b79..e04fcf3 100644
--- a/native/utils/java/jni-base.cc
+++ b/native/utils/java/jni-base.cc
@@ -16,17 +16,7 @@
#include "utils/java/jni-base.h"
-#include <jni.h>
-#include <type_traits>
-#include <vector>
-
-#include "utils/base/integral_types.h"
-#include "utils/java/scoped_local_ref.h"
#include "utils/java/string_utils.h"
-#include "utils/memory/mmap.h"
-
-using libtextclassifier3::JStringToUtf8String;
-using libtextclassifier3::ScopedLocalRef;
namespace libtextclassifier3 {
@@ -36,41 +26,4 @@
return result;
}
-jint GetFdFromFileDescriptor(JNIEnv* env, jobject fd) {
- ScopedLocalRef<jclass> fd_class(env->FindClass("java/io/FileDescriptor"),
- env);
- if (fd_class == nullptr) {
- TC3_LOG(ERROR) << "Couldn't find FileDescriptor.";
- return reinterpret_cast<jlong>(nullptr);
- }
- jfieldID fd_class_descriptor =
- env->GetFieldID(fd_class.get(), "descriptor", "I");
- if (fd_class_descriptor == nullptr) {
- env->ExceptionClear();
- fd_class_descriptor = env->GetFieldID(fd_class.get(), "fd", "I");
- }
- if (fd_class_descriptor == nullptr) {
- TC3_LOG(ERROR) << "Couldn't find descriptor.";
- return reinterpret_cast<jlong>(nullptr);
- }
- return env->GetIntField(fd, fd_class_descriptor);
-}
-
-jint GetFdFromAssetFileDescriptor(JNIEnv* env, jobject afd) {
- ScopedLocalRef<jclass> afd_class(
- env->FindClass("android/content/res/AssetFileDescriptor"), env);
- if (afd_class == nullptr) {
- TC3_LOG(ERROR) << "Couldn't find AssetFileDescriptor.";
- return reinterpret_cast<jlong>(nullptr);
- }
- jmethodID afd_class_getFileDescriptor = env->GetMethodID(
- afd_class.get(), "getFileDescriptor", "()Ljava/io/FileDescriptor;");
- if (afd_class_getFileDescriptor == nullptr) {
- TC3_LOG(ERROR) << "Couldn't find getFileDescriptor.";
- return reinterpret_cast<jlong>(nullptr);
- }
- jobject bundle_jfd = env->CallObjectMethod(afd, afd_class_getFileDescriptor);
- return GetFdFromFileDescriptor(env, bundle_jfd);
-}
-
} // namespace libtextclassifier3
diff --git a/native/utils/java/jni-base.h b/native/utils/java/jni-base.h
index 23658a3..05bc082 100644
--- a/native/utils/java/jni-base.h
+++ b/native/utils/java/jni-base.h
@@ -73,11 +73,6 @@
std::string ToStlString(JNIEnv* env, const jstring& str);
-// Get system-level file descriptor from AssetFileDescriptor.
-jint GetFdFromAssetFileDescriptor(JNIEnv* env, jobject afd);
-
-// Get system-level file descriptor from FileDescriptor.
-jint GetFdFromFileDescriptor(JNIEnv* env, jobject fd);
} // namespace libtextclassifier3
#endif // LIBTEXTCLASSIFIER_UTILS_JAVA_JNI_BASE_H_
diff --git a/native/utils/regex-match.cc b/native/utils/regex-match.cc
index 8c55e6b..1120b56 100644
--- a/native/utils/regex-match.cc
+++ b/native/utils/regex-match.cc
@@ -149,16 +149,14 @@
} // namespace
-bool SetFieldFromCapturingGroup(const int group_id,
- const FlatbufferFieldPath* field_path,
- const UniLib::RegexMatcher* matcher,
- ReflectiveFlatbuffer* flatbuffer) {
+Optional<std::string> GetCapturingGroupText(const UniLib::RegexMatcher* matcher,
+ const int group_id) {
int status = UniLib::RegexMatcher::kNoError;
std::string group_text = matcher->Group(group_id, &status).ToUTF8String();
if (status != UniLib::RegexMatcher::kNoError || group_text.empty()) {
- return false;
+ return Optional<std::string>();
}
- return flatbuffer->ParseAndSet(field_path, group_text);
+ return Optional<std::string>(group_text);
}
bool VerifyMatch(const std::string& context,
diff --git a/native/utils/regex-match.h b/native/utils/regex-match.h
index f77f6b1..1466b86 100644
--- a/native/utils/regex-match.h
+++ b/native/utils/regex-match.h
@@ -17,17 +17,15 @@
#ifndef LIBTEXTCLASSIFIER_UTILS_REGEX_MATCH_H_
#define LIBTEXTCLASSIFIER_UTILS_REGEX_MATCH_H_
-#include "utils/flatbuffers.h"
-#include "utils/flatbuffers_generated.h"
+#include "utils/optional.h"
#include "utils/utf8/unilib.h"
namespace libtextclassifier3 {
-// Sets a field in the flatbuffer from a regex match group.
-// Returns true if successful, and false if the field couldn't be set.
-bool SetFieldFromCapturingGroup(const int group_id,
- const FlatbufferFieldPath* field_path,
- const UniLib::RegexMatcher* matcher,
- ReflectiveFlatbuffer* flatbuffer);
+
+// Returns text of a capturing group if the capturing group was fulfilled in
+// the regex match.
+Optional<std::string> GetCapturingGroupText(const UniLib::RegexMatcher* matcher,
+ const int group_id);
// Post-checks a regular expression match with a lua verifier script.
// The verifier can access:
diff --git a/native/utils/regex-match_test.cc b/native/utils/regex-match_test.cc
index ef86d65..9cf22ab 100644
--- a/native/utils/regex-match_test.cc
+++ b/native/utils/regex-match_test.cc
@@ -18,6 +18,7 @@
#include <memory>
+#include "utils/utf8/unicodetext.h"
#include "utils/utf8/unilib.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
@@ -25,18 +26,18 @@
namespace libtextclassifier3 {
namespace {
-class LuaVerifierTest : public testing::Test {
+class RegexMatchTest : public testing::Test {
protected:
- LuaVerifierTest() : INIT_UNILIB_FOR_TESTING(unilib_) {}
+ RegexMatchTest() : INIT_UNILIB_FOR_TESTING(unilib_) {}
UniLib unilib_;
};
#ifdef TC3_UNILIB_ICU
-TEST_F(LuaVerifierTest, HandlesSimpleVerification) {
+TEST_F(RegexMatchTest, HandlesSimpleVerification) {
EXPECT_TRUE(VerifyMatch(/*context=*/"", /*matcher=*/nullptr, "return true;"));
}
-TEST_F(LuaVerifierTest, HandlesCustomVerification) {
+TEST_F(RegexMatchTest, HandlesCustomVerification) {
UnicodeText pattern = UTF8ToUnicodeText("(\\d{16})",
/*do_copy=*/true);
UnicodeText message = UTF8ToUnicodeText("cc: 4012888888881881",
@@ -60,9 +61,11 @@
end
return luhn(match[1].text);
)";
- auto regex_pattern = unilib_.CreateRegexPattern(pattern);
+ const std::unique_ptr<UniLib::RegexPattern> regex_pattern =
+ unilib_.CreateRegexPattern(pattern);
ASSERT_TRUE(regex_pattern != nullptr);
- auto matcher = regex_pattern->Matcher(message);
+ const std::unique_ptr<UniLib::RegexMatcher> matcher =
+ regex_pattern->Matcher(message);
ASSERT_TRUE(matcher != nullptr);
int status = UniLib::RegexMatcher::kNoError;
ASSERT_TRUE(matcher->Find(&status) &&
@@ -70,6 +73,37 @@
EXPECT_TRUE(VerifyMatch(message.ToUTF8String(), matcher.get(), verifier));
}
+
+TEST_F(RegexMatchTest, RetrievesMatchGroupTest) {
+ UnicodeText pattern =
+ UTF8ToUnicodeText("never gonna (?:give (you) up|let (you) down)",
+ /*do_copy=*/true);
+ const std::unique_ptr<UniLib::RegexPattern> regex_pattern =
+ unilib_.CreateRegexPattern(pattern);
+ ASSERT_TRUE(regex_pattern != nullptr);
+ UnicodeText message =
+ UTF8ToUnicodeText("never gonna give you up - never gonna let you down");
+ const std::unique_ptr<UniLib::RegexMatcher> matcher =
+ regex_pattern->Matcher(message);
+ ASSERT_TRUE(matcher != nullptr);
+ int status = UniLib::RegexMatcher::kNoError;
+
+ ASSERT_TRUE(matcher->Find(&status) &&
+ status == UniLib::RegexMatcher::kNoError);
+ EXPECT_THAT(GetCapturingGroupText(matcher.get(), 0).value(),
+ testing::Eq("never gonna give you up"));
+ EXPECT_THAT(GetCapturingGroupText(matcher.get(), 1).value(),
+ testing::Eq("you"));
+ EXPECT_FALSE(GetCapturingGroupText(matcher.get(), 2).has_value());
+
+ ASSERT_TRUE(matcher->Find(&status) &&
+ status == UniLib::RegexMatcher::kNoError);
+ EXPECT_THAT(GetCapturingGroupText(matcher.get(), 0).value(),
+ testing::Eq("never gonna let you down"));
+ EXPECT_FALSE(GetCapturingGroupText(matcher.get(), 1).has_value());
+ EXPECT_THAT(GetCapturingGroupText(matcher.get(), 2).value(),
+ testing::Eq("you"));
+}
#endif
} // namespace
diff --git a/native/utils/resources.fbs b/native/utils/resources.fbs
index a88c56d..2319d53 100755
--- a/native/utils/resources.fbs
+++ b/native/utils/resources.fbs
@@ -19,22 +19,22 @@
namespace libtextclassifier3;
table Resource {
locale:[int];
- content:string;
+ content:string (shared);
compressed_content:CompressedBuffer;
}
namespace libtextclassifier3;
table ResourceEntry {
- name:string (key);
+ name:string (key, shared);
resource:[Resource];
}
// BCP 47 tag for the supported locale.
namespace libtextclassifier3;
table LanguageTag {
- language:string;
- script:string;
- region:string;
+ language:string (shared);
+ script:string (shared);
+ region:string (shared);
}
namespace libtextclassifier3;
diff --git a/native/utils/sentencepiece/sorted_strings_table.cc b/native/utils/sentencepiece/sorted_strings_table.cc
index 8e7e9ba..da5d21d 100644
--- a/native/utils/sentencepiece/sorted_strings_table.cc
+++ b/native/utils/sentencepiece/sorted_strings_table.cc
@@ -17,6 +17,8 @@
#include "utils/sentencepiece/sorted_strings_table.h"
#include <algorithm>
+
+#include "utils/base/endian.h"
#include "utils/base/logging.h"
namespace libtextclassifier3 {
@@ -45,15 +47,17 @@
static_cast<unsigned char>(input[match_length]),
[this, match_length](uint32 piece_offset, uint32 c) -> bool {
return static_cast<unsigned char>(
- pieces_[piece_offset + match_length]) < c;
+ pieces_[piece_offset + match_length]) <
+ LittleEndian::ToHost32(c);
}) -
offsets_);
right = (std::upper_bound(
offsets_ + left, offsets_ + right,
static_cast<unsigned char>(input[match_length]),
[this, match_length](uint32 c, uint32 piece_offset) -> bool {
- return c < static_cast<unsigned char>(
- pieces_[piece_offset + match_length]);
+ return LittleEndian::ToHost32(c) <
+ static_cast<unsigned char>(
+ pieces_[piece_offset + match_length]);
}) -
offsets_);
span_size = right - left;
@@ -64,7 +68,7 @@
// Due to the loop invariant and the fact that the strings are sorted, there
// can only be one piece matching completely now, namely at left.
- if (pieces_[offsets_[left] + match_length] == 0) {
+ if (pieces_[LittleEndian::ToHost32(offsets_[left]) + match_length] == 0) {
update_fn(TrieMatch(/*id=*/left,
/*match_length=*/match_length));
left++;
@@ -77,7 +81,8 @@
for (int i = left; i < right; i++) {
bool matches = true;
int piece_match_length = match_length;
- for (int k = offsets_[i] + piece_match_length; pieces_[k] != 0; k++) {
+ for (int k = LittleEndian::ToHost32(offsets_[i]) + piece_match_length;
+ pieces_[k] != 0; k++) {
if (match_length >= input.size() ||
input[piece_match_length] != pieces_[k]) {
matches = false;
diff --git a/native/utils/strings/append.cc b/native/utils/strings/append.cc
new file mode 100644
index 0000000..36712e8
--- /dev/null
+++ b/native/utils/strings/append.cc
@@ -0,0 +1,87 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/strings/append.h"
+
+#include <cstring>
+#include <string>
+#include <vector>
+
+namespace libtextclassifier3 {
+namespace strings {
+
+void SStringAppendV(std::string *strp, int bufsize, const char *fmt,
+ va_list arglist) {
+ int capacity = bufsize;
+ if (capacity <= 0) {
+ va_list backup;
+ va_copy(backup, arglist);
+ capacity = vsnprintf(nullptr, 0, fmt, backup);
+ va_end(arglist);
+ }
+
+ size_t start = strp->size();
+ strp->resize(strp->size() + capacity + 1);
+
+ int written = vsnprintf(&(*strp)[start], capacity + 1, fmt, arglist);
+ va_end(arglist);
+ strp->resize(start + std::min(capacity, written));
+}
+
+void SStringAppendF(std::string *strp,
+ int bufsize,
+ const char *fmt, ...) {
+ va_list arglist;
+ va_start(arglist, fmt);
+ SStringAppendV(strp, bufsize, fmt, arglist);
+}
+
+std::string StringPrintf(const char* fmt, ...) {
+ std::string s;
+ va_list arglist;
+ va_start(arglist, fmt);
+ SStringAppendV(&s, 0, fmt, arglist);
+ return s;
+}
+
+std::string JoinStrings(const char *delim,
+ const std::vector<std::string> &vec) {
+ int delim_len = strlen(delim);
+
+ // Calc size.
+ int out_len = 0;
+ for (size_t i = 0; i < vec.size(); i++) {
+ out_len += vec[i].size() + delim_len;
+ }
+
+ // Write out.
+ std::string ret;
+ ret.reserve(out_len);
+ for (size_t i = 0; i < vec.size(); i++) {
+ ret.append(vec[i]);
+ ret.append(delim, delim_len);
+ }
+
+ // Strip last delimiter.
+ if (!ret.empty()) {
+ // Must be at least delim_len.
+ ret.resize(ret.size() - delim_len);
+ }
+ return ret;
+}
+
+} // namespace strings
+} // namespace libtextclassifier3
diff --git a/native/utils/strings/append.h b/native/utils/strings/append.h
new file mode 100644
index 0000000..4b4d0b0
--- /dev/null
+++ b/native/utils/strings/append.h
@@ -0,0 +1,43 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_STRINGS_APPEND_H_
+#define LIBTEXTCLASSIFIER_UTILS_STRINGS_APPEND_H_
+
+#include <string>
+#include <vector>
+
+namespace libtextclassifier3 {
+namespace strings {
+
+// Append vsnprintf to strp. If bufsize hint is > 0 it is
+// used. Otherwise we compute the required bufsize (which is somewhat
+// expensive).
+void SStringAppendV(std::string *strp, int bufsize, const char *fmt,
+ va_list arglist);
+
+void SStringAppendF(std::string *strp, int bufsize, const char *fmt, ...)
+ __attribute__((format(printf, 3, 4)));
+
+std::string StringPrintf(const char *fmt, ...)
+ __attribute__((format(printf, 1, 2)));
+
+std::string JoinStrings(const char *delim, const std::vector<std::string> &vec);
+
+} // namespace strings
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_STRINGS_APPEND_H_
diff --git a/native/utils/strings/append_test.cc b/native/utils/strings/append_test.cc
new file mode 100644
index 0000000..8950761
--- /dev/null
+++ b/native/utils/strings/append_test.cc
@@ -0,0 +1,60 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/strings/append.h"
+
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+namespace strings {
+
+TEST(StringUtilTest, SStringAppendF) {
+ std::string str;
+ SStringAppendF(&str, 5, "%d %d", 0, 1);
+ EXPECT_EQ(str, "0 1");
+
+ SStringAppendF(&str, 1, "%d", 9);
+ EXPECT_EQ(str, "0 19");
+
+ SStringAppendF(&str, 1, "%d", 10);
+ EXPECT_EQ(str, "0 191");
+
+ str.clear();
+
+ SStringAppendF(&str, 5, "%d", 100);
+ EXPECT_EQ(str, "100");
+}
+
+TEST(StringUtilTest, SStringAppendFBufCalc) {
+ std::string str;
+ SStringAppendF(&str, 0, "%d %s %d", 1, "hello", 2);
+ EXPECT_EQ(str, "1 hello 2");
+}
+
+TEST(StringUtilTest, JoinStrings) {
+ std::vector<std::string> vec;
+ vec.push_back("1");
+ vec.push_back("2");
+ vec.push_back("3");
+
+ EXPECT_EQ("1,2,3", JoinStrings(",", vec));
+ EXPECT_EQ("123", JoinStrings("", vec));
+ EXPECT_EQ("1, 2, 3", JoinStrings(", ", vec));
+ EXPECT_EQ("", JoinStrings(",", std::vector<std::string>()));
+}
+
+} // namespace strings
+} // namespace libtextclassifier3
diff --git a/native/utils/tflite-model-executor.cc b/native/utils/tflite-model-executor.cc
index 9ba232e..4ad60cd 100644
--- a/native/utils/tflite-model-executor.cc
+++ b/native/utils/tflite-model-executor.cc
@@ -201,6 +201,10 @@
TfLiteModelExecutor::TfLiteModelExecutor(
std::unique_ptr<const tflite::FlatBufferModel> model)
: model_(std::move(model)), resolver_(BuildOpResolver()) {}
+TfLiteModelExecutor::TfLiteModelExecutor(
+ std::unique_ptr<const tflite::FlatBufferModel> model,
+ std::unique_ptr<tflite::OpResolver> resolver)
+ : model_(std::move(model)), resolver_(std::move(resolver)) {}
std::unique_ptr<tflite::Interpreter> TfLiteModelExecutor::CreateInterpreter()
const {
diff --git a/native/utils/tflite-model-executor.h b/native/utils/tflite-model-executor.h
index 10d4233..e9c6af9 100644
--- a/native/utils/tflite-model-executor.h
+++ b/native/utils/tflite-model-executor.h
@@ -132,6 +132,8 @@
protected:
explicit TfLiteModelExecutor(
std::unique_ptr<const tflite::FlatBufferModel> model);
+ TfLiteModelExecutor(std::unique_ptr<const tflite::FlatBufferModel> model,
+ std::unique_ptr<tflite::OpResolver> resolver);
std::unique_ptr<const tflite::FlatBufferModel> model_;
std::unique_ptr<tflite::OpResolver> resolver_;
diff --git a/native/utils/utf8/NSString+Unicode.h b/native/utils/utf8/NSString+Unicode.h
new file mode 100644
index 0000000..734d58f
--- /dev/null
+++ b/native/utils/utf8/NSString+Unicode.h
@@ -0,0 +1,38 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#import <Foundation/Foundation.h>
+
+/// Defines utility methods for operating with Unicode in @c NSString.
+/// @discussion Unicode has 1,114,112 code points ( http://en.wikipedia.org/wiki/Code_point ),
+/// and multiple encodings that map these code points into code units.
+/// @c NSString API exposes the string as if it were encoded in UTF-16, which makes use
+/// of surrogate pairs ( http://en.wikipedia.org/wiki/UTF-16 ).
+/// The methods in this category translate indices between Unicode codepoints and
+/// UTF-16 unichars.
+@interface NSString (Unicode)
+
+/// Returns the number of Unicode codepoints for a string slice.
+/// @param start The NSString start index.
+/// @param length The number of unichar units.
+/// @return The number of Unicode code points in the specified unichar range.
+- (NSUInteger)tc_countChar32:(NSUInteger)start withLength:(NSUInteger)length;
+
+/// Returns the length of the string in terms of Unicode codepoints.
+/// @return The number of Unicode codepoints in this string.
+- (NSUInteger)tc_codepointLength;
+
+@end
diff --git a/native/utils/utf8/UniLibJavaIcuTest.java b/native/utils/utf8/UniLibJavaIcuTest.java
deleted file mode 100644
index d6a0a06..0000000
--- a/native/utils/utf8/UniLibJavaIcuTest.java
+++ /dev/null
@@ -1,41 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.google.android.textclassifier.utils.utf8;
-
-import static com.google.common.truth.Truth.assertThat;
-
-import org.junit.Before;
-import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.junit.runners.JUnit4;
-
-/** This is just a launcher of the tests because we need a valid JNIEnv in the C++ test. */
-@RunWith(JUnit4.class)
-public class UniLibJavaIcuTest {
-
- @Before
- public void setUp() throws Exception {
- System.loadLibrary("unilib-javaicu_test-jni");
- }
-
- private native boolean testsMain();
-
- @Test
- public void testNative() {
- assertThat(testsMain()).isTrue();
- }
-}
diff --git a/native/utils/utf8/unicodetext.h b/native/utils/utf8/unicodetext.h
index 310fd38..3f884f9 100644
--- a/native/utils/utf8/unicodetext.h
+++ b/native/utils/utf8/unicodetext.h
@@ -119,11 +119,12 @@
}
int utf8_length() const {
- if (it_[0] < 0x80) {
+ const unsigned char byte = static_cast<unsigned char>(it_[0]);
+ if (byte < 0x80) {
return 1;
- } else if (it_[0] < 0xE0) {
+ } else if (byte < 0xE0) {
return 2;
- } else if (it_[0] < 0xF0) {
+ } else if (byte < 0xF0) {
return 3;
} else {
return 4;
diff --git a/native/utils/utf8/unilib-common.cc b/native/utils/utf8/unilib-common.cc
new file mode 100644
index 0000000..2b6deda
--- /dev/null
+++ b/native/utils/utf8/unilib-common.cc
@@ -0,0 +1,414 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/utf8/unilib-common.h"
+
+#include <algorithm>
+
+namespace libtextclassifier3 {
+namespace {
+
+#define ARRAYSIZE(a) sizeof(a) / sizeof(*a)
+
+// Derived from http://www.unicode.org/Public/UNIDATA/UnicodeData.txt
+// grep -E "Ps" UnicodeData.txt | \
+// sed -rne "s/^([0-9A-Z]{4});.*(PAREN|BRACKET|BRAKCET|BRACE).*/0x\1, /p"
+// IMPORTANT: entries with the same offsets in kOpeningBrackets and
+// kClosingBrackets must be counterparts.
+constexpr char32 kOpeningBrackets[] = {
+ 0x0028, 0x005B, 0x007B, 0x0F3C, 0x2045, 0x207D, 0x208D, 0x2329, 0x2768,
+ 0x276A, 0x276C, 0x2770, 0x2772, 0x2774, 0x27E6, 0x27E8, 0x27EA, 0x27EC,
+ 0x27EE, 0x2983, 0x2985, 0x2987, 0x2989, 0x298B, 0x298D, 0x298F, 0x2991,
+ 0x2993, 0x2995, 0x2997, 0x29FC, 0x2E22, 0x2E24, 0x2E26, 0x2E28, 0x3008,
+ 0x300A, 0x300C, 0x300E, 0x3010, 0x3014, 0x3016, 0x3018, 0x301A, 0xFD3F,
+ 0xFE17, 0xFE35, 0xFE37, 0xFE39, 0xFE3B, 0xFE3D, 0xFE3F, 0xFE41, 0xFE43,
+ 0xFE47, 0xFE59, 0xFE5B, 0xFE5D, 0xFF08, 0xFF3B, 0xFF5B, 0xFF5F, 0xFF62};
+constexpr int kNumOpeningBrackets = ARRAYSIZE(kOpeningBrackets);
+
+// grep -E "Pe" UnicodeData.txt | \
+// sed -rne "s/^([0-9A-Z]{4});.*(PAREN|BRACKET|BRAKCET|BRACE).*/0x\1, /p"
+constexpr char32 kClosingBrackets[] = {
+ 0x0029, 0x005D, 0x007D, 0x0F3D, 0x2046, 0x207E, 0x208E, 0x232A, 0x2769,
+ 0x276B, 0x276D, 0x2771, 0x2773, 0x2775, 0x27E7, 0x27E9, 0x27EB, 0x27ED,
+ 0x27EF, 0x2984, 0x2986, 0x2988, 0x298A, 0x298C, 0x298E, 0x2990, 0x2992,
+ 0x2994, 0x2996, 0x2998, 0x29FD, 0x2E23, 0x2E25, 0x2E27, 0x2E29, 0x3009,
+ 0x300B, 0x300D, 0x300F, 0x3011, 0x3015, 0x3017, 0x3019, 0x301B, 0xFD3E,
+ 0xFE18, 0xFE36, 0xFE38, 0xFE3A, 0xFE3C, 0xFE3E, 0xFE40, 0xFE42, 0xFE44,
+ 0xFE48, 0xFE5A, 0xFE5C, 0xFE5E, 0xFF09, 0xFF3D, 0xFF5D, 0xFF60, 0xFF63};
+constexpr int kNumClosingBrackets = ARRAYSIZE(kClosingBrackets);
+
+// grep -E "WS" UnicodeData.txt | sed -re "s/([0-9A-Z]+);.*/0x\1, /"
+constexpr char32 kWhitespaces[] = {
+ 0x000C, 0x0020, 0x1680, 0x2000, 0x2001, 0x2002, 0x2003, 0x2004,
+ 0x2005, 0x2006, 0x2007, 0x2008, 0x2009, 0x200A, 0x2028, 0x205F,
+ 0x21C7, 0x21C8, 0x21C9, 0x21CA, 0x21F6, 0x2B31, 0x2B84, 0x2B85,
+ 0x2B86, 0x2B87, 0x2B94, 0x3000, 0x4DCC, 0x10344, 0x10347, 0x1DA0A,
+ 0x1DA0B, 0x1DA0C, 0x1DA0D, 0x1DA0E, 0x1DA0F, 0x1DA10, 0x1F4F0, 0x1F500,
+ 0x1F501, 0x1F502, 0x1F503, 0x1F504, 0x1F5D8, 0x1F5DE};
+constexpr int kNumWhitespaces = ARRAYSIZE(kWhitespaces);
+
+// grep -E "Nd" UnicodeData.txt | sed -re "s/([0-9A-Z]+);.*/0x\1, /"
+// As the name suggests, these ranges are always 10 codepoints long, so we just
+// store the end of the range.
+constexpr char32 kDecimalDigitRangesEnd[] = {
+ 0x0039, 0x0669, 0x06f9, 0x07c9, 0x096f, 0x09ef, 0x0a6f, 0x0aef,
+ 0x0b6f, 0x0bef, 0x0c6f, 0x0cef, 0x0d6f, 0x0def, 0x0e59, 0x0ed9,
+ 0x0f29, 0x1049, 0x1099, 0x17e9, 0x1819, 0x194f, 0x19d9, 0x1a89,
+ 0x1a99, 0x1b59, 0x1bb9, 0x1c49, 0x1c59, 0xa629, 0xa8d9, 0xa909,
+ 0xa9d9, 0xa9f9, 0xaa59, 0xabf9, 0xff19, 0x104a9, 0x1106f, 0x110f9,
+ 0x1113f, 0x111d9, 0x112f9, 0x11459, 0x114d9, 0x11659, 0x116c9, 0x11739,
+ 0x118e9, 0x11c59, 0x11d59, 0x16a69, 0x16b59, 0x1d7ff};
+constexpr int kNumDecimalDigitRangesEnd = ARRAYSIZE(kDecimalDigitRangesEnd);
+
+// grep -E "Lu" UnicodeData.txt | sed -re "s/([0-9A-Z]+);.*/0x\1, /"
+// There are three common ways in which upper/lower case codepoint ranges
+// were introduced: one offs, dense ranges, and ranges that alternate between
+// lower and upper case. For the sake of keeping out binary size down, we
+// treat each independently.
+constexpr char32 kUpperSingles[] = {
+ 0x01b8, 0x01bc, 0x01c4, 0x01c7, 0x01ca, 0x01f1, 0x0376, 0x037f,
+ 0x03cf, 0x03f4, 0x03fa, 0x10c7, 0x10cd, 0x2102, 0x2107, 0x2115,
+ 0x2145, 0x2183, 0x2c72, 0x2c75, 0x2cf2, 0xa7b6};
+constexpr int kNumUpperSingles = ARRAYSIZE(kUpperSingles);
+constexpr char32 kUpperRanges1Start[] = {
+ 0x0041, 0x00c0, 0x00d8, 0x0181, 0x018a, 0x018e, 0x0193, 0x0196,
+ 0x019c, 0x019f, 0x01b2, 0x01f7, 0x023a, 0x023d, 0x0244, 0x0389,
+ 0x0392, 0x03a3, 0x03d2, 0x03fd, 0x0531, 0x10a0, 0x13a0, 0x1f08,
+ 0x1f18, 0x1f28, 0x1f38, 0x1f48, 0x1f68, 0x1fb8, 0x1fc8, 0x1fd8,
+ 0x1fe8, 0x1ff8, 0x210b, 0x2110, 0x2119, 0x212b, 0x2130, 0x213e,
+ 0x2c00, 0x2c63, 0x2c6e, 0x2c7e, 0xa7ab, 0xa7b0};
+constexpr int kNumUpperRanges1Start = ARRAYSIZE(kUpperRanges1Start);
+constexpr char32 kUpperRanges1End[] = {
+ 0x005a, 0x00d6, 0x00de, 0x0182, 0x018b, 0x0191, 0x0194, 0x0198,
+ 0x019d, 0x01a0, 0x01b3, 0x01f8, 0x023b, 0x023e, 0x0246, 0x038a,
+ 0x03a1, 0x03ab, 0x03d4, 0x042f, 0x0556, 0x10c5, 0x13f5, 0x1f0f,
+ 0x1f1d, 0x1f2f, 0x1f3f, 0x1f4d, 0x1f6f, 0x1fbb, 0x1fcb, 0x1fdb,
+ 0x1fec, 0x1ffb, 0x210d, 0x2112, 0x211d, 0x212d, 0x2133, 0x213f,
+ 0x2c2e, 0x2c64, 0x2c70, 0x2c80, 0xa7ae, 0xa7b4};
+constexpr int kNumUpperRanges1End = ARRAYSIZE(kUpperRanges1End);
+constexpr char32 kUpperRanges2Start[] = {
+ 0x0100, 0x0139, 0x014a, 0x0179, 0x0184, 0x0187, 0x01a2, 0x01a7, 0x01ac,
+ 0x01af, 0x01b5, 0x01cd, 0x01de, 0x01f4, 0x01fa, 0x0241, 0x0248, 0x0370,
+ 0x0386, 0x038c, 0x038f, 0x03d8, 0x03f7, 0x0460, 0x048a, 0x04c1, 0x04d0,
+ 0x1e00, 0x1e9e, 0x1f59, 0x2124, 0x2c60, 0x2c67, 0x2c82, 0x2ceb, 0xa640,
+ 0xa680, 0xa722, 0xa732, 0xa779, 0xa77e, 0xa78b, 0xa790, 0xa796};
+constexpr int kNumUpperRanges2Start = ARRAYSIZE(kUpperRanges2Start);
+constexpr char32 kUpperRanges2End[] = {
+ 0x0136, 0x0147, 0x0178, 0x017d, 0x0186, 0x0189, 0x01a6, 0x01a9, 0x01ae,
+ 0x01b1, 0x01b7, 0x01db, 0x01ee, 0x01f6, 0x0232, 0x0243, 0x024e, 0x0372,
+ 0x0388, 0x038e, 0x0391, 0x03ee, 0x03f9, 0x0480, 0x04c0, 0x04cd, 0x052e,
+ 0x1e94, 0x1efe, 0x1f5f, 0x212a, 0x2c62, 0x2c6d, 0x2ce2, 0x2ced, 0xa66c,
+ 0xa69a, 0xa72e, 0xa76e, 0xa77d, 0xa786, 0xa78d, 0xa792, 0xa7aa};
+constexpr int kNumUpperRanges2End = ARRAYSIZE(kUpperRanges2End);
+
+// grep -E "Ll" UnicodeData.txt | sed -re "s/([0-9A-Z]+);.*/0x\1, /"
+constexpr char32 kLowerSingles[] = {
+ 0x00b5, 0x0188, 0x0192, 0x0195, 0x019e, 0x01b0, 0x01c6, 0x01c9,
+ 0x01f0, 0x023c, 0x0242, 0x0377, 0x0390, 0x03f5, 0x03f8, 0x1fbe,
+ 0x210a, 0x2113, 0x212f, 0x2134, 0x2139, 0x214e, 0x2184, 0x2c61,
+ 0x2ce4, 0x2cf3, 0x2d27, 0x2d2d, 0xa7af, 0xa7c3, 0xa7fa, 0x1d7cb};
+constexpr int kNumLowerSingles = ARRAYSIZE(kLowerSingles);
+constexpr char32 kLowerRanges1Start[] = {
+ 0x0061, 0x00df, 0x00f8, 0x017f, 0x018c, 0x0199, 0x01b9, 0x01bd,
+ 0x0234, 0x023f, 0x0250, 0x0295, 0x037b, 0x03ac, 0x03d0, 0x03d5,
+ 0x03f0, 0x03fb, 0x0430, 0x0560, 0x10d0, 0x10fd, 0x13f8, 0x1c80,
+ 0x1d00, 0x1d6b, 0x1d79, 0x1e96, 0x1f00, 0x1f10, 0x1f20, 0x1f30,
+ 0x1f40, 0x1f50, 0x1f60, 0x1f70, 0x1f80, 0x1f90, 0x1fa0, 0x1fb0,
+ 0x1fb6, 0x1fc2, 0x1fc6, 0x1fd0, 0x1fd6, 0x1fe0, 0x1ff2, 0x1ff6,
+ 0x210e, 0x213c, 0x2146, 0x2c30, 0x2c65, 0x2c77, 0x2d00, 0xa730,
+ 0xa772, 0xa794, 0xab30, 0xab60, 0xab70, 0xfb00, 0xfb13, 0xff41,
+ 0x10428, 0x104d8, 0x10cc0, 0x118c0, 0x16e60, 0x1d41a, 0x1d44e, 0x1d456,
+ 0x1d482, 0x1d4b6, 0x1d4be, 0x1d4c5, 0x1d4ea, 0x1d51e, 0x1d552, 0x1d586,
+ 0x1d5ba, 0x1d5ee, 0x1d622, 0x1d656, 0x1d68a, 0x1d6c2, 0x1d6dc, 0x1d6fc,
+ 0x1d716, 0x1d736, 0x1d750, 0x1d770, 0x1d78a, 0x1d7aa, 0x1d7c4, 0x1e922};
+constexpr int kNumLowerRanges1Start = ARRAYSIZE(kLowerRanges1Start);
+constexpr char32 kLowerRanges1End[] = {
+ 0x007a, 0x00f6, 0x00ff, 0x0180, 0x018d, 0x019b, 0x01ba, 0x01bf,
+ 0x0239, 0x0240, 0x0293, 0x02af, 0x037d, 0x03ce, 0x03d1, 0x03d7,
+ 0x03f3, 0x03fc, 0x045f, 0x0588, 0x10fa, 0x10ff, 0x13fd, 0x1c88,
+ 0x1d2b, 0x1d77, 0x1d9a, 0x1e9d, 0x1f07, 0x1f15, 0x1f27, 0x1f37,
+ 0x1f45, 0x1f57, 0x1f67, 0x1f7d, 0x1f87, 0x1f97, 0x1fa7, 0x1fb4,
+ 0x1fb7, 0x1fc4, 0x1fc7, 0x1fd3, 0x1fd7, 0x1fe7, 0x1ff4, 0x1ff7,
+ 0x210f, 0x213d, 0x2149, 0x2c5e, 0x2c66, 0x2c7b, 0x2d25, 0xa731,
+ 0xa778, 0xa795, 0xab5a, 0xab67, 0xabbf, 0xfb06, 0xfb17, 0xff5a,
+ 0x1044f, 0x104fb, 0x10cf2, 0x118df, 0x16e7f, 0x1d433, 0x1d454, 0x1d467,
+ 0x1d49b, 0x1d4b9, 0x1d4c3, 0x1d4cf, 0x1d503, 0x1d537, 0x1d56b, 0x1d59f,
+ 0x1d5d3, 0x1d607, 0x1d63b, 0x1d66f, 0x1d6a5, 0x1d6da, 0x1d6e1, 0x1d714,
+ 0x1d71b, 0x1d74e, 0x1d755, 0x1d788, 0x1d78f, 0x1d7c2, 0x1d7c9, 0x1e943};
+constexpr int kNumLowerRanges1End = ARRAYSIZE(kLowerRanges1End);
+constexpr char32 kLowerRanges2Start[] = {
+ 0x0101, 0x0138, 0x0149, 0x017a, 0x0183, 0x01a1, 0x01a8, 0x01ab,
+ 0x01b4, 0x01cc, 0x01dd, 0x01f3, 0x01f9, 0x0247, 0x0371, 0x03d9,
+ 0x0461, 0x048b, 0x04c2, 0x04cf, 0x1e01, 0x1e9f, 0x2c68, 0x2c71,
+ 0x2c74, 0x2c81, 0x2cec, 0xa641, 0xa681, 0xa723, 0xa733, 0xa77a,
+ 0xa77f, 0xa78c, 0xa791, 0xa797, 0xa7b5, 0x1d4bb};
+constexpr int kNumLowerRanges2Start = ARRAYSIZE(kLowerRanges2Start);
+constexpr char32 kLowerRanges2End[] = {
+ 0x0137, 0x0148, 0x0177, 0x017e, 0x0185, 0x01a5, 0x01aa, 0x01ad,
+ 0x01b6, 0x01dc, 0x01ef, 0x01f5, 0x0233, 0x024f, 0x0373, 0x03ef,
+ 0x0481, 0x04bf, 0x04ce, 0x052f, 0x1e95, 0x1eff, 0x2c6c, 0x2c73,
+ 0x2c76, 0x2ce3, 0x2cee, 0xa66d, 0xa69b, 0xa72f, 0xa771, 0xa77c,
+ 0xa787, 0xa78e, 0xa793, 0xa7a9, 0xa7bf, 0x1d4bd};
+constexpr int kNumLowerRanges2End = ARRAYSIZE(kLowerRanges2End);
+
+// grep -E "Lu" UnicodeData.txt | \
+// sed -rne "s/^([0-9A-Z]+);.*;([0-9A-Z]+);$/(0x\1, 0x\2), /p"
+// We have two strategies for mapping from upper to lower case. We have single
+// character lookups that do not follow a pattern, and ranges for which there
+// is a constant codepoint shift.
+// Note that these ranges ignore anything that's not an upper case character,
+// so when applied to a non-uppercase character the result is incorrect.
+constexpr int kToLowerSingles[] = {
+ 0x0130, 0x0178, 0x0181, 0x0186, 0x018b, 0x018e, 0x018f, 0x0190, 0x0191,
+ 0x0194, 0x0196, 0x0197, 0x0198, 0x019c, 0x019d, 0x019f, 0x01a6, 0x01a9,
+ 0x01ae, 0x01b7, 0x01f6, 0x01f7, 0x0220, 0x023a, 0x023d, 0x023e, 0x0243,
+ 0x0244, 0x0245, 0x037f, 0x0386, 0x038c, 0x03cf, 0x03f4, 0x03f9, 0x04c0,
+ 0x1e9e, 0x1fec, 0x2126, 0x212a, 0x212b, 0x2132, 0x2183, 0x2c60, 0x2c62,
+ 0x2c63, 0x2c64, 0x2c6d, 0x2c6e, 0x2c6f, 0x2c70, 0xa77d, 0xa78d, 0xa7aa,
+ 0xa7ab, 0xa7ac, 0xa7ad, 0xa7ae, 0xa7b0, 0xa7b1, 0xa7b2, 0xa7b3};
+constexpr int kNumToLowerSingles = ARRAYSIZE(kToLowerSingles);
+constexpr int kToUpperSingles[] = {
+ 0x0069, 0x00ff, 0x0253, 0x0254, 0x018c, 0x01dd, 0x0259, 0x025b, 0x0192,
+ 0x0263, 0x0269, 0x0268, 0x0199, 0x026f, 0x0272, 0x0275, 0x0280, 0x0283,
+ 0x0288, 0x0292, 0x0195, 0x01bf, 0x019e, 0x2c65, 0x019a, 0x2c66, 0x0180,
+ 0x0289, 0x028c, 0x03f3, 0x03ac, 0x03cc, 0x03d7, 0x03b8, 0x03f2, 0x04cf,
+ 0x00df, 0x1fe5, 0x03c9, 0x006b, 0x00e5, 0x214e, 0x2184, 0x2c61, 0x026b,
+ 0x1d7d, 0x027d, 0x0251, 0x0271, 0x0250, 0x0252, 0x1d79, 0x0265, 0x0266,
+ 0x025c, 0x0261, 0x026c, 0x026a, 0x029e, 0x0287, 0x029d, 0xab53};
+constexpr int kNumToUpperSingles = ARRAYSIZE(kToUpperSingles);
+constexpr int kToLowerRangesStart[] = {
+ 0x0041, 0x0100, 0x0189, 0x01a0, 0x01b1, 0x01b3, 0x0388, 0x038e, 0x0391,
+ 0x03d8, 0x03fd, 0x0400, 0x0410, 0x0460, 0x0531, 0x10a0, 0x13a0, 0x13f0,
+ 0x1e00, 0x1f08, 0x1fba, 0x1fc8, 0x1fd8, 0x1fda, 0x1fe8, 0x1fea, 0x1ff8,
+ 0x1ffa, 0x2c00, 0x2c67, 0x2c7e, 0x2c80, 0xff21, 0x10400, 0x10c80, 0x118a0};
+constexpr int kNumToLowerRangesStart = ARRAYSIZE(kToLowerRangesStart);
+constexpr int kToLowerRangesEnd[] = {
+ 0x00de, 0x0187, 0x019f, 0x01af, 0x01b2, 0x0386, 0x038c, 0x038f, 0x03cf,
+ 0x03fa, 0x03ff, 0x040f, 0x042f, 0x052e, 0x0556, 0x10cd, 0x13ef, 0x13f5,
+ 0x1efe, 0x1fb9, 0x1fbb, 0x1fcb, 0x1fd9, 0x1fdb, 0x1fe9, 0x1fec, 0x1ff9,
+ 0x2183, 0x2c64, 0x2c75, 0x2c7f, 0xa7b6, 0xff3a, 0x104d3, 0x10cb2, 0x118bf};
+constexpr int kNumToLowerRangesEnd = ARRAYSIZE(kToLowerRangesEnd);
+constexpr int kToLowerRangesOffsets[] = {
+ 32, 1, 205, 1, 217, 1, 37, 63, 32, 1, -130, 80,
+ 32, 1, 48, 7264, 38864, 8, 1, -8, -74, -86, -8, -100,
+ -8, -112, -128, -126, 48, 1, -10815, 1, 32, 40, 64, 32};
+constexpr int kNumToLowerRangesOffsets = ARRAYSIZE(kToLowerRangesOffsets);
+constexpr int kToUpperRangesStart[] = {
+ 0x0061, 0x0101, 0x01a1, 0x01b4, 0x023f, 0x0256, 0x028a, 0x037b, 0x03ad,
+ 0x03b1, 0x03cd, 0x03d9, 0x0430, 0x0450, 0x0461, 0x0561, 0x13f8, 0x1e01,
+ 0x1f00, 0x1f70, 0x1f72, 0x1f76, 0x1f78, 0x1f7a, 0x1f7c, 0x1fd0, 0x1fe0,
+ 0x2c30, 0x2c68, 0x2c81, 0x2d00, 0xab70, 0xff41, 0x10428, 0x10cc0, 0x118c0};
+constexpr int kNumToUpperRangesStart = ARRAYSIZE(kToUpperRangesStart);
+constexpr int kToUpperRangesEnd[] = {
+ 0x00fe, 0x0188, 0x01b0, 0x0387, 0x0240, 0x026c, 0x028b, 0x037d, 0x03b0,
+ 0x03ef, 0x03ce, 0x03fb, 0x044f, 0x045f, 0x052f, 0x0586, 0x13fd, 0x1eff,
+ 0x1fb1, 0x1f71, 0x1f75, 0x1f77, 0x1f79, 0x1f7c, 0x2105, 0x1fd1, 0x1fe1,
+ 0x2c94, 0x2c76, 0xa7b7, 0x2d2d, 0xabbf, 0xff5a, 0x104fb, 0x10cf2, 0x118df};
+constexpr int kNumToUpperRangesEnd = ARRAYSIZE(kToUpperRangesEnd);
+constexpr int kToUpperRangesOffsets[]{
+ -32, -1, -1, -1, 10815, -205, -217, 130, -37, -32, -63, -1,
+ -32, -80, -1, -48, -8, -1, 8, 74, 86, 100, 128, 112,
+ 126, 8, 8, -48, -1, -1, -7264, -38864, -32, -40, -64, -32};
+constexpr int kNumToUpperRangesOffsets = ARRAYSIZE(kToUpperRangesOffsets);
+
+#undef ARRAYSIZE
+
+static_assert(kNumOpeningBrackets == kNumClosingBrackets,
+ "mismatching number of opening and closing brackets");
+static_assert(kNumLowerRanges1Start == kNumLowerRanges1End,
+ "number of uppercase stride 1 range starts/ends doesn't match");
+static_assert(kNumLowerRanges2Start == kNumLowerRanges2End,
+ "number of uppercase stride 2 range starts/ends doesn't match");
+static_assert(kNumUpperRanges1Start == kNumUpperRanges1End,
+ "number of uppercase stride 1 range starts/ends doesn't match");
+static_assert(kNumUpperRanges2Start == kNumUpperRanges2End,
+ "number of uppercase stride 2 range starts/ends doesn't match");
+static_assert(kNumToLowerSingles == kNumToUpperSingles,
+ "number of to lower and upper singles doesn't match");
+static_assert(kNumToLowerRangesStart == kNumToLowerRangesEnd,
+ "mismatching number of range starts/ends for to lower ranges");
+static_assert(kNumToLowerRangesStart == kNumToLowerRangesOffsets,
+ "number of to lower ranges and offsets doesn't match");
+static_assert(kNumToUpperRangesStart == kNumToUpperRangesEnd,
+ "mismatching number of range starts/ends for to upper ranges");
+static_assert(kNumToUpperRangesStart == kNumToUpperRangesOffsets,
+ "number of to upper ranges and offsets doesn't match");
+
+constexpr int kNoMatch = -1;
+
+// Returns the index of the element in the array that matched the given
+// codepoint, or kNoMatch if the element didn't exist.
+// The input array must be in sorted order.
+int GetMatchIndex(const char32* array, int array_length, char32 c) {
+ const char32* end = array + array_length;
+ const auto find_it = std::lower_bound(array, end, c);
+ if (find_it != end && *find_it == c) {
+ return find_it - array;
+ } else {
+ return kNoMatch;
+ }
+}
+
+// Returns the index of the range in the array that overlapped the given
+// codepoint, or kNoMatch if no such range existed.
+// The input array must be in sorted order.
+int GetOverlappingRangeIndex(const char32* arr, int arr_length,
+ int range_length, char32 c) {
+ const char32* end = arr + arr_length;
+ const auto find_it = std::lower_bound(arr, end, c);
+ if (find_it == end) {
+ return kNoMatch;
+ }
+ // The end is inclusive, we so subtract one less than the range length.
+ const char32 range_end = *find_it;
+ const char32 range_start = range_end - (range_length - 1);
+ if (c < range_start || range_end < c) {
+ return kNoMatch;
+ } else {
+ return find_it - arr;
+ }
+}
+
+// As above, but with explicit codepoint start and end indices for the range.
+// The input array must be in sorted order.
+int GetOverlappingRangeIndex(const char32* start_arr, const char32* end_arr,
+ int arr_length, int stride, char32 c) {
+ const char32* end_arr_end = end_arr + arr_length;
+ const auto find_it = std::lower_bound(end_arr, end_arr_end, c);
+ if (find_it == end_arr_end) {
+ return kNoMatch;
+ }
+ // Find the corresponding start.
+ const int range_index = find_it - end_arr;
+ const char32 range_start = start_arr[range_index];
+ const char32 range_end = *find_it;
+ if (c < range_start || range_end < c) {
+ return kNoMatch;
+ }
+ if ((c - range_start) % stride == 0) {
+ return range_index;
+ } else {
+ return kNoMatch;
+ }
+}
+
+} // anonymous namespace
+
+bool IsOpeningBracket(char32 codepoint) {
+ return GetMatchIndex(kOpeningBrackets, kNumOpeningBrackets, codepoint) >= 0;
+}
+
+bool IsClosingBracket(char32 codepoint) {
+ return GetMatchIndex(kClosingBrackets, kNumClosingBrackets, codepoint) >= 0;
+}
+
+bool IsWhitespace(char32 codepoint) {
+ return GetMatchIndex(kWhitespaces, kNumWhitespaces, codepoint) >= 0;
+}
+
+bool IsDigit(char32 codepoint) {
+ return GetOverlappingRangeIndex(kDecimalDigitRangesEnd,
+ kNumDecimalDigitRangesEnd,
+ /*range_length=*/10, codepoint) >= 0;
+}
+
+bool IsLower(char32 codepoint) {
+ if (GetMatchIndex(kLowerSingles, kNumLowerSingles, codepoint) >= 0) {
+ return true;
+ } else if (GetOverlappingRangeIndex(kLowerRanges1Start, kLowerRanges1End,
+ kNumLowerRanges1Start, /*stride=*/1,
+ codepoint) >= 0) {
+ return true;
+ } else if (GetOverlappingRangeIndex(kLowerRanges2Start, kLowerRanges2End,
+ kNumLowerRanges2Start, /*stride=*/2,
+ codepoint) >= 0) {
+ return true;
+ } else {
+ return false;
+ }
+}
+
+bool IsUpper(char32 codepoint) {
+ if (GetMatchIndex(kUpperSingles, kNumUpperSingles, codepoint) >= 0) {
+ return true;
+ } else if (GetOverlappingRangeIndex(kUpperRanges1Start, kUpperRanges1End,
+ kNumUpperRanges1Start, /*stride=*/1,
+ codepoint) >= 0) {
+ return true;
+ } else if (GetOverlappingRangeIndex(kUpperRanges2Start, kUpperRanges2End,
+ kNumUpperRanges2Start, /*stride=*/2,
+ codepoint) >= 0) {
+ return true;
+ } else {
+ return false;
+ }
+}
+
+char32 ToLower(char32 codepoint) {
+ // Make sure we still produce output even if the method is called for a
+ // codepoint that's not an uppercase character.
+ if (!IsUpper(codepoint)) {
+ return codepoint;
+ }
+ const int singles_idx =
+ GetMatchIndex(kToLowerSingles, kNumToLowerSingles, codepoint);
+ if (singles_idx >= 0) {
+ return kToUpperSingles[singles_idx];
+ }
+ const int ranges_idx =
+ GetOverlappingRangeIndex(kToLowerRangesStart, kToLowerRangesEnd,
+ kNumToLowerRangesStart, /*stride=*/1, codepoint);
+ if (ranges_idx >= 0) {
+ return codepoint + kToLowerRangesOffsets[ranges_idx];
+ }
+ return codepoint;
+}
+
+char32 ToUpper(char32 codepoint) {
+ // Make sure we still produce output even if the method is called for a
+ // codepoint that's not an uppercase character.
+ if (!IsLower(codepoint)) {
+ return codepoint;
+ }
+ const int singles_idx =
+ GetMatchIndex(kToUpperSingles, kNumToUpperSingles, codepoint);
+ if (singles_idx >= 0) {
+ return kToLowerSingles[singles_idx];
+ }
+ const int ranges_idx =
+ GetOverlappingRangeIndex(kToUpperRangesStart, kToUpperRangesEnd,
+ kNumToUpperRangesStart, /*stride=*/1, codepoint);
+ if (ranges_idx >= 0) {
+ return codepoint + kToUpperRangesOffsets[ranges_idx];
+ }
+ return codepoint;
+}
+
+char32 GetPairedBracket(char32 codepoint) {
+ const int open_offset =
+ GetMatchIndex(kOpeningBrackets, kNumOpeningBrackets, codepoint);
+ if (open_offset >= 0) {
+ return kClosingBrackets[open_offset];
+ }
+ const int close_offset =
+ GetMatchIndex(kClosingBrackets, kNumClosingBrackets, codepoint);
+ if (close_offset >= 0) {
+ return kOpeningBrackets[close_offset];
+ }
+ return codepoint;
+}
+
+} // namespace libtextclassifier3
diff --git a/native/utils/utf8/unilib-common.h b/native/utils/utf8/unilib-common.h
new file mode 100644
index 0000000..0394cc3
--- /dev/null
+++ b/native/utils/utf8/unilib-common.h
@@ -0,0 +1,36 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_UTF8_UNILIB_COMMON_H_
+#define LIBTEXTCLASSIFIER_UTILS_UTF8_UNILIB_COMMON_H_
+
+#include "utils/base/integral_types.h"
+
+namespace libtextclassifier3 {
+
+bool IsOpeningBracket(char32 codepoint);
+bool IsClosingBracket(char32 codepoint);
+bool IsWhitespace(char32 codepoint);
+bool IsDigit(char32 codepoint);
+bool IsLower(char32 codepoint);
+bool IsUpper(char32 codepoint);
+char32 ToLower(char32 codepoint);
+char32 ToUpper(char32 codepoint);
+char32 GetPairedBracket(char32 codepoint);
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_UTF8_UNILIB_COMMON_H_
diff --git a/native/utils/utf8/unilib-javaicu.cc b/native/utils/utf8/unilib-javaicu.cc
index 8cddddd..13bb536 100644
--- a/native/utils/utf8/unilib-javaicu.cc
+++ b/native/utils/utf8/unilib-javaicu.cc
@@ -16,230 +16,14 @@
#include "utils/utf8/unilib-javaicu.h"
-#include <algorithm>
#include <cassert>
#include <cctype>
#include <map>
#include "utils/java/string_utils.h"
+#include "utils/utf8/unilib-common.h"
namespace libtextclassifier3 {
-namespace {
-
-// -----------------------------------------------------------------------------
-// Native implementations.
-// -----------------------------------------------------------------------------
-
-#define ARRAYSIZE(a) sizeof(a) / sizeof(*a)
-
-// Derived from http://www.unicode.org/Public/UNIDATA/UnicodeData.txt
-// grep -E "Ps" UnicodeData.txt | \
-// sed -rne "s/^([0-9A-Z]{4});.*(PAREN|BRACKET|BRAKCET|BRACE).*/0x\1, /p"
-// IMPORTANT: entries with the same offsets in kOpeningBrackets and
-// kClosingBrackets must be counterparts.
-constexpr char32 kOpeningBrackets[] = {
- 0x0028, 0x005B, 0x007B, 0x0F3C, 0x2045, 0x207D, 0x208D, 0x2329, 0x2768,
- 0x276A, 0x276C, 0x2770, 0x2772, 0x2774, 0x27E6, 0x27E8, 0x27EA, 0x27EC,
- 0x27EE, 0x2983, 0x2985, 0x2987, 0x2989, 0x298B, 0x298D, 0x298F, 0x2991,
- 0x2993, 0x2995, 0x2997, 0x29FC, 0x2E22, 0x2E24, 0x2E26, 0x2E28, 0x3008,
- 0x300A, 0x300C, 0x300E, 0x3010, 0x3014, 0x3016, 0x3018, 0x301A, 0xFD3F,
- 0xFE17, 0xFE35, 0xFE37, 0xFE39, 0xFE3B, 0xFE3D, 0xFE3F, 0xFE41, 0xFE43,
- 0xFE47, 0xFE59, 0xFE5B, 0xFE5D, 0xFF08, 0xFF3B, 0xFF5B, 0xFF5F, 0xFF62};
-constexpr int kNumOpeningBrackets = ARRAYSIZE(kOpeningBrackets);
-
-// grep -E "Pe" UnicodeData.txt | \
-// sed -rne "s/^([0-9A-Z]{4});.*(PAREN|BRACKET|BRAKCET|BRACE).*/0x\1, /p"
-constexpr char32 kClosingBrackets[] = {
- 0x0029, 0x005D, 0x007D, 0x0F3D, 0x2046, 0x207E, 0x208E, 0x232A, 0x2769,
- 0x276B, 0x276D, 0x2771, 0x2773, 0x2775, 0x27E7, 0x27E9, 0x27EB, 0x27ED,
- 0x27EF, 0x2984, 0x2986, 0x2988, 0x298A, 0x298C, 0x298E, 0x2990, 0x2992,
- 0x2994, 0x2996, 0x2998, 0x29FD, 0x2E23, 0x2E25, 0x2E27, 0x2E29, 0x3009,
- 0x300B, 0x300D, 0x300F, 0x3011, 0x3015, 0x3017, 0x3019, 0x301B, 0xFD3E,
- 0xFE18, 0xFE36, 0xFE38, 0xFE3A, 0xFE3C, 0xFE3E, 0xFE40, 0xFE42, 0xFE44,
- 0xFE48, 0xFE5A, 0xFE5C, 0xFE5E, 0xFF09, 0xFF3D, 0xFF5D, 0xFF60, 0xFF63};
-constexpr int kNumClosingBrackets = ARRAYSIZE(kClosingBrackets);
-
-// grep -E "WS" UnicodeData.txt | sed -re "s/([0-9A-Z]+);.*/0x\1, /"
-constexpr char32 kWhitespaces[] = {
- 0x000C, 0x0020, 0x1680, 0x2000, 0x2001, 0x2002, 0x2003, 0x2004,
- 0x2005, 0x2006, 0x2007, 0x2008, 0x2009, 0x200A, 0x2028, 0x205F,
- 0x21C7, 0x21C8, 0x21C9, 0x21CA, 0x21F6, 0x2B31, 0x2B84, 0x2B85,
- 0x2B86, 0x2B87, 0x2B94, 0x3000, 0x4DCC, 0x10344, 0x10347, 0x1DA0A,
- 0x1DA0B, 0x1DA0C, 0x1DA0D, 0x1DA0E, 0x1DA0F, 0x1DA10, 0x1F4F0, 0x1F500,
- 0x1F501, 0x1F502, 0x1F503, 0x1F504, 0x1F5D8, 0x1F5DE};
-constexpr int kNumWhitespaces = ARRAYSIZE(kWhitespaces);
-
-// grep -E "Nd" UnicodeData.txt | sed -re "s/([0-9A-Z]+);.*/0x\1, /"
-// As the name suggests, these ranges are always 10 codepoints long, so we just
-// store the end of the range.
-constexpr char32 kDecimalDigitRangesEnd[] = {
- 0x0039, 0x0669, 0x06f9, 0x07c9, 0x096f, 0x09ef, 0x0a6f, 0x0aef,
- 0x0b6f, 0x0bef, 0x0c6f, 0x0cef, 0x0d6f, 0x0def, 0x0e59, 0x0ed9,
- 0x0f29, 0x1049, 0x1099, 0x17e9, 0x1819, 0x194f, 0x19d9, 0x1a89,
- 0x1a99, 0x1b59, 0x1bb9, 0x1c49, 0x1c59, 0xa629, 0xa8d9, 0xa909,
- 0xa9d9, 0xa9f9, 0xaa59, 0xabf9, 0xff19, 0x104a9, 0x1106f, 0x110f9,
- 0x1113f, 0x111d9, 0x112f9, 0x11459, 0x114d9, 0x11659, 0x116c9, 0x11739,
- 0x118e9, 0x11c59, 0x11d59, 0x16a69, 0x16b59, 0x1d7ff};
-constexpr int kNumDecimalDigitRangesEnd = ARRAYSIZE(kDecimalDigitRangesEnd);
-
-// grep -E "Lu" UnicodeData.txt | sed -re "s/([0-9A-Z]+);.*/0x\1, /"
-// There are three common ways in which upper/lower case codepoint ranges
-// were introduced: one offs, dense ranges, and ranges that alternate between
-// lower and upper case. For the sake of keeping out binary size down, we
-// treat each independently.
-constexpr char32 kUpperSingles[] = {
- 0x01b8, 0x01bc, 0x01c4, 0x01c7, 0x01ca, 0x01f1, 0x0376, 0x037f,
- 0x03cf, 0x03f4, 0x03fa, 0x10c7, 0x10cd, 0x2102, 0x2107, 0x2115,
- 0x2145, 0x2183, 0x2c72, 0x2c75, 0x2cf2, 0xa7b6};
-constexpr int kNumUpperSingles = ARRAYSIZE(kUpperSingles);
-constexpr char32 kUpperRanges1Start[] = {
- 0x0041, 0x00c0, 0x00d8, 0x0181, 0x018a, 0x018e, 0x0193, 0x0196,
- 0x019c, 0x019f, 0x01b2, 0x01f7, 0x023a, 0x023d, 0x0244, 0x0389,
- 0x0392, 0x03a3, 0x03d2, 0x03fd, 0x0531, 0x10a0, 0x13a0, 0x1f08,
- 0x1f18, 0x1f28, 0x1f38, 0x1f48, 0x1f68, 0x1fb8, 0x1fc8, 0x1fd8,
- 0x1fe8, 0x1ff8, 0x210b, 0x2110, 0x2119, 0x212b, 0x2130, 0x213e,
- 0x2c00, 0x2c63, 0x2c6e, 0x2c7e, 0xa7ab, 0xa7b0};
-constexpr int kNumUpperRanges1Start = ARRAYSIZE(kUpperRanges1Start);
-constexpr char32 kUpperRanges1End[] = {
- 0x005a, 0x00d6, 0x00de, 0x0182, 0x018b, 0x0191, 0x0194, 0x0198,
- 0x019d, 0x01a0, 0x01b3, 0x01f8, 0x023b, 0x023e, 0x0246, 0x038a,
- 0x03a1, 0x03ab, 0x03d4, 0x042f, 0x0556, 0x10c5, 0x13f5, 0x1f0f,
- 0x1f1d, 0x1f2f, 0x1f3f, 0x1f4d, 0x1f6f, 0x1fbb, 0x1fcb, 0x1fdb,
- 0x1fec, 0x1ffb, 0x210d, 0x2112, 0x211d, 0x212d, 0x2133, 0x213f,
- 0x2c2e, 0x2c64, 0x2c70, 0x2c80, 0xa7ae, 0xa7b4};
-constexpr int kNumUpperRanges1End = ARRAYSIZE(kUpperRanges1End);
-constexpr char32 kUpperRanges2Start[] = {
- 0x0100, 0x0139, 0x014a, 0x0179, 0x0184, 0x0187, 0x01a2, 0x01a7, 0x01ac,
- 0x01af, 0x01b5, 0x01cd, 0x01de, 0x01f4, 0x01fa, 0x0241, 0x0248, 0x0370,
- 0x0386, 0x038c, 0x038f, 0x03d8, 0x03f7, 0x0460, 0x048a, 0x04c1, 0x04d0,
- 0x1e00, 0x1e9e, 0x1f59, 0x2124, 0x2c60, 0x2c67, 0x2c82, 0x2ceb, 0xa640,
- 0xa680, 0xa722, 0xa732, 0xa779, 0xa77e, 0xa78b, 0xa790, 0xa796};
-constexpr int kNumUpperRanges2Start = ARRAYSIZE(kUpperRanges2Start);
-constexpr char32 kUpperRanges2End[] = {
- 0x0136, 0x0147, 0x0178, 0x017d, 0x0186, 0x0189, 0x01a6, 0x01a9, 0x01ae,
- 0x01b1, 0x01b7, 0x01db, 0x01ee, 0x01f6, 0x0232, 0x0243, 0x024e, 0x0372,
- 0x0388, 0x038e, 0x0391, 0x03ee, 0x03f9, 0x0480, 0x04c0, 0x04cd, 0x052e,
- 0x1e94, 0x1efe, 0x1f5f, 0x212a, 0x2c62, 0x2c6d, 0x2ce2, 0x2ced, 0xa66c,
- 0xa69a, 0xa72e, 0xa76e, 0xa77d, 0xa786, 0xa78d, 0xa792, 0xa7aa};
-constexpr int kNumUpperRanges2End = ARRAYSIZE(kUpperRanges2End);
-
-// grep -E "Lu" UnicodeData.txt | \
-// sed -rne "s/^([0-9A-Z]+);.*;([0-9A-Z]+);$/(0x\1, 0x\2), /p"
-// We have two strategies for mapping from upper to lower case. We have single
-// character lookups that do not follow a pattern, and ranges for which there
-// is a constant codepoint shift.
-// Note that these ranges ignore anything that's not an upper case character,
-// so when applied to a non-uppercase character the result is incorrect.
-constexpr int kToLowerSingles[] = {
- 0x0130, 0x0178, 0x0181, 0x0186, 0x018b, 0x018e, 0x018f, 0x0190, 0x0191,
- 0x0194, 0x0196, 0x0197, 0x0198, 0x019c, 0x019d, 0x019f, 0x01a6, 0x01a9,
- 0x01ae, 0x01b7, 0x01f6, 0x01f7, 0x0220, 0x023a, 0x023d, 0x023e, 0x0243,
- 0x0244, 0x0245, 0x037f, 0x0386, 0x038c, 0x03cf, 0x03f4, 0x03f9, 0x04c0,
- 0x1e9e, 0x1fec, 0x2126, 0x212a, 0x212b, 0x2132, 0x2183, 0x2c60, 0x2c62,
- 0x2c63, 0x2c64, 0x2c6d, 0x2c6e, 0x2c6f, 0x2c70, 0xa77d, 0xa78d, 0xa7aa,
- 0xa7ab, 0xa7ac, 0xa7ad, 0xa7ae, 0xa7b0, 0xa7b1, 0xa7b2, 0xa7b3};
-constexpr int kNumToLowerSingles = ARRAYSIZE(kToLowerSingles);
-constexpr int kToLowerSinglesOffsets[] = {
- -199, -121, 210, 206, 1, 79, 202, 203, 1,
- 207, 211, 209, 1, 211, 213, 214, 218, 218,
- 218, 219, -97, -56, -130, 10795, -163, 10792, -195,
- 69, 71, 116, 38, 64, 8, -60, -7, 15,
- -7615, -7, -7517, -8383, -8262, 28, 1, 1, -10743,
- -3814, -10727, -10780, -10749, -10783, -10782, -35332, -42280, -42308,
- -42319, -42315, -42305, -42308, -42258, -42282, -42261, 928};
-constexpr int kNumToLowerSinglesOffsets = ARRAYSIZE(kToLowerSinglesOffsets);
-constexpr int kToLowerRangesStart[] = {
- 0x0041, 0x0100, 0x0189, 0x01a0, 0x01b1, 0x01b3, 0x0388, 0x038e, 0x0391,
- 0x03d8, 0x03fd, 0x0400, 0x0410, 0x0460, 0x0531, 0x10a0, 0x13a0, 0x13f0,
- 0x1e00, 0x1f08, 0x1fba, 0x1fc8, 0x1fd8, 0x1fda, 0x1fe8, 0x1fea, 0x1ff8,
- 0x1ffa, 0x2c00, 0x2c67, 0x2c7e, 0x2c80, 0xff21, 0x10400, 0x10c80, 0x118a0};
-constexpr int kNumToLowerRangesStart = ARRAYSIZE(kToLowerRangesStart);
-constexpr int kToLowerRangesEnd[] = {
- 0x00de, 0x0187, 0x019f, 0x01af, 0x01b2, 0x0386, 0x038c, 0x038f, 0x03cf,
- 0x03fa, 0x03ff, 0x040f, 0x042f, 0x052e, 0x0556, 0x10cd, 0x13ef, 0x13f5,
- 0x1efe, 0x1fb9, 0x1fbb, 0x1fcb, 0x1fd9, 0x1fdb, 0x1fe9, 0x1fec, 0x1ff9,
- 0x2183, 0x2c64, 0x2c75, 0x2c7f, 0xa7b6, 0xff3a, 0x104d3, 0x10cb2, 0x118bf};
-constexpr int kNumToLowerRangesEnd = ARRAYSIZE(kToLowerRangesEnd);
-constexpr int kToLowerRangesOffsets[] = {
- 32, 1, 205, 1, 217, 1, 37, 63, 32, 1, -130, 80,
- 32, 1, 48, 7264, 38864, 8, 1, -8, -74, -86, -8, -100,
- -8, -112, -128, -126, 48, 1, -10815, 1, 32, 40, 64, 32};
-constexpr int kNumToLowerRangesOffsets = ARRAYSIZE(kToLowerRangesOffsets);
-
-#undef ARRAYSIZE
-
-static_assert(kNumOpeningBrackets == kNumClosingBrackets,
- "mismatching number of opening and closing brackets");
-static_assert(kNumUpperRanges1Start == kNumUpperRanges1End,
- "number of uppercase stride 1 range starts/ends doesn't match");
-static_assert(kNumUpperRanges2Start == kNumUpperRanges2End,
- "number of uppercase stride 2 range starts/ends doesn't match");
-static_assert(kNumToLowerSingles == kNumToLowerSinglesOffsets,
- "number of to lower singles and offsets doesn't match");
-static_assert(kNumToLowerRangesStart == kNumToLowerRangesEnd,
- "mismatching number of range starts/ends for to lower ranges");
-static_assert(kNumToLowerRangesStart == kNumToLowerRangesOffsets,
- "number of to lower ranges and offsets doesn't match");
-
-constexpr int kNoMatch = -1;
-
-// Returns the index of the element in the array that matched the given
-// codepoint, or kNoMatch if the element didn't exist.
-// The input array must be in sorted order.
-int GetMatchIndex(const char32* array, int array_length, char32 c) {
- const char32* end = array + array_length;
- const auto find_it = std::lower_bound(array, end, c);
- if (find_it != end && *find_it == c) {
- return find_it - array;
- } else {
- return kNoMatch;
- }
-}
-
-// Returns the index of the range in the array that overlapped the given
-// codepoint, or kNoMatch if no such range existed.
-// The input array must be in sorted order.
-int GetOverlappingRangeIndex(const char32* arr, int arr_length,
- int range_length, char32 c) {
- const char32* end = arr + arr_length;
- const auto find_it = std::lower_bound(arr, end, c);
- if (find_it == end) {
- return kNoMatch;
- }
- // The end is inclusive, we so subtract one less than the range length.
- const char32 range_end = *find_it;
- const char32 range_start = range_end - (range_length - 1);
- if (c < range_start || range_end < c) {
- return kNoMatch;
- } else {
- return find_it - arr;
- }
-}
-
-// As above, but with explicit codepoint start and end indices for the range.
-// The input array must be in sorted order.
-int GetOverlappingRangeIndex(const char32* start_arr, const char32* end_arr,
- int arr_length, int stride, char32 c) {
- const char32* end_arr_end = end_arr + arr_length;
- const auto find_it = std::lower_bound(end_arr, end_arr_end, c);
- if (find_it == end_arr_end) {
- return kNoMatch;
- }
- // Find the corresponding start.
- const int range_index = find_it - end_arr;
- const char32 range_start = start_arr[range_index];
- const char32 range_end = *find_it;
- if (c < range_start || range_end < c) {
- return kNoMatch;
- }
- if ((c - range_start) % stride == 0) {
- return range_index;
- } else {
- return kNoMatch;
- }
-}
-
-} // anonymous namespace
UniLib::UniLib() {
TC3_LOG(FATAL) << "Java ICU UniLib must be initialized with a JniCache.";
@@ -249,71 +33,39 @@
: jni_cache_(jni_cache) {}
bool UniLib::IsOpeningBracket(char32 codepoint) const {
- return GetMatchIndex(kOpeningBrackets, kNumOpeningBrackets, codepoint) >= 0;
+ return libtextclassifier3::IsOpeningBracket(codepoint);
}
bool UniLib::IsClosingBracket(char32 codepoint) const {
- return GetMatchIndex(kClosingBrackets, kNumClosingBrackets, codepoint) >= 0;
+ return libtextclassifier3::IsClosingBracket(codepoint);
}
bool UniLib::IsWhitespace(char32 codepoint) const {
- return GetMatchIndex(kWhitespaces, kNumWhitespaces, codepoint) >= 0;
+ return libtextclassifier3::IsWhitespace(codepoint);
}
bool UniLib::IsDigit(char32 codepoint) const {
- return GetOverlappingRangeIndex(kDecimalDigitRangesEnd,
- kNumDecimalDigitRangesEnd,
- /*range_length=*/10, codepoint) >= 0;
+ return libtextclassifier3::IsDigit(codepoint);
+}
+
+bool UniLib::IsLower(char32 codepoint) const {
+ return libtextclassifier3::IsLower(codepoint);
}
bool UniLib::IsUpper(char32 codepoint) const {
- if (GetMatchIndex(kUpperSingles, kNumUpperSingles, codepoint) >= 0) {
- return true;
- } else if (GetOverlappingRangeIndex(kUpperRanges1Start, kUpperRanges1End,
- kNumUpperRanges1Start, /*stride=*/1,
- codepoint) >= 0) {
- return true;
- } else if (GetOverlappingRangeIndex(kUpperRanges2Start, kUpperRanges2End,
- kNumUpperRanges2Start, /*stride=*/2,
- codepoint) >= 0) {
- return true;
- } else {
- return false;
- }
+ return libtextclassifier3::IsUpper(codepoint);
}
char32 UniLib::ToLower(char32 codepoint) const {
- // Make sure we still produce output even if the method is called for a
- // codepoint that's not an uppercase character.
- if (!IsUpper(codepoint)) {
- return codepoint;
- }
- const int singles_idx =
- GetMatchIndex(kToLowerSingles, kNumToLowerSingles, codepoint);
- if (singles_idx >= 0) {
- return codepoint + kToLowerSinglesOffsets[singles_idx];
- }
- const int ranges_idx =
- GetOverlappingRangeIndex(kToLowerRangesStart, kToLowerRangesEnd,
- kNumToLowerRangesStart, /*stride=*/1, codepoint);
- if (ranges_idx >= 0) {
- return codepoint + kToLowerRangesOffsets[ranges_idx];
- }
- return codepoint;
+ return libtextclassifier3::ToLower(codepoint);
+}
+
+char32 UniLib::ToUpper(char32 codepoint) const {
+ return libtextclassifier3::ToUpper(codepoint);
}
char32 UniLib::GetPairedBracket(char32 codepoint) const {
- const int open_offset =
- GetMatchIndex(kOpeningBrackets, kNumOpeningBrackets, codepoint);
- if (open_offset >= 0) {
- return kClosingBrackets[open_offset];
- }
- const int close_offset =
- GetMatchIndex(kClosingBrackets, kNumClosingBrackets, codepoint);
- if (close_offset >= 0) {
- return kOpeningBrackets[close_offset];
- }
- return codepoint;
+ return libtextclassifier3::GetPairedBracket(codepoint);
}
// -----------------------------------------------------------------------------
diff --git a/native/utils/utf8/unilib-javaicu.h b/native/utils/utf8/unilib-javaicu.h
index 0a5d339..77f4970 100644
--- a/native/utils/utf8/unilib-javaicu.h
+++ b/native/utils/utf8/unilib-javaicu.h
@@ -45,9 +45,11 @@
bool IsClosingBracket(char32 codepoint) const;
bool IsWhitespace(char32 codepoint) const;
bool IsDigit(char32 codepoint) const;
+ bool IsLower(char32 codepoint) const;
bool IsUpper(char32 codepoint) const;
char32 ToLower(char32 codepoint) const;
+ char32 ToUpper(char32 codepoint) const;
char32 GetPairedBracket(char32 codepoint) const;
// Forward declaration for friend.
diff --git a/native/utils/utf8/unilib_test-include.cc b/native/utils/utf8/unilib_test-include.cc
index bd53208..9465ea6 100644
--- a/native/utils/utf8/unilib_test-include.cc
+++ b/native/utils/utf8/unilib_test-include.cc
@@ -16,6 +16,7 @@
#include "utils/utf8/unilib_test-include.h"
+#include "utils/utf8/unicodetext.h"
#include "gmock/gmock.h"
namespace libtextclassifier3 {
@@ -34,9 +35,15 @@
EXPECT_FALSE(unilib_.IsUpper(')'));
EXPECT_TRUE(unilib_.IsUpper('A'));
EXPECT_TRUE(unilib_.IsUpper('Z'));
+ EXPECT_FALSE(unilib_.IsLower(')'));
+ EXPECT_TRUE(unilib_.IsLower('a'));
+ EXPECT_TRUE(unilib_.IsLower('z'));
EXPECT_EQ(unilib_.ToLower('A'), 'a');
EXPECT_EQ(unilib_.ToLower('Z'), 'z');
EXPECT_EQ(unilib_.ToLower(')'), ')');
+ EXPECT_EQ(unilib_.ToUpper('a'), 'A');
+ EXPECT_EQ(unilib_.ToUpper('z'), 'Z');
+ EXPECT_EQ(unilib_.ToUpper(')'), ')');
EXPECT_EQ(unilib_.GetPairedBracket(')'), '(');
EXPECT_EQ(unilib_.GetPairedBracket('}'), '{');
}
@@ -55,9 +62,21 @@
EXPECT_TRUE(unilib_.IsUpper(0x0391)); // GREEK CAPITAL ALPHA
EXPECT_TRUE(unilib_.IsUpper(0x03AB)); // GREEK CAPITAL UPSILON W DIAL
EXPECT_FALSE(unilib_.IsUpper(0x03AC)); // GREEK SMALL ALPHA WITH TONOS
+ EXPECT_TRUE(unilib_.IsLower(0x03AC)); // GREEK SMALL ALPHA WITH TONOS
+ EXPECT_TRUE(unilib_.IsLower(0x03B1)); // GREEK SMALL ALPHA
+ EXPECT_TRUE(unilib_.IsLower(0x03CB)); // GREEK SMALL UPSILON
+ EXPECT_TRUE(unilib_.IsLower(0x0211)); // SMALL R WITH DOUBLE GRAVE
+ EXPECT_TRUE(unilib_.IsLower(0x03C0)); // GREEK SMALL PI
+ EXPECT_TRUE(unilib_.IsLower(0x007a)); // SMALL Z
+ EXPECT_FALSE(unilib_.IsLower(0x005a)); // CAPITAL Z
+ EXPECT_FALSE(unilib_.IsLower(0x0212)); // CAPITAL R WITH DOUBLE GRAVE
+ EXPECT_FALSE(unilib_.IsLower(0x0391)); // GREEK CAPITAL ALPHA
EXPECT_EQ(unilib_.ToLower(0x0391), 0x03B1); // GREEK ALPHA
EXPECT_EQ(unilib_.ToLower(0x03AB), 0x03CB); // GREEK UPSILON WITH DIALYTIKA
EXPECT_EQ(unilib_.ToLower(0x03C0), 0x03C0); // GREEK SMALL PI
+ EXPECT_EQ(unilib_.ToUpper(0x03B1), 0x0391); // GREEK ALPHA
+ EXPECT_EQ(unilib_.ToUpper(0x03CB), 0x03AB); // GREEK UPSILON WITH DIALYTIKA
+ EXPECT_EQ(unilib_.ToUpper(0x0391), 0x0391); // GREEK CAPITAL ALPHA
EXPECT_EQ(unilib_.GetPairedBracket(0x0F3C), 0x0F3D);
EXPECT_EQ(unilib_.GetPairedBracket(0x0F3D), 0x0F3C);
diff --git a/native/utils/utf8/unilib_test-include.h b/native/utils/utf8/unilib_test-include.h
index 151a6f0..b4efcd6 100644
--- a/native/utils/utf8/unilib_test-include.h
+++ b/native/utils/utf8/unilib_test-include.h
@@ -26,6 +26,9 @@
extern JNIEnv* g_jenv;
#define TC3_TESTING_CREATE_UNILIB_INSTANCE(VAR) VAR(JniCache::Create(g_jenv))
#include "utils/utf8/unilib-javaicu.h"
+#elif defined TC3_UNILIB_APPLE
+#include "utils/utf8/unilib-apple.h"
+#define TC3_TESTING_CREATE_UNILIB_INSTANCE(VAR) VAR()
#elif defined TC3_UNILIB_DUMMY
#include "utils/utf8/unilib-dummy.h"
#define TC3_TESTING_CREATE_UNILIB_INSTANCE(VAR) VAR()
diff --git a/native/utils/utf8/unilib_test.cc b/native/utils/utf8/unilib_test.cc
new file mode 100644
index 0000000..b5658af
--- /dev/null
+++ b/native/utils/utf8/unilib_test.cc
@@ -0,0 +1,20 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "gtest/gtest.h"
+
+// The actual code of the test is in the following include:
+#include "utils/utf8/unilib_test-include.h"
diff --git a/native/utils/zlib/zlib_regex.cc b/native/utils/zlib/zlib_regex.cc
index bfe3f5b..73b6d30 100644
--- a/native/utils/zlib/zlib_regex.cc
+++ b/native/utils/zlib/zlib_regex.cc
@@ -20,6 +20,7 @@
#include "utils/base/logging.h"
#include "utils/flatbuffers.h"
+#include "utils/utf8/unicodetext.h"
namespace libtextclassifier3 {