Sync from google3.
Bug: 68239358
Test: Builds. Tested on device. CTS test passes.
bit FrameworksCoreTests:android.view.textclassifier.TextClassificationManagerTest
Change-Id: Ie5e20b06b1c615ab246e7ed7f08e980e61c492c4
diff --git a/Android.mk b/Android.mk
index 88aad2b..2317b83 100644
--- a/Android.mk
+++ b/Android.mk
@@ -69,7 +69,7 @@
LOCAL_CFLAGS += $(MY_LIBTEXTCLASSIFIER_CFLAGS)
LOCAL_STRIP_MODULE := $(LIBTEXTCLASSIFIER_STRIP_OPTS)
-LOCAL_SRC_FILES := $(filter-out tests/% %_test.cc,$(call all-subdir-cpp-files))
+LOCAL_SRC_FILES := $(filter-out tests/% %_test.cc test-util.%,$(call all-subdir-cpp-files))
LOCAL_C_INCLUDES := $(TOP)/external/tensorflow $(TOP)/external/flatbuffers/include
LOCAL_SHARED_LIBRARIES += liblog
@@ -81,6 +81,9 @@
LOCAL_ADDITIONAL_DEPENDENCIES += $(LOCAL_PATH)/jni.lds
LOCAL_LDFLAGS += -Wl,-version-script=$(LOCAL_PATH)/jni.lds
+LOCAL_CPPFLAGS_32 += -DLIBTEXTCLASSIFIER_TEST_DATA_DIR="\"/data/nativetest/libtextclassifier_tests/test_data/\""
+LOCAL_CPPFLAGS_64 += -DLIBTEXTCLASSIFIER_TEST_DATA_DIR="\"/data/nativetest64/libtextclassifier_tests/test_data/\""
+
include $(BUILD_SHARED_LIBRARY)
# -----------------------
diff --git a/cached-features.cc b/cached-features.cc
index 0b22d6d..0863c0e 100644
--- a/cached-features.cc
+++ b/cached-features.cc
@@ -54,90 +54,139 @@
return true;
}
+int CalculateOutputFeaturesSize(const FeatureProcessorOptions* options,
+ int feature_vector_size) {
+ const bool bounds_sensitive_enabled =
+ options->bounds_sensitive_features() &&
+ options->bounds_sensitive_features()->enabled();
+
+ int num_extracted_tokens = 0;
+ if (bounds_sensitive_enabled) {
+ const FeatureProcessorOptions_::BoundsSensitiveFeatures* config =
+ options->bounds_sensitive_features();
+ num_extracted_tokens += config->num_tokens_before();
+ num_extracted_tokens += config->num_tokens_inside_left();
+ num_extracted_tokens += config->num_tokens_inside_right();
+ num_extracted_tokens += config->num_tokens_after();
+ if (config->include_inside_bag()) {
+ ++num_extracted_tokens;
+ }
+ } else {
+ num_extracted_tokens = 2 * options->context_size() + 1;
+ }
+
+ int output_features_size = num_extracted_tokens * feature_vector_size;
+
+ if (bounds_sensitive_enabled &&
+ options->bounds_sensitive_features()->include_inside_length()) {
+ ++output_features_size;
+ }
+
+ return output_features_size;
+}
+
} // namespace
-CachedFeatures::CachedFeatures(
+std::unique_ptr<CachedFeatures> CachedFeatures::Create(
const TokenSpan& extraction_span,
const std::vector<std::vector<int>>& sparse_features,
const std::vector<std::vector<float>>& dense_features,
const std::vector<int>& padding_sparse_features,
const std::vector<float>& padding_dense_features,
- const FeatureProcessorOptions_::BoundsSensitiveFeatures* config,
- EmbeddingExecutor* embedding_executor, int feature_vector_size)
- : extraction_span_(extraction_span), config_(config) {
- int num_extracted_tokens = 0;
- num_extracted_tokens += config->num_tokens_before();
- num_extracted_tokens += config->num_tokens_inside_left();
- num_extracted_tokens += config->num_tokens_inside_right();
- num_extracted_tokens += config->num_tokens_after();
- if (config->include_inside_bag()) {
- ++num_extracted_tokens;
- }
- output_features_size_ = num_extracted_tokens * feature_vector_size;
- if (config->include_inside_length()) {
- ++output_features_size_;
+ const FeatureProcessorOptions* options,
+ EmbeddingExecutor* embedding_executor, int feature_vector_size) {
+ const int min_feature_version =
+ options->bounds_sensitive_features() &&
+ options->bounds_sensitive_features()->enabled()
+ ? 2
+ : 1;
+ if (options->feature_version() < min_feature_version) {
+ TC_LOG(ERROR) << "Unsupported feature version.";
+ return nullptr;
}
- features_.resize(feature_vector_size * TokenSpanSize(extraction_span));
+ std::unique_ptr<CachedFeatures> cached_features(new CachedFeatures());
+ cached_features->extraction_span_ = extraction_span;
+ cached_features->options_ = options;
+
+ cached_features->output_features_size_ =
+ CalculateOutputFeaturesSize(options, feature_vector_size);
+
+ cached_features->features_.resize(feature_vector_size *
+ TokenSpanSize(extraction_span));
for (int i = 0; i < TokenSpanSize(extraction_span); ++i) {
if (!PopulateTokenFeatures(/*target_feature_index=*/i, sparse_features[i],
dense_features[i], feature_vector_size,
- embedding_executor, &features_)) {
+ embedding_executor,
+ &cached_features->features_)) {
TC_LOG(ERROR) << "Could not embed sparse token features.";
- return;
+ return nullptr;
}
}
- padding_features_.resize(feature_vector_size);
+ cached_features->padding_features_.resize(feature_vector_size);
if (!PopulateTokenFeatures(/*target_feature_index=*/0,
padding_sparse_features, padding_dense_features,
feature_vector_size, embedding_executor,
- &padding_features_)) {
+ &cached_features->padding_features_)) {
TC_LOG(ERROR) << "Could not embed sparse padding token features.";
- return;
+ return nullptr;
}
+
+ return cached_features;
}
-std::vector<float> CachedFeatures::Get(TokenSpan selected_span) const {
+void CachedFeatures::AppendClickContextFeaturesForClick(
+ int click_pos, std::vector<float>* output_features) const {
+ click_pos -= extraction_span_.first;
+
+ AppendFeaturesInternal(
+ /*intended_span=*/ExpandTokenSpan(SingleTokenSpan(click_pos),
+ options_->context_size(),
+ options_->context_size()),
+ /*read_mask_span=*/{0, TokenSpanSize(extraction_span_)}, output_features);
+}
+
+void CachedFeatures::AppendBoundsSensitiveFeaturesForSpan(
+ TokenSpan selected_span, std::vector<float>* output_features) const {
+ const FeatureProcessorOptions_::BoundsSensitiveFeatures* config =
+ options_->bounds_sensitive_features();
+
selected_span.first -= extraction_span_.first;
selected_span.second -= extraction_span_.first;
- std::vector<float> output_features;
- output_features.reserve(output_features_size_);
-
// Append the features for tokens around the left bound. Masks out tokens
// after the right bound, so that if num_tokens_inside_left goes past it,
// padding tokens will be used.
- AppendFeatures(
- /*intended_span=*/{selected_span.first - config_->num_tokens_before(),
+ AppendFeaturesInternal(
+ /*intended_span=*/{selected_span.first - config->num_tokens_before(),
selected_span.first +
- config_->num_tokens_inside_left()},
- /*read_mask_span=*/{0, selected_span.second}, &output_features);
+ config->num_tokens_inside_left()},
+ /*read_mask_span=*/{0, selected_span.second}, output_features);
// Append the features for tokens around the right bound. Masks out tokens
// before the left bound, so that if num_tokens_inside_right goes past it,
// padding tokens will be used.
- AppendFeatures(
+ AppendFeaturesInternal(
/*intended_span=*/{selected_span.second -
- config_->num_tokens_inside_right(),
- selected_span.second + config_->num_tokens_after()},
+ config->num_tokens_inside_right(),
+ selected_span.second + config->num_tokens_after()},
/*read_mask_span=*/{selected_span.first, TokenSpanSize(extraction_span_)},
- &output_features);
+ output_features);
- if (config_->include_inside_bag()) {
- AppendSummedFeatures(selected_span, &output_features);
+ if (config->include_inside_bag()) {
+ AppendBagFeatures(selected_span, output_features);
}
- if (config_->include_inside_length()) {
- output_features.push_back(static_cast<float>(TokenSpanSize(selected_span)));
+ if (config->include_inside_length()) {
+ output_features->push_back(
+ static_cast<float>(TokenSpanSize(selected_span)));
}
-
- return output_features;
}
-void CachedFeatures::AppendFeatures(const TokenSpan& intended_span,
- const TokenSpan& read_mask_span,
- std::vector<float>* output_features) const {
+void CachedFeatures::AppendFeaturesInternal(
+ const TokenSpan& intended_span, const TokenSpan& read_mask_span,
+ std::vector<float>* output_features) const {
const TokenSpan copy_span =
IntersectTokenSpans(intended_span, read_mask_span);
for (int i = intended_span.first; i < copy_span.first; ++i) {
@@ -158,14 +207,14 @@
padding_features_.end());
}
-void CachedFeatures::AppendSummedFeatures(
- const TokenSpan& summing_span, std::vector<float>* output_features) const {
+void CachedFeatures::AppendBagFeatures(
+ const TokenSpan& bag_span, std::vector<float>* output_features) const {
const int offset = output_features->size();
output_features->resize(output_features->size() + NumFeaturesPerToken());
- for (int i = summing_span.first; i < summing_span.second; ++i) {
+ for (int i = bag_span.first; i < bag_span.second; ++i) {
for (int j = 0; j < NumFeaturesPerToken(); ++j) {
(*output_features)[offset + j] +=
- features_[i * NumFeaturesPerToken() + j];
+ features_[i * NumFeaturesPerToken() + j] / TokenSpanSize(bag_span);
}
}
}
diff --git a/cached-features.h b/cached-features.h
index 5ffb9a9..86b700f 100644
--- a/cached-features.h
+++ b/cached-features.h
@@ -14,8 +14,8 @@
* limitations under the License.
*/
-#ifndef KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_CACHED_FEATURES_H_
-#define KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_CACHED_FEATURES_H_
+#ifndef LIBTEXTCLASSIFIER_CACHED_FEATURES_H_
+#define LIBTEXTCLASSIFIER_CACHED_FEATURES_H_
#include <memory>
#include <vector>
@@ -30,40 +30,52 @@
// Assumes that features for each Token are independent.
class CachedFeatures {
public:
- CachedFeatures(
+ static std::unique_ptr<CachedFeatures> Create(
const TokenSpan& extraction_span,
const std::vector<std::vector<int>>& sparse_features,
const std::vector<std::vector<float>>& dense_features,
const std::vector<int>& padding_sparse_features,
const std::vector<float>& padding_dense_features,
- const FeatureProcessorOptions_::BoundsSensitiveFeatures* config,
+ const FeatureProcessorOptions* options,
EmbeddingExecutor* embedding_executor, int feature_vector_size);
- // Gets a vector of features for the given token span.
- std::vector<float> Get(TokenSpan selected_span) const;
+ // Appends the click context features for the given click position to
+ // 'output_features'.
+ void AppendClickContextFeaturesForClick(
+ int click_pos, std::vector<float>* output_features) const;
+
+ // Appends the bounds-sensitive features for the given token span to
+ // 'output_features'.
+ void AppendBoundsSensitiveFeaturesForSpan(
+ TokenSpan selected_span, std::vector<float>* output_features) const;
+
+ // Returns number of features that 'AppendFeaturesForSpan' appends.
+ int OutputFeaturesSize() const { return output_features_size_; }
private:
+ CachedFeatures() {}
+
// Appends token features to the output. The intended_span specifies which
// tokens' features should be used in principle. The read_mask_span restricts
// which tokens are actually read. For tokens outside of the read_mask_span,
// padding tokens are used instead.
- void AppendFeatures(const TokenSpan& intended_span,
- const TokenSpan& read_mask_span,
- std::vector<float>* output_features) const;
+ void AppendFeaturesInternal(const TokenSpan& intended_span,
+ const TokenSpan& read_mask_span,
+ std::vector<float>* output_features) const;
// Appends features of one padding token to the output.
void AppendPaddingFeatures(std::vector<float>* output_features) const;
// Appends the features of tokens from the given span to the output. The
- // features are summed so that the appended features have the size
+ // features are averaged so that the appended features have the size
// corresponding to one token.
- void AppendSummedFeatures(const TokenSpan& summing_span,
- std::vector<float>* output_features) const;
+ void AppendBagFeatures(const TokenSpan& bag_span,
+ std::vector<float>* output_features) const;
int NumFeaturesPerToken() const;
- const TokenSpan extraction_span_;
- const FeatureProcessorOptions_::BoundsSensitiveFeatures* config_;
+ TokenSpan extraction_span_;
+ const FeatureProcessorOptions* options_;
int output_features_size_;
std::vector<float> features_;
std::vector<float> padding_features_;
@@ -71,4 +83,4 @@
} // namespace libtextclassifier2
-#endif // KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_CACHED_FEATURES_H_
+#endif // LIBTEXTCLASSIFIER_CACHED_FEATURES_H_
diff --git a/cached-features_test.cc b/cached-features_test.cc
index 2412ff3..9566a8d 100644
--- a/cached-features_test.cc
+++ b/cached-features_test.cc
@@ -54,18 +54,29 @@
std::vector<float> storage_;
};
-TEST(CachedFeaturesTest, Simple) {
- FeatureProcessorOptions_::BoundsSensitiveFeaturesT config;
- config.enabled = true;
- config.num_tokens_before = 2;
- config.num_tokens_inside_left = 2;
- config.num_tokens_inside_right = 2;
- config.num_tokens_after = 2;
- config.include_inside_bag = true;
- config.include_inside_length = true;
+std::vector<float> GetCachedClickContextFeatures(
+ const CachedFeatures& cached_features, int click_pos) {
+ std::vector<float> output_features;
+ cached_features.AppendClickContextFeaturesForClick(click_pos,
+ &output_features);
+ return output_features;
+}
+
+std::vector<float> GetCachedBoundsSensitiveFeatures(
+ const CachedFeatures& cached_features, TokenSpan selected_span) {
+ std::vector<float> output_features;
+ cached_features.AppendBoundsSensitiveFeaturesForSpan(selected_span,
+ &output_features);
+ return output_features;
+}
+
+TEST(CachedFeaturesTest, ClickContext) {
+ FeatureProcessorOptionsT options;
+ options.context_size = 2;
+ options.feature_version = 1;
flatbuffers::FlatBufferBuilder builder;
- builder.Finish(CreateBoundsSensitiveFeatures(builder, &config));
- flatbuffers::DetachedBuffer config_fb = builder.Release();
+ builder.Finish(CreateFeatureProcessorOptions(builder, &options));
+ flatbuffers::DetachedBuffer options_fb = builder.Release();
std::vector<std::vector<int>> sparse_features(9);
for (int i = 0; i < sparse_features.size(); ++i) {
@@ -80,36 +91,88 @@
std::vector<float> padding_dense_features = {321.0};
FakeEmbeddingExecutor executor;
- const CachedFeatures cached_features(
- {3, 9}, sparse_features, dense_features, padding_sparse_features,
- padding_dense_features,
- flatbuffers::GetRoot<FeatureProcessorOptions_::BoundsSensitiveFeatures>(
- config_fb.data()),
- &executor, /*feature_vector_size=*/3);
+ const std::unique_ptr<CachedFeatures> cached_features =
+ CachedFeatures::Create(
+ {3, 10}, sparse_features, dense_features, padding_sparse_features,
+ padding_dense_features,
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
+ &executor, /*feature_vector_size=*/3);
+ ASSERT_TRUE(cached_features);
- EXPECT_THAT(cached_features.Get({5, 8}),
- ElementsAreFloat({11.0, -11.0, 0.1, 22.0, -22.0, 0.2,
- 33.0, -33.0, 0.3, 44.0, -44.0, 0.4,
- 44.0, -44.0, 0.4, 55.0, -55.0, 0.5,
- 66.0, -66.0, 0.6, 112233.0, -112233.0, 321.0,
- 132.0, -132.0, 1.2, 3.0}));
+ EXPECT_THAT(GetCachedClickContextFeatures(*cached_features, 5),
+ ElementsAreFloat({11.0, -11.0, 0.1, 22.0, -22.0, 0.2, 33.0, -33.0,
+ 0.3, 44.0, -44.0, 0.4, 55.0, -55.0, 0.5}));
+
+ EXPECT_THAT(GetCachedClickContextFeatures(*cached_features, 6),
+ ElementsAreFloat({22.0, -22.0, 0.2, 33.0, -33.0, 0.3, 44.0, -44.0,
+ 0.4, 55.0, -55.0, 0.5, 66.0, -66.0, 0.6}));
+
+ EXPECT_THAT(GetCachedClickContextFeatures(*cached_features, 7),
+ ElementsAreFloat({33.0, -33.0, 0.3, 44.0, -44.0, 0.4, 55.0, -55.0,
+ 0.5, 66.0, -66.0, 0.6, 77.0, -77.0, 0.7}));
+}
+
+TEST(CachedFeaturesTest, BoundsSensitive) {
+ std::unique_ptr<FeatureProcessorOptions_::BoundsSensitiveFeaturesT> config(
+ new FeatureProcessorOptions_::BoundsSensitiveFeaturesT());
+ config->enabled = true;
+ config->num_tokens_before = 2;
+ config->num_tokens_inside_left = 2;
+ config->num_tokens_inside_right = 2;
+ config->num_tokens_after = 2;
+ config->include_inside_bag = true;
+ config->include_inside_length = true;
+ FeatureProcessorOptionsT options;
+ options.bounds_sensitive_features = std::move(config);
+ options.feature_version = 2;
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(CreateFeatureProcessorOptions(builder, &options));
+ flatbuffers::DetachedBuffer options_fb = builder.Release();
+
+ std::vector<std::vector<int>> sparse_features(6);
+ for (int i = 0; i < sparse_features.size(); ++i) {
+ sparse_features[i].push_back(i + 1);
+ }
+ std::vector<std::vector<float>> dense_features(6);
+ for (int i = 0; i < dense_features.size(); ++i) {
+ dense_features[i].push_back((i + 1) * 0.1);
+ }
+
+ std::vector<int> padding_sparse_features = {10203};
+ std::vector<float> padding_dense_features = {321.0};
+
+ FakeEmbeddingExecutor executor;
+ const std::unique_ptr<CachedFeatures> cached_features =
+ CachedFeatures::Create(
+ {3, 9}, sparse_features, dense_features, padding_sparse_features,
+ padding_dense_features,
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
+ &executor, /*feature_vector_size=*/3);
+ ASSERT_TRUE(cached_features);
EXPECT_THAT(
- cached_features.Get({5, 7}),
+ GetCachedBoundsSensitiveFeatures(*cached_features, {5, 8}),
+ ElementsAreFloat({11.0, -11.0, 0.1, 22.0, -22.0, 0.2, 33.0,
+ -33.0, 0.3, 44.0, -44.0, 0.4, 44.0, -44.0,
+ 0.4, 55.0, -55.0, 0.5, 66.0, -66.0, 0.6,
+ 112233.0, -112233.0, 321.0, 44.0, -44.0, 0.4, 3.0}));
+
+ EXPECT_THAT(
+ GetCachedBoundsSensitiveFeatures(*cached_features, {5, 7}),
ElementsAreFloat({11.0, -11.0, 0.1, 22.0, -22.0, 0.2, 33.0,
-33.0, 0.3, 44.0, -44.0, 0.4, 33.0, -33.0,
0.3, 44.0, -44.0, 0.4, 55.0, -55.0, 0.5,
- 66.0, -66.0, 0.6, 77.0, -77.0, 0.7, 2.0}));
+ 66.0, -66.0, 0.6, 38.5, -38.5, 0.35, 2.0}));
EXPECT_THAT(
- cached_features.Get({6, 8}),
+ GetCachedBoundsSensitiveFeatures(*cached_features, {6, 8}),
ElementsAreFloat({22.0, -22.0, 0.2, 33.0, -33.0, 0.3, 44.0,
-44.0, 0.4, 55.0, -55.0, 0.5, 44.0, -44.0,
0.4, 55.0, -55.0, 0.5, 66.0, -66.0, 0.6,
- 112233.0, -112233.0, 321.0, 99.0, -99.0, 0.9, 2.0}));
+ 112233.0, -112233.0, 321.0, 49.5, -49.5, 0.45, 2.0}));
EXPECT_THAT(
- cached_features.Get({6, 7}),
+ GetCachedBoundsSensitiveFeatures(*cached_features, {6, 7}),
ElementsAreFloat({22.0, -22.0, 0.2, 33.0, -33.0, 0.3,
44.0, -44.0, 0.4, 112233.0, -112233.0, 321.0,
112233.0, -112233.0, 321.0, 44.0, -44.0, 0.4,
diff --git a/datetime/extractor.cc b/datetime/extractor.cc
new file mode 100644
index 0000000..79d7686
--- /dev/null
+++ b/datetime/extractor.cc
@@ -0,0 +1,420 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "datetime/extractor.h"
+
+#include "util/base/logging.h"
+
+namespace libtextclassifier2 {
+
+constexpr char const* kGroupYear = "YEAR";
+constexpr char const* kGroupMonth = "MONTH";
+constexpr char const* kGroupDay = "DAY";
+constexpr char const* kGroupHour = "HOUR";
+constexpr char const* kGroupMinute = "MINUTE";
+constexpr char const* kGroupSecond = "SECOND";
+constexpr char const* kGroupAmpm = "AMPM";
+constexpr char const* kGroupRelationDistance = "RELATIONDISTANCE";
+constexpr char const* kGroupRelation = "RELATION";
+constexpr char const* kGroupRelationType = "RELATIONTYPE";
+
+bool DatetimeExtractor::Extract(DateParseData* result) const {
+ result->field_set_mask = 0;
+
+ UnicodeText group_text;
+ if (GroupNotEmpty(kGroupYear, &group_text)) {
+ result->field_set_mask |= DateParseData::YEAR_FIELD;
+ if (!ParseYear(group_text, &(result->year))) {
+ TC_LOG(ERROR) << "Couldn't extract YEAR.";
+ return false;
+ }
+ }
+
+ if (GroupNotEmpty(kGroupMonth, &group_text)) {
+ result->field_set_mask |= DateParseData::MONTH_FIELD;
+ if (!ParseMonth(group_text, &(result->month))) {
+ TC_LOG(ERROR) << "Couldn't extract MONTH.";
+ return false;
+ }
+ }
+
+ if (GroupNotEmpty(kGroupDay, &group_text)) {
+ result->field_set_mask |= DateParseData::DAY_FIELD;
+ if (!ParseDigits(group_text, &(result->day_of_month))) {
+ TC_LOG(ERROR) << "Couldn't extract DAY.";
+ return false;
+ }
+ }
+
+ if (GroupNotEmpty(kGroupHour, &group_text)) {
+ result->field_set_mask |= DateParseData::HOUR_FIELD;
+ if (!ParseDigits(group_text, &(result->hour))) {
+ TC_LOG(ERROR) << "Couldn't extract HOUR.";
+ return false;
+ }
+ }
+
+ if (GroupNotEmpty(kGroupMinute, &group_text)) {
+ result->field_set_mask |= DateParseData::MINUTE_FIELD;
+ if (!ParseDigits(group_text, &(result->minute))) {
+ TC_LOG(ERROR) << "Couldn't extract MINUTE.";
+ return false;
+ }
+ }
+
+ if (GroupNotEmpty(kGroupSecond, &group_text)) {
+ result->field_set_mask |= DateParseData::SECOND_FIELD;
+ if (!ParseDigits(group_text, &(result->second))) {
+ TC_LOG(ERROR) << "Couldn't extract SECOND.";
+ return false;
+ }
+ }
+
+ if (GroupNotEmpty(kGroupAmpm, &group_text)) {
+ result->field_set_mask |= DateParseData::AMPM_FIELD;
+ if (!ParseAMPM(group_text, &(result->ampm))) {
+ TC_LOG(ERROR) << "Couldn't extract AMPM.";
+ return false;
+ }
+ }
+
+ if (GroupNotEmpty(kGroupRelationDistance, &group_text)) {
+ result->field_set_mask |= DateParseData::RELATION_DISTANCE_FIELD;
+ if (!ParseRelationDistance(group_text, &(result->relation_distance))) {
+ TC_LOG(ERROR) << "Couldn't extract RELATION_DISTANCE_FIELD.";
+ return false;
+ }
+ }
+
+ if (GroupNotEmpty(kGroupRelation, &group_text)) {
+ result->field_set_mask |= DateParseData::RELATION_FIELD;
+ if (!ParseRelation(group_text, &(result->relation))) {
+ TC_LOG(ERROR) << "Couldn't extract RELATION_FIELD.";
+ return false;
+ }
+ }
+
+ if (GroupNotEmpty(kGroupRelationType, &group_text)) {
+ result->field_set_mask |= DateParseData::RELATION_TYPE_FIELD;
+ if (!ParseRelationType(group_text, &(result->relation_type))) {
+ TC_LOG(ERROR) << "Couldn't extract RELATION_TYPE_FIELD.";
+ return false;
+ }
+ }
+
+ return true;
+}
+
+bool DatetimeExtractor::RuleIdForType(DatetimeExtractorType type,
+ int* rule_id) const {
+ auto type_it = type_and_locale_to_rule_.find(type);
+ if (type_it == type_and_locale_to_rule_.end()) {
+ return false;
+ }
+
+ auto locale_it = type_it->second.find(locale_id_);
+ if (locale_it == type_it->second.end()) {
+ return false;
+ }
+ *rule_id = locale_it->second;
+ return true;
+}
+
+bool DatetimeExtractor::ExtractType(const UnicodeText& input,
+ DatetimeExtractorType extractor_type,
+ UnicodeText* match_result) const {
+ int rule_id;
+ if (!RuleIdForType(extractor_type, &rule_id)) {
+ return false;
+ }
+
+ std::unique_ptr<UniLib::RegexMatcher> matcher =
+ rules_[rule_id]->Matcher(input);
+ if (!matcher) {
+ return false;
+ }
+
+ int status;
+ if (!matcher->Find(&status)) {
+ return false;
+ }
+
+ if (match_result != nullptr) {
+ *match_result = matcher->Group(&status);
+ if (status != UniLib::RegexMatcher::kNoError) {
+ return false;
+ }
+ }
+ return true;
+}
+
+bool DatetimeExtractor::GroupNotEmpty(const std::string& name,
+ UnicodeText* result) const {
+ int status;
+ *result = matcher_.Group(name, &status);
+ if (status != UniLib::RegexMatcher::kNoError) {
+ return false;
+ }
+ return !result->empty();
+}
+
+template <typename T>
+bool DatetimeExtractor::MapInput(
+ const UnicodeText& input,
+ const std::vector<std::pair<DatetimeExtractorType, T>>& mapping,
+ T* result) const {
+ for (const auto& type_value_pair : mapping) {
+ if (ExtractType(input, type_value_pair.first)) {
+ *result = type_value_pair.second;
+ return true;
+ }
+ }
+ return false;
+}
+
+bool DatetimeExtractor::ParseWrittenNumber(const UnicodeText& input,
+ int* parsed_number) const {
+ std::vector<std::pair<int, int>> found_numbers;
+ for (const auto& type_value_pair :
+ std::vector<std::pair<DatetimeExtractorType, int>>{
+ {DatetimeExtractorType_ZERO, 0},
+ {DatetimeExtractorType_ONE, 1},
+ {DatetimeExtractorType_TWO, 2},
+ {DatetimeExtractorType_THREE, 3},
+ {DatetimeExtractorType_FOUR, 4},
+ {DatetimeExtractorType_FIVE, 5},
+ {DatetimeExtractorType_SIX, 6},
+ {DatetimeExtractorType_SEVEN, 7},
+ {DatetimeExtractorType_EIGHT, 8},
+ {DatetimeExtractorType_NINE, 9},
+ {DatetimeExtractorType_TEN, 10},
+ {DatetimeExtractorType_ELEVEN, 11},
+ {DatetimeExtractorType_TWELVE, 12},
+ {DatetimeExtractorType_THIRTEEN, 13},
+ {DatetimeExtractorType_FOURTEEN, 14},
+ {DatetimeExtractorType_FIFTEEN, 15},
+ {DatetimeExtractorType_SIXTEEN, 16},
+ {DatetimeExtractorType_SEVENTEEN, 17},
+ {DatetimeExtractorType_EIGHTEEN, 18},
+ {DatetimeExtractorType_NINETEEN, 19},
+ {DatetimeExtractorType_TWENTY, 20},
+ {DatetimeExtractorType_THIRTY, 30},
+ {DatetimeExtractorType_FORTY, 40},
+ {DatetimeExtractorType_FIFTY, 50},
+ {DatetimeExtractorType_SIXTY, 60},
+ {DatetimeExtractorType_SEVENTY, 70},
+ {DatetimeExtractorType_EIGHTY, 80},
+ {DatetimeExtractorType_NINETY, 90},
+ {DatetimeExtractorType_HUNDRED, 100},
+ {DatetimeExtractorType_THOUSAND, 1000},
+ }) {
+ int rule_id;
+ if (!RuleIdForType(type_value_pair.first, &rule_id)) {
+ return false;
+ }
+
+ std::unique_ptr<UniLib::RegexMatcher> matcher =
+ rules_[rule_id]->Matcher(input);
+ if (!matcher) {
+ return false;
+ }
+
+ int status;
+ while (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) {
+ int span_start = matcher->Start(&status);
+ if (status != UniLib::RegexMatcher::kNoError) {
+ return false;
+ }
+ found_numbers.push_back({span_start, type_value_pair.second});
+ }
+ }
+
+ std::sort(found_numbers.begin(), found_numbers.end(),
+ [](const std::pair<int, int>& a, const std::pair<int, int>& b) {
+ return a.first < b.first;
+ });
+
+ int sum = 0;
+ int running_value = -1;
+ // Simple math to make sure we handle written numerical modifiers correctly
+ // so that :="fifty one thousand and one" maps to 51001 and not 50 1 1000 1.
+ for (const std::pair<int, int> position_number_pair : found_numbers) {
+ if (running_value >= 0) {
+ if (running_value > position_number_pair.second) {
+ sum += running_value;
+ running_value = position_number_pair.second;
+ } else {
+ running_value *= position_number_pair.second;
+ }
+ } else {
+ running_value = position_number_pair.second;
+ }
+ }
+ sum += running_value;
+ *parsed_number = sum;
+ return true;
+}
+
+bool DatetimeExtractor::ParseDigits(const UnicodeText& input,
+ int* parsed_digits) const {
+ UnicodeText digit;
+ if (!ExtractType(input, DatetimeExtractorType_DIGITS, &digit)) {
+ return false;
+ }
+
+ if (!unilib_.ParseInt32(digit, parsed_digits)) {
+ return false;
+ }
+ return true;
+}
+
+bool DatetimeExtractor::ParseYear(const UnicodeText& input,
+ int* parsed_year) const {
+ if (!ParseDigits(input, parsed_year)) {
+ return false;
+ }
+
+ if (*parsed_year < 100) {
+ if (*parsed_year < 50) {
+ *parsed_year += 2000;
+ } else {
+ *parsed_year += 1900;
+ }
+ }
+
+ return true;
+}
+
+bool DatetimeExtractor::ParseMonth(const UnicodeText& input,
+ int* parsed_month) const {
+ if (ParseDigits(input, parsed_month)) {
+ return true;
+ }
+
+ if (MapInput(input,
+ {
+ {DatetimeExtractorType_JANUARY, 1},
+ {DatetimeExtractorType_FEBRUARY, 2},
+ {DatetimeExtractorType_MARCH, 3},
+ {DatetimeExtractorType_APRIL, 4},
+ {DatetimeExtractorType_MAY, 5},
+ {DatetimeExtractorType_JUNE, 6},
+ {DatetimeExtractorType_JULY, 7},
+ {DatetimeExtractorType_AUGUST, 8},
+ {DatetimeExtractorType_SEPTEMBER, 9},
+ {DatetimeExtractorType_OCTOBER, 10},
+ {DatetimeExtractorType_NOVEMBER, 11},
+ {DatetimeExtractorType_DECEMBER, 12},
+ },
+ parsed_month)) {
+ return true;
+ }
+
+ return false;
+}
+
+bool DatetimeExtractor::ParseAMPM(const UnicodeText& input,
+ int* parsed_ampm) const {
+ return MapInput(input,
+ {
+ {DatetimeExtractorType_AM, DateParseData::AMPM::AM},
+ {DatetimeExtractorType_PM, DateParseData::AMPM::PM},
+ },
+ parsed_ampm);
+}
+
+bool DatetimeExtractor::ParseRelationDistance(const UnicodeText& input,
+ int* parsed_distance) const {
+ if (ParseDigits(input, parsed_distance)) {
+ return true;
+ }
+ if (ParseWrittenNumber(input, parsed_distance)) {
+ return true;
+ }
+ return false;
+}
+
+bool DatetimeExtractor::ParseRelation(
+ const UnicodeText& input, DateParseData::Relation* parsed_relation) const {
+ return MapInput(
+ input,
+ {
+ {DatetimeExtractorType_NOW, DateParseData::Relation::NOW},
+ {DatetimeExtractorType_YESTERDAY, DateParseData::Relation::YESTERDAY},
+ {DatetimeExtractorType_TOMORROW, DateParseData::Relation::TOMORROW},
+ {DatetimeExtractorType_NEXT, DateParseData::Relation::NEXT},
+ {DatetimeExtractorType_NEXT_OR_SAME,
+ DateParseData::Relation::NEXT_OR_SAME},
+ {DatetimeExtractorType_LAST, DateParseData::Relation::LAST},
+ {DatetimeExtractorType_PAST, DateParseData::Relation::PAST},
+ {DatetimeExtractorType_FUTURE, DateParseData::Relation::FUTURE},
+ },
+ parsed_relation);
+}
+
+bool DatetimeExtractor::ParseRelationType(
+ const UnicodeText& input,
+ DateParseData::RelationType* parsed_relation_type) const {
+ return MapInput(
+ input,
+ {
+ {DatetimeExtractorType_MONDAY, DateParseData::MONDAY},
+ {DatetimeExtractorType_TUESDAY, DateParseData::TUESDAY},
+ {DatetimeExtractorType_WEDNESDAY, DateParseData::WEDNESDAY},
+ {DatetimeExtractorType_THURSDAY, DateParseData::THURSDAY},
+ {DatetimeExtractorType_FRIDAY, DateParseData::FRIDAY},
+ {DatetimeExtractorType_SATURDAY, DateParseData::SATURDAY},
+ {DatetimeExtractorType_SUNDAY, DateParseData::SUNDAY},
+ {DatetimeExtractorType_DAY, DateParseData::DAY},
+ {DatetimeExtractorType_WEEK, DateParseData::WEEK},
+ {DatetimeExtractorType_MONTH, DateParseData::MONTH},
+ {DatetimeExtractorType_YEAR, DateParseData::YEAR},
+ },
+ parsed_relation_type);
+}
+
+bool DatetimeExtractor::ParseTimeUnit(const UnicodeText& input,
+ int* parsed_time_unit) const {
+ return MapInput(input,
+ {
+ {DatetimeExtractorType_DAYS, DateParseData::DAYS},
+ {DatetimeExtractorType_WEEKS, DateParseData::WEEKS},
+ {DatetimeExtractorType_MONTHS, DateParseData::MONTHS},
+ {DatetimeExtractorType_HOURS, DateParseData::HOURS},
+ {DatetimeExtractorType_MINUTES, DateParseData::MINUTES},
+ {DatetimeExtractorType_SECONDS, DateParseData::SECONDS},
+ {DatetimeExtractorType_YEARS, DateParseData::YEARS},
+ },
+ parsed_time_unit);
+}
+
+bool DatetimeExtractor::ParseWeekday(const UnicodeText& input,
+ int* parsed_weekday) const {
+ return MapInput(
+ input,
+ {
+ {DatetimeExtractorType_MONDAY, DateParseData::MONDAY},
+ {DatetimeExtractorType_TUESDAY, DateParseData::TUESDAY},
+ {DatetimeExtractorType_WEDNESDAY, DateParseData::WEDNESDAY},
+ {DatetimeExtractorType_THURSDAY, DateParseData::THURSDAY},
+ {DatetimeExtractorType_FRIDAY, DateParseData::FRIDAY},
+ {DatetimeExtractorType_SATURDAY, DateParseData::SATURDAY},
+ {DatetimeExtractorType_SUNDAY, DateParseData::SUNDAY},
+ },
+ parsed_weekday);
+}
+
+} // namespace libtextclassifier2
diff --git a/datetime/extractor.h b/datetime/extractor.h
new file mode 100644
index 0000000..f068dff
--- /dev/null
+++ b/datetime/extractor.h
@@ -0,0 +1,92 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_DATETIME_EXTRACTOR_H_
+#define LIBTEXTCLASSIFIER_DATETIME_EXTRACTOR_H_
+
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "model_generated.h"
+#include "util/calendar/types.h"
+#include "util/utf8/unicodetext.h"
+#include "util/utf8/unilib.h"
+
+namespace libtextclassifier2 {
+
+// A helper class for DatetimeParser that extracts structured data
+// (DateParseDate) from the current match of the passed RegexMatcher.
+class DatetimeExtractor {
+ public:
+ DatetimeExtractor(
+ const UniLib::RegexMatcher& matcher, int locale_id, const UniLib& unilib,
+ const std::vector<std::unique_ptr<UniLib::RegexPattern>>& extractor_rules,
+ const std::unordered_map<DatetimeExtractorType,
+ std::unordered_map<int, int>>&
+ type_and_locale_to_extractor_rule)
+ : matcher_(matcher),
+ locale_id_(locale_id),
+ unilib_(unilib),
+ rules_(extractor_rules),
+ type_and_locale_to_rule_(type_and_locale_to_extractor_rule) {}
+ bool Extract(DateParseData* result) const;
+
+ private:
+ bool RuleIdForType(DatetimeExtractorType type, int* rule_id) const;
+
+ // Returns true if the rule for given extractor matched. If it matched,
+ // match_result will contain the first group of the rule (if match_result not
+ // nullptr).
+ bool ExtractType(const UnicodeText& input,
+ DatetimeExtractorType extractor_type,
+ UnicodeText* match_result = nullptr) const;
+
+ bool GroupNotEmpty(const std::string& name, UnicodeText* result) const;
+
+ // Returns true if any of the extractors from 'mapping' matched. If it did,
+ // will fill 'result' with the associated value from 'mapping'.
+ template <typename T>
+ bool MapInput(const UnicodeText& input,
+ const std::vector<std::pair<DatetimeExtractorType, T>>& mapping,
+ T* result) const;
+
+ bool ParseDigits(const UnicodeText& input, int* parsed_digits) const;
+ bool ParseWrittenNumber(const UnicodeText& input, int* parsed_number) const;
+ bool ParseYear(const UnicodeText& input, int* parsed_year) const;
+ bool ParseMonth(const UnicodeText& input, int* parsed_month) const;
+ bool ParseAMPM(const UnicodeText& input, int* parsed_ampm) const;
+ bool ParseRelation(const UnicodeText& input,
+ DateParseData::Relation* parsed_relation) const;
+ bool ParseRelationDistance(const UnicodeText& input,
+ int* parsed_distance) const;
+ bool ParseTimeUnit(const UnicodeText& input, int* parsed_time_unit) const;
+ bool ParseRelationType(
+ const UnicodeText& input,
+ DateParseData::RelationType* parsed_relation_type) const;
+ bool ParseWeekday(const UnicodeText& input, int* parsed_weekday) const;
+
+ const UniLib::RegexMatcher& matcher_;
+ int locale_id_;
+ const UniLib& unilib_;
+ const std::vector<std::unique_ptr<UniLib::RegexPattern>>& rules_;
+ const std::unordered_map<DatetimeExtractorType, std::unordered_map<int, int>>&
+ type_and_locale_to_rule_;
+};
+
+} // namespace libtextclassifier2
+
+#endif // LIBTEXTCLASSIFIER_DATETIME_EXTRACTOR_H_
diff --git a/datetime/parser.cc b/datetime/parser.cc
new file mode 100644
index 0000000..27d0a00
--- /dev/null
+++ b/datetime/parser.cc
@@ -0,0 +1,244 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "datetime/parser.h"
+
+#include <set>
+#include <unordered_set>
+
+#include "datetime/extractor.h"
+#include "util/calendar/calendar.h"
+#include "util/strings/split.h"
+
+namespace libtextclassifier2 {
+std::unique_ptr<DatetimeParser> DatetimeParser::Instance(
+ const DatetimeModel* model, const UniLib& unilib) {
+ std::unique_ptr<DatetimeParser> result(new DatetimeParser(model, unilib));
+ if (!result->initialized_) {
+ result.reset();
+ }
+ return result;
+}
+
+DatetimeParser::DatetimeParser(const DatetimeModel* model, const UniLib& unilib)
+ : unilib_(unilib) {
+ initialized_ = false;
+ for (int i = 0; i < model->patterns()->Length(); ++i) {
+ const DatetimeModelPattern* pattern = model->patterns()->Get(i);
+ for (int j = 0; j < pattern->regexes()->Length(); ++j) {
+ std::unique_ptr<UniLib::RegexPattern> regex_pattern =
+ unilib.CreateRegexPattern(UTF8ToUnicodeText(
+ pattern->regexes()->Get(j)->str(), /*do_copy=*/false));
+ if (!regex_pattern) {
+ TC_LOG(ERROR) << "Couldn't create pattern: "
+ << pattern->regexes()->Get(j)->str();
+ return;
+ }
+ rules_.push_back(std::move(regex_pattern));
+ rule_id_to_pattern_.push_back(pattern);
+ for (int k = 0; k < pattern->locales()->Length(); ++k) {
+ locale_to_rules_[pattern->locales()->Get(k)].push_back(rules_.size() -
+ 1);
+ }
+ }
+ }
+
+ for (int i = 0; i < model->extractors()->Length(); ++i) {
+ const DatetimeModelExtractor* extractor = model->extractors()->Get(i);
+ std::unique_ptr<UniLib::RegexPattern> regex_pattern =
+ unilib.CreateRegexPattern(
+ UTF8ToUnicodeText(extractor->pattern()->str(), /*do_copy=*/false));
+ if (!regex_pattern) {
+ TC_LOG(ERROR) << "Couldn't create pattern: "
+ << extractor->pattern()->str();
+ return;
+ }
+ extractor_rules_.push_back(std::move(regex_pattern));
+
+ for (int j = 0; j < extractor->locales()->Length(); ++j) {
+ type_and_locale_to_extractor_rule_[extractor->extractor()]
+ [extractor->locales()->Get(j)] = i;
+ }
+ }
+
+ for (int i = 0; i < model->locales()->Length(); ++i) {
+ locale_string_to_id_[model->locales()->Get(i)->str()] = i;
+ }
+
+ initialized_ = true;
+}
+
+bool DatetimeParser::Parse(
+ const std::string& input, const int64 reference_time_ms_utc,
+ const std::string& reference_timezone, const std::string& locales,
+ std::vector<DatetimeParseResultSpan>* results) const {
+ return Parse(UTF8ToUnicodeText(input, /*do_copy=*/false),
+ reference_time_ms_utc, reference_timezone, locales, results);
+}
+
+bool DatetimeParser::Parse(
+ const UnicodeText& input, const int64 reference_time_ms_utc,
+ const std::string& reference_timezone, const std::string& locales,
+ std::vector<DatetimeParseResultSpan>* results) const {
+ std::vector<DatetimeParseResultSpan> found_spans;
+ std::unordered_set<int> executed_rules;
+ for (const int locale_id : ParseLocales(locales)) {
+ auto rules_it = locale_to_rules_.find(locale_id);
+ if (rules_it == locale_to_rules_.end()) {
+ continue;
+ }
+
+ for (const int rule_id : rules_it->second) {
+ // Skip rules that were already executed in previous locales.
+ if (executed_rules.find(rule_id) != executed_rules.end()) {
+ continue;
+ }
+ executed_rules.insert(rule_id);
+
+ if (!ParseWithRule(*rules_[rule_id], rule_id_to_pattern_[rule_id], input,
+ reference_time_ms_utc, reference_timezone, locale_id,
+ &found_spans)) {
+ return false;
+ }
+ }
+ }
+
+ // Resolve conflicts by always picking the longer span.
+ std::sort(
+ found_spans.begin(), found_spans.end(),
+ [](const DatetimeParseResultSpan& a, const DatetimeParseResultSpan& b) {
+ return (a.span.second - a.span.first) > (b.span.second - b.span.first);
+ });
+
+ std::set<int, std::function<bool(int, int)>> chosen_indices_set(
+ [&found_spans](int a, int b) {
+ return found_spans[a].span.first < found_spans[b].span.first;
+ });
+ for (int i = 0; i < found_spans.size(); ++i) {
+ if (!DoesCandidateConflict(i, found_spans, chosen_indices_set)) {
+ chosen_indices_set.insert(i);
+ results->push_back(found_spans[i]);
+ }
+ }
+
+ return true;
+}
+
+bool DatetimeParser::ParseWithRule(
+ const UniLib::RegexPattern& regex, const DatetimeModelPattern* pattern,
+ const UnicodeText& input, const int64 reference_time_ms_utc,
+ const std::string& reference_timezone, const int locale_id,
+ std::vector<DatetimeParseResultSpan>* result) const {
+ std::unique_ptr<UniLib::RegexMatcher> matcher = regex.Matcher(input);
+
+ int status = UniLib::RegexMatcher::kNoError;
+ while (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) {
+ const int start = matcher->Start(&status);
+ if (status != UniLib::RegexMatcher::kNoError) {
+ return false;
+ }
+
+ const int end = matcher->End(&status);
+ if (status != UniLib::RegexMatcher::kNoError) {
+ return false;
+ }
+
+ DatetimeParseResultSpan parse_result;
+ parse_result.span = {start, end};
+ if (!ExtractDatetime(*matcher, reference_time_ms_utc, reference_timezone,
+ locale_id, &(parse_result.data))) {
+ return false;
+ }
+ parse_result.target_classification_score =
+ pattern->target_classification_score();
+ parse_result.priority_score = pattern->priority_score();
+
+ result->push_back(parse_result);
+ }
+ return true;
+}
+
+constexpr char const* kDefaultLocale = "";
+
+std::vector<int> DatetimeParser::ParseLocales(
+ const std::string& locales) const {
+ std::vector<std::string> split_locales = strings::Split(locales, ',');
+
+ // Add a default fallback locale to the end of the list.
+ split_locales.push_back(kDefaultLocale);
+
+ std::vector<int> result;
+ for (const std::string& locale : split_locales) {
+ auto locale_it = locale_string_to_id_.find(locale);
+ if (locale_it == locale_string_to_id_.end()) {
+ TC_LOG(INFO) << "Ignoring locale: " << locale;
+ continue;
+ }
+ result.push_back(locale_it->second);
+ }
+ return result;
+}
+
+namespace {
+
+DatetimeGranularity GetGranularity(const DateParseData& data) {
+ DatetimeGranularity granularity = DatetimeGranularity::GRANULARITY_YEAR;
+ if (data.field_set_mask & DateParseData::YEAR_FIELD) {
+ granularity = DatetimeGranularity::GRANULARITY_YEAR;
+ }
+ if (data.field_set_mask & DateParseData::MONTH_FIELD) {
+ granularity = DatetimeGranularity::GRANULARITY_MONTH;
+ }
+ if (data.field_set_mask & DateParseData::DAY_FIELD) {
+ granularity = DatetimeGranularity::GRANULARITY_DAY;
+ }
+ if (data.field_set_mask & DateParseData::HOUR_FIELD) {
+ granularity = DatetimeGranularity::GRANULARITY_HOUR;
+ }
+ if (data.field_set_mask & DateParseData::MINUTE_FIELD) {
+ granularity = DatetimeGranularity::GRANULARITY_MINUTE;
+ }
+ if (data.field_set_mask & DateParseData::SECOND_FIELD) {
+ granularity = DatetimeGranularity::GRANULARITY_SECOND;
+ }
+ return granularity;
+}
+
+} // namespace
+
+bool DatetimeParser::ExtractDatetime(const UniLib::RegexMatcher& matcher,
+ const int64 reference_time_ms_utc,
+ const std::string& reference_timezone,
+ int locale_id,
+ DatetimeParseResult* result) const {
+ DateParseData parse;
+ DatetimeExtractor extractor(matcher, locale_id, unilib_, extractor_rules_,
+ type_and_locale_to_extractor_rule_);
+ if (!extractor.Extract(&parse)) {
+ return false;
+ }
+
+ if (!calendar_lib_.InterpretParseData(parse, reference_time_ms_utc,
+ reference_timezone,
+ &(result->time_ms_utc))) {
+ return false;
+ }
+ result->granularity = GetGranularity(parse);
+
+ return true;
+}
+
+} // namespace libtextclassifier2
diff --git a/datetime/parser.h b/datetime/parser.h
new file mode 100644
index 0000000..a56f83d
--- /dev/null
+++ b/datetime/parser.h
@@ -0,0 +1,84 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_DATETIME_PARSER_H_
+#define LIBTEXTCLASSIFIER_DATETIME_PARSER_H_
+
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "model_generated.h"
+#include "types.h"
+#include "util/base/integral_types.h"
+#include "util/calendar/calendar.h"
+#include "util/utf8/unilib.h"
+
+namespace libtextclassifier2 {
+
+// Parses datetime expressions in the input and resolves them to actual absolute
+// time.
+class DatetimeParser {
+ public:
+ static std::unique_ptr<DatetimeParser> Instance(const DatetimeModel* model,
+ const UniLib& unilib);
+
+ // Parses the dates in 'input' and fills result. Makes sure that the results
+ // do not overlap.
+ bool Parse(const std::string& input, int64 reference_time_ms_utc,
+ const std::string& reference_timezone, const std::string& locales,
+ std::vector<DatetimeParseResultSpan>* results) const;
+
+ // Same as above but takes UnicodeText.
+ bool Parse(const UnicodeText& input, int64 reference_time_ms_utc,
+ const std::string& reference_timezone, const std::string& locales,
+ std::vector<DatetimeParseResultSpan>* results) const;
+
+ protected:
+ DatetimeParser(const DatetimeModel* model, const UniLib& unilib);
+
+ // Returns a list of locale ids for given locale spec string (comma-separated
+ // locale names).
+ std::vector<int> ParseLocales(const std::string& locales) const;
+ bool ParseWithRule(const UniLib::RegexPattern& regex,
+ const DatetimeModelPattern* pattern,
+ const UnicodeText& input, int64 reference_time_ms_utc,
+ const std::string& reference_timezone, const int locale_id,
+ std::vector<DatetimeParseResultSpan>* result) const;
+
+ // Converts the current match in 'matcher' into DatetimeParseResult.
+ bool ExtractDatetime(const UniLib::RegexMatcher& matcher,
+ int64 reference_time_ms_utc,
+ const std::string& reference_timezone, int locale_id,
+ DatetimeParseResult* result) const;
+
+ private:
+ bool initialized_;
+ const UniLib& unilib_;
+ std::vector<const DatetimeModelPattern*> rule_id_to_pattern_;
+ std::vector<std::unique_ptr<UniLib::RegexPattern>> rules_;
+ std::unordered_map<int, std::vector<int>> locale_to_rules_;
+ std::vector<std::unique_ptr<UniLib::RegexPattern>> extractor_rules_;
+ std::unordered_map<DatetimeExtractorType, std::unordered_map<int, int>>
+ type_and_locale_to_extractor_rule_;
+ std::unordered_map<std::string, int> locale_string_to_id_;
+ CalendarLib calendar_lib_;
+};
+
+} // namespace libtextclassifier2
+
+#endif // LIBTEXTCLASSIFIER_DATETIME_PARSER_H_
diff --git a/datetime/parser_test.cc b/datetime/parser_test.cc
new file mode 100644
index 0000000..721cb86
--- /dev/null
+++ b/datetime/parser_test.cc
@@ -0,0 +1,235 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <time.h>
+#include <fstream>
+#include <iostream>
+#include <memory>
+#include <string>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+#include "datetime/parser.h"
+#include "model_generated.h"
+#include "types-test-util.h"
+
+using testing::ElementsAreArray;
+
+namespace libtextclassifier2 {
+namespace {
+
+std::string GetModelPath() {
+ return LIBTEXTCLASSIFIER_TEST_DATA_DIR;
+}
+
+std::string ReadFile(const std::string& file_name) {
+ std::ifstream file_stream(file_name);
+ return std::string(std::istreambuf_iterator<char>(file_stream), {});
+}
+
+std::string FormatMillis(int64 time_ms_utc) {
+ long time_seconds = time_ms_utc / 1000; // NOLINT
+ // Format time, "ddd yyyy-mm-dd hh:mm:ss zzz"
+ char buffer[512];
+ strftime(buffer, sizeof(buffer), "%a %Y-%m-%d %H:%M:%S %Z",
+ localtime(&time_seconds));
+ return std::string(buffer);
+}
+
+class ParserTest : public testing::Test {
+ public:
+ void SetUp() override {
+ model_buffer_ = ReadFile(GetModelPath() + "test_model.fb");
+ const Model* model = GetModel(model_buffer_.data());
+ ASSERT_TRUE(model != nullptr);
+ ASSERT_TRUE(model->datetime_model() != nullptr);
+ parser_ = DatetimeParser::Instance(model->datetime_model(), unilib_);
+ }
+
+ bool ParsesCorrectly(const std::string& marked_text,
+ const int64 expected_ms_utc,
+ DatetimeGranularity expected_granularity,
+ const std::string& timezone = "Europe/Zurich") {
+ auto expected_start_index = marked_text.find("{");
+ EXPECT_TRUE(expected_start_index != std::string::npos);
+ auto expected_end_index = marked_text.find("}");
+ EXPECT_TRUE(expected_end_index != std::string::npos);
+
+ std::string text;
+ text += std::string(marked_text.begin(),
+ marked_text.begin() + expected_start_index);
+ text += std::string(marked_text.begin() + expected_start_index + 1,
+ marked_text.begin() + expected_end_index);
+ text += std::string(marked_text.begin() + expected_end_index + 1,
+ marked_text.end());
+
+ std::vector<DatetimeParseResultSpan> results;
+
+ if (!parser_->Parse(text, 0, timezone, /*locales=*/"", &results)) {
+ TC_LOG(ERROR) << text;
+ TC_CHECK(false);
+ }
+ EXPECT_TRUE(!results.empty());
+
+ std::vector<DatetimeParseResultSpan> filtered_results;
+ for (const DatetimeParseResultSpan& result : results) {
+ if (SpansOverlap(result.span,
+ {expected_start_index, expected_end_index})) {
+ filtered_results.push_back(result);
+ }
+ }
+
+ const std::vector<DatetimeParseResultSpan> expected{
+ {{expected_start_index, expected_end_index - 1},
+ {expected_ms_utc, expected_granularity},
+ /*target_classification_score=*/1.0,
+ /*priority_score=*/0.0}};
+ const bool matches =
+ testing::Matches(ElementsAreArray(expected))(filtered_results);
+ if (!matches) {
+ TC_LOG(ERROR) << "Expected: " << expected[0] << " which corresponds to: "
+ << FormatMillis(expected[0].data.time_ms_utc);
+ for (int i = 0; i < filtered_results.size(); ++i) {
+ TC_LOG(ERROR) << "Actual[" << i << "]: " << filtered_results[i]
+ << " which corresponds to: "
+ << FormatMillis(filtered_results[i].data.time_ms_utc);
+ }
+ }
+ return matches;
+ }
+
+ protected:
+ std::string model_buffer_;
+ std::unique_ptr<DatetimeParser> parser_;
+ UniLib unilib_;
+};
+
+// Test with just a few cases to make debugging of general failures easier.
+TEST_F(ParserTest, ParseShort) {
+ EXPECT_TRUE(
+ ParsesCorrectly("{January 1, 1988}", 567990000000, GRANULARITY_DAY));
+ EXPECT_TRUE(
+ ParsesCorrectly("{three days ago}", -262800000, GRANULARITY_YEAR));
+}
+
+TEST_F(ParserTest, Parse) {
+ EXPECT_TRUE(
+ ParsesCorrectly("{January 1, 1988}", 567990000000, GRANULARITY_DAY));
+ EXPECT_TRUE(ParsesCorrectly("{1 2 2018}", 1517439600000, GRANULARITY_DAY));
+ EXPECT_TRUE(
+ ParsesCorrectly("{january 31 2018}", 1517353200000, GRANULARITY_DAY));
+ EXPECT_TRUE(ParsesCorrectly("lorem {1 january 2018} ipsum", 1514761200000,
+ GRANULARITY_DAY));
+ EXPECT_TRUE(ParsesCorrectly("{19/apr/2010:06:36:15}", 1271651775000,
+ GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectly("{09/Mar/2004 22:02:40}", 1078866160000,
+ GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectly("{Dec 2, 2010 2:39:58 AM}", 1291253998000,
+ GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectly("{Jun 09 2011 15:28:14}", 1307626094000,
+ GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectly("{Apr 20 00:00:35 2010}", 1271714435000,
+ GRANULARITY_SECOND));
+ EXPECT_TRUE(
+ ParsesCorrectly("{Mar 16 08:12:04}", 6419524000, GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectly("{2012-10-14T22:11:20}", 1350245480000,
+ GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectly("{2014-07-01T14:59:55.711Z}", 1404219595000,
+ GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectly("{2010-06-26 02:31:29,573}", 1277512289000,
+ GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectly("{2006/01/22 04:11:05}", 1137899465000,
+ GRANULARITY_SECOND));
+ EXPECT_TRUE(
+ ParsesCorrectly("{150423 11:42:35}", 1429782155000, GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectly("{11:42:35}", 38555000, GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectly("{11:42:35.173}", 38555000, GRANULARITY_SECOND));
+ EXPECT_TRUE(
+ ParsesCorrectly("{23/Apr 11:42:35,173}", 9715355000, GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectly("{23/Apr/2015:11:42:35}", 1429782155000,
+ GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectly("{23/Apr/2015 11:42:35}", 1429782155000,
+ GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectly("{23-Apr-2015 11:42:35}", 1429782155000,
+ GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectly("{23-Apr-2015 11:42:35.883}", 1429782155000,
+ GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectly("{23 Apr 2015 11:42:35}", 1429782155000,
+ GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectly("{23 Apr 2015 11:42:35.883}", 1429782155000,
+ GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectly("{04/23/15 11:42:35}", 1429782155000,
+ GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectly("{04/23/2015 11:42:35}", 1429782155000,
+ GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectly("{04/23/2015 11:42:35.883}", 1429782155000,
+ GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectly("{8/5/2011 3:31:18 AM:234}", 1312507878000,
+ GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectly("{9/28/2011 2:23:15 PM}", 1317212595000,
+ GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectly("{19/apr/2010:06:36:15}", 1271651775000,
+ GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectly(
+ "Are sentiments apartments decisively the especially alteration. "
+ "Thrown shy denote ten ladies though ask saw. Or by to he going "
+ "think order event music. Incommode so intention defective at "
+ "convinced. Led income months itself and houses you. After nor "
+ "you leave might share court balls. {19/apr/2010:06:36:15} Are "
+ "sentiments apartments decisively the especially alteration. "
+ "Thrown shy denote ten ladies though ask saw. Or by to he going "
+ "think order event music. Incommode so intention defective at "
+ "convinced. Led income months itself and houses you. After nor "
+ "you leave might share court balls. ",
+ 1271651775000, GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectly("{january 1 2018 at 4:30}", 1514777400000,
+ GRANULARITY_MINUTE));
+ EXPECT_TRUE(ParsesCorrectly("{january 1 2018 at 4}", 1514775600000,
+ GRANULARITY_HOUR));
+ EXPECT_TRUE(ParsesCorrectly("{january 1 2018 at 4pm}", 1514818800000,
+ GRANULARITY_HOUR));
+
+ // WARNING: The following cases have incorrect granularity.
+ // TODO(zilka): Fix this, when granularity is correctly implemented in
+ // InterpretParseData.
+ EXPECT_TRUE(ParsesCorrectly("{today}", -3600000, GRANULARITY_YEAR));
+ EXPECT_TRUE(ParsesCorrectly("{today}", -57600000, GRANULARITY_YEAR,
+ "America/Los_Angeles"));
+ EXPECT_TRUE(ParsesCorrectly("{next week}", 255600000, GRANULARITY_YEAR));
+ EXPECT_TRUE(ParsesCorrectly("{next day}", 82800000, GRANULARITY_YEAR));
+ EXPECT_TRUE(ParsesCorrectly("{in three days}", 255600000, GRANULARITY_YEAR));
+ EXPECT_TRUE(
+ ParsesCorrectly("{in three weeks}", 1465200000, GRANULARITY_YEAR));
+ EXPECT_TRUE(ParsesCorrectly("{tomorrow}", 82800000, GRANULARITY_YEAR));
+ EXPECT_TRUE(
+ ParsesCorrectly("{tomorrow at 4:00}", 97200000, GRANULARITY_MINUTE));
+ EXPECT_TRUE(ParsesCorrectly("{tomorrow at 4}", 97200000, GRANULARITY_HOUR));
+ EXPECT_TRUE(ParsesCorrectly("{next wednesday}", 514800000, GRANULARITY_YEAR));
+ EXPECT_TRUE(
+ ParsesCorrectly("{next wednesday at 4}", 529200000, GRANULARITY_HOUR));
+ EXPECT_TRUE(ParsesCorrectly("last seen {today at 9:01 PM}", 72060000,
+ GRANULARITY_MINUTE));
+ EXPECT_TRUE(
+ ParsesCorrectly("{Three days ago}", -262800000, GRANULARITY_YEAR));
+ EXPECT_TRUE(
+ ParsesCorrectly("{three days ago}", -262800000, GRANULARITY_YEAR));
+}
+
+// TODO(zilka): Add a test that tests multiple locales.
+
+} // namespace
+} // namespace libtextclassifier2
diff --git a/feature-processor.cc b/feature-processor.cc
index c607b13..76a4c83 100644
--- a/feature-processor.cc
+++ b/feature-processor.cc
@@ -111,8 +111,8 @@
}
}
-UniLib* MaybeCreateUnilib(UniLib* unilib,
- std::unique_ptr<UniLib>* owned_unilib) {
+const UniLib* MaybeCreateUnilib(const UniLib* unilib,
+ std::unique_ptr<UniLib>* owned_unilib) {
if (unilib) {
return unilib;
} else {
@@ -128,6 +128,12 @@
std::vector<Token>* tokens) const {
const UnicodeText context_unicode = UTF8ToUnicodeText(context,
/*do_copy=*/false);
+ StripTokensFromOtherLines(context_unicode, span, tokens);
+}
+
+void FeatureProcessor::StripTokensFromOtherLines(
+ const UnicodeText& context_unicode, CodepointSpan span,
+ std::vector<Token>* tokens) const {
std::vector<UnicodeTextRange> lines = SplitContext(context_unicode);
auto span_start = context_unicode.begin();
@@ -168,28 +174,33 @@
return (*options_->collections())[options_->default_collection()]->str();
}
+std::vector<Token> FeatureProcessor::Tokenize(const std::string& text) const {
+ const UnicodeText text_unicode = UTF8ToUnicodeText(text, /*do_copy=*/false);
+ return Tokenize(text_unicode);
+}
+
std::vector<Token> FeatureProcessor::Tokenize(
- const std::string& utf8_text) const {
+ const UnicodeText& text_unicode) const {
if (options_->tokenization_type() ==
FeatureProcessorOptions_::TokenizationType_INTERNAL_TOKENIZER) {
- return tokenizer_.Tokenize(utf8_text);
+ return tokenizer_.Tokenize(text_unicode);
} else if (options_->tokenization_type() ==
FeatureProcessorOptions_::TokenizationType_ICU ||
options_->tokenization_type() ==
FeatureProcessorOptions_::TokenizationType_MIXED) {
std::vector<Token> result;
- if (!ICUTokenize(utf8_text, &result)) {
+ if (!ICUTokenize(text_unicode, &result)) {
return {};
}
if (options_->tokenization_type() ==
FeatureProcessorOptions_::TokenizationType_MIXED) {
- InternalRetokenize(utf8_text, &result);
+ InternalRetokenize(text_unicode, &result);
}
return result;
} else {
TC_LOG(ERROR) << "Unknown tokenization type specified. Using "
"internal.";
- return tokenizer_.Tokenize(utf8_text);
+ return tokenizer_.Tokenize(text_unicode);
}
}
@@ -565,22 +576,21 @@
std::vector<UnicodeTextRange> FeatureProcessor::SplitContext(
const UnicodeText& context_unicode) const {
- if (options_->only_use_line_with_click()) {
- std::vector<UnicodeTextRange> lines;
- std::set<char32> codepoints;
- codepoints.insert('\n');
- codepoints.insert('|');
- FindSubstrings(context_unicode, codepoints, &lines);
- return lines;
- } else {
- return {{context_unicode.begin(), context_unicode.end()}};
- }
+ std::vector<UnicodeTextRange> lines;
+ const std::set<char32> codepoints{{'\n', '|'}};
+ FindSubstrings(context_unicode, codepoints, &lines);
+ return lines;
}
CodepointSpan FeatureProcessor::StripBoundaryCodepoints(
const std::string& context, CodepointSpan span) const {
const UnicodeText context_unicode =
UTF8ToUnicodeText(context, /*do_copy=*/false);
+ return StripBoundaryCodepoints(context_unicode, span);
+}
+
+CodepointSpan FeatureProcessor::StripBoundaryCodepoints(
+ const UnicodeText& context_unicode, CodepointSpan span) const {
UnicodeText::const_iterator span_begin = context_unicode.begin();
std::advance(span_begin, span.first);
UnicodeText::const_iterator span_end = context_unicode.begin();
@@ -683,17 +693,29 @@
void FeatureProcessor::TokenizeAndFindClick(const std::string& context,
CodepointSpan input_span,
+ bool only_use_line_with_click,
+ std::vector<Token>* tokens,
+ int* click_pos) const {
+ const UnicodeText context_unicode =
+ UTF8ToUnicodeText(context, /*do_copy=*/false);
+ TokenizeAndFindClick(context_unicode, input_span, only_use_line_with_click,
+ tokens, click_pos);
+}
+
+void FeatureProcessor::TokenizeAndFindClick(const UnicodeText& context_unicode,
+ CodepointSpan input_span,
+ bool only_use_line_with_click,
std::vector<Token>* tokens,
int* click_pos) const {
TC_CHECK(tokens != nullptr);
- *tokens = Tokenize(context);
+ *tokens = Tokenize(context_unicode);
if (options_->split_tokens_on_selection_boundaries()) {
internal::SplitTokensOnSelectionBoundaries(input_span, tokens);
}
- if (options_->only_use_line_with_click()) {
- StripTokensFromOtherLines(context, input_span, tokens);
+ if (only_use_line_with_click) {
+ StripTokensFromOtherLines(context_unicode, input_span, tokens);
}
int local_click_pos;
@@ -748,18 +770,9 @@
bool FeatureProcessor::ExtractFeatures(
const std::vector<Token>& tokens, TokenSpan token_span,
+ CodepointSpan selection_span_for_feature,
EmbeddingExecutor* embedding_executor, int feature_vector_size,
std::unique_ptr<CachedFeatures>* cached_features) const {
- if (options_->feature_version() < 2) {
- TC_LOG(ERROR) << "Unsupported feature version.";
- return false;
- }
- if (!options_->bounds_sensitive_features() ||
- !options_->bounds_sensitive_features()->enabled()) {
- TC_LOG(ERROR) << "Bounds-sensitive features not enabled.";
- return false;
- }
-
if (options_->min_supported_codepoint_ratio() > 0) {
const float supported_codepoint_ratio =
SupportedCodepointsRatio(token_span, tokens);
@@ -775,9 +788,10 @@
for (int i = token_span.first; i < token_span.second; ++i) {
const Token& token = tokens[i];
const int features_index = i - token_span.first;
- if (!feature_extractor_.Extract(token, false,
- &(sparse_features[features_index]),
- &(dense_features[features_index]))) {
+ if (!feature_extractor_.Extract(
+ token, token.IsContainedInSpan(selection_span_for_feature),
+ &(sparse_features[features_index]),
+ &(dense_features[features_index]))) {
TC_LOG(ERROR) << "Could not extract token's features: " << token;
return false;
}
@@ -791,23 +805,25 @@
return false;
}
- cached_features->reset(new CachedFeatures(
- token_span, sparse_features, dense_features, padding_sparse_features,
- padding_dense_features, options_->bounds_sensitive_features(),
- embedding_executor, feature_vector_size));
+ *cached_features =
+ CachedFeatures::Create(token_span, sparse_features, dense_features,
+ padding_sparse_features, padding_dense_features,
+ options_, embedding_executor, feature_vector_size);
+ if (!*cached_features) {
+ TC_LOG(ERROR) << "Cound not create cached features.";
+ return false;
+ }
return true;
}
-bool FeatureProcessor::ICUTokenize(const std::string& context,
+bool FeatureProcessor::ICUTokenize(const UnicodeText& context_unicode,
std::vector<Token>* result) const {
std::unique_ptr<UniLib::BreakIterator> break_iterator =
- unilib_->CreateBreakIterator(context);
+ unilib_->CreateBreakIterator(context_unicode);
if (!break_iterator) {
return false;
}
-
- UnicodeText context_unicode = UTF8ToUnicodeText(context, /*do_copy=*/false);
int last_break_index = 0;
int break_index = 0;
int last_unicode_index = 0;
@@ -845,11 +861,8 @@
return true;
}
-void FeatureProcessor::InternalRetokenize(const std::string& context,
+void FeatureProcessor::InternalRetokenize(const UnicodeText& unicode_text,
std::vector<Token>* tokens) const {
- const UnicodeText unicode_text =
- UTF8ToUnicodeText(context, /*do_copy=*/false);
-
std::vector<Token> result;
CodepointSpan span(-1, -1);
for (Token& token : *tokens) {
diff --git a/feature-processor.h b/feature-processor.h
index 834c260..e6f33d6 100644
--- a/feature-processor.h
+++ b/feature-processor.h
@@ -16,8 +16,8 @@
// Feature processing for FFModel (feed-forward SmartSelection model).
-#ifndef KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_FEATURE_PROCESSOR_H_
-#define KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_FEATURE_PROCESSOR_H_
+#ifndef LIBTEXTCLASSIFIER_FEATURE_PROCESSOR_H_
+#define LIBTEXTCLASSIFIER_FEATURE_PROCESSOR_H_
#include <map>
#include <memory>
@@ -66,8 +66,8 @@
// If unilib is not nullptr, just returns unilib. Otherwise, if unilib is
// nullptr, will create UniLib, assign ownership to owned_unilib, and return it.
-UniLib* MaybeCreateUnilib(UniLib* unilib,
- std::unique_ptr<UniLib>* owned_unilib);
+const UniLib* MaybeCreateUnilib(const UniLib* unilib,
+ std::unique_ptr<UniLib>* owned_unilib);
} // namespace internal
@@ -89,7 +89,7 @@
// If unilib is nullptr, will create and own an instance of a UniLib,
// otherwise will use what's passed in.
explicit FeatureProcessor(const FeatureProcessorOptions* options,
- UniLib* unilib = nullptr)
+ const UniLib* unilib = nullptr)
: owned_unilib_(nullptr),
unilib_(internal::MaybeCreateUnilib(unilib, &owned_unilib_)),
feature_extractor_(internal::BuildTokenFeatureExtractorOptions(options),
@@ -117,7 +117,10 @@
}
// Tokenizes the input string using the selected tokenization method.
- std::vector<Token> Tokenize(const std::string& utf8_text) const;
+ std::vector<Token> Tokenize(const std::string& text) const;
+
+ // Same as above but takes UnicodeText.
+ std::vector<Token> Tokenize(const UnicodeText& text_unicode) const;
// Converts a label into a token span.
bool LabelToTokenSpan(int label, TokenSpan* token_span) const;
@@ -139,11 +142,19 @@
// Tokenizes the context and input span, and finds the click position.
void TokenizeAndFindClick(const std::string& context,
CodepointSpan input_span,
+ bool only_use_line_with_click,
+ std::vector<Token>* tokens, int* click_pos) const;
+
+ // Same as above but takes UnicodeText.
+ void TokenizeAndFindClick(const UnicodeText& context_unicode,
+ CodepointSpan input_span,
+ bool only_use_line_with_click,
std::vector<Token>* tokens, int* click_pos) const;
// Extracts features as a CachedFeatures object that can be used for repeated
// inference over token spans in the given context.
bool ExtractFeatures(const std::vector<Token>& tokens, TokenSpan token_span,
+ CodepointSpan selection_span_for_feature,
EmbeddingExecutor* embedding_executor,
int feature_vector_size,
std::unique_ptr<CachedFeatures>* cached_features) const;
@@ -161,7 +172,7 @@
int EmbeddingSize() const { return options_->embedding_size(); }
- // Splits context to several segments according to configuration.
+ // Splits context to several segments.
std::vector<UnicodeTextRange> SplitContext(
const UnicodeText& context_unicode) const;
@@ -171,6 +182,10 @@
CodepointSpan StripBoundaryCodepoints(const std::string& context,
CodepointSpan span) const;
+ // Same as above but takes UnicodeText.
+ CodepointSpan StripBoundaryCodepoints(const UnicodeText& context_unicode,
+ CodepointSpan span) const;
+
protected:
// Represents a codepoint range [start, end).
struct CodepointRange {
@@ -241,12 +256,12 @@
const std::vector<Token>& tokens) const;
// Tokenizes the input text using ICU tokenizer.
- bool ICUTokenize(const std::string& context,
+ bool ICUTokenize(const UnicodeText& context_unicode,
std::vector<Token>* result) const;
// Takes the result of ICU tokenization and retokenizes stretches of tokens
// made of a specific subset of characters using the internal tokenizer.
- void InternalRetokenize(const std::string& context,
+ void InternalRetokenize(const UnicodeText& unicode_text,
std::vector<Token>* tokens) const;
// Tokenizes a substring of the unicode string, appending the resulting tokens
@@ -260,9 +275,14 @@
void StripTokensFromOtherLines(const std::string& context, CodepointSpan span,
std::vector<Token>* tokens) const;
+ // Same as above but takes UnicodeText.
+ void StripTokensFromOtherLines(const UnicodeText& context_unicode,
+ CodepointSpan span,
+ std::vector<Token>* tokens) const;
+
private:
std::unique_ptr<UniLib> owned_unilib_;
- UniLib* unilib_;
+ const UniLib* unilib_;
protected:
const TokenFeatureExtractor feature_extractor_;
@@ -296,4 +316,4 @@
} // namespace libtextclassifier2
-#endif // KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_FEATURE_PROCESSOR_H_
+#endif // LIBTEXTCLASSIFIER_FEATURE_PROCESSOR_H_
diff --git a/feature-processor_test.cc b/feature-processor_test.cc
index 5af8b96..78977d4 100644
--- a/feature-processor_test.cc
+++ b/feature-processor_test.cc
@@ -147,11 +147,13 @@
}
TEST(FeatureProcessorTest, KeepLineWithClickFirst) {
+ CREATE_UNILIB_FOR_TESTING
FeatureProcessorOptionsT options;
options.only_use_line_with_click = true;
flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
TestingFeatureProcessor feature_processor(
- flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()));
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
+ &unilib);
const std::string context = "Fiřst Lině\nSěcond Lině\nThiřd Lině";
const CodepointSpan span = {0, 5};
@@ -171,11 +173,13 @@
}
TEST(FeatureProcessorTest, KeepLineWithClickSecond) {
+ CREATE_UNILIB_FOR_TESTING
FeatureProcessorOptionsT options;
options.only_use_line_with_click = true;
flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
TestingFeatureProcessor feature_processor(
- flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()));
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
+ &unilib);
const std::string context = "Fiřst Lině\nSěcond Lině\nThiřd Lině";
const CodepointSpan span = {18, 22};
@@ -195,11 +199,13 @@
}
TEST(FeatureProcessorTest, KeepLineWithClickThird) {
+ CREATE_UNILIB_FOR_TESTING
FeatureProcessorOptionsT options;
options.only_use_line_with_click = true;
flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
TestingFeatureProcessor feature_processor(
- flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()));
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
+ &unilib);
const std::string context = "Fiřst Lině\nSěcond Lině\nThiřd Lině";
const CodepointSpan span = {24, 33};
@@ -219,11 +225,13 @@
}
TEST(FeatureProcessorTest, KeepLineWithClickSecondWithPipe) {
+ CREATE_UNILIB_FOR_TESTING
FeatureProcessorOptionsT options;
options.only_use_line_with_click = true;
flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
TestingFeatureProcessor feature_processor(
- flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()));
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
+ &unilib);
const std::string context = "Fiřst Lině|Sěcond Lině\nThiřd Lině";
const CodepointSpan span = {18, 22};
@@ -243,11 +251,13 @@
}
TEST(FeatureProcessorTest, KeepLineWithCrosslineClick) {
+ CREATE_UNILIB_FOR_TESTING
FeatureProcessorOptionsT options;
options.only_use_line_with_click = true;
flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
TestingFeatureProcessor feature_processor(
- flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()));
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
+ &unilib);
const std::string context = "Fiřst Lině\nSěcond Lině\nThiřd Lině";
const CodepointSpan span = {5, 23};
@@ -269,6 +279,7 @@
}
TEST(FeatureProcessorTest, SpanToLabel) {
+ CREATE_UNILIB_FOR_TESTING
FeatureProcessorOptionsT options;
options.context_size = 1;
options.max_selection_span = 1;
@@ -283,7 +294,8 @@
flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
TestingFeatureProcessor feature_processor(
- flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()));
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
+ &unilib);
std::vector<Token> tokens = feature_processor.Tokenize("one, two, three");
ASSERT_EQ(3, tokens.size());
int label;
@@ -301,7 +313,8 @@
flatbuffers::DetachedBuffer options2_fb =
PackFeatureProcessorOptions(options);
TestingFeatureProcessor feature_processor2(
- flatbuffers::GetRoot<FeatureProcessorOptions>(options2_fb.data()));
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options2_fb.data()),
+ &unilib);
int label2;
ASSERT_TRUE(feature_processor2.SpanToLabel({5, 8}, tokens, &label2));
EXPECT_EQ(label, label2);
@@ -322,7 +335,8 @@
flatbuffers::DetachedBuffer options3_fb =
PackFeatureProcessorOptions(options);
TestingFeatureProcessor feature_processor3(
- flatbuffers::GetRoot<FeatureProcessorOptions>(options3_fb.data()));
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options3_fb.data()),
+ &unilib);
tokens = feature_processor3.Tokenize("zero, one, two, three, four");
ASSERT_TRUE(feature_processor3.SpanToLabel({6, 15}, tokens, &label2));
EXPECT_NE(kInvalidLabel, label2);
@@ -340,6 +354,7 @@
}
TEST(FeatureProcessorTest, SpanToLabelIgnoresPunctuation) {
+ CREATE_UNILIB_FOR_TESTING
FeatureProcessorOptionsT options;
options.context_size = 1;
options.max_selection_span = 1;
@@ -354,7 +369,8 @@
flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
TestingFeatureProcessor feature_processor(
- flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()));
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
+ &unilib);
std::vector<Token> tokens = feature_processor.Tokenize("one, two, three");
ASSERT_EQ(3, tokens.size());
int label;
@@ -372,7 +388,8 @@
flatbuffers::DetachedBuffer options2_fb =
PackFeatureProcessorOptions(options);
TestingFeatureProcessor feature_processor2(
- flatbuffers::GetRoot<FeatureProcessorOptions>(options2_fb.data()));
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options2_fb.data()),
+ &unilib);
int label2;
ASSERT_TRUE(feature_processor2.SpanToLabel({5, 8}, tokens, &label2));
EXPECT_EQ(label, label2);
@@ -393,7 +410,8 @@
flatbuffers::DetachedBuffer options3_fb =
PackFeatureProcessorOptions(options);
TestingFeatureProcessor feature_processor3(
- flatbuffers::GetRoot<FeatureProcessorOptions>(options3_fb.data()));
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options3_fb.data()),
+ &unilib);
tokens = feature_processor3.Tokenize("zero, one, two, three, four");
ASSERT_TRUE(feature_processor3.SpanToLabel({6, 15}, tokens, &label2));
EXPECT_NE(kInvalidLabel, label2);
@@ -524,8 +542,10 @@
}
flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
+ CREATE_UNILIB_FOR_TESTING
TestingFeatureProcessor feature_processor(
- flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()));
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
+ &unilib);
EXPECT_THAT(feature_processor.SupportedCodepointsRatio(
{0, 3}, feature_processor.Tokenize("aaa bbb ccc")),
FloatEq(1.0));
@@ -565,30 +585,75 @@
flatbuffers::DetachedBuffer options2_fb =
PackFeatureProcessorOptions(options);
TestingFeatureProcessor feature_processor2(
- flatbuffers::GetRoot<FeatureProcessorOptions>(options2_fb.data()));
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options2_fb.data()),
+ &unilib);
EXPECT_TRUE(feature_processor2.ExtractFeatures(
- tokens, {0, 3}, &embedding_executor,
+ tokens, /*token_span=*/{0, 3},
+ /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex},
+ &embedding_executor,
/*feature_vector_size=*/4, &cached_features));
options.min_supported_codepoint_ratio = 0.2;
flatbuffers::DetachedBuffer options3_fb =
PackFeatureProcessorOptions(options);
TestingFeatureProcessor feature_processor3(
- flatbuffers::GetRoot<FeatureProcessorOptions>(options3_fb.data()));
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options3_fb.data()),
+ &unilib);
EXPECT_TRUE(feature_processor3.ExtractFeatures(
- tokens, {0, 3}, &embedding_executor,
+ tokens, /*token_span=*/{0, 3},
+ /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex},
+ &embedding_executor,
/*feature_vector_size=*/4, &cached_features));
options.min_supported_codepoint_ratio = 0.5;
flatbuffers::DetachedBuffer options4_fb =
PackFeatureProcessorOptions(options);
TestingFeatureProcessor feature_processor4(
- flatbuffers::GetRoot<FeatureProcessorOptions>(options4_fb.data()));
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options4_fb.data()),
+ &unilib);
EXPECT_FALSE(feature_processor4.ExtractFeatures(
- tokens, {0, 3}, &embedding_executor,
+ tokens, /*token_span=*/{0, 3},
+ /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex},
+ &embedding_executor,
/*feature_vector_size=*/4, &cached_features));
}
+TEST(FeatureProcessorTest, InSpanFeature) {
+ FeatureProcessorOptionsT options;
+ options.context_size = 2;
+ options.max_selection_span = 2;
+ options.snap_label_span_boundaries_to_containing_tokens = false;
+ options.feature_version = 2;
+ options.embedding_size = 4;
+ options.extract_selection_mask_feature = true;
+
+ flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
+ CREATE_UNILIB_FOR_TESTING
+ TestingFeatureProcessor feature_processor(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
+ &unilib);
+
+ std::unique_ptr<CachedFeatures> cached_features;
+
+ FakeEmbeddingExecutor embedding_executor;
+
+ const std::vector<Token> tokens = {Token("aaa", 0, 3), Token("bbb", 4, 7),
+ Token("ccc", 8, 11), Token("ddd", 12, 15)};
+
+ EXPECT_TRUE(feature_processor.ExtractFeatures(
+ tokens, /*token_span=*/{0, 4},
+ /*selection_span_for_feature=*/{4, 11}, &embedding_executor,
+ /*feature_vector_size=*/5, &cached_features));
+ std::vector<float> features;
+ cached_features->AppendClickContextFeaturesForClick(1, &features);
+ ASSERT_EQ(features.size(), 25);
+ EXPECT_THAT(features[4], FloatEq(0.0));
+ EXPECT_THAT(features[9], FloatEq(0.0));
+ EXPECT_THAT(features[14], FloatEq(1.0));
+ EXPECT_THAT(features[19], FloatEq(1.0));
+ EXPECT_THAT(features[24], FloatEq(0.0));
+}
+
TEST(FeatureProcessorTest, StripUnusedTokensWithNoRelativeClick) {
std::vector<Token> tokens_orig{
Token("0", 0, 0), Token("1", 0, 0), Token("2", 0, 0), Token("3", 0, 0),
@@ -702,6 +767,7 @@
}
TEST(FeatureProcessorTest, InternalTokenizeOnScriptChange) {
+ CREATE_UNILIB_FOR_TESTING
FeatureProcessorOptionsT options;
options.tokenization_codepoint_config.emplace_back(
new TokenizationCodepointRangeT());
@@ -716,7 +782,8 @@
flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
TestingFeatureProcessor feature_processor(
- flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()));
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
+ &unilib);
EXPECT_EQ(feature_processor.Tokenize("앨라배마123웹사이트"),
std::vector<Token>({Token("앨라배마123웹사이트", 0, 11)}));
@@ -725,7 +792,8 @@
flatbuffers::DetachedBuffer options_fb2 =
PackFeatureProcessorOptions(options);
TestingFeatureProcessor feature_processor2(
- flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb2.data()));
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb2.data()),
+ &unilib);
EXPECT_EQ(feature_processor2.Tokenize("앨라배마123웹사이트"),
std::vector<Token>({Token("앨라배마", 0, 4), Token("123", 4, 7),
@@ -839,6 +907,7 @@
#endif
TEST(FeatureProcessorTest, IgnoredSpanBoundaryCodepoints) {
+ CREATE_UNILIB_FOR_TESTING
FeatureProcessorOptionsT options;
options.ignored_span_boundary_codepoints.push_back('.');
options.ignored_span_boundary_codepoints.push_back(',');
@@ -847,7 +916,8 @@
flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
TestingFeatureProcessor feature_processor(
- flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()));
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
+ &unilib);
const std::string text1_utf8 = "ěščř";
const UnicodeText text1 = UTF8ToUnicodeText(text1_utf8, /*do_copy=*/false);
diff --git a/model-executor.cc b/model-executor.cc
index 2b1fc11..fc7d4ae 100644
--- a/model-executor.cc
+++ b/model-executor.cc
@@ -16,6 +16,7 @@
#include "model-executor.h"
+#include "quantization.h"
#include "util/base/logging.h"
namespace libtextclassifier2 {
@@ -40,7 +41,10 @@
} // namespace internal
TFLiteEmbeddingExecutor::TFLiteEmbeddingExecutor(
- const tflite::Model* model_spec) {
+ const tflite::Model* model_spec, const int embedding_size,
+ const int quantization_bits)
+ : quantization_bits_(quantization_bits),
+ output_embedding_size_(embedding_size) {
internal::FromModelSpec(model_spec, &model_, &interpreter_);
if (!interpreter_) {
return;
@@ -58,13 +62,21 @@
scales_->dims->data[1] != 1) {
return;
}
- embedding_size_ = embeddings_->dims->data[1];
+ bytes_per_embedding_ = embeddings_->dims->data[1];
+ if (!CheckQuantizationParams(bytes_per_embedding_, quantization_bits_,
+ output_embedding_size_)) {
+ TC_LOG(ERROR) << "Mismatch in quantization parameters.";
+ return;
+ }
+
initialized_ = true;
}
bool TFLiteEmbeddingExecutor::AddEmbedding(
const TensorView<int>& sparse_features, float* dest, int dest_size) {
- if (!initialized_ || dest_size != embedding_size_) {
+ if (!initialized_ || dest_size != output_embedding_size_) {
+ TC_LOG(ERROR) << "Mismatching dest_size and output_embedding_size: "
+ << dest_size << " " << output_embedding_size_;
return false;
}
const int num_sparse_features = sparse_features.size();
@@ -73,15 +85,11 @@
if (bucket_id >= num_buckets_) {
return false;
}
- const float multiplier = scales_->data.f[bucket_id];
- for (int k = 0; k < embedding_size_; ++k) {
- // Dequantize and add the embedding.
- dest[k] +=
- 1.0 / num_sparse_features *
- (static_cast<int>(
- embeddings_->data.uint8[bucket_id * embedding_size_ + k]) -
- kQuantBias) *
- multiplier;
+
+ if (!DequantizeAdd(scales_->data.f, embeddings_->data.uint8,
+ bytes_per_embedding_, num_sparse_features,
+ quantization_bits_, bucket_id, dest, dest_size)) {
+ return false;
}
}
return true;
diff --git a/model-executor.h b/model-executor.h
index b16d53d..b495427 100644
--- a/model-executor.h
+++ b/model-executor.h
@@ -16,8 +16,8 @@
// Contains classes that can execute different models/parts of a model.
-#ifndef KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_MODEL_EXECUTOR_H_
-#define KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_MODEL_EXECUTOR_H_
+#ifndef LIBTEXTCLASSIFIER_MODEL_EXECUTOR_H_
+#define LIBTEXTCLASSIFIER_MODEL_EXECUTOR_H_
#include <memory>
@@ -81,17 +81,19 @@
// NOTE: This class is not thread-safe.
class TFLiteEmbeddingExecutor : public EmbeddingExecutor {
public:
- explicit TFLiteEmbeddingExecutor(const tflite::Model* model_spec);
+ explicit TFLiteEmbeddingExecutor(const tflite::Model* model_spec,
+ int embedding_size, int quantization_bits);
bool AddEmbedding(const TensorView<int>& sparse_features, float* dest,
int dest_size) override;
bool IsReady() override { return initialized_; }
protected:
- static const int kQuantBias = 128;
+ int quantization_bits_;
bool initialized_ = false;
int num_buckets_ = -1;
- int embedding_size_ = -1;
+ int bytes_per_embedding_ = -1;
+ int output_embedding_size_ = -1;
const TfLiteTensor* scales_ = nullptr;
const TfLiteTensor* embeddings_ = nullptr;
@@ -101,4 +103,4 @@
} // namespace libtextclassifier2
-#endif // KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_MODEL_EXECUTOR_H_
+#endif // LIBTEXTCLASSIFIER_MODEL_EXECUTOR_H_
diff --git a/model.fbs b/model.fbs
index d98e5ac..632703c 100755
--- a/model.fbs
+++ b/model.fbs
@@ -1,145 +1,439 @@
-// Generated from model.proto
-
-namespace libtextclassifier2.TokenizationCodepointRange_;
-
-enum Role : int {
- DEFAULT_ROLE = 0,
- SPLIT_BEFORE = 1,
- SPLIT_AFTER = 2,
- TOKEN_SEPARATOR = 3,
- DISCARD_CODEPOINT = 4,
- WHITESPACE_SEPARATOR = 7,
-}
-
-namespace libtextclassifier2.FeatureProcessorOptions_;
-
-enum CenterTokenSelectionMethod : int {
- DEFAULT_CENTER_TOKEN_METHOD = 0,
- CENTER_TOKEN_FROM_CLICK = 1,
- CENTER_TOKEN_MIDDLE_OF_SELECTION = 2,
-}
-
-enum TokenizationType : int {
- INVALID_TOKENIZATION_TYPE = 0,
- INTERNAL_TOKENIZER = 1,
- ICU = 2,
- MIXED = 3,
-}
+file_identifier "TC2 ";
namespace libtextclassifier2;
-
-table SelectionModelOptions {
- strip_unpaired_brackets:bool;
- symmetry_context_size:int;
+enum DatetimeExtractorType : int {
+ UNKNOWN_DATETIME_EXTRACTOR_TYPE = 0,
+ AM = 1,
+ PM = 2,
+ JANUARY = 3,
+ FEBRUARY = 4,
+ MARCH = 5,
+ APRIL = 6,
+ MAY = 7,
+ JUNE = 8,
+ JULY = 9,
+ AUGUST = 10,
+ SEPTEMBER = 11,
+ OCTOBER = 12,
+ NOVEMBER = 13,
+ DECEMBER = 14,
+ NEXT = 15,
+ NEXT_OR_SAME = 16,
+ LAST = 17,
+ NOW = 18,
+ TOMORROW = 19,
+ YESTERDAY = 20,
+ PAST = 21,
+ FUTURE = 22,
+ DAY = 23,
+ WEEK = 24,
+ MONTH = 25,
+ YEAR = 26,
+ MONDAY = 27,
+ TUESDAY = 28,
+ WEDNESDAY = 29,
+ THURSDAY = 30,
+ FRIDAY = 31,
+ SATURDAY = 32,
+ SUNDAY = 33,
+ DAYS = 34,
+ WEEKS = 35,
+ MONTHS = 36,
+ HOURS = 37,
+ MINUTES = 38,
+ SECONDS = 39,
+ YEARS = 40,
+ DIGITS = 41,
+ SIGNEDDIGITS = 42,
+ ZERO = 43,
+ ONE = 44,
+ TWO = 45,
+ THREE = 46,
+ FOUR = 47,
+ FIVE = 48,
+ SIX = 49,
+ SEVEN = 50,
+ EIGHT = 51,
+ NINE = 52,
+ TEN = 53,
+ ELEVEN = 54,
+ TWELVE = 55,
+ THIRTEEN = 56,
+ FOURTEEN = 57,
+ FIFTEEN = 58,
+ SIXTEEN = 59,
+ SEVENTEEN = 60,
+ EIGHTEEN = 61,
+ NINETEEN = 62,
+ TWENTY = 63,
+ THIRTY = 64,
+ FORTY = 65,
+ FIFTY = 66,
+ SIXTY = 67,
+ SEVENTY = 68,
+ EIGHTY = 69,
+ NINETY = 70,
+ HUNDRED = 71,
+ THOUSAND = 72,
}
+// Options for the model that predicts text selection.
+namespace libtextclassifier2;
+table SelectionModelOptions {
+ // If true, before the selection is returned, the unpaired brackets contained
+ // in the predicted selection are stripped from the both selection ends.
+ // The bracket codepoints are defined in the Unicode standard:
+ // http://www.unicode.org/Public/UNIDATA/BidiBrackets.txt
+ strip_unpaired_brackets:bool = 1;
+
+ // Number of hypothetical click positions on either side of the actual click
+ // to consider in order to enforce symmetry.
+ symmetry_context_size:int;
+
+ // Number of examples to bundle in one batch for inference.
+ batch_size:int = 1024;
+}
+
+// Options for the model that classifies a text selection.
+namespace libtextclassifier2;
table ClassificationModelOptions {
+ // Limits for phone numbers.
phone_min_num_digits:int = 7;
+
phone_max_num_digits:int = 15;
}
-table RegexModelOptions {
- patterns:[libtextclassifier2.RegexModelOptions_.Pattern];
-}
-
-namespace libtextclassifier2.RegexModelOptions_;
-
+// List of regular expression matchers to check.
+namespace libtextclassifier2.RegexModel_;
table Pattern {
+ // The name of the collection of a match.
collection_name:string;
+
+ // The pattern to check.
+ // Can specify a single capturing group used as match boundaries.
pattern:string;
+
+ // Whether to apply the pattern for annotation.
+ enabled_for_annotation:bool = 0;
+
+ // Whether to apply the pattern for classification.
+ enabled_for_classification:bool = 0;
+
+ // Whether to apply the pattern for selection.
+ enabled_for_selection:bool = 0;
+
+ // The final score to assign to the results of this pattern.
+ target_classification_score:float = 1;
+
+ // Priority score used for conflict resulution with the other models.
+ priority_score:float = 0;
}
namespace libtextclassifier2;
-
-table StructuredRegexModel {
- patterns:[libtextclassifier2.StructuredRegexModel_.StructuredPattern];
-}
-
-namespace libtextclassifier2.StructuredRegexModel_;
-
-table StructuredPattern {
- pattern:string;
- node_names:[string];
+table RegexModel {
+ patterns:[libtextclassifier2.RegexModel_.Pattern];
}
namespace libtextclassifier2;
+table DatetimeModelPattern {
+ // List of regex patterns.
+ regexes:[string];
+ // List of locale indices in DatetimeModel that represent the locales that
+ // these patterns should be used for. If empty, can be used for all locales.
+ locales:[int];
+
+ // The final score to assign to the results of this pattern.
+ target_classification_score:float = 1;
+
+ // Priority score used for conflict resulution with the other models.
+ priority_score:float = 0;
+}
+
+namespace libtextclassifier2;
+table DatetimeModelExtractor {
+ extractor:libtextclassifier2.DatetimeExtractorType;
+ pattern:string;
+ locales:[int];
+}
+
+namespace libtextclassifier2;
+table DatetimeModel {
+ // List of BCP 47 locale strings representing all locales supported by the
+ // model. The individual patterns refer back to them using an index.
+ locales:[string];
+
+ patterns:[libtextclassifier2.DatetimeModelPattern];
+ extractors:[libtextclassifier2.DatetimeModelExtractor];
+}
+
+// Options controlling the output of the models.
+namespace libtextclassifier2;
+table ModelTriggeringOptions {
+ // Lower bound threshold for filtering annotation model outputs.
+ min_annotate_confidence:float = 0;
+}
+
+namespace libtextclassifier2;
table Model {
- language:string;
+ // Comma-separated list of locales supported by the model as BCP 47 tags.
+ locales:string;
+
version:int;
selection_feature_options:libtextclassifier2.FeatureProcessorOptions;
classification_feature_options:libtextclassifier2.FeatureProcessorOptions;
- selection_model:[ubyte];
- classification_model:[ubyte];
- embedding_model:[ubyte];
- regex_options:libtextclassifier2.RegexModelOptions;
+
+ // TFLite models.
+ selection_model:[ubyte] (force_align: 16);
+
+ classification_model:[ubyte] (force_align: 16);
+ embedding_model:[ubyte] (force_align: 16);
+ regex_model:libtextclassifier2.RegexModel;
+
+ // Options for the different models.
selection_options:libtextclassifier2.SelectionModelOptions;
+
classification_options:libtextclassifier2.ClassificationModelOptions;
- regex_model:libtextclassifier2.StructuredRegexModel;
+ datetime_model:libtextclassifier2.DatetimeModel;
+
+ // Options controlling the output of the models.
+ triggering_options:libtextclassifier2.ModelTriggeringOptions;
}
+// Role of the codepoints in the range.
+namespace libtextclassifier2.TokenizationCodepointRange_;
+enum Role : int {
+ // Concatenates the codepoint to the current run of codepoints.
+ DEFAULT_ROLE = 0,
+
+ // Splits a run of codepoints before the current codepoint.
+ SPLIT_BEFORE = 1,
+
+ // Splits a run of codepoints after the current codepoint.
+ SPLIT_AFTER = 2,
+
+ // Each codepoint will be a separate token. Good e.g. for Chinese
+ // characters.
+ TOKEN_SEPARATOR = 3,
+
+ // Discards the codepoint.
+ DISCARD_CODEPOINT = 4,
+
+ // Common values:
+ // Splits on the characters and discards them. Good e.g. for the space
+ // character.
+ WHITESPACE_SEPARATOR = 7,
+}
+
+// Represents a codepoint range [start, end) with its role for tokenization.
+namespace libtextclassifier2;
table TokenizationCodepointRange {
start:int;
end:int;
role:libtextclassifier2.TokenizationCodepointRange_.Role;
+
+ // Integer identifier of the script this range denotes. Negative values are
+ // reserved for Tokenizer's internal use.
script_id:int;
}
-table FeatureProcessorOptions {
- num_buckets:int = -1;
- embedding_size:int = -1;
- context_size:int = -1;
- max_selection_span:int = -1;
- chargram_orders:[int];
- max_word_length:int = 20;
- unicode_aware_features:bool;
- extract_case_feature:bool;
- extract_selection_mask_feature:bool;
- regexp_feature:[string];
- remap_digits:bool;
- lowercase_tokens:bool;
- selection_reduced_output_space:bool;
- collections:[string];
- default_collection:int = -1;
- only_use_line_with_click:bool;
- split_tokens_on_selection_boundaries:bool;
- tokenization_codepoint_config:[libtextclassifier2.TokenizationCodepointRange];
- center_token_selection_method:libtextclassifier2.FeatureProcessorOptions_.CenterTokenSelectionMethod;
- snap_label_span_boundaries_to_containing_tokens:bool;
- supported_codepoint_ranges:[libtextclassifier2.FeatureProcessorOptions_.CodepointRange];
- internal_tokenizer_codepoint_ranges:[libtextclassifier2.FeatureProcessorOptions_.CodepointRange];
- min_supported_codepoint_ratio:float = 0.0;
- feature_version:int;
- tokenization_type:libtextclassifier2.FeatureProcessorOptions_.TokenizationType;
- icu_preserve_whitespace_tokens:bool;
- ignored_span_boundary_codepoints:[int];
- click_random_token_in_selection:bool;
- alternative_collection_map:[libtextclassifier2.FeatureProcessorOptions_.CollectionMapEntry];
- bounds_sensitive_features:libtextclassifier2.FeatureProcessorOptions_.BoundsSensitiveFeatures;
- split_selection_candidates:bool;
- allowed_chargrams:[string];
- tokenize_on_script_change:bool;
+// Method for selecting the center token.
+namespace libtextclassifier2.FeatureProcessorOptions_;
+enum CenterTokenSelectionMethod : int {
+ DEFAULT_CENTER_TOKEN_METHOD = 0,
+
+ // Use click indices to determine the center token.
+ CENTER_TOKEN_FROM_CLICK = 1,
+
+ // Use selection indices to get a token range, and select the middle of it
+ // as the center token.
+ CENTER_TOKEN_MIDDLE_OF_SELECTION = 2,
}
+// Controls the type of tokenization the model will use for the input text.
namespace libtextclassifier2.FeatureProcessorOptions_;
+enum TokenizationType : int {
+ INVALID_TOKENIZATION_TYPE = 0,
+ // Use the internal tokenizer for tokenization.
+ INTERNAL_TOKENIZER = 1,
+
+ // Use ICU for tokenization.
+ ICU = 2,
+
+ // First apply ICU tokenization. Then identify stretches of tokens
+ // consisting only of codepoints in internal_tokenizer_codepoint_ranges
+ // and re-tokenize them using the internal tokenizer.
+ MIXED = 3,
+}
+
+// Range of codepoints start - end, where end is exclusive.
+namespace libtextclassifier2.FeatureProcessorOptions_;
table CodepointRange {
start:int;
end:int;
}
-table CollectionMapEntry {
+// Bounds-sensitive feature extraction configuration go/tc-bounds-sensitive.
+namespace libtextclassifier2.FeatureProcessorOptions_;
+table BoundsSensitiveFeatures {
+ // Enables the extraction of bounds-sensitive features, instead of the click
+ // context features.
+ enabled:bool;
+
+ // The numbers of tokens to extract in specific locations relative to the
+ // bounds.
+ // Immediately before the span.
+ num_tokens_before:int;
+
+ // Inside the span, aligned with the beginning.
+ num_tokens_inside_left:int;
+
+ // Inside the span, aligned with the end.
+ num_tokens_inside_right:int;
+
+ // Immediately after the span.
+ num_tokens_after:int;
+
+ // If true, also extracts the tokens of the entire span and adds up their
+ // features forming one "token" to include in the extracted features.
+ include_inside_bag:bool;
+
+ // If true, includes the selection length (in the number of tokens) as a
+ // feature.
+ include_inside_length:bool;
+}
+
+namespace libtextclassifier2.FeatureProcessorOptions_;
+table AlternativeCollectionMapEntry {
key:string;
value:string;
}
-table BoundsSensitiveFeatures {
- enabled:bool;
- num_tokens_before:int;
- num_tokens_inside_left:int;
- num_tokens_inside_right:int;
- num_tokens_after:int;
- include_inside_bag:bool;
- include_inside_length:bool;
+// TC_STRIP
+// Next ID: 44
+// TC_END_STRIP
+namespace libtextclassifier2;
+table FeatureProcessorOptions {
+ // Number of buckets used for hashing charactergrams.
+ num_buckets:int = -1;
+
+ // Size of the embedding.
+ embedding_size:int = -1;
+
+ // Context size defines the number of words to the left and to the right of
+ // the selected word to be used as context. For example, if context size is
+ // N, then we take N words to the left and N words to the right of the
+ // selected word as its context.
+ context_size:int = -1;
+
+ // Maximum number of words of the context to select in total.
+ max_selection_span:int = -1;
+
+ // Orders of charactergrams to extract. E.g., 2 means character bigrams, 3
+ // character trigrams etc.
+ chargram_orders:[int];
+
+ // Maximum length of a word, in codepoints.
+ max_word_length:int = 20;
+
+ // If true, will use the unicode-aware functionality for extracting features.
+ unicode_aware_features:bool = 0;
+
+ // Whether to extract the token case feature.
+ extract_case_feature:bool = 0;
+
+ // Whether to extract the selection mask feature.
+ extract_selection_mask_feature:bool = 0;
+
+ // List of regexps to run over each token. For each regexp, if there is a
+ // match, a dense feature of 1.0 is emitted. Otherwise -1.0 is used.
+ regexp_feature:[string];
+
+ // Whether to remap all digits to a single number.
+ remap_digits:bool = 0;
+
+ // Whether to lower-case each token before generating hashgrams.
+ lowercase_tokens:bool;
+
+ // If true, the selection classifier output will contain only the selections
+ // that are feasible (e.g., those that are shorter than max_selection_span),
+ // if false, the output will be a complete cross-product of possible
+ // selections to the left and posible selections to the right, including the
+ // infeasible ones.
+ // NOTE: Exists mainly for compatibility with older models that were trained
+ // with the non-reduced output space.
+ selection_reduced_output_space:bool = 1;
+
+ // Collection names.
+ collections:[string];
+
+ // An index of collection in collections to be used if a collection name can't
+ // be mapped to an id.
+ default_collection:int = -1;
+
+ // If true, will split the input by lines, and only use the line that contains
+ // the clicked token.
+ only_use_line_with_click:bool = 0;
+
+ // If true, will split tokens that contain the selection boundary, at the
+ // position of the boundary.
+ // E.g. "foo{bar}@google.com" -> "foo", "bar", "@google.com"
+ split_tokens_on_selection_boundaries:bool = 0;
+
+ // Codepoint ranges that determine how different codepoints are tokenized.
+ // The ranges must not overlap.
+ tokenization_codepoint_config:[libtextclassifier2.TokenizationCodepointRange];
+
+ center_token_selection_method:libtextclassifier2.FeatureProcessorOptions_.CenterTokenSelectionMethod;
+
+ // If true, span boundaries will be snapped to containing tokens and not
+ // required to exactly match token boundaries.
+ snap_label_span_boundaries_to_containing_tokens:bool;
+
+ // A set of codepoint ranges supported by the model.
+ supported_codepoint_ranges:[libtextclassifier2.FeatureProcessorOptions_.CodepointRange];
+
+ // A set of codepoint ranges to use in the mixed tokenization mode to identify
+ // stretches of tokens to re-tokenize using the internal tokenizer.
+ internal_tokenizer_codepoint_ranges:[libtextclassifier2.FeatureProcessorOptions_.CodepointRange];
+
+ // Minimum ratio of supported codepoints in the input context. If the ratio
+ // is lower than this, the feature computation will fail.
+ min_supported_codepoint_ratio:float = 0;
+
+ // Used for versioning the format of features the model expects.
+ // - feature_version == 0:
+ // For each token the features consist of:
+ // - chargram embeddings
+ // - dense features
+ // Chargram embeddings for tokens are concatenated first together,
+ // and at the end, the dense features for the tokens are concatenated
+ // to it. So the resulting feature vector has two regions.
+ feature_version:int = 0;
+
+ tokenization_type:libtextclassifier2.FeatureProcessorOptions_.TokenizationType;
+ icu_preserve_whitespace_tokens:bool = 0;
+
+ // List of codepoints that will be stripped from beginning and end of
+ // predicted spans.
+ ignored_span_boundary_codepoints:[int];
+
+ bounds_sensitive_features:libtextclassifier2.FeatureProcessorOptions_.BoundsSensitiveFeatures;
+
+ // List of allowed charactergrams. The extracted charactergrams are filtered
+ // using this list, and charactergrams that are not present are interpreted as
+ // out-of-vocabulary.
+ // If no allowed_chargrams are specified, all charactergrams are allowed.
+ // The field is typed as bytes type to allow non-UTF8 chargrams.
+ allowed_chargrams:[string];
+
+ // If true, tokens will be also split when the codepoint's script_id changes
+ // as defined in TokenizationCodepointRange.
+ tokenize_on_script_change:bool = 0;
+
+ // Number of bits for quantization for embeddings.
+ embedding_quantization_bits:int = 8;
}
+root_type libtextclassifier2.Model;
diff --git a/model_generated.h b/model_generated.h
index fd11c39..914da8a 100755
--- a/model_generated.h
+++ b/model_generated.h
@@ -17,8 +17,8 @@
// automatically generated by the FlatBuffers compiler, do not modify
-#ifndef FLATBUFFERS_GENERATED_MODEL_LIBTEXTCLASSIFIER2_FEATUREPROCESSOROPTIONS__H_
-#define FLATBUFFERS_GENERATED_MODEL_LIBTEXTCLASSIFIER2_FEATUREPROCESSOROPTIONS__H_
+#ifndef FLATBUFFERS_GENERATED_MODEL_LIBTEXTCLASSIFIER2_H_
+#define FLATBUFFERS_GENERATED_MODEL_LIBTEXTCLASSIFIER2_H_
#include "flatbuffers/flatbuffers.h"
@@ -30,25 +30,27 @@
struct ClassificationModelOptions;
struct ClassificationModelOptionsT;
-struct RegexModelOptions;
-struct RegexModelOptionsT;
-
-namespace RegexModelOptions_ {
+namespace RegexModel_ {
struct Pattern;
struct PatternT;
-} // namespace RegexModelOptions_
+} // namespace RegexModel_
-struct StructuredRegexModel;
-struct StructuredRegexModelT;
+struct RegexModel;
+struct RegexModelT;
-namespace StructuredRegexModel_ {
+struct DatetimeModelPattern;
+struct DatetimeModelPatternT;
-struct StructuredPattern;
-struct StructuredPatternT;
+struct DatetimeModelExtractor;
+struct DatetimeModelExtractorT;
-} // namespace StructuredRegexModel_
+struct DatetimeModel;
+struct DatetimeModelT;
+
+struct ModelTriggeringOptions;
+struct ModelTriggeringOptionsT;
struct Model;
struct ModelT;
@@ -56,22 +58,264 @@
struct TokenizationCodepointRange;
struct TokenizationCodepointRangeT;
-struct FeatureProcessorOptions;
-struct FeatureProcessorOptionsT;
-
namespace FeatureProcessorOptions_ {
struct CodepointRange;
struct CodepointRangeT;
-struct CollectionMapEntry;
-struct CollectionMapEntryT;
-
struct BoundsSensitiveFeatures;
struct BoundsSensitiveFeaturesT;
+struct AlternativeCollectionMapEntry;
+struct AlternativeCollectionMapEntryT;
+
} // namespace FeatureProcessorOptions_
+struct FeatureProcessorOptions;
+struct FeatureProcessorOptionsT;
+
+enum DatetimeExtractorType {
+ DatetimeExtractorType_UNKNOWN_DATETIME_EXTRACTOR_TYPE = 0,
+ DatetimeExtractorType_AM = 1,
+ DatetimeExtractorType_PM = 2,
+ DatetimeExtractorType_JANUARY = 3,
+ DatetimeExtractorType_FEBRUARY = 4,
+ DatetimeExtractorType_MARCH = 5,
+ DatetimeExtractorType_APRIL = 6,
+ DatetimeExtractorType_MAY = 7,
+ DatetimeExtractorType_JUNE = 8,
+ DatetimeExtractorType_JULY = 9,
+ DatetimeExtractorType_AUGUST = 10,
+ DatetimeExtractorType_SEPTEMBER = 11,
+ DatetimeExtractorType_OCTOBER = 12,
+ DatetimeExtractorType_NOVEMBER = 13,
+ DatetimeExtractorType_DECEMBER = 14,
+ DatetimeExtractorType_NEXT = 15,
+ DatetimeExtractorType_NEXT_OR_SAME = 16,
+ DatetimeExtractorType_LAST = 17,
+ DatetimeExtractorType_NOW = 18,
+ DatetimeExtractorType_TOMORROW = 19,
+ DatetimeExtractorType_YESTERDAY = 20,
+ DatetimeExtractorType_PAST = 21,
+ DatetimeExtractorType_FUTURE = 22,
+ DatetimeExtractorType_DAY = 23,
+ DatetimeExtractorType_WEEK = 24,
+ DatetimeExtractorType_MONTH = 25,
+ DatetimeExtractorType_YEAR = 26,
+ DatetimeExtractorType_MONDAY = 27,
+ DatetimeExtractorType_TUESDAY = 28,
+ DatetimeExtractorType_WEDNESDAY = 29,
+ DatetimeExtractorType_THURSDAY = 30,
+ DatetimeExtractorType_FRIDAY = 31,
+ DatetimeExtractorType_SATURDAY = 32,
+ DatetimeExtractorType_SUNDAY = 33,
+ DatetimeExtractorType_DAYS = 34,
+ DatetimeExtractorType_WEEKS = 35,
+ DatetimeExtractorType_MONTHS = 36,
+ DatetimeExtractorType_HOURS = 37,
+ DatetimeExtractorType_MINUTES = 38,
+ DatetimeExtractorType_SECONDS = 39,
+ DatetimeExtractorType_YEARS = 40,
+ DatetimeExtractorType_DIGITS = 41,
+ DatetimeExtractorType_SIGNEDDIGITS = 42,
+ DatetimeExtractorType_ZERO = 43,
+ DatetimeExtractorType_ONE = 44,
+ DatetimeExtractorType_TWO = 45,
+ DatetimeExtractorType_THREE = 46,
+ DatetimeExtractorType_FOUR = 47,
+ DatetimeExtractorType_FIVE = 48,
+ DatetimeExtractorType_SIX = 49,
+ DatetimeExtractorType_SEVEN = 50,
+ DatetimeExtractorType_EIGHT = 51,
+ DatetimeExtractorType_NINE = 52,
+ DatetimeExtractorType_TEN = 53,
+ DatetimeExtractorType_ELEVEN = 54,
+ DatetimeExtractorType_TWELVE = 55,
+ DatetimeExtractorType_THIRTEEN = 56,
+ DatetimeExtractorType_FOURTEEN = 57,
+ DatetimeExtractorType_FIFTEEN = 58,
+ DatetimeExtractorType_SIXTEEN = 59,
+ DatetimeExtractorType_SEVENTEEN = 60,
+ DatetimeExtractorType_EIGHTEEN = 61,
+ DatetimeExtractorType_NINETEEN = 62,
+ DatetimeExtractorType_TWENTY = 63,
+ DatetimeExtractorType_THIRTY = 64,
+ DatetimeExtractorType_FORTY = 65,
+ DatetimeExtractorType_FIFTY = 66,
+ DatetimeExtractorType_SIXTY = 67,
+ DatetimeExtractorType_SEVENTY = 68,
+ DatetimeExtractorType_EIGHTY = 69,
+ DatetimeExtractorType_NINETY = 70,
+ DatetimeExtractorType_HUNDRED = 71,
+ DatetimeExtractorType_THOUSAND = 72,
+ DatetimeExtractorType_MIN = DatetimeExtractorType_UNKNOWN_DATETIME_EXTRACTOR_TYPE,
+ DatetimeExtractorType_MAX = DatetimeExtractorType_THOUSAND
+};
+
+inline DatetimeExtractorType (&EnumValuesDatetimeExtractorType())[73] {
+ static DatetimeExtractorType values[] = {
+ DatetimeExtractorType_UNKNOWN_DATETIME_EXTRACTOR_TYPE,
+ DatetimeExtractorType_AM,
+ DatetimeExtractorType_PM,
+ DatetimeExtractorType_JANUARY,
+ DatetimeExtractorType_FEBRUARY,
+ DatetimeExtractorType_MARCH,
+ DatetimeExtractorType_APRIL,
+ DatetimeExtractorType_MAY,
+ DatetimeExtractorType_JUNE,
+ DatetimeExtractorType_JULY,
+ DatetimeExtractorType_AUGUST,
+ DatetimeExtractorType_SEPTEMBER,
+ DatetimeExtractorType_OCTOBER,
+ DatetimeExtractorType_NOVEMBER,
+ DatetimeExtractorType_DECEMBER,
+ DatetimeExtractorType_NEXT,
+ DatetimeExtractorType_NEXT_OR_SAME,
+ DatetimeExtractorType_LAST,
+ DatetimeExtractorType_NOW,
+ DatetimeExtractorType_TOMORROW,
+ DatetimeExtractorType_YESTERDAY,
+ DatetimeExtractorType_PAST,
+ DatetimeExtractorType_FUTURE,
+ DatetimeExtractorType_DAY,
+ DatetimeExtractorType_WEEK,
+ DatetimeExtractorType_MONTH,
+ DatetimeExtractorType_YEAR,
+ DatetimeExtractorType_MONDAY,
+ DatetimeExtractorType_TUESDAY,
+ DatetimeExtractorType_WEDNESDAY,
+ DatetimeExtractorType_THURSDAY,
+ DatetimeExtractorType_FRIDAY,
+ DatetimeExtractorType_SATURDAY,
+ DatetimeExtractorType_SUNDAY,
+ DatetimeExtractorType_DAYS,
+ DatetimeExtractorType_WEEKS,
+ DatetimeExtractorType_MONTHS,
+ DatetimeExtractorType_HOURS,
+ DatetimeExtractorType_MINUTES,
+ DatetimeExtractorType_SECONDS,
+ DatetimeExtractorType_YEARS,
+ DatetimeExtractorType_DIGITS,
+ DatetimeExtractorType_SIGNEDDIGITS,
+ DatetimeExtractorType_ZERO,
+ DatetimeExtractorType_ONE,
+ DatetimeExtractorType_TWO,
+ DatetimeExtractorType_THREE,
+ DatetimeExtractorType_FOUR,
+ DatetimeExtractorType_FIVE,
+ DatetimeExtractorType_SIX,
+ DatetimeExtractorType_SEVEN,
+ DatetimeExtractorType_EIGHT,
+ DatetimeExtractorType_NINE,
+ DatetimeExtractorType_TEN,
+ DatetimeExtractorType_ELEVEN,
+ DatetimeExtractorType_TWELVE,
+ DatetimeExtractorType_THIRTEEN,
+ DatetimeExtractorType_FOURTEEN,
+ DatetimeExtractorType_FIFTEEN,
+ DatetimeExtractorType_SIXTEEN,
+ DatetimeExtractorType_SEVENTEEN,
+ DatetimeExtractorType_EIGHTEEN,
+ DatetimeExtractorType_NINETEEN,
+ DatetimeExtractorType_TWENTY,
+ DatetimeExtractorType_THIRTY,
+ DatetimeExtractorType_FORTY,
+ DatetimeExtractorType_FIFTY,
+ DatetimeExtractorType_SIXTY,
+ DatetimeExtractorType_SEVENTY,
+ DatetimeExtractorType_EIGHTY,
+ DatetimeExtractorType_NINETY,
+ DatetimeExtractorType_HUNDRED,
+ DatetimeExtractorType_THOUSAND
+ };
+ return values;
+}
+
+inline const char **EnumNamesDatetimeExtractorType() {
+ static const char *names[] = {
+ "UNKNOWN_DATETIME_EXTRACTOR_TYPE",
+ "AM",
+ "PM",
+ "JANUARY",
+ "FEBRUARY",
+ "MARCH",
+ "APRIL",
+ "MAY",
+ "JUNE",
+ "JULY",
+ "AUGUST",
+ "SEPTEMBER",
+ "OCTOBER",
+ "NOVEMBER",
+ "DECEMBER",
+ "NEXT",
+ "NEXT_OR_SAME",
+ "LAST",
+ "NOW",
+ "TOMORROW",
+ "YESTERDAY",
+ "PAST",
+ "FUTURE",
+ "DAY",
+ "WEEK",
+ "MONTH",
+ "YEAR",
+ "MONDAY",
+ "TUESDAY",
+ "WEDNESDAY",
+ "THURSDAY",
+ "FRIDAY",
+ "SATURDAY",
+ "SUNDAY",
+ "DAYS",
+ "WEEKS",
+ "MONTHS",
+ "HOURS",
+ "MINUTES",
+ "SECONDS",
+ "YEARS",
+ "DIGITS",
+ "SIGNEDDIGITS",
+ "ZERO",
+ "ONE",
+ "TWO",
+ "THREE",
+ "FOUR",
+ "FIVE",
+ "SIX",
+ "SEVEN",
+ "EIGHT",
+ "NINE",
+ "TEN",
+ "ELEVEN",
+ "TWELVE",
+ "THIRTEEN",
+ "FOURTEEN",
+ "FIFTEEN",
+ "SIXTEEN",
+ "SEVENTEEN",
+ "EIGHTEEN",
+ "NINETEEN",
+ "TWENTY",
+ "THIRTY",
+ "FORTY",
+ "FIFTY",
+ "SIXTY",
+ "SEVENTY",
+ "EIGHTY",
+ "NINETY",
+ "HUNDRED",
+ "THOUSAND",
+ nullptr
+ };
+ return names;
+}
+
+inline const char *EnumNameDatetimeExtractorType(DatetimeExtractorType e) {
+ const size_t index = static_cast<int>(e);
+ return EnumNamesDatetimeExtractorType()[index];
+}
+
namespace TokenizationCodepointRange_ {
enum Role {
@@ -194,9 +438,11 @@
typedef SelectionModelOptions TableType;
bool strip_unpaired_brackets;
int32_t symmetry_context_size;
+ int32_t batch_size;
SelectionModelOptionsT()
- : strip_unpaired_brackets(false),
- symmetry_context_size(0) {
+ : strip_unpaired_brackets(true),
+ symmetry_context_size(0),
+ batch_size(1024) {
}
};
@@ -204,18 +450,23 @@
typedef SelectionModelOptionsT NativeTableType;
enum {
VT_STRIP_UNPAIRED_BRACKETS = 4,
- VT_SYMMETRY_CONTEXT_SIZE = 6
+ VT_SYMMETRY_CONTEXT_SIZE = 6,
+ VT_BATCH_SIZE = 8
};
bool strip_unpaired_brackets() const {
- return GetField<uint8_t>(VT_STRIP_UNPAIRED_BRACKETS, 0) != 0;
+ return GetField<uint8_t>(VT_STRIP_UNPAIRED_BRACKETS, 1) != 0;
}
int32_t symmetry_context_size() const {
return GetField<int32_t>(VT_SYMMETRY_CONTEXT_SIZE, 0);
}
+ int32_t batch_size() const {
+ return GetField<int32_t>(VT_BATCH_SIZE, 1024);
+ }
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyField<uint8_t>(verifier, VT_STRIP_UNPAIRED_BRACKETS) &&
VerifyField<int32_t>(verifier, VT_SYMMETRY_CONTEXT_SIZE) &&
+ VerifyField<int32_t>(verifier, VT_BATCH_SIZE) &&
verifier.EndTable();
}
SelectionModelOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@@ -227,11 +478,14 @@
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
void add_strip_unpaired_brackets(bool strip_unpaired_brackets) {
- fbb_.AddElement<uint8_t>(SelectionModelOptions::VT_STRIP_UNPAIRED_BRACKETS, static_cast<uint8_t>(strip_unpaired_brackets), 0);
+ fbb_.AddElement<uint8_t>(SelectionModelOptions::VT_STRIP_UNPAIRED_BRACKETS, static_cast<uint8_t>(strip_unpaired_brackets), 1);
}
void add_symmetry_context_size(int32_t symmetry_context_size) {
fbb_.AddElement<int32_t>(SelectionModelOptions::VT_SYMMETRY_CONTEXT_SIZE, symmetry_context_size, 0);
}
+ void add_batch_size(int32_t batch_size) {
+ fbb_.AddElement<int32_t>(SelectionModelOptions::VT_BATCH_SIZE, batch_size, 1024);
+ }
explicit SelectionModelOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
@@ -246,9 +500,11 @@
inline flatbuffers::Offset<SelectionModelOptions> CreateSelectionModelOptions(
flatbuffers::FlatBufferBuilder &_fbb,
- bool strip_unpaired_brackets = false,
- int32_t symmetry_context_size = 0) {
+ bool strip_unpaired_brackets = true,
+ int32_t symmetry_context_size = 0,
+ int32_t batch_size = 1024) {
SelectionModelOptionsBuilder builder_(_fbb);
+ builder_.add_batch_size(batch_size);
builder_.add_symmetry_context_size(symmetry_context_size);
builder_.add_strip_unpaired_brackets(strip_unpaired_brackets);
return builder_.Finish();
@@ -322,76 +578,23 @@
flatbuffers::Offset<ClassificationModelOptions> CreateClassificationModelOptions(flatbuffers::FlatBufferBuilder &_fbb, const ClassificationModelOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
-struct RegexModelOptionsT : public flatbuffers::NativeTable {
- typedef RegexModelOptions TableType;
- std::vector<std::unique_ptr<libtextclassifier2::RegexModelOptions_::PatternT>> patterns;
- RegexModelOptionsT() {
- }
-};
-
-struct RegexModelOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
- typedef RegexModelOptionsT NativeTableType;
- enum {
- VT_PATTERNS = 4
- };
- const flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::RegexModelOptions_::Pattern>> *patterns() const {
- return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::RegexModelOptions_::Pattern>> *>(VT_PATTERNS);
- }
- bool Verify(flatbuffers::Verifier &verifier) const {
- return VerifyTableStart(verifier) &&
- VerifyOffset(verifier, VT_PATTERNS) &&
- verifier.Verify(patterns()) &&
- verifier.VerifyVectorOfTables(patterns()) &&
- verifier.EndTable();
- }
- RegexModelOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
- void UnPackTo(RegexModelOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
- static flatbuffers::Offset<RegexModelOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const RegexModelOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
-};
-
-struct RegexModelOptionsBuilder {
- flatbuffers::FlatBufferBuilder &fbb_;
- flatbuffers::uoffset_t start_;
- void add_patterns(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::RegexModelOptions_::Pattern>>> patterns) {
- fbb_.AddOffset(RegexModelOptions::VT_PATTERNS, patterns);
- }
- explicit RegexModelOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
- : fbb_(_fbb) {
- start_ = fbb_.StartTable();
- }
- RegexModelOptionsBuilder &operator=(const RegexModelOptionsBuilder &);
- flatbuffers::Offset<RegexModelOptions> Finish() {
- const auto end = fbb_.EndTable(start_);
- auto o = flatbuffers::Offset<RegexModelOptions>(end);
- return o;
- }
-};
-
-inline flatbuffers::Offset<RegexModelOptions> CreateRegexModelOptions(
- flatbuffers::FlatBufferBuilder &_fbb,
- flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::RegexModelOptions_::Pattern>>> patterns = 0) {
- RegexModelOptionsBuilder builder_(_fbb);
- builder_.add_patterns(patterns);
- return builder_.Finish();
-}
-
-inline flatbuffers::Offset<RegexModelOptions> CreateRegexModelOptionsDirect(
- flatbuffers::FlatBufferBuilder &_fbb,
- const std::vector<flatbuffers::Offset<libtextclassifier2::RegexModelOptions_::Pattern>> *patterns = nullptr) {
- return libtextclassifier2::CreateRegexModelOptions(
- _fbb,
- patterns ? _fbb.CreateVector<flatbuffers::Offset<libtextclassifier2::RegexModelOptions_::Pattern>>(*patterns) : 0);
-}
-
-flatbuffers::Offset<RegexModelOptions> CreateRegexModelOptions(flatbuffers::FlatBufferBuilder &_fbb, const RegexModelOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
-
-namespace RegexModelOptions_ {
+namespace RegexModel_ {
struct PatternT : public flatbuffers::NativeTable {
typedef Pattern TableType;
std::string collection_name;
std::string pattern;
- PatternT() {
+ bool enabled_for_annotation;
+ bool enabled_for_classification;
+ bool enabled_for_selection;
+ float target_classification_score;
+ float priority_score;
+ PatternT()
+ : enabled_for_annotation(false),
+ enabled_for_classification(false),
+ enabled_for_selection(false),
+ target_classification_score(1.0f),
+ priority_score(0.0f) {
}
};
@@ -399,7 +602,12 @@
typedef PatternT NativeTableType;
enum {
VT_COLLECTION_NAME = 4,
- VT_PATTERN = 6
+ VT_PATTERN = 6,
+ VT_ENABLED_FOR_ANNOTATION = 8,
+ VT_ENABLED_FOR_CLASSIFICATION = 10,
+ VT_ENABLED_FOR_SELECTION = 12,
+ VT_TARGET_CLASSIFICATION_SCORE = 14,
+ VT_PRIORITY_SCORE = 16
};
const flatbuffers::String *collection_name() const {
return GetPointer<const flatbuffers::String *>(VT_COLLECTION_NAME);
@@ -407,12 +615,32 @@
const flatbuffers::String *pattern() const {
return GetPointer<const flatbuffers::String *>(VT_PATTERN);
}
+ bool enabled_for_annotation() const {
+ return GetField<uint8_t>(VT_ENABLED_FOR_ANNOTATION, 0) != 0;
+ }
+ bool enabled_for_classification() const {
+ return GetField<uint8_t>(VT_ENABLED_FOR_CLASSIFICATION, 0) != 0;
+ }
+ bool enabled_for_selection() const {
+ return GetField<uint8_t>(VT_ENABLED_FOR_SELECTION, 0) != 0;
+ }
+ float target_classification_score() const {
+ return GetField<float>(VT_TARGET_CLASSIFICATION_SCORE, 1.0f);
+ }
+ float priority_score() const {
+ return GetField<float>(VT_PRIORITY_SCORE, 0.0f);
+ }
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyOffset(verifier, VT_COLLECTION_NAME) &&
verifier.Verify(collection_name()) &&
VerifyOffset(verifier, VT_PATTERN) &&
verifier.Verify(pattern()) &&
+ VerifyField<uint8_t>(verifier, VT_ENABLED_FOR_ANNOTATION) &&
+ VerifyField<uint8_t>(verifier, VT_ENABLED_FOR_CLASSIFICATION) &&
+ VerifyField<uint8_t>(verifier, VT_ENABLED_FOR_SELECTION) &&
+ VerifyField<float>(verifier, VT_TARGET_CLASSIFICATION_SCORE) &&
+ VerifyField<float>(verifier, VT_PRIORITY_SCORE) &&
verifier.EndTable();
}
PatternT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@@ -429,6 +657,21 @@
void add_pattern(flatbuffers::Offset<flatbuffers::String> pattern) {
fbb_.AddOffset(Pattern::VT_PATTERN, pattern);
}
+ void add_enabled_for_annotation(bool enabled_for_annotation) {
+ fbb_.AddElement<uint8_t>(Pattern::VT_ENABLED_FOR_ANNOTATION, static_cast<uint8_t>(enabled_for_annotation), 0);
+ }
+ void add_enabled_for_classification(bool enabled_for_classification) {
+ fbb_.AddElement<uint8_t>(Pattern::VT_ENABLED_FOR_CLASSIFICATION, static_cast<uint8_t>(enabled_for_classification), 0);
+ }
+ void add_enabled_for_selection(bool enabled_for_selection) {
+ fbb_.AddElement<uint8_t>(Pattern::VT_ENABLED_FOR_SELECTION, static_cast<uint8_t>(enabled_for_selection), 0);
+ }
+ void add_target_classification_score(float target_classification_score) {
+ fbb_.AddElement<float>(Pattern::VT_TARGET_CLASSIFICATION_SCORE, target_classification_score, 1.0f);
+ }
+ void add_priority_score(float priority_score) {
+ fbb_.AddElement<float>(Pattern::VT_PRIORITY_SCORE, priority_score, 0.0f);
+ }
explicit PatternBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
@@ -444,41 +687,61 @@
inline flatbuffers::Offset<Pattern> CreatePattern(
flatbuffers::FlatBufferBuilder &_fbb,
flatbuffers::Offset<flatbuffers::String> collection_name = 0,
- flatbuffers::Offset<flatbuffers::String> pattern = 0) {
+ flatbuffers::Offset<flatbuffers::String> pattern = 0,
+ bool enabled_for_annotation = false,
+ bool enabled_for_classification = false,
+ bool enabled_for_selection = false,
+ float target_classification_score = 1.0f,
+ float priority_score = 0.0f) {
PatternBuilder builder_(_fbb);
+ builder_.add_priority_score(priority_score);
+ builder_.add_target_classification_score(target_classification_score);
builder_.add_pattern(pattern);
builder_.add_collection_name(collection_name);
+ builder_.add_enabled_for_selection(enabled_for_selection);
+ builder_.add_enabled_for_classification(enabled_for_classification);
+ builder_.add_enabled_for_annotation(enabled_for_annotation);
return builder_.Finish();
}
inline flatbuffers::Offset<Pattern> CreatePatternDirect(
flatbuffers::FlatBufferBuilder &_fbb,
const char *collection_name = nullptr,
- const char *pattern = nullptr) {
- return libtextclassifier2::RegexModelOptions_::CreatePattern(
+ const char *pattern = nullptr,
+ bool enabled_for_annotation = false,
+ bool enabled_for_classification = false,
+ bool enabled_for_selection = false,
+ float target_classification_score = 1.0f,
+ float priority_score = 0.0f) {
+ return libtextclassifier2::RegexModel_::CreatePattern(
_fbb,
collection_name ? _fbb.CreateString(collection_name) : 0,
- pattern ? _fbb.CreateString(pattern) : 0);
+ pattern ? _fbb.CreateString(pattern) : 0,
+ enabled_for_annotation,
+ enabled_for_classification,
+ enabled_for_selection,
+ target_classification_score,
+ priority_score);
}
flatbuffers::Offset<Pattern> CreatePattern(flatbuffers::FlatBufferBuilder &_fbb, const PatternT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
-} // namespace RegexModelOptions_
+} // namespace RegexModel_
-struct StructuredRegexModelT : public flatbuffers::NativeTable {
- typedef StructuredRegexModel TableType;
- std::vector<std::unique_ptr<libtextclassifier2::StructuredRegexModel_::StructuredPatternT>> patterns;
- StructuredRegexModelT() {
+struct RegexModelT : public flatbuffers::NativeTable {
+ typedef RegexModel TableType;
+ std::vector<std::unique_ptr<libtextclassifier2::RegexModel_::PatternT>> patterns;
+ RegexModelT() {
}
};
-struct StructuredRegexModel FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
- typedef StructuredRegexModelT NativeTableType;
+struct RegexModel FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef RegexModelT NativeTableType;
enum {
VT_PATTERNS = 4
};
- const flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::StructuredRegexModel_::StructuredPattern>> *patterns() const {
- return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::StructuredRegexModel_::StructuredPattern>> *>(VT_PATTERNS);
+ const flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::RegexModel_::Pattern>> *patterns() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::RegexModel_::Pattern>> *>(VT_PATTERNS);
}
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
@@ -487,141 +750,403 @@
verifier.VerifyVectorOfTables(patterns()) &&
verifier.EndTable();
}
- StructuredRegexModelT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
- void UnPackTo(StructuredRegexModelT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
- static flatbuffers::Offset<StructuredRegexModel> Pack(flatbuffers::FlatBufferBuilder &_fbb, const StructuredRegexModelT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+ RegexModelT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(RegexModelT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<RegexModel> Pack(flatbuffers::FlatBufferBuilder &_fbb, const RegexModelT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
};
-struct StructuredRegexModelBuilder {
+struct RegexModelBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
- void add_patterns(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::StructuredRegexModel_::StructuredPattern>>> patterns) {
- fbb_.AddOffset(StructuredRegexModel::VT_PATTERNS, patterns);
+ void add_patterns(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::RegexModel_::Pattern>>> patterns) {
+ fbb_.AddOffset(RegexModel::VT_PATTERNS, patterns);
}
- explicit StructuredRegexModelBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ explicit RegexModelBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
}
- StructuredRegexModelBuilder &operator=(const StructuredRegexModelBuilder &);
- flatbuffers::Offset<StructuredRegexModel> Finish() {
+ RegexModelBuilder &operator=(const RegexModelBuilder &);
+ flatbuffers::Offset<RegexModel> Finish() {
const auto end = fbb_.EndTable(start_);
- auto o = flatbuffers::Offset<StructuredRegexModel>(end);
+ auto o = flatbuffers::Offset<RegexModel>(end);
return o;
}
};
-inline flatbuffers::Offset<StructuredRegexModel> CreateStructuredRegexModel(
+inline flatbuffers::Offset<RegexModel> CreateRegexModel(
flatbuffers::FlatBufferBuilder &_fbb,
- flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::StructuredRegexModel_::StructuredPattern>>> patterns = 0) {
- StructuredRegexModelBuilder builder_(_fbb);
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::RegexModel_::Pattern>>> patterns = 0) {
+ RegexModelBuilder builder_(_fbb);
builder_.add_patterns(patterns);
return builder_.Finish();
}
-inline flatbuffers::Offset<StructuredRegexModel> CreateStructuredRegexModelDirect(
+inline flatbuffers::Offset<RegexModel> CreateRegexModelDirect(
flatbuffers::FlatBufferBuilder &_fbb,
- const std::vector<flatbuffers::Offset<libtextclassifier2::StructuredRegexModel_::StructuredPattern>> *patterns = nullptr) {
- return libtextclassifier2::CreateStructuredRegexModel(
+ const std::vector<flatbuffers::Offset<libtextclassifier2::RegexModel_::Pattern>> *patterns = nullptr) {
+ return libtextclassifier2::CreateRegexModel(
_fbb,
- patterns ? _fbb.CreateVector<flatbuffers::Offset<libtextclassifier2::StructuredRegexModel_::StructuredPattern>>(*patterns) : 0);
+ patterns ? _fbb.CreateVector<flatbuffers::Offset<libtextclassifier2::RegexModel_::Pattern>>(*patterns) : 0);
}
-flatbuffers::Offset<StructuredRegexModel> CreateStructuredRegexModel(flatbuffers::FlatBufferBuilder &_fbb, const StructuredRegexModelT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+flatbuffers::Offset<RegexModel> CreateRegexModel(flatbuffers::FlatBufferBuilder &_fbb, const RegexModelT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
-namespace StructuredRegexModel_ {
-
-struct StructuredPatternT : public flatbuffers::NativeTable {
- typedef StructuredPattern TableType;
- std::string pattern;
- std::vector<std::string> node_names;
- StructuredPatternT() {
+struct DatetimeModelPatternT : public flatbuffers::NativeTable {
+ typedef DatetimeModelPattern TableType;
+ std::vector<std::string> regexes;
+ std::vector<int32_t> locales;
+ float target_classification_score;
+ float priority_score;
+ DatetimeModelPatternT()
+ : target_classification_score(1.0f),
+ priority_score(0.0f) {
}
};
-struct StructuredPattern FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
- typedef StructuredPatternT NativeTableType;
+struct DatetimeModelPattern FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef DatetimeModelPatternT NativeTableType;
enum {
- VT_PATTERN = 4,
- VT_NODE_NAMES = 6
+ VT_REGEXES = 4,
+ VT_LOCALES = 6,
+ VT_TARGET_CLASSIFICATION_SCORE = 8,
+ VT_PRIORITY_SCORE = 10
};
- const flatbuffers::String *pattern() const {
- return GetPointer<const flatbuffers::String *>(VT_PATTERN);
+ const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *regexes() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_REGEXES);
}
- const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *node_names() const {
- return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_NODE_NAMES);
+ const flatbuffers::Vector<int32_t> *locales() const {
+ return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_LOCALES);
+ }
+ float target_classification_score() const {
+ return GetField<float>(VT_TARGET_CLASSIFICATION_SCORE, 1.0f);
+ }
+ float priority_score() const {
+ return GetField<float>(VT_PRIORITY_SCORE, 0.0f);
}
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
- VerifyOffset(verifier, VT_PATTERN) &&
- verifier.Verify(pattern()) &&
- VerifyOffset(verifier, VT_NODE_NAMES) &&
- verifier.Verify(node_names()) &&
- verifier.VerifyVectorOfStrings(node_names()) &&
+ VerifyOffset(verifier, VT_REGEXES) &&
+ verifier.Verify(regexes()) &&
+ verifier.VerifyVectorOfStrings(regexes()) &&
+ VerifyOffset(verifier, VT_LOCALES) &&
+ verifier.Verify(locales()) &&
+ VerifyField<float>(verifier, VT_TARGET_CLASSIFICATION_SCORE) &&
+ VerifyField<float>(verifier, VT_PRIORITY_SCORE) &&
verifier.EndTable();
}
- StructuredPatternT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
- void UnPackTo(StructuredPatternT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
- static flatbuffers::Offset<StructuredPattern> Pack(flatbuffers::FlatBufferBuilder &_fbb, const StructuredPatternT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+ DatetimeModelPatternT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(DatetimeModelPatternT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<DatetimeModelPattern> Pack(flatbuffers::FlatBufferBuilder &_fbb, const DatetimeModelPatternT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
};
-struct StructuredPatternBuilder {
+struct DatetimeModelPatternBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
- void add_pattern(flatbuffers::Offset<flatbuffers::String> pattern) {
- fbb_.AddOffset(StructuredPattern::VT_PATTERN, pattern);
+ void add_regexes(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> regexes) {
+ fbb_.AddOffset(DatetimeModelPattern::VT_REGEXES, regexes);
}
- void add_node_names(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> node_names) {
- fbb_.AddOffset(StructuredPattern::VT_NODE_NAMES, node_names);
+ void add_locales(flatbuffers::Offset<flatbuffers::Vector<int32_t>> locales) {
+ fbb_.AddOffset(DatetimeModelPattern::VT_LOCALES, locales);
}
- explicit StructuredPatternBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ void add_target_classification_score(float target_classification_score) {
+ fbb_.AddElement<float>(DatetimeModelPattern::VT_TARGET_CLASSIFICATION_SCORE, target_classification_score, 1.0f);
+ }
+ void add_priority_score(float priority_score) {
+ fbb_.AddElement<float>(DatetimeModelPattern::VT_PRIORITY_SCORE, priority_score, 0.0f);
+ }
+ explicit DatetimeModelPatternBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
}
- StructuredPatternBuilder &operator=(const StructuredPatternBuilder &);
- flatbuffers::Offset<StructuredPattern> Finish() {
+ DatetimeModelPatternBuilder &operator=(const DatetimeModelPatternBuilder &);
+ flatbuffers::Offset<DatetimeModelPattern> Finish() {
const auto end = fbb_.EndTable(start_);
- auto o = flatbuffers::Offset<StructuredPattern>(end);
+ auto o = flatbuffers::Offset<DatetimeModelPattern>(end);
return o;
}
};
-inline flatbuffers::Offset<StructuredPattern> CreateStructuredPattern(
+inline flatbuffers::Offset<DatetimeModelPattern> CreateDatetimeModelPattern(
flatbuffers::FlatBufferBuilder &_fbb,
- flatbuffers::Offset<flatbuffers::String> pattern = 0,
- flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> node_names = 0) {
- StructuredPatternBuilder builder_(_fbb);
- builder_.add_node_names(node_names);
- builder_.add_pattern(pattern);
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> regexes = 0,
+ flatbuffers::Offset<flatbuffers::Vector<int32_t>> locales = 0,
+ float target_classification_score = 1.0f,
+ float priority_score = 0.0f) {
+ DatetimeModelPatternBuilder builder_(_fbb);
+ builder_.add_priority_score(priority_score);
+ builder_.add_target_classification_score(target_classification_score);
+ builder_.add_locales(locales);
+ builder_.add_regexes(regexes);
return builder_.Finish();
}
-inline flatbuffers::Offset<StructuredPattern> CreateStructuredPatternDirect(
+inline flatbuffers::Offset<DatetimeModelPattern> CreateDatetimeModelPatternDirect(
flatbuffers::FlatBufferBuilder &_fbb,
- const char *pattern = nullptr,
- const std::vector<flatbuffers::Offset<flatbuffers::String>> *node_names = nullptr) {
- return libtextclassifier2::StructuredRegexModel_::CreateStructuredPattern(
+ const std::vector<flatbuffers::Offset<flatbuffers::String>> *regexes = nullptr,
+ const std::vector<int32_t> *locales = nullptr,
+ float target_classification_score = 1.0f,
+ float priority_score = 0.0f) {
+ return libtextclassifier2::CreateDatetimeModelPattern(
_fbb,
- pattern ? _fbb.CreateString(pattern) : 0,
- node_names ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*node_names) : 0);
+ regexes ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*regexes) : 0,
+ locales ? _fbb.CreateVector<int32_t>(*locales) : 0,
+ target_classification_score,
+ priority_score);
}
-flatbuffers::Offset<StructuredPattern> CreateStructuredPattern(flatbuffers::FlatBufferBuilder &_fbb, const StructuredPatternT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+flatbuffers::Offset<DatetimeModelPattern> CreateDatetimeModelPattern(flatbuffers::FlatBufferBuilder &_fbb, const DatetimeModelPatternT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
-} // namespace StructuredRegexModel_
+struct DatetimeModelExtractorT : public flatbuffers::NativeTable {
+ typedef DatetimeModelExtractor TableType;
+ DatetimeExtractorType extractor;
+ std::string pattern;
+ std::vector<int32_t> locales;
+ DatetimeModelExtractorT()
+ : extractor(DatetimeExtractorType_UNKNOWN_DATETIME_EXTRACTOR_TYPE) {
+ }
+};
+
+struct DatetimeModelExtractor FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef DatetimeModelExtractorT NativeTableType;
+ enum {
+ VT_EXTRACTOR = 4,
+ VT_PATTERN = 6,
+ VT_LOCALES = 8
+ };
+ DatetimeExtractorType extractor() const {
+ return static_cast<DatetimeExtractorType>(GetField<int32_t>(VT_EXTRACTOR, 0));
+ }
+ const flatbuffers::String *pattern() const {
+ return GetPointer<const flatbuffers::String *>(VT_PATTERN);
+ }
+ const flatbuffers::Vector<int32_t> *locales() const {
+ return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_LOCALES);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int32_t>(verifier, VT_EXTRACTOR) &&
+ VerifyOffset(verifier, VT_PATTERN) &&
+ verifier.Verify(pattern()) &&
+ VerifyOffset(verifier, VT_LOCALES) &&
+ verifier.Verify(locales()) &&
+ verifier.EndTable();
+ }
+ DatetimeModelExtractorT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(DatetimeModelExtractorT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<DatetimeModelExtractor> Pack(flatbuffers::FlatBufferBuilder &_fbb, const DatetimeModelExtractorT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct DatetimeModelExtractorBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_extractor(DatetimeExtractorType extractor) {
+ fbb_.AddElement<int32_t>(DatetimeModelExtractor::VT_EXTRACTOR, static_cast<int32_t>(extractor), 0);
+ }
+ void add_pattern(flatbuffers::Offset<flatbuffers::String> pattern) {
+ fbb_.AddOffset(DatetimeModelExtractor::VT_PATTERN, pattern);
+ }
+ void add_locales(flatbuffers::Offset<flatbuffers::Vector<int32_t>> locales) {
+ fbb_.AddOffset(DatetimeModelExtractor::VT_LOCALES, locales);
+ }
+ explicit DatetimeModelExtractorBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ DatetimeModelExtractorBuilder &operator=(const DatetimeModelExtractorBuilder &);
+ flatbuffers::Offset<DatetimeModelExtractor> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<DatetimeModelExtractor>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<DatetimeModelExtractor> CreateDatetimeModelExtractor(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ DatetimeExtractorType extractor = DatetimeExtractorType_UNKNOWN_DATETIME_EXTRACTOR_TYPE,
+ flatbuffers::Offset<flatbuffers::String> pattern = 0,
+ flatbuffers::Offset<flatbuffers::Vector<int32_t>> locales = 0) {
+ DatetimeModelExtractorBuilder builder_(_fbb);
+ builder_.add_locales(locales);
+ builder_.add_pattern(pattern);
+ builder_.add_extractor(extractor);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<DatetimeModelExtractor> CreateDatetimeModelExtractorDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ DatetimeExtractorType extractor = DatetimeExtractorType_UNKNOWN_DATETIME_EXTRACTOR_TYPE,
+ const char *pattern = nullptr,
+ const std::vector<int32_t> *locales = nullptr) {
+ return libtextclassifier2::CreateDatetimeModelExtractor(
+ _fbb,
+ extractor,
+ pattern ? _fbb.CreateString(pattern) : 0,
+ locales ? _fbb.CreateVector<int32_t>(*locales) : 0);
+}
+
+flatbuffers::Offset<DatetimeModelExtractor> CreateDatetimeModelExtractor(flatbuffers::FlatBufferBuilder &_fbb, const DatetimeModelExtractorT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+struct DatetimeModelT : public flatbuffers::NativeTable {
+ typedef DatetimeModel TableType;
+ std::vector<std::string> locales;
+ std::vector<std::unique_ptr<DatetimeModelPatternT>> patterns;
+ std::vector<std::unique_ptr<DatetimeModelExtractorT>> extractors;
+ DatetimeModelT() {
+ }
+};
+
+struct DatetimeModel FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef DatetimeModelT NativeTableType;
+ enum {
+ VT_LOCALES = 4,
+ VT_PATTERNS = 6,
+ VT_EXTRACTORS = 8
+ };
+ const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *locales() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_LOCALES);
+ }
+ const flatbuffers::Vector<flatbuffers::Offset<DatetimeModelPattern>> *patterns() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<DatetimeModelPattern>> *>(VT_PATTERNS);
+ }
+ const flatbuffers::Vector<flatbuffers::Offset<DatetimeModelExtractor>> *extractors() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<DatetimeModelExtractor>> *>(VT_EXTRACTORS);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffset(verifier, VT_LOCALES) &&
+ verifier.Verify(locales()) &&
+ verifier.VerifyVectorOfStrings(locales()) &&
+ VerifyOffset(verifier, VT_PATTERNS) &&
+ verifier.Verify(patterns()) &&
+ verifier.VerifyVectorOfTables(patterns()) &&
+ VerifyOffset(verifier, VT_EXTRACTORS) &&
+ verifier.Verify(extractors()) &&
+ verifier.VerifyVectorOfTables(extractors()) &&
+ verifier.EndTable();
+ }
+ DatetimeModelT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(DatetimeModelT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<DatetimeModel> Pack(flatbuffers::FlatBufferBuilder &_fbb, const DatetimeModelT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct DatetimeModelBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_locales(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> locales) {
+ fbb_.AddOffset(DatetimeModel::VT_LOCALES, locales);
+ }
+ void add_patterns(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<DatetimeModelPattern>>> patterns) {
+ fbb_.AddOffset(DatetimeModel::VT_PATTERNS, patterns);
+ }
+ void add_extractors(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<DatetimeModelExtractor>>> extractors) {
+ fbb_.AddOffset(DatetimeModel::VT_EXTRACTORS, extractors);
+ }
+ explicit DatetimeModelBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ DatetimeModelBuilder &operator=(const DatetimeModelBuilder &);
+ flatbuffers::Offset<DatetimeModel> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<DatetimeModel>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<DatetimeModel> CreateDatetimeModel(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> locales = 0,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<DatetimeModelPattern>>> patterns = 0,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<DatetimeModelExtractor>>> extractors = 0) {
+ DatetimeModelBuilder builder_(_fbb);
+ builder_.add_extractors(extractors);
+ builder_.add_patterns(patterns);
+ builder_.add_locales(locales);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<DatetimeModel> CreateDatetimeModelDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const std::vector<flatbuffers::Offset<flatbuffers::String>> *locales = nullptr,
+ const std::vector<flatbuffers::Offset<DatetimeModelPattern>> *patterns = nullptr,
+ const std::vector<flatbuffers::Offset<DatetimeModelExtractor>> *extractors = nullptr) {
+ return libtextclassifier2::CreateDatetimeModel(
+ _fbb,
+ locales ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*locales) : 0,
+ patterns ? _fbb.CreateVector<flatbuffers::Offset<DatetimeModelPattern>>(*patterns) : 0,
+ extractors ? _fbb.CreateVector<flatbuffers::Offset<DatetimeModelExtractor>>(*extractors) : 0);
+}
+
+flatbuffers::Offset<DatetimeModel> CreateDatetimeModel(flatbuffers::FlatBufferBuilder &_fbb, const DatetimeModelT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+struct ModelTriggeringOptionsT : public flatbuffers::NativeTable {
+ typedef ModelTriggeringOptions TableType;
+ float min_annotate_confidence;
+ ModelTriggeringOptionsT()
+ : min_annotate_confidence(0.0f) {
+ }
+};
+
+struct ModelTriggeringOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef ModelTriggeringOptionsT NativeTableType;
+ enum {
+ VT_MIN_ANNOTATE_CONFIDENCE = 4
+ };
+ float min_annotate_confidence() const {
+ return GetField<float>(VT_MIN_ANNOTATE_CONFIDENCE, 0.0f);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<float>(verifier, VT_MIN_ANNOTATE_CONFIDENCE) &&
+ verifier.EndTable();
+ }
+ ModelTriggeringOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(ModelTriggeringOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<ModelTriggeringOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const ModelTriggeringOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct ModelTriggeringOptionsBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_min_annotate_confidence(float min_annotate_confidence) {
+ fbb_.AddElement<float>(ModelTriggeringOptions::VT_MIN_ANNOTATE_CONFIDENCE, min_annotate_confidence, 0.0f);
+ }
+ explicit ModelTriggeringOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ ModelTriggeringOptionsBuilder &operator=(const ModelTriggeringOptionsBuilder &);
+ flatbuffers::Offset<ModelTriggeringOptions> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<ModelTriggeringOptions>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<ModelTriggeringOptions> CreateModelTriggeringOptions(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ float min_annotate_confidence = 0.0f) {
+ ModelTriggeringOptionsBuilder builder_(_fbb);
+ builder_.add_min_annotate_confidence(min_annotate_confidence);
+ return builder_.Finish();
+}
+
+flatbuffers::Offset<ModelTriggeringOptions> CreateModelTriggeringOptions(flatbuffers::FlatBufferBuilder &_fbb, const ModelTriggeringOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
struct ModelT : public flatbuffers::NativeTable {
typedef Model TableType;
- std::string language;
+ std::string locales;
int32_t version;
std::unique_ptr<FeatureProcessorOptionsT> selection_feature_options;
std::unique_ptr<FeatureProcessorOptionsT> classification_feature_options;
std::vector<uint8_t> selection_model;
std::vector<uint8_t> classification_model;
std::vector<uint8_t> embedding_model;
- std::unique_ptr<RegexModelOptionsT> regex_options;
+ std::unique_ptr<RegexModelT> regex_model;
std::unique_ptr<SelectionModelOptionsT> selection_options;
std::unique_ptr<ClassificationModelOptionsT> classification_options;
- std::unique_ptr<StructuredRegexModelT> regex_model;
+ std::unique_ptr<DatetimeModelT> datetime_model;
+ std::unique_ptr<ModelTriggeringOptionsT> triggering_options;
ModelT()
: version(0) {
}
@@ -630,20 +1155,21 @@
struct Model FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
typedef ModelT NativeTableType;
enum {
- VT_LANGUAGE = 4,
+ VT_LOCALES = 4,
VT_VERSION = 6,
VT_SELECTION_FEATURE_OPTIONS = 8,
VT_CLASSIFICATION_FEATURE_OPTIONS = 10,
VT_SELECTION_MODEL = 12,
VT_CLASSIFICATION_MODEL = 14,
VT_EMBEDDING_MODEL = 16,
- VT_REGEX_OPTIONS = 18,
+ VT_REGEX_MODEL = 18,
VT_SELECTION_OPTIONS = 20,
VT_CLASSIFICATION_OPTIONS = 22,
- VT_REGEX_MODEL = 24
+ VT_DATETIME_MODEL = 24,
+ VT_TRIGGERING_OPTIONS = 26
};
- const flatbuffers::String *language() const {
- return GetPointer<const flatbuffers::String *>(VT_LANGUAGE);
+ const flatbuffers::String *locales() const {
+ return GetPointer<const flatbuffers::String *>(VT_LOCALES);
}
int32_t version() const {
return GetField<int32_t>(VT_VERSION, 0);
@@ -663,8 +1189,8 @@
const flatbuffers::Vector<uint8_t> *embedding_model() const {
return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_EMBEDDING_MODEL);
}
- const RegexModelOptions *regex_options() const {
- return GetPointer<const RegexModelOptions *>(VT_REGEX_OPTIONS);
+ const RegexModel *regex_model() const {
+ return GetPointer<const RegexModel *>(VT_REGEX_MODEL);
}
const SelectionModelOptions *selection_options() const {
return GetPointer<const SelectionModelOptions *>(VT_SELECTION_OPTIONS);
@@ -672,13 +1198,16 @@
const ClassificationModelOptions *classification_options() const {
return GetPointer<const ClassificationModelOptions *>(VT_CLASSIFICATION_OPTIONS);
}
- const StructuredRegexModel *regex_model() const {
- return GetPointer<const StructuredRegexModel *>(VT_REGEX_MODEL);
+ const DatetimeModel *datetime_model() const {
+ return GetPointer<const DatetimeModel *>(VT_DATETIME_MODEL);
+ }
+ const ModelTriggeringOptions *triggering_options() const {
+ return GetPointer<const ModelTriggeringOptions *>(VT_TRIGGERING_OPTIONS);
}
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
- VerifyOffset(verifier, VT_LANGUAGE) &&
- verifier.Verify(language()) &&
+ VerifyOffset(verifier, VT_LOCALES) &&
+ verifier.Verify(locales()) &&
VerifyField<int32_t>(verifier, VT_VERSION) &&
VerifyOffset(verifier, VT_SELECTION_FEATURE_OPTIONS) &&
verifier.VerifyTable(selection_feature_options()) &&
@@ -690,14 +1219,16 @@
verifier.Verify(classification_model()) &&
VerifyOffset(verifier, VT_EMBEDDING_MODEL) &&
verifier.Verify(embedding_model()) &&
- VerifyOffset(verifier, VT_REGEX_OPTIONS) &&
- verifier.VerifyTable(regex_options()) &&
+ VerifyOffset(verifier, VT_REGEX_MODEL) &&
+ verifier.VerifyTable(regex_model()) &&
VerifyOffset(verifier, VT_SELECTION_OPTIONS) &&
verifier.VerifyTable(selection_options()) &&
VerifyOffset(verifier, VT_CLASSIFICATION_OPTIONS) &&
verifier.VerifyTable(classification_options()) &&
- VerifyOffset(verifier, VT_REGEX_MODEL) &&
- verifier.VerifyTable(regex_model()) &&
+ VerifyOffset(verifier, VT_DATETIME_MODEL) &&
+ verifier.VerifyTable(datetime_model()) &&
+ VerifyOffset(verifier, VT_TRIGGERING_OPTIONS) &&
+ verifier.VerifyTable(triggering_options()) &&
verifier.EndTable();
}
ModelT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@@ -708,8 +1239,8 @@
struct ModelBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
- void add_language(flatbuffers::Offset<flatbuffers::String> language) {
- fbb_.AddOffset(Model::VT_LANGUAGE, language);
+ void add_locales(flatbuffers::Offset<flatbuffers::String> locales) {
+ fbb_.AddOffset(Model::VT_LOCALES, locales);
}
void add_version(int32_t version) {
fbb_.AddElement<int32_t>(Model::VT_VERSION, version, 0);
@@ -729,8 +1260,8 @@
void add_embedding_model(flatbuffers::Offset<flatbuffers::Vector<uint8_t>> embedding_model) {
fbb_.AddOffset(Model::VT_EMBEDDING_MODEL, embedding_model);
}
- void add_regex_options(flatbuffers::Offset<RegexModelOptions> regex_options) {
- fbb_.AddOffset(Model::VT_REGEX_OPTIONS, regex_options);
+ void add_regex_model(flatbuffers::Offset<RegexModel> regex_model) {
+ fbb_.AddOffset(Model::VT_REGEX_MODEL, regex_model);
}
void add_selection_options(flatbuffers::Offset<SelectionModelOptions> selection_options) {
fbb_.AddOffset(Model::VT_SELECTION_OPTIONS, selection_options);
@@ -738,8 +1269,11 @@
void add_classification_options(flatbuffers::Offset<ClassificationModelOptions> classification_options) {
fbb_.AddOffset(Model::VT_CLASSIFICATION_OPTIONS, classification_options);
}
- void add_regex_model(flatbuffers::Offset<StructuredRegexModel> regex_model) {
- fbb_.AddOffset(Model::VT_REGEX_MODEL, regex_model);
+ void add_datetime_model(flatbuffers::Offset<DatetimeModel> datetime_model) {
+ fbb_.AddOffset(Model::VT_DATETIME_MODEL, datetime_model);
+ }
+ void add_triggering_options(flatbuffers::Offset<ModelTriggeringOptions> triggering_options) {
+ fbb_.AddOffset(Model::VT_TRIGGERING_OPTIONS, triggering_options);
}
explicit ModelBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
@@ -755,58 +1289,62 @@
inline flatbuffers::Offset<Model> CreateModel(
flatbuffers::FlatBufferBuilder &_fbb,
- flatbuffers::Offset<flatbuffers::String> language = 0,
+ flatbuffers::Offset<flatbuffers::String> locales = 0,
int32_t version = 0,
flatbuffers::Offset<FeatureProcessorOptions> selection_feature_options = 0,
flatbuffers::Offset<FeatureProcessorOptions> classification_feature_options = 0,
flatbuffers::Offset<flatbuffers::Vector<uint8_t>> selection_model = 0,
flatbuffers::Offset<flatbuffers::Vector<uint8_t>> classification_model = 0,
flatbuffers::Offset<flatbuffers::Vector<uint8_t>> embedding_model = 0,
- flatbuffers::Offset<RegexModelOptions> regex_options = 0,
+ flatbuffers::Offset<RegexModel> regex_model = 0,
flatbuffers::Offset<SelectionModelOptions> selection_options = 0,
flatbuffers::Offset<ClassificationModelOptions> classification_options = 0,
- flatbuffers::Offset<StructuredRegexModel> regex_model = 0) {
+ flatbuffers::Offset<DatetimeModel> datetime_model = 0,
+ flatbuffers::Offset<ModelTriggeringOptions> triggering_options = 0) {
ModelBuilder builder_(_fbb);
- builder_.add_regex_model(regex_model);
+ builder_.add_triggering_options(triggering_options);
+ builder_.add_datetime_model(datetime_model);
builder_.add_classification_options(classification_options);
builder_.add_selection_options(selection_options);
- builder_.add_regex_options(regex_options);
+ builder_.add_regex_model(regex_model);
builder_.add_embedding_model(embedding_model);
builder_.add_classification_model(classification_model);
builder_.add_selection_model(selection_model);
builder_.add_classification_feature_options(classification_feature_options);
builder_.add_selection_feature_options(selection_feature_options);
builder_.add_version(version);
- builder_.add_language(language);
+ builder_.add_locales(locales);
return builder_.Finish();
}
inline flatbuffers::Offset<Model> CreateModelDirect(
flatbuffers::FlatBufferBuilder &_fbb,
- const char *language = nullptr,
+ const char *locales = nullptr,
int32_t version = 0,
flatbuffers::Offset<FeatureProcessorOptions> selection_feature_options = 0,
flatbuffers::Offset<FeatureProcessorOptions> classification_feature_options = 0,
const std::vector<uint8_t> *selection_model = nullptr,
const std::vector<uint8_t> *classification_model = nullptr,
const std::vector<uint8_t> *embedding_model = nullptr,
- flatbuffers::Offset<RegexModelOptions> regex_options = 0,
+ flatbuffers::Offset<RegexModel> regex_model = 0,
flatbuffers::Offset<SelectionModelOptions> selection_options = 0,
flatbuffers::Offset<ClassificationModelOptions> classification_options = 0,
- flatbuffers::Offset<StructuredRegexModel> regex_model = 0) {
+ flatbuffers::Offset<DatetimeModel> datetime_model = 0,
+ flatbuffers::Offset<ModelTriggeringOptions> triggering_options = 0) {
return libtextclassifier2::CreateModel(
_fbb,
- language ? _fbb.CreateString(language) : 0,
+ locales ? _fbb.CreateString(locales) : 0,
version,
selection_feature_options,
classification_feature_options,
selection_model ? _fbb.CreateVector<uint8_t>(*selection_model) : 0,
classification_model ? _fbb.CreateVector<uint8_t>(*classification_model) : 0,
embedding_model ? _fbb.CreateVector<uint8_t>(*embedding_model) : 0,
- regex_options,
+ regex_model,
selection_options,
classification_options,
- regex_model);
+ datetime_model,
+ triggering_options);
}
flatbuffers::Offset<Model> CreateModel(flatbuffers::FlatBufferBuilder &_fbb, const ModelT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
@@ -901,523 +1439,6 @@
flatbuffers::Offset<TokenizationCodepointRange> CreateTokenizationCodepointRange(flatbuffers::FlatBufferBuilder &_fbb, const TokenizationCodepointRangeT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
-struct FeatureProcessorOptionsT : public flatbuffers::NativeTable {
- typedef FeatureProcessorOptions TableType;
- int32_t num_buckets;
- int32_t embedding_size;
- int32_t context_size;
- int32_t max_selection_span;
- std::vector<int32_t> chargram_orders;
- int32_t max_word_length;
- bool unicode_aware_features;
- bool extract_case_feature;
- bool extract_selection_mask_feature;
- std::vector<std::string> regexp_feature;
- bool remap_digits;
- bool lowercase_tokens;
- bool selection_reduced_output_space;
- std::vector<std::string> collections;
- int32_t default_collection;
- bool only_use_line_with_click;
- bool split_tokens_on_selection_boundaries;
- std::vector<std::unique_ptr<TokenizationCodepointRangeT>> tokenization_codepoint_config;
- libtextclassifier2::FeatureProcessorOptions_::CenterTokenSelectionMethod center_token_selection_method;
- bool snap_label_span_boundaries_to_containing_tokens;
- std::vector<std::unique_ptr<libtextclassifier2::FeatureProcessorOptions_::CodepointRangeT>> supported_codepoint_ranges;
- std::vector<std::unique_ptr<libtextclassifier2::FeatureProcessorOptions_::CodepointRangeT>> internal_tokenizer_codepoint_ranges;
- float min_supported_codepoint_ratio;
- int32_t feature_version;
- libtextclassifier2::FeatureProcessorOptions_::TokenizationType tokenization_type;
- bool icu_preserve_whitespace_tokens;
- std::vector<int32_t> ignored_span_boundary_codepoints;
- bool click_random_token_in_selection;
- std::vector<std::unique_ptr<libtextclassifier2::FeatureProcessorOptions_::CollectionMapEntryT>> alternative_collection_map;
- std::unique_ptr<libtextclassifier2::FeatureProcessorOptions_::BoundsSensitiveFeaturesT> bounds_sensitive_features;
- bool split_selection_candidates;
- std::vector<std::string> allowed_chargrams;
- bool tokenize_on_script_change;
- FeatureProcessorOptionsT()
- : num_buckets(-1),
- embedding_size(-1),
- context_size(-1),
- max_selection_span(-1),
- max_word_length(20),
- unicode_aware_features(false),
- extract_case_feature(false),
- extract_selection_mask_feature(false),
- remap_digits(false),
- lowercase_tokens(false),
- selection_reduced_output_space(false),
- default_collection(-1),
- only_use_line_with_click(false),
- split_tokens_on_selection_boundaries(false),
- center_token_selection_method(libtextclassifier2::FeatureProcessorOptions_::CenterTokenSelectionMethod_DEFAULT_CENTER_TOKEN_METHOD),
- snap_label_span_boundaries_to_containing_tokens(false),
- min_supported_codepoint_ratio(0.0f),
- feature_version(0),
- tokenization_type(libtextclassifier2::FeatureProcessorOptions_::TokenizationType_INVALID_TOKENIZATION_TYPE),
- icu_preserve_whitespace_tokens(false),
- click_random_token_in_selection(false),
- split_selection_candidates(false),
- tokenize_on_script_change(false) {
- }
-};
-
-struct FeatureProcessorOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
- typedef FeatureProcessorOptionsT NativeTableType;
- enum {
- VT_NUM_BUCKETS = 4,
- VT_EMBEDDING_SIZE = 6,
- VT_CONTEXT_SIZE = 8,
- VT_MAX_SELECTION_SPAN = 10,
- VT_CHARGRAM_ORDERS = 12,
- VT_MAX_WORD_LENGTH = 14,
- VT_UNICODE_AWARE_FEATURES = 16,
- VT_EXTRACT_CASE_FEATURE = 18,
- VT_EXTRACT_SELECTION_MASK_FEATURE = 20,
- VT_REGEXP_FEATURE = 22,
- VT_REMAP_DIGITS = 24,
- VT_LOWERCASE_TOKENS = 26,
- VT_SELECTION_REDUCED_OUTPUT_SPACE = 28,
- VT_COLLECTIONS = 30,
- VT_DEFAULT_COLLECTION = 32,
- VT_ONLY_USE_LINE_WITH_CLICK = 34,
- VT_SPLIT_TOKENS_ON_SELECTION_BOUNDARIES = 36,
- VT_TOKENIZATION_CODEPOINT_CONFIG = 38,
- VT_CENTER_TOKEN_SELECTION_METHOD = 40,
- VT_SNAP_LABEL_SPAN_BOUNDARIES_TO_CONTAINING_TOKENS = 42,
- VT_SUPPORTED_CODEPOINT_RANGES = 44,
- VT_INTERNAL_TOKENIZER_CODEPOINT_RANGES = 46,
- VT_MIN_SUPPORTED_CODEPOINT_RATIO = 48,
- VT_FEATURE_VERSION = 50,
- VT_TOKENIZATION_TYPE = 52,
- VT_ICU_PRESERVE_WHITESPACE_TOKENS = 54,
- VT_IGNORED_SPAN_BOUNDARY_CODEPOINTS = 56,
- VT_CLICK_RANDOM_TOKEN_IN_SELECTION = 58,
- VT_ALTERNATIVE_COLLECTION_MAP = 60,
- VT_BOUNDS_SENSITIVE_FEATURES = 62,
- VT_SPLIT_SELECTION_CANDIDATES = 64,
- VT_ALLOWED_CHARGRAMS = 66,
- VT_TOKENIZE_ON_SCRIPT_CHANGE = 68
- };
- int32_t num_buckets() const {
- return GetField<int32_t>(VT_NUM_BUCKETS, -1);
- }
- int32_t embedding_size() const {
- return GetField<int32_t>(VT_EMBEDDING_SIZE, -1);
- }
- int32_t context_size() const {
- return GetField<int32_t>(VT_CONTEXT_SIZE, -1);
- }
- int32_t max_selection_span() const {
- return GetField<int32_t>(VT_MAX_SELECTION_SPAN, -1);
- }
- const flatbuffers::Vector<int32_t> *chargram_orders() const {
- return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_CHARGRAM_ORDERS);
- }
- int32_t max_word_length() const {
- return GetField<int32_t>(VT_MAX_WORD_LENGTH, 20);
- }
- bool unicode_aware_features() const {
- return GetField<uint8_t>(VT_UNICODE_AWARE_FEATURES, 0) != 0;
- }
- bool extract_case_feature() const {
- return GetField<uint8_t>(VT_EXTRACT_CASE_FEATURE, 0) != 0;
- }
- bool extract_selection_mask_feature() const {
- return GetField<uint8_t>(VT_EXTRACT_SELECTION_MASK_FEATURE, 0) != 0;
- }
- const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *regexp_feature() const {
- return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_REGEXP_FEATURE);
- }
- bool remap_digits() const {
- return GetField<uint8_t>(VT_REMAP_DIGITS, 0) != 0;
- }
- bool lowercase_tokens() const {
- return GetField<uint8_t>(VT_LOWERCASE_TOKENS, 0) != 0;
- }
- bool selection_reduced_output_space() const {
- return GetField<uint8_t>(VT_SELECTION_REDUCED_OUTPUT_SPACE, 0) != 0;
- }
- const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *collections() const {
- return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_COLLECTIONS);
- }
- int32_t default_collection() const {
- return GetField<int32_t>(VT_DEFAULT_COLLECTION, -1);
- }
- bool only_use_line_with_click() const {
- return GetField<uint8_t>(VT_ONLY_USE_LINE_WITH_CLICK, 0) != 0;
- }
- bool split_tokens_on_selection_boundaries() const {
- return GetField<uint8_t>(VT_SPLIT_TOKENS_ON_SELECTION_BOUNDARIES, 0) != 0;
- }
- const flatbuffers::Vector<flatbuffers::Offset<TokenizationCodepointRange>> *tokenization_codepoint_config() const {
- return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<TokenizationCodepointRange>> *>(VT_TOKENIZATION_CODEPOINT_CONFIG);
- }
- libtextclassifier2::FeatureProcessorOptions_::CenterTokenSelectionMethod center_token_selection_method() const {
- return static_cast<libtextclassifier2::FeatureProcessorOptions_::CenterTokenSelectionMethod>(GetField<int32_t>(VT_CENTER_TOKEN_SELECTION_METHOD, 0));
- }
- bool snap_label_span_boundaries_to_containing_tokens() const {
- return GetField<uint8_t>(VT_SNAP_LABEL_SPAN_BOUNDARIES_TO_CONTAINING_TOKENS, 0) != 0;
- }
- const flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CodepointRange>> *supported_codepoint_ranges() const {
- return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CodepointRange>> *>(VT_SUPPORTED_CODEPOINT_RANGES);
- }
- const flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CodepointRange>> *internal_tokenizer_codepoint_ranges() const {
- return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CodepointRange>> *>(VT_INTERNAL_TOKENIZER_CODEPOINT_RANGES);
- }
- float min_supported_codepoint_ratio() const {
- return GetField<float>(VT_MIN_SUPPORTED_CODEPOINT_RATIO, 0.0f);
- }
- int32_t feature_version() const {
- return GetField<int32_t>(VT_FEATURE_VERSION, 0);
- }
- libtextclassifier2::FeatureProcessorOptions_::TokenizationType tokenization_type() const {
- return static_cast<libtextclassifier2::FeatureProcessorOptions_::TokenizationType>(GetField<int32_t>(VT_TOKENIZATION_TYPE, 0));
- }
- bool icu_preserve_whitespace_tokens() const {
- return GetField<uint8_t>(VT_ICU_PRESERVE_WHITESPACE_TOKENS, 0) != 0;
- }
- const flatbuffers::Vector<int32_t> *ignored_span_boundary_codepoints() const {
- return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_IGNORED_SPAN_BOUNDARY_CODEPOINTS);
- }
- bool click_random_token_in_selection() const {
- return GetField<uint8_t>(VT_CLICK_RANDOM_TOKEN_IN_SELECTION, 0) != 0;
- }
- const flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CollectionMapEntry>> *alternative_collection_map() const {
- return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CollectionMapEntry>> *>(VT_ALTERNATIVE_COLLECTION_MAP);
- }
- const libtextclassifier2::FeatureProcessorOptions_::BoundsSensitiveFeatures *bounds_sensitive_features() const {
- return GetPointer<const libtextclassifier2::FeatureProcessorOptions_::BoundsSensitiveFeatures *>(VT_BOUNDS_SENSITIVE_FEATURES);
- }
- bool split_selection_candidates() const {
- return GetField<uint8_t>(VT_SPLIT_SELECTION_CANDIDATES, 0) != 0;
- }
- const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *allowed_chargrams() const {
- return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_ALLOWED_CHARGRAMS);
- }
- bool tokenize_on_script_change() const {
- return GetField<uint8_t>(VT_TOKENIZE_ON_SCRIPT_CHANGE, 0) != 0;
- }
- bool Verify(flatbuffers::Verifier &verifier) const {
- return VerifyTableStart(verifier) &&
- VerifyField<int32_t>(verifier, VT_NUM_BUCKETS) &&
- VerifyField<int32_t>(verifier, VT_EMBEDDING_SIZE) &&
- VerifyField<int32_t>(verifier, VT_CONTEXT_SIZE) &&
- VerifyField<int32_t>(verifier, VT_MAX_SELECTION_SPAN) &&
- VerifyOffset(verifier, VT_CHARGRAM_ORDERS) &&
- verifier.Verify(chargram_orders()) &&
- VerifyField<int32_t>(verifier, VT_MAX_WORD_LENGTH) &&
- VerifyField<uint8_t>(verifier, VT_UNICODE_AWARE_FEATURES) &&
- VerifyField<uint8_t>(verifier, VT_EXTRACT_CASE_FEATURE) &&
- VerifyField<uint8_t>(verifier, VT_EXTRACT_SELECTION_MASK_FEATURE) &&
- VerifyOffset(verifier, VT_REGEXP_FEATURE) &&
- verifier.Verify(regexp_feature()) &&
- verifier.VerifyVectorOfStrings(regexp_feature()) &&
- VerifyField<uint8_t>(verifier, VT_REMAP_DIGITS) &&
- VerifyField<uint8_t>(verifier, VT_LOWERCASE_TOKENS) &&
- VerifyField<uint8_t>(verifier, VT_SELECTION_REDUCED_OUTPUT_SPACE) &&
- VerifyOffset(verifier, VT_COLLECTIONS) &&
- verifier.Verify(collections()) &&
- verifier.VerifyVectorOfStrings(collections()) &&
- VerifyField<int32_t>(verifier, VT_DEFAULT_COLLECTION) &&
- VerifyField<uint8_t>(verifier, VT_ONLY_USE_LINE_WITH_CLICK) &&
- VerifyField<uint8_t>(verifier, VT_SPLIT_TOKENS_ON_SELECTION_BOUNDARIES) &&
- VerifyOffset(verifier, VT_TOKENIZATION_CODEPOINT_CONFIG) &&
- verifier.Verify(tokenization_codepoint_config()) &&
- verifier.VerifyVectorOfTables(tokenization_codepoint_config()) &&
- VerifyField<int32_t>(verifier, VT_CENTER_TOKEN_SELECTION_METHOD) &&
- VerifyField<uint8_t>(verifier, VT_SNAP_LABEL_SPAN_BOUNDARIES_TO_CONTAINING_TOKENS) &&
- VerifyOffset(verifier, VT_SUPPORTED_CODEPOINT_RANGES) &&
- verifier.Verify(supported_codepoint_ranges()) &&
- verifier.VerifyVectorOfTables(supported_codepoint_ranges()) &&
- VerifyOffset(verifier, VT_INTERNAL_TOKENIZER_CODEPOINT_RANGES) &&
- verifier.Verify(internal_tokenizer_codepoint_ranges()) &&
- verifier.VerifyVectorOfTables(internal_tokenizer_codepoint_ranges()) &&
- VerifyField<float>(verifier, VT_MIN_SUPPORTED_CODEPOINT_RATIO) &&
- VerifyField<int32_t>(verifier, VT_FEATURE_VERSION) &&
- VerifyField<int32_t>(verifier, VT_TOKENIZATION_TYPE) &&
- VerifyField<uint8_t>(verifier, VT_ICU_PRESERVE_WHITESPACE_TOKENS) &&
- VerifyOffset(verifier, VT_IGNORED_SPAN_BOUNDARY_CODEPOINTS) &&
- verifier.Verify(ignored_span_boundary_codepoints()) &&
- VerifyField<uint8_t>(verifier, VT_CLICK_RANDOM_TOKEN_IN_SELECTION) &&
- VerifyOffset(verifier, VT_ALTERNATIVE_COLLECTION_MAP) &&
- verifier.Verify(alternative_collection_map()) &&
- verifier.VerifyVectorOfTables(alternative_collection_map()) &&
- VerifyOffset(verifier, VT_BOUNDS_SENSITIVE_FEATURES) &&
- verifier.VerifyTable(bounds_sensitive_features()) &&
- VerifyField<uint8_t>(verifier, VT_SPLIT_SELECTION_CANDIDATES) &&
- VerifyOffset(verifier, VT_ALLOWED_CHARGRAMS) &&
- verifier.Verify(allowed_chargrams()) &&
- verifier.VerifyVectorOfStrings(allowed_chargrams()) &&
- VerifyField<uint8_t>(verifier, VT_TOKENIZE_ON_SCRIPT_CHANGE) &&
- verifier.EndTable();
- }
- FeatureProcessorOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
- void UnPackTo(FeatureProcessorOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
- static flatbuffers::Offset<FeatureProcessorOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const FeatureProcessorOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
-};
-
-struct FeatureProcessorOptionsBuilder {
- flatbuffers::FlatBufferBuilder &fbb_;
- flatbuffers::uoffset_t start_;
- void add_num_buckets(int32_t num_buckets) {
- fbb_.AddElement<int32_t>(FeatureProcessorOptions::VT_NUM_BUCKETS, num_buckets, -1);
- }
- void add_embedding_size(int32_t embedding_size) {
- fbb_.AddElement<int32_t>(FeatureProcessorOptions::VT_EMBEDDING_SIZE, embedding_size, -1);
- }
- void add_context_size(int32_t context_size) {
- fbb_.AddElement<int32_t>(FeatureProcessorOptions::VT_CONTEXT_SIZE, context_size, -1);
- }
- void add_max_selection_span(int32_t max_selection_span) {
- fbb_.AddElement<int32_t>(FeatureProcessorOptions::VT_MAX_SELECTION_SPAN, max_selection_span, -1);
- }
- void add_chargram_orders(flatbuffers::Offset<flatbuffers::Vector<int32_t>> chargram_orders) {
- fbb_.AddOffset(FeatureProcessorOptions::VT_CHARGRAM_ORDERS, chargram_orders);
- }
- void add_max_word_length(int32_t max_word_length) {
- fbb_.AddElement<int32_t>(FeatureProcessorOptions::VT_MAX_WORD_LENGTH, max_word_length, 20);
- }
- void add_unicode_aware_features(bool unicode_aware_features) {
- fbb_.AddElement<uint8_t>(FeatureProcessorOptions::VT_UNICODE_AWARE_FEATURES, static_cast<uint8_t>(unicode_aware_features), 0);
- }
- void add_extract_case_feature(bool extract_case_feature) {
- fbb_.AddElement<uint8_t>(FeatureProcessorOptions::VT_EXTRACT_CASE_FEATURE, static_cast<uint8_t>(extract_case_feature), 0);
- }
- void add_extract_selection_mask_feature(bool extract_selection_mask_feature) {
- fbb_.AddElement<uint8_t>(FeatureProcessorOptions::VT_EXTRACT_SELECTION_MASK_FEATURE, static_cast<uint8_t>(extract_selection_mask_feature), 0);
- }
- void add_regexp_feature(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> regexp_feature) {
- fbb_.AddOffset(FeatureProcessorOptions::VT_REGEXP_FEATURE, regexp_feature);
- }
- void add_remap_digits(bool remap_digits) {
- fbb_.AddElement<uint8_t>(FeatureProcessorOptions::VT_REMAP_DIGITS, static_cast<uint8_t>(remap_digits), 0);
- }
- void add_lowercase_tokens(bool lowercase_tokens) {
- fbb_.AddElement<uint8_t>(FeatureProcessorOptions::VT_LOWERCASE_TOKENS, static_cast<uint8_t>(lowercase_tokens), 0);
- }
- void add_selection_reduced_output_space(bool selection_reduced_output_space) {
- fbb_.AddElement<uint8_t>(FeatureProcessorOptions::VT_SELECTION_REDUCED_OUTPUT_SPACE, static_cast<uint8_t>(selection_reduced_output_space), 0);
- }
- void add_collections(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> collections) {
- fbb_.AddOffset(FeatureProcessorOptions::VT_COLLECTIONS, collections);
- }
- void add_default_collection(int32_t default_collection) {
- fbb_.AddElement<int32_t>(FeatureProcessorOptions::VT_DEFAULT_COLLECTION, default_collection, -1);
- }
- void add_only_use_line_with_click(bool only_use_line_with_click) {
- fbb_.AddElement<uint8_t>(FeatureProcessorOptions::VT_ONLY_USE_LINE_WITH_CLICK, static_cast<uint8_t>(only_use_line_with_click), 0);
- }
- void add_split_tokens_on_selection_boundaries(bool split_tokens_on_selection_boundaries) {
- fbb_.AddElement<uint8_t>(FeatureProcessorOptions::VT_SPLIT_TOKENS_ON_SELECTION_BOUNDARIES, static_cast<uint8_t>(split_tokens_on_selection_boundaries), 0);
- }
- void add_tokenization_codepoint_config(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<TokenizationCodepointRange>>> tokenization_codepoint_config) {
- fbb_.AddOffset(FeatureProcessorOptions::VT_TOKENIZATION_CODEPOINT_CONFIG, tokenization_codepoint_config);
- }
- void add_center_token_selection_method(libtextclassifier2::FeatureProcessorOptions_::CenterTokenSelectionMethod center_token_selection_method) {
- fbb_.AddElement<int32_t>(FeatureProcessorOptions::VT_CENTER_TOKEN_SELECTION_METHOD, static_cast<int32_t>(center_token_selection_method), 0);
- }
- void add_snap_label_span_boundaries_to_containing_tokens(bool snap_label_span_boundaries_to_containing_tokens) {
- fbb_.AddElement<uint8_t>(FeatureProcessorOptions::VT_SNAP_LABEL_SPAN_BOUNDARIES_TO_CONTAINING_TOKENS, static_cast<uint8_t>(snap_label_span_boundaries_to_containing_tokens), 0);
- }
- void add_supported_codepoint_ranges(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CodepointRange>>> supported_codepoint_ranges) {
- fbb_.AddOffset(FeatureProcessorOptions::VT_SUPPORTED_CODEPOINT_RANGES, supported_codepoint_ranges);
- }
- void add_internal_tokenizer_codepoint_ranges(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CodepointRange>>> internal_tokenizer_codepoint_ranges) {
- fbb_.AddOffset(FeatureProcessorOptions::VT_INTERNAL_TOKENIZER_CODEPOINT_RANGES, internal_tokenizer_codepoint_ranges);
- }
- void add_min_supported_codepoint_ratio(float min_supported_codepoint_ratio) {
- fbb_.AddElement<float>(FeatureProcessorOptions::VT_MIN_SUPPORTED_CODEPOINT_RATIO, min_supported_codepoint_ratio, 0.0f);
- }
- void add_feature_version(int32_t feature_version) {
- fbb_.AddElement<int32_t>(FeatureProcessorOptions::VT_FEATURE_VERSION, feature_version, 0);
- }
- void add_tokenization_type(libtextclassifier2::FeatureProcessorOptions_::TokenizationType tokenization_type) {
- fbb_.AddElement<int32_t>(FeatureProcessorOptions::VT_TOKENIZATION_TYPE, static_cast<int32_t>(tokenization_type), 0);
- }
- void add_icu_preserve_whitespace_tokens(bool icu_preserve_whitespace_tokens) {
- fbb_.AddElement<uint8_t>(FeatureProcessorOptions::VT_ICU_PRESERVE_WHITESPACE_TOKENS, static_cast<uint8_t>(icu_preserve_whitespace_tokens), 0);
- }
- void add_ignored_span_boundary_codepoints(flatbuffers::Offset<flatbuffers::Vector<int32_t>> ignored_span_boundary_codepoints) {
- fbb_.AddOffset(FeatureProcessorOptions::VT_IGNORED_SPAN_BOUNDARY_CODEPOINTS, ignored_span_boundary_codepoints);
- }
- void add_click_random_token_in_selection(bool click_random_token_in_selection) {
- fbb_.AddElement<uint8_t>(FeatureProcessorOptions::VT_CLICK_RANDOM_TOKEN_IN_SELECTION, static_cast<uint8_t>(click_random_token_in_selection), 0);
- }
- void add_alternative_collection_map(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CollectionMapEntry>>> alternative_collection_map) {
- fbb_.AddOffset(FeatureProcessorOptions::VT_ALTERNATIVE_COLLECTION_MAP, alternative_collection_map);
- }
- void add_bounds_sensitive_features(flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::BoundsSensitiveFeatures> bounds_sensitive_features) {
- fbb_.AddOffset(FeatureProcessorOptions::VT_BOUNDS_SENSITIVE_FEATURES, bounds_sensitive_features);
- }
- void add_split_selection_candidates(bool split_selection_candidates) {
- fbb_.AddElement<uint8_t>(FeatureProcessorOptions::VT_SPLIT_SELECTION_CANDIDATES, static_cast<uint8_t>(split_selection_candidates), 0);
- }
- void add_allowed_chargrams(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> allowed_chargrams) {
- fbb_.AddOffset(FeatureProcessorOptions::VT_ALLOWED_CHARGRAMS, allowed_chargrams);
- }
- void add_tokenize_on_script_change(bool tokenize_on_script_change) {
- fbb_.AddElement<uint8_t>(FeatureProcessorOptions::VT_TOKENIZE_ON_SCRIPT_CHANGE, static_cast<uint8_t>(tokenize_on_script_change), 0);
- }
- explicit FeatureProcessorOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
- : fbb_(_fbb) {
- start_ = fbb_.StartTable();
- }
- FeatureProcessorOptionsBuilder &operator=(const FeatureProcessorOptionsBuilder &);
- flatbuffers::Offset<FeatureProcessorOptions> Finish() {
- const auto end = fbb_.EndTable(start_);
- auto o = flatbuffers::Offset<FeatureProcessorOptions>(end);
- return o;
- }
-};
-
-inline flatbuffers::Offset<FeatureProcessorOptions> CreateFeatureProcessorOptions(
- flatbuffers::FlatBufferBuilder &_fbb,
- int32_t num_buckets = -1,
- int32_t embedding_size = -1,
- int32_t context_size = -1,
- int32_t max_selection_span = -1,
- flatbuffers::Offset<flatbuffers::Vector<int32_t>> chargram_orders = 0,
- int32_t max_word_length = 20,
- bool unicode_aware_features = false,
- bool extract_case_feature = false,
- bool extract_selection_mask_feature = false,
- flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> regexp_feature = 0,
- bool remap_digits = false,
- bool lowercase_tokens = false,
- bool selection_reduced_output_space = false,
- flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> collections = 0,
- int32_t default_collection = -1,
- bool only_use_line_with_click = false,
- bool split_tokens_on_selection_boundaries = false,
- flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<TokenizationCodepointRange>>> tokenization_codepoint_config = 0,
- libtextclassifier2::FeatureProcessorOptions_::CenterTokenSelectionMethod center_token_selection_method = libtextclassifier2::FeatureProcessorOptions_::CenterTokenSelectionMethod_DEFAULT_CENTER_TOKEN_METHOD,
- bool snap_label_span_boundaries_to_containing_tokens = false,
- flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CodepointRange>>> supported_codepoint_ranges = 0,
- flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CodepointRange>>> internal_tokenizer_codepoint_ranges = 0,
- float min_supported_codepoint_ratio = 0.0f,
- int32_t feature_version = 0,
- libtextclassifier2::FeatureProcessorOptions_::TokenizationType tokenization_type = libtextclassifier2::FeatureProcessorOptions_::TokenizationType_INVALID_TOKENIZATION_TYPE,
- bool icu_preserve_whitespace_tokens = false,
- flatbuffers::Offset<flatbuffers::Vector<int32_t>> ignored_span_boundary_codepoints = 0,
- bool click_random_token_in_selection = false,
- flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CollectionMapEntry>>> alternative_collection_map = 0,
- flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::BoundsSensitiveFeatures> bounds_sensitive_features = 0,
- bool split_selection_candidates = false,
- flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> allowed_chargrams = 0,
- bool tokenize_on_script_change = false) {
- FeatureProcessorOptionsBuilder builder_(_fbb);
- builder_.add_allowed_chargrams(allowed_chargrams);
- builder_.add_bounds_sensitive_features(bounds_sensitive_features);
- builder_.add_alternative_collection_map(alternative_collection_map);
- builder_.add_ignored_span_boundary_codepoints(ignored_span_boundary_codepoints);
- builder_.add_tokenization_type(tokenization_type);
- builder_.add_feature_version(feature_version);
- builder_.add_min_supported_codepoint_ratio(min_supported_codepoint_ratio);
- builder_.add_internal_tokenizer_codepoint_ranges(internal_tokenizer_codepoint_ranges);
- builder_.add_supported_codepoint_ranges(supported_codepoint_ranges);
- builder_.add_center_token_selection_method(center_token_selection_method);
- builder_.add_tokenization_codepoint_config(tokenization_codepoint_config);
- builder_.add_default_collection(default_collection);
- builder_.add_collections(collections);
- builder_.add_regexp_feature(regexp_feature);
- builder_.add_max_word_length(max_word_length);
- builder_.add_chargram_orders(chargram_orders);
- builder_.add_max_selection_span(max_selection_span);
- builder_.add_context_size(context_size);
- builder_.add_embedding_size(embedding_size);
- builder_.add_num_buckets(num_buckets);
- builder_.add_tokenize_on_script_change(tokenize_on_script_change);
- builder_.add_split_selection_candidates(split_selection_candidates);
- builder_.add_click_random_token_in_selection(click_random_token_in_selection);
- builder_.add_icu_preserve_whitespace_tokens(icu_preserve_whitespace_tokens);
- builder_.add_snap_label_span_boundaries_to_containing_tokens(snap_label_span_boundaries_to_containing_tokens);
- builder_.add_split_tokens_on_selection_boundaries(split_tokens_on_selection_boundaries);
- builder_.add_only_use_line_with_click(only_use_line_with_click);
- builder_.add_selection_reduced_output_space(selection_reduced_output_space);
- builder_.add_lowercase_tokens(lowercase_tokens);
- builder_.add_remap_digits(remap_digits);
- builder_.add_extract_selection_mask_feature(extract_selection_mask_feature);
- builder_.add_extract_case_feature(extract_case_feature);
- builder_.add_unicode_aware_features(unicode_aware_features);
- return builder_.Finish();
-}
-
-inline flatbuffers::Offset<FeatureProcessorOptions> CreateFeatureProcessorOptionsDirect(
- flatbuffers::FlatBufferBuilder &_fbb,
- int32_t num_buckets = -1,
- int32_t embedding_size = -1,
- int32_t context_size = -1,
- int32_t max_selection_span = -1,
- const std::vector<int32_t> *chargram_orders = nullptr,
- int32_t max_word_length = 20,
- bool unicode_aware_features = false,
- bool extract_case_feature = false,
- bool extract_selection_mask_feature = false,
- const std::vector<flatbuffers::Offset<flatbuffers::String>> *regexp_feature = nullptr,
- bool remap_digits = false,
- bool lowercase_tokens = false,
- bool selection_reduced_output_space = false,
- const std::vector<flatbuffers::Offset<flatbuffers::String>> *collections = nullptr,
- int32_t default_collection = -1,
- bool only_use_line_with_click = false,
- bool split_tokens_on_selection_boundaries = false,
- const std::vector<flatbuffers::Offset<TokenizationCodepointRange>> *tokenization_codepoint_config = nullptr,
- libtextclassifier2::FeatureProcessorOptions_::CenterTokenSelectionMethod center_token_selection_method = libtextclassifier2::FeatureProcessorOptions_::CenterTokenSelectionMethod_DEFAULT_CENTER_TOKEN_METHOD,
- bool snap_label_span_boundaries_to_containing_tokens = false,
- const std::vector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CodepointRange>> *supported_codepoint_ranges = nullptr,
- const std::vector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CodepointRange>> *internal_tokenizer_codepoint_ranges = nullptr,
- float min_supported_codepoint_ratio = 0.0f,
- int32_t feature_version = 0,
- libtextclassifier2::FeatureProcessorOptions_::TokenizationType tokenization_type = libtextclassifier2::FeatureProcessorOptions_::TokenizationType_INVALID_TOKENIZATION_TYPE,
- bool icu_preserve_whitespace_tokens = false,
- const std::vector<int32_t> *ignored_span_boundary_codepoints = nullptr,
- bool click_random_token_in_selection = false,
- const std::vector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CollectionMapEntry>> *alternative_collection_map = nullptr,
- flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::BoundsSensitiveFeatures> bounds_sensitive_features = 0,
- bool split_selection_candidates = false,
- const std::vector<flatbuffers::Offset<flatbuffers::String>> *allowed_chargrams = nullptr,
- bool tokenize_on_script_change = false) {
- return libtextclassifier2::CreateFeatureProcessorOptions(
- _fbb,
- num_buckets,
- embedding_size,
- context_size,
- max_selection_span,
- chargram_orders ? _fbb.CreateVector<int32_t>(*chargram_orders) : 0,
- max_word_length,
- unicode_aware_features,
- extract_case_feature,
- extract_selection_mask_feature,
- regexp_feature ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*regexp_feature) : 0,
- remap_digits,
- lowercase_tokens,
- selection_reduced_output_space,
- collections ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*collections) : 0,
- default_collection,
- only_use_line_with_click,
- split_tokens_on_selection_boundaries,
- tokenization_codepoint_config ? _fbb.CreateVector<flatbuffers::Offset<TokenizationCodepointRange>>(*tokenization_codepoint_config) : 0,
- center_token_selection_method,
- snap_label_span_boundaries_to_containing_tokens,
- supported_codepoint_ranges ? _fbb.CreateVector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CodepointRange>>(*supported_codepoint_ranges) : 0,
- internal_tokenizer_codepoint_ranges ? _fbb.CreateVector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CodepointRange>>(*internal_tokenizer_codepoint_ranges) : 0,
- min_supported_codepoint_ratio,
- feature_version,
- tokenization_type,
- icu_preserve_whitespace_tokens,
- ignored_span_boundary_codepoints ? _fbb.CreateVector<int32_t>(*ignored_span_boundary_codepoints) : 0,
- click_random_token_in_selection,
- alternative_collection_map ? _fbb.CreateVector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CollectionMapEntry>>(*alternative_collection_map) : 0,
- bounds_sensitive_features,
- split_selection_candidates,
- allowed_chargrams ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*allowed_chargrams) : 0,
- tokenize_on_script_change);
-}
-
-flatbuffers::Offset<FeatureProcessorOptions> CreateFeatureProcessorOptions(flatbuffers::FlatBufferBuilder &_fbb, const FeatureProcessorOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
-
namespace FeatureProcessorOptions_ {
struct CodepointRangeT : public flatbuffers::NativeTable {
@@ -1486,82 +1507,6 @@
flatbuffers::Offset<CodepointRange> CreateCodepointRange(flatbuffers::FlatBufferBuilder &_fbb, const CodepointRangeT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
-struct CollectionMapEntryT : public flatbuffers::NativeTable {
- typedef CollectionMapEntry TableType;
- std::string key;
- std::string value;
- CollectionMapEntryT() {
- }
-};
-
-struct CollectionMapEntry FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
- typedef CollectionMapEntryT NativeTableType;
- enum {
- VT_KEY = 4,
- VT_VALUE = 6
- };
- const flatbuffers::String *key() const {
- return GetPointer<const flatbuffers::String *>(VT_KEY);
- }
- const flatbuffers::String *value() const {
- return GetPointer<const flatbuffers::String *>(VT_VALUE);
- }
- bool Verify(flatbuffers::Verifier &verifier) const {
- return VerifyTableStart(verifier) &&
- VerifyOffset(verifier, VT_KEY) &&
- verifier.Verify(key()) &&
- VerifyOffset(verifier, VT_VALUE) &&
- verifier.Verify(value()) &&
- verifier.EndTable();
- }
- CollectionMapEntryT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
- void UnPackTo(CollectionMapEntryT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
- static flatbuffers::Offset<CollectionMapEntry> Pack(flatbuffers::FlatBufferBuilder &_fbb, const CollectionMapEntryT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
-};
-
-struct CollectionMapEntryBuilder {
- flatbuffers::FlatBufferBuilder &fbb_;
- flatbuffers::uoffset_t start_;
- void add_key(flatbuffers::Offset<flatbuffers::String> key) {
- fbb_.AddOffset(CollectionMapEntry::VT_KEY, key);
- }
- void add_value(flatbuffers::Offset<flatbuffers::String> value) {
- fbb_.AddOffset(CollectionMapEntry::VT_VALUE, value);
- }
- explicit CollectionMapEntryBuilder(flatbuffers::FlatBufferBuilder &_fbb)
- : fbb_(_fbb) {
- start_ = fbb_.StartTable();
- }
- CollectionMapEntryBuilder &operator=(const CollectionMapEntryBuilder &);
- flatbuffers::Offset<CollectionMapEntry> Finish() {
- const auto end = fbb_.EndTable(start_);
- auto o = flatbuffers::Offset<CollectionMapEntry>(end);
- return o;
- }
-};
-
-inline flatbuffers::Offset<CollectionMapEntry> CreateCollectionMapEntry(
- flatbuffers::FlatBufferBuilder &_fbb,
- flatbuffers::Offset<flatbuffers::String> key = 0,
- flatbuffers::Offset<flatbuffers::String> value = 0) {
- CollectionMapEntryBuilder builder_(_fbb);
- builder_.add_value(value);
- builder_.add_key(key);
- return builder_.Finish();
-}
-
-inline flatbuffers::Offset<CollectionMapEntry> CreateCollectionMapEntryDirect(
- flatbuffers::FlatBufferBuilder &_fbb,
- const char *key = nullptr,
- const char *value = nullptr) {
- return libtextclassifier2::FeatureProcessorOptions_::CreateCollectionMapEntry(
- _fbb,
- key ? _fbb.CreateString(key) : 0,
- value ? _fbb.CreateString(value) : 0);
-}
-
-flatbuffers::Offset<CollectionMapEntry> CreateCollectionMapEntry(flatbuffers::FlatBufferBuilder &_fbb, const CollectionMapEntryT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
-
struct BoundsSensitiveFeaturesT : public flatbuffers::NativeTable {
typedef BoundsSensitiveFeatures TableType;
bool enabled;
@@ -1688,8 +1633,572 @@
flatbuffers::Offset<BoundsSensitiveFeatures> CreateBoundsSensitiveFeatures(flatbuffers::FlatBufferBuilder &_fbb, const BoundsSensitiveFeaturesT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+struct AlternativeCollectionMapEntryT : public flatbuffers::NativeTable {
+ typedef AlternativeCollectionMapEntry TableType;
+ std::string key;
+ std::string value;
+ AlternativeCollectionMapEntryT() {
+ }
+};
+
+struct AlternativeCollectionMapEntry FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef AlternativeCollectionMapEntryT NativeTableType;
+ enum {
+ VT_KEY = 4,
+ VT_VALUE = 6
+ };
+ const flatbuffers::String *key() const {
+ return GetPointer<const flatbuffers::String *>(VT_KEY);
+ }
+ const flatbuffers::String *value() const {
+ return GetPointer<const flatbuffers::String *>(VT_VALUE);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffset(verifier, VT_KEY) &&
+ verifier.Verify(key()) &&
+ VerifyOffset(verifier, VT_VALUE) &&
+ verifier.Verify(value()) &&
+ verifier.EndTable();
+ }
+ AlternativeCollectionMapEntryT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(AlternativeCollectionMapEntryT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<AlternativeCollectionMapEntry> Pack(flatbuffers::FlatBufferBuilder &_fbb, const AlternativeCollectionMapEntryT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct AlternativeCollectionMapEntryBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_key(flatbuffers::Offset<flatbuffers::String> key) {
+ fbb_.AddOffset(AlternativeCollectionMapEntry::VT_KEY, key);
+ }
+ void add_value(flatbuffers::Offset<flatbuffers::String> value) {
+ fbb_.AddOffset(AlternativeCollectionMapEntry::VT_VALUE, value);
+ }
+ explicit AlternativeCollectionMapEntryBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ AlternativeCollectionMapEntryBuilder &operator=(const AlternativeCollectionMapEntryBuilder &);
+ flatbuffers::Offset<AlternativeCollectionMapEntry> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<AlternativeCollectionMapEntry>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<AlternativeCollectionMapEntry> CreateAlternativeCollectionMapEntry(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::String> key = 0,
+ flatbuffers::Offset<flatbuffers::String> value = 0) {
+ AlternativeCollectionMapEntryBuilder builder_(_fbb);
+ builder_.add_value(value);
+ builder_.add_key(key);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<AlternativeCollectionMapEntry> CreateAlternativeCollectionMapEntryDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const char *key = nullptr,
+ const char *value = nullptr) {
+ return libtextclassifier2::FeatureProcessorOptions_::CreateAlternativeCollectionMapEntry(
+ _fbb,
+ key ? _fbb.CreateString(key) : 0,
+ value ? _fbb.CreateString(value) : 0);
+}
+
+flatbuffers::Offset<AlternativeCollectionMapEntry> CreateAlternativeCollectionMapEntry(flatbuffers::FlatBufferBuilder &_fbb, const AlternativeCollectionMapEntryT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
} // namespace FeatureProcessorOptions_
+struct FeatureProcessorOptionsT : public flatbuffers::NativeTable {
+ typedef FeatureProcessorOptions TableType;
+ int32_t num_buckets;
+ int32_t embedding_size;
+ int32_t context_size;
+ int32_t max_selection_span;
+ std::vector<int32_t> chargram_orders;
+ int32_t max_word_length;
+ bool unicode_aware_features;
+ bool extract_case_feature;
+ bool extract_selection_mask_feature;
+ std::vector<std::string> regexp_feature;
+ bool remap_digits;
+ bool lowercase_tokens;
+ bool selection_reduced_output_space;
+ std::vector<std::string> collections;
+ int32_t default_collection;
+ bool only_use_line_with_click;
+ bool split_tokens_on_selection_boundaries;
+ std::vector<std::unique_ptr<TokenizationCodepointRangeT>> tokenization_codepoint_config;
+ libtextclassifier2::FeatureProcessorOptions_::CenterTokenSelectionMethod center_token_selection_method;
+ bool snap_label_span_boundaries_to_containing_tokens;
+ std::vector<std::unique_ptr<libtextclassifier2::FeatureProcessorOptions_::CodepointRangeT>> supported_codepoint_ranges;
+ std::vector<std::unique_ptr<libtextclassifier2::FeatureProcessorOptions_::CodepointRangeT>> internal_tokenizer_codepoint_ranges;
+ float min_supported_codepoint_ratio;
+ int32_t feature_version;
+ libtextclassifier2::FeatureProcessorOptions_::TokenizationType tokenization_type;
+ bool icu_preserve_whitespace_tokens;
+ std::vector<int32_t> ignored_span_boundary_codepoints;
+ std::unique_ptr<libtextclassifier2::FeatureProcessorOptions_::BoundsSensitiveFeaturesT> bounds_sensitive_features;
+ std::vector<std::string> allowed_chargrams;
+ bool tokenize_on_script_change;
+ int32_t embedding_quantization_bits;
+ FeatureProcessorOptionsT()
+ : num_buckets(-1),
+ embedding_size(-1),
+ context_size(-1),
+ max_selection_span(-1),
+ max_word_length(20),
+ unicode_aware_features(false),
+ extract_case_feature(false),
+ extract_selection_mask_feature(false),
+ remap_digits(false),
+ lowercase_tokens(false),
+ selection_reduced_output_space(true),
+ default_collection(-1),
+ only_use_line_with_click(false),
+ split_tokens_on_selection_boundaries(false),
+ center_token_selection_method(libtextclassifier2::FeatureProcessorOptions_::CenterTokenSelectionMethod_DEFAULT_CENTER_TOKEN_METHOD),
+ snap_label_span_boundaries_to_containing_tokens(false),
+ min_supported_codepoint_ratio(0.0f),
+ feature_version(0),
+ tokenization_type(libtextclassifier2::FeatureProcessorOptions_::TokenizationType_INVALID_TOKENIZATION_TYPE),
+ icu_preserve_whitespace_tokens(false),
+ tokenize_on_script_change(false),
+ embedding_quantization_bits(8) {
+ }
+};
+
+struct FeatureProcessorOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef FeatureProcessorOptionsT NativeTableType;
+ enum {
+ VT_NUM_BUCKETS = 4,
+ VT_EMBEDDING_SIZE = 6,
+ VT_CONTEXT_SIZE = 8,
+ VT_MAX_SELECTION_SPAN = 10,
+ VT_CHARGRAM_ORDERS = 12,
+ VT_MAX_WORD_LENGTH = 14,
+ VT_UNICODE_AWARE_FEATURES = 16,
+ VT_EXTRACT_CASE_FEATURE = 18,
+ VT_EXTRACT_SELECTION_MASK_FEATURE = 20,
+ VT_REGEXP_FEATURE = 22,
+ VT_REMAP_DIGITS = 24,
+ VT_LOWERCASE_TOKENS = 26,
+ VT_SELECTION_REDUCED_OUTPUT_SPACE = 28,
+ VT_COLLECTIONS = 30,
+ VT_DEFAULT_COLLECTION = 32,
+ VT_ONLY_USE_LINE_WITH_CLICK = 34,
+ VT_SPLIT_TOKENS_ON_SELECTION_BOUNDARIES = 36,
+ VT_TOKENIZATION_CODEPOINT_CONFIG = 38,
+ VT_CENTER_TOKEN_SELECTION_METHOD = 40,
+ VT_SNAP_LABEL_SPAN_BOUNDARIES_TO_CONTAINING_TOKENS = 42,
+ VT_SUPPORTED_CODEPOINT_RANGES = 44,
+ VT_INTERNAL_TOKENIZER_CODEPOINT_RANGES = 46,
+ VT_MIN_SUPPORTED_CODEPOINT_RATIO = 48,
+ VT_FEATURE_VERSION = 50,
+ VT_TOKENIZATION_TYPE = 52,
+ VT_ICU_PRESERVE_WHITESPACE_TOKENS = 54,
+ VT_IGNORED_SPAN_BOUNDARY_CODEPOINTS = 56,
+ VT_BOUNDS_SENSITIVE_FEATURES = 58,
+ VT_ALLOWED_CHARGRAMS = 60,
+ VT_TOKENIZE_ON_SCRIPT_CHANGE = 62,
+ VT_EMBEDDING_QUANTIZATION_BITS = 64
+ };
+ int32_t num_buckets() const {
+ return GetField<int32_t>(VT_NUM_BUCKETS, -1);
+ }
+ int32_t embedding_size() const {
+ return GetField<int32_t>(VT_EMBEDDING_SIZE, -1);
+ }
+ int32_t context_size() const {
+ return GetField<int32_t>(VT_CONTEXT_SIZE, -1);
+ }
+ int32_t max_selection_span() const {
+ return GetField<int32_t>(VT_MAX_SELECTION_SPAN, -1);
+ }
+ const flatbuffers::Vector<int32_t> *chargram_orders() const {
+ return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_CHARGRAM_ORDERS);
+ }
+ int32_t max_word_length() const {
+ return GetField<int32_t>(VT_MAX_WORD_LENGTH, 20);
+ }
+ bool unicode_aware_features() const {
+ return GetField<uint8_t>(VT_UNICODE_AWARE_FEATURES, 0) != 0;
+ }
+ bool extract_case_feature() const {
+ return GetField<uint8_t>(VT_EXTRACT_CASE_FEATURE, 0) != 0;
+ }
+ bool extract_selection_mask_feature() const {
+ return GetField<uint8_t>(VT_EXTRACT_SELECTION_MASK_FEATURE, 0) != 0;
+ }
+ const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *regexp_feature() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_REGEXP_FEATURE);
+ }
+ bool remap_digits() const {
+ return GetField<uint8_t>(VT_REMAP_DIGITS, 0) != 0;
+ }
+ bool lowercase_tokens() const {
+ return GetField<uint8_t>(VT_LOWERCASE_TOKENS, 0) != 0;
+ }
+ bool selection_reduced_output_space() const {
+ return GetField<uint8_t>(VT_SELECTION_REDUCED_OUTPUT_SPACE, 1) != 0;
+ }
+ const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *collections() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_COLLECTIONS);
+ }
+ int32_t default_collection() const {
+ return GetField<int32_t>(VT_DEFAULT_COLLECTION, -1);
+ }
+ bool only_use_line_with_click() const {
+ return GetField<uint8_t>(VT_ONLY_USE_LINE_WITH_CLICK, 0) != 0;
+ }
+ bool split_tokens_on_selection_boundaries() const {
+ return GetField<uint8_t>(VT_SPLIT_TOKENS_ON_SELECTION_BOUNDARIES, 0) != 0;
+ }
+ const flatbuffers::Vector<flatbuffers::Offset<TokenizationCodepointRange>> *tokenization_codepoint_config() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<TokenizationCodepointRange>> *>(VT_TOKENIZATION_CODEPOINT_CONFIG);
+ }
+ libtextclassifier2::FeatureProcessorOptions_::CenterTokenSelectionMethod center_token_selection_method() const {
+ return static_cast<libtextclassifier2::FeatureProcessorOptions_::CenterTokenSelectionMethod>(GetField<int32_t>(VT_CENTER_TOKEN_SELECTION_METHOD, 0));
+ }
+ bool snap_label_span_boundaries_to_containing_tokens() const {
+ return GetField<uint8_t>(VT_SNAP_LABEL_SPAN_BOUNDARIES_TO_CONTAINING_TOKENS, 0) != 0;
+ }
+ const flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CodepointRange>> *supported_codepoint_ranges() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CodepointRange>> *>(VT_SUPPORTED_CODEPOINT_RANGES);
+ }
+ const flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CodepointRange>> *internal_tokenizer_codepoint_ranges() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CodepointRange>> *>(VT_INTERNAL_TOKENIZER_CODEPOINT_RANGES);
+ }
+ float min_supported_codepoint_ratio() const {
+ return GetField<float>(VT_MIN_SUPPORTED_CODEPOINT_RATIO, 0.0f);
+ }
+ int32_t feature_version() const {
+ return GetField<int32_t>(VT_FEATURE_VERSION, 0);
+ }
+ libtextclassifier2::FeatureProcessorOptions_::TokenizationType tokenization_type() const {
+ return static_cast<libtextclassifier2::FeatureProcessorOptions_::TokenizationType>(GetField<int32_t>(VT_TOKENIZATION_TYPE, 0));
+ }
+ bool icu_preserve_whitespace_tokens() const {
+ return GetField<uint8_t>(VT_ICU_PRESERVE_WHITESPACE_TOKENS, 0) != 0;
+ }
+ const flatbuffers::Vector<int32_t> *ignored_span_boundary_codepoints() const {
+ return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_IGNORED_SPAN_BOUNDARY_CODEPOINTS);
+ }
+ const libtextclassifier2::FeatureProcessorOptions_::BoundsSensitiveFeatures *bounds_sensitive_features() const {
+ return GetPointer<const libtextclassifier2::FeatureProcessorOptions_::BoundsSensitiveFeatures *>(VT_BOUNDS_SENSITIVE_FEATURES);
+ }
+ const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *allowed_chargrams() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_ALLOWED_CHARGRAMS);
+ }
+ bool tokenize_on_script_change() const {
+ return GetField<uint8_t>(VT_TOKENIZE_ON_SCRIPT_CHANGE, 0) != 0;
+ }
+ int32_t embedding_quantization_bits() const {
+ return GetField<int32_t>(VT_EMBEDDING_QUANTIZATION_BITS, 8);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int32_t>(verifier, VT_NUM_BUCKETS) &&
+ VerifyField<int32_t>(verifier, VT_EMBEDDING_SIZE) &&
+ VerifyField<int32_t>(verifier, VT_CONTEXT_SIZE) &&
+ VerifyField<int32_t>(verifier, VT_MAX_SELECTION_SPAN) &&
+ VerifyOffset(verifier, VT_CHARGRAM_ORDERS) &&
+ verifier.Verify(chargram_orders()) &&
+ VerifyField<int32_t>(verifier, VT_MAX_WORD_LENGTH) &&
+ VerifyField<uint8_t>(verifier, VT_UNICODE_AWARE_FEATURES) &&
+ VerifyField<uint8_t>(verifier, VT_EXTRACT_CASE_FEATURE) &&
+ VerifyField<uint8_t>(verifier, VT_EXTRACT_SELECTION_MASK_FEATURE) &&
+ VerifyOffset(verifier, VT_REGEXP_FEATURE) &&
+ verifier.Verify(regexp_feature()) &&
+ verifier.VerifyVectorOfStrings(regexp_feature()) &&
+ VerifyField<uint8_t>(verifier, VT_REMAP_DIGITS) &&
+ VerifyField<uint8_t>(verifier, VT_LOWERCASE_TOKENS) &&
+ VerifyField<uint8_t>(verifier, VT_SELECTION_REDUCED_OUTPUT_SPACE) &&
+ VerifyOffset(verifier, VT_COLLECTIONS) &&
+ verifier.Verify(collections()) &&
+ verifier.VerifyVectorOfStrings(collections()) &&
+ VerifyField<int32_t>(verifier, VT_DEFAULT_COLLECTION) &&
+ VerifyField<uint8_t>(verifier, VT_ONLY_USE_LINE_WITH_CLICK) &&
+ VerifyField<uint8_t>(verifier, VT_SPLIT_TOKENS_ON_SELECTION_BOUNDARIES) &&
+ VerifyOffset(verifier, VT_TOKENIZATION_CODEPOINT_CONFIG) &&
+ verifier.Verify(tokenization_codepoint_config()) &&
+ verifier.VerifyVectorOfTables(tokenization_codepoint_config()) &&
+ VerifyField<int32_t>(verifier, VT_CENTER_TOKEN_SELECTION_METHOD) &&
+ VerifyField<uint8_t>(verifier, VT_SNAP_LABEL_SPAN_BOUNDARIES_TO_CONTAINING_TOKENS) &&
+ VerifyOffset(verifier, VT_SUPPORTED_CODEPOINT_RANGES) &&
+ verifier.Verify(supported_codepoint_ranges()) &&
+ verifier.VerifyVectorOfTables(supported_codepoint_ranges()) &&
+ VerifyOffset(verifier, VT_INTERNAL_TOKENIZER_CODEPOINT_RANGES) &&
+ verifier.Verify(internal_tokenizer_codepoint_ranges()) &&
+ verifier.VerifyVectorOfTables(internal_tokenizer_codepoint_ranges()) &&
+ VerifyField<float>(verifier, VT_MIN_SUPPORTED_CODEPOINT_RATIO) &&
+ VerifyField<int32_t>(verifier, VT_FEATURE_VERSION) &&
+ VerifyField<int32_t>(verifier, VT_TOKENIZATION_TYPE) &&
+ VerifyField<uint8_t>(verifier, VT_ICU_PRESERVE_WHITESPACE_TOKENS) &&
+ VerifyOffset(verifier, VT_IGNORED_SPAN_BOUNDARY_CODEPOINTS) &&
+ verifier.Verify(ignored_span_boundary_codepoints()) &&
+ VerifyOffset(verifier, VT_BOUNDS_SENSITIVE_FEATURES) &&
+ verifier.VerifyTable(bounds_sensitive_features()) &&
+ VerifyOffset(verifier, VT_ALLOWED_CHARGRAMS) &&
+ verifier.Verify(allowed_chargrams()) &&
+ verifier.VerifyVectorOfStrings(allowed_chargrams()) &&
+ VerifyField<uint8_t>(verifier, VT_TOKENIZE_ON_SCRIPT_CHANGE) &&
+ VerifyField<int32_t>(verifier, VT_EMBEDDING_QUANTIZATION_BITS) &&
+ verifier.EndTable();
+ }
+ FeatureProcessorOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(FeatureProcessorOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<FeatureProcessorOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const FeatureProcessorOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct FeatureProcessorOptionsBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_num_buckets(int32_t num_buckets) {
+ fbb_.AddElement<int32_t>(FeatureProcessorOptions::VT_NUM_BUCKETS, num_buckets, -1);
+ }
+ void add_embedding_size(int32_t embedding_size) {
+ fbb_.AddElement<int32_t>(FeatureProcessorOptions::VT_EMBEDDING_SIZE, embedding_size, -1);
+ }
+ void add_context_size(int32_t context_size) {
+ fbb_.AddElement<int32_t>(FeatureProcessorOptions::VT_CONTEXT_SIZE, context_size, -1);
+ }
+ void add_max_selection_span(int32_t max_selection_span) {
+ fbb_.AddElement<int32_t>(FeatureProcessorOptions::VT_MAX_SELECTION_SPAN, max_selection_span, -1);
+ }
+ void add_chargram_orders(flatbuffers::Offset<flatbuffers::Vector<int32_t>> chargram_orders) {
+ fbb_.AddOffset(FeatureProcessorOptions::VT_CHARGRAM_ORDERS, chargram_orders);
+ }
+ void add_max_word_length(int32_t max_word_length) {
+ fbb_.AddElement<int32_t>(FeatureProcessorOptions::VT_MAX_WORD_LENGTH, max_word_length, 20);
+ }
+ void add_unicode_aware_features(bool unicode_aware_features) {
+ fbb_.AddElement<uint8_t>(FeatureProcessorOptions::VT_UNICODE_AWARE_FEATURES, static_cast<uint8_t>(unicode_aware_features), 0);
+ }
+ void add_extract_case_feature(bool extract_case_feature) {
+ fbb_.AddElement<uint8_t>(FeatureProcessorOptions::VT_EXTRACT_CASE_FEATURE, static_cast<uint8_t>(extract_case_feature), 0);
+ }
+ void add_extract_selection_mask_feature(bool extract_selection_mask_feature) {
+ fbb_.AddElement<uint8_t>(FeatureProcessorOptions::VT_EXTRACT_SELECTION_MASK_FEATURE, static_cast<uint8_t>(extract_selection_mask_feature), 0);
+ }
+ void add_regexp_feature(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> regexp_feature) {
+ fbb_.AddOffset(FeatureProcessorOptions::VT_REGEXP_FEATURE, regexp_feature);
+ }
+ void add_remap_digits(bool remap_digits) {
+ fbb_.AddElement<uint8_t>(FeatureProcessorOptions::VT_REMAP_DIGITS, static_cast<uint8_t>(remap_digits), 0);
+ }
+ void add_lowercase_tokens(bool lowercase_tokens) {
+ fbb_.AddElement<uint8_t>(FeatureProcessorOptions::VT_LOWERCASE_TOKENS, static_cast<uint8_t>(lowercase_tokens), 0);
+ }
+ void add_selection_reduced_output_space(bool selection_reduced_output_space) {
+ fbb_.AddElement<uint8_t>(FeatureProcessorOptions::VT_SELECTION_REDUCED_OUTPUT_SPACE, static_cast<uint8_t>(selection_reduced_output_space), 1);
+ }
+ void add_collections(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> collections) {
+ fbb_.AddOffset(FeatureProcessorOptions::VT_COLLECTIONS, collections);
+ }
+ void add_default_collection(int32_t default_collection) {
+ fbb_.AddElement<int32_t>(FeatureProcessorOptions::VT_DEFAULT_COLLECTION, default_collection, -1);
+ }
+ void add_only_use_line_with_click(bool only_use_line_with_click) {
+ fbb_.AddElement<uint8_t>(FeatureProcessorOptions::VT_ONLY_USE_LINE_WITH_CLICK, static_cast<uint8_t>(only_use_line_with_click), 0);
+ }
+ void add_split_tokens_on_selection_boundaries(bool split_tokens_on_selection_boundaries) {
+ fbb_.AddElement<uint8_t>(FeatureProcessorOptions::VT_SPLIT_TOKENS_ON_SELECTION_BOUNDARIES, static_cast<uint8_t>(split_tokens_on_selection_boundaries), 0);
+ }
+ void add_tokenization_codepoint_config(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<TokenizationCodepointRange>>> tokenization_codepoint_config) {
+ fbb_.AddOffset(FeatureProcessorOptions::VT_TOKENIZATION_CODEPOINT_CONFIG, tokenization_codepoint_config);
+ }
+ void add_center_token_selection_method(libtextclassifier2::FeatureProcessorOptions_::CenterTokenSelectionMethod center_token_selection_method) {
+ fbb_.AddElement<int32_t>(FeatureProcessorOptions::VT_CENTER_TOKEN_SELECTION_METHOD, static_cast<int32_t>(center_token_selection_method), 0);
+ }
+ void add_snap_label_span_boundaries_to_containing_tokens(bool snap_label_span_boundaries_to_containing_tokens) {
+ fbb_.AddElement<uint8_t>(FeatureProcessorOptions::VT_SNAP_LABEL_SPAN_BOUNDARIES_TO_CONTAINING_TOKENS, static_cast<uint8_t>(snap_label_span_boundaries_to_containing_tokens), 0);
+ }
+ void add_supported_codepoint_ranges(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CodepointRange>>> supported_codepoint_ranges) {
+ fbb_.AddOffset(FeatureProcessorOptions::VT_SUPPORTED_CODEPOINT_RANGES, supported_codepoint_ranges);
+ }
+ void add_internal_tokenizer_codepoint_ranges(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CodepointRange>>> internal_tokenizer_codepoint_ranges) {
+ fbb_.AddOffset(FeatureProcessorOptions::VT_INTERNAL_TOKENIZER_CODEPOINT_RANGES, internal_tokenizer_codepoint_ranges);
+ }
+ void add_min_supported_codepoint_ratio(float min_supported_codepoint_ratio) {
+ fbb_.AddElement<float>(FeatureProcessorOptions::VT_MIN_SUPPORTED_CODEPOINT_RATIO, min_supported_codepoint_ratio, 0.0f);
+ }
+ void add_feature_version(int32_t feature_version) {
+ fbb_.AddElement<int32_t>(FeatureProcessorOptions::VT_FEATURE_VERSION, feature_version, 0);
+ }
+ void add_tokenization_type(libtextclassifier2::FeatureProcessorOptions_::TokenizationType tokenization_type) {
+ fbb_.AddElement<int32_t>(FeatureProcessorOptions::VT_TOKENIZATION_TYPE, static_cast<int32_t>(tokenization_type), 0);
+ }
+ void add_icu_preserve_whitespace_tokens(bool icu_preserve_whitespace_tokens) {
+ fbb_.AddElement<uint8_t>(FeatureProcessorOptions::VT_ICU_PRESERVE_WHITESPACE_TOKENS, static_cast<uint8_t>(icu_preserve_whitespace_tokens), 0);
+ }
+ void add_ignored_span_boundary_codepoints(flatbuffers::Offset<flatbuffers::Vector<int32_t>> ignored_span_boundary_codepoints) {
+ fbb_.AddOffset(FeatureProcessorOptions::VT_IGNORED_SPAN_BOUNDARY_CODEPOINTS, ignored_span_boundary_codepoints);
+ }
+ void add_bounds_sensitive_features(flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::BoundsSensitiveFeatures> bounds_sensitive_features) {
+ fbb_.AddOffset(FeatureProcessorOptions::VT_BOUNDS_SENSITIVE_FEATURES, bounds_sensitive_features);
+ }
+ void add_allowed_chargrams(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> allowed_chargrams) {
+ fbb_.AddOffset(FeatureProcessorOptions::VT_ALLOWED_CHARGRAMS, allowed_chargrams);
+ }
+ void add_tokenize_on_script_change(bool tokenize_on_script_change) {
+ fbb_.AddElement<uint8_t>(FeatureProcessorOptions::VT_TOKENIZE_ON_SCRIPT_CHANGE, static_cast<uint8_t>(tokenize_on_script_change), 0);
+ }
+ void add_embedding_quantization_bits(int32_t embedding_quantization_bits) {
+ fbb_.AddElement<int32_t>(FeatureProcessorOptions::VT_EMBEDDING_QUANTIZATION_BITS, embedding_quantization_bits, 8);
+ }
+ explicit FeatureProcessorOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ FeatureProcessorOptionsBuilder &operator=(const FeatureProcessorOptionsBuilder &);
+ flatbuffers::Offset<FeatureProcessorOptions> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<FeatureProcessorOptions>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<FeatureProcessorOptions> CreateFeatureProcessorOptions(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ int32_t num_buckets = -1,
+ int32_t embedding_size = -1,
+ int32_t context_size = -1,
+ int32_t max_selection_span = -1,
+ flatbuffers::Offset<flatbuffers::Vector<int32_t>> chargram_orders = 0,
+ int32_t max_word_length = 20,
+ bool unicode_aware_features = false,
+ bool extract_case_feature = false,
+ bool extract_selection_mask_feature = false,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> regexp_feature = 0,
+ bool remap_digits = false,
+ bool lowercase_tokens = false,
+ bool selection_reduced_output_space = true,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> collections = 0,
+ int32_t default_collection = -1,
+ bool only_use_line_with_click = false,
+ bool split_tokens_on_selection_boundaries = false,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<TokenizationCodepointRange>>> tokenization_codepoint_config = 0,
+ libtextclassifier2::FeatureProcessorOptions_::CenterTokenSelectionMethod center_token_selection_method = libtextclassifier2::FeatureProcessorOptions_::CenterTokenSelectionMethod_DEFAULT_CENTER_TOKEN_METHOD,
+ bool snap_label_span_boundaries_to_containing_tokens = false,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CodepointRange>>> supported_codepoint_ranges = 0,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CodepointRange>>> internal_tokenizer_codepoint_ranges = 0,
+ float min_supported_codepoint_ratio = 0.0f,
+ int32_t feature_version = 0,
+ libtextclassifier2::FeatureProcessorOptions_::TokenizationType tokenization_type = libtextclassifier2::FeatureProcessorOptions_::TokenizationType_INVALID_TOKENIZATION_TYPE,
+ bool icu_preserve_whitespace_tokens = false,
+ flatbuffers::Offset<flatbuffers::Vector<int32_t>> ignored_span_boundary_codepoints = 0,
+ flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::BoundsSensitiveFeatures> bounds_sensitive_features = 0,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> allowed_chargrams = 0,
+ bool tokenize_on_script_change = false,
+ int32_t embedding_quantization_bits = 8) {
+ FeatureProcessorOptionsBuilder builder_(_fbb);
+ builder_.add_embedding_quantization_bits(embedding_quantization_bits);
+ builder_.add_allowed_chargrams(allowed_chargrams);
+ builder_.add_bounds_sensitive_features(bounds_sensitive_features);
+ builder_.add_ignored_span_boundary_codepoints(ignored_span_boundary_codepoints);
+ builder_.add_tokenization_type(tokenization_type);
+ builder_.add_feature_version(feature_version);
+ builder_.add_min_supported_codepoint_ratio(min_supported_codepoint_ratio);
+ builder_.add_internal_tokenizer_codepoint_ranges(internal_tokenizer_codepoint_ranges);
+ builder_.add_supported_codepoint_ranges(supported_codepoint_ranges);
+ builder_.add_center_token_selection_method(center_token_selection_method);
+ builder_.add_tokenization_codepoint_config(tokenization_codepoint_config);
+ builder_.add_default_collection(default_collection);
+ builder_.add_collections(collections);
+ builder_.add_regexp_feature(regexp_feature);
+ builder_.add_max_word_length(max_word_length);
+ builder_.add_chargram_orders(chargram_orders);
+ builder_.add_max_selection_span(max_selection_span);
+ builder_.add_context_size(context_size);
+ builder_.add_embedding_size(embedding_size);
+ builder_.add_num_buckets(num_buckets);
+ builder_.add_tokenize_on_script_change(tokenize_on_script_change);
+ builder_.add_icu_preserve_whitespace_tokens(icu_preserve_whitespace_tokens);
+ builder_.add_snap_label_span_boundaries_to_containing_tokens(snap_label_span_boundaries_to_containing_tokens);
+ builder_.add_split_tokens_on_selection_boundaries(split_tokens_on_selection_boundaries);
+ builder_.add_only_use_line_with_click(only_use_line_with_click);
+ builder_.add_selection_reduced_output_space(selection_reduced_output_space);
+ builder_.add_lowercase_tokens(lowercase_tokens);
+ builder_.add_remap_digits(remap_digits);
+ builder_.add_extract_selection_mask_feature(extract_selection_mask_feature);
+ builder_.add_extract_case_feature(extract_case_feature);
+ builder_.add_unicode_aware_features(unicode_aware_features);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<FeatureProcessorOptions> CreateFeatureProcessorOptionsDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ int32_t num_buckets = -1,
+ int32_t embedding_size = -1,
+ int32_t context_size = -1,
+ int32_t max_selection_span = -1,
+ const std::vector<int32_t> *chargram_orders = nullptr,
+ int32_t max_word_length = 20,
+ bool unicode_aware_features = false,
+ bool extract_case_feature = false,
+ bool extract_selection_mask_feature = false,
+ const std::vector<flatbuffers::Offset<flatbuffers::String>> *regexp_feature = nullptr,
+ bool remap_digits = false,
+ bool lowercase_tokens = false,
+ bool selection_reduced_output_space = true,
+ const std::vector<flatbuffers::Offset<flatbuffers::String>> *collections = nullptr,
+ int32_t default_collection = -1,
+ bool only_use_line_with_click = false,
+ bool split_tokens_on_selection_boundaries = false,
+ const std::vector<flatbuffers::Offset<TokenizationCodepointRange>> *tokenization_codepoint_config = nullptr,
+ libtextclassifier2::FeatureProcessorOptions_::CenterTokenSelectionMethod center_token_selection_method = libtextclassifier2::FeatureProcessorOptions_::CenterTokenSelectionMethod_DEFAULT_CENTER_TOKEN_METHOD,
+ bool snap_label_span_boundaries_to_containing_tokens = false,
+ const std::vector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CodepointRange>> *supported_codepoint_ranges = nullptr,
+ const std::vector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CodepointRange>> *internal_tokenizer_codepoint_ranges = nullptr,
+ float min_supported_codepoint_ratio = 0.0f,
+ int32_t feature_version = 0,
+ libtextclassifier2::FeatureProcessorOptions_::TokenizationType tokenization_type = libtextclassifier2::FeatureProcessorOptions_::TokenizationType_INVALID_TOKENIZATION_TYPE,
+ bool icu_preserve_whitespace_tokens = false,
+ const std::vector<int32_t> *ignored_span_boundary_codepoints = nullptr,
+ flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::BoundsSensitiveFeatures> bounds_sensitive_features = 0,
+ const std::vector<flatbuffers::Offset<flatbuffers::String>> *allowed_chargrams = nullptr,
+ bool tokenize_on_script_change = false,
+ int32_t embedding_quantization_bits = 8) {
+ return libtextclassifier2::CreateFeatureProcessorOptions(
+ _fbb,
+ num_buckets,
+ embedding_size,
+ context_size,
+ max_selection_span,
+ chargram_orders ? _fbb.CreateVector<int32_t>(*chargram_orders) : 0,
+ max_word_length,
+ unicode_aware_features,
+ extract_case_feature,
+ extract_selection_mask_feature,
+ regexp_feature ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*regexp_feature) : 0,
+ remap_digits,
+ lowercase_tokens,
+ selection_reduced_output_space,
+ collections ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*collections) : 0,
+ default_collection,
+ only_use_line_with_click,
+ split_tokens_on_selection_boundaries,
+ tokenization_codepoint_config ? _fbb.CreateVector<flatbuffers::Offset<TokenizationCodepointRange>>(*tokenization_codepoint_config) : 0,
+ center_token_selection_method,
+ snap_label_span_boundaries_to_containing_tokens,
+ supported_codepoint_ranges ? _fbb.CreateVector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CodepointRange>>(*supported_codepoint_ranges) : 0,
+ internal_tokenizer_codepoint_ranges ? _fbb.CreateVector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CodepointRange>>(*internal_tokenizer_codepoint_ranges) : 0,
+ min_supported_codepoint_ratio,
+ feature_version,
+ tokenization_type,
+ icu_preserve_whitespace_tokens,
+ ignored_span_boundary_codepoints ? _fbb.CreateVector<int32_t>(*ignored_span_boundary_codepoints) : 0,
+ bounds_sensitive_features,
+ allowed_chargrams ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*allowed_chargrams) : 0,
+ tokenize_on_script_change,
+ embedding_quantization_bits);
+}
+
+flatbuffers::Offset<FeatureProcessorOptions> CreateFeatureProcessorOptions(flatbuffers::FlatBufferBuilder &_fbb, const FeatureProcessorOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
inline SelectionModelOptionsT *SelectionModelOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new SelectionModelOptionsT();
UnPackTo(_o, _resolver);
@@ -1701,6 +2210,7 @@
(void)_resolver;
{ auto _e = strip_unpaired_brackets(); _o->strip_unpaired_brackets = _e; };
{ auto _e = symmetry_context_size(); _o->symmetry_context_size = _e; };
+ { auto _e = batch_size(); _o->batch_size = _e; };
}
inline flatbuffers::Offset<SelectionModelOptions> SelectionModelOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SelectionModelOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
@@ -1713,10 +2223,12 @@
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const SelectionModelOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
auto _strip_unpaired_brackets = _o->strip_unpaired_brackets;
auto _symmetry_context_size = _o->symmetry_context_size;
+ auto _batch_size = _o->batch_size;
return libtextclassifier2::CreateSelectionModelOptions(
_fbb,
_strip_unpaired_brackets,
- _symmetry_context_size);
+ _symmetry_context_size,
+ _batch_size);
}
inline ClassificationModelOptionsT *ClassificationModelOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
@@ -1748,33 +2260,7 @@
_phone_max_num_digits);
}
-inline RegexModelOptionsT *RegexModelOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
- auto _o = new RegexModelOptionsT();
- UnPackTo(_o, _resolver);
- return _o;
-}
-
-inline void RegexModelOptions::UnPackTo(RegexModelOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
- (void)_o;
- (void)_resolver;
- { auto _e = patterns(); if (_e) { _o->patterns.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->patterns[_i] = std::unique_ptr<libtextclassifier2::RegexModelOptions_::PatternT>(_e->Get(_i)->UnPack(_resolver)); } } };
-}
-
-inline flatbuffers::Offset<RegexModelOptions> RegexModelOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const RegexModelOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
- return CreateRegexModelOptions(_fbb, _o, _rehasher);
-}
-
-inline flatbuffers::Offset<RegexModelOptions> CreateRegexModelOptions(flatbuffers::FlatBufferBuilder &_fbb, const RegexModelOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
- (void)_rehasher;
- (void)_o;
- struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const RegexModelOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
- auto _patterns = _o->patterns.size() ? _fbb.CreateVector<flatbuffers::Offset<libtextclassifier2::RegexModelOptions_::Pattern>> (_o->patterns.size(), [](size_t i, _VectorArgs *__va) { return CreatePattern(*__va->__fbb, __va->__o->patterns[i].get(), __va->__rehasher); }, &_va ) : 0;
- return libtextclassifier2::CreateRegexModelOptions(
- _fbb,
- _patterns);
-}
-
-namespace RegexModelOptions_ {
+namespace RegexModel_ {
inline PatternT *Pattern::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new PatternT();
@@ -1787,6 +2273,11 @@
(void)_resolver;
{ auto _e = collection_name(); if (_e) _o->collection_name = _e->str(); };
{ auto _e = pattern(); if (_e) _o->pattern = _e->str(); };
+ { auto _e = enabled_for_annotation(); _o->enabled_for_annotation = _e; };
+ { auto _e = enabled_for_classification(); _o->enabled_for_classification = _e; };
+ { auto _e = enabled_for_selection(); _o->enabled_for_selection = _e; };
+ { auto _e = target_classification_score(); _o->target_classification_score = _e; };
+ { auto _e = priority_score(); _o->priority_score = _e; };
}
inline flatbuffers::Offset<Pattern> Pattern::Pack(flatbuffers::FlatBufferBuilder &_fbb, const PatternT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
@@ -1799,72 +2290,174 @@
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const PatternT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
auto _collection_name = _o->collection_name.empty() ? 0 : _fbb.CreateString(_o->collection_name);
auto _pattern = _o->pattern.empty() ? 0 : _fbb.CreateString(_o->pattern);
- return libtextclassifier2::RegexModelOptions_::CreatePattern(
+ auto _enabled_for_annotation = _o->enabled_for_annotation;
+ auto _enabled_for_classification = _o->enabled_for_classification;
+ auto _enabled_for_selection = _o->enabled_for_selection;
+ auto _target_classification_score = _o->target_classification_score;
+ auto _priority_score = _o->priority_score;
+ return libtextclassifier2::RegexModel_::CreatePattern(
_fbb,
_collection_name,
- _pattern);
+ _pattern,
+ _enabled_for_annotation,
+ _enabled_for_classification,
+ _enabled_for_selection,
+ _target_classification_score,
+ _priority_score);
}
-} // namespace RegexModelOptions_
+} // namespace RegexModel_
-inline StructuredRegexModelT *StructuredRegexModel::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
- auto _o = new StructuredRegexModelT();
+inline RegexModelT *RegexModel::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new RegexModelT();
UnPackTo(_o, _resolver);
return _o;
}
-inline void StructuredRegexModel::UnPackTo(StructuredRegexModelT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+inline void RegexModel::UnPackTo(RegexModelT *_o, const flatbuffers::resolver_function_t *_resolver) const {
(void)_o;
(void)_resolver;
- { auto _e = patterns(); if (_e) { _o->patterns.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->patterns[_i] = std::unique_ptr<libtextclassifier2::StructuredRegexModel_::StructuredPatternT>(_e->Get(_i)->UnPack(_resolver)); } } };
+ { auto _e = patterns(); if (_e) { _o->patterns.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->patterns[_i] = std::unique_ptr<libtextclassifier2::RegexModel_::PatternT>(_e->Get(_i)->UnPack(_resolver)); } } };
}
-inline flatbuffers::Offset<StructuredRegexModel> StructuredRegexModel::Pack(flatbuffers::FlatBufferBuilder &_fbb, const StructuredRegexModelT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
- return CreateStructuredRegexModel(_fbb, _o, _rehasher);
+inline flatbuffers::Offset<RegexModel> RegexModel::Pack(flatbuffers::FlatBufferBuilder &_fbb, const RegexModelT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateRegexModel(_fbb, _o, _rehasher);
}
-inline flatbuffers::Offset<StructuredRegexModel> CreateStructuredRegexModel(flatbuffers::FlatBufferBuilder &_fbb, const StructuredRegexModelT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+inline flatbuffers::Offset<RegexModel> CreateRegexModel(flatbuffers::FlatBufferBuilder &_fbb, const RegexModelT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
(void)_rehasher;
(void)_o;
- struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const StructuredRegexModelT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
- auto _patterns = _o->patterns.size() ? _fbb.CreateVector<flatbuffers::Offset<libtextclassifier2::StructuredRegexModel_::StructuredPattern>> (_o->patterns.size(), [](size_t i, _VectorArgs *__va) { return CreateStructuredPattern(*__va->__fbb, __va->__o->patterns[i].get(), __va->__rehasher); }, &_va ) : 0;
- return libtextclassifier2::CreateStructuredRegexModel(
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const RegexModelT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _patterns = _o->patterns.size() ? _fbb.CreateVector<flatbuffers::Offset<libtextclassifier2::RegexModel_::Pattern>> (_o->patterns.size(), [](size_t i, _VectorArgs *__va) { return CreatePattern(*__va->__fbb, __va->__o->patterns[i].get(), __va->__rehasher); }, &_va ) : 0;
+ return libtextclassifier2::CreateRegexModel(
_fbb,
_patterns);
}
-namespace StructuredRegexModel_ {
-
-inline StructuredPatternT *StructuredPattern::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
- auto _o = new StructuredPatternT();
+inline DatetimeModelPatternT *DatetimeModelPattern::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new DatetimeModelPatternT();
UnPackTo(_o, _resolver);
return _o;
}
-inline void StructuredPattern::UnPackTo(StructuredPatternT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+inline void DatetimeModelPattern::UnPackTo(DatetimeModelPatternT *_o, const flatbuffers::resolver_function_t *_resolver) const {
(void)_o;
(void)_resolver;
- { auto _e = pattern(); if (_e) _o->pattern = _e->str(); };
- { auto _e = node_names(); if (_e) { _o->node_names.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->node_names[_i] = _e->Get(_i)->str(); } } };
+ { auto _e = regexes(); if (_e) { _o->regexes.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->regexes[_i] = _e->Get(_i)->str(); } } };
+ { auto _e = locales(); if (_e) { _o->locales.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->locales[_i] = _e->Get(_i); } } };
+ { auto _e = target_classification_score(); _o->target_classification_score = _e; };
+ { auto _e = priority_score(); _o->priority_score = _e; };
}
-inline flatbuffers::Offset<StructuredPattern> StructuredPattern::Pack(flatbuffers::FlatBufferBuilder &_fbb, const StructuredPatternT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
- return CreateStructuredPattern(_fbb, _o, _rehasher);
+inline flatbuffers::Offset<DatetimeModelPattern> DatetimeModelPattern::Pack(flatbuffers::FlatBufferBuilder &_fbb, const DatetimeModelPatternT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateDatetimeModelPattern(_fbb, _o, _rehasher);
}
-inline flatbuffers::Offset<StructuredPattern> CreateStructuredPattern(flatbuffers::FlatBufferBuilder &_fbb, const StructuredPatternT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+inline flatbuffers::Offset<DatetimeModelPattern> CreateDatetimeModelPattern(flatbuffers::FlatBufferBuilder &_fbb, const DatetimeModelPatternT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
(void)_rehasher;
(void)_o;
- struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const StructuredPatternT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
- auto _pattern = _o->pattern.empty() ? 0 : _fbb.CreateString(_o->pattern);
- auto _node_names = _o->node_names.size() ? _fbb.CreateVectorOfStrings(_o->node_names) : 0;
- return libtextclassifier2::StructuredRegexModel_::CreateStructuredPattern(
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const DatetimeModelPatternT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _regexes = _o->regexes.size() ? _fbb.CreateVectorOfStrings(_o->regexes) : 0;
+ auto _locales = _o->locales.size() ? _fbb.CreateVector(_o->locales) : 0;
+ auto _target_classification_score = _o->target_classification_score;
+ auto _priority_score = _o->priority_score;
+ return libtextclassifier2::CreateDatetimeModelPattern(
_fbb,
- _pattern,
- _node_names);
+ _regexes,
+ _locales,
+ _target_classification_score,
+ _priority_score);
}
-} // namespace StructuredRegexModel_
+inline DatetimeModelExtractorT *DatetimeModelExtractor::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new DatetimeModelExtractorT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void DatetimeModelExtractor::UnPackTo(DatetimeModelExtractorT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ { auto _e = extractor(); _o->extractor = _e; };
+ { auto _e = pattern(); if (_e) _o->pattern = _e->str(); };
+ { auto _e = locales(); if (_e) { _o->locales.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->locales[_i] = _e->Get(_i); } } };
+}
+
+inline flatbuffers::Offset<DatetimeModelExtractor> DatetimeModelExtractor::Pack(flatbuffers::FlatBufferBuilder &_fbb, const DatetimeModelExtractorT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateDatetimeModelExtractor(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<DatetimeModelExtractor> CreateDatetimeModelExtractor(flatbuffers::FlatBufferBuilder &_fbb, const DatetimeModelExtractorT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const DatetimeModelExtractorT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _extractor = _o->extractor;
+ auto _pattern = _o->pattern.empty() ? 0 : _fbb.CreateString(_o->pattern);
+ auto _locales = _o->locales.size() ? _fbb.CreateVector(_o->locales) : 0;
+ return libtextclassifier2::CreateDatetimeModelExtractor(
+ _fbb,
+ _extractor,
+ _pattern,
+ _locales);
+}
+
+inline DatetimeModelT *DatetimeModel::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new DatetimeModelT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void DatetimeModel::UnPackTo(DatetimeModelT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ { auto _e = locales(); if (_e) { _o->locales.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->locales[_i] = _e->Get(_i)->str(); } } };
+ { auto _e = patterns(); if (_e) { _o->patterns.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->patterns[_i] = std::unique_ptr<DatetimeModelPatternT>(_e->Get(_i)->UnPack(_resolver)); } } };
+ { auto _e = extractors(); if (_e) { _o->extractors.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->extractors[_i] = std::unique_ptr<DatetimeModelExtractorT>(_e->Get(_i)->UnPack(_resolver)); } } };
+}
+
+inline flatbuffers::Offset<DatetimeModel> DatetimeModel::Pack(flatbuffers::FlatBufferBuilder &_fbb, const DatetimeModelT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateDatetimeModel(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<DatetimeModel> CreateDatetimeModel(flatbuffers::FlatBufferBuilder &_fbb, const DatetimeModelT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const DatetimeModelT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _locales = _o->locales.size() ? _fbb.CreateVectorOfStrings(_o->locales) : 0;
+ auto _patterns = _o->patterns.size() ? _fbb.CreateVector<flatbuffers::Offset<DatetimeModelPattern>> (_o->patterns.size(), [](size_t i, _VectorArgs *__va) { return CreateDatetimeModelPattern(*__va->__fbb, __va->__o->patterns[i].get(), __va->__rehasher); }, &_va ) : 0;
+ auto _extractors = _o->extractors.size() ? _fbb.CreateVector<flatbuffers::Offset<DatetimeModelExtractor>> (_o->extractors.size(), [](size_t i, _VectorArgs *__va) { return CreateDatetimeModelExtractor(*__va->__fbb, __va->__o->extractors[i].get(), __va->__rehasher); }, &_va ) : 0;
+ return libtextclassifier2::CreateDatetimeModel(
+ _fbb,
+ _locales,
+ _patterns,
+ _extractors);
+}
+
+inline ModelTriggeringOptionsT *ModelTriggeringOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new ModelTriggeringOptionsT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void ModelTriggeringOptions::UnPackTo(ModelTriggeringOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ { auto _e = min_annotate_confidence(); _o->min_annotate_confidence = _e; };
+}
+
+inline flatbuffers::Offset<ModelTriggeringOptions> ModelTriggeringOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ModelTriggeringOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateModelTriggeringOptions(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<ModelTriggeringOptions> CreateModelTriggeringOptions(flatbuffers::FlatBufferBuilder &_fbb, const ModelTriggeringOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const ModelTriggeringOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _min_annotate_confidence = _o->min_annotate_confidence;
+ return libtextclassifier2::CreateModelTriggeringOptions(
+ _fbb,
+ _min_annotate_confidence);
+}
inline ModelT *Model::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new ModelT();
@@ -1875,17 +2468,18 @@
inline void Model::UnPackTo(ModelT *_o, const flatbuffers::resolver_function_t *_resolver) const {
(void)_o;
(void)_resolver;
- { auto _e = language(); if (_e) _o->language = _e->str(); };
+ { auto _e = locales(); if (_e) _o->locales = _e->str(); };
{ auto _e = version(); _o->version = _e; };
{ auto _e = selection_feature_options(); if (_e) _o->selection_feature_options = std::unique_ptr<FeatureProcessorOptionsT>(_e->UnPack(_resolver)); };
{ auto _e = classification_feature_options(); if (_e) _o->classification_feature_options = std::unique_ptr<FeatureProcessorOptionsT>(_e->UnPack(_resolver)); };
{ auto _e = selection_model(); if (_e) { _o->selection_model.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->selection_model[_i] = _e->Get(_i); } } };
{ auto _e = classification_model(); if (_e) { _o->classification_model.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->classification_model[_i] = _e->Get(_i); } } };
{ auto _e = embedding_model(); if (_e) { _o->embedding_model.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->embedding_model[_i] = _e->Get(_i); } } };
- { auto _e = regex_options(); if (_e) _o->regex_options = std::unique_ptr<RegexModelOptionsT>(_e->UnPack(_resolver)); };
+ { auto _e = regex_model(); if (_e) _o->regex_model = std::unique_ptr<RegexModelT>(_e->UnPack(_resolver)); };
{ auto _e = selection_options(); if (_e) _o->selection_options = std::unique_ptr<SelectionModelOptionsT>(_e->UnPack(_resolver)); };
{ auto _e = classification_options(); if (_e) _o->classification_options = std::unique_ptr<ClassificationModelOptionsT>(_e->UnPack(_resolver)); };
- { auto _e = regex_model(); if (_e) _o->regex_model = std::unique_ptr<StructuredRegexModelT>(_e->UnPack(_resolver)); };
+ { auto _e = datetime_model(); if (_e) _o->datetime_model = std::unique_ptr<DatetimeModelT>(_e->UnPack(_resolver)); };
+ { auto _e = triggering_options(); if (_e) _o->triggering_options = std::unique_ptr<ModelTriggeringOptionsT>(_e->UnPack(_resolver)); };
}
inline flatbuffers::Offset<Model> Model::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ModelT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
@@ -1896,30 +2490,32 @@
(void)_rehasher;
(void)_o;
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const ModelT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
- auto _language = _o->language.empty() ? 0 : _fbb.CreateString(_o->language);
+ auto _locales = _o->locales.empty() ? 0 : _fbb.CreateString(_o->locales);
auto _version = _o->version;
auto _selection_feature_options = _o->selection_feature_options ? CreateFeatureProcessorOptions(_fbb, _o->selection_feature_options.get(), _rehasher) : 0;
auto _classification_feature_options = _o->classification_feature_options ? CreateFeatureProcessorOptions(_fbb, _o->classification_feature_options.get(), _rehasher) : 0;
auto _selection_model = _o->selection_model.size() ? _fbb.CreateVector(_o->selection_model) : 0;
auto _classification_model = _o->classification_model.size() ? _fbb.CreateVector(_o->classification_model) : 0;
auto _embedding_model = _o->embedding_model.size() ? _fbb.CreateVector(_o->embedding_model) : 0;
- auto _regex_options = _o->regex_options ? CreateRegexModelOptions(_fbb, _o->regex_options.get(), _rehasher) : 0;
+ auto _regex_model = _o->regex_model ? CreateRegexModel(_fbb, _o->regex_model.get(), _rehasher) : 0;
auto _selection_options = _o->selection_options ? CreateSelectionModelOptions(_fbb, _o->selection_options.get(), _rehasher) : 0;
auto _classification_options = _o->classification_options ? CreateClassificationModelOptions(_fbb, _o->classification_options.get(), _rehasher) : 0;
- auto _regex_model = _o->regex_model ? CreateStructuredRegexModel(_fbb, _o->regex_model.get(), _rehasher) : 0;
+ auto _datetime_model = _o->datetime_model ? CreateDatetimeModel(_fbb, _o->datetime_model.get(), _rehasher) : 0;
+ auto _triggering_options = _o->triggering_options ? CreateModelTriggeringOptions(_fbb, _o->triggering_options.get(), _rehasher) : 0;
return libtextclassifier2::CreateModel(
_fbb,
- _language,
+ _locales,
_version,
_selection_feature_options,
_classification_feature_options,
_selection_model,
_classification_model,
_embedding_model,
- _regex_options,
+ _regex_model,
_selection_options,
_classification_options,
- _regex_model);
+ _datetime_model,
+ _triggering_options);
}
inline TokenizationCodepointRangeT *TokenizationCodepointRange::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
@@ -1957,128 +2553,6 @@
_script_id);
}
-inline FeatureProcessorOptionsT *FeatureProcessorOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
- auto _o = new FeatureProcessorOptionsT();
- UnPackTo(_o, _resolver);
- return _o;
-}
-
-inline void FeatureProcessorOptions::UnPackTo(FeatureProcessorOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
- (void)_o;
- (void)_resolver;
- { auto _e = num_buckets(); _o->num_buckets = _e; };
- { auto _e = embedding_size(); _o->embedding_size = _e; };
- { auto _e = context_size(); _o->context_size = _e; };
- { auto _e = max_selection_span(); _o->max_selection_span = _e; };
- { auto _e = chargram_orders(); if (_e) { _o->chargram_orders.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->chargram_orders[_i] = _e->Get(_i); } } };
- { auto _e = max_word_length(); _o->max_word_length = _e; };
- { auto _e = unicode_aware_features(); _o->unicode_aware_features = _e; };
- { auto _e = extract_case_feature(); _o->extract_case_feature = _e; };
- { auto _e = extract_selection_mask_feature(); _o->extract_selection_mask_feature = _e; };
- { auto _e = regexp_feature(); if (_e) { _o->regexp_feature.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->regexp_feature[_i] = _e->Get(_i)->str(); } } };
- { auto _e = remap_digits(); _o->remap_digits = _e; };
- { auto _e = lowercase_tokens(); _o->lowercase_tokens = _e; };
- { auto _e = selection_reduced_output_space(); _o->selection_reduced_output_space = _e; };
- { auto _e = collections(); if (_e) { _o->collections.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->collections[_i] = _e->Get(_i)->str(); } } };
- { auto _e = default_collection(); _o->default_collection = _e; };
- { auto _e = only_use_line_with_click(); _o->only_use_line_with_click = _e; };
- { auto _e = split_tokens_on_selection_boundaries(); _o->split_tokens_on_selection_boundaries = _e; };
- { auto _e = tokenization_codepoint_config(); if (_e) { _o->tokenization_codepoint_config.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->tokenization_codepoint_config[_i] = std::unique_ptr<TokenizationCodepointRangeT>(_e->Get(_i)->UnPack(_resolver)); } } };
- { auto _e = center_token_selection_method(); _o->center_token_selection_method = _e; };
- { auto _e = snap_label_span_boundaries_to_containing_tokens(); _o->snap_label_span_boundaries_to_containing_tokens = _e; };
- { auto _e = supported_codepoint_ranges(); if (_e) { _o->supported_codepoint_ranges.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->supported_codepoint_ranges[_i] = std::unique_ptr<libtextclassifier2::FeatureProcessorOptions_::CodepointRangeT>(_e->Get(_i)->UnPack(_resolver)); } } };
- { auto _e = internal_tokenizer_codepoint_ranges(); if (_e) { _o->internal_tokenizer_codepoint_ranges.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->internal_tokenizer_codepoint_ranges[_i] = std::unique_ptr<libtextclassifier2::FeatureProcessorOptions_::CodepointRangeT>(_e->Get(_i)->UnPack(_resolver)); } } };
- { auto _e = min_supported_codepoint_ratio(); _o->min_supported_codepoint_ratio = _e; };
- { auto _e = feature_version(); _o->feature_version = _e; };
- { auto _e = tokenization_type(); _o->tokenization_type = _e; };
- { auto _e = icu_preserve_whitespace_tokens(); _o->icu_preserve_whitespace_tokens = _e; };
- { auto _e = ignored_span_boundary_codepoints(); if (_e) { _o->ignored_span_boundary_codepoints.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->ignored_span_boundary_codepoints[_i] = _e->Get(_i); } } };
- { auto _e = click_random_token_in_selection(); _o->click_random_token_in_selection = _e; };
- { auto _e = alternative_collection_map(); if (_e) { _o->alternative_collection_map.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->alternative_collection_map[_i] = std::unique_ptr<libtextclassifier2::FeatureProcessorOptions_::CollectionMapEntryT>(_e->Get(_i)->UnPack(_resolver)); } } };
- { auto _e = bounds_sensitive_features(); if (_e) _o->bounds_sensitive_features = std::unique_ptr<libtextclassifier2::FeatureProcessorOptions_::BoundsSensitiveFeaturesT>(_e->UnPack(_resolver)); };
- { auto _e = split_selection_candidates(); _o->split_selection_candidates = _e; };
- { auto _e = allowed_chargrams(); if (_e) { _o->allowed_chargrams.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->allowed_chargrams[_i] = _e->Get(_i)->str(); } } };
- { auto _e = tokenize_on_script_change(); _o->tokenize_on_script_change = _e; };
-}
-
-inline flatbuffers::Offset<FeatureProcessorOptions> FeatureProcessorOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const FeatureProcessorOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
- return CreateFeatureProcessorOptions(_fbb, _o, _rehasher);
-}
-
-inline flatbuffers::Offset<FeatureProcessorOptions> CreateFeatureProcessorOptions(flatbuffers::FlatBufferBuilder &_fbb, const FeatureProcessorOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
- (void)_rehasher;
- (void)_o;
- struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const FeatureProcessorOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
- auto _num_buckets = _o->num_buckets;
- auto _embedding_size = _o->embedding_size;
- auto _context_size = _o->context_size;
- auto _max_selection_span = _o->max_selection_span;
- auto _chargram_orders = _o->chargram_orders.size() ? _fbb.CreateVector(_o->chargram_orders) : 0;
- auto _max_word_length = _o->max_word_length;
- auto _unicode_aware_features = _o->unicode_aware_features;
- auto _extract_case_feature = _o->extract_case_feature;
- auto _extract_selection_mask_feature = _o->extract_selection_mask_feature;
- auto _regexp_feature = _o->regexp_feature.size() ? _fbb.CreateVectorOfStrings(_o->regexp_feature) : 0;
- auto _remap_digits = _o->remap_digits;
- auto _lowercase_tokens = _o->lowercase_tokens;
- auto _selection_reduced_output_space = _o->selection_reduced_output_space;
- auto _collections = _o->collections.size() ? _fbb.CreateVectorOfStrings(_o->collections) : 0;
- auto _default_collection = _o->default_collection;
- auto _only_use_line_with_click = _o->only_use_line_with_click;
- auto _split_tokens_on_selection_boundaries = _o->split_tokens_on_selection_boundaries;
- auto _tokenization_codepoint_config = _o->tokenization_codepoint_config.size() ? _fbb.CreateVector<flatbuffers::Offset<TokenizationCodepointRange>> (_o->tokenization_codepoint_config.size(), [](size_t i, _VectorArgs *__va) { return CreateTokenizationCodepointRange(*__va->__fbb, __va->__o->tokenization_codepoint_config[i].get(), __va->__rehasher); }, &_va ) : 0;
- auto _center_token_selection_method = _o->center_token_selection_method;
- auto _snap_label_span_boundaries_to_containing_tokens = _o->snap_label_span_boundaries_to_containing_tokens;
- auto _supported_codepoint_ranges = _o->supported_codepoint_ranges.size() ? _fbb.CreateVector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CodepointRange>> (_o->supported_codepoint_ranges.size(), [](size_t i, _VectorArgs *__va) { return CreateCodepointRange(*__va->__fbb, __va->__o->supported_codepoint_ranges[i].get(), __va->__rehasher); }, &_va ) : 0;
- auto _internal_tokenizer_codepoint_ranges = _o->internal_tokenizer_codepoint_ranges.size() ? _fbb.CreateVector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CodepointRange>> (_o->internal_tokenizer_codepoint_ranges.size(), [](size_t i, _VectorArgs *__va) { return CreateCodepointRange(*__va->__fbb, __va->__o->internal_tokenizer_codepoint_ranges[i].get(), __va->__rehasher); }, &_va ) : 0;
- auto _min_supported_codepoint_ratio = _o->min_supported_codepoint_ratio;
- auto _feature_version = _o->feature_version;
- auto _tokenization_type = _o->tokenization_type;
- auto _icu_preserve_whitespace_tokens = _o->icu_preserve_whitespace_tokens;
- auto _ignored_span_boundary_codepoints = _o->ignored_span_boundary_codepoints.size() ? _fbb.CreateVector(_o->ignored_span_boundary_codepoints) : 0;
- auto _click_random_token_in_selection = _o->click_random_token_in_selection;
- auto _alternative_collection_map = _o->alternative_collection_map.size() ? _fbb.CreateVector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CollectionMapEntry>> (_o->alternative_collection_map.size(), [](size_t i, _VectorArgs *__va) { return CreateCollectionMapEntry(*__va->__fbb, __va->__o->alternative_collection_map[i].get(), __va->__rehasher); }, &_va ) : 0;
- auto _bounds_sensitive_features = _o->bounds_sensitive_features ? CreateBoundsSensitiveFeatures(_fbb, _o->bounds_sensitive_features.get(), _rehasher) : 0;
- auto _split_selection_candidates = _o->split_selection_candidates;
- auto _allowed_chargrams = _o->allowed_chargrams.size() ? _fbb.CreateVectorOfStrings(_o->allowed_chargrams) : 0;
- auto _tokenize_on_script_change = _o->tokenize_on_script_change;
- return libtextclassifier2::CreateFeatureProcessorOptions(
- _fbb,
- _num_buckets,
- _embedding_size,
- _context_size,
- _max_selection_span,
- _chargram_orders,
- _max_word_length,
- _unicode_aware_features,
- _extract_case_feature,
- _extract_selection_mask_feature,
- _regexp_feature,
- _remap_digits,
- _lowercase_tokens,
- _selection_reduced_output_space,
- _collections,
- _default_collection,
- _only_use_line_with_click,
- _split_tokens_on_selection_boundaries,
- _tokenization_codepoint_config,
- _center_token_selection_method,
- _snap_label_span_boundaries_to_containing_tokens,
- _supported_codepoint_ranges,
- _internal_tokenizer_codepoint_ranges,
- _min_supported_codepoint_ratio,
- _feature_version,
- _tokenization_type,
- _icu_preserve_whitespace_tokens,
- _ignored_span_boundary_codepoints,
- _click_random_token_in_selection,
- _alternative_collection_map,
- _bounds_sensitive_features,
- _split_selection_candidates,
- _allowed_chargrams,
- _tokenize_on_script_change);
-}
-
namespace FeatureProcessorOptions_ {
inline CodepointRangeT *CodepointRange::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
@@ -2110,35 +2584,6 @@
_end);
}
-inline CollectionMapEntryT *CollectionMapEntry::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
- auto _o = new CollectionMapEntryT();
- UnPackTo(_o, _resolver);
- return _o;
-}
-
-inline void CollectionMapEntry::UnPackTo(CollectionMapEntryT *_o, const flatbuffers::resolver_function_t *_resolver) const {
- (void)_o;
- (void)_resolver;
- { auto _e = key(); if (_e) _o->key = _e->str(); };
- { auto _e = value(); if (_e) _o->value = _e->str(); };
-}
-
-inline flatbuffers::Offset<CollectionMapEntry> CollectionMapEntry::Pack(flatbuffers::FlatBufferBuilder &_fbb, const CollectionMapEntryT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
- return CreateCollectionMapEntry(_fbb, _o, _rehasher);
-}
-
-inline flatbuffers::Offset<CollectionMapEntry> CreateCollectionMapEntry(flatbuffers::FlatBufferBuilder &_fbb, const CollectionMapEntryT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
- (void)_rehasher;
- (void)_o;
- struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const CollectionMapEntryT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
- auto _key = _o->key.empty() ? 0 : _fbb.CreateString(_o->key);
- auto _value = _o->value.empty() ? 0 : _fbb.CreateString(_o->value);
- return libtextclassifier2::FeatureProcessorOptions_::CreateCollectionMapEntry(
- _fbb,
- _key,
- _value);
-}
-
inline BoundsSensitiveFeaturesT *BoundsSensitiveFeatures::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new BoundsSensitiveFeaturesT();
UnPackTo(_o, _resolver);
@@ -2183,7 +2628,183 @@
_include_inside_length);
}
+inline AlternativeCollectionMapEntryT *AlternativeCollectionMapEntry::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new AlternativeCollectionMapEntryT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void AlternativeCollectionMapEntry::UnPackTo(AlternativeCollectionMapEntryT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ { auto _e = key(); if (_e) _o->key = _e->str(); };
+ { auto _e = value(); if (_e) _o->value = _e->str(); };
+}
+
+inline flatbuffers::Offset<AlternativeCollectionMapEntry> AlternativeCollectionMapEntry::Pack(flatbuffers::FlatBufferBuilder &_fbb, const AlternativeCollectionMapEntryT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateAlternativeCollectionMapEntry(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<AlternativeCollectionMapEntry> CreateAlternativeCollectionMapEntry(flatbuffers::FlatBufferBuilder &_fbb, const AlternativeCollectionMapEntryT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const AlternativeCollectionMapEntryT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _key = _o->key.empty() ? 0 : _fbb.CreateString(_o->key);
+ auto _value = _o->value.empty() ? 0 : _fbb.CreateString(_o->value);
+ return libtextclassifier2::FeatureProcessorOptions_::CreateAlternativeCollectionMapEntry(
+ _fbb,
+ _key,
+ _value);
+}
+
} // namespace FeatureProcessorOptions_
+
+inline FeatureProcessorOptionsT *FeatureProcessorOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new FeatureProcessorOptionsT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void FeatureProcessorOptions::UnPackTo(FeatureProcessorOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ { auto _e = num_buckets(); _o->num_buckets = _e; };
+ { auto _e = embedding_size(); _o->embedding_size = _e; };
+ { auto _e = context_size(); _o->context_size = _e; };
+ { auto _e = max_selection_span(); _o->max_selection_span = _e; };
+ { auto _e = chargram_orders(); if (_e) { _o->chargram_orders.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->chargram_orders[_i] = _e->Get(_i); } } };
+ { auto _e = max_word_length(); _o->max_word_length = _e; };
+ { auto _e = unicode_aware_features(); _o->unicode_aware_features = _e; };
+ { auto _e = extract_case_feature(); _o->extract_case_feature = _e; };
+ { auto _e = extract_selection_mask_feature(); _o->extract_selection_mask_feature = _e; };
+ { auto _e = regexp_feature(); if (_e) { _o->regexp_feature.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->regexp_feature[_i] = _e->Get(_i)->str(); } } };
+ { auto _e = remap_digits(); _o->remap_digits = _e; };
+ { auto _e = lowercase_tokens(); _o->lowercase_tokens = _e; };
+ { auto _e = selection_reduced_output_space(); _o->selection_reduced_output_space = _e; };
+ { auto _e = collections(); if (_e) { _o->collections.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->collections[_i] = _e->Get(_i)->str(); } } };
+ { auto _e = default_collection(); _o->default_collection = _e; };
+ { auto _e = only_use_line_with_click(); _o->only_use_line_with_click = _e; };
+ { auto _e = split_tokens_on_selection_boundaries(); _o->split_tokens_on_selection_boundaries = _e; };
+ { auto _e = tokenization_codepoint_config(); if (_e) { _o->tokenization_codepoint_config.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->tokenization_codepoint_config[_i] = std::unique_ptr<TokenizationCodepointRangeT>(_e->Get(_i)->UnPack(_resolver)); } } };
+ { auto _e = center_token_selection_method(); _o->center_token_selection_method = _e; };
+ { auto _e = snap_label_span_boundaries_to_containing_tokens(); _o->snap_label_span_boundaries_to_containing_tokens = _e; };
+ { auto _e = supported_codepoint_ranges(); if (_e) { _o->supported_codepoint_ranges.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->supported_codepoint_ranges[_i] = std::unique_ptr<libtextclassifier2::FeatureProcessorOptions_::CodepointRangeT>(_e->Get(_i)->UnPack(_resolver)); } } };
+ { auto _e = internal_tokenizer_codepoint_ranges(); if (_e) { _o->internal_tokenizer_codepoint_ranges.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->internal_tokenizer_codepoint_ranges[_i] = std::unique_ptr<libtextclassifier2::FeatureProcessorOptions_::CodepointRangeT>(_e->Get(_i)->UnPack(_resolver)); } } };
+ { auto _e = min_supported_codepoint_ratio(); _o->min_supported_codepoint_ratio = _e; };
+ { auto _e = feature_version(); _o->feature_version = _e; };
+ { auto _e = tokenization_type(); _o->tokenization_type = _e; };
+ { auto _e = icu_preserve_whitespace_tokens(); _o->icu_preserve_whitespace_tokens = _e; };
+ { auto _e = ignored_span_boundary_codepoints(); if (_e) { _o->ignored_span_boundary_codepoints.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->ignored_span_boundary_codepoints[_i] = _e->Get(_i); } } };
+ { auto _e = bounds_sensitive_features(); if (_e) _o->bounds_sensitive_features = std::unique_ptr<libtextclassifier2::FeatureProcessorOptions_::BoundsSensitiveFeaturesT>(_e->UnPack(_resolver)); };
+ { auto _e = allowed_chargrams(); if (_e) { _o->allowed_chargrams.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->allowed_chargrams[_i] = _e->Get(_i)->str(); } } };
+ { auto _e = tokenize_on_script_change(); _o->tokenize_on_script_change = _e; };
+ { auto _e = embedding_quantization_bits(); _o->embedding_quantization_bits = _e; };
+}
+
+inline flatbuffers::Offset<FeatureProcessorOptions> FeatureProcessorOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const FeatureProcessorOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateFeatureProcessorOptions(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<FeatureProcessorOptions> CreateFeatureProcessorOptions(flatbuffers::FlatBufferBuilder &_fbb, const FeatureProcessorOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const FeatureProcessorOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _num_buckets = _o->num_buckets;
+ auto _embedding_size = _o->embedding_size;
+ auto _context_size = _o->context_size;
+ auto _max_selection_span = _o->max_selection_span;
+ auto _chargram_orders = _o->chargram_orders.size() ? _fbb.CreateVector(_o->chargram_orders) : 0;
+ auto _max_word_length = _o->max_word_length;
+ auto _unicode_aware_features = _o->unicode_aware_features;
+ auto _extract_case_feature = _o->extract_case_feature;
+ auto _extract_selection_mask_feature = _o->extract_selection_mask_feature;
+ auto _regexp_feature = _o->regexp_feature.size() ? _fbb.CreateVectorOfStrings(_o->regexp_feature) : 0;
+ auto _remap_digits = _o->remap_digits;
+ auto _lowercase_tokens = _o->lowercase_tokens;
+ auto _selection_reduced_output_space = _o->selection_reduced_output_space;
+ auto _collections = _o->collections.size() ? _fbb.CreateVectorOfStrings(_o->collections) : 0;
+ auto _default_collection = _o->default_collection;
+ auto _only_use_line_with_click = _o->only_use_line_with_click;
+ auto _split_tokens_on_selection_boundaries = _o->split_tokens_on_selection_boundaries;
+ auto _tokenization_codepoint_config = _o->tokenization_codepoint_config.size() ? _fbb.CreateVector<flatbuffers::Offset<TokenizationCodepointRange>> (_o->tokenization_codepoint_config.size(), [](size_t i, _VectorArgs *__va) { return CreateTokenizationCodepointRange(*__va->__fbb, __va->__o->tokenization_codepoint_config[i].get(), __va->__rehasher); }, &_va ) : 0;
+ auto _center_token_selection_method = _o->center_token_selection_method;
+ auto _snap_label_span_boundaries_to_containing_tokens = _o->snap_label_span_boundaries_to_containing_tokens;
+ auto _supported_codepoint_ranges = _o->supported_codepoint_ranges.size() ? _fbb.CreateVector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CodepointRange>> (_o->supported_codepoint_ranges.size(), [](size_t i, _VectorArgs *__va) { return CreateCodepointRange(*__va->__fbb, __va->__o->supported_codepoint_ranges[i].get(), __va->__rehasher); }, &_va ) : 0;
+ auto _internal_tokenizer_codepoint_ranges = _o->internal_tokenizer_codepoint_ranges.size() ? _fbb.CreateVector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CodepointRange>> (_o->internal_tokenizer_codepoint_ranges.size(), [](size_t i, _VectorArgs *__va) { return CreateCodepointRange(*__va->__fbb, __va->__o->internal_tokenizer_codepoint_ranges[i].get(), __va->__rehasher); }, &_va ) : 0;
+ auto _min_supported_codepoint_ratio = _o->min_supported_codepoint_ratio;
+ auto _feature_version = _o->feature_version;
+ auto _tokenization_type = _o->tokenization_type;
+ auto _icu_preserve_whitespace_tokens = _o->icu_preserve_whitespace_tokens;
+ auto _ignored_span_boundary_codepoints = _o->ignored_span_boundary_codepoints.size() ? _fbb.CreateVector(_o->ignored_span_boundary_codepoints) : 0;
+ auto _bounds_sensitive_features = _o->bounds_sensitive_features ? CreateBoundsSensitiveFeatures(_fbb, _o->bounds_sensitive_features.get(), _rehasher) : 0;
+ auto _allowed_chargrams = _o->allowed_chargrams.size() ? _fbb.CreateVectorOfStrings(_o->allowed_chargrams) : 0;
+ auto _tokenize_on_script_change = _o->tokenize_on_script_change;
+ auto _embedding_quantization_bits = _o->embedding_quantization_bits;
+ return libtextclassifier2::CreateFeatureProcessorOptions(
+ _fbb,
+ _num_buckets,
+ _embedding_size,
+ _context_size,
+ _max_selection_span,
+ _chargram_orders,
+ _max_word_length,
+ _unicode_aware_features,
+ _extract_case_feature,
+ _extract_selection_mask_feature,
+ _regexp_feature,
+ _remap_digits,
+ _lowercase_tokens,
+ _selection_reduced_output_space,
+ _collections,
+ _default_collection,
+ _only_use_line_with_click,
+ _split_tokens_on_selection_boundaries,
+ _tokenization_codepoint_config,
+ _center_token_selection_method,
+ _snap_label_span_boundaries_to_containing_tokens,
+ _supported_codepoint_ranges,
+ _internal_tokenizer_codepoint_ranges,
+ _min_supported_codepoint_ratio,
+ _feature_version,
+ _tokenization_type,
+ _icu_preserve_whitespace_tokens,
+ _ignored_span_boundary_codepoints,
+ _bounds_sensitive_features,
+ _allowed_chargrams,
+ _tokenize_on_script_change,
+ _embedding_quantization_bits);
+}
+
+inline const libtextclassifier2::Model *GetModel(const void *buf) {
+ return flatbuffers::GetRoot<libtextclassifier2::Model>(buf);
+}
+
+inline const char *ModelIdentifier() {
+ return "TC2 ";
+}
+
+inline bool ModelBufferHasIdentifier(const void *buf) {
+ return flatbuffers::BufferHasIdentifier(
+ buf, ModelIdentifier());
+}
+
+inline bool VerifyModelBuffer(
+ flatbuffers::Verifier &verifier) {
+ return verifier.VerifyBuffer<libtextclassifier2::Model>(ModelIdentifier());
+}
+
+inline void FinishModelBuffer(
+ flatbuffers::FlatBufferBuilder &fbb,
+ flatbuffers::Offset<libtextclassifier2::Model> root) {
+ fbb.Finish(root, ModelIdentifier());
+}
+
+inline std::unique_ptr<ModelT> UnPackModel(
+ const void *buf,
+ const flatbuffers::resolver_function_t *res = nullptr) {
+ return std::unique_ptr<ModelT>(GetModel(buf)->UnPack(res));
+}
+
} // namespace libtextclassifier2
-#endif // FLATBUFFERS_GENERATED_MODEL_LIBTEXTCLASSIFIER2_FEATUREPROCESSOROPTIONS__H_
+#endif // FLATBUFFERS_GENERATED_MODEL_LIBTEXTCLASSIFIER2_H_
diff --git a/models/textclassifier.en.model b/models/textclassifier.en.model
index 9814c93..f2be859 100644
--- a/models/textclassifier.en.model
+++ b/models/textclassifier.en.model
Binary files differ
diff --git a/quantization.cc b/quantization.cc
new file mode 100644
index 0000000..1a34565
--- /dev/null
+++ b/quantization.cc
@@ -0,0 +1,92 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "quantization.h"
+
+#include "util/base/logging.h"
+
+namespace libtextclassifier2 {
+namespace {
+float DequantizeValue(int num_sparse_features, int quantization_bias,
+ float multiplier, int value) {
+ return 1.0 / num_sparse_features * (value - quantization_bias) * multiplier;
+}
+
+void DequantizeAdd8bit(const float* scales, const uint8* embeddings,
+ int bytes_per_embedding, const int num_sparse_features,
+ const int bucket_id, float* dest, int dest_size) {
+ static const int kQuantizationBias8bit = 128;
+ const float multiplier = scales[bucket_id];
+ for (int k = 0; k < dest_size; ++k) {
+ dest[k] +=
+ DequantizeValue(num_sparse_features, kQuantizationBias8bit, multiplier,
+ embeddings[bucket_id * bytes_per_embedding + k]);
+ }
+}
+
+void DequantizeAddNBit(const float* scales, const uint8* embeddings,
+ int bytes_per_embedding, int num_sparse_features,
+ int quantization_bits, int bucket_id, float* dest,
+ int dest_size) {
+ const int quantization_bias = 1 << (quantization_bits - 1);
+ const float multiplier = scales[bucket_id];
+ for (int i = 0; i < dest_size; ++i) {
+ const int bit_offset = i * quantization_bits;
+ const int read16_offset = bit_offset / 8;
+
+ uint16 data = embeddings[bucket_id * bytes_per_embedding + read16_offset];
+ // If we are not at the end of the embedding row, we can read 2-byte uint16,
+ // but if we are, we need to only read uint8.
+ if (read16_offset < bytes_per_embedding - 1) {
+ data |= embeddings[bucket_id * bytes_per_embedding + read16_offset + 1]
+ << 8;
+ }
+ int value = (data >> (bit_offset % 8)) & ((1 << quantization_bits) - 1);
+ dest[i] += DequantizeValue(num_sparse_features, quantization_bias,
+ multiplier, value);
+ }
+}
+} // namespace
+
+bool CheckQuantizationParams(int bytes_per_embedding, int quantization_bits,
+ int output_embedding_size) {
+ if (bytes_per_embedding * 8 / quantization_bits < output_embedding_size) {
+ return false;
+ }
+
+ return true;
+}
+
+bool DequantizeAdd(const float* scales, const uint8* embeddings,
+ int bytes_per_embedding, int num_sparse_features,
+ int quantization_bits, int bucket_id, float* dest,
+ int dest_size) {
+ if (quantization_bits == 8) {
+ DequantizeAdd8bit(scales, embeddings, bytes_per_embedding,
+ num_sparse_features, bucket_id, dest, dest_size);
+ } else if (quantization_bits != 8) {
+ DequantizeAddNBit(scales, embeddings, bytes_per_embedding,
+ num_sparse_features, quantization_bits, bucket_id, dest,
+ dest_size);
+ } else {
+ TC_LOG(ERROR) << "Unsupported quantization_bits: " << quantization_bits;
+ return false;
+ }
+
+ return true;
+}
+
+} // namespace libtextclassifier2
diff --git a/quantization.h b/quantization.h
new file mode 100644
index 0000000..c486640
--- /dev/null
+++ b/quantization.h
@@ -0,0 +1,39 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_QUANTIZATION_H_
+#define LIBTEXTCLASSIFIER_QUANTIZATION_H_
+
+#include "util/base/integral_types.h"
+
+namespace libtextclassifier2 {
+
+// Returns true if the quantization parameters are valid.
+bool CheckQuantizationParams(int bytes_per_embedding, int quantization_bits,
+ int output_embedding_size);
+
+// Dequantizes embeddings (quantized to 1 to 8 bits) into the floats they
+// represent. The algorithm proceeds by reading 2-byte words from the embedding
+// storage to handle well the cases when the quantized value crosses the byte-
+// boundary.
+bool DequantizeAdd(const float* scales, const uint8* embeddings,
+ int bytes_per_embedding, int num_sparse_features,
+ int quantization_bits, int bucket_id, float* dest,
+ int dest_size);
+
+} // namespace libtextclassifier2
+
+#endif // LIBTEXTCLASSIFIER_QUANTIZATION_H_
diff --git a/quantization_test.cc b/quantization_test.cc
new file mode 100644
index 0000000..088daaf
--- /dev/null
+++ b/quantization_test.cc
@@ -0,0 +1,163 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "quantization.h"
+
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+using testing::ElementsAreArray;
+using testing::FloatEq;
+using testing::Matcher;
+
+namespace libtextclassifier2 {
+namespace {
+
+Matcher<std::vector<float>> ElementsAreFloat(const std::vector<float>& values) {
+ std::vector<Matcher<float>> matchers;
+ for (const float value : values) {
+ matchers.push_back(FloatEq(value));
+ }
+ return ElementsAreArray(matchers);
+}
+
+TEST(QuantizationTest, DequantizeAdd8bit) {
+ std::vector<float> scales{{0.1, 9.0, -7.0}};
+ std::vector<uint8> embeddings{{/*0: */ 0x00, 0xFF, 0x09, 0x00,
+ /*1: */ 0xFF, 0x09, 0x00, 0xFF,
+ /*2: */ 0x09, 0x00, 0xFF, 0x09}};
+
+ const int quantization_bits = 8;
+ const int bytes_per_embedding = 4;
+ const int num_sparse_features = 7;
+ {
+ const int bucket_id = 0;
+ std::vector<float> dest(4, 0.0);
+ DequantizeAdd(scales.data(), embeddings.data(), bytes_per_embedding,
+ num_sparse_features, quantization_bits, bucket_id,
+ dest.data(), dest.size());
+
+ EXPECT_THAT(dest,
+ ElementsAreFloat(std::vector<float>{
+ // clang-format off
+ {1.0 / 7 * 0.1 * (0x00 - 128),
+ 1.0 / 7 * 0.1 * (0xFF - 128),
+ 1.0 / 7 * 0.1 * (0x09 - 128),
+ 1.0 / 7 * 0.1 * (0x00 - 128)}
+ // clang-format on
+ }));
+ }
+
+ {
+ const int bucket_id = 1;
+ std::vector<float> dest(4, 0.0);
+ DequantizeAdd(scales.data(), embeddings.data(), bytes_per_embedding,
+ num_sparse_features, quantization_bits, bucket_id,
+ dest.data(), dest.size());
+
+ EXPECT_THAT(dest,
+ ElementsAreFloat(std::vector<float>{
+ // clang-format off
+ {1.0 / 7 * 9.0 * (0xFF - 128),
+ 1.0 / 7 * 9.0 * (0x09 - 128),
+ 1.0 / 7 * 9.0 * (0x00 - 128),
+ 1.0 / 7 * 9.0 * (0xFF - 128)}
+ // clang-format on
+ }));
+ }
+}
+
+TEST(QuantizationTest, DequantizeAdd1bitZeros) {
+ const int bytes_per_embedding = 4;
+ const int num_buckets = 3;
+ const int num_sparse_features = 7;
+ const int quantization_bits = 1;
+ const int bucket_id = 1;
+
+ std::vector<float> scales(num_buckets);
+ std::vector<uint8> embeddings(bytes_per_embedding * num_buckets);
+ std::fill(scales.begin(), scales.end(), 1);
+ std::fill(embeddings.begin(), embeddings.end(), 0);
+
+ std::vector<float> dest(32);
+ DequantizeAdd(scales.data(), embeddings.data(), bytes_per_embedding,
+ num_sparse_features, quantization_bits, bucket_id, dest.data(),
+ dest.size());
+
+ std::vector<float> expected(32);
+ std::fill(expected.begin(), expected.end(),
+ 1.0 / num_sparse_features * (0 - 1));
+ EXPECT_THAT(dest, ElementsAreFloat(expected));
+}
+
+TEST(QuantizationTest, DequantizeAdd1bitOnes) {
+ const int bytes_per_embedding = 4;
+ const int num_buckets = 3;
+ const int num_sparse_features = 7;
+ const int quantization_bits = 1;
+ const int bucket_id = 1;
+
+ std::vector<float> scales(num_buckets, 1.0);
+ std::vector<uint8> embeddings(bytes_per_embedding * num_buckets, 0xFF);
+
+ std::vector<float> dest(32);
+ DequantizeAdd(scales.data(), embeddings.data(), bytes_per_embedding,
+ num_sparse_features, quantization_bits, bucket_id, dest.data(),
+ dest.size());
+ std::vector<float> expected(32);
+ std::fill(expected.begin(), expected.end(),
+ 1.0 / num_sparse_features * (1 - 1));
+ EXPECT_THAT(dest, ElementsAreFloat(expected));
+}
+
+TEST(QuantizationTest, DequantizeAdd3bit) {
+ const int bytes_per_embedding = 4;
+ const int num_buckets = 3;
+ const int num_sparse_features = 7;
+ const int quantization_bits = 3;
+ const int bucket_id = 1;
+
+ std::vector<float> scales(num_buckets, 1.0);
+ scales[1] = 9.0;
+ std::vector<uint8> embeddings(bytes_per_embedding * num_buckets, 0);
+ // For bucket_id=1, the embedding has values 0..9 for indices 0..9:
+ embeddings[4] = (1 << 7) | (1 << 6) | (1 << 4) | 1;
+ embeddings[5] = (1 << 6) | (1 << 4) | (1 << 3);
+ embeddings[6] = (1 << 4) | (1 << 3) | (1 << 2) | (1 << 1) | 1;
+
+ std::vector<float> dest(10);
+ DequantizeAdd(scales.data(), embeddings.data(), bytes_per_embedding,
+ num_sparse_features, quantization_bits, bucket_id, dest.data(),
+ dest.size());
+
+ std::vector<float> expected;
+ expected.push_back(1.0 / num_sparse_features * (1 - 4) * scales[bucket_id]);
+ expected.push_back(1.0 / num_sparse_features * (2 - 4) * scales[bucket_id]);
+ expected.push_back(1.0 / num_sparse_features * (3 - 4) * scales[bucket_id]);
+ expected.push_back(1.0 / num_sparse_features * (4 - 4) * scales[bucket_id]);
+ expected.push_back(1.0 / num_sparse_features * (5 - 4) * scales[bucket_id]);
+ expected.push_back(1.0 / num_sparse_features * (6 - 4) * scales[bucket_id]);
+ expected.push_back(1.0 / num_sparse_features * (7 - 4) * scales[bucket_id]);
+ expected.push_back(1.0 / num_sparse_features * (0 - 4) * scales[bucket_id]);
+ expected.push_back(1.0 / num_sparse_features * (0 - 4) * scales[bucket_id]);
+ expected.push_back(1.0 / num_sparse_features * (0 - 4) * scales[bucket_id]);
+ EXPECT_THAT(dest, ElementsAreFloat(expected));
+}
+
+} // namespace
+} // namespace libtextclassifier2
diff --git a/regex-base.cc b/regex-base.cc
deleted file mode 100644
index 790e453..0000000
--- a/regex-base.cc
+++ /dev/null
@@ -1,61 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "regex-base.h"
-
-namespace libtextclassifier2 {
-
-Rules::Rules(const std::string &locale) : locale_(locale) {}
-
-FlatBufferRules::FlatBufferRules(const std::string &locale, const Model *model)
- : Rules(locale), model_(model), rule_cache_() {
- for (int i = 0; i < model_->regex_model()->patterns()->Length(); i++) {
- auto regex_model = model_->regex_model()->patterns()->Get(i);
- for (int j = 0; j < regex_model->node_names()->Length(); j++) {
- std::string name = regex_model->node_names()->Get(j)->str();
- rule_cache_[name].push_back(regex_model->pattern());
- }
- }
-}
-
-bool FlatBufferRules::RuleForName(const std::string &name,
- std::string *out) const {
- const auto match = rule_cache_.find(name);
- if (match != rule_cache_.end()) {
- if (match->second.size() != 1) {
- TC_LOG(ERROR) << "Found " << match->second.size()
- << " rule where only 1 was expected.";
- return false;
- }
- *out = match->second[0]->str();
- return true;
- }
- return false;
-}
-
-const std::vector<std::string> FlatBufferRules::RulesForName(
- const std::string &name) const {
- std::vector<std::string> results;
- const auto match = rule_cache_.find(name);
- if (match != rule_cache_.end()) {
- for (auto &s : match->second) {
- results.push_back(s->str());
- }
- }
- return results;
-}
-
-} // namespace libtextclassifier2
diff --git a/regex-base.h b/regex-base.h
deleted file mode 100644
index 5856eba..0000000
--- a/regex-base.h
+++ /dev/null
@@ -1,472 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_REGEX_BASE_H_
-#define KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_REGEX_BASE_H_
-
-#include <iostream>
-#include <string>
-#include <unordered_map>
-#include <vector>
-
-#include "model_generated.h"
-#include "util/base/logging.h"
-#include "util/memory/mmap.h"
-#include "unicode/regex.h"
-#include "unicode/uchar.h"
-
-namespace libtextclassifier2 {
-
-// Encapsulates the start and end of a region of a string of a entity that
-// has been mapping to an element of type T
-template <class T>
-class SpanResult {
- private:
- const T data_;
- const int start_;
- const int end_;
-
- public:
- SpanResult(int start, int end, T data)
- : data_(data), start_(start), end_(end) {}
-
- const T &Data() const { return data_; }
-
- int Start() const { return start_; }
-
- int End() const { return end_; }
-};
-
-// Interface supplying class that provides a protocol for encapsulating a unit
-// of text processing that can match against strings and also extract values of
-// SpanResults of type T. Implemenations are expected to be thread-safe.
-template <class T>
-class Node {
- public:
- typedef SpanResult<T> Result;
- typedef std::vector<Result> Results;
-
- // Returns a boolean value if Node can find a region of the string which
- // matches it's logic.
- virtual bool Matches(const std::string &input) const = 0;
-
- // Populates the supplied Results vector with the values obtained by
- // extracted on any matching elements of the string.
- // Returns true if the processing yielded no error.
- // Returns false if their was an error processing the string.
- virtual bool Extract(const std::string &input, Results *result) const = 0;
-
- virtual ~Node() = default;
-};
-
-class Rules {
- public:
- explicit Rules(const std::string &locale);
- virtual ~Rules() = default;
- virtual const std::vector<std::string> RulesForName(
- const std::string &name) const = 0;
- virtual bool RuleForName(const std::string &name, std::string *out) const = 0;
-
- protected:
- const std::string locale_;
-};
-
-class FlatBufferRules : public Rules {
- public:
- FlatBufferRules(const std::string &locale, const Model *model);
- const std::vector<std::string> RulesForName(
- const std::string &name) const override;
- bool RuleForName(const std::string &name, std::string *out) const override;
-
- protected:
- const Model *model_;
- std::unordered_map<std::string, std::vector<const flatbuffers::String *>>
- rule_cache_;
-};
-
-// Abstract supplying class that provides a protocol for encapsulating a unit
-// of regular expression processing that can match against strings and also
-// extract values of SpanResults of type T. Implemenations are expected to be
-// thread-safe.
-// Implementors of this class are expected to call Init() prior to calling
-// Extract() or Match().
-template <class T>
-class RegexNode : public Node<T> {
- public:
- typedef typename Node<T>::Result Result;
- typedef typename Node<T>::Results Results;
-
- RegexNode() : pattern_() {}
-
- // A string representation of the regular expression used in this processing.
- const std::string pattern() const {
- std::string regex;
- pattern_->pattern().toUTF8String(regex);
- return regex;
- }
-
- bool Matches(const std::string &input) const override {
- UErrorCode status = U_ZERO_ERROR;
- const icu::UnicodeString unicode_context(input.c_str(), input.size(),
- "utf-8");
- std::unique_ptr<icu::RegexMatcher> matcher(
- pattern_->matcher(unicode_context, status));
-
- if (U_FAILURE(status)) {
- TC_LOG(ERROR) << "failed to load regex '" << pattern()
- << "': " << u_errorName(status);
- return false;
- }
- const bool res = matcher->find(status);
- if (U_FAILURE(status)) {
- TC_LOG(ERROR) << "failed to find with regex '" << pattern()
- << "': " << u_errorName(status);
- return false;
- }
- return res;
- }
-
- bool Extract(const std::string &input, Results *result) const override = 0;
- ~RegexNode() override {}
-
- protected:
- bool Init(const std::string ®ex) {
- UErrorCode status = U_ZERO_ERROR;
- pattern_ = std::unique_ptr<icu::RegexPattern>(
- icu::RegexPattern::compile(regex.c_str(), UREGEX_MULTILINE, status));
- if (U_FAILURE(status)) {
- TC_LOG(ERROR) << "failed to compile regex '" << pattern()
- << "': " << u_errorName(status);
- return false;
- }
- return true;
- }
-
- std::unique_ptr<icu::RegexPattern> pattern_;
-};
-
-// Class that encapsulates complex object matching and extract values of
-// from multiple sub-patterns, each of which is mapped to a string describing
-// the value.
-// Thread-safe.
-template <class T>
-class CompoundNode : public RegexNode<std::unordered_map<std::string, T>> {
- public:
- typedef std::unordered_map<std::string, std::unique_ptr<const Node<T>>>
- Extractors;
- typedef typename RegexNode<std::unordered_map<std::string, T>>::Result Result;
- typedef
- typename RegexNode<std::unordered_map<std::string, T>>::Results Results;
-
- static std::unique_ptr<CompoundNode<T>> Instance(const std::string &rule,
- const Extractors &extractors,
- const Rules &rules) {
- std::unique_ptr<CompoundNode<T>> node(new CompoundNode(extractors));
- if (!node->Init(rule, rules)) {
- return nullptr;
- }
- return node;
- }
-
- bool Extract(const std::string &input, Results *result) const override {
- UErrorCode status = U_ZERO_ERROR;
- const icu::UnicodeString unicode_context(input.c_str(), input.size(),
- "utf-8");
- std::unique_ptr<icu::RegexMatcher> matcher(
- RegexNode<std::unordered_map<std::string, T>>::pattern_->matcher(
- unicode_context, status));
- if (U_FAILURE(status)) {
- TC_LOG(ERROR) << "error loading regex '"
- << RegexNode<std::unordered_map<std::string, T>>::pattern()
- << "'" << u_errorName(status);
- return false;
- }
- while (matcher->find() && U_SUCCESS(status)) {
- const int start = matcher->start(status);
- if (U_FAILURE(status)) {
- TC_LOG(ERROR)
- << "failed to demarshall start '"
- << RegexNode<std::unordered_map<std::string, T>>::pattern() << "'"
- << u_errorName(status);
- return false;
- }
-
- const int end = matcher->end(status);
- if (U_FAILURE(status)) {
- TC_LOG(ERROR)
- << "failed to demarshall end '"
- << RegexNode<std::unordered_map<std::string, T>>::pattern() << "'"
- << u_errorName(status);
- return false;
- }
-
- std::unordered_map<std::string, T> extraction;
- for (auto name : groupnames_) {
- const int group_number = matcher->pattern().groupNumberFromName(
- icu::UnicodeString(name.c_str(), name.size(), "utf-8"), status);
- if (U_FAILURE(status)) {
- // We expect this to happen for optional named groups.
- continue;
- }
-
- std::string capture;
- matcher->group(group_number, status).toUTF8String(capture);
- std::vector<SpanResult<T>> sub_result;
-
- if (!extractors_->find(name)->second->Extract(capture, &sub_result)) {
- return false;
- }
- if (!sub_result.empty()) {
- extraction[name] = sub_result[0].Data();
- }
- }
- result->push_back(Result(start, end, extraction));
- }
- if (U_FAILURE(status)) {
- TC_LOG(ERROR) << "failed to extract '"
- << RegexNode<std::unordered_map<std::string, T>>::pattern()
- << "':" << u_errorName(status);
- return false;
- }
- return true;
- }
-
- private:
- const Extractors *extractors_;
- std::vector<std::string> groupnames_;
-
- explicit CompoundNode(const Extractors &extractors)
- : extractors_(&extractors), groupnames_() {}
-
- bool Init(const std::string &rule, const Rules &rules) {
- static const icu::RegexPattern *pattern = []() {
- UErrorCode status = U_ZERO_ERROR;
- return icu::RegexPattern::compile("[?]<([A-Z_]+)>", UREGEX_MULTILINE,
- status);
- }();
- if (!pattern) {
- return false;
- }
- std::string source = rule;
- UErrorCode status = U_ZERO_ERROR;
- const icu::UnicodeString unicode_context(source.c_str(), source.size(),
- "utf-8");
- std::unique_ptr<icu::RegexMatcher> matcher(
- pattern->matcher(unicode_context, status));
-
- std::unordered_map<std::string, std::string> swaps;
- while (matcher->find(status)) {
- std::string name;
- matcher->group(1, status).toUTF8String(name);
- if (U_FAILURE(status)) {
- TC_LOG(ERROR) << "failed to demarshall name " << u_errorName(status);
- return false;
- }
- groupnames_.push_back(name);
- }
- if (U_FAILURE(status)) {
- TC_LOG(ERROR) << "failed to execute the regex properly"
- << u_errorName(status);
- return false;
- }
- for (auto swap : swaps) {
- std::string::size_type n = 0;
- while ((n = source.find(swap.first, n)) != std::string::npos) {
- source.replace(n, swap.first.size(), swap.second);
- n += swap.second.size();
- }
- }
- return RegexNode<std::unordered_map<std::string, T>>::Init(source);
- }
-};
-
-// Class that managed multiple alternate Nodes that all yield the same result
-// and can be used interchangeably. It returns reults only from the first
-// successfully matching child node - this Nodes can be given to in in
-// order of precedence.
-// Thread-safe.
-template <class T>
-class OrNode : public Node<T> {
- public:
- typedef typename Node<T>::Result Result;
- typedef typename Node<T>::Results Results;
-
- explicit OrNode(std::vector<std::unique_ptr<const Node<T>>> alternatives)
- : alternatives_(std::move(alternatives)) {}
-
- bool Extract(const std::string &input, Results *result) const override {
- for (auto &alternative : alternatives_) {
- typename Node<T>::Results alternative_result;
- // NOTE: We are explicitly choosing to fall through errors on these
- // alternatives to try a lesser match instead of bailing on the user.
- if (alternative->Extract(input, &alternative_result)) {
- if (!alternative_result.empty()) {
- for (auto &s : alternative_result) {
- result->push_back(s);
- }
- return true;
- }
- }
- }
- return true;
- }
-
- bool Matches(const std::string &input) const override {
- for (auto &alternative : alternatives_) {
- if (alternative->Matches(input)) {
- return true;
- }
- }
- return false;
- }
-
- private:
- std::vector<std::unique_ptr<const Node<T>>> alternatives_;
-};
-
-// Class that managed multiple alternate RegexNodes that all yield the same
-// result and can be used interchangeably. It returns reults only from the first
-// successfully matching child node - this Nodes can be given to in in
-// order of precedence.
-// Thread-safe.
-template <class T>
-class OrRegexNode : public RegexNode<T> {
- public:
- typedef typename RegexNode<T>::Result Result;
- typedef typename RegexNode<T>::Results Results;
-
- bool Extract(const std::string &input, Results *result) const override {
- for (auto &alternative : alternatives_) {
- typename RegexNode<T>::Results alternative_result;
- // NOTE: we are explicitly choosing to fall through errors on these
- // alternatives to try a lesser match instead of bailing on the user
- if (alternative->Extract(input, &alternative_result)) {
- if (!alternative_result.empty()) {
- for (typename RegexNode<T>::Result &s : alternative_result) {
- result->push_back(s);
- }
- return true;
- }
- }
- }
- return true;
- }
-
- protected:
- std::vector<std::unique_ptr<const RegexNode<T>>> alternatives_;
-
- explicit OrRegexNode(
- std::vector<std::unique_ptr<const RegexNode<T>>> alternatives)
- : alternatives_(std::move(alternatives)) {}
-
- bool Init() {
- std::string pattern;
- for (int i = 0; i < alternatives_.size(); i++) {
- if (i == 0) {
- pattern = alternatives_[i]->pattern();
- } else {
- pattern += "|";
- pattern += alternatives_[i]->pattern();
- }
- }
- return RegexNode<T>::Init(pattern);
- }
-};
-
-// Class that yields a constant value for any string that matches the input
-// Thread-safe.
-template <class T>
-class MappingNode : public RegexNode<T> {
- public:
- typedef RegexNode<T> Parent;
- typedef typename Parent::Result Result;
- typedef typename Parent::Results Results;
-
- static std::unique_ptr<MappingNode<T>> Instance(const std::string &name,
- const T &value,
- const Rules &rules) {
- std::unique_ptr<MappingNode<T>> node(new MappingNode<T>(value));
- if (!node->Init(name, rules)) {
- return nullptr;
- }
- return node;
- }
-
- bool Extract(const std::string &input, Results *result) const override {
- UErrorCode status = U_ZERO_ERROR;
-
- const icu::UnicodeString unicode_context(input.c_str(), input.size(),
- "utf-8");
- std::unique_ptr<icu::RegexMatcher> matcher(
- Parent::pattern_->matcher(unicode_context, status));
-
- if (U_FAILURE(status)) {
- TC_LOG(ERROR) << "error loading regex ";
- return false;
- }
-
- while (matcher->find() && U_SUCCESS(status)) {
- const int start = matcher->start(status);
- if (U_FAILURE(status)) {
- TC_LOG(ERROR) << "failed to demarshall start " << u_errorName(status);
- return false;
- }
- const int end = matcher->end(status);
- if (U_FAILURE(status)) {
- TC_LOG(ERROR) << "failed to demarshall end " << u_errorName(status);
- return false;
- }
- result->push_back(Result(start, end, value_));
- }
- if (U_FAILURE(status)) {
- TC_LOG(ERROR) << "failed to demarshall end " << u_errorName(status);
- return false;
- }
- return true;
- }
-
- private:
- explicit MappingNode(const T value) : value_(value) {}
-
- const T value_;
-
- bool Init(const std::string &name, const Rules &rules) {
- std::string pattern;
- if (!rules.RuleForName(name, &pattern)) {
- TC_LOG(ERROR) << "failed to load rule for name '" << name << "'";
- return false;
- }
- return RegexNode<T>::Init(pattern);
- }
-};
-
-template <class T>
-bool BuildMappings(const Rules &rules,
- const std::vector<std::pair<std::string, T>> &pairs,
- std::vector<std::unique_ptr<const RegexNode<T>>> *mappings) {
- for (auto &p : pairs) {
- if (std::unique_ptr<RegexNode<T>> node =
- MappingNode<T>::Instance(p.first, p.second, rules)) {
- mappings->emplace_back(std::move(node));
- } else {
- return false;
- }
- }
- return true;
-}
-
-} // namespace libtextclassifier2
-#endif // KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_REGEX_BASE_H_
diff --git a/regex-number.cc b/regex-number.cc
deleted file mode 100644
index ed5a899..0000000
--- a/regex-number.cc
+++ /dev/null
@@ -1,254 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "regex-number.h"
-
-namespace libtextclassifier2 {
-
-std::unique_ptr<DigitNode> DigitNode::Instance(const Rules& rules) {
- std::unique_ptr<DigitNode> node(new DigitNode());
- if (!node->Init(rules)) {
- return nullptr;
- }
- return node;
-}
-
-DigitNode::DigitNode() {}
-
-constexpr const char* kDigits = "DIGITS";
-
-bool DigitNode::Init(const Rules& rules) {
- std::string pattern;
- if (!rules.RuleForName(kDigits, &pattern)) {
- TC_LOG(ERROR) << "failed to load pattern";
- return false;
- }
- return RegexNode<int>::Init(pattern);
-}
-
-bool DigitNode::Extract(const std::string& input, Results* result) const {
- UErrorCode status = U_ZERO_ERROR;
-
- const icu::UnicodeString unicode_context(input.c_str(), input.size(),
- "utf-8");
- const std::unique_ptr<icu::RegexMatcher> matcher(
- pattern_->matcher(unicode_context, status));
- if (U_FAILURE(status)) {
- TC_LOG(ERROR) << "failed to compile regex: " << u_errorName(status);
- return false;
- }
-
- while (matcher->find() && U_SUCCESS(status)) {
- const int start = matcher->start(status);
- if (U_FAILURE(status)) {
- TC_LOG(ERROR) << "failed to bind start: " << u_errorName(status);
- return false;
- }
- const int end = matcher->end(status);
- if (U_FAILURE(status)) {
- TC_LOG(ERROR) << "failed to bind end: " << u_errorName(status);
- return false;
- }
- std::string digit;
- matcher->group(status).toUTF8String(digit);
- if (U_FAILURE(status)) {
- TC_LOG(ERROR) << "failed to bind digit std::string: "
- << u_errorName(status);
- return false;
- }
- result->push_back(Result(start, end, stoi(digit)));
- }
- return true;
-}
-
-constexpr const char* kSignedDigits = "SIGNEDDIGITS";
-
-std::unique_ptr<SignedDigitNode> SignedDigitNode::Instance(const Rules& rules) {
- std::unique_ptr<SignedDigitNode> node(new SignedDigitNode());
- if (!node->Init(rules)) {
- return nullptr;
- }
- return node;
-}
-
-SignedDigitNode::SignedDigitNode() {}
-
-bool SignedDigitNode::Init(const Rules& rules) {
- std::string pattern;
- if (!rules.RuleForName(kSignedDigits, &pattern)) {
- TC_LOG(ERROR) << "failed to load pattern";
- return false;
- }
- return RegexNode<int>::Init(pattern);
-}
-
-bool SignedDigitNode::Extract(const std::string& input, Results* result) const {
- UErrorCode status = U_ZERO_ERROR;
-
- const icu::UnicodeString unicode_context(input.c_str(), input.size(),
- "utf-8");
- const std::unique_ptr<icu::RegexMatcher> matcher(
- pattern_->matcher(unicode_context, status));
- if (U_FAILURE(status)) {
- TC_LOG(ERROR) << "failed to compile regex: " << u_errorName(status);
- return false;
- }
-
- while (matcher->find() && U_SUCCESS(status)) {
- const int start = matcher->start(status);
- if (U_FAILURE(status)) {
- TC_LOG(ERROR) << "failed to bind start: " << u_errorName(status);
- return false;
- }
- const int end = matcher->end(status);
- if (U_FAILURE(status)) {
- TC_LOG(ERROR) << "failed to bind end: " << u_errorName(status);
- return false;
- }
- std::string digit;
- matcher->group(status).toUTF8String(digit);
- if (U_FAILURE(status)) {
- TC_LOG(ERROR) << "failed to bind digit std::string: "
- << u_errorName(status);
- return false;
- }
- result->push_back(Result(start, end, stoi(digit)));
- }
- return true;
-}
-
-constexpr const char* kZero = "ZERO";
-constexpr const char* kOne = "ONE";
-constexpr const char* kTwo = "TWO";
-constexpr const char* kThree = "THREE";
-constexpr const char* kFour = "FOUR";
-constexpr const char* kFive = "FIVE";
-constexpr const char* kSix = "SIX";
-constexpr const char* kSeven = "SEVEN";
-constexpr const char* kEight = "EIGHT";
-constexpr const char* kNine = "NINE";
-constexpr const char* kTen = "TEN";
-constexpr const char* kEleven = "ELEVEN";
-constexpr const char* kTwelve = "TWELVE";
-constexpr const char* kThirteen = "THIRTEEN";
-constexpr const char* kFourteen = "FOURTEEN";
-constexpr const char* kFifteen = "FIFTEEN";
-constexpr const char* kSixteen = "SIXTEEN";
-constexpr const char* kSeventeen = "SEVENTEEN";
-constexpr const char* kEighteen = "EIGHTEEN";
-constexpr const char* kNineteen = "NINETEEN";
-constexpr const char* kTwenty = "TWENTY";
-constexpr const char* kThirty = "THIRTY";
-constexpr const char* kForty = "FORTY";
-constexpr const char* kFifty = "FIFTY";
-constexpr const char* kSixty = "SIXTY";
-constexpr const char* kSeventy = "SEVENTY";
-constexpr const char* kEighty = "EIGHTY";
-constexpr const char* kNinety = "NINETY";
-constexpr const char* kHundred = "HUNDRED";
-constexpr const char* kThousand = "THOUSAND";
-
-std::unique_ptr<NumberNode> NumberNode::Instance(const Rules& rules) {
- const std::vector<std::pair<std::string, int>> name_values = {
- {kZero, 0}, {kOne, 1}, {kTwo, 2}, {kThree, 3},
- {kFour, 4}, {kFive, 5}, {kSix, 6}, {kSeven, 7},
- {kEight, 8}, {kNine, 9}, {kTen, 10}, {kEleven, 11},
- {kTwelve, 12}, {kThirteen, 13}, {kFourteen, 14}, {kFifteen, 15},
- {kSixteen, 16}, {kSeventeen, 17}, {kEighteen, 18}, {kNineteen, 19},
- {kTwenty, 20}, {kThirty, 30}, {kForty, 40}, {kFifty, 50},
- {kSixty, 60}, {kSeventy, 70}, {kEighty, 80}, {kNinety, 90},
- {kHundred, 100}, {kThousand, 1000},
- };
- std::vector<std::unique_ptr<const RegexNode<int>>> alternatives;
- if (!BuildMappings<int>(rules, name_values, &alternatives)) {
- return nullptr;
- }
- std::unique_ptr<NumberNode> node(new NumberNode(std::move(alternatives)));
- if (!node->Init()) {
- return nullptr;
- }
- return node;
-} // namespace libtextclassifier2
-
-bool NumberNode::Init() { return OrRegexNode<int>::Init(); }
-
-NumberNode::NumberNode(
- std::vector<std::unique_ptr<const RegexNode<int>>> alternatives)
- : OrRegexNode<int>(std::move(alternatives)) {}
-
-bool NumberNode::Extract(const std::string& input, Results* result) const {
- UErrorCode status = U_ZERO_ERROR;
-
- const icu::UnicodeString unicode_context(input.c_str(), input.size(),
- "utf-8");
- const std::unique_ptr<icu::RegexMatcher> matcher(
- RegexNode<int>::pattern_->matcher(unicode_context, status));
-
- OrRegexNode<int>::Results parts;
- int start = 0;
- int end = 0;
- while (matcher->find() && U_SUCCESS(status)) {
- std::string group;
- matcher->group(0, status).toUTF8String(group);
- int span_start = matcher->start(status);
- if (U_FAILURE(status)) {
- TC_LOG(ERROR) << "failed to demarshall start " << u_errorName(status);
- return false;
- }
- int span_end = matcher->end(status);
- if (U_FAILURE(status)) {
- TC_LOG(ERROR) << "failed to demarshall end " << u_errorName(status);
- }
- if (span_start < start) {
- start = span_start;
- }
- if (span_end < end) {
- end = span_end;
- }
-
- for (auto& child : alternatives_) {
- if (child->Matches(group)) {
- OrRegexNode<int>::Results group_results;
- if (!child->Extract(group, &group_results)) {
- return false;
- }
- for (OrRegexNode<int>::Result span : group_results) {
- parts.push_back(span);
- }
- }
- }
- }
- int sum = 0;
- int running_value = -1;
- // Simple math to make sure we handle written numerical modifiers correctly
- // so that :="fifty one thousand and one" maps to 51001 and not 50 1 1000 1.
- for (OrRegexNode<int>::Result part : parts) {
- if (running_value >= 0) {
- if (running_value > part.Data()) {
- sum += running_value;
- running_value = part.Data();
- } else {
- running_value *= part.Data();
- }
- } else {
- running_value = part.Data();
- }
- }
- sum += running_value;
- result->push_back(Result(start, end, sum));
- return true;
-}
-} // namespace libtextclassifier2
diff --git a/regex-number.h b/regex-number.h
deleted file mode 100644
index 286e57d..0000000
--- a/regex-number.h
+++ /dev/null
@@ -1,110 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_REGEX_NUMBER_H_
-#define KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_REGEX_NUMBER_H_
-
-#include "regex-base.h"
-
-namespace libtextclassifier2 {
-
-extern const char* const kZero;
-extern const char* const kOne;
-extern const char* const kTwo;
-extern const char* const kThree;
-extern const char* const kFour;
-extern const char* const kFive;
-extern const char* const kSix;
-extern const char* const kSeven;
-extern const char* const kEight;
-extern const char* const kNine;
-extern const char* const kTen;
-extern const char* const kEleven;
-extern const char* const kTwelve;
-extern const char* const kThirteen;
-extern const char* const kFourteen;
-extern const char* const kFifteen;
-extern const char* const kSixteen;
-extern const char* const kSeventeen;
-extern const char* const kEighteen;
-extern const char* const kNineteen;
-extern const char* const kTwenty;
-extern const char* const kThirty;
-extern const char* const kForty;
-extern const char* const kFifty;
-extern const char* const kSixty;
-extern const char* const kSeventy;
-extern const char* const kEighty;
-extern const char* const kNinety;
-extern const char* const kHundred;
-extern const char* const kThousand;
-
-extern const char* const kDigits;
-extern const char* const kSignedDigits;
-
-// Class that encapsulates a unsigned integer matching and extract values of
-// SpanResults of type int.
-// Thread-safe.
-class DigitNode : public RegexNode<int> {
- public:
- typedef typename RegexNode<int>::Result Result;
- typedef typename RegexNode<int>::Results Results;
-
- // Factory method for yielding a pointer to a DigitNode implementation.
- static std::unique_ptr<DigitNode> Instance(const Rules& rules);
- bool Extract(const std::string& input, Results* result) const override;
-
- protected:
- DigitNode();
- bool Init(const Rules& rules);
-};
-
-// Class that encapsulates a signed integer matching and extract values of
-// SpanResults of type int.
-// Thread-safe.
-class SignedDigitNode : public RegexNode<int> {
- public:
- typedef typename RegexNode<int>::Result Result;
- typedef typename RegexNode<int>::Results Results;
-
- // Factory method for yielding a pointer to a DigitNode implementation.
- static std::unique_ptr<SignedDigitNode> Instance(const Rules& rules);
- bool Extract(const std::string& input, Results* result) const override;
-
- protected:
- SignedDigitNode();
- bool Init(const Rules& rules);
-};
-
-// Class that encapsulates a simple natural language integer matching and
-// extract values of SpanResults of type int.
-// Thread-safe.
-class NumberNode : public OrRegexNode<int> {
- public:
- typedef typename OrRegexNode<int>::Result Result;
- typedef typename OrRegexNode<int>::Results Results;
-
- static std::unique_ptr<NumberNode> Instance(const Rules& rules);
- bool Extract(const std::string& input, Results* result) const override;
-
- protected:
- explicit NumberNode(
- std::vector<std::unique_ptr<const RegexNode<int>>> alternatives);
- bool Init();
-};
-
-} // namespace libtextclassifier2
-#endif // KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_REGEX_NUMBER_H_
diff --git a/strip-unpaired-brackets.cc b/strip-unpaired-brackets.cc
index f813e6b..367766b 100644
--- a/strip-unpaired-brackets.cc
+++ b/strip-unpaired-brackets.cc
@@ -25,11 +25,9 @@
namespace {
// Returns true if given codepoint is contained in the given span in context.
-bool IsCodepointInSpan(const char32 codepoint, const std::string& context,
+bool IsCodepointInSpan(const char32 codepoint,
+ const UnicodeText& context_unicode,
const CodepointSpan span) {
- const UnicodeText context_unicode =
- UTF8ToUnicodeText(context, /*do_copy=*/false);
-
auto begin_it = context_unicode.begin();
std::advance(begin_it, span.first);
auto end_it = context_unicode.begin();
@@ -39,21 +37,16 @@
}
// Returns the first codepoint of the span.
-char32 FirstSpanCodepoint(const std::string& context,
+char32 FirstSpanCodepoint(const UnicodeText& context_unicode,
const CodepointSpan span) {
- const UnicodeText context_unicode =
- UTF8ToUnicodeText(context, /*do_copy=*/false);
-
auto it = context_unicode.begin();
std::advance(it, span.first);
return *it;
}
// Returns the last codepoint of the span.
-char32 LastSpanCodepoint(const std::string& context, const CodepointSpan span) {
- const UnicodeText context_unicode =
- UTF8ToUnicodeText(context, /*do_copy=*/false);
-
+char32 LastSpanCodepoint(const UnicodeText& context_unicode,
+ const CodepointSpan span) {
auto it = context_unicode.begin();
std::advance(it, span.second - 1);
return *it;
@@ -61,20 +54,27 @@
} // namespace
+CodepointSpan StripUnpairedBrackets(const std::string& context,
+ CodepointSpan span, const UniLib& unilib) {
+ const UnicodeText context_unicode =
+ UTF8ToUnicodeText(context, /*do_copy=*/false);
+ return StripUnpairedBrackets(context_unicode, span, unilib);
+}
+
// If the first or the last codepoint of the given span is a bracket, the
// bracket is stripped if the span does not contain its corresponding paired
// version.
-CodepointSpan StripUnpairedBrackets(const std::string& context,
+CodepointSpan StripUnpairedBrackets(const UnicodeText& context_unicode,
CodepointSpan span, const UniLib& unilib) {
- if (context.empty()) {
+ if (context_unicode.empty()) {
return span;
}
- const char32 begin_char = FirstSpanCodepoint(context, span);
+ const char32 begin_char = FirstSpanCodepoint(context_unicode, span);
const char32 paired_begin_char = unilib.GetPairedBracket(begin_char);
if (paired_begin_char != begin_char) {
if (!unilib.IsOpeningBracket(begin_char) ||
- !IsCodepointInSpan(paired_begin_char, context, span)) {
+ !IsCodepointInSpan(paired_begin_char, context_unicode, span)) {
++span.first;
}
}
@@ -83,11 +83,11 @@
return span;
}
- const char32 end_char = LastSpanCodepoint(context, span);
+ const char32 end_char = LastSpanCodepoint(context_unicode, span);
const char32 paired_end_char = unilib.GetPairedBracket(end_char);
if (paired_end_char != end_char) {
if (!unilib.IsClosingBracket(end_char) ||
- !IsCodepointInSpan(paired_end_char, context, span)) {
+ !IsCodepointInSpan(paired_end_char, context_unicode, span)) {
--span.second;
}
}
diff --git a/strip-unpaired-brackets.h b/strip-unpaired-brackets.h
index 2d7893e..4e82c3e 100644
--- a/strip-unpaired-brackets.h
+++ b/strip-unpaired-brackets.h
@@ -14,8 +14,8 @@
* limitations under the License.
*/
-#ifndef KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_STRIP_UNPAIRED_BRACKETS_H_
-#define KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_STRIP_UNPAIRED_BRACKETS_H_
+#ifndef LIBTEXTCLASSIFIER_STRIP_UNPAIRED_BRACKETS_H_
+#define LIBTEXTCLASSIFIER_STRIP_UNPAIRED_BRACKETS_H_
#include <string>
@@ -28,6 +28,11 @@
// version.
CodepointSpan StripUnpairedBrackets(const std::string& context,
CodepointSpan span, const UniLib& unilib);
+
+// Same as above but takes UnicodeText instance directly.
+CodepointSpan StripUnpairedBrackets(const UnicodeText& context_unicode,
+ CodepointSpan span, const UniLib& unilib);
+
} // namespace libtextclassifier2
-#endif // KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_STRIP_UNPAIRED_BRACKETS_H_
+#endif // LIBTEXTCLASSIFIER_STRIP_UNPAIRED_BRACKETS_H_
diff --git a/strip-unpaired-brackets_test.cc b/strip-unpaired-brackets_test.cc
index fb99d82..c329157 100644
--- a/strip-unpaired-brackets_test.cc
+++ b/strip-unpaired-brackets_test.cc
@@ -22,7 +22,7 @@
namespace {
TEST(StripUnpairedBracketsTest, StripUnpairedBrackets) {
- UniLib unilib;
+ CREATE_UNILIB_FOR_TESTING
// If the brackets match, nothing gets stripped.
EXPECT_EQ(StripUnpairedBrackets("call me (123) 456 today", {8, 17}, unilib),
std::make_pair(8, 17));
diff --git a/tensor-view.h b/tensor-view.h
index 69788c8..00ab08c 100644
--- a/tensor-view.h
+++ b/tensor-view.h
@@ -14,8 +14,8 @@
* limitations under the License.
*/
-#ifndef KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_TENSOR_VIEW_H_
-#define KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_TENSOR_VIEW_H_
+#ifndef LIBTEXTCLASSIFIER_TENSOR_VIEW_H_
+#define LIBTEXTCLASSIFIER_TENSOR_VIEW_H_
#include <algorithm>
#include <vector>
@@ -69,4 +69,4 @@
} // namespace libtextclassifier2
-#endif // KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_TENSOR_VIEW_H_
+#endif // LIBTEXTCLASSIFIER_TENSOR_VIEW_H_
diff --git a/test_data/dummy.fb b/test_data/dummy.fb
deleted file mode 100644
index 4fec970..0000000
--- a/test_data/dummy.fb
+++ /dev/null
Binary files differ
diff --git a/test_data/test_model.fb b/test_data/test_model.fb
index f62d9fe..f2be859 100644
--- a/test_data/test_model.fb
+++ b/test_data/test_model.fb
Binary files differ
diff --git a/test_data/test_model_cc.fb b/test_data/test_model_cc.fb
new file mode 100644
index 0000000..01eafd6
--- /dev/null
+++ b/test_data/test_model_cc.fb
Binary files differ
diff --git a/test_data/wrong_embeddings.fb b/test_data/wrong_embeddings.fb
index 513fcf5..e938bdc 100644
--- a/test_data/wrong_embeddings.fb
+++ b/test_data/wrong_embeddings.fb
Binary files differ
diff --git a/text-classifier.cc b/text-classifier.cc
index 1ee7e56..ef83688 100644
--- a/text-classifier.cc
+++ b/text-classifier.cc
@@ -27,9 +27,16 @@
#include "util/utf8/unicodetext.h"
namespace libtextclassifier2 {
+const std::string& TextClassifier::kOtherCollection =
+ *[]() { return new std::string("other"); }();
+const std::string& TextClassifier::kPhoneCollection =
+ *[]() { return new std::string("phone"); }();
+const std::string& TextClassifier::kDateCollection =
+ *[]() { return new std::string("date"); }();
+
namespace {
const Model* LoadAndVerifyModel(const void* addr, int size) {
- const Model* model = flatbuffers::GetRoot<Model>(addr);
+ const Model* model = GetModel(addr);
flatbuffers::Verifier verifier(reinterpret_cast<const uint8_t*>(addr), size);
if (model->Verify(verifier)) {
@@ -41,13 +48,14 @@
} // namespace
std::unique_ptr<TextClassifier> TextClassifier::FromUnownedBuffer(
- const char* buffer, int size) {
+ const char* buffer, int size, const UniLib* unilib) {
const Model* model = LoadAndVerifyModel(buffer, size);
if (model == nullptr) {
return nullptr;
}
- auto classifier = std::unique_ptr<TextClassifier>(new TextClassifier(model));
+ auto classifier =
+ std::unique_ptr<TextClassifier>(new TextClassifier(model, unilib));
if (!classifier->IsInitialized()) {
return nullptr;
}
@@ -56,7 +64,7 @@
}
std::unique_ptr<TextClassifier> TextClassifier::FromScopedMmap(
- std::unique_ptr<ScopedMmap>* mmap) {
+ std::unique_ptr<ScopedMmap>* mmap, const UniLib* unilib) {
if (!(*mmap)->handle().ok()) {
TC_VLOG(1) << "Mmap failed.";
return nullptr;
@@ -70,7 +78,7 @@
}
auto classifier =
- std::unique_ptr<TextClassifier>(new TextClassifier(mmap, model));
+ std::unique_ptr<TextClassifier>(new TextClassifier(mmap, model, unilib));
if (!classifier->IsInitialized()) {
return nullptr;
}
@@ -78,101 +86,104 @@
return classifier;
}
-std::unique_ptr<TextClassifier> TextClassifier::FromFileDescriptor(int fd,
- int offset,
- int size) {
+std::unique_ptr<TextClassifier> TextClassifier::FromFileDescriptor(
+ int fd, int offset, int size, const UniLib* unilib) {
std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd, offset, size));
- return FromScopedMmap(&mmap);
+ return FromScopedMmap(&mmap, unilib);
}
-std::unique_ptr<TextClassifier> TextClassifier::FromFileDescriptor(int fd) {
+std::unique_ptr<TextClassifier> TextClassifier::FromFileDescriptor(
+ int fd, const UniLib* unilib) {
std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd));
- return FromScopedMmap(&mmap);
+ return FromScopedMmap(&mmap, unilib);
}
std::unique_ptr<TextClassifier> TextClassifier::FromPath(
- const std::string& path) {
+ const std::string& path, const UniLib* unilib) {
std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(path));
- return FromScopedMmap(&mmap);
+ return FromScopedMmap(&mmap, unilib);
}
void TextClassifier::ValidateAndInitialize() {
+ initialized_ = false;
+
if (model_ == nullptr) {
TC_LOG(ERROR) << "No model specified.";
- initialized_ = false;
return;
}
if (!model_->selection_options()) {
TC_LOG(ERROR) << "No selection options.";
- initialized_ = false;
return;
}
if (!model_->classification_options()) {
TC_LOG(ERROR) << "No classification options.";
- initialized_ = false;
return;
}
if (!model_->selection_feature_options()) {
TC_LOG(ERROR) << "No selection feature options.";
- initialized_ = false;
return;
}
if (!model_->classification_feature_options()) {
TC_LOG(ERROR) << "No classification feature options.";
- initialized_ = false;
return;
}
if (!model_->classification_feature_options()->bounds_sensitive_features()) {
TC_LOG(ERROR) << "No classification bounds sensitive feature options.";
- initialized_ = false;
return;
}
if (!model_->selection_feature_options()->bounds_sensitive_features()) {
TC_LOG(ERROR) << "No selection bounds sensitive feature options.";
- initialized_ = false;
+ return;
+ }
+
+ if (model_->selection_feature_options()->embedding_size() !=
+ model_->classification_feature_options()->embedding_size() ||
+ model_->selection_feature_options()->embedding_quantization_bits() !=
+ model_->classification_feature_options()
+ ->embedding_quantization_bits()) {
+ TC_LOG(ERROR) << "Mismatching embedding size/quantization.";
return;
}
if (!model_->selection_model()) {
TC_LOG(ERROR) << "No selection model.";
- initialized_ = false;
return;
}
if (!model_->embedding_model()) {
TC_LOG(ERROR) << "No embedding model.";
- initialized_ = false;
return;
}
if (!model_->classification_model()) {
TC_LOG(ERROR) << "No clf model.";
- initialized_ = false;
return;
}
- if (model_->regex_options()) {
- InitializeRegexModel();
+ if (model_->regex_model()) {
+ if (!InitializeRegexModel()) {
+ TC_LOG(ERROR) << "Could not initialize regex model.";
+ }
}
embedding_executor_.reset(new TFLiteEmbeddingExecutor(
- flatbuffers::GetRoot<tflite::Model>(model_->embedding_model()->data())));
+ flatbuffers::GetRoot<tflite::Model>(model_->embedding_model()->data()),
+ model_->selection_feature_options()->embedding_size(),
+ model_->selection_feature_options()->embedding_quantization_bits()));
if (!embedding_executor_ || !embedding_executor_->IsReady()) {
TC_LOG(ERROR) << "Could not initialize embedding executor.";
- initialized_ = false;
return;
}
selection_executor_.reset(new ModelExecutor(
flatbuffers::GetRoot<tflite::Model>(model_->selection_model()->data())));
if (!selection_executor_) {
TC_LOG(ERROR) << "Could not initialize selection executor.";
- initialized_ = false;
return;
}
classification_executor_.reset(
@@ -180,39 +191,63 @@
model_->classification_model()->data())));
if (!classification_executor_) {
TC_LOG(ERROR) << "Could not initialize classification executor.";
- initialized_ = false;
return;
}
+ if (model_->datetime_model()) {
+ datetime_parser_ =
+ DatetimeParser::Instance(model_->datetime_model(), *unilib_);
+ if (!datetime_parser_) {
+ TC_LOG(ERROR) << "Could not initialize datetime parser.";
+ return;
+ }
+ }
+
selection_feature_processor_.reset(
- new FeatureProcessor(model_->selection_feature_options(), unilib_.get()));
- classification_feature_processor_.reset(new FeatureProcessor(
- model_->classification_feature_options(), unilib_.get()));
+ new FeatureProcessor(model_->selection_feature_options(), unilib_));
+ classification_feature_processor_.reset(
+ new FeatureProcessor(model_->classification_feature_options(), unilib_));
initialized_ = true;
}
-void TextClassifier::InitializeRegexModel() {
- if (!model_->regex_options()->patterns()) {
+bool TextClassifier::InitializeRegexModel() {
+ if (!model_->regex_model()->patterns()) {
initialized_ = false;
TC_LOG(ERROR) << "No patterns in the regex config.";
- return;
+ return false;
}
// Initialize pattern recognizers.
- for (const auto& regex_pattern : *model_->regex_options()->patterns()) {
+ int regex_pattern_id = 0;
+ for (const auto& regex_pattern : *model_->regex_model()->patterns()) {
std::unique_ptr<UniLib::RegexPattern> compiled_pattern(
- unilib_->CreateRegexPattern(regex_pattern->pattern()->c_str()));
+ unilib_->CreateRegexPattern(UTF8ToUnicodeText(
+ regex_pattern->pattern()->c_str(),
+ regex_pattern->pattern()->Length(), /*do_copy=*/false)));
if (!compiled_pattern) {
- TC_LOG(WARNING) << "Failed to load pattern"
- << regex_pattern->pattern()->str();
+ TC_LOG(INFO) << "Failed to load pattern"
+ << regex_pattern->pattern()->str();
continue;
}
- regex_patterns_.push_back(
- {regex_pattern->collection_name()->str(), std::move(compiled_pattern)});
+ if (regex_pattern->enabled_for_annotation()) {
+ annotation_regex_patterns_.push_back(regex_pattern_id);
+ }
+ if (regex_pattern->enabled_for_classification()) {
+ classification_regex_patterns_.push_back(regex_pattern_id);
+ }
+ if (regex_pattern->enabled_for_selection()) {
+ selection_regex_patterns_.push_back(regex_pattern_id);
+ }
+ regex_patterns_.push_back({regex_pattern->collection_name()->str(),
+ regex_pattern->target_classification_score(),
+ regex_pattern->priority_score(),
+ std::move(compiled_pattern)});
+ ++regex_pattern_id;
}
+ return true;
}
namespace {
@@ -240,18 +275,19 @@
std::advance(selection_end, selection_indices.second);
return UnicodeText::UTF8Substring(selection_begin, selection_end);
}
-
} // namespace
CodepointSpan TextClassifier::SuggestSelection(
- const std::string& context, CodepointSpan click_indices) const {
+ const std::string& context, CodepointSpan click_indices,
+ const SelectionOptions& options) const {
if (!initialized_) {
TC_LOG(ERROR) << "Not initialized";
return click_indices;
}
- const int context_codepoint_size =
- UTF8ToUnicodeText(context, /*do_copy=*/false).size();
+ const UnicodeText context_unicode =
+ UTF8ToUnicodeText(context, /*do_copy=*/false);
+ const int context_codepoint_size = context_unicode.size_codepoints();
if (click_indices.first < 0 || click_indices.second < 0 ||
click_indices.first >= context_codepoint_size ||
@@ -262,22 +298,184 @@
return click_indices;
}
+ std::vector<AnnotatedSpan> candidates;
+ if (!ModelSuggestSelection(context_unicode, click_indices, &candidates)) {
+ TC_LOG(ERROR) << "Model suggest selection failed.";
+ return click_indices;
+ }
+ if (!RegexChunk(context_unicode, selection_regex_patterns_, &candidates)) {
+ TC_LOG(ERROR) << "Regex suggest selection failed.";
+ return click_indices;
+ }
+ if (!DatetimeChunk(UTF8ToUnicodeText(context, /*do_copy=*/false),
+ /*reference_time_ms_utc=*/0, /*reference_timezone=*/"",
+ options.locales, &candidates)) {
+ TC_LOG(ERROR) << "Datetime suggest selection failed.";
+ return click_indices;
+ }
+
+ // Sort candidates according to their position in the input, so that the next
+ // code can assume that any connected component of overlapping spans forms a
+ // contiguous block.
+ std::sort(candidates.begin(), candidates.end(),
+ [](const AnnotatedSpan& a, const AnnotatedSpan& b) {
+ return a.span.first < b.span.first;
+ });
+
+ std::vector<int> candidate_indices;
+ if (!ResolveConflicts(candidates, context, &candidate_indices)) {
+ TC_LOG(ERROR) << "Couldn't resolve conflicts.";
+ return click_indices;
+ }
+
+ for (const int i : candidate_indices) {
+ if (SpansOverlap(candidates[i].span, click_indices)) {
+ return candidates[i].span;
+ }
+ }
+
+ return click_indices;
+}
+
+namespace {
+// Helper function that returns the index of the first candidate that
+// transitively does not overlap with the candidate on 'start_index'. If the end
+// of 'candidates' is reached, it returns the index that points right behind the
+// array.
+int FirstNonOverlappingSpanIndex(const std::vector<AnnotatedSpan>& candidates,
+ int start_index) {
+ int first_non_overlapping = start_index + 1;
+ CodepointSpan conflicting_span = candidates[start_index].span;
+ while (
+ first_non_overlapping < candidates.size() &&
+ SpansOverlap(conflicting_span, candidates[first_non_overlapping].span)) {
+ // Grow the span to include the current one.
+ conflicting_span.second = std::max(
+ conflicting_span.second, candidates[first_non_overlapping].span.second);
+
+ ++first_non_overlapping;
+ }
+ return first_non_overlapping;
+}
+} // namespace
+
+bool TextClassifier::ResolveConflicts(
+ const std::vector<AnnotatedSpan>& candidates, const std::string& context,
+ std::vector<int>* result) const {
+ result->clear();
+ result->reserve(candidates.size());
+ for (int i = 0; i < candidates.size();) {
+ int first_non_overlapping =
+ FirstNonOverlappingSpanIndex(candidates, /*start_index=*/i);
+
+ const bool conflict_found = first_non_overlapping != (i + 1);
+ if (conflict_found) {
+ std::vector<int> candidate_indices;
+ if (!ResolveConflict(context, candidates, i, first_non_overlapping,
+ &candidate_indices)) {
+ return false;
+ }
+ result->insert(result->end(), candidate_indices.begin(),
+ candidate_indices.end());
+ } else {
+ result->push_back(i);
+ }
+
+ // Skip over the whole conflicting group/go to next candidate.
+ i = first_non_overlapping;
+ }
+ return true;
+}
+
+namespace {
+inline bool ClassifiedAsOther(
+ const std::vector<ClassificationResult>& classification) {
+ return !classification.empty() &&
+ classification[0].collection == TextClassifier::kOtherCollection;
+}
+
+float GetPriorityScore(
+ const std::vector<ClassificationResult>& classification) {
+ if (!ClassifiedAsOther(classification)) {
+ return classification[0].priority_score;
+ } else {
+ return -1.0;
+ }
+}
+} // namespace
+
+bool TextClassifier::ResolveConflict(
+ const std::string& context, const std::vector<AnnotatedSpan>& candidates,
+ int start_index, int end_index, std::vector<int>* chosen_indices) const {
+ std::vector<int> conflicting_indices;
+ std::unordered_map<int, float> scores;
+ for (int i = start_index; i < end_index; ++i) {
+ conflicting_indices.push_back(i);
+ if (!candidates[i].classification.empty()) {
+ scores[i] = GetPriorityScore(candidates[i].classification);
+ continue;
+ }
+
+ // OPTIMIZATION: So that we don't have to classify all the ML model
+ // spans apriori, we wait until we get here, when they conflict with
+ // something and we need the actual classification scores. So if the
+ // candidate conflicts and comes from the model, we need to run a
+ // classification to determine its priority:
+ std::vector<ClassificationResult> classification;
+ if (!ModelClassifyText(context, candidates[i].span, &classification)) {
+ return false;
+ }
+
+ if (!classification.empty()) {
+ scores[i] = GetPriorityScore(classification);
+ }
+ }
+
+ std::sort(conflicting_indices.begin(), conflicting_indices.end(),
+ [&scores](int i, int j) { return scores[i] > scores[j]; });
+
+ // Keeps the candidates sorted by their position in the text (their left span
+ // index) for fast retrieval down.
+ std::set<int, std::function<bool(int, int)>> chosen_indices_set(
+ [&candidates](int a, int b) {
+ return candidates[a].span.first < candidates[b].span.first;
+ });
+
+ // Greedily place the candidates if they don't conflict with the already
+ // placed ones.
+ for (int i = 0; i < conflicting_indices.size(); ++i) {
+ const int considered_candidate = conflicting_indices[i];
+ if (!DoesCandidateConflict(considered_candidate, candidates,
+ chosen_indices_set)) {
+ chosen_indices_set.insert(considered_candidate);
+ }
+ }
+
+ *chosen_indices =
+ std::vector<int>(chosen_indices_set.begin(), chosen_indices_set.end());
+
+ return true;
+}
+
+bool TextClassifier::ModelSuggestSelection(
+ const UnicodeText& context_unicode, CodepointSpan click_indices,
+ std::vector<AnnotatedSpan>* result) const {
std::vector<Token> tokens;
int click_pos;
- selection_feature_processor_->TokenizeAndFindClick(context, click_indices,
- &tokens, &click_pos);
+ selection_feature_processor_->TokenizeAndFindClick(
+ context_unicode, click_indices,
+ selection_feature_processor_->GetOptions()->only_use_line_with_click(),
+ &tokens, &click_pos);
if (click_pos == kInvalidIndex) {
TC_VLOG(1) << "Could not calculate the click position.";
- return click_indices;
+ return false;
}
const int symmetry_context_size =
model_->selection_options()->symmetry_context_size();
- const int max_selection_span =
- selection_feature_processor_->GetOptions()->max_selection_span();
const FeatureProcessorOptions_::BoundsSensitiveFeatures*
- bounds_sensitive_features =
- model_->selection_feature_options()->bounds_sensitive_features();
+ bounds_sensitive_features = selection_feature_processor_->GetOptions()
+ ->bounds_sensitive_features();
// The symmetry context span is the clicked token with symmetry_context_size
// tokens on either side.
@@ -287,57 +485,236 @@
/*num_tokens_right=*/symmetry_context_size),
{0, tokens.size()});
- // The extraction span is the symmetry context span expanded to include
- // max_selection_span tokens on either side, which is how far a selection can
- // stretch from the click, plus a relevant number of tokens outside of the
- // bounds of the selection.
- const TokenSpan extraction_span = IntersectTokenSpans(
- ExpandTokenSpan(symmetry_context_span,
- /*num_tokens_left=*/max_selection_span +
- bounds_sensitive_features->num_tokens_before(),
- /*num_tokens_right=*/max_selection_span +
- bounds_sensitive_features->num_tokens_after()),
- {0, tokens.size()});
+ // Compute the extraction span based on the model type.
+ TokenSpan extraction_span;
+ if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
+ // The extraction span is the symmetry context span expanded to include
+ // max_selection_span tokens on either side, which is how far a selection
+ // can stretch from the click, plus a relevant number of tokens outside of
+ // the bounds of the selection.
+ const int max_selection_span =
+ selection_feature_processor_->GetOptions()->max_selection_span();
+ extraction_span =
+ ExpandTokenSpan(symmetry_context_span,
+ /*num_tokens_left=*/max_selection_span +
+ bounds_sensitive_features->num_tokens_before(),
+ /*num_tokens_right=*/max_selection_span +
+ bounds_sensitive_features->num_tokens_after());
+ } else {
+ // The extraction span is the symmetry context span expanded to include
+ // context_size tokens on either side.
+ const int context_size =
+ selection_feature_processor_->GetOptions()->context_size();
+ extraction_span = ExpandTokenSpan(symmetry_context_span,
+ /*num_tokens_left=*/context_size,
+ /*num_tokens_right=*/context_size);
+ }
+ extraction_span = IntersectTokenSpans(extraction_span, {0, tokens.size()});
+
+ std::unique_ptr<CachedFeatures> cached_features;
+ if (!selection_feature_processor_->ExtractFeatures(
+ tokens, extraction_span,
+ /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex},
+ embedding_executor_.get(),
+ selection_feature_processor_->EmbeddingSize() +
+ selection_feature_processor_->DenseFeaturesCount(),
+ &cached_features)) {
+ TC_LOG(ERROR) << "Could not extract features.";
+ return false;
+ }
+
+ // Produce selection model candidates.
+ std::vector<TokenSpan> chunks;
+ if (!ModelChunk(tokens.size(), /*span_of_interest=*/symmetry_context_span,
+ *cached_features, &chunks)) {
+ TC_LOG(ERROR) << "Could not chunk.";
+ return false;
+ }
+
+ for (const TokenSpan& chunk : chunks) {
+ AnnotatedSpan candidate;
+ candidate.span = selection_feature_processor_->StripBoundaryCodepoints(
+ context_unicode, TokenSpanToCodepointSpan(tokens, chunk));
+ if (model_->selection_options()->strip_unpaired_brackets()) {
+ candidate.span =
+ StripUnpairedBrackets(context_unicode, candidate.span, *unilib_);
+ }
+
+ // Only output non-empty spans.
+ if (candidate.span.first != candidate.span.second) {
+ result->push_back(candidate);
+ }
+ }
+ return true;
+}
+
+bool TextClassifier::ModelClassifyText(
+ const std::string& context, CodepointSpan selection_indices,
+ std::vector<ClassificationResult>* classification_results) const {
+ std::vector<Token> tokens;
+ int click_pos;
+ classification_feature_processor_->TokenizeAndFindClick(
+ context, selection_indices,
+ classification_feature_processor_->GetOptions()
+ ->only_use_line_with_click(),
+ &tokens, &click_pos);
+ const TokenSpan selection_token_span =
+ CodepointSpanToTokenSpan(tokens, selection_indices);
+ const FeatureProcessorOptions_::BoundsSensitiveFeatures*
+ bounds_sensitive_features =
+ classification_feature_processor_->GetOptions()
+ ->bounds_sensitive_features();
+ if (selection_token_span.first == kInvalidIndex ||
+ selection_token_span.second == kInvalidIndex) {
+ TC_LOG(ERROR) << "Could not determine span.";
+ return false;
+ }
+
+ // Compute the extraction span based on the model type.
+ TokenSpan extraction_span;
+ if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
+ // The extraction span is the selection span expanded to include a relevant
+ // number of tokens outside of the bounds of the selection.
+ extraction_span = ExpandTokenSpan(
+ selection_token_span,
+ /*num_tokens_left=*/bounds_sensitive_features->num_tokens_before(),
+ /*num_tokens_right=*/bounds_sensitive_features->num_tokens_after());
+ } else {
+ if (click_pos == kInvalidIndex) {
+ TC_LOG(ERROR) << "Couldn't choose a click position.";
+ return false;
+ }
+ // The extraction span is the clicked token with context_size tokens on
+ // either side.
+ const int context_size =
+ selection_feature_processor_->GetOptions()->context_size();
+ extraction_span = ExpandTokenSpan(SingleTokenSpan(click_pos),
+ /*num_tokens_left=*/context_size,
+ /*num_tokens_right=*/context_size);
+ }
+ extraction_span = IntersectTokenSpans(extraction_span, {0, tokens.size()});
std::unique_ptr<CachedFeatures> cached_features;
if (!classification_feature_processor_->ExtractFeatures(
- tokens, extraction_span, embedding_executor_.get(),
+ tokens, extraction_span, selection_indices, embedding_executor_.get(),
classification_feature_processor_->EmbeddingSize() +
classification_feature_processor_->DenseFeaturesCount(),
&cached_features)) {
TC_LOG(ERROR) << "Could not extract features.";
- return click_indices;
+ return false;
}
- std::vector<TokenSpan> chunks;
- if (!Chunk(tokens.size(), /*span_of_interest=*/symmetry_context_span,
- *cached_features, &chunks)) {
- TC_LOG(ERROR) << "Could not chunk.";
- return click_indices;
+ std::vector<float> features;
+ features.reserve(cached_features->OutputFeaturesSize());
+ if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
+ cached_features->AppendBoundsSensitiveFeaturesForSpan(selection_token_span,
+ &features);
+ } else {
+ cached_features->AppendClickContextFeaturesForClick(click_pos, &features);
}
- CodepointSpan result = click_indices;
- for (const TokenSpan& chunk : chunks) {
- if (chunk.first <= click_pos && click_pos < chunk.second) {
- result = selection_feature_processor_->StripBoundaryCodepoints(
- context, TokenSpanToCodepointSpan(tokens, chunk));
- break;
+ TensorView<float> logits =
+ classification_executor_->ComputeLogits(TensorView<float>(
+ features.data(), {1, static_cast<int>(features.size())}));
+ if (!logits.is_valid()) {
+ TC_LOG(ERROR) << "Couldn't compute logits.";
+ return false;
+ }
+
+ if (logits.dims() != 2 || logits.dim(0) != 1 ||
+ logits.dim(1) != classification_feature_processor_->NumCollections()) {
+ TC_LOG(ERROR) << "Mismatching output";
+ return false;
+ }
+
+ const std::vector<float> scores =
+ ComputeSoftmax(logits.data(), logits.dim(1));
+
+ classification_results->resize(scores.size());
+ for (int i = 0; i < scores.size(); i++) {
+ (*classification_results)[i] = {
+ classification_feature_processor_->LabelToCollection(i), scores[i]};
+ }
+ std::sort(classification_results->begin(), classification_results->end(),
+ [](const ClassificationResult& a, const ClassificationResult& b) {
+ return a.score > b.score;
+ });
+
+ // Phone class sanity check.
+ if (!classification_results->empty() &&
+ classification_results->begin()->collection == kPhoneCollection) {
+ const int digit_count = CountDigits(context, selection_indices);
+ if (digit_count <
+ model_->classification_options()->phone_min_num_digits() ||
+ digit_count >
+ model_->classification_options()->phone_max_num_digits()) {
+ *classification_results = {{kOtherCollection, 1.0}};
}
}
- if (model_->selection_options()->strip_unpaired_brackets()) {
- const CodepointSpan stripped_result =
- StripUnpairedBrackets(context, result, *unilib_);
- if (stripped_result.first != stripped_result.second) {
- result = stripped_result;
- }
- }
-
- return result;
+ return true;
}
-std::vector<std::pair<std::string, float>> TextClassifier::ClassifyText(
- const std::string& context, CodepointSpan selection_indices) const {
+bool TextClassifier::RegexClassifyText(
+ const std::string& context, CodepointSpan selection_indices,
+ ClassificationResult* classification_result) const {
+ const std::string selection_text =
+ ExtractSelection(context, selection_indices);
+ const UnicodeText selection_text_unicode(
+ UTF8ToUnicodeText(selection_text, /*do_copy=*/false));
+
+ // Check whether any of the regular expressions match.
+ for (const int pattern_id : classification_regex_patterns_) {
+ const CompiledRegexPattern& regex_pattern = regex_patterns_[pattern_id];
+ const std::unique_ptr<UniLib::RegexMatcher> matcher =
+ regex_pattern.pattern->Matcher(selection_text_unicode);
+ int status = UniLib::RegexMatcher::kNoError;
+ if (matcher->Matches(&status) && status == UniLib::RegexMatcher::kNoError) {
+ *classification_result = {regex_pattern.collection_name,
+ regex_pattern.target_classification_score,
+ regex_pattern.priority_score};
+ return true;
+ }
+ if (status != UniLib::RegexMatcher::kNoError) {
+ TC_LOG(ERROR) << "Cound't match regex: " << pattern_id;
+ }
+ }
+
+ return false;
+}
+
+bool TextClassifier::DatetimeClassifyText(
+ const std::string& context, CodepointSpan selection_indices,
+ const ClassificationOptions& options,
+ ClassificationResult* classification_result) const {
+ const std::string selection_text =
+ ExtractSelection(context, selection_indices);
+
+ std::vector<DatetimeParseResultSpan> datetime_spans;
+ if (!datetime_parser_->Parse(selection_text, options.reference_time_ms_utc,
+ options.reference_timezone, options.locales,
+ &datetime_spans)) {
+ TC_LOG(ERROR) << "Error during parsing datetime.";
+ return false;
+ }
+ for (const DatetimeParseResultSpan& datetime_span : datetime_spans) {
+ // Only consider the result valid if the selection and extracted datetime
+ // spans exactly match.
+ if (std::make_pair(datetime_span.span.first + selection_indices.first,
+ datetime_span.span.second + selection_indices.first) ==
+ selection_indices) {
+ *classification_result = {kDateCollection,
+ datetime_span.target_classification_score};
+ classification_result->datetime_parse_result = datetime_span.data;
+ return true;
+ }
+ }
+ return false;
+}
+
+std::vector<ClassificationResult> TextClassifier::ClassifyText(
+ const std::string& context, CodepointSpan selection_indices,
+ const ClassificationOptions& options) const {
if (!initialized_) {
TC_LOG(ERROR) << "Not initialized";
return {};
@@ -350,125 +727,70 @@
return {};
}
- // Check whether any of the regular expressions match.
- const std::string selection_text =
- ExtractSelection(context, selection_indices);
- for (const CompiledRegexPattern& regex_pattern : regex_patterns_) {
- if (regex_pattern.pattern->Matches(selection_text)) {
- return {{regex_pattern.collection_name, 1.0}};
- }
+ // Try the regular expression models.
+ ClassificationResult regex_result;
+ if (RegexClassifyText(context, selection_indices, ®ex_result)) {
+ return {regex_result};
}
- const FeatureProcessorOptions_::BoundsSensitiveFeatures*
- bounds_sensitive_features =
- model_->classification_feature_options()->bounds_sensitive_features();
-
- std::vector<Token> tokens;
- classification_feature_processor_->TokenizeAndFindClick(
- context, selection_indices, &tokens, /*click_pos=*/nullptr);
- const TokenSpan selection_token_span =
- CodepointSpanToTokenSpan(tokens, selection_indices);
-
- if (selection_token_span.first == kInvalidIndex ||
- selection_token_span.second == kInvalidIndex) {
- return {};
+ // Try the date model.
+ ClassificationResult datetime_result;
+ if (DatetimeClassifyText(context, selection_indices, options,
+ &datetime_result)) {
+ return {datetime_result};
}
- // The extraction span is the selection span expanded to include a relevant
- // number of tokens outside of the bounds of the selection.
- const TokenSpan extraction_span = IntersectTokenSpans(
- ExpandTokenSpan(selection_token_span,
- bounds_sensitive_features->num_tokens_before(),
- bounds_sensitive_features->num_tokens_after()),
- {0, tokens.size()});
-
- std::unique_ptr<CachedFeatures> cached_features;
- if (!classification_feature_processor_->ExtractFeatures(
- tokens, extraction_span, embedding_executor_.get(),
- classification_feature_processor_->EmbeddingSize() +
- classification_feature_processor_->DenseFeaturesCount(),
- &cached_features)) {
- TC_LOG(ERROR) << "Could not extract features.";
- return {};
+ // Fallback to the model.
+ std::vector<ClassificationResult> model_result;
+ if (ModelClassifyText(context, selection_indices, &model_result)) {
+ return model_result;
}
- const std::vector<float> features =
- cached_features->Get(selection_token_span);
-
- TensorView<float> logits =
- classification_executor_->ComputeLogits(TensorView<float>(
- features.data(), {1, static_cast<int>(features.size())}));
- if (!logits.is_valid()) {
- TC_LOG(ERROR) << "Couldn't compute logits.";
- return {};
- }
-
- if (logits.dims() != 2 || logits.dim(0) != 1 ||
- logits.dim(1) != classification_feature_processor_->NumCollections()) {
- TC_LOG(ERROR) << "Mismatching output";
- return {};
- }
-
- const std::vector<float> scores =
- ComputeSoftmax(logits.data(), logits.dim(1));
-
- std::vector<std::pair<std::string, float>> result(scores.size());
- for (int i = 0; i < scores.size(); i++) {
- result[i] = {classification_feature_processor_->LabelToCollection(i),
- scores[i]};
- }
- std::sort(result.begin(), result.end(),
- [](const std::pair<std::string, float>& a,
- const std::pair<std::string, float>& b) {
- return a.second > b.second;
- });
-
- // Phone class sanity check.
- if (result.begin()->first == kPhoneCollection) {
- const int digit_count = CountDigits(context, selection_indices);
- if (digit_count <
- model_->classification_options()->phone_min_num_digits() ||
- digit_count >
- model_->classification_options()->phone_max_num_digits()) {
- return {{kOtherCollection, 1.0}};
- }
- }
-
- return result;
+ // No classifications.
+ return {};
}
-std::vector<AnnotatedSpan> TextClassifier::Annotate(
- const std::string& context) const {
+bool TextClassifier::ModelAnnotate(const std::string& context,
+ std::vector<AnnotatedSpan>* result) const {
const UnicodeText context_unicode = UTF8ToUnicodeText(context,
/*do_copy=*/false);
+ std::vector<UnicodeTextRange> lines;
+ if (!selection_feature_processor_->GetOptions()->only_use_line_with_click()) {
+ lines.push_back({context_unicode.begin(), context_unicode.end()});
+ } else {
+ lines = selection_feature_processor_->SplitContext(context_unicode);
+ }
std::vector<TokenSpan> chunks;
- for (const UnicodeTextRange& line :
- selection_feature_processor_->SplitContext(context_unicode)) {
+ for (const UnicodeTextRange& line : lines) {
const std::string line_str =
UnicodeText::UTF8Substring(line.first, line.second);
std::vector<Token> tokens;
selection_feature_processor_->TokenizeAndFindClick(
- line_str, {0, std::distance(line.first, line.second)}, &tokens,
+ line_str, {0, std::distance(line.first, line.second)},
+ selection_feature_processor_->GetOptions()->only_use_line_with_click(),
+ &tokens,
/*click_pos=*/nullptr);
const TokenSpan full_line_span = {0, tokens.size()};
std::unique_ptr<CachedFeatures> cached_features;
- if (!classification_feature_processor_->ExtractFeatures(
- tokens, full_line_span, embedding_executor_.get(),
- classification_feature_processor_->EmbeddingSize() +
- classification_feature_processor_->DenseFeaturesCount(),
+ if (!selection_feature_processor_->ExtractFeatures(
+ tokens, full_line_span,
+ /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex},
+ embedding_executor_.get(),
+ selection_feature_processor_->EmbeddingSize() +
+ selection_feature_processor_->DenseFeaturesCount(),
&cached_features)) {
TC_LOG(ERROR) << "Could not extract features.";
- continue;
+ return false;
}
std::vector<TokenSpan> local_chunks;
- if (!Chunk(tokens.size(), /*span_of_interest=*/full_line_span,
- *cached_features, &local_chunks)) {
+ if (!ModelChunk(tokens.size(), /*span_of_interest=*/full_line_span,
+ *cached_features, &local_chunks)) {
TC_LOG(ERROR) << "Could not chunk.";
- continue;
+ return false;
}
const int offset = std::distance(context_unicode.begin(), line.first);
@@ -476,37 +798,130 @@
const CodepointSpan codepoint_span =
selection_feature_processor_->StripBoundaryCodepoints(
line_str, TokenSpanToCodepointSpan(tokens, chunk));
- chunks.push_back(
- {codepoint_span.first + offset, codepoint_span.second + offset});
+
+ // Skip empty spans.
+ if (codepoint_span.first != codepoint_span.second) {
+ chunks.push_back(
+ {codepoint_span.first + offset, codepoint_span.second + offset});
+ }
}
}
- std::vector<AnnotatedSpan> result;
for (const CodepointSpan& chunk : chunks) {
- result.emplace_back();
- result.back().span = chunk;
- result.back().classification = ClassifyText(context, chunk);
+ std::vector<ClassificationResult> classification;
+ if (!ModelClassifyText(context, chunk, &classification)) {
+ TC_LOG(ERROR) << "Could not classify text: " << chunk.first << " "
+ << chunk.second;
+ return false;
+ }
+
+ // Do not include the span if it's classified as "other".
+ if (!classification.empty() && !ClassifiedAsOther(classification)) {
+ AnnotatedSpan result_span;
+ result_span.span = chunk;
+ result_span.classification = std::move(classification);
+ result->push_back(std::move(result_span));
+ }
}
+ return true;
+}
+
+const FeatureProcessor& TextClassifier::SelectionFeatureProcessorForTests()
+ const {
+ return *selection_feature_processor_;
+}
+
+std::vector<AnnotatedSpan> TextClassifier::Annotate(
+ const std::string& context, const AnnotationOptions& options) const {
+ std::vector<AnnotatedSpan> candidates;
+
+ // Annotate with the selection model.
+ if (!ModelAnnotate(context, &candidates)) {
+ TC_LOG(ERROR) << "Couldn't run ModelAnnotate.";
+ return {};
+ }
+
+ // Annotate with the regular expression models.
+ if (!RegexChunk(UTF8ToUnicodeText(context, /*do_copy=*/false),
+ annotation_regex_patterns_, &candidates)) {
+ TC_LOG(ERROR) << "Couldn't run RegexChunk.";
+ return {};
+ }
+
+ // Annotate with the datetime model.
+ if (!DatetimeChunk(UTF8ToUnicodeText(context, /*do_copy=*/false),
+ options.reference_time_ms_utc, options.reference_timezone,
+ options.locales, &candidates)) {
+ TC_LOG(ERROR) << "Couldn't run RegexChunk.";
+ return {};
+ }
+
+ // Sort candidates according to their position in the input, so that the next
+ // code can assume that any connected component of overlapping spans forms a
+ // contiguous block.
+ std::sort(candidates.begin(), candidates.end(),
+ [](const AnnotatedSpan& a, const AnnotatedSpan& b) {
+ return a.span.first < b.span.first;
+ });
+
+ std::vector<int> candidate_indices;
+ if (!ResolveConflicts(candidates, context, &candidate_indices)) {
+ TC_LOG(ERROR) << "Couldn't resolve conflicts.";
+ return {};
+ }
+
+ const float min_annotate_confidence =
+ (model_->triggering_options() != nullptr
+ ? model_->triggering_options()->min_annotate_confidence()
+ : 0.f);
+
+ std::vector<AnnotatedSpan> result;
+ result.reserve(candidate_indices.size());
+ for (const int i : candidate_indices) {
+ if (!candidates[i].classification.empty() &&
+ !ClassifiedAsOther(candidates[i].classification) &&
+ candidates[i].classification[0].score >= min_annotate_confidence) {
+ result.push_back(std::move(candidates[i]));
+ }
+ }
+
return result;
}
-bool TextClassifier::Chunk(int num_tokens, const TokenSpan& span_of_interest,
- const CachedFeatures& cached_features,
- std::vector<TokenSpan>* chunks) const {
+bool TextClassifier::RegexChunk(const UnicodeText& context_unicode,
+ const std::vector<int>& rules,
+ std::vector<AnnotatedSpan>* result) const {
+ for (int pattern_id : rules) {
+ const CompiledRegexPattern& regex_pattern = regex_patterns_[pattern_id];
+ const auto matcher = regex_pattern.pattern->Matcher(context_unicode);
+ if (!matcher) {
+ TC_LOG(ERROR) << "Could not get regex matcher for pattern: "
+ << pattern_id;
+ return false;
+ }
+
+ int status = UniLib::RegexMatcher::kNoError;
+ while (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) {
+ result->emplace_back();
+ // Selection/annotation regular expressions need to specify a capturing
+ // group specifying the selection.
+ result->back().span = {matcher->Start(1, &status),
+ matcher->End(1, &status)};
+ result->back().classification = {
+ {regex_pattern.collection_name,
+ regex_pattern.target_classification_score,
+ regex_pattern.priority_score}};
+ }
+ }
+ return true;
+}
+
+bool TextClassifier::ModelChunk(int num_tokens,
+ const TokenSpan& span_of_interest,
+ const CachedFeatures& cached_features,
+ std::vector<TokenSpan>* chunks) const {
const int max_selection_span =
selection_feature_processor_->GetOptions()->max_selection_span();
- const int max_chunk_length = selection_feature_processor_->GetOptions()
- ->selection_reduced_output_space()
- ? max_selection_span + 1
- : 2 * max_selection_span + 1;
-
- struct ScoredChunk {
- bool operator<(const ScoredChunk& that) const { return score < that.score; }
-
- TokenSpan token_span;
- float score;
- };
-
// The inference span is the span of interest expanded to include
// max_selection_span tokens on either side, which is how far a selection can
// stretch from the click.
@@ -517,37 +932,25 @@
{0, num_tokens});
std::vector<ScoredChunk> scored_chunks;
- // Iterate over chunk candidates that:
- // - Are contained in the inference span
- // - Have a non-empty intersection with the span of interest
- // - Are at least one token long
- // - Are not longer than the maximum chunk length
- for (int start = inference_span.first; start < span_of_interest.second;
- ++start) {
- const int leftmost_end_index = std::max(start, span_of_interest.first) + 1;
- for (int end = leftmost_end_index;
- end <= inference_span.second && end - start <= max_chunk_length;
- ++end) {
- const std::vector<float> features = cached_features.Get({start, end});
- TensorView<float> logits =
- selection_executor_->ComputeLogits(TensorView<float>(
- features.data(), {1, static_cast<int>(features.size())}));
-
- if (!logits.is_valid()) {
- TC_LOG(ERROR) << "Couldn't compute logits.";
- return false;
- }
-
- if (logits.dims() != 2 || logits.dim(0) != 1 || logits.dim(1) != 1) {
- TC_LOG(ERROR) << "Mismatching output";
- return false;
- }
-
- scored_chunks.push_back(ScoredChunk{{start, end}, logits.data()[0]});
+ if (selection_feature_processor_->GetOptions()->bounds_sensitive_features() &&
+ selection_feature_processor_->GetOptions()
+ ->bounds_sensitive_features()
+ ->enabled()) {
+ if (!ModelBoundsSensitiveScoreChunks(num_tokens, span_of_interest,
+ inference_span, cached_features,
+ &scored_chunks)) {
+ return false;
+ }
+ } else {
+ if (!ModelClickContextScoreChunks(num_tokens, span_of_interest,
+ cached_features, &scored_chunks)) {
+ return false;
}
}
-
- std::sort(scored_chunks.rbegin(), scored_chunks.rend());
+ std::sort(scored_chunks.rbegin(), scored_chunks.rend(),
+ [](const ScoredChunk& lhs, const ScoredChunk& rhs) {
+ return lhs.score < rhs.score;
+ });
// Traverse the candidate chunks from highest-scoring to lowest-scoring. Pick
// them greedily as long as they do not overlap with any previously picked
@@ -581,6 +984,183 @@
return true;
}
+namespace {
+// Updates the value at the given key in the map to maximum of the current value
+// and the given value, or simply inserts the value if the key is not yet there.
+template <typename Map>
+void UpdateMax(Map* map, typename Map::key_type key,
+ typename Map::mapped_type value) {
+ const auto it = map->find(key);
+ if (it != map->end()) {
+ it->second = std::max(it->second, value);
+ } else {
+ (*map)[key] = value;
+ }
+}
+} // namespace
+
+bool TextClassifier::ModelClickContextScoreChunks(
+ int num_tokens, const TokenSpan& span_of_interest,
+ const CachedFeatures& cached_features,
+ std::vector<ScoredChunk>* scored_chunks) const {
+ const int max_batch_size = model_->selection_options()->batch_size();
+
+ std::vector<float> all_features;
+ std::map<TokenSpan, float> chunk_scores;
+ for (int batch_start = span_of_interest.first;
+ batch_start < span_of_interest.second; batch_start += max_batch_size) {
+ const int batch_end =
+ std::min(batch_start + max_batch_size, span_of_interest.second);
+
+ // Prepare features for the whole batch.
+ all_features.clear();
+ all_features.reserve(max_batch_size * cached_features.OutputFeaturesSize());
+ for (int click_pos = batch_start; click_pos < batch_end; ++click_pos) {
+ cached_features.AppendClickContextFeaturesForClick(click_pos,
+ &all_features);
+ }
+
+ // Run batched inference.
+ const int batch_size = batch_end - batch_start;
+ const int features_size = cached_features.OutputFeaturesSize();
+ TensorView<float> logits = selection_executor_->ComputeLogits(
+ TensorView<float>(all_features.data(), {batch_size, features_size}));
+ if (!logits.is_valid()) {
+ TC_LOG(ERROR) << "Couldn't compute logits.";
+ return false;
+ }
+ if (logits.dims() != 2 || logits.dim(0) != batch_size ||
+ logits.dim(1) !=
+ selection_feature_processor_->GetSelectionLabelCount()) {
+ TC_LOG(ERROR) << "Mismatching output.";
+ return false;
+ }
+
+ // Save results.
+ for (int click_pos = batch_start; click_pos < batch_end; ++click_pos) {
+ const std::vector<float> scores = ComputeSoftmax(
+ logits.data() + logits.dim(1) * (click_pos - batch_start),
+ logits.dim(1));
+ for (int j = 0;
+ j < selection_feature_processor_->GetSelectionLabelCount(); ++j) {
+ TokenSpan relative_token_span;
+ if (!selection_feature_processor_->LabelToTokenSpan(
+ j, &relative_token_span)) {
+ TC_LOG(ERROR) << "Couldn't map the label to a token span.";
+ return false;
+ }
+ const TokenSpan candidate_span = ExpandTokenSpan(
+ SingleTokenSpan(click_pos), relative_token_span.first,
+ relative_token_span.second);
+ if (candidate_span.first >= 0 && candidate_span.second <= num_tokens) {
+ UpdateMax(&chunk_scores, candidate_span, scores[j]);
+ }
+ }
+ }
+ }
+
+ scored_chunks->clear();
+ scored_chunks->reserve(chunk_scores.size());
+ for (const auto& entry : chunk_scores) {
+ scored_chunks->push_back(ScoredChunk{entry.first, entry.second});
+ }
+
+ return true;
+}
+
+bool TextClassifier::ModelBoundsSensitiveScoreChunks(
+ int num_tokens, const TokenSpan& span_of_interest,
+ const TokenSpan& inference_span, const CachedFeatures& cached_features,
+ std::vector<ScoredChunk>* scored_chunks) const {
+ const int max_selection_span =
+ selection_feature_processor_->GetOptions()->max_selection_span();
+ const int max_chunk_length = selection_feature_processor_->GetOptions()
+ ->selection_reduced_output_space()
+ ? max_selection_span + 1
+ : 2 * max_selection_span + 1;
+
+ // Prepare all chunk candidates into one batch:
+ // - Are contained in the inference span
+ // - Have a non-empty intersection with the span of interest
+ // - Are at least one token long
+ // - Are not longer than the maximum chunk length
+ std::vector<TokenSpan> candidate_spans;
+ for (int start = inference_span.first; start < span_of_interest.second;
+ ++start) {
+ const int leftmost_end_index = std::max(start, span_of_interest.first) + 1;
+ for (int end = leftmost_end_index;
+ end <= inference_span.second && end - start <= max_chunk_length;
+ ++end) {
+ candidate_spans.emplace_back(start, end);
+ }
+ }
+
+ const int max_batch_size = model_->selection_options()->batch_size();
+
+ std::vector<float> all_features;
+ scored_chunks->clear();
+ scored_chunks->reserve(candidate_spans.size());
+ for (int batch_start = 0; batch_start < candidate_spans.size();
+ batch_start += max_batch_size) {
+ const int batch_end = std::min(batch_start + max_batch_size,
+ static_cast<int>(candidate_spans.size()));
+
+ // Prepare features for the whole batch.
+ all_features.clear();
+ all_features.reserve(max_batch_size * cached_features.OutputFeaturesSize());
+ for (int i = batch_start; i < batch_end; ++i) {
+ cached_features.AppendBoundsSensitiveFeaturesForSpan(candidate_spans[i],
+ &all_features);
+ }
+
+ // Run batched inference.
+ const int batch_size = batch_end - batch_start;
+ const int features_size = cached_features.OutputFeaturesSize();
+ TensorView<float> logits = selection_executor_->ComputeLogits(
+ TensorView<float>(all_features.data(), {batch_size, features_size}));
+ if (!logits.is_valid()) {
+ TC_LOG(ERROR) << "Couldn't compute logits.";
+ return false;
+ }
+ if (logits.dims() != 2 || logits.dim(0) != batch_size ||
+ logits.dim(1) != 1) {
+ TC_LOG(ERROR) << "Mismatching output.";
+ return false;
+ }
+
+ // Save results.
+ for (int i = batch_start; i < batch_end; ++i) {
+ scored_chunks->push_back(
+ ScoredChunk{candidate_spans[i], logits.data()[i - batch_start]});
+ }
+ }
+
+ return true;
+}
+
+bool TextClassifier::DatetimeChunk(const UnicodeText& context_unicode,
+ int64 reference_time_ms_utc,
+ const std::string& reference_timezone,
+ const std::string& locales,
+ std::vector<AnnotatedSpan>* result) const {
+ std::vector<DatetimeParseResultSpan> datetime_spans;
+ if (!datetime_parser_->Parse(context_unicode, reference_time_ms_utc,
+ reference_timezone, locales, &datetime_spans)) {
+ return false;
+ }
+ for (const DatetimeParseResultSpan& datetime_span : datetime_spans) {
+ AnnotatedSpan annotated_span;
+ annotated_span.span = datetime_span.span;
+ annotated_span.classification = {{kDateCollection,
+ datetime_span.target_classification_score,
+ datetime_span.priority_score}};
+ annotated_span.classification[0].datetime_parse_result = datetime_span.data;
+
+ result->push_back(std::move(annotated_span));
+ }
+ return true;
+}
+
const Model* ViewModel(const void* buffer, int size) {
if (!buffer) {
return nullptr;
diff --git a/text-classifier.h b/text-classifier.h
index cd84eb4..33d6357 100644
--- a/text-classifier.h
+++ b/text-classifier.h
@@ -16,14 +16,15 @@
// Inference code for the text classification model.
-#ifndef KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_TEXT_CLASSIFIER_H_
-#define KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_TEXT_CLASSIFIER_H_
+#ifndef LIBTEXTCLASSIFIER_TEXT_CLASSIFIER_H_
+#define LIBTEXTCLASSIFIER_TEXT_CLASSIFIER_H_
#include <memory>
#include <set>
#include <string>
#include <vector>
+#include "datetime/parser.h"
#include "feature-processor.h"
#include "model-executor.h"
#include "model_generated.h"
@@ -34,20 +35,62 @@
namespace libtextclassifier2 {
+struct SelectionOptions {
+ // Comma-separated list of locale specification for the input text (BCP 47
+ // tags).
+ std::string locales;
+
+ static SelectionOptions Default() { return SelectionOptions(); }
+};
+
+struct ClassificationOptions {
+ // For parsing relative datetimes, the reference now time against which the
+ // relative datetimes get resolved.
+ // UTC milliseconds since epoch.
+ int64 reference_time_ms_utc = 0;
+
+ // Timezone in which the input text was written (format as accepted by ICU).
+ std::string reference_timezone;
+
+ // Comma-separated list of locale specification for the input text (BCP 47
+ // tags).
+ std::string locales;
+
+ static ClassificationOptions Default() { return ClassificationOptions(); }
+};
+
+struct AnnotationOptions {
+ // For parsing relative datetimes, the reference now time against which the
+ // relative datetimes get resolved.
+ // UTC milliseconds since epoch.
+ int64 reference_time_ms_utc = 0;
+
+ // Timezone in which the input text was written (format as accepted by ICU).
+ std::string reference_timezone;
+
+ // Comma-separated list of locale specification for the input text (BCP 47
+ // tags).
+ std::string locales;
+
+ static AnnotationOptions Default() { return AnnotationOptions(); }
+};
+
// A text processing model that provides text classification, annotation,
// selection suggestion for various types.
// NOTE: This class is not thread-safe.
class TextClassifier {
public:
- static std::unique_ptr<TextClassifier> FromUnownedBuffer(const char* buffer,
- int size);
+ static std::unique_ptr<TextClassifier> FromUnownedBuffer(
+ const char* buffer, int size, const UniLib* unilib = nullptr);
// Takes ownership of the mmap.
static std::unique_ptr<TextClassifier> FromScopedMmap(
- std::unique_ptr<ScopedMmap>* mmap);
- static std::unique_ptr<TextClassifier> FromFileDescriptor(int fd, int offset,
- int size);
- static std::unique_ptr<TextClassifier> FromFileDescriptor(int fd);
- static std::unique_ptr<TextClassifier> FromPath(const std::string& path);
+ std::unique_ptr<ScopedMmap>* mmap, const UniLib* unilib = nullptr);
+ static std::unique_ptr<TextClassifier> FromFileDescriptor(
+ int fd, int offset, int size, const UniLib* unilib = nullptr);
+ static std::unique_ptr<TextClassifier> FromFileDescriptor(
+ int fd, const UniLib* unilib = nullptr);
+ static std::unique_ptr<TextClassifier> FromPath(
+ const std::string& path, const UniLib* unilib = nullptr);
// Returns true if the model is ready for use.
bool IsInitialized() { return initialized_; }
@@ -60,31 +103,54 @@
// NOTE: The selection indices are passed in and returned in terms of
// UTF8 codepoints (not bytes).
// Requires that the model is a smart selection model.
- CodepointSpan SuggestSelection(const std::string& context,
- CodepointSpan click_indices) const;
+ CodepointSpan SuggestSelection(
+ const std::string& context, CodepointSpan click_indices,
+ const SelectionOptions& options = SelectionOptions::Default()) const;
// Classifies the selected text given the context string.
// Returns an empty result if an error occurs.
- std::vector<std::pair<std::string, float>> ClassifyText(
- const std::string& context, CodepointSpan selection_indices) const;
+ std::vector<ClassificationResult> ClassifyText(
+ const std::string& context, CodepointSpan selection_indices,
+ const ClassificationOptions& options =
+ ClassificationOptions::Default()) const;
- // Annotates given input text. The annotations should cover the whole input
- // context except for whitespaces, and are sorted by their position in the
- // context string.
- std::vector<AnnotatedSpan> Annotate(const std::string& context) const;
+ // Annotates given input text. The annotations are sorted by their position
+ // in the context string and exclude spans classified as 'other'.
+ std::vector<AnnotatedSpan> Annotate(
+ const std::string& context,
+ const AnnotationOptions& options = AnnotationOptions::Default()) const;
+
+ // Exposes the selection feature processor for tests and evaluations.
+ const FeatureProcessor& SelectionFeatureProcessorForTests() const;
+
+ // String collection names for various classes.
+ static const std::string& kOtherCollection;
+ static const std::string& kPhoneCollection;
+ static const std::string& kDateCollection;
protected:
+ struct ScoredChunk {
+ TokenSpan token_span;
+ float score;
+ };
+
// Constructs and initializes text classifier from given model.
// Takes ownership of 'mmap', and thus owns the buffer that backs 'model'.
- TextClassifier(std::unique_ptr<ScopedMmap>* mmap, const Model* model)
- : model_(model), mmap_(std::move(*mmap)), unilib_(new UniLib()) {
+ TextClassifier(std::unique_ptr<ScopedMmap>* mmap, const Model* model,
+ const UniLib* unilib)
+ : model_(model),
+ mmap_(std::move(*mmap)),
+ owned_unilib_(nullptr),
+ unilib_(internal::MaybeCreateUnilib(unilib, &owned_unilib_)) {
ValidateAndInitialize();
}
// Constructs, validates and initializes text classifier from given model.
// Does not own the buffer that backs 'model'.
- explicit TextClassifier(const Model* model)
- : model_(model), unilib_(new UniLib()) {
+ explicit TextClassifier(const Model* model, const UniLib* unilib)
+ : model_(model),
+ owned_unilib_(nullptr),
+ unilib_(internal::MaybeCreateUnilib(unilib, &owned_unilib_)) {
ValidateAndInitialize();
}
@@ -93,7 +159,55 @@
void ValidateAndInitialize();
// Initializes regular expressions for the regex model.
- void InitializeRegexModel();
+ bool InitializeRegexModel();
+
+ // Resolves conflicts in the list of candidates by removing some overlapping
+ // ones. Returns indices of the surviving ones.
+ // NOTE: Assumes that the candidates are sorted according to their position in
+ // the span.
+ bool ResolveConflicts(const std::vector<AnnotatedSpan>& candidates,
+ const std::string& context,
+ std::vector<int>* result) const;
+
+ // Resolves one conflict between candidates on indices 'start_index'
+ // (inclusive) and 'end_index' (exclusive). Assigns the winning candidate
+ // indices to 'chosen_indices'. Returns false if a problem arises.
+ bool ResolveConflict(const std::string& context,
+ const std::vector<AnnotatedSpan>& candidates,
+ int start_index, int end_index,
+ std::vector<int>* chosen_indices) const;
+
+ // Gets selection candidates from the ML model.
+ bool ModelSuggestSelection(const UnicodeText& context_unicode,
+ CodepointSpan click_indices,
+ std::vector<AnnotatedSpan>* result) const;
+
+ // Classifies the selected text given the context string with the
+ // classification model.
+ // Returns true if no error occurred.
+ bool ModelClassifyText(
+ const std::string& context, CodepointSpan selection_indices,
+ std::vector<ClassificationResult>* classification_results) const;
+
+ // Classifies the selected text with the regular expressions models.
+ // Returns true if any regular expression matched and the result was set.
+ bool RegexClassifyText(const std::string& context,
+ CodepointSpan selection_indices,
+ ClassificationResult* classification_result) const;
+
+ // Classifies the selected text with the date time model.
+ // Returns true if there was a match and the result was set.
+ bool DatetimeClassifyText(const std::string& context,
+ CodepointSpan selection_indices,
+ const ClassificationOptions& options,
+ ClassificationResult* classification_result) const;
+
+ // Chunks given input text with the selection model and classifies the spans
+ // with the classification model.
+ // The annotations are sorted by their position in the context string and
+ // exclude spans classified as 'other'.
+ bool ModelAnnotate(const std::string& context,
+ std::vector<AnnotatedSpan>* result) const;
// Groups the tokens into chunks. A chunk is a token span that should be the
// suggested selection when any of its contained tokens is clicked. The chunks
@@ -104,15 +218,37 @@
// The resulting chunks all have to overlap with it and they cover this span
// completely. The first and last chunk might extend beyond it.
// The chunks vector is cleared before filling.
- bool Chunk(int num_tokens, const TokenSpan& span_of_interest,
- const CachedFeatures& cached_features,
- std::vector<TokenSpan>* chunks) const;
+ bool ModelChunk(int num_tokens, const TokenSpan& span_of_interest,
+ const CachedFeatures& cached_features,
+ std::vector<TokenSpan>* chunks) const;
- // Collection name for other.
- const std::string kOtherCollection = "other";
+ // A helper method for ModelChunk(). It generates scored chunk candidates for
+ // a click context model.
+ // NOTE: The returned chunks can (and most likely do) overlap.
+ bool ModelClickContextScoreChunks(
+ int num_tokens, const TokenSpan& span_of_interest,
+ const CachedFeatures& cached_features,
+ std::vector<ScoredChunk>* scored_chunks) const;
- // Collection name for phone.
- const std::string kPhoneCollection = "phone";
+ // A helper method for ModelChunk(). It generates scored chunk candidates for
+ // a bounds-sensitive model.
+ // NOTE: The returned chunks can (and most likely do) overlap.
+ bool ModelBoundsSensitiveScoreChunks(
+ int num_tokens, const TokenSpan& span_of_interest,
+ const TokenSpan& inference_span, const CachedFeatures& cached_features,
+ std::vector<ScoredChunk>* scored_chunks) const;
+
+ // Produces chunks isolated by a set of regular expressions.
+ bool RegexChunk(const UnicodeText& context_unicode,
+ const std::vector<int>& rules,
+ std::vector<AnnotatedSpan>* result) const;
+
+ // Produces chunks from the datetime parser.
+ bool DatetimeChunk(const UnicodeText& context_unicode,
+ int64 reference_time_ms_utc,
+ const std::string& reference_timezone,
+ const std::string& locales,
+ std::vector<AnnotatedSpan>* result) const;
const Model* model_;
@@ -123,16 +259,27 @@
std::unique_ptr<FeatureProcessor> selection_feature_processor_;
std::unique_ptr<FeatureProcessor> classification_feature_processor_;
+ std::unique_ptr<DatetimeParser> datetime_parser_;
+
private:
struct CompiledRegexPattern {
std::string collection_name;
+ float target_classification_score;
+ float priority_score;
std::unique_ptr<UniLib::RegexPattern> pattern;
};
std::unique_ptr<ScopedMmap> mmap_;
bool initialized_ = false;
+
std::vector<CompiledRegexPattern> regex_patterns_;
- std::unique_ptr<UniLib> unilib_;
+
+ // Indices into regex_patterns_ for the different modes.
+ std::vector<int> annotation_regex_patterns_, classification_regex_patterns_,
+ selection_regex_patterns_;
+
+ std::unique_ptr<UniLib> owned_unilib_;
+ const UniLib* unilib_;
};
// Interprets the buffer as a Model flatbuffer and returns it for reading.
@@ -140,4 +287,4 @@
} // namespace libtextclassifier2
-#endif // KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_TEXT_CLASSIFIER_H_
+#endif // LIBTEXTCLASSIFIER_TEXT_CLASSIFIER_H_
diff --git a/text-classifier_test.cc b/text-classifier_test.cc
index 82904e5..b6cf16a 100644
--- a/text-classifier_test.cc
+++ b/text-classifier_test.cc
@@ -21,6 +21,8 @@
#include <memory>
#include <string>
+#include "model_generated.h"
+#include "types-test-util.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
@@ -29,13 +31,13 @@
using testing::ElementsAreArray;
using testing::Pair;
+using testing::Values;
-std::string FirstResult(
- const std::vector<std::pair<std::string, float>>& results) {
+std::string FirstResult(const std::vector<ClassificationResult>& results) {
if (results.empty()) {
return "<INVALID RESULTS>";
}
- return results[0].first;
+ return results[0].collection;
}
MATCHER_P3(IsAnnotatedSpan, start, end, best_class, "") {
@@ -53,27 +55,30 @@
}
TEST(TextClassifierTest, EmbeddingExecutorLoadingFails) {
+ CREATE_UNILIB_FOR_TESTING;
std::unique_ptr<TextClassifier> classifier =
- TextClassifier::FromPath(GetModelPath() + "wrong_embeddings.fb");
+ TextClassifier::FromPath(GetModelPath() + "wrong_embeddings.fb", &unilib);
EXPECT_FALSE(classifier);
}
-TEST(TextClassifierTest, ClassifyText) {
+class TextClassifierTest : public ::testing::TestWithParam<const char*> {};
+
+INSTANTIATE_TEST_CASE_P(ClickContext, TextClassifierTest,
+ Values("test_model_cc.fb"));
+INSTANTIATE_TEST_CASE_P(BoundsSensitive, TextClassifierTest,
+ Values("test_model.fb"));
+
+TEST_P(TextClassifierTest, ClassifyText) {
+ CREATE_UNILIB_FOR_TESTING;
std::unique_ptr<TextClassifier> classifier =
- TextClassifier::FromPath(GetModelPath() + "test_model.fb");
+ TextClassifier::FromPath(GetModelPath() + GetParam(), &unilib);
ASSERT_TRUE(classifier);
EXPECT_EQ("other",
FirstResult(classifier->ClassifyText(
"this afternoon Barack Obama gave a speech at", {15, 27})));
- EXPECT_EQ("other",
- FirstResult(classifier->ClassifyText("you@android.com", {0, 15})));
- EXPECT_EQ("other", FirstResult(classifier->ClassifyText(
- "Contact me at you@android.com", {14, 29})));
EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
"Call me at (800) 123-456 today", {11, 24})));
- EXPECT_EQ("other", FirstResult(classifier->ClassifyText(
- "Visit www.google.com every today!", {6, 20})));
// More lines.
EXPECT_EQ("other",
@@ -81,11 +86,6 @@
"this afternoon Barack Obama gave a speech at|Visit "
"www.google.com every today!|Call me at (800) 123-456 today.",
{15, 27})));
- EXPECT_EQ("other",
- FirstResult(classifier->ClassifyText(
- "this afternoon Barack Obama gave a speech at|Visit "
- "www.google.com every today!|Call me at (800) 123-456 today.",
- {51, 65})));
EXPECT_EQ("phone",
FirstResult(classifier->ClassifyText(
"this afternoon Barack Obama gave a speech at|Visit "
@@ -105,9 +105,216 @@
"a\n\n\n\nx x x\n\n\n\n\n\n", {1, 5})));
}
-TEST(TextClassifierTest, PhoneFiltering) {
+std::unique_ptr<RegexModel_::PatternT> MakePattern(
+ const std::string& collection_name, const std::string& pattern,
+ const bool enabled_for_classification, const bool enabled_for_selection,
+ const bool enabled_for_annotation, const float score) {
+ std::unique_ptr<RegexModel_::PatternT> result(new RegexModel_::PatternT);
+ result->collection_name = collection_name;
+ result->pattern = pattern;
+ result->enabled_for_selection = enabled_for_selection;
+ result->enabled_for_classification = enabled_for_classification;
+ result->enabled_for_annotation = enabled_for_annotation;
+ result->target_classification_score = score;
+ result->priority_score = score;
+ return result;
+}
+
+#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
+TEST_P(TextClassifierTest, ClassifyTextRegularExpression) {
+ CREATE_UNILIB_FOR_TESTING;
+ const std::string test_model = ReadFile(GetModelPath() + GetParam());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ // Add test regex models.
+ unpacked_model->regex_model->patterns.push_back(MakePattern(
+ "person", "Barack Obama", /*enabled_for_classification=*/true,
+ /*enabled_for_selection=*/false, /*enabled_for_annotation=*/false, 1.0));
+ unpacked_model->regex_model->patterns.push_back(MakePattern(
+ "flight", "[a-zA-Z]{2}\\d{2,4}", /*enabled_for_classification=*/true,
+ /*enabled_for_selection=*/false, /*enabled_for_annotation=*/false, 0.5));
+
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(Model::Pack(builder, unpacked_model.get()));
+
std::unique_ptr<TextClassifier> classifier =
- TextClassifier::FromPath(GetModelPath() + "test_model.fb");
+ TextClassifier::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), &unilib);
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ("flight",
+ FirstResult(classifier->ClassifyText(
+ "Your flight LX373 is delayed by 3 hours.", {12, 17})));
+ EXPECT_EQ("person",
+ FirstResult(classifier->ClassifyText(
+ "this afternoon Barack Obama gave a speech at", {15, 27})));
+ EXPECT_EQ("email",
+ FirstResult(classifier->ClassifyText("you@android.com", {0, 15})));
+ EXPECT_EQ("email", FirstResult(classifier->ClassifyText(
+ "Contact me at you@android.com", {14, 29})));
+
+ EXPECT_EQ("url", FirstResult(classifier->ClassifyText(
+ "Visit www.google.com every today!", {6, 20})));
+
+ EXPECT_EQ("flight", FirstResult(classifier->ClassifyText("LX 37", {0, 5})));
+ EXPECT_EQ("flight", FirstResult(classifier->ClassifyText("flight LX 37 abcd",
+ {7, 12})));
+
+ // More lines.
+ EXPECT_EQ("url",
+ FirstResult(classifier->ClassifyText(
+ "this afternoon Barack Obama gave a speech at|Visit "
+ "www.google.com every today!|Call me at (800) 123-456 today.",
+ {51, 65})));
+}
+#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
+
+#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
+
+TEST_P(TextClassifierTest, SuggestSelectionRegularExpression) {
+ CREATE_UNILIB_FOR_TESTING;
+ const std::string test_model = ReadFile(GetModelPath() + GetParam());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ // Add test regex models.
+ unpacked_model->regex_model.reset(new RegexModelT);
+ unpacked_model->regex_model->patterns.push_back(MakePattern(
+ "person", " (Barack Obama) ", /*enabled_for_classification=*/false,
+ /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
+ unpacked_model->regex_model->patterns.push_back(MakePattern(
+ "flight", "([a-zA-Z]{2} ?\\d{2,4})", /*enabled_for_classification=*/false,
+ /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
+ unpacked_model->regex_model->patterns.back()->priority_score = 1.1;
+
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<TextClassifier> classifier =
+ TextClassifier::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), &unilib);
+ ASSERT_TRUE(classifier);
+
+ // Check regular expression selection.
+ EXPECT_EQ(classifier->SuggestSelection(
+ "Your flight MA 0123 is delayed by 3 hours.", {12, 14}),
+ std::make_pair(12, 19));
+ EXPECT_EQ(classifier->SuggestSelection(
+ "this afternoon Barack Obama gave a speech at", {15, 21}),
+ std::make_pair(15, 27));
+}
+#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
+
+#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
+TEST_P(TextClassifierTest,
+ SuggestSelectionRegularExpressionConflictsModelWins) {
+ const std::string test_model = ReadFile(GetModelPath() + GetParam());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ // Add test regex models.
+ unpacked_model->regex_model.reset(new RegexModelT);
+ unpacked_model->regex_model->patterns.push_back(MakePattern(
+ "person", " (Barack Obama) ", /*enabled_for_classification=*/false,
+ /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
+ unpacked_model->regex_model->patterns.push_back(MakePattern(
+ "flight", "([a-zA-Z]{2} ?\\d{2,4})", /*enabled_for_classification=*/false,
+ /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
+ unpacked_model->regex_model->patterns.back()->priority_score = 0.5;
+
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<TextClassifier> classifier =
+ TextClassifier::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize());
+ ASSERT_TRUE(classifier);
+
+ // Check conflict resolution.
+ EXPECT_EQ(
+ classifier->SuggestSelection(
+ "saw Barack Obama today .. 350 Third Street, Cambridge, MA 0123",
+ {55, 57}),
+ std::make_pair(26, 62));
+}
+#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
+
+#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
+TEST_P(TextClassifierTest,
+ SuggestSelectionRegularExpressionConflictsRegexWins) {
+ const std::string test_model = ReadFile(GetModelPath() + GetParam());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ // Add test regex models.
+ unpacked_model->regex_model.reset(new RegexModelT);
+ unpacked_model->regex_model->patterns.push_back(MakePattern(
+ "person", " (Barack Obama) ", /*enabled_for_classification=*/false,
+ /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
+ unpacked_model->regex_model->patterns.push_back(MakePattern(
+ "flight", "([a-zA-Z]{2} ?\\d{2,4})", /*enabled_for_classification=*/false,
+ /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
+ unpacked_model->regex_model->patterns.back()->priority_score = 1.1;
+
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<TextClassifier> classifier =
+ TextClassifier::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize());
+ ASSERT_TRUE(classifier);
+
+ // Check conflict resolution.
+ EXPECT_EQ(
+ classifier->SuggestSelection(
+ "saw Barack Obama today .. 350 Third Street, Cambridge, MA 0123",
+ {55, 57}),
+ std::make_pair(55, 62));
+}
+#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
+
+#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
+TEST_P(TextClassifierTest, AnnotateRegex) {
+ CREATE_UNILIB_FOR_TESTING;
+ const std::string test_model = ReadFile(GetModelPath() + GetParam());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ // Add test regex models.
+ unpacked_model->regex_model.reset(new RegexModelT);
+ unpacked_model->regex_model->patterns.push_back(MakePattern(
+ "person", " (Barack Obama) ", /*enabled_for_classification=*/false,
+ /*enabled_for_selection=*/false, /*enabled_for_annotation=*/true, 1.0));
+ unpacked_model->regex_model->patterns.push_back(MakePattern(
+ "flight", "([a-zA-Z]{2} ?\\d{2,4})", /*enabled_for_classification=*/false,
+ /*enabled_for_selection=*/false, /*enabled_for_annotation=*/true, 0.5));
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<TextClassifier> classifier =
+ TextClassifier::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), &unilib);
+ ASSERT_TRUE(classifier);
+
+ const std::string test_string =
+ "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
+ "number is 853 225 3556";
+ EXPECT_THAT(classifier->Annotate(test_string),
+ ElementsAreArray({
+ IsAnnotatedSpan(6, 18, "person"),
+ IsAnnotatedSpan(19, 24, "date"),
+ IsAnnotatedSpan(28, 55, "address"),
+ IsAnnotatedSpan(79, 91, "phone"),
+ }));
+}
+
+#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
+
+TEST_P(TextClassifierTest, PhoneFiltering) {
+ CREATE_UNILIB_FOR_TESTING;
+ std::unique_ptr<TextClassifier> classifier =
+ TextClassifier::FromPath(GetModelPath() + GetParam(), &unilib);
ASSERT_TRUE(classifier);
EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
@@ -118,9 +325,10 @@
"phone: (123) 456 789,0001112", {7, 28})));
}
-TEST(TextClassifierTest, SuggestSelection) {
+TEST_P(TextClassifierTest, SuggestSelection) {
+ CREATE_UNILIB_FOR_TESTING;
std::unique_ptr<TextClassifier> classifier =
- TextClassifier::FromPath(GetModelPath() + "test_model.fb");
+ TextClassifier::FromPath(GetModelPath() + GetParam(), &unilib);
ASSERT_TRUE(classifier);
EXPECT_EQ(classifier->SuggestSelection(
@@ -147,15 +355,12 @@
EXPECT_EQ(
classifier->SuggestSelection("call me at (857) 225 3556 today", {11, 16}),
std::make_pair(11, 25));
- EXPECT_EQ(
- classifier->SuggestSelection("call me at (857 225 3556 today", {11, 15}),
- std::make_pair(12, 24));
- EXPECT_EQ(
- classifier->SuggestSelection("call me at 857 225 3556) today", {11, 14}),
- std::make_pair(11, 23));
- EXPECT_EQ(
- classifier->SuggestSelection("call me at )857 225 3556( today", {11, 15}),
- std::make_pair(12, 24));
+ EXPECT_EQ(classifier->SuggestSelection("call me at (857 today", {11, 15}),
+ std::make_pair(12, 15));
+ EXPECT_EQ(classifier->SuggestSelection("call me at 3556) today", {11, 16}),
+ std::make_pair(11, 15));
+ EXPECT_EQ(classifier->SuggestSelection("call me at )857( today", {11, 16}),
+ std::make_pair(12, 15));
// If the resulting selection would be empty, the original span is returned.
EXPECT_EQ(classifier->SuggestSelection("call me at )( today", {11, 13}),
@@ -166,9 +371,10 @@
std::make_pair(11, 12));
}
-TEST(TextClassifierTest, SuggestSelectionsAreSymmetric) {
+TEST_P(TextClassifierTest, SuggestSelectionsAreSymmetric) {
+ CREATE_UNILIB_FOR_TESTING;
std::unique_ptr<TextClassifier> classifier =
- TextClassifier::FromPath(GetModelPath() + "test_model.fb");
+ TextClassifier::FromPath(GetModelPath() + GetParam(), &unilib);
ASSERT_TRUE(classifier);
EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {0, 3}),
@@ -183,20 +389,26 @@
std::make_pair(6, 33));
}
-TEST(TextClassifierTest, SuggestSelectionWithNewLine) {
+TEST_P(TextClassifierTest, SuggestSelectionWithNewLine) {
+ CREATE_UNILIB_FOR_TESTING;
std::unique_ptr<TextClassifier> classifier =
- TextClassifier::FromPath(GetModelPath() + "test_model.fb");
+ TextClassifier::FromPath(GetModelPath() + GetParam(), &unilib);
ASSERT_TRUE(classifier);
EXPECT_EQ(classifier->SuggestSelection("abc\n857 225 3556", {4, 7}),
std::make_pair(4, 16));
EXPECT_EQ(classifier->SuggestSelection("857 225 3556\nabc", {0, 3}),
std::make_pair(0, 12));
+
+ SelectionOptions options;
+ EXPECT_EQ(classifier->SuggestSelection("857 225\n3556\nabc", {0, 3}, options),
+ std::make_pair(0, 7));
}
-TEST(TextClassifierTest, SuggestSelectionWithPunctuation) {
+TEST_P(TextClassifierTest, SuggestSelectionWithPunctuation) {
+ CREATE_UNILIB_FOR_TESTING;
std::unique_ptr<TextClassifier> classifier =
- TextClassifier::FromPath(GetModelPath() + "test_model.fb");
+ TextClassifier::FromPath(GetModelPath() + GetParam(), &unilib);
ASSERT_TRUE(classifier);
// From the right.
@@ -220,9 +432,10 @@
std::make_pair(16, 27));
}
-TEST(TextClassifierTest, SuggestSelectionNoCrashWithJunk) {
+TEST_P(TextClassifierTest, SuggestSelectionNoCrashWithJunk) {
+ CREATE_UNILIB_FOR_TESTING;
std::unique_ptr<TextClassifier> classifier =
- TextClassifier::FromPath(GetModelPath() + "test_model.fb");
+ TextClassifier::FromPath(GetModelPath() + GetParam(), &unilib);
ASSERT_TRUE(classifier);
// Try passing in bunch of invalid selections.
@@ -239,33 +452,261 @@
std::make_pair(100, 17));
}
-TEST(TextClassifierTest, Annotate) {
+TEST_P(TextClassifierTest, Annotate) {
+ CREATE_UNILIB_FOR_TESTING;
std::unique_ptr<TextClassifier> classifier =
- TextClassifier::FromPath(GetModelPath() + "test_model.fb");
+ TextClassifier::FromPath(GetModelPath() + GetParam(), &unilib);
ASSERT_TRUE(classifier);
const std::string test_string =
- "& saw Barak Obama today .. 350 Third Street, Cambridge\nand my phone "
- "number is 853 225 3556.";
+ "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
+ "number is 853 225 3556";
EXPECT_THAT(classifier->Annotate(test_string),
ElementsAreArray({
- IsAnnotatedSpan(0, 0, "<INVALID RESULTS>"),
- IsAnnotatedSpan(2, 5, "other"),
- IsAnnotatedSpan(6, 11, "other"),
- IsAnnotatedSpan(12, 17, "other"),
- IsAnnotatedSpan(18, 23, "other"),
- IsAnnotatedSpan(24, 24, "<INVALID RESULTS>"),
- IsAnnotatedSpan(27, 54, "address"),
- IsAnnotatedSpan(55, 58, "other"),
- IsAnnotatedSpan(59, 61, "other"),
- IsAnnotatedSpan(62, 67, "other"),
- IsAnnotatedSpan(68, 74, "other"),
- IsAnnotatedSpan(75, 77, "other"),
- IsAnnotatedSpan(78, 90, "phone"),
+#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
+ IsAnnotatedSpan(19, 24, "date"),
+#endif
+ IsAnnotatedSpan(28, 55, "address"),
+ IsAnnotatedSpan(79, 91, "phone"),
}));
+
+ AnnotationOptions options;
+ EXPECT_THAT(classifier->Annotate("853 225 3556", options),
+ ElementsAreArray({IsAnnotatedSpan(0, 12, "phone")}));
+ EXPECT_TRUE(classifier->Annotate("853 225\n3556", options).empty());
}
-// TODO(jacekj): Test the regex functionality.
+TEST_P(TextClassifierTest, AnnotateSmallBatches) {
+ CREATE_UNILIB_FOR_TESTING;
+ const std::string test_model = ReadFile(GetModelPath() + GetParam());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ // Set the batch size.
+ unpacked_model->selection_options->batch_size = 4;
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<TextClassifier> classifier =
+ TextClassifier::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), &unilib);
+ ASSERT_TRUE(classifier);
+
+ const std::string test_string =
+ "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
+ "number is 853 225 3556";
+ EXPECT_THAT(classifier->Annotate(test_string),
+ ElementsAreArray({
+#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
+ IsAnnotatedSpan(19, 24, "date"),
+#endif
+ IsAnnotatedSpan(28, 55, "address"),
+ IsAnnotatedSpan(79, 91, "phone"),
+ }));
+
+ AnnotationOptions options;
+ EXPECT_THAT(classifier->Annotate("853 225 3556", options),
+ ElementsAreArray({IsAnnotatedSpan(0, 12, "phone")}));
+ EXPECT_TRUE(classifier->Annotate("853 225\n3556", options).empty());
+}
+
+TEST_P(TextClassifierTest, AnnotateFilteringDiscardAll) {
+ CREATE_UNILIB_FOR_TESTING;
+ const std::string test_model = ReadFile(GetModelPath() + GetParam());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ // Add test thresholds.
+ unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
+ unpacked_model->triggering_options->min_annotate_confidence =
+ 2.f; // Discards all results.
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<TextClassifier> classifier =
+ TextClassifier::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), &unilib);
+ ASSERT_TRUE(classifier);
+
+ const std::string test_string =
+ "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
+ "number is 853 225 3556";
+ EXPECT_TRUE(classifier->Annotate(test_string).empty());
+}
+
+TEST_P(TextClassifierTest, AnnotateFilteringKeepAll) {
+ CREATE_UNILIB_FOR_TESTING;
+ const std::string test_model = ReadFile(GetModelPath() + GetParam());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ // Add test thresholds.
+ unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
+ unpacked_model->triggering_options->min_annotate_confidence =
+ 0.f; // Keeps all results.
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<TextClassifier> classifier =
+ TextClassifier::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), &unilib);
+ ASSERT_TRUE(classifier);
+
+ const std::string test_string =
+ "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
+ "number is 853 225 3556";
+#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
+ EXPECT_EQ(classifier->Annotate(test_string).size(), 3);
+#else
+ // In non-ICU mode there is no "date" result.
+ EXPECT_EQ(classifier->Annotate(test_string).size(), 2);
+#endif
+}
+
+#ifdef LIBTEXTCLASSIFIER_CALENDAR_ICU
+TEST_P(TextClassifierTest, ClassifyTextDate) {
+ std::unique_ptr<TextClassifier> classifier =
+ TextClassifier::FromPath(GetModelPath() + GetParam());
+ EXPECT_TRUE(classifier);
+
+ std::vector<ClassificationResult> result;
+ ClassificationOptions options;
+
+ options.reference_timezone = "Europe/Zurich";
+ result = classifier->ClassifyText("january 1, 2017", {0, 15}, options);
+
+ ASSERT_EQ(result.size(), 1);
+ EXPECT_THAT(result[0].collection, "date");
+ EXPECT_EQ(result[0].datetime_parse_result.time_ms_utc, 1483225200000);
+ EXPECT_EQ(result[0].datetime_parse_result.granularity,
+ DatetimeGranularity::GRANULARITY_DAY);
+ result.clear();
+
+ options.reference_timezone = "America/Los_Angeles";
+ result = classifier->ClassifyText("march 1, 2017", {0, 13}, options);
+ ASSERT_EQ(result.size(), 1);
+ EXPECT_THAT(result[0].collection, "date");
+ EXPECT_EQ(result[0].datetime_parse_result.time_ms_utc, 1488355200000);
+ EXPECT_EQ(result[0].datetime_parse_result.granularity,
+ DatetimeGranularity::GRANULARITY_DAY);
+ result.clear();
+
+ options.reference_timezone = "America/Los_Angeles";
+ result = classifier->ClassifyText("2018/01/01 10:30:20", {0, 19}, options);
+ ASSERT_EQ(result.size(), 1);
+ EXPECT_THAT(result[0].collection, "date");
+ EXPECT_EQ(result[0].datetime_parse_result.time_ms_utc, 1514831420000);
+ EXPECT_EQ(result[0].datetime_parse_result.granularity,
+ DatetimeGranularity::GRANULARITY_SECOND);
+ result.clear();
+
+ // Date on another line.
+ options.reference_timezone = "Europe/Zurich";
+ result = classifier->ClassifyText(
+ "hello world this is the first line\n"
+ "january 1, 2017",
+ {35, 50}, options);
+ ASSERT_EQ(result.size(), 1);
+ EXPECT_THAT(result[0].collection, "date");
+ EXPECT_EQ(result[0].datetime_parse_result.time_ms_utc, 1483225200000);
+ EXPECT_EQ(result[0].datetime_parse_result.granularity,
+ DatetimeGranularity::GRANULARITY_DAY);
+ result.clear();
+}
+#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
+
+class TestingTextClassifier : public TextClassifier {
+ public:
+ TestingTextClassifier(const std::string& model, const UniLib* unilib)
+ : TextClassifier(ViewModel(model.data(), model.size()), unilib) {}
+
+ using TextClassifier::ResolveConflicts;
+};
+
+AnnotatedSpan MakeAnnotatedSpan(CodepointSpan span,
+ const std::string& collection,
+ const float score) {
+ AnnotatedSpan result;
+ result.span = span;
+ result.classification.push_back({collection, score});
+ return result;
+}
+
+TEST(TextClassifierTest, ResolveConflictsTrivial) {
+ CREATE_UNILIB_FOR_TESTING;
+ TestingTextClassifier classifier("", &unilib);
+
+ std::vector<AnnotatedSpan> candidates{
+ {MakeAnnotatedSpan({0, 1}, "phone", 1.0)}};
+
+ std::vector<int> chosen;
+ classifier.ResolveConflicts(candidates, /*context=*/"", &chosen);
+ EXPECT_THAT(chosen, ElementsAreArray({0}));
+}
+
+TEST(TextClassifierTest, ResolveConflictsSequence) {
+ CREATE_UNILIB_FOR_TESTING;
+ TestingTextClassifier classifier("", &unilib);
+
+ std::vector<AnnotatedSpan> candidates{{
+ MakeAnnotatedSpan({0, 1}, "phone", 1.0),
+ MakeAnnotatedSpan({1, 2}, "phone", 1.0),
+ MakeAnnotatedSpan({2, 3}, "phone", 1.0),
+ MakeAnnotatedSpan({3, 4}, "phone", 1.0),
+ MakeAnnotatedSpan({4, 5}, "phone", 1.0),
+ }};
+
+ std::vector<int> chosen;
+ classifier.ResolveConflicts(candidates, /*context=*/"", &chosen);
+ EXPECT_THAT(chosen, ElementsAreArray({0, 1, 2, 3, 4}));
+}
+
+TEST(TextClassifierTest, ResolveConflictsThreeSpans) {
+ CREATE_UNILIB_FOR_TESTING;
+ TestingTextClassifier classifier("", &unilib);
+
+ std::vector<AnnotatedSpan> candidates{{
+ MakeAnnotatedSpan({0, 3}, "phone", 1.0),
+ MakeAnnotatedSpan({1, 5}, "phone", 0.5), // Looser!
+ MakeAnnotatedSpan({3, 7}, "phone", 1.0),
+ }};
+
+ std::vector<int> chosen;
+ classifier.ResolveConflicts(candidates, /*context=*/"", &chosen);
+ EXPECT_THAT(chosen, ElementsAreArray({0, 2}));
+}
+
+TEST(TextClassifierTest, ResolveConflictsThreeSpansReversed) {
+ CREATE_UNILIB_FOR_TESTING;
+ TestingTextClassifier classifier("", &unilib);
+
+ std::vector<AnnotatedSpan> candidates{{
+ MakeAnnotatedSpan({0, 3}, "phone", 0.5), // Looser!
+ MakeAnnotatedSpan({1, 5}, "phone", 1.0),
+ MakeAnnotatedSpan({3, 7}, "phone", 0.6), // Looser!
+ }};
+
+ std::vector<int> chosen;
+ classifier.ResolveConflicts(candidates, /*context=*/"", &chosen);
+ EXPECT_THAT(chosen, ElementsAreArray({1}));
+}
+
+TEST(TextClassifierTest, ResolveConflictsFiveSpans) {
+ CREATE_UNILIB_FOR_TESTING;
+ TestingTextClassifier classifier("", &unilib);
+
+ std::vector<AnnotatedSpan> candidates{{
+ MakeAnnotatedSpan({0, 3}, "phone", 0.5),
+ MakeAnnotatedSpan({1, 5}, "other", 1.0), // Looser!
+ MakeAnnotatedSpan({3, 7}, "phone", 0.6),
+ MakeAnnotatedSpan({8, 12}, "phone", 0.6), // Looser!
+ MakeAnnotatedSpan({11, 15}, "phone", 0.9),
+ }};
+
+ std::vector<int> chosen;
+ classifier.ResolveConflicts(candidates, /*context=*/"", &chosen);
+ EXPECT_THAT(chosen, ElementsAreArray({0, 2, 4}));
+}
} // namespace
} // namespace libtextclassifier2
diff --git a/textclassifier_jni.cc b/textclassifier_jni.cc
index ecc6500..d98fcae 100644
--- a/textclassifier_jni.cc
+++ b/textclassifier_jni.cc
@@ -19,93 +19,173 @@
#include "textclassifier_jni.h"
#include <jni.h>
+#include <type_traits>
#include <vector>
#include "text-classifier.h"
+#include "util/base/integral_types.h"
#include "util/java/scoped_local_ref.h"
+#include "util/java/string_utils.h"
#include "util/memory/mmap.h"
+#include "util/utf8/unilib.h"
using libtextclassifier2::AnnotatedSpan;
+using libtextclassifier2::AnnotationOptions;
+using libtextclassifier2::ClassificationOptions;
+using libtextclassifier2::ClassificationResult;
+using libtextclassifier2::CodepointSpan;
+using libtextclassifier2::JStringToUtf8String;
using libtextclassifier2::Model;
+using libtextclassifier2::ScopedLocalRef;
+using libtextclassifier2::SelectionOptions;
using libtextclassifier2::TextClassifier;
+#ifdef LIBTEXTCLASSIFIER_UNILIB_JAVAICU
+using libtextclassifier2::UniLib;
+#endif
+
+namespace libtextclassifier2 {
+
+using libtextclassifier2::CodepointSpan;
namespace {
-bool JStringToUtf8String(JNIEnv* env, const jstring& jstr,
- std::string* result) {
- if (jstr == nullptr) {
- *result = std::string();
- return false;
- }
-
- jclass string_class = env->FindClass("java/lang/String");
- if (!string_class) {
- TC_LOG(ERROR) << "Can't find String class";
- return false;
- }
-
- jmethodID get_bytes_id =
- env->GetMethodID(string_class, "getBytes", "(Ljava/lang/String;)[B");
-
- jstring encoding = env->NewStringUTF("UTF-8");
- jbyteArray array = reinterpret_cast<jbyteArray>(
- env->CallObjectMethod(jstr, get_bytes_id, encoding));
-
- jbyte* const array_bytes = env->GetByteArrayElements(array, JNI_FALSE);
- int length = env->GetArrayLength(array);
-
- *result = std::string(reinterpret_cast<char*>(array_bytes), length);
-
- // Release the array.
- env->ReleaseByteArrayElements(array, array_bytes, JNI_ABORT);
- env->DeleteLocalRef(array);
- env->DeleteLocalRef(string_class);
- env->DeleteLocalRef(encoding);
-
- return true;
-}
-
std::string ToStlString(JNIEnv* env, const jstring& str) {
std::string result;
JStringToUtf8String(env, str, &result);
return result;
}
-jobjectArray ScoredStringsToJObjectArray(
- JNIEnv* env, const std::string& result_class_name,
- const std::vector<std::pair<std::string, float>>& classification_result) {
- jclass result_class = env->FindClass(result_class_name.c_str());
+jobjectArray ClassificationResultsToJObjectArray(
+ JNIEnv* env,
+ const std::vector<ClassificationResult>& classification_result) {
+ const ScopedLocalRef<jclass> result_class(
+ env->FindClass(TC_PACKAGE_PATH TC_CLASS_NAME_STR "$ClassificationResult"),
+ env);
if (!result_class) {
- TC_LOG(ERROR) << "Couldn't find result class: " << result_class_name;
+ TC_LOG(ERROR) << "Couldn't find ClassificationResult class.";
+ return nullptr;
+ }
+ const ScopedLocalRef<jclass> datetime_parse_class(
+ env->FindClass(TC_PACKAGE_PATH TC_CLASS_NAME_STR "$DatetimeResult"), env);
+ if (!datetime_parse_class) {
+ TC_LOG(ERROR) << "Couldn't find DatetimeResult class.";
return nullptr;
}
- jmethodID result_class_constructor =
- env->GetMethodID(result_class, "<init>", "(Ljava/lang/String;F)V");
+ const jmethodID result_class_constructor =
+ env->GetMethodID(result_class.get(), "<init>",
+ "(Ljava/lang/String;FL" TC_PACKAGE_PATH TC_CLASS_NAME_STR
+ "$DatetimeResult;)V");
+ const jmethodID datetime_parse_class_constructor =
+ env->GetMethodID(datetime_parse_class.get(), "<init>", "(JI)V");
- jobjectArray results =
- env->NewObjectArray(classification_result.size(), result_class, nullptr);
-
+ const jobjectArray results = env->NewObjectArray(classification_result.size(),
+ result_class.get(), nullptr);
for (int i = 0; i < classification_result.size(); i++) {
jstring row_string =
- env->NewStringUTF(classification_result[i].first.c_str());
+ env->NewStringUTF(classification_result[i].collection.c_str());
+ jobject row_datetime_parse = nullptr;
+ if (classification_result[i].datetime_parse_result.IsSet()) {
+ row_datetime_parse = env->NewObject(
+ datetime_parse_class.get(), datetime_parse_class_constructor,
+ classification_result[i].datetime_parse_result.time_ms_utc,
+ classification_result[i].datetime_parse_result.granularity);
+ }
jobject result =
- env->NewObject(result_class, result_class_constructor, row_string,
- static_cast<jfloat>(classification_result[i].second));
+ env->NewObject(result_class.get(), result_class_constructor, row_string,
+ static_cast<jfloat>(classification_result[i].score),
+ row_datetime_parse);
env->SetObjectArrayElement(results, i, result);
env->DeleteLocalRef(result);
}
- env->DeleteLocalRef(result_class);
return results;
}
-} // namespace
+template <typename T, typename F>
+std::pair<bool, T> CallJniMethod0(JNIEnv* env, jobject object,
+ jclass class_object, F function,
+ const std::string& method_name,
+ const std::string& return_java_type) {
+ const jmethodID method = env->GetMethodID(class_object, method_name.c_str(),
+ ("()" + return_java_type).c_str());
+ if (!method) {
+ return std::make_pair(false, T());
+ }
+ return std::make_pair(true, (env->*function)(object, method));
+}
-namespace libtextclassifier2 {
+SelectionOptions FromJavaSelectionOptions(JNIEnv* env, jobject joptions) {
+ if (!joptions) {
+ return {};
+ }
-using libtextclassifier2::CodepointSpan;
+ const ScopedLocalRef<jclass> options_class(
+ env->FindClass(TC_PACKAGE_PATH TC_CLASS_NAME_STR "$SelectionOptions"),
+ env);
+ const std::pair<bool, jobject> status_or_locales = CallJniMethod0<jobject>(
+ env, joptions, options_class.get(), &JNIEnv::CallObjectMethod,
+ "getLocales", "Ljava/lang/String;");
+ if (!status_or_locales.first) {
+ return {};
+ }
-namespace {
+ SelectionOptions options;
+ options.locales =
+ ToStlString(env, reinterpret_cast<jstring>(status_or_locales.second));
+
+ return options;
+}
+
+template <typename T>
+T FromJavaOptionsInternal(JNIEnv* env, jobject joptions,
+ const std::string& class_name) {
+ if (!joptions) {
+ return {};
+ }
+
+ const ScopedLocalRef<jclass> options_class(env->FindClass(class_name.c_str()),
+ env);
+ if (!options_class) {
+ return {};
+ }
+
+ const std::pair<bool, jobject> status_or_locales = CallJniMethod0<jobject>(
+ env, joptions, options_class.get(), &JNIEnv::CallObjectMethod,
+ "getLocale", "Ljava/lang/String;");
+ const std::pair<bool, jobject> status_or_reference_timezone =
+ CallJniMethod0<jobject>(env, joptions, options_class.get(),
+ &JNIEnv::CallObjectMethod, "getReferenceTimezone",
+ "Ljava/lang/String;");
+ const std::pair<bool, int64> status_or_reference_time_ms_utc =
+ CallJniMethod0<int64>(env, joptions, options_class.get(),
+ &JNIEnv::CallLongMethod, "getReferenceTimeMsUtc",
+ "J");
+
+ if (!status_or_locales.first || !status_or_reference_timezone.first ||
+ !status_or_reference_time_ms_utc.first) {
+ return {};
+ }
+
+ T options;
+ options.locales =
+ ToStlString(env, reinterpret_cast<jstring>(status_or_locales.second));
+ options.reference_timezone = ToStlString(
+ env, reinterpret_cast<jstring>(status_or_reference_timezone.second));
+ options.reference_time_ms_utc = status_or_reference_time_ms_utc.second;
+ return options;
+}
+
+ClassificationOptions FromJavaClassificationOptions(JNIEnv* env,
+ jobject joptions) {
+ return FromJavaOptionsInternal<ClassificationOptions>(
+ env, joptions,
+ TC_PACKAGE_PATH TC_CLASS_NAME_STR "$ClassificationOptions");
+}
+
+AnnotationOptions FromJavaAnnotationOptions(JNIEnv* env, jobject joptions) {
+ return FromJavaOptionsInternal<AnnotationOptions>(
+ env, joptions, TC_PACKAGE_PATH TC_CLASS_NAME_STR "$AnnotationOptions");
+}
CodepointSpan ConvertIndicesBMPUTF8(const std::string& utf8_str,
CodepointSpan orig_indices,
@@ -166,21 +246,34 @@
} // namespace libtextclassifier2
-using libtextclassifier2::CodepointSpan;
+using libtextclassifier2::ClassificationResultsToJObjectArray;
using libtextclassifier2::ConvertIndicesBMPToUTF8;
using libtextclassifier2::ConvertIndicesUTF8ToBMP;
-using libtextclassifier2::ScopedLocalRef;
+using libtextclassifier2::FromJavaAnnotationOptions;
+using libtextclassifier2::FromJavaClassificationOptions;
+using libtextclassifier2::FromJavaSelectionOptions;
+using libtextclassifier2::ToStlString;
JNI_METHOD(jlong, TC_CLASS_NAME, nativeNew)
(JNIEnv* env, jobject thiz, jint fd) {
+#ifdef LIBTEXTCLASSIFIER_UNILIB_JAVAICU
+ return reinterpret_cast<jlong>(
+ TextClassifier::FromFileDescriptor(fd).release(), new UniLib(env));
+#else
return reinterpret_cast<jlong>(
TextClassifier::FromFileDescriptor(fd).release());
+#endif
}
JNI_METHOD(jlong, TC_CLASS_NAME, nativeNewFromPath)
(JNIEnv* env, jobject thiz, jstring path) {
const std::string path_str = ToStlString(env, path);
+#ifdef LIBTEXTCLASSIFIER_UNILIB_JAVAICU
+ return reinterpret_cast<jlong>(
+ TextClassifier::FromPath(path_str, new UniLib(env)).release());
+#else
return reinterpret_cast<jlong>(TextClassifier::FromPath(path_str).release());
+#endif
}
JNI_METHOD(jlong, TC_CLASS_NAME, nativeNewFromAssetFileDescriptor)
@@ -215,13 +308,19 @@
jobject bundle_jfd = env->CallObjectMethod(afd, afd_class_getFileDescriptor);
jint bundle_cfd = env->GetIntField(bundle_jfd, fd_class_descriptor);
+#ifdef LIBTEXTCLASSIFIER_UNILIB_JAVAICU
+ return reinterpret_cast<jlong>(TextClassifier::FromFileDescriptor(
+ bundle_cfd, offset, size, new UniLib(env))
+ .release());
+#else
return reinterpret_cast<jlong>(
TextClassifier::FromFileDescriptor(bundle_cfd, offset, size).release());
+#endif
}
-JNI_METHOD(jintArray, TC_CLASS_NAME, nativeSuggest)
+JNI_METHOD(jintArray, TC_CLASS_NAME, nativeSuggestSelection)
(JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin,
- jint selection_end) {
+ jint selection_end, jobject options) {
if (!ptr) {
return nullptr;
}
@@ -231,8 +330,8 @@
const std::string context_utf8 = ToStlString(env, context);
CodepointSpan input_indices =
ConvertIndicesBMPToUTF8(context_utf8, {selection_begin, selection_end});
- CodepointSpan selection =
- model->SuggestSelection(context_utf8, input_indices);
+ CodepointSpan selection = model->SuggestSelection(
+ context_utf8, input_indices, FromJavaSelectionOptions(env, options));
selection = ConvertIndicesUTF8ToBMP(context_utf8, selection);
jintArray result = env->NewIntArray(2);
@@ -243,28 +342,31 @@
JNI_METHOD(jobjectArray, TC_CLASS_NAME, nativeClassifyText)
(JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin,
- jint selection_end) {
+ jint selection_end, jobject options) {
if (!ptr) {
return nullptr;
}
TextClassifier* ff_model = reinterpret_cast<TextClassifier*>(ptr);
- const std::vector<std::pair<std::string, float>> classification_result =
- ff_model->ClassifyText(ToStlString(env, context),
- {selection_begin, selection_end});
- return ScoredStringsToJObjectArray(
- env, TC_PACKAGE_PATH TC_CLASS_NAME_STR "$ClassificationResult",
- classification_result);
+ const std::string context_utf8 = ToStlString(env, context);
+ const CodepointSpan input_indices =
+ ConvertIndicesBMPToUTF8(context_utf8, {selection_begin, selection_end});
+ const std::vector<ClassificationResult> classification_result =
+ ff_model->ClassifyText(context_utf8, input_indices,
+ FromJavaClassificationOptions(env, options));
+
+ return ClassificationResultsToJObjectArray(env, classification_result);
}
JNI_METHOD(jobjectArray, TC_CLASS_NAME, nativeAnnotate)
-(JNIEnv* env, jobject thiz, jlong ptr, jstring context) {
+(JNIEnv* env, jobject thiz, jlong ptr, jstring context, jobject options) {
if (!ptr) {
return nullptr;
}
TextClassifier* model = reinterpret_cast<TextClassifier*>(ptr);
std::string context_utf8 = ToStlString(env, context);
- std::vector<AnnotatedSpan> annotations = model->Annotate(context_utf8);
+ std::vector<AnnotatedSpan> annotations =
+ model->Annotate(context_utf8, FromJavaAnnotationOptions(env, options));
jclass result_class =
env->FindClass(TC_PACKAGE_PATH TC_CLASS_NAME_STR "$AnnotatedSpan");
@@ -287,9 +389,9 @@
jobject result = env->NewObject(
result_class, result_class_constructor,
static_cast<jint>(span_bmp.first), static_cast<jint>(span_bmp.second),
- ScoredStringsToJObjectArray(
- env, TC_PACKAGE_PATH TC_CLASS_NAME_STR "$ClassificationResult",
- annotations[i].classification));
+ ClassificationResultsToJObjectArray(env,
+
+ annotations[i].classification));
env->SetObjectArrayElement(results, i, result);
env->DeleteLocalRef(result);
}
@@ -305,6 +407,12 @@
JNI_METHOD(jstring, TC_CLASS_NAME, nativeGetLanguage)
(JNIEnv* env, jobject clazz, jint fd) {
+ TC_LOG(WARNING) << "Using deprecated getLanguage().";
+ return JNI_METHOD_NAME(TC_CLASS_NAME, nativeGetLocales)(env, clazz, fd);
+}
+
+JNI_METHOD(jstring, TC_CLASS_NAME, nativeGetLocales)
+(JNIEnv* env, jobject clazz, jint fd) {
std::unique_ptr<libtextclassifier2::ScopedMmap> mmap(
new libtextclassifier2::ScopedMmap(fd));
if (!mmap->handle().ok()) {
@@ -312,10 +420,10 @@
}
const Model* model = libtextclassifier2::ViewModel(
mmap->handle().start(), mmap->handle().num_bytes());
- if (!model || !model->language()) {
+ if (!model || !model->locales()) {
return env->NewStringUTF("");
}
- return env->NewStringUTF(model->language()->c_str());
+ return env->NewStringUTF(model->locales()->c_str());
}
JNI_METHOD(jint, TC_CLASS_NAME, nativeGetVersion)
diff --git a/textclassifier_jni.h b/textclassifier_jni.h
index 1f64fff..9ae9388 100644
--- a/textclassifier_jni.h
+++ b/textclassifier_jni.h
@@ -14,8 +14,8 @@
* limitations under the License.
*/
-#ifndef KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_TEXTCLASSIFIER_JNI_H_
-#define KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_TEXTCLASSIFIER_JNI_H_
+#ifndef LIBTEXTCLASSIFIER_TEXTCLASSIFIER_JNI_H_
+#define LIBTEXTCLASSIFIER_TEXTCLASSIFIER_JNI_H_
#include <jni.h>
#include <string>
@@ -32,7 +32,7 @@
#endif
#ifndef TC_CLASS_NAME
-#define TC_CLASS_NAME SmartSelection
+#define TC_CLASS_NAME TextClassifierImplNative
#endif
#define TC_CLASS_NAME_STR ADD_QUOTES(TC_CLASS_NAME)
@@ -40,10 +40,13 @@
#define TC_PACKAGE_PATH "android/view/textclassifier/"
#endif
+#define JNI_METHOD_NAME_INTERNAL(package_name, class_name, method_name) \
+ Java_##package_name##_##class_name##_##method_name
+
#define JNI_METHOD_PRIMITIVE(return_type, package_name, class_name, \
method_name) \
- JNIEXPORT return_type JNICALL \
- Java_##package_name##_##class_name##_##method_name
+ JNIEXPORT return_type JNICALL JNI_METHOD_NAME_INTERNAL( \
+ package_name, class_name, method_name)
// The indirection is needed to correctly expand the TC_PACKAGE_NAME macro.
// See the explanation near ADD_QUOTES macro.
@@ -53,6 +56,12 @@
#define JNI_METHOD(return_type, class_name, method_name) \
JNI_METHOD2(return_type, TC_PACKAGE_NAME, class_name, method_name)
+#define JNI_METHOD_NAME2(package_name, class_name, method_name) \
+ JNI_METHOD_NAME_INTERNAL(package_name, class_name, method_name)
+
+#define JNI_METHOD_NAME(class_name, method_name) \
+ JNI_METHOD_NAME2(TC_PACKAGE_NAME, class_name, method_name)
+
#ifdef __cplusplus
extern "C" {
#endif
@@ -67,23 +76,27 @@
JNI_METHOD(jlong, TC_CLASS_NAME, nativeNewFromAssetFileDescriptor)
(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size);
-JNI_METHOD(jintArray, TC_CLASS_NAME, nativeSuggest)
+JNI_METHOD(jintArray, TC_CLASS_NAME, nativeSuggestSelection)
(JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin,
- jint selection_end);
+ jint selection_end, jobject options);
JNI_METHOD(jobjectArray, TC_CLASS_NAME, nativeClassifyText)
(JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin,
- jint selection_end);
+ jint selection_end, jobject options);
JNI_METHOD(jobjectArray, TC_CLASS_NAME, nativeAnnotate)
-(JNIEnv* env, jobject thiz, jlong ptr, jstring context);
+(JNIEnv* env, jobject thiz, jlong ptr, jstring context, jobject options);
JNI_METHOD(void, TC_CLASS_NAME, nativeClose)
(JNIEnv* env, jobject thiz, jlong ptr);
+// DEPRECATED. Use nativeGetLocales instead.
JNI_METHOD(jstring, TC_CLASS_NAME, nativeGetLanguage)
(JNIEnv* env, jobject clazz, jint fd);
+JNI_METHOD(jstring, TC_CLASS_NAME, nativeGetLocales)
+(JNIEnv* env, jobject clazz, jint fd);
+
JNI_METHOD(jint, TC_CLASS_NAME, nativeGetVersion)
(JNIEnv* env, jobject clazz, jint fd);
@@ -106,4 +119,4 @@
} // namespace libtextclassifier2
-#endif // KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_TEXTCLASSIFIER_JNI_H_
+#endif // LIBTEXTCLASSIFIER_TEXTCLASSIFIER_JNI_H_
diff --git a/token-feature-extractor.cc b/token-feature-extractor.cc
index 33c4d75..e194179 100644
--- a/token-feature-extractor.cc
+++ b/token-feature-extractor.cc
@@ -74,10 +74,88 @@
: options_(options), unilib_(unilib) {
for (const std::string& pattern : options.regexp_features) {
regex_patterns_.push_back(std::unique_ptr<UniLib::RegexPattern>(
- unilib_.CreateRegexPattern(pattern)));
+ unilib_.CreateRegexPattern(UTF8ToUnicodeText(
+ pattern.c_str(), pattern.size(), /*do_copy=*/false))));
}
}
+bool TokenFeatureExtractor::Extract(const Token& token, bool is_in_span,
+ std::vector<int>* sparse_features,
+ std::vector<float>* dense_features) const {
+ if (sparse_features == nullptr || dense_features == nullptr) {
+ return false;
+ }
+ *sparse_features = ExtractCharactergramFeatures(token);
+ *dense_features = ExtractDenseFeatures(token, is_in_span);
+ return true;
+}
+
+std::vector<int> TokenFeatureExtractor::ExtractCharactergramFeatures(
+ const Token& token) const {
+ if (options_.unicode_aware_features) {
+ return ExtractCharactergramFeaturesUnicode(token);
+ } else {
+ return ExtractCharactergramFeaturesAscii(token);
+ }
+}
+
+std::vector<float> TokenFeatureExtractor::ExtractDenseFeatures(
+ const Token& token, bool is_in_span) const {
+ std::vector<float> dense_features;
+
+ if (options_.extract_case_feature) {
+ if (options_.unicode_aware_features) {
+ UnicodeText token_unicode =
+ UTF8ToUnicodeText(token.value, /*do_copy=*/false);
+ const bool is_upper = unilib_.IsUpper(*token_unicode.begin());
+ if (!token.value.empty() && is_upper) {
+ dense_features.push_back(1.0);
+ } else {
+ dense_features.push_back(-1.0);
+ }
+ } else {
+ if (!token.value.empty() && isupper(*token.value.begin())) {
+ dense_features.push_back(1.0);
+ } else {
+ dense_features.push_back(-1.0);
+ }
+ }
+ }
+
+ if (options_.extract_selection_mask_feature) {
+ if (is_in_span) {
+ dense_features.push_back(1.0);
+ } else {
+ if (options_.unicode_aware_features) {
+ dense_features.push_back(-1.0);
+ } else {
+ dense_features.push_back(0.0);
+ }
+ }
+ }
+
+ // Add regexp features.
+ if (!regex_patterns_.empty()) {
+ UnicodeText token_unicode =
+ UTF8ToUnicodeText(token.value, /*do_copy=*/false);
+ for (int i = 0; i < regex_patterns_.size(); ++i) {
+ if (!regex_patterns_[i].get()) {
+ dense_features.push_back(-1.0);
+ continue;
+ }
+ auto matcher = regex_patterns_[i]->Matcher(token_unicode);
+ int status;
+ if (matcher->Matches(&status)) {
+ dense_features.push_back(1.0);
+ } else {
+ dense_features.push_back(-1.0);
+ }
+ }
+ }
+
+ return dense_features;
+}
+
int TokenFeatureExtractor::HashToken(StringPiece token) const {
if (options_.allowed_chargrams.empty()) {
return tc2farmhash::Fingerprint64(token) % options_.num_buckets;
@@ -101,15 +179,6 @@
}
}
-std::vector<int> TokenFeatureExtractor::ExtractCharactergramFeatures(
- const Token& token) const {
- if (options_.unicode_aware_features) {
- return ExtractCharactergramFeaturesUnicode(token);
- } else {
- return ExtractCharactergramFeaturesAscii(token);
- }
-}
-
std::vector<int> TokenFeatureExtractor::ExtractCharactergramFeaturesAscii(
const Token& token) const {
std::vector<int> result;
@@ -237,63 +306,4 @@
return result;
}
-bool TokenFeatureExtractor::Extract(const Token& token, bool is_in_span,
- std::vector<int>* sparse_features,
- std::vector<float>* dense_features) const {
- if (sparse_features == nullptr || dense_features == nullptr) {
- return false;
- }
-
- *sparse_features = ExtractCharactergramFeatures(token);
-
- if (options_.extract_case_feature) {
- if (options_.unicode_aware_features) {
- UnicodeText token_unicode =
- UTF8ToUnicodeText(token.value, /*do_copy=*/false);
- const bool is_upper = unilib_.IsUpper(*token_unicode.begin());
- if (!token.value.empty() && is_upper) {
- dense_features->push_back(1.0);
- } else {
- dense_features->push_back(-1.0);
- }
- } else {
- if (!token.value.empty() && isupper(*token.value.begin())) {
- dense_features->push_back(1.0);
- } else {
- dense_features->push_back(-1.0);
- }
- }
- }
-
- if (options_.extract_selection_mask_feature) {
- if (is_in_span) {
- dense_features->push_back(1.0);
- } else {
- if (options_.unicode_aware_features) {
- dense_features->push_back(-1.0);
- } else {
- dense_features->push_back(0.0);
- }
- }
- }
-
- // Add regexp features.
- if (!regex_patterns_.empty()) {
- for (int i = 0; i < regex_patterns_.size(); ++i) {
- if (!regex_patterns_[i].get()) {
- dense_features->push_back(-1.0);
- continue;
- }
-
- if (regex_patterns_[i]->Matches(token.value)) {
- dense_features->push_back(1.0);
- } else {
- dense_features->push_back(-1.0);
- }
- }
- }
-
- return true;
-}
-
} // namespace libtextclassifier2
diff --git a/token-feature-extractor.h b/token-feature-extractor.h
index 9d476ba..1646f74 100644
--- a/token-feature-extractor.h
+++ b/token-feature-extractor.h
@@ -14,8 +14,8 @@
* limitations under the License.
*/
-#ifndef KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_TOKEN_FEATURE_EXTRACTOR_H_
-#define KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_TOKEN_FEATURE_EXTRACTOR_H_
+#ifndef LIBTEXTCLASSIFIER_TOKEN_FEATURE_EXTRACTOR_H_
+#define LIBTEXTCLASSIFIER_TOKEN_FEATURE_EXTRACTOR_H_
#include <memory>
#include <unordered_set>
@@ -68,17 +68,22 @@
TokenFeatureExtractor(const TokenFeatureExtractorOptions& options,
const UniLib& unilib);
- // Extracts features from a token.
- // - is_in_span is a bool indicator whether the token is a part of the
- // selection span (true) or not (false).
- // - sparse_features are indices into a sparse feature vector of size
- // options.num_buckets which are set to 1.0 (others are implicitly 0.0).
- // - dense_features are values of a dense feature vector of size 0-2
- // (depending on the options) for the token
+ // Extracts both the sparse (charactergram) and the dense features from a
+ // token. is_in_span is a bool indicator whether the token is a part of the
+ // selection span (true) or not (false).
+ // Fails and returns false if either of the output pointers in a nullptr.
bool Extract(const Token& token, bool is_in_span,
std::vector<int>* sparse_features,
std::vector<float>* dense_features) const;
+ // Extracts the sparse (charactergram) features from the token.
+ std::vector<int> ExtractCharactergramFeatures(const Token& token) const;
+
+ // Extracts the dense features from the token. is_in_span is a bool indicator
+ // whether the token is a part of the selection span (true) or not (false).
+ std::vector<float> ExtractDenseFeatures(const Token& token,
+ bool is_in_span) const;
+
int DenseFeaturesCount() const {
int feature_count =
options_.extract_case_feature + options_.extract_selection_mask_feature;
@@ -90,9 +95,6 @@
// Hashes given token to given number of buckets.
int HashToken(StringPiece token) const;
- // Extracts the charactergram features from the token.
- std::vector<int> ExtractCharactergramFeatures(const Token& token) const;
-
// Extracts the charactergram features from the token in a non-unicode-aware
// way.
std::vector<int> ExtractCharactergramFeaturesAscii(const Token& token) const;
@@ -109,4 +111,4 @@
} // namespace libtextclassifier2
-#endif // KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_TOKEN_FEATURE_EXTRACTOR_H_
+#endif // LIBTEXTCLASSIFIER_TOKEN_FEATURE_EXTRACTOR_H_
diff --git a/token-feature-extractor_test.cc b/token-feature-extractor_test.cc
index d6e48bb..4b7e011 100644
--- a/token-feature-extractor_test.cc
+++ b/token-feature-extractor_test.cc
@@ -35,7 +35,7 @@
options.extract_case_feature = true;
options.unicode_aware_features = false;
options.extract_selection_mask_feature = true;
- const UniLib unilib;
+ CREATE_UNILIB_FOR_TESTING
TestingTokenFeatureExtractor extractor(options, unilib);
std::vector<int> sparse_features;
@@ -106,7 +106,7 @@
options.extract_case_feature = true;
options.unicode_aware_features = false;
options.extract_selection_mask_feature = true;
- const UniLib unilib;
+ CREATE_UNILIB_FOR_TESTING
TestingTokenFeatureExtractor extractor(options, unilib);
std::vector<int> sparse_features;
@@ -136,7 +136,7 @@
options.extract_case_feature = true;
options.unicode_aware_features = true;
options.extract_selection_mask_feature = true;
- const UniLib unilib;
+ CREATE_UNILIB_FOR_TESTING
TestingTokenFeatureExtractor extractor(options, unilib);
std::vector<int> sparse_features;
@@ -207,7 +207,7 @@
options.extract_case_feature = true;
options.unicode_aware_features = true;
options.extract_selection_mask_feature = true;
- const UniLib unilib;
+ CREATE_UNILIB_FOR_TESTING
TestingTokenFeatureExtractor extractor(options, unilib);
std::vector<int> sparse_features;
@@ -239,7 +239,7 @@
options.extract_case_feature = true;
options.unicode_aware_features = true;
options.extract_selection_mask_feature = false;
- const UniLib unilib;
+ CREATE_UNILIB_FOR_TESTING
TestingTokenFeatureExtractor extractor(options, unilib);
std::vector<int> sparse_features;
@@ -274,7 +274,7 @@
options.chargram_orders = std::vector<int>{1, 2};
options.remap_digits = true;
options.unicode_aware_features = false;
- const UniLib unilib;
+ CREATE_UNILIB_FOR_TESTING
TestingTokenFeatureExtractor extractor(options, unilib);
std::vector<int> sparse_features;
@@ -299,7 +299,7 @@
options.chargram_orders = std::vector<int>{1, 2};
options.remap_digits = true;
options.unicode_aware_features = true;
- const UniLib unilib;
+ CREATE_UNILIB_FOR_TESTING
TestingTokenFeatureExtractor extractor(options, unilib);
std::vector<int> sparse_features;
@@ -324,7 +324,7 @@
options.chargram_orders = std::vector<int>{1, 2};
options.lowercase_tokens = true;
options.unicode_aware_features = false;
- const UniLib unilib;
+ CREATE_UNILIB_FOR_TESTING
TestingTokenFeatureExtractor extractor(options, unilib);
std::vector<int> sparse_features;
@@ -349,7 +349,7 @@
options.chargram_orders = std::vector<int>{1, 2};
options.lowercase_tokens = true;
options.unicode_aware_features = true;
- const UniLib unilib;
+ CREATE_UNILIB_FOR_TESTING
TestingTokenFeatureExtractor extractor(options, unilib);
std::vector<int> sparse_features;
@@ -372,7 +372,7 @@
options.unicode_aware_features = false;
options.regexp_features.push_back("^[a-z]+$"); // all lower case.
options.regexp_features.push_back("^[0-9]+$"); // all digits.
- const UniLib unilib;
+ CREATE_UNILIB_FOR_TESTING
TestingTokenFeatureExtractor extractor(options, unilib);
std::vector<int> sparse_features;
@@ -405,7 +405,7 @@
options.extract_case_feature = true;
options.unicode_aware_features = true;
options.extract_selection_mask_feature = true;
- const UniLib unilib;
+ CREATE_UNILIB_FOR_TESTING
TestingTokenFeatureExtractor extractor(options, unilib);
// Test that this runs. ASAN should catch problems.
@@ -431,7 +431,7 @@
options.unicode_aware_features = true;
options.extract_selection_mask_feature = true;
- UniLib unilib;
+ CREATE_UNILIB_FOR_TESTING
TestingTokenFeatureExtractor extractor_unicode(options, unilib);
options.unicode_aware_features = false;
@@ -466,7 +466,7 @@
options.unicode_aware_features = false;
options.extract_selection_mask_feature = true;
- const UniLib unilib;
+ CREATE_UNILIB_FOR_TESTING
TestingTokenFeatureExtractor extractor(options, unilib);
std::vector<int> sparse_features;
@@ -493,7 +493,7 @@
options.allowed_chargrams.insert("!");
options.allowed_chargrams.insert("\xc4"); // UTF8 control character.
- const UniLib unilib;
+ CREATE_UNILIB_FOR_TESTING
TestingTokenFeatureExtractor extractor(options, unilib);
std::vector<int> sparse_features;
diff --git a/tokenizer.cc b/tokenizer.cc
index 456826d..ebc9696 100644
--- a/tokenizer.cc
+++ b/tokenizer.cc
@@ -20,7 +20,6 @@
#include "util/base/logging.h"
#include "util/strings/utf8.h"
-#include "util/utf8/unicodetext.h"
namespace libtextclassifier2 {
@@ -73,15 +72,18 @@
}
}
-std::vector<Token> Tokenizer::Tokenize(const std::string& utf8_text) const {
- UnicodeText context_unicode = UTF8ToUnicodeText(utf8_text, /*do_copy=*/false);
+std::vector<Token> Tokenizer::Tokenize(const std::string& text) const {
+ UnicodeText text_unicode = UTF8ToUnicodeText(text, /*do_copy=*/false);
+ return Tokenize(text_unicode);
+}
+std::vector<Token> Tokenizer::Tokenize(const UnicodeText& text_unicode) const {
std::vector<Token> result;
Token new_token("", 0, 0);
int codepoint_index = 0;
int last_script = kInvalidScript;
- for (auto it = context_unicode.begin(); it != context_unicode.end();
+ for (auto it = text_unicode.begin(); it != text_unicode.end();
++it, ++codepoint_index) {
TokenizationCodepointRange_::Role role;
int script;
diff --git a/tokenizer.h b/tokenizer.h
index 72a9fbd..9ce2c7c 100644
--- a/tokenizer.h
+++ b/tokenizer.h
@@ -14,8 +14,8 @@
* limitations under the License.
*/
-#ifndef KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_TOKENIZER_H_
-#define KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_TOKENIZER_H_
+#ifndef LIBTEXTCLASSIFIER_TOKENIZER_H_
+#define LIBTEXTCLASSIFIER_TOKENIZER_H_
#include <string>
#include <vector>
@@ -23,6 +23,7 @@
#include "model_generated.h"
#include "types.h"
#include "util/base/integral_types.h"
+#include "util/utf8/unicodetext.h"
namespace libtextclassifier2 {
@@ -38,7 +39,10 @@
bool split_on_script_change);
// Tokenizes the input string using the selected tokenization method.
- std::vector<Token> Tokenize(const std::string& utf8_text) const;
+ std::vector<Token> Tokenize(const std::string& text) const;
+
+ // Same as above but takes UnicodeText.
+ std::vector<Token> Tokenize(const UnicodeText& text_unicode) const;
protected:
// Finds the tokenization codepoint range config for given codepoint.
@@ -63,4 +67,4 @@
} // namespace libtextclassifier2
-#endif // KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_TOKENIZER_H_
+#endif // LIBTEXTCLASSIFIER_TOKENIZER_H_
diff --git a/types-test-util.h b/types-test-util.h
new file mode 100644
index 0000000..1679e7c
--- /dev/null
+++ b/types-test-util.h
@@ -0,0 +1,49 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_TYPES_TEST_UTIL_H_
+#define LIBTEXTCLASSIFIER_TYPES_TEST_UTIL_H_
+
+#include <ostream>
+
+#include "types.h"
+#include "util/base/logging.h"
+
+namespace libtextclassifier2 {
+
+inline std::ostream& operator<<(std::ostream& stream, const Token& value) {
+ logging::LoggingStringStream tmp_stream;
+ tmp_stream << value;
+ return stream << tmp_stream.message;
+}
+
+inline std::ostream& operator<<(std::ostream& stream,
+ const AnnotatedSpan& value) {
+ logging::LoggingStringStream tmp_stream;
+ tmp_stream << value;
+ return stream << tmp_stream.message;
+}
+
+inline std::ostream& operator<<(std::ostream& stream,
+ const DatetimeParseResultSpan& value) {
+ logging::LoggingStringStream tmp_stream;
+ tmp_stream << value;
+ return stream << tmp_stream.message;
+}
+
+} // namespace libtextclassifier2
+
+#endif // LIBTEXTCLASSIFIER_TYPES_TEST_UTIL_H_
diff --git a/types.h b/types.h
index d50d438..828a109 100644
--- a/types.h
+++ b/types.h
@@ -14,14 +14,17 @@
* limitations under the License.
*/
-#ifndef KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_TYPES_H_
-#define KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_TYPES_H_
+#ifndef LIBTEXTCLASSIFIER_TYPES_H_
+#define LIBTEXTCLASSIFIER_TYPES_H_
#include <algorithm>
+#include <cmath>
#include <functional>
+#include <set>
#include <string>
#include <utility>
#include <vector>
+#include "util/base/integral_types.h"
#include "util/base/logging.h"
@@ -41,6 +44,42 @@
// TODO(b/71982294): Make it a struct.
using CodepointSpan = std::pair<CodepointIndex, CodepointIndex>;
+inline bool SpansOverlap(const CodepointSpan& a, const CodepointSpan& b) {
+ return a.first < b.second && b.first < a.second;
+}
+
+template <typename T>
+bool DoesCandidateConflict(
+ const int considered_candidate, const std::vector<T>& candidates,
+ const std::set<int, std::function<bool(int, int)>>& chosen_indices_set) {
+ if (chosen_indices_set.empty()) {
+ return false;
+ }
+
+ auto conflicting_it = chosen_indices_set.lower_bound(considered_candidate);
+ // Check conflict on the right.
+ if (conflicting_it != chosen_indices_set.end() &&
+ SpansOverlap(candidates[considered_candidate].span,
+ candidates[*conflicting_it].span)) {
+ return true;
+ }
+
+ // Check conflict on the left.
+ // If we can't go more left, there can't be a conflict:
+ if (conflicting_it == chosen_indices_set.begin()) {
+ return false;
+ }
+ // Otherwise move one span left and insert if it doesn't overlap with the
+ // candidate.
+ --conflicting_it;
+ if (!SpansOverlap(candidates[considered_candidate].span,
+ candidates[*conflicting_it].span)) {
+ return false;
+ }
+
+ return true;
+}
+
// Marks a span in a sequence of tokens. The first element is the index of the
// first token in the span, and the second element is the index of the token one
// past the end of the span.
@@ -112,13 +151,116 @@
}
}
+enum DatetimeGranularity {
+ GRANULARITY_UNKNOWN = -1, // GRANULARITY_UNKNOWN is used as a proxy for this
+ // structure being uninitialized.
+ GRANULARITY_YEAR = 0,
+ GRANULARITY_MONTH = 1,
+ GRANULARITY_WEEK = 2,
+ GRANULARITY_DAY = 3,
+ GRANULARITY_HOUR = 4,
+ GRANULARITY_MINUTE = 5,
+ GRANULARITY_SECOND = 6
+};
+
+struct DatetimeParseResult {
+ // The absolute time in milliseconds since the epoch in UTC. This is derived
+ // from the reference time and the fields specified in the text - so it may
+ // be imperfect where the time was ambiguous. (e.g. "at 7:30" may be am or pm)
+ int64 time_ms_utc;
+
+ // The precision of the estimate then in to calculating the milliseconds
+ DatetimeGranularity granularity;
+
+ DatetimeParseResult() : time_ms_utc(0), granularity(GRANULARITY_UNKNOWN) {}
+
+ DatetimeParseResult(int64 arg_time_ms_utc,
+ DatetimeGranularity arg_granularity)
+ : time_ms_utc(arg_time_ms_utc), granularity(arg_granularity) {}
+
+ bool IsSet() const { return granularity != GRANULARITY_UNKNOWN; }
+
+ bool operator==(const DatetimeParseResult& other) const {
+ return granularity == other.granularity && time_ms_utc == other.time_ms_utc;
+ }
+};
+
+const float kFloatCompareEpsilon = 1e-5;
+
+struct DatetimeParseResultSpan {
+ CodepointSpan span;
+ DatetimeParseResult data;
+ float target_classification_score;
+ float priority_score;
+
+ bool operator==(const DatetimeParseResultSpan& other) const {
+ return span == other.span && data.granularity == other.data.granularity &&
+ data.time_ms_utc == other.data.time_ms_utc &&
+ std::abs(target_classification_score -
+ other.target_classification_score) < kFloatCompareEpsilon &&
+ std::abs(priority_score - other.priority_score) <
+ kFloatCompareEpsilon;
+ }
+};
+
+// Pretty-printing function for DatetimeParseResultSpan.
+inline logging::LoggingStringStream& operator<<(
+ logging::LoggingStringStream& stream,
+ const DatetimeParseResultSpan& value) {
+ return stream << "DatetimeParseResultSpan({" << value.span.first << ", "
+ << value.span.second << "}, {/*time_ms_utc=*/ "
+ << value.data.time_ms_utc << ", /*granularity=*/ "
+ << value.data.granularity << "})";
+}
+
+struct ClassificationResult {
+ std::string collection;
+ float score;
+ DatetimeParseResult datetime_parse_result;
+
+ // Internal score used for conflict resolution.
+ float priority_score;
+
+ explicit ClassificationResult() : score(-1.0f), priority_score(-1.0) {}
+
+ ClassificationResult(const std::string& arg_collection, float arg_score)
+ : collection(arg_collection),
+ score(arg_score),
+ priority_score(arg_score) {}
+
+ ClassificationResult(const std::string& arg_collection, float arg_score,
+ float arg_priority_score)
+ : collection(arg_collection),
+ score(arg_score),
+ priority_score(arg_priority_score) {}
+};
+
+// Pretty-printing function for ClassificationResult.
+inline logging::LoggingStringStream& operator<<(
+ logging::LoggingStringStream& stream, const ClassificationResult& result) {
+ return stream << "ClassificationResult(" << result.collection << ", "
+ << result.score << ")";
+}
+
+// Pretty-printing function for std::vector<ClassificationResult>.
+inline logging::LoggingStringStream& operator<<(
+ logging::LoggingStringStream& stream,
+ const std::vector<ClassificationResult>& results) {
+ stream = stream << "{\n";
+ for (const ClassificationResult& result : results) {
+ stream = stream << " " << result << "\n";
+ }
+ stream = stream << "}";
+ return stream;
+}
+
// Represents a result of Annotate call.
struct AnnotatedSpan {
// Unicode codepoint indices in the input string.
CodepointSpan span = {kInvalidIndex, kInvalidIndex};
// Classification result for the span.
- std::vector<std::pair<std::string, float>> classification;
+ std::vector<ClassificationResult> classification;
};
// Pretty-printing function for AnnotatedSpan.
@@ -127,8 +269,8 @@
std::string best_class;
float best_score = -1;
if (!span.classification.empty()) {
- best_class = span.classification[0].first;
- best_score = span.classification[0].second;
+ best_class = span.classification[0].collection;
+ best_score = span.classification[0].score;
}
return stream << "Span(" << span.span.first << ", " << span.span.second
<< ", " << best_class << ", " << best_score << ")";
@@ -161,4 +303,4 @@
} // namespace libtextclassifier2
-#endif // KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_TYPES_H_
+#endif // LIBTEXTCLASSIFIER_TYPES_H_
diff --git a/util/base/casts.h b/util/base/casts.h
index c33173a..a1d2056 100644
--- a/util/base/casts.h
+++ b/util/base/casts.h
@@ -14,8 +14,8 @@
* limitations under the License.
*/
-#ifndef KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_BASE_CASTS_H_
-#define KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_BASE_CASTS_H_
+#ifndef LIBTEXTCLASSIFIER_UTIL_BASE_CASTS_H_
+#define LIBTEXTCLASSIFIER_UTIL_BASE_CASTS_H_
#include <string.h> // for memcpy
@@ -89,4 +89,4 @@
} // namespace libtextclassifier2
-#endif // KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_BASE_CASTS_H_
+#endif // LIBTEXTCLASSIFIER_UTIL_BASE_CASTS_H_
diff --git a/util/base/config.h b/util/base/config.h
index 41b99a9..8844b14 100644
--- a/util/base/config.h
+++ b/util/base/config.h
@@ -16,8 +16,8 @@
// Define macros to indicate C++ standard / platform / etc we use.
-#ifndef KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_BASE_CONFIG_H_
-#define KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_BASE_CONFIG_H_
+#ifndef LIBTEXTCLASSIFIER_UTIL_BASE_CONFIG_H_
+#define LIBTEXTCLASSIFIER_UTIL_BASE_CONFIG_H_
namespace libtextclassifier2 {
@@ -40,4 +40,4 @@
} // namespace libtextclassifier2
-#endif // KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_BASE_CONFIG_H_
+#endif // LIBTEXTCLASSIFIER_UTIL_BASE_CONFIG_H_
diff --git a/util/base/endian.h b/util/base/endian.h
index 2a6e654..2dfbfd6 100644
--- a/util/base/endian.h
+++ b/util/base/endian.h
@@ -14,8 +14,8 @@
* limitations under the License.
*/
-#ifndef KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_BASE_ENDIAN_H_
-#define KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_BASE_ENDIAN_H_
+#ifndef LIBTEXTCLASSIFIER_UTIL_BASE_ENDIAN_H_
+#define LIBTEXTCLASSIFIER_UTIL_BASE_ENDIAN_H_
#include "util/base/integral_types.h"
@@ -40,7 +40,7 @@
// The following guarantees declaration of the byte swap functions, and
// defines __BYTE_ORDER for MSVC
-#if defined(__GLIBC__) || defined(__CYGWIN__)
+#if defined(__GLIBC__) || defined(__BIONIC__) || defined(__CYGWIN__)
#include <byteswap.h> // IWYU pragma: export
// The following section defines the byte swap functions for OS X / iOS,
// which does not ship with byteswap.h.
@@ -135,4 +135,4 @@
} // namespace libtextclassifier2
-#endif // KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_BASE_ENDIAN_H_
+#endif // LIBTEXTCLASSIFIER_UTIL_BASE_ENDIAN_H_
diff --git a/util/base/integral_types.h b/util/base/integral_types.h
index a599f3c..f82c9cd 100644
--- a/util/base/integral_types.h
+++ b/util/base/integral_types.h
@@ -16,8 +16,8 @@
// Basic integer type definitions.
-#ifndef KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_BASE_INTEGRAL_TYPES_H_
-#define KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_BASE_INTEGRAL_TYPES_H_
+#ifndef LIBTEXTCLASSIFIER_UTIL_BASE_INTEGRAL_TYPES_H_
+#define LIBTEXTCLASSIFIER_UTIL_BASE_INTEGRAL_TYPES_H_
#include "util/base/config.h"
@@ -58,4 +58,4 @@
} // namespace libtextclassifier2
-#endif // KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_BASE_INTEGRAL_TYPES_H_
+#endif // LIBTEXTCLASSIFIER_UTIL_BASE_INTEGRAL_TYPES_H_
diff --git a/util/base/logging.h b/util/base/logging.h
index cebbbf2..4391d46 100644
--- a/util/base/logging.h
+++ b/util/base/logging.h
@@ -14,8 +14,8 @@
* limitations under the License.
*/
-#ifndef KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_BASE_LOGGING_H_
-#define KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_BASE_LOGGING_H_
+#ifndef LIBTEXTCLASSIFIER_UTIL_BASE_LOGGING_H_
+#define LIBTEXTCLASSIFIER_UTIL_BASE_LOGGING_H_
#include <cassert>
#include <string>
@@ -23,22 +23,6 @@
#include "util/base/logging_levels.h"
#include "util/base/port.h"
-// TC_STRIP
-namespace libtextclassifier2 {
-// string class that can't be instantiated. Makes sure that the code does not
-// compile when non std::string is used.
-//
-// NOTE: defined here because most files directly or transitively include this
-// file. Asking people to include a special header just to make sure they don't
-// use the unqualified string doesn't work: as that header doesn't produce any
-// immediate benefit, one can easily forget about it.
-class string {
- public:
- // Makes the class non-instantiable.
- virtual ~string() = 0;
-};
-} // namespace libtextclassifier2
-// TC_END_STRIP
namespace libtextclassifier2 {
namespace logging {
@@ -180,4 +164,4 @@
#define TC_VLOG(severity) TC_NULLSTREAM
#endif
-#endif // KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_BASE_LOGGING_H_
+#endif // LIBTEXTCLASSIFIER_UTIL_BASE_LOGGING_H_
diff --git a/util/base/logging_levels.h b/util/base/logging_levels.h
index 7d7dff2..17c882f 100644
--- a/util/base/logging_levels.h
+++ b/util/base/logging_levels.h
@@ -14,8 +14,8 @@
* limitations under the License.
*/
-#ifndef KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_BASE_LOGGING_LEVELS_H_
-#define KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_BASE_LOGGING_LEVELS_H_
+#ifndef LIBTEXTCLASSIFIER_UTIL_BASE_LOGGING_LEVELS_H_
+#define LIBTEXTCLASSIFIER_UTIL_BASE_LOGGING_LEVELS_H_
namespace libtextclassifier2 {
namespace logging {
@@ -30,4 +30,4 @@
} // namespace logging
} // namespace libtextclassifier2
-#endif // KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_BASE_LOGGING_LEVELS_H_
+#endif // LIBTEXTCLASSIFIER_UTIL_BASE_LOGGING_LEVELS_H_
diff --git a/util/base/logging_raw.h b/util/base/logging_raw.h
index 6cae105..e6265c7 100644
--- a/util/base/logging_raw.h
+++ b/util/base/logging_raw.h
@@ -14,8 +14,8 @@
* limitations under the License.
*/
-#ifndef KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_BASE_LOGGING_RAW_H_
-#define KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_BASE_LOGGING_RAW_H_
+#ifndef LIBTEXTCLASSIFIER_UTIL_BASE_LOGGING_RAW_H_
+#define LIBTEXTCLASSIFIER_UTIL_BASE_LOGGING_RAW_H_
#include <string>
@@ -33,4 +33,4 @@
} // namespace logging
} // namespace libtextclassifier2
-#endif // KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_BASE_LOGGING_RAW_H_
+#endif // LIBTEXTCLASSIFIER_UTIL_BASE_LOGGING_RAW_H_
diff --git a/util/base/macros.h b/util/base/macros.h
index 7aca681..edb980e 100644
--- a/util/base/macros.h
+++ b/util/base/macros.h
@@ -14,8 +14,8 @@
* limitations under the License.
*/
-#ifndef KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_BASE_MACROS_H_
-#define KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_BASE_MACROS_H_
+#ifndef LIBTEXTCLASSIFIER_UTIL_BASE_MACROS_H_
+#define LIBTEXTCLASSIFIER_UTIL_BASE_MACROS_H_
#include "util/base/config.h"
@@ -80,4 +80,4 @@
} // namespace libtextclassifier2
-#endif // KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_BASE_MACROS_H_
+#endif // LIBTEXTCLASSIFIER_UTIL_BASE_MACROS_H_
diff --git a/util/base/port.h b/util/base/port.h
index 5a68daa..90a2bce 100644
--- a/util/base/port.h
+++ b/util/base/port.h
@@ -16,8 +16,8 @@
// Various portability macros, type definitions, and inline functions.
-#ifndef KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_BASE_PORT_H_
-#define KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_BASE_PORT_H_
+#ifndef LIBTEXTCLASSIFIER_UTIL_BASE_PORT_H_
+#define LIBTEXTCLASSIFIER_UTIL_BASE_PORT_H_
namespace libtextclassifier2 {
@@ -42,4 +42,4 @@
} // namespace libtextclassifier2
-#endif // KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_BASE_PORT_H_
+#endif // LIBTEXTCLASSIFIER_UTIL_BASE_PORT_H_
diff --git a/util/calendar/calendar-icu.cc b/util/calendar/calendar-icu.cc
new file mode 100644
index 0000000..99deeb2
--- /dev/null
+++ b/util/calendar/calendar-icu.cc
@@ -0,0 +1,382 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "util/calendar/calendar-icu.h"
+
+#include <memory>
+
+#include "unicode/gregocal.h"
+#include "unicode/timezone.h"
+#include "unicode/ucal.h"
+
+namespace libtextclassifier2 {
+namespace {
+int MapToDayOfWeekOrDefault(int relation_type, int default_value) {
+ switch (relation_type) {
+ case DateParseData::MONDAY:
+ return UCalendarDaysOfWeek::UCAL_MONDAY;
+ case DateParseData::TUESDAY:
+ return UCalendarDaysOfWeek::UCAL_TUESDAY;
+ case DateParseData::WEDNESDAY:
+ return UCalendarDaysOfWeek::UCAL_WEDNESDAY;
+ case DateParseData::THURSDAY:
+ return UCalendarDaysOfWeek::UCAL_THURSDAY;
+ case DateParseData::FRIDAY:
+ return UCalendarDaysOfWeek::UCAL_FRIDAY;
+ case DateParseData::SATURDAY:
+ return UCalendarDaysOfWeek::UCAL_SATURDAY;
+ case DateParseData::SUNDAY:
+ return UCalendarDaysOfWeek::UCAL_SUNDAY;
+ default:
+ return default_value;
+ }
+}
+
+bool DispatchToRecedeOrToLastDayOfWeek(icu::Calendar* date, int relation_type,
+ int distance) {
+ UErrorCode status = U_ZERO_ERROR;
+ switch (relation_type) {
+ case DateParseData::MONDAY:
+ case DateParseData::TUESDAY:
+ case DateParseData::WEDNESDAY:
+ case DateParseData::THURSDAY:
+ case DateParseData::FRIDAY:
+ case DateParseData::SATURDAY:
+ case DateParseData::SUNDAY:
+ for (int i = 0; i < distance; i++) {
+ do {
+ if (U_FAILURE(status)) {
+ TC_LOG(ERROR) << "error day of week";
+ return false;
+ }
+ date->add(UCalendarDateFields::UCAL_DATE, 1, status);
+ if (U_FAILURE(status)) {
+ TC_LOG(ERROR) << "error adding a day";
+ return false;
+ }
+ } while (date->get(UCalendarDateFields::UCAL_DAY_OF_WEEK, status) !=
+ MapToDayOfWeekOrDefault(relation_type, 1));
+ }
+ return true;
+ case DateParseData::DAY:
+ date->add(UCalendarDateFields::UCAL_DATE, -1 * distance, status);
+ if (U_FAILURE(status)) {
+ TC_LOG(ERROR) << "error adding a day";
+ return false;
+ }
+
+ return true;
+ case DateParseData::WEEK:
+ date->set(UCalendarDateFields::UCAL_DAY_OF_WEEK, 1);
+ date->add(UCalendarDateFields::UCAL_DATE, -7 * (distance - 1), status);
+ if (U_FAILURE(status)) {
+ TC_LOG(ERROR) << "error adding a week";
+ return false;
+ }
+
+ return true;
+ case DateParseData::MONTH:
+ date->set(UCalendarDateFields::UCAL_DATE, 1);
+ date->add(UCalendarDateFields::UCAL_MONTH, -1 * (distance - 1), status);
+ if (U_FAILURE(status)) {
+ TC_LOG(ERROR) << "error adding a month";
+ return false;
+ }
+ return true;
+ case DateParseData::YEAR:
+ date->set(UCalendarDateFields::UCAL_DAY_OF_YEAR, 1);
+ date->add(UCalendarDateFields::UCAL_YEAR, -1 * (distance - 1), status);
+ if (U_FAILURE(status)) {
+ TC_LOG(ERROR) << "error adding a year";
+
+ return true;
+ default:
+ return false;
+ }
+ return false;
+ }
+}
+
+bool DispatchToAdvancerOrToNextOrSameDayOfWeek(icu::Calendar* date,
+ int relation_type) {
+ UErrorCode status = U_ZERO_ERROR;
+ switch (relation_type) {
+ case DateParseData::MONDAY:
+ case DateParseData::TUESDAY:
+ case DateParseData::WEDNESDAY:
+ case DateParseData::THURSDAY:
+ case DateParseData::FRIDAY:
+ case DateParseData::SATURDAY:
+ case DateParseData::SUNDAY:
+ while (date->get(UCalendarDateFields::UCAL_DAY_OF_WEEK, status) !=
+ MapToDayOfWeekOrDefault(relation_type, 1)) {
+ if (U_FAILURE(status)) {
+ TC_LOG(ERROR) << "error day of week";
+ return false;
+ }
+ date->add(UCalendarDateFields::UCAL_DATE, 1, status);
+ if (U_FAILURE(status)) {
+ TC_LOG(ERROR) << "error adding a day";
+ return false;
+ }
+ }
+ return true;
+ case DateParseData::DAY:
+ date->add(UCalendarDateFields::UCAL_DATE, 1, status);
+ if (U_FAILURE(status)) {
+ TC_LOG(ERROR) << "error adding a day";
+ return false;
+ }
+
+ return true;
+ case DateParseData::WEEK:
+ date->set(UCalendarDateFields::UCAL_DAY_OF_WEEK, 1);
+ date->add(UCalendarDateFields::UCAL_DATE, 7, status);
+ if (U_FAILURE(status)) {
+ TC_LOG(ERROR) << "error adding a week";
+ return false;
+ }
+
+ return true;
+ case DateParseData::MONTH:
+ date->set(UCalendarDateFields::UCAL_DATE, 1);
+ date->add(UCalendarDateFields::UCAL_MONTH, 1, status);
+ if (U_FAILURE(status)) {
+ TC_LOG(ERROR) << "error adding a month";
+ return false;
+ }
+ return true;
+ case DateParseData::YEAR:
+ date->set(UCalendarDateFields::UCAL_DAY_OF_YEAR, 1);
+ date->add(UCalendarDateFields::UCAL_YEAR, 1, status);
+ if (U_FAILURE(status)) {
+ TC_LOG(ERROR) << "error adding a year";
+
+ return true;
+ default:
+ return false;
+ }
+ return false;
+ }
+}
+
+bool DispatchToAdvancerOrToNextDayOfWeek(icu::Calendar* date, int relation_type,
+ int distance) {
+ UErrorCode status = U_ZERO_ERROR;
+ switch (relation_type) {
+ case DateParseData::MONDAY:
+ case DateParseData::TUESDAY:
+ case DateParseData::WEDNESDAY:
+ case DateParseData::THURSDAY:
+ case DateParseData::FRIDAY:
+ case DateParseData::SATURDAY:
+ case DateParseData::SUNDAY:
+ for (int i = 0; i < distance; i++) {
+ do {
+ if (U_FAILURE(status)) {
+ TC_LOG(ERROR) << "error day of week";
+ return false;
+ }
+ date->add(UCalendarDateFields::UCAL_DATE, 1, status);
+ if (U_FAILURE(status)) {
+ TC_LOG(ERROR) << "error adding a day";
+ return false;
+ }
+ } while (date->get(UCalendarDateFields::UCAL_DAY_OF_WEEK, status) !=
+ MapToDayOfWeekOrDefault(relation_type, 1));
+ }
+ return true;
+ case DateParseData::DAY:
+ date->add(UCalendarDateFields::UCAL_DATE, distance, status);
+ if (U_FAILURE(status)) {
+ TC_LOG(ERROR) << "error adding a day";
+ return false;
+ }
+
+ return true;
+ case DateParseData::WEEK:
+ date->set(UCalendarDateFields::UCAL_DAY_OF_WEEK, 1);
+ date->add(UCalendarDateFields::UCAL_DATE, 7 * distance, status);
+ if (U_FAILURE(status)) {
+ TC_LOG(ERROR) << "error adding a week";
+ return false;
+ }
+
+ return true;
+ case DateParseData::MONTH:
+ date->set(UCalendarDateFields::UCAL_DATE, 1);
+ date->add(UCalendarDateFields::UCAL_MONTH, 1 * distance, status);
+ if (U_FAILURE(status)) {
+ TC_LOG(ERROR) << "error adding a month";
+ return false;
+ }
+ return true;
+ case DateParseData::YEAR:
+ date->set(UCalendarDateFields::UCAL_DAY_OF_YEAR, 1);
+ date->add(UCalendarDateFields::UCAL_YEAR, 1 * distance, status);
+ if (U_FAILURE(status)) {
+ TC_LOG(ERROR) << "error adding a year";
+
+ return true;
+ default:
+ return false;
+ }
+ return false;
+ }
+}
+
+} // namespace
+
+bool CalendarLib::InterpretParseData(const DateParseData& parse_data,
+ int64 reference_time_ms_utc,
+ const std::string& reference_timezone,
+ int64* interpreted_time_ms_utc) const {
+ UErrorCode status = U_ZERO_ERROR;
+
+ std::unique_ptr<icu::Calendar> date(icu::Calendar::createInstance(status));
+ if (U_FAILURE(status)) {
+ TC_LOG(ERROR) << "error getting calendar instance";
+ return false;
+ }
+
+ date->adoptTimeZone(icu::TimeZone::createTimeZone(
+ icu::UnicodeString::fromUTF8(reference_timezone)));
+ date->setTime(reference_time_ms_utc, status);
+
+ // By default, the parsed time is interpreted to be on the reference day. But
+ // a parsed date, should have time 0:00:00 unless specified.
+ date->set(UCalendarDateFields::UCAL_HOUR_OF_DAY, 0);
+ date->set(UCalendarDateFields::UCAL_MINUTE, 0);
+ date->set(UCalendarDateFields::UCAL_SECOND, 0);
+ date->set(UCalendarDateFields::UCAL_MILLISECOND, 0);
+
+ static const int64 kMillisInHour = 1000 * 60 * 60;
+ if (parse_data.field_set_mask & DateParseData::Fields::ZONE_OFFSET_FIELD) {
+ date->set(UCalendarDateFields::UCAL_ZONE_OFFSET,
+ parse_data.zone_offset * kMillisInHour);
+ }
+ if (parse_data.field_set_mask & DateParseData::Fields::DST_OFFSET_FIELD) {
+ // convert from hours to milliseconds
+ date->set(UCalendarDateFields::UCAL_DST_OFFSET,
+ parse_data.dst_offset * kMillisInHour);
+ }
+
+ if (parse_data.field_set_mask & DateParseData::Fields::RELATION_FIELD) {
+ switch (parse_data.relation) {
+ case DateParseData::Relation::NEXT:
+ if (parse_data.field_set_mask &
+ DateParseData::Fields::RELATION_TYPE_FIELD) {
+ if (!DispatchToAdvancerOrToNextDayOfWeek(
+ date.get(), parse_data.relation_type, 1)) {
+ return false;
+ }
+ }
+ break;
+ case DateParseData::Relation::NEXT_OR_SAME:
+ if (parse_data.field_set_mask &
+ DateParseData::Fields::RELATION_TYPE_FIELD) {
+ if (!DispatchToAdvancerOrToNextOrSameDayOfWeek(
+ date.get(), parse_data.relation_type)) {
+ return false;
+ }
+ }
+ break;
+ case DateParseData::Relation::LAST:
+ if (parse_data.field_set_mask &
+ DateParseData::Fields::RELATION_TYPE_FIELD) {
+ if (!DispatchToRecedeOrToLastDayOfWeek(date.get(),
+ parse_data.relation_type, 1)) {
+ return false;
+ }
+ }
+ break;
+ case DateParseData::Relation::NOW:
+ // NOOP
+ break;
+ case DateParseData::Relation::TOMORROW:
+ date->add(UCalendarDateFields::UCAL_DATE, 1, status);
+ if (U_FAILURE(status)) {
+ TC_LOG(ERROR) << "error adding a day";
+ return false;
+ }
+ break;
+ case DateParseData::Relation::YESTERDAY:
+ date->add(UCalendarDateFields::UCAL_DATE, -1, status);
+ if (U_FAILURE(status)) {
+ TC_LOG(ERROR) << "error subtracting a day";
+ return false;
+ }
+ break;
+ case DateParseData::Relation::PAST:
+ if (parse_data.field_set_mask &
+ DateParseData::Fields::RELATION_TYPE_FIELD) {
+ if (parse_data.field_set_mask &
+ DateParseData::Fields::RELATION_DISTANCE_FIELD) {
+ if (!DispatchToRecedeOrToLastDayOfWeek(
+ date.get(), parse_data.relation_type,
+ parse_data.relation_distance)) {
+ return false;
+ }
+ }
+ }
+ break;
+ case DateParseData::Relation::FUTURE:
+ if (parse_data.field_set_mask &
+ DateParseData::Fields::RELATION_TYPE_FIELD) {
+ if (parse_data.field_set_mask &
+ DateParseData::Fields::RELATION_DISTANCE_FIELD) {
+ if (!DispatchToAdvancerOrToNextDayOfWeek(
+ date.get(), parse_data.relation_type,
+ parse_data.relation_distance)) {
+ return false;
+ }
+ }
+ }
+ break;
+ }
+ }
+ if (parse_data.field_set_mask & DateParseData::Fields::YEAR_FIELD) {
+ date->set(UCalendarDateFields::UCAL_YEAR, parse_data.year);
+ }
+ if (parse_data.field_set_mask & DateParseData::Fields::MONTH_FIELD) {
+ // NOTE: Java and ICU disagree on month formats
+ date->set(UCalendarDateFields::UCAL_MONTH, parse_data.month - 1);
+ }
+ if (parse_data.field_set_mask & DateParseData::Fields::DAY_FIELD) {
+ date->set(UCalendarDateFields::UCAL_DATE, parse_data.day_of_month);
+ }
+ if (parse_data.field_set_mask & DateParseData::Fields::HOUR_FIELD) {
+ if (parse_data.field_set_mask & DateParseData::Fields::AMPM_FIELD &&
+ parse_data.ampm == 1 && parse_data.hour < 12) {
+ date->set(UCalendarDateFields::UCAL_HOUR_OF_DAY, parse_data.hour + 12);
+ } else {
+ date->set(UCalendarDateFields::UCAL_HOUR_OF_DAY, parse_data.hour);
+ }
+ }
+ if (parse_data.field_set_mask & DateParseData::Fields::MINUTE_FIELD) {
+ date->set(UCalendarDateFields::UCAL_MINUTE, parse_data.minute);
+ }
+ if (parse_data.field_set_mask & DateParseData::Fields::SECOND_FIELD) {
+ date->set(UCalendarDateFields::UCAL_SECOND, parse_data.second);
+ }
+ *interpreted_time_ms_utc = date->getTime(status);
+ if (U_FAILURE(status)) {
+ TC_LOG(ERROR) << "error getting time from instance";
+ return false;
+ }
+ return true;
+}
+} // namespace libtextclassifier2
diff --git a/util/calendar/calendar-icu.h b/util/calendar/calendar-icu.h
new file mode 100644
index 0000000..50cb716
--- /dev/null
+++ b/util/calendar/calendar-icu.h
@@ -0,0 +1,39 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTIL_CALENDAR_CALENDAR_ICU_H_
+#define LIBTEXTCLASSIFIER_UTIL_CALENDAR_CALENDAR_ICU_H_
+
+#include <string>
+
+#include "util/base/integral_types.h"
+#include "util/base/logging.h"
+#include "util/calendar/types.h"
+
+namespace libtextclassifier2 {
+
+class CalendarLib {
+ public:
+ // Interprets parse_data as milliseconds since_epoch. Relative times are
+ // resolved against the current time (reference_time_ms_utc). Returns true if
+ // the interpratation was successful, false otherwise.
+ bool InterpretParseData(const DateParseData& parse_data,
+ int64 reference_time_ms_utc,
+ const std::string& reference_timezone,
+ int64* interpreted_time_ms_utc) const;
+};
+} // namespace libtextclassifier2
+#endif // LIBTEXTCLASSIFIER_UTIL_CALENDAR_CALENDAR_ICU_H_
diff --git a/util/calendar/calendar.h b/util/calendar/calendar.h
new file mode 100644
index 0000000..b0cf2e6
--- /dev/null
+++ b/util/calendar/calendar.h
@@ -0,0 +1,22 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTIL_CALENDAR_CALENDAR_H_
+#define LIBTEXTCLASSIFIER_UTIL_CALENDAR_CALENDAR_H_
+
+#include "util/calendar/calendar-icu.h"
+
+#endif // LIBTEXTCLASSIFIER_UTIL_CALENDAR_CALENDAR_H_
diff --git a/util/calendar/calendar_test.cc b/util/calendar/calendar_test.cc
new file mode 100644
index 0000000..7065a95
--- /dev/null
+++ b/util/calendar/calendar_test.cc
@@ -0,0 +1,41 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// This test serves the purpose of making sure all the different implementations
+// of the unspoken CalendarLib interface support the same methods.
+
+#include "util/calendar/calendar.h"
+#include "util/base/logging.h"
+
+#include "gtest/gtest.h"
+
+namespace libtextclassifier2 {
+namespace {
+
+TEST(CalendarTest, Interface) {
+ CalendarLib calendar;
+ int64 time;
+ std::string timezone;
+ bool result = calendar.InterpretParseData(
+ DateParseData{0l, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ static_cast<DateParseData::Relation>(0),
+ static_cast<DateParseData::RelationType>(0), 0},
+ 0L, "Zurich", &time);
+ TC_LOG(INFO) << result;
+}
+
+} // namespace
+} // namespace libtextclassifier2
diff --git a/util/calendar/types.h b/util/calendar/types.h
new file mode 100644
index 0000000..4f58911
--- /dev/null
+++ b/util/calendar/types.h
@@ -0,0 +1,106 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTIL_CALENDAR_TYPES_H_
+#define LIBTEXTCLASSIFIER_UTIL_CALENDAR_TYPES_H_
+
+struct DateParseData {
+ enum Relation {
+ NEXT = 1,
+ NEXT_OR_SAME = 2,
+ LAST = 3,
+ NOW = 4,
+ TOMORROW = 5,
+ YESTERDAY = 6,
+ PAST = 7,
+ FUTURE = 8
+ };
+
+ enum RelationType {
+ MONDAY = 1,
+ TUESDAY = 2,
+ WEDNESDAY = 3,
+ THURSDAY = 4,
+ FRIDAY = 5,
+ SATURDAY = 6,
+ SUNDAY = 7,
+ DAY = 8,
+ WEEK = 9,
+ MONTH = 10,
+ YEAR = 11
+ };
+
+ enum Fields {
+ YEAR_FIELD = 1 << 0,
+ MONTH_FIELD = 1 << 1,
+ DAY_FIELD = 1 << 2,
+ HOUR_FIELD = 1 << 3,
+ MINUTE_FIELD = 1 << 4,
+ SECOND_FIELD = 1 << 5,
+ AMPM_FIELD = 1 << 6,
+ ZONE_OFFSET_FIELD = 1 << 7,
+ DST_OFFSET_FIELD = 1 << 8,
+ RELATION_FIELD = 1 << 9,
+ RELATION_TYPE_FIELD = 1 << 10,
+ RELATION_DISTANCE_FIELD = 1 << 11
+ };
+
+ enum AMPM { AM = 0, PM = 1 };
+
+ enum TimeUnit {
+ DAYS = 1,
+ WEEKS = 2,
+ MONTHS = 3,
+ HOURS = 4,
+ MINUTES = 5,
+ SECONDS = 6,
+ YEARS = 7
+ };
+
+ // Bit mask of fields which have been set on the struct
+ int field_set_mask;
+
+ // Fields describing absolute date fields.
+ // Year of the date seen in the text match.
+ int year;
+ // Month of the year starting with January = 1.
+ int month;
+ // Day of the month starting with 1.
+ int day_of_month;
+ // Hour of the day with a range of 0-23,
+ // values less than 12 need the AMPM field below or heuristics
+ // to definitively determine the time.
+ int hour;
+ // Hour of the day with a range of 0-59.
+ int minute;
+ // Hour of the day with a range of 0-59.
+ int second;
+ // 0 == AM, 1 == PM
+ int ampm;
+ // Number of hours offset from UTC this date time is in.
+ int zone_offset;
+ // Number of hours offest for DST
+ int dst_offset;
+
+ // The permutation from now that was made to find the date time.
+ Relation relation;
+ // The unit of measure of the change to the date time.
+ RelationType relation_type;
+ // The number of units of change that were made.
+ int relation_distance;
+};
+
+#endif // LIBTEXTCLASSIFIER_UTIL_CALENDAR_TYPES_H_
diff --git a/util/gtl/map_util.h b/util/gtl/map_util.h
index d14071e..bd020f8 100644
--- a/util/gtl/map_util.h
+++ b/util/gtl/map_util.h
@@ -14,8 +14,8 @@
* limitations under the License.
*/
-#ifndef KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_GTL_MAP_UTIL_H_
-#define KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_GTL_MAP_UTIL_H_
+#ifndef LIBTEXTCLASSIFIER_UTIL_GTL_MAP_UTIL_H_
+#define LIBTEXTCLASSIFIER_UTIL_GTL_MAP_UTIL_H_
namespace libtextclassifier2 {
@@ -62,4 +62,4 @@
} // namespace libtextclassifier2
-#endif // KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_GTL_MAP_UTIL_H_
+#endif // LIBTEXTCLASSIFIER_UTIL_GTL_MAP_UTIL_H_
diff --git a/util/gtl/stl_util.h b/util/gtl/stl_util.h
index 9d93c03..7b88e05 100644
--- a/util/gtl/stl_util.h
+++ b/util/gtl/stl_util.h
@@ -14,8 +14,8 @@
* limitations under the License.
*/
-#ifndef KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_GTL_STL_UTIL_H_
-#define KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_GTL_STL_UTIL_H_
+#ifndef LIBTEXTCLASSIFIER_UTIL_GTL_STL_UTIL_H_
+#define LIBTEXTCLASSIFIER_UTIL_GTL_STL_UTIL_H_
namespace libtextclassifier2 {
@@ -52,4 +52,4 @@
} // namespace libtextclassifier2
-#endif // KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_GTL_STL_UTIL_H_
+#endif // LIBTEXTCLASSIFIER_UTIL_GTL_STL_UTIL_H_
diff --git a/util/hash/farmhash.h b/util/hash/farmhash.h
index 3bbe294..477b7a8 100644
--- a/util/hash/farmhash.h
+++ b/util/hash/farmhash.h
@@ -14,8 +14,8 @@
* limitations under the License.
*/
-#ifndef KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_HASH_FARMHASH_H_
-#define KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_HASH_FARMHASH_H_
+#ifndef LIBTEXTCLASSIFIER_UTIL_HASH_FARMHASH_H_
+#define LIBTEXTCLASSIFIER_UTIL_HASH_FARMHASH_H_
#include <assert.h>
#include <stdint.h>
@@ -261,4 +261,4 @@
} // namespace NAMESPACE_FOR_HASH_FUNCTIONS
-#endif // KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_HASH_FARMHASH_H_
+#endif // LIBTEXTCLASSIFIER_UTIL_HASH_FARMHASH_H_
diff --git a/util/hash/hash.h b/util/hash/hash.h
index beabd6e..b7a3b53 100644
--- a/util/hash/hash.h
+++ b/util/hash/hash.h
@@ -14,8 +14,8 @@
* limitations under the License.
*/
-#ifndef KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_HASH_HASH_H_
-#define KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_HASH_HASH_H_
+#ifndef LIBTEXTCLASSIFIER_UTIL_HASH_HASH_H_
+#define LIBTEXTCLASSIFIER_UTIL_HASH_HASH_H_
#include <string>
@@ -35,4 +35,4 @@
} // namespace libtextclassifier2
-#endif // KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_HASH_HASH_H_
+#endif // LIBTEXTCLASSIFIER_UTIL_HASH_HASH_H_
diff --git a/util/java/scoped_local_ref.h b/util/java/scoped_local_ref.h
index e716df5..8476767 100644
--- a/util/java/scoped_local_ref.h
+++ b/util/java/scoped_local_ref.h
@@ -14,8 +14,8 @@
* limitations under the License.
*/
-#ifndef KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_JAVA_SCOPED_LOCAL_REF_H_
-#define KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_JAVA_SCOPED_LOCAL_REF_H_
+#ifndef LIBTEXTCLASSIFIER_UTIL_JAVA_SCOPED_LOCAL_REF_H_
+#define LIBTEXTCLASSIFIER_UTIL_JAVA_SCOPED_LOCAL_REF_H_
#include <jni.h>
#include <memory>
@@ -28,6 +28,8 @@
// A deleter to be used with std::unique_ptr to delete JNI local references.
class LocalRefDeleter {
public:
+ LocalRefDeleter() : env_(nullptr) {}
+
// Style guide violating implicit constructor so that the LocalRefDeleter
// is implicitly constructed from the second argument to ScopedLocalRef.
LocalRefDeleter(JNIEnv* env) : env_(env) {} // NOLINT(runtime/explicit)
@@ -43,7 +45,11 @@
}
// The delete operator.
- void operator()(jobject o) const { env_->DeleteLocalRef(o); }
+ void operator()(jobject object) const {
+ if (env_) {
+ env_->DeleteLocalRef(object);
+ }
+ }
private:
// The env_ stashed to use for deletion. Thread-local, don't share!
@@ -62,4 +68,4 @@
} // namespace libtextclassifier2
-#endif // KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_JAVA_SCOPED_LOCAL_REF_H_
+#endif // LIBTEXTCLASSIFIER_UTIL_JAVA_SCOPED_LOCAL_REF_H_
diff --git a/util/java/string_utils.cc b/util/java/string_utils.cc
new file mode 100644
index 0000000..ffd5b11
--- /dev/null
+++ b/util/java/string_utils.cc
@@ -0,0 +1,57 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "util/java/string_utils.h"
+
+#include "util/base/logging.h"
+
+namespace libtextclassifier2 {
+
+bool JStringToUtf8String(JNIEnv* env, const jstring& jstr,
+ std::string* result) {
+ if (jstr == nullptr) {
+ *result = std::string();
+ return false;
+ }
+
+ jclass string_class = env->FindClass("java/lang/String");
+ if (!string_class) {
+ TC_LOG(ERROR) << "Can't find String class";
+ return false;
+ }
+
+ jmethodID get_bytes_id =
+ env->GetMethodID(string_class, "getBytes", "(Ljava/lang/String;)[B");
+
+ jstring encoding = env->NewStringUTF("UTF-8");
+ jbyteArray array = reinterpret_cast<jbyteArray>(
+ env->CallObjectMethod(jstr, get_bytes_id, encoding));
+
+ jbyte* const array_bytes = env->GetByteArrayElements(array, JNI_FALSE);
+ int length = env->GetArrayLength(array);
+
+ *result = std::string(reinterpret_cast<char*>(array_bytes), length);
+
+ // Release the array.
+ env->ReleaseByteArrayElements(array, array_bytes, JNI_ABORT);
+ env->DeleteLocalRef(array);
+ env->DeleteLocalRef(string_class);
+ env->DeleteLocalRef(encoding);
+
+ return true;
+}
+
+} // namespace libtextclassifier2
diff --git a/util/java/string_utils.h b/util/java/string_utils.h
new file mode 100644
index 0000000..6a85856
--- /dev/null
+++ b/util/java/string_utils.h
@@ -0,0 +1,29 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTIL_JAVA_STRING_UTILS_H_
+#define LIBTEXTCLASSIFIER_UTIL_JAVA_STRING_UTILS_H_
+
+#include <jni.h>
+#include <string>
+
+namespace libtextclassifier2 {
+
+bool JStringToUtf8String(JNIEnv* env, const jstring& jstr, std::string* result);
+
+} // namespace libtextclassifier2
+
+#endif // LIBTEXTCLASSIFIER_UTIL_JAVA_STRING_UTILS_H_
diff --git a/util/math/fastexp.h b/util/math/fastexp.h
index acc1453..af7a08c 100644
--- a/util/math/fastexp.h
+++ b/util/math/fastexp.h
@@ -16,8 +16,8 @@
// Fast approximation for exp.
-#ifndef KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_MATH_FASTEXP_H_
-#define KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_MATH_FASTEXP_H_
+#ifndef LIBTEXTCLASSIFIER_UTIL_MATH_FASTEXP_H_
+#define LIBTEXTCLASSIFIER_UTIL_MATH_FASTEXP_H_
#include <cassert>
#include <cmath>
@@ -65,4 +65,4 @@
} // namespace libtextclassifier2
-#endif // KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_MATH_FASTEXP_H_
+#endif // LIBTEXTCLASSIFIER_UTIL_MATH_FASTEXP_H_
diff --git a/util/math/softmax.h b/util/math/softmax.h
index 57bf832..f70a9ab 100644
--- a/util/math/softmax.h
+++ b/util/math/softmax.h
@@ -14,8 +14,8 @@
* limitations under the License.
*/
-#ifndef KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_MATH_SOFTMAX_H_
-#define KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_MATH_SOFTMAX_H_
+#ifndef LIBTEXTCLASSIFIER_UTIL_MATH_SOFTMAX_H_
+#define LIBTEXTCLASSIFIER_UTIL_MATH_SOFTMAX_H_
#include <vector>
@@ -35,4 +35,4 @@
} // namespace libtextclassifier2
-#endif // KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_MATH_SOFTMAX_H_
+#endif // LIBTEXTCLASSIFIER_UTIL_MATH_SOFTMAX_H_
diff --git a/util/memory/mmap.h b/util/memory/mmap.h
index 781f222..7d28b64 100644
--- a/util/memory/mmap.h
+++ b/util/memory/mmap.h
@@ -14,8 +14,8 @@
* limitations under the License.
*/
-#ifndef KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_MEMORY_MMAP_H_
-#define KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_MEMORY_MMAP_H_
+#ifndef LIBTEXTCLASSIFIER_UTIL_MEMORY_MMAP_H_
+#define LIBTEXTCLASSIFIER_UTIL_MEMORY_MMAP_H_
#include <stddef.h>
@@ -138,4 +138,4 @@
} // namespace libtextclassifier2
-#endif // KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_MEMORY_MMAP_H_
+#endif // LIBTEXTCLASSIFIER_UTIL_MEMORY_MMAP_H_
diff --git a/util/strings/numbers.h b/util/strings/numbers.h
index 096954e..a2c8c6e 100644
--- a/util/strings/numbers.h
+++ b/util/strings/numbers.h
@@ -14,8 +14,8 @@
* limitations under the License.
*/
-#ifndef KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_STRINGS_NUMBERS_H_
-#define KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_STRINGS_NUMBERS_H_
+#ifndef LIBTEXTCLASSIFIER_UTIL_STRINGS_NUMBERS_H_
+#define LIBTEXTCLASSIFIER_UTIL_STRINGS_NUMBERS_H_
#include <string>
@@ -49,4 +49,4 @@
} // namespace libtextclassifier2
-#endif // KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_STRINGS_NUMBERS_H_
+#endif // LIBTEXTCLASSIFIER_UTIL_STRINGS_NUMBERS_H_
diff --git a/util/strings/split.h b/util/strings/split.h
index 9860265..abd453b 100644
--- a/util/strings/split.h
+++ b/util/strings/split.h
@@ -14,8 +14,8 @@
* limitations under the License.
*/
-#ifndef KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_STRINGS_SPLIT_H_
-#define KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_STRINGS_SPLIT_H_
+#ifndef LIBTEXTCLASSIFIER_UTIL_STRINGS_SPLIT_H_
+#define LIBTEXTCLASSIFIER_UTIL_STRINGS_SPLIT_H_
#include <string>
#include <vector>
@@ -28,4 +28,4 @@
} // namespace strings
} // namespace libtextclassifier2
-#endif // KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_STRINGS_SPLIT_H_
+#endif // LIBTEXTCLASSIFIER_UTIL_STRINGS_SPLIT_H_
diff --git a/util/strings/stringpiece.h b/util/strings/stringpiece.h
index f6187e9..bd62274 100644
--- a/util/strings/stringpiece.h
+++ b/util/strings/stringpiece.h
@@ -14,8 +14,8 @@
* limitations under the License.
*/
-#ifndef KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_STRINGS_STRINGPIECE_H_
-#define KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_STRINGS_STRINGPIECE_H_
+#ifndef LIBTEXTCLASSIFIER_UTIL_STRINGS_STRINGPIECE_H_
+#define LIBTEXTCLASSIFIER_UTIL_STRINGS_STRINGPIECE_H_
#include <stddef.h>
@@ -63,4 +63,4 @@
} // namespace libtextclassifier2
-#endif // KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_STRINGS_STRINGPIECE_H_
+#endif // LIBTEXTCLASSIFIER_UTIL_STRINGS_STRINGPIECE_H_
diff --git a/util/strings/utf8.h b/util/strings/utf8.h
index 89823e2..e54c18a 100644
--- a/util/strings/utf8.h
+++ b/util/strings/utf8.h
@@ -14,8 +14,8 @@
* limitations under the License.
*/
-#ifndef KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_STRINGS_UTF8_H_
-#define KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_STRINGS_UTF8_H_
+#ifndef LIBTEXTCLASSIFIER_UTIL_STRINGS_UTF8_H_
+#define LIBTEXTCLASSIFIER_UTIL_STRINGS_UTF8_H_
namespace libtextclassifier2 {
@@ -46,4 +46,4 @@
} // namespace libtextclassifier2
-#endif // KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_STRINGS_UTF8_H_
+#endif // LIBTEXTCLASSIFIER_UTIL_STRINGS_UTF8_H_
diff --git a/util/utf8/unicodetext.cc b/util/utf8/unicodetext.cc
index c814a2e..79381bf 100644
--- a/util/utf8/unicodetext.cc
+++ b/util/utf8/unicodetext.cc
@@ -27,6 +27,16 @@
// *************** Data representation **********
// Note: the copy constructor is undefined.
+UnicodeText::Repr& UnicodeText::Repr::operator=(Repr&& src) {
+ if (ours_ && data_) delete[] data_;
+ data_ = src.data_;
+ size_ = src.size_;
+ capacity_ = src.capacity_;
+ ours_ = src.ours_;
+ src.ours_ = false;
+ return *this;
+}
+
void UnicodeText::Repr::PointTo(const char* data, int size) {
if (ours_ && data_) delete[] data_; // If we owned the old buffer, free it.
data_ = const_cast<char*>(data);
@@ -89,6 +99,11 @@
UnicodeText::UnicodeText(const UnicodeText& src) { Copy(src); }
+UnicodeText& UnicodeText::operator=(UnicodeText&& src) {
+ this->repr_ = std::move(src.repr_);
+ return *this;
+}
+
UnicodeText& UnicodeText::Copy(const UnicodeText& src) {
repr_.Copy(src.repr_.data_, src.repr_.size_);
return *this;
@@ -109,6 +124,10 @@
return *this;
}
+const char* UnicodeText::data() const { return repr_.data_; }
+
+int UnicodeText::size_bytes() const { return repr_.size_; }
+
namespace {
enum {
@@ -166,7 +185,22 @@
void UnicodeText::clear() { repr_.clear(); }
-int UnicodeText::size() const { return std::distance(begin(), end()); }
+int UnicodeText::size_codepoints() const {
+ return std::distance(begin(), end());
+}
+
+bool UnicodeText::empty() const { return size_bytes() == 0; }
+
+bool UnicodeText::operator==(const UnicodeText& other) const {
+ if (repr_.size_ != other.repr_.size_) {
+ return false;
+ }
+ return memcmp(repr_.data_, other.repr_.data_, repr_.size_) == 0;
+}
+
+std::string UnicodeText::ToUTF8String() const {
+ return std::string(begin(), end());
+}
std::string UnicodeText::UTF8Substring(const const_iterator& first,
const const_iterator& last) {
@@ -246,8 +280,13 @@
return t;
}
+UnicodeText UTF8ToUnicodeText(const char* utf8_buf, bool do_copy) {
+ return UTF8ToUnicodeText(utf8_buf, strlen(utf8_buf), do_copy);
+}
+
UnicodeText UTF8ToUnicodeText(const std::string& str, bool do_copy) {
return UTF8ToUnicodeText(str.data(), str.size(), do_copy);
}
+
} // namespace libtextclassifier2
diff --git a/util/utf8/unicodetext.h b/util/utf8/unicodetext.h
index d331f9b..7fb8ac1 100644
--- a/util/utf8/unicodetext.h
+++ b/util/utf8/unicodetext.h
@@ -14,8 +14,8 @@
* limitations under the License.
*/
-#ifndef KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_UTF8_UNICODETEXT_H_
-#define KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_UTF8_UNICODETEXT_H_
+#ifndef LIBTEXTCLASSIFIER_UTIL_UTF8_UNICODETEXT_H_
+#define LIBTEXTCLASSIFIER_UTIL_UTF8_UNICODETEXT_H_
#include <iterator>
#include <string>
@@ -68,6 +68,7 @@
UnicodeText(); // Create an empty text.
UnicodeText(const UnicodeText& src);
+ UnicodeText& operator=(UnicodeText&& src);
~UnicodeText();
class const_iterator {
@@ -77,7 +78,7 @@
typedef std::input_iterator_tag iterator_category;
typedef char32 value_type;
typedef int difference_type;
- typedef void pointer; // (Not needed.)
+ typedef void pointer; // (Not needed.)
typedef const char32 reference; // (Needed for const_reverse_iterator)
// Iterators are default-constructible.
@@ -88,7 +89,7 @@
char32 operator*() const; // Dereference
- const_iterator& operator++(); // Advance (++iter)
+ const_iterator& operator++(); // Advance (++iter)
const_iterator operator++(int) { // (iter++)
const_iterator result(*this);
++*this;
@@ -132,14 +133,28 @@
private:
friend class UnicodeText;
- explicit const_iterator(const char *it) : it_(it) {}
+ explicit const_iterator(const char* it) : it_(it) {}
- const char *it_;
+ const char* it_;
};
const_iterator begin() const;
const_iterator end() const;
- int size() const; // the number of Unicode characters (codepoints)
+
+ // Gets pointer to the underlying utf8 data.
+ const char* data() const;
+
+ // Gets length (in bytes) of the underlying utf8 data.
+ int size_bytes() const;
+
+ // Computes length (in number of Unicode codepoints) of the underlying utf8
+ // data.
+ // NOTE: Complexity O(n).
+ int size_codepoints() const;
+
+ bool empty() const;
+
+ bool operator==(const UnicodeText& other) const;
// x.PointToUTF8(buf,len) changes x so that it points to buf
// ("becomes an alias"). It does not take ownership or copy buf.
@@ -153,6 +168,7 @@
UnicodeText& AppendCodepoint(char32 ch);
void clear();
+ std::string ToUTF8String() const;
static std::string UTF8Substring(const const_iterator& first,
const const_iterator& last);
@@ -167,6 +183,7 @@
bool ours_; // Do we own data_?
Repr() : data_(nullptr), size_(0), capacity_(0), ours_(true) {}
+ Repr& operator=(Repr&& src);
~Repr() {
if (ours_) delete[] data_;
}
@@ -191,9 +208,14 @@
typedef std::pair<UnicodeText::const_iterator, UnicodeText::const_iterator>
UnicodeTextRange;
+// NOTE: The following are needed to avoid implicit conversion from char* to
+// std::string, or from ::string to std::string, because if this happens it
+// often results in invalid memory access to a temporary object created during
+// such conversion (if do_copy == false).
UnicodeText UTF8ToUnicodeText(const char* utf8_buf, int len, bool do_copy);
+UnicodeText UTF8ToUnicodeText(const char* utf8_buf, bool do_copy);
UnicodeText UTF8ToUnicodeText(const std::string& str, bool do_copy);
} // namespace libtextclassifier2
-#endif // KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_UTF8_UNICODETEXT_H_
+#endif // LIBTEXTCLASSIFIER_UTIL_UTF8_UNICODETEXT_H_
diff --git a/util/utf8/unilib-icu.cc b/util/utf8/unilib-icu.cc
index 147a364..3d5500a 100644
--- a/util/utf8/unilib-icu.cc
+++ b/util/utf8/unilib-icu.cc
@@ -16,10 +16,30 @@
#include "util/utf8/unilib-icu.h"
-#include "util/base/logging.h"
+#include <utility>
namespace libtextclassifier2 {
+bool UniLib::ParseInt32(const UnicodeText& text, int* result) const {
+ UErrorCode status = U_ZERO_ERROR;
+ UNumberFormat* format_alias =
+ unum_open(UNUM_DECIMAL, nullptr, 0, "en_US_POSIX", nullptr, &status);
+ if (U_FAILURE(status)) {
+ return false;
+ }
+ icu::UnicodeString utf8_string = icu::UnicodeString::fromUTF8(
+ icu::StringPiece(text.data(), text.size_bytes()));
+ int parse_index = 0;
+ const int32 integer = unum_parse(format_alias, utf8_string.getBuffer(),
+ utf8_string.length(), &parse_index, &status);
+ *result = integer;
+ unum_close(format_alias);
+ if (U_FAILURE(status) || parse_index != utf8_string.length()) {
+ return false;
+ }
+ return true;
+}
+
bool UniLib::IsOpeningBracket(char32 codepoint) const {
return u_getIntPropertyValue(codepoint, UCHAR_BIDI_PAIRED_BRACKET_TYPE) ==
U_BPT_OPEN;
@@ -44,28 +64,169 @@
return u_getBidiPairedBracket(codepoint);
}
-bool UniLib::RegexPattern::Matches(const std::string& text) {
- const icu::UnicodeString unicode_text(text.c_str(), text.size(), "utf-8");
- UErrorCode status;
- status = U_ZERO_ERROR;
- std::unique_ptr<icu::RegexMatcher> matcher(
- pattern_->matcher(unicode_text, status));
- if (U_FAILURE(status) || !matcher) {
- return false;
- }
-
- status = U_ZERO_ERROR;
- const bool result = matcher->matches(/*startIndex=*/0, status);
+UniLib::RegexMatcher::RegexMatcher(icu::RegexPattern* pattern,
+ icu::UnicodeString text)
+ : pattern_(pattern), text_(std::move(text)) {
+ UErrorCode status = U_ZERO_ERROR;
+ matcher_.reset(pattern->matcher(text_, status));
if (U_FAILURE(status)) {
+ matcher_.reset(nullptr);
+ }
+}
+
+std::unique_ptr<UniLib::RegexMatcher> UniLib::RegexPattern::Matcher(
+ const UnicodeText& input) const {
+ return std::unique_ptr<UniLib::RegexMatcher>(new UniLib::RegexMatcher(
+ pattern_.get(), icu::UnicodeString::fromUTF8(
+ icu::StringPiece(input.data(), input.size_bytes()))));
+}
+
+constexpr int UniLib::RegexMatcher::kError;
+constexpr int UniLib::RegexMatcher::kNoError;
+
+bool UniLib::RegexMatcher::Matches(int* status) const {
+ std::string text = "";
+ text_.toUTF8String(text);
+ if (!matcher_) {
+ *status = kError;
return false;
}
-
+ UErrorCode icu_status = U_ZERO_ERROR;
+ const bool result = matcher_->matches(/*startIndex=*/0, icu_status);
+ if (U_FAILURE(icu_status)) {
+ *status = kError;
+ return false;
+ }
+ *status = kNoError;
return result;
}
+bool UniLib::RegexMatcher::Find(int* status) {
+ if (!matcher_) {
+ *status = kError;
+ return false;
+ }
+ UErrorCode icu_status = U_ZERO_ERROR;
+ const bool result = matcher_->find(icu_status);
+ if (U_FAILURE(icu_status)) {
+ *status = kError;
+ return false;
+ }
+ *status = kNoError;
+ return result;
+}
+
+int UniLib::RegexMatcher::Start(int* status) const {
+ if (!matcher_) {
+ *status = kError;
+ return kError;
+ }
+ UErrorCode icu_status = U_ZERO_ERROR;
+ const int result = matcher_->start(icu_status);
+ if (U_FAILURE(icu_status)) {
+ *status = kError;
+ return kError;
+ }
+ *status = kNoError;
+ return result;
+}
+
+int UniLib::RegexMatcher::Start(int group_idx, int* status) const {
+ if (!matcher_) {
+ *status = kError;
+ return kError;
+ }
+ UErrorCode icu_status = U_ZERO_ERROR;
+ const int result = matcher_->start(group_idx, icu_status);
+ if (U_FAILURE(icu_status)) {
+ *status = kError;
+ return kError;
+ }
+ *status = kNoError;
+ return result;
+}
+
+int UniLib::RegexMatcher::End(int* status) const {
+ if (!matcher_) {
+ *status = kError;
+ return kError;
+ }
+ UErrorCode icu_status = U_ZERO_ERROR;
+ const int result = matcher_->end(icu_status);
+ if (U_FAILURE(icu_status)) {
+ *status = kError;
+ return kError;
+ }
+ *status = kNoError;
+ return result;
+}
+
+int UniLib::RegexMatcher::End(int group_idx, int* status) const {
+ if (!matcher_) {
+ *status = kError;
+ return kError;
+ }
+ UErrorCode icu_status = U_ZERO_ERROR;
+ const int result = matcher_->end(group_idx, icu_status);
+ if (U_FAILURE(icu_status)) {
+ *status = kError;
+ return kError;
+ }
+ *status = kNoError;
+ return result;
+}
+
+UnicodeText UniLib::RegexMatcher::Group(int* status) const {
+ if (!matcher_) {
+ *status = kError;
+ return UTF8ToUnicodeText("", /*do_copy=*/false);
+ }
+ std::string result = "";
+ UErrorCode icu_status = U_ZERO_ERROR;
+ matcher_->group(icu_status).toUTF8String(result);
+ if (U_FAILURE(icu_status)) {
+ *status = kError;
+ return UTF8ToUnicodeText("", /*do_copy=*/false);
+ }
+ *status = kNoError;
+ return UTF8ToUnicodeText(result, /*do_copy=*/true);
+}
+
+UnicodeText UniLib::RegexMatcher::Group(int group_idx, int* status) const {
+ if (!matcher_) {
+ *status = kError;
+ return UTF8ToUnicodeText("", /*do_copy=*/false);
+ }
+ std::string result = "";
+ UErrorCode icu_status = U_ZERO_ERROR;
+ matcher_->group(group_idx, icu_status).toUTF8String(result);
+ if (U_FAILURE(icu_status)) {
+ *status = kError;
+ return UTF8ToUnicodeText("", /*do_copy=*/false);
+ }
+ *status = kNoError;
+ return UTF8ToUnicodeText(result, /*do_copy=*/true);
+}
+
+UnicodeText UniLib::RegexMatcher::Group(const std::string& group_name,
+ int* status) const {
+ UErrorCode icu_status = U_ZERO_ERROR;
+ const int group_idx = pattern_->groupNumberFromName(
+ icu::UnicodeString::fromUTF8(
+ icu::StringPiece(group_name.c_str(), group_name.size())),
+ icu_status);
+ if (U_FAILURE(icu_status)) {
+ *status = kError;
+ return UTF8ToUnicodeText("", /*do_copy=*/false);
+ }
+ return Group(group_idx, status);
+}
+
constexpr int UniLib::BreakIterator::kDone;
-UniLib::BreakIterator::BreakIterator(const std::string& text) {
+UniLib::BreakIterator::BreakIterator(const UnicodeText& text)
+ : text_(icu::UnicodeString::fromUTF8(
+ icu::StringPiece(text.data(), text.size_bytes()))) {
icu::ErrorCode status;
break_iterator_.reset(
icu::BreakIterator::createWordInstance(icu::Locale("en"), status));
@@ -73,9 +234,7 @@
break_iterator_.reset();
return;
}
-
- const icu::UnicodeString unicode_text = icu::UnicodeString::fromUTF8(text);
- break_iterator_->setText(unicode_text);
+ break_iterator_->setText(text_);
}
int UniLib::BreakIterator::Next() {
@@ -88,11 +247,12 @@
}
std::unique_ptr<UniLib::RegexPattern> UniLib::CreateRegexPattern(
- const std::string& regex) const {
+ const UnicodeText& regex) const {
UErrorCode status = U_ZERO_ERROR;
- std::unique_ptr<icu::RegexPattern> pattern(icu::RegexPattern::compile(
- icu::UnicodeString(regex.c_str(), regex.size(), "utf-8"), /*flags=*/0,
- status));
+ std::unique_ptr<icu::RegexPattern> pattern(
+ icu::RegexPattern::compile(icu::UnicodeString::fromUTF8(icu::StringPiece(
+ regex.data(), regex.size_bytes())),
+ /*flags=*/UREGEX_MULTILINE, status));
if (U_FAILURE(status) || !pattern) {
return nullptr;
}
@@ -101,7 +261,7 @@
}
std::unique_ptr<UniLib::BreakIterator> UniLib::CreateBreakIterator(
- const std::string& text) const {
+ const UnicodeText& text) const {
return std::unique_ptr<UniLib::BreakIterator>(
new UniLib::BreakIterator(text));
}
diff --git a/util/utf8/unilib-icu.h b/util/utf8/unilib-icu.h
index 0d34b74..8070d24 100644
--- a/util/utf8/unilib-icu.h
+++ b/util/utf8/unilib-icu.h
@@ -17,22 +17,24 @@
// UniLib implementation with the help of ICU. UniLib is basically a wrapper
// around the ICU functionality.
-#ifndef KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_UTF8_UNILIB_ICU_H_
-#define KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_UTF8_UNILIB_ICU_H_
+#ifndef LIBTEXTCLASSIFIER_UTIL_UTF8_UNILIB_ICU_H_
+#define LIBTEXTCLASSIFIER_UTIL_UTF8_UNILIB_ICU_H_
#include <memory>
-#include <string>
#include "util/base/integral_types.h"
+#include "util/utf8/unicodetext.h"
#include "unicode/brkiter.h"
#include "unicode/errorcode.h"
#include "unicode/regex.h"
#include "unicode/uchar.h"
+#include "unicode/unum.h"
namespace libtextclassifier2 {
class UniLib {
public:
+ bool ParseInt32(const UnicodeText& text, int* result) const;
bool IsOpeningBracket(char32 codepoint) const;
bool IsClosingBracket(char32 codepoint) const;
bool IsWhitespace(char32 codepoint) const;
@@ -42,10 +44,72 @@
char32 ToLower(char32 codepoint) const;
char32 GetPairedBracket(char32 codepoint) const;
+ // Forward declaration for friend.
+ class RegexPattern;
+
+ class RegexMatcher {
+ public:
+ static constexpr int kError = -1;
+ static constexpr int kNoError = 0;
+
+ // Checks whether the input text matches the pattern exactly.
+ bool Matches(int* status) const;
+
+ // Finds occurrences of the pattern in the input text.
+ // Can be called repeatedly to find all occurences. A call will update
+ // internal state, so that 'Start', 'End' and 'Group' can be called to get
+ // information about the match.
+ bool Find(int* status);
+
+ // Gets the start offset of the last match (from 'Find').
+ // Sets status to 'kError' if 'Find'
+ // was not called previously.
+ int Start(int* status) const;
+
+ // Gets the start offset of the specified group of the last match.
+ // (from 'Find').
+ // Sets status to 'kError' if an invalid group was specified or if 'Find'
+ // was not called previously.
+ int Start(int group_idx, int* status) const;
+
+ // Gets the end offset of the last match (from 'Find').
+ // Sets status to 'kError' if 'Find'
+ // was not called previously.
+ int End(int* status) const;
+
+ // Gets the end offset of the specified group of the last match.
+ // (from 'Find').
+ // Sets status to 'kError' if an invalid group was specified or if 'Find'
+ // was not called previously.
+ int End(int group_idx, int* status) const;
+
+ // Gets the text of the last match (from 'Find').
+ // Sets status to 'kError' if 'Find' was not called previously.
+ UnicodeText Group(int* status) const;
+
+ // Gets the text of the specified group of the last match (from 'Find').
+ // Sets status to 'kError' if an invalid group was specified or if 'Find'
+ // was not called previously.
+ UnicodeText Group(int group_idx, int* status) const;
+
+ // Gets the text of the specified group of the last match (from 'Find').
+ // Sets status to 'kError' if an invalid group was specified or if 'Find'
+ // was not called previously.
+ UnicodeText Group(const std::string& group_name, int* status) const;
+
+ protected:
+ friend class RegexPattern;
+ explicit RegexMatcher(icu::RegexPattern* pattern, icu::UnicodeString text);
+
+ private:
+ std::unique_ptr<icu::RegexMatcher> matcher_;
+ icu::RegexPattern* pattern_;
+ icu::UnicodeString text_;
+ };
+
class RegexPattern {
public:
- // Returns true if the whole input matches with the regex.
- bool Matches(const std::string& text);
+ std::unique_ptr<RegexMatcher> Matcher(const UnicodeText& input) const;
protected:
friend class UniLib;
@@ -64,18 +128,19 @@
protected:
friend class UniLib;
- explicit BreakIterator(const std::string& text);
+ explicit BreakIterator(const UnicodeText& text);
private:
std::unique_ptr<icu::BreakIterator> break_iterator_;
+ icu::UnicodeString text_;
};
std::unique_ptr<RegexPattern> CreateRegexPattern(
- const std::string& regex) const;
+ const UnicodeText& regex) const;
std::unique_ptr<BreakIterator> CreateBreakIterator(
- const std::string& text) const;
+ const UnicodeText& text) const;
};
} // namespace libtextclassifier2
-#endif // KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_UTF8_UNILIB_ICU_H_
+#endif // LIBTEXTCLASSIFIER_UTIL_UTF8_UNILIB_ICU_H_
diff --git a/util/utf8/unilib.h b/util/utf8/unilib.h
index b583d72..29b4575 100644
--- a/util/utf8/unilib.h
+++ b/util/utf8/unilib.h
@@ -14,15 +14,10 @@
* limitations under the License.
*/
-#ifndef KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_UTF8_UNILIB_H_
-#define KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_UTF8_UNILIB_H_
+#ifndef LIBTEXTCLASSIFIER_UTIL_UTF8_UNILIB_H_
+#define LIBTEXTCLASSIFIER_UTIL_UTF8_UNILIB_H_
-#if defined LIBTEXTCLASSIFIER_UNILIB_ICU
#include "util/utf8/unilib-icu.h"
-#elif defined LIBTEXTCLASSIFIER_UNILIB_DUMMY
-#include "util/utf8/unilib-dummy.h"
-#else
-#error No LIBTEXTCLASSIFIER_UNILIB implementation specified.
-#endif
+#define CREATE_UNILIB_FOR_TESTING const UniLib unilib;
-#endif // KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_UTF8_UNILIB_H_
+#endif // LIBTEXTCLASSIFIER_UTIL_UTF8_UNILIB_H_
diff --git a/util/utf8/unilib_test.cc b/util/utf8/unilib_test.cc
index a1bbdf4..bff2ffc 100644
--- a/util/utf8/unilib_test.cc
+++ b/util/utf8/unilib_test.cc
@@ -15,6 +15,7 @@
*/
#include "util/utf8/unilib.h"
+#include "util/utf8/unicodetext.h"
#include "util/base/logging.h"
@@ -23,23 +24,104 @@
namespace libtextclassifier2 {
namespace {
-TEST(UniLibTest, Interface) {
- UniLib unilib;
- TC_LOG(INFO) << unilib.IsOpeningBracket('(');
- TC_LOG(INFO) << unilib.IsClosingBracket(')');
- TC_LOG(INFO) << unilib.IsWhitespace(')');
- TC_LOG(INFO) << unilib.IsDigit(')');
- TC_LOG(INFO) << unilib.IsUpper(')');
- TC_LOG(INFO) << unilib.ToLower(')');
- TC_LOG(INFO) << unilib.GetPairedBracket(')');
+TEST(UniLibTest, CharacterClassesAscii) {
+ CREATE_UNILIB_FOR_TESTING
+ EXPECT_TRUE(unilib.IsOpeningBracket('('));
+ EXPECT_TRUE(unilib.IsClosingBracket(')'));
+ EXPECT_FALSE(unilib.IsWhitespace(')'));
+ EXPECT_TRUE(unilib.IsWhitespace(' '));
+ EXPECT_FALSE(unilib.IsDigit(')'));
+ EXPECT_TRUE(unilib.IsDigit('0'));
+ EXPECT_TRUE(unilib.IsDigit('9'));
+ EXPECT_FALSE(unilib.IsUpper(')'));
+ EXPECT_TRUE(unilib.IsUpper('A'));
+ EXPECT_TRUE(unilib.IsUpper('Z'));
+ EXPECT_EQ(unilib.ToLower('A'), 'a');
+ EXPECT_EQ(unilib.ToLower('Z'), 'z');
+ EXPECT_EQ(unilib.ToLower(')'), ')');
+ EXPECT_EQ(unilib.GetPairedBracket(')'), '(');
+ EXPECT_EQ(unilib.GetPairedBracket('}'), '{');
+}
+
+#ifndef LIBTEXTCLASSIFIER_UNILIB_DUMMY
+TEST(UniLibTest, CharacterClassesUnicode) {
+ CREATE_UNILIB_FOR_TESTING
+ EXPECT_TRUE(unilib.IsOpeningBracket(0x0F3C)); // TIBET ANG KHANG GYON
+ EXPECT_TRUE(unilib.IsClosingBracket(0x0F3D)); // TIBET ANG KHANG GYAS
+ EXPECT_FALSE(unilib.IsWhitespace(0x23F0)); // ALARM CLOCK
+ EXPECT_TRUE(unilib.IsWhitespace(0x2003)); // EM SPACE
+ EXPECT_FALSE(unilib.IsDigit(0xA619)); // VAI SYMBOL JONG
+ EXPECT_TRUE(unilib.IsDigit(0xA620)); // VAI DIGIT ZERO
+ EXPECT_TRUE(unilib.IsDigit(0xA629)); // VAI DIGIT NINE
+ EXPECT_FALSE(unilib.IsDigit(0xA62A)); // VAI SYLLABLE NDOLE MA
+ EXPECT_FALSE(unilib.IsUpper(0x0211)); // SMALL R WITH DOUBLE GRAVE
+ EXPECT_TRUE(unilib.IsUpper(0x0212)); // CAPITAL R WITH DOUBLE GRAVE
+ EXPECT_TRUE(unilib.IsUpper(0x0391)); // GREEK CAPITAL ALPHA
+ EXPECT_TRUE(unilib.IsUpper(0x03AB)); // GREEK CAPITAL UPSILON W DIAL
+ EXPECT_FALSE(unilib.IsUpper(0x03AC)); // GREEK SMALL ALPHA WITH TONOS
+ EXPECT_EQ(unilib.ToLower(0x0391), 0x03B1); // GREEK ALPHA
+ EXPECT_EQ(unilib.ToLower(0x03AB), 0x03CB); // GREEK UPSILON WITH DIALYTIKA
+ EXPECT_EQ(unilib.ToLower(0x03C0), 0x03C0); // GREEK SMALL PI
+
+ EXPECT_EQ(unilib.GetPairedBracket(0x0F3C), 0x0F3D);
+ EXPECT_EQ(unilib.GetPairedBracket(0x0F3D), 0x0F3C);
+}
+#endif // ndef LIBTEXTCLASSIFIER_UNILIB_DUMMY
+
+TEST(UniLibTest, Regex) {
+ CREATE_UNILIB_FOR_TESTING
+ const UnicodeText regex_pattern =
+ UTF8ToUnicodeText("[0-9]+", /*do_copy=*/true);
std::unique_ptr<UniLib::RegexPattern> pattern =
- unilib.CreateRegexPattern("[0-9]");
- TC_LOG(INFO) << pattern->Matches("Hello");
+ unilib.CreateRegexPattern(regex_pattern);
+ const UnicodeText input = UTF8ToUnicodeText("hello 0123", /*do_copy=*/false);
+ int status;
+ std::unique_ptr<UniLib::RegexMatcher> matcher = pattern->Matcher(input);
+ TC_LOG(INFO) << matcher->Matches(&status);
+ TC_LOG(INFO) << matcher->Find(&status);
+ TC_LOG(INFO) << matcher->Start(0, &status);
+ TC_LOG(INFO) << matcher->End(0, &status);
+ TC_LOG(INFO) << matcher->Group(0, &status).size_codepoints();
+}
+
+TEST(UniLibTest, BreakIterator) {
+ CREATE_UNILIB_FOR_TESTING
+ const UnicodeText text = UTF8ToUnicodeText("some text", /*do_copy=*/false);
std::unique_ptr<UniLib::BreakIterator> iterator =
- unilib.CreateBreakIterator("some text");
+ unilib.CreateBreakIterator(text);
TC_LOG(INFO) << iterator->Next();
TC_LOG(INFO) << UniLib::BreakIterator::kDone;
}
+TEST(UniLibTest, IntegerParse) {
+#if !defined LIBTEXTCLASSIFIER_UNILIB_JAVAICU
+ CREATE_UNILIB_FOR_TESTING;
+ int result;
+ EXPECT_TRUE(
+ unilib.ParseInt32(UTF8ToUnicodeText("123", /*do_copy=*/false), &result));
+ EXPECT_EQ(result, 123);
+#endif
+}
+
+TEST(UniLibTest, IntegerParseFullWidth) {
+#if defined LIBTEXTCLASSIFIER_UNILIB_ICU
+ CREATE_UNILIB_FOR_TESTING;
+ int result;
+ // The input string here is full width
+ EXPECT_TRUE(unilib.ParseInt32(UTF8ToUnicodeText("123", /*do_copy=*/false),
+ &result));
+ EXPECT_EQ(result, 123);
+#endif
+}
+
+TEST(UniLibTest, IntegerParseFullWidthWithAlpha) {
+#if defined LIBTEXTCLASSIFIER_UNILIB_ICU
+ CREATE_UNILIB_FOR_TESTING;
+ int result;
+ // The input string here is full width
+ EXPECT_FALSE(unilib.ParseInt32(UTF8ToUnicodeText("1a3", /*do_copy=*/false),
+ &result));
+#endif
+}
} // namespace
} // namespace libtextclassifier2