Snap for 4653471 from ba849e7b63cdf4a38e6ef1a5a9ffd60567d7c40b to pi-release
Change-Id: I12ec2cd602c19108909495e99737ca4e9aeaeee7
diff --git a/cached-features.cc b/cached-features.cc
index 0863c0e..2a46780 100644
--- a/cached-features.cc
+++ b/cached-features.cc
@@ -23,37 +23,6 @@
namespace {
-// Populates the features for one token into the target vector at an offset
-// corresponding to the given token index. It builds the features to populate by
-// embedding the sparse features and combining them with the dense featues.
-// Embeds sparse features and the features of one token into the features
-// vector.
-bool PopulateTokenFeatures(int target_feature_index,
- const std::vector<int>& sparse_features,
- const std::vector<float>& dense_features,
- int feature_vector_size,
- EmbeddingExecutor* embedding_executor,
- std::vector<float>* target_features) {
- const int sparse_embedding_size = feature_vector_size - dense_features.size();
- float* dest =
- target_features->data() + target_feature_index * feature_vector_size;
-
- // Embed sparse features.
- if (!embedding_executor->AddEmbedding(
- TensorView<int>(sparse_features.data(),
- {static_cast<int>(sparse_features.size())}),
- dest, sparse_embedding_size)) {
- return false;
- }
-
- // Copy dense features.
- for (int j = 0; j < dense_features.size(); ++j) {
- dest[sparse_embedding_size + j] = dense_features[j];
- }
-
- return true;
-}
-
int CalculateOutputFeaturesSize(const FeatureProcessorOptions* options,
int feature_vector_size) {
const bool bounds_sensitive_enabled =
@@ -89,12 +58,9 @@
std::unique_ptr<CachedFeatures> CachedFeatures::Create(
const TokenSpan& extraction_span,
- const std::vector<std::vector<int>>& sparse_features,
- const std::vector<std::vector<float>>& dense_features,
- const std::vector<int>& padding_sparse_features,
- const std::vector<float>& padding_dense_features,
- const FeatureProcessorOptions* options,
- EmbeddingExecutor* embedding_executor, int feature_vector_size) {
+ std::unique_ptr<std::vector<float>> features,
+ std::unique_ptr<std::vector<float>> padding_features,
+ const FeatureProcessorOptions* options, int feature_vector_size) {
const int min_feature_version =
options->bounds_sensitive_features() &&
options->bounds_sensitive_features()->enabled()
@@ -107,32 +73,13 @@
std::unique_ptr<CachedFeatures> cached_features(new CachedFeatures());
cached_features->extraction_span_ = extraction_span;
+ cached_features->features_ = std::move(features);
+ cached_features->padding_features_ = std::move(padding_features);
cached_features->options_ = options;
cached_features->output_features_size_ =
CalculateOutputFeaturesSize(options, feature_vector_size);
- cached_features->features_.resize(feature_vector_size *
- TokenSpanSize(extraction_span));
- for (int i = 0; i < TokenSpanSize(extraction_span); ++i) {
- if (!PopulateTokenFeatures(/*target_feature_index=*/i, sparse_features[i],
- dense_features[i], feature_vector_size,
- embedding_executor,
- &cached_features->features_)) {
- TC_LOG(ERROR) << "Could not embed sparse token features.";
- return nullptr;
- }
- }
-
- cached_features->padding_features_.resize(feature_vector_size);
- if (!PopulateTokenFeatures(/*target_feature_index=*/0,
- padding_sparse_features, padding_dense_features,
- feature_vector_size, embedding_executor,
- &cached_features->padding_features_)) {
- TC_LOG(ERROR) << "Could not embed sparse padding token features.";
- return nullptr;
- }
-
return cached_features;
}
@@ -194,8 +141,8 @@
}
output_features->insert(
output_features->end(),
- features_.begin() + copy_span.first * NumFeaturesPerToken(),
- features_.begin() + copy_span.second * NumFeaturesPerToken());
+ features_->begin() + copy_span.first * NumFeaturesPerToken(),
+ features_->begin() + copy_span.second * NumFeaturesPerToken());
for (int i = copy_span.second; i < intended_span.second; ++i) {
AppendPaddingFeatures(output_features);
}
@@ -203,8 +150,8 @@
void CachedFeatures::AppendPaddingFeatures(
std::vector<float>* output_features) const {
- output_features->insert(output_features->end(), padding_features_.begin(),
- padding_features_.end());
+ output_features->insert(output_features->end(), padding_features_->begin(),
+ padding_features_->end());
}
void CachedFeatures::AppendBagFeatures(
@@ -214,13 +161,13 @@
for (int i = bag_span.first; i < bag_span.second; ++i) {
for (int j = 0; j < NumFeaturesPerToken(); ++j) {
(*output_features)[offset + j] +=
- features_[i * NumFeaturesPerToken() + j] / TokenSpanSize(bag_span);
+ (*features_)[i * NumFeaturesPerToken() + j] / TokenSpanSize(bag_span);
}
}
}
int CachedFeatures::NumFeaturesPerToken() const {
- return padding_features_.size();
+ return padding_features_->size();
}
} // namespace libtextclassifier2
diff --git a/cached-features.h b/cached-features.h
index 86b700f..0224d86 100644
--- a/cached-features.h
+++ b/cached-features.h
@@ -32,12 +32,9 @@
public:
static std::unique_ptr<CachedFeatures> Create(
const TokenSpan& extraction_span,
- const std::vector<std::vector<int>>& sparse_features,
- const std::vector<std::vector<float>>& dense_features,
- const std::vector<int>& padding_sparse_features,
- const std::vector<float>& padding_dense_features,
- const FeatureProcessorOptions* options,
- EmbeddingExecutor* embedding_executor, int feature_vector_size);
+ std::unique_ptr<std::vector<float>> features,
+ std::unique_ptr<std::vector<float>> padding_features,
+ const FeatureProcessorOptions* options, int feature_vector_size);
// Appends the click context features for the given click position to
// 'output_features'.
@@ -77,8 +74,8 @@
TokenSpan extraction_span_;
const FeatureProcessorOptions* options_;
int output_features_size_;
- std::vector<float> features_;
- std::vector<float> padding_features_;
+ std::unique_ptr<std::vector<float>> features_;
+ std::unique_ptr<std::vector<float>> padding_features_;
};
} // namespace libtextclassifier2
diff --git a/cached-features_test.cc b/cached-features_test.cc
index 9566a8d..f064a63 100644
--- a/cached-features_test.cc
+++ b/cached-features_test.cc
@@ -37,22 +37,15 @@
return ElementsAreArray(matchers);
}
-// EmbeddingExecutor that always returns features based on
-class FakeEmbeddingExecutor : public EmbeddingExecutor {
- public:
- bool AddEmbedding(const TensorView<int>& sparse_features, float* dest,
- int dest_size) override {
- TC_CHECK_GE(dest_size, 2);
- EXPECT_EQ(sparse_features.size(), 1);
-
- dest[0] = sparse_features.data()[0] * 11.0f;
- dest[1] = -sparse_features.data()[0] * 11.0f;
- return true;
+std::unique_ptr<std::vector<float>> MakeFeatures(int num_tokens) {
+ std::unique_ptr<std::vector<float>> features(new std::vector<float>());
+ for (int i = 1; i <= num_tokens; ++i) {
+ features->push_back(i * 11.0f);
+ features->push_back(-i * 11.0f);
+ features->push_back(i * 0.1f);
}
-
- private:
- std::vector<float> storage_;
-};
+ return features;
+}
std::vector<float> GetCachedClickContextFeatures(
const CachedFeatures& cached_features, int click_pos) {
@@ -78,25 +71,15 @@
builder.Finish(CreateFeatureProcessorOptions(builder, &options));
flatbuffers::DetachedBuffer options_fb = builder.Release();
- std::vector<std::vector<int>> sparse_features(9);
- for (int i = 0; i < sparse_features.size(); ++i) {
- sparse_features[i].push_back(i + 1);
- }
- std::vector<std::vector<float>> dense_features(9);
- for (int i = 0; i < dense_features.size(); ++i) {
- dense_features[i].push_back((i + 1) * 0.1);
- }
+ std::unique_ptr<std::vector<float>> features = MakeFeatures(9);
+ std::unique_ptr<std::vector<float>> padding_features(
+ new std::vector<float>{112233.0, -112233.0, 321.0});
- std::vector<int> padding_sparse_features = {10203};
- std::vector<float> padding_dense_features = {321.0};
-
- FakeEmbeddingExecutor executor;
const std::unique_ptr<CachedFeatures> cached_features =
CachedFeatures::Create(
- {3, 10}, sparse_features, dense_features, padding_sparse_features,
- padding_dense_features,
+ {3, 10}, std::move(features), std::move(padding_features),
flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
- &executor, /*feature_vector_size=*/3);
+ /*feature_vector_size=*/3);
ASSERT_TRUE(cached_features);
EXPECT_THAT(GetCachedClickContextFeatures(*cached_features, 5),
@@ -129,25 +112,15 @@
builder.Finish(CreateFeatureProcessorOptions(builder, &options));
flatbuffers::DetachedBuffer options_fb = builder.Release();
- std::vector<std::vector<int>> sparse_features(6);
- for (int i = 0; i < sparse_features.size(); ++i) {
- sparse_features[i].push_back(i + 1);
- }
- std::vector<std::vector<float>> dense_features(6);
- for (int i = 0; i < dense_features.size(); ++i) {
- dense_features[i].push_back((i + 1) * 0.1);
- }
+ std::unique_ptr<std::vector<float>> features = MakeFeatures(9);
+ std::unique_ptr<std::vector<float>> padding_features(
+ new std::vector<float>{112233.0, -112233.0, 321.0});
- std::vector<int> padding_sparse_features = {10203};
- std::vector<float> padding_dense_features = {321.0};
-
- FakeEmbeddingExecutor executor;
const std::unique_ptr<CachedFeatures> cached_features =
CachedFeatures::Create(
- {3, 9}, sparse_features, dense_features, padding_sparse_features,
- padding_dense_features,
+ {3, 9}, std::move(features), std::move(padding_features),
flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
- &executor, /*feature_vector_size=*/3);
+ /*feature_vector_size=*/3);
ASSERT_TRUE(cached_features);
EXPECT_THAT(
diff --git a/datetime/extractor.cc b/datetime/extractor.cc
index 79d7686..8c6c3ff 100644
--- a/datetime/extractor.cc
+++ b/datetime/extractor.cc
@@ -30,9 +30,15 @@
constexpr char const* kGroupRelationDistance = "RELATIONDISTANCE";
constexpr char const* kGroupRelation = "RELATION";
constexpr char const* kGroupRelationType = "RELATIONTYPE";
+// Dummy groups serve just as an inflator of the selection. E.g. we might want
+// to select more text than was contained in an envelope of all extractor spans.
+constexpr char const* kGroupDummy1 = "DUMMY1";
+constexpr char const* kGroupDummy2 = "DUMMY2";
-bool DatetimeExtractor::Extract(DateParseData* result) const {
+bool DatetimeExtractor::Extract(DateParseData* result,
+ CodepointSpan* result_span) const {
result->field_set_mask = 0;
+ *result_span = {kInvalidIndex, kInvalidIndex};
UnicodeText group_text;
if (GroupNotEmpty(kGroupYear, &group_text)) {
@@ -41,6 +47,10 @@
TC_LOG(ERROR) << "Couldn't extract YEAR.";
return false;
}
+ if (!UpdateMatchSpan(kGroupYear, result_span)) {
+ TC_LOG(ERROR) << "Couldn't update span.";
+ return false;
+ }
}
if (GroupNotEmpty(kGroupMonth, &group_text)) {
@@ -49,6 +59,10 @@
TC_LOG(ERROR) << "Couldn't extract MONTH.";
return false;
}
+ if (!UpdateMatchSpan(kGroupMonth, result_span)) {
+ TC_LOG(ERROR) << "Couldn't update span.";
+ return false;
+ }
}
if (GroupNotEmpty(kGroupDay, &group_text)) {
@@ -57,6 +71,10 @@
TC_LOG(ERROR) << "Couldn't extract DAY.";
return false;
}
+ if (!UpdateMatchSpan(kGroupDay, result_span)) {
+ TC_LOG(ERROR) << "Couldn't update span.";
+ return false;
+ }
}
if (GroupNotEmpty(kGroupHour, &group_text)) {
@@ -65,6 +83,10 @@
TC_LOG(ERROR) << "Couldn't extract HOUR.";
return false;
}
+ if (!UpdateMatchSpan(kGroupHour, result_span)) {
+ TC_LOG(ERROR) << "Couldn't update span.";
+ return false;
+ }
}
if (GroupNotEmpty(kGroupMinute, &group_text)) {
@@ -73,6 +95,10 @@
TC_LOG(ERROR) << "Couldn't extract MINUTE.";
return false;
}
+ if (!UpdateMatchSpan(kGroupMinute, result_span)) {
+ TC_LOG(ERROR) << "Couldn't update span.";
+ return false;
+ }
}
if (GroupNotEmpty(kGroupSecond, &group_text)) {
@@ -81,6 +107,10 @@
TC_LOG(ERROR) << "Couldn't extract SECOND.";
return false;
}
+ if (!UpdateMatchSpan(kGroupSecond, result_span)) {
+ TC_LOG(ERROR) << "Couldn't update span.";
+ return false;
+ }
}
if (GroupNotEmpty(kGroupAmpm, &group_text)) {
@@ -89,6 +119,10 @@
TC_LOG(ERROR) << "Couldn't extract AMPM.";
return false;
}
+ if (!UpdateMatchSpan(kGroupAmpm, result_span)) {
+ TC_LOG(ERROR) << "Couldn't update span.";
+ return false;
+ }
}
if (GroupNotEmpty(kGroupRelationDistance, &group_text)) {
@@ -97,6 +131,10 @@
TC_LOG(ERROR) << "Couldn't extract RELATION_DISTANCE_FIELD.";
return false;
}
+ if (!UpdateMatchSpan(kGroupRelationDistance, result_span)) {
+ TC_LOG(ERROR) << "Couldn't update span.";
+ return false;
+ }
}
if (GroupNotEmpty(kGroupRelation, &group_text)) {
@@ -105,6 +143,10 @@
TC_LOG(ERROR) << "Couldn't extract RELATION_FIELD.";
return false;
}
+ if (!UpdateMatchSpan(kGroupRelation, result_span)) {
+ TC_LOG(ERROR) << "Couldn't update span.";
+ return false;
+ }
}
if (GroupNotEmpty(kGroupRelationType, &group_text)) {
@@ -113,6 +155,29 @@
TC_LOG(ERROR) << "Couldn't extract RELATION_TYPE_FIELD.";
return false;
}
+ if (!UpdateMatchSpan(kGroupRelationType, result_span)) {
+ TC_LOG(ERROR) << "Couldn't update span.";
+ return false;
+ }
+ }
+
+ if (GroupNotEmpty(kGroupDummy1, &group_text)) {
+ if (!UpdateMatchSpan(kGroupDummy1, result_span)) {
+ TC_LOG(ERROR) << "Couldn't update span.";
+ return false;
+ }
+ }
+
+ if (GroupNotEmpty(kGroupDummy2, &group_text)) {
+ if (!UpdateMatchSpan(kGroupDummy2, result_span)) {
+ TC_LOG(ERROR) << "Couldn't update span.";
+ return false;
+ }
+ }
+
+ if (result_span->first == kInvalidIndex ||
+ result_span->second == kInvalidIndex) {
+ *result_span = {kInvalidIndex, kInvalidIndex};
}
return true;
@@ -161,7 +226,7 @@
return true;
}
-bool DatetimeExtractor::GroupNotEmpty(const std::string& name,
+bool DatetimeExtractor::GroupNotEmpty(StringPiece name,
UnicodeText* result) const {
int status;
*result = matcher_.Group(name, &status);
@@ -171,6 +236,27 @@
return !result->empty();
}
+bool DatetimeExtractor::UpdateMatchSpan(StringPiece name,
+ CodepointSpan* span) const {
+ int status;
+ const int match_start = matcher_.Start(name, &status);
+ if (status != UniLib::RegexMatcher::kNoError) {
+ return false;
+ }
+ const int match_end = matcher_.End(name, &status);
+ if (status != UniLib::RegexMatcher::kNoError) {
+ return false;
+ }
+ if (span->first == kInvalidIndex || span->first > match_start) {
+ span->first = match_start;
+ }
+ if (span->second == kInvalidIndex || span->second < match_end) {
+ span->second = match_end;
+ }
+
+ return true;
+}
+
template <typename T>
bool DatetimeExtractor::MapInput(
const UnicodeText& input,
diff --git a/datetime/extractor.h b/datetime/extractor.h
index f068dff..ceeb9cf 100644
--- a/datetime/extractor.h
+++ b/datetime/extractor.h
@@ -22,7 +22,8 @@
#include <vector>
#include "model_generated.h"
-#include "util/calendar/types.h"
+#include "types.h"
+#include "util/strings/stringpiece.h"
#include "util/utf8/unicodetext.h"
#include "util/utf8/unilib.h"
@@ -34,7 +35,8 @@
public:
DatetimeExtractor(
const UniLib::RegexMatcher& matcher, int locale_id, const UniLib& unilib,
- const std::vector<std::unique_ptr<UniLib::RegexPattern>>& extractor_rules,
+ const std::vector<std::unique_ptr<const UniLib::RegexPattern>>&
+ extractor_rules,
const std::unordered_map<DatetimeExtractorType,
std::unordered_map<int, int>>&
type_and_locale_to_extractor_rule)
@@ -43,7 +45,7 @@
unilib_(unilib),
rules_(extractor_rules),
type_and_locale_to_rule_(type_and_locale_to_extractor_rule) {}
- bool Extract(DateParseData* result) const;
+ bool Extract(DateParseData* result, CodepointSpan* result_span) const;
private:
bool RuleIdForType(DatetimeExtractorType type, int* rule_id) const;
@@ -55,7 +57,10 @@
DatetimeExtractorType extractor_type,
UnicodeText* match_result = nullptr) const;
- bool GroupNotEmpty(const std::string& name, UnicodeText* result) const;
+ bool GroupNotEmpty(StringPiece name, UnicodeText* result) const;
+
+ // Updates the span to include the current match for the given group.
+ bool UpdateMatchSpan(StringPiece group_name, CodepointSpan* span) const;
// Returns true if any of the extractors from 'mapping' matched. If it did,
// will fill 'result' with the associated value from 'mapping'.
@@ -82,7 +87,7 @@
const UniLib::RegexMatcher& matcher_;
int locale_id_;
const UniLib& unilib_;
- const std::vector<std::unique_ptr<UniLib::RegexPattern>>& rules_;
+ const std::vector<std::unique_ptr<const UniLib::RegexPattern>>& rules_;
const std::unordered_map<DatetimeExtractorType, std::unordered_map<int, int>>&
type_and_locale_to_rule_;
};
diff --git a/datetime/parser.cc b/datetime/parser.cc
index 27d0a00..8ad3d33 100644
--- a/datetime/parser.cc
+++ b/datetime/parser.cc
@@ -78,21 +78,24 @@
locale_string_to_id_[model->locales()->Get(i)->str()] = i;
}
+ use_extractors_for_locating_ = model->use_extractors_for_locating();
+
initialized_ = true;
}
bool DatetimeParser::Parse(
const std::string& input, const int64 reference_time_ms_utc,
const std::string& reference_timezone, const std::string& locales,
- std::vector<DatetimeParseResultSpan>* results) const {
+ ModeFlag mode, std::vector<DatetimeParseResultSpan>* results) const {
return Parse(UTF8ToUnicodeText(input, /*do_copy=*/false),
- reference_time_ms_utc, reference_timezone, locales, results);
+ reference_time_ms_utc, reference_timezone, locales, mode,
+ results);
}
bool DatetimeParser::Parse(
const UnicodeText& input, const int64 reference_time_ms_utc,
const std::string& reference_timezone, const std::string& locales,
- std::vector<DatetimeParseResultSpan>* results) const {
+ ModeFlag mode, std::vector<DatetimeParseResultSpan>* results) const {
std::vector<DatetimeParseResultSpan> found_spans;
std::unordered_set<int> executed_rules;
for (const int locale_id : ParseLocales(locales)) {
@@ -106,6 +109,11 @@
if (executed_rules.find(rule_id) != executed_rules.end()) {
continue;
}
+
+ if (!(rule_id_to_pattern_[rule_id]->enabled_modes() & mode)) {
+ continue;
+ }
+
executed_rules.insert(rule_id);
if (!ParseWithRule(*rules_[rule_id], rule_id_to_pattern_[rule_id], input,
@@ -157,11 +165,13 @@
}
DatetimeParseResultSpan parse_result;
- parse_result.span = {start, end};
if (!ExtractDatetime(*matcher, reference_time_ms_utc, reference_timezone,
- locale_id, &(parse_result.data))) {
+ locale_id, &(parse_result.data), &parse_result.span)) {
return false;
}
+ if (!use_extractors_for_locating_) {
+ parse_result.span = {start, end};
+ }
parse_result.target_classification_score =
pattern->target_classification_score();
parse_result.priority_score = pattern->priority_score();
@@ -196,13 +206,34 @@
DatetimeGranularity GetGranularity(const DateParseData& data) {
DatetimeGranularity granularity = DatetimeGranularity::GRANULARITY_YEAR;
- if (data.field_set_mask & DateParseData::YEAR_FIELD) {
+ if ((data.field_set_mask & DateParseData::YEAR_FIELD) ||
+ (data.field_set_mask & DateParseData::RELATION_TYPE_FIELD &&
+ (data.relation_type == DateParseData::RelationType::YEAR))) {
granularity = DatetimeGranularity::GRANULARITY_YEAR;
}
- if (data.field_set_mask & DateParseData::MONTH_FIELD) {
+ if ((data.field_set_mask & DateParseData::MONTH_FIELD) ||
+ (data.field_set_mask & DateParseData::RELATION_TYPE_FIELD &&
+ (data.relation_type == DateParseData::RelationType::MONTH))) {
granularity = DatetimeGranularity::GRANULARITY_MONTH;
}
- if (data.field_set_mask & DateParseData::DAY_FIELD) {
+ if (data.field_set_mask & DateParseData::RELATION_TYPE_FIELD &&
+ (data.relation_type == DateParseData::RelationType::WEEK)) {
+ granularity = DatetimeGranularity::GRANULARITY_WEEK;
+ }
+ if (data.field_set_mask & DateParseData::DAY_FIELD ||
+ (data.field_set_mask & DateParseData::RELATION_FIELD &&
+ (data.relation == DateParseData::Relation::NOW ||
+ data.relation == DateParseData::Relation::TOMORROW ||
+ data.relation == DateParseData::Relation::YESTERDAY)) ||
+ (data.field_set_mask & DateParseData::RELATION_TYPE_FIELD &&
+ (data.relation_type == DateParseData::RelationType::MONDAY ||
+ data.relation_type == DateParseData::RelationType::TUESDAY ||
+ data.relation_type == DateParseData::RelationType::WEDNESDAY ||
+ data.relation_type == DateParseData::RelationType::THURSDAY ||
+ data.relation_type == DateParseData::RelationType::FRIDAY ||
+ data.relation_type == DateParseData::RelationType::SATURDAY ||
+ data.relation_type == DateParseData::RelationType::SUNDAY ||
+ data.relation_type == DateParseData::RelationType::DAY))) {
granularity = DatetimeGranularity::GRANULARITY_DAY;
}
if (data.field_set_mask & DateParseData::HOUR_FIELD) {
@@ -222,12 +253,12 @@
bool DatetimeParser::ExtractDatetime(const UniLib::RegexMatcher& matcher,
const int64 reference_time_ms_utc,
const std::string& reference_timezone,
- int locale_id,
- DatetimeParseResult* result) const {
+ int locale_id, DatetimeParseResult* result,
+ CodepointSpan* result_span) const {
DateParseData parse;
DatetimeExtractor extractor(matcher, locale_id, unilib_, extractor_rules_,
type_and_locale_to_extractor_rule_);
- if (!extractor.Extract(&parse)) {
+ if (!extractor.Extract(&parse, result_span)) {
return false;
}
diff --git a/datetime/parser.h b/datetime/parser.h
index a56f83d..9f31142 100644
--- a/datetime/parser.h
+++ b/datetime/parser.h
@@ -41,11 +41,13 @@
// do not overlap.
bool Parse(const std::string& input, int64 reference_time_ms_utc,
const std::string& reference_timezone, const std::string& locales,
+ ModeFlag mode,
std::vector<DatetimeParseResultSpan>* results) const;
// Same as above but takes UnicodeText.
bool Parse(const UnicodeText& input, int64 reference_time_ms_utc,
const std::string& reference_timezone, const std::string& locales,
+ ModeFlag mode,
std::vector<DatetimeParseResultSpan>* results) const;
protected:
@@ -64,19 +66,21 @@
bool ExtractDatetime(const UniLib::RegexMatcher& matcher,
int64 reference_time_ms_utc,
const std::string& reference_timezone, int locale_id,
- DatetimeParseResult* result) const;
+ DatetimeParseResult* result,
+ CodepointSpan* result_span) const;
private:
bool initialized_;
const UniLib& unilib_;
std::vector<const DatetimeModelPattern*> rule_id_to_pattern_;
- std::vector<std::unique_ptr<UniLib::RegexPattern>> rules_;
+ std::vector<std::unique_ptr<const UniLib::RegexPattern>> rules_;
std::unordered_map<int, std::vector<int>> locale_to_rules_;
- std::vector<std::unique_ptr<UniLib::RegexPattern>> extractor_rules_;
+ std::vector<std::unique_ptr<const UniLib::RegexPattern>> extractor_rules_;
std::unordered_map<DatetimeExtractorType, std::unordered_map<int, int>>
type_and_locale_to_extractor_rule_;
std::unordered_map<std::string, int> locale_string_to_id_;
CalendarLib calendar_lib_;
+ bool use_extractors_for_locating_;
};
} // namespace libtextclassifier2
diff --git a/datetime/parser_test.cc b/datetime/parser_test.cc
index 721cb86..1df959f 100644
--- a/datetime/parser_test.cc
+++ b/datetime/parser_test.cc
@@ -79,7 +79,8 @@
std::vector<DatetimeParseResultSpan> results;
- if (!parser_->Parse(text, 0, timezone, /*locales=*/"", &results)) {
+ if (!parser_->Parse(text, 0, timezone, /*locales=*/"", ModeFlag_ANNOTATION,
+ &results)) {
TC_LOG(ERROR) << text;
TC_CHECK(false);
}
@@ -122,8 +123,7 @@
TEST_F(ParserTest, ParseShort) {
EXPECT_TRUE(
ParsesCorrectly("{January 1, 1988}", 567990000000, GRANULARITY_DAY));
- EXPECT_TRUE(
- ParsesCorrectly("{three days ago}", -262800000, GRANULARITY_YEAR));
+ EXPECT_TRUE(ParsesCorrectly("{three days ago}", -262800000, GRANULARITY_DAY));
}
TEST_F(ParserTest, Parse) {
@@ -148,37 +148,37 @@
ParsesCorrectly("{Mar 16 08:12:04}", 6419524000, GRANULARITY_SECOND));
EXPECT_TRUE(ParsesCorrectly("{2012-10-14T22:11:20}", 1350245480000,
GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectly("{2014-07-01T14:59:55.711Z}", 1404219595000,
+ EXPECT_TRUE(ParsesCorrectly("{2014-07-01T14:59:55}.711Z", 1404219595000,
GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectly("{2010-06-26 02:31:29,573}", 1277512289000,
+ EXPECT_TRUE(ParsesCorrectly("{2010-06-26 02:31:29},573", 1277512289000,
GRANULARITY_SECOND));
EXPECT_TRUE(ParsesCorrectly("{2006/01/22 04:11:05}", 1137899465000,
GRANULARITY_SECOND));
EXPECT_TRUE(
ParsesCorrectly("{150423 11:42:35}", 1429782155000, GRANULARITY_SECOND));
EXPECT_TRUE(ParsesCorrectly("{11:42:35}", 38555000, GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectly("{11:42:35.173}", 38555000, GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectly("{11:42:35}.173", 38555000, GRANULARITY_SECOND));
EXPECT_TRUE(
- ParsesCorrectly("{23/Apr 11:42:35,173}", 9715355000, GRANULARITY_SECOND));
+ ParsesCorrectly("{23/Apr 11:42:35},173", 9715355000, GRANULARITY_SECOND));
EXPECT_TRUE(ParsesCorrectly("{23/Apr/2015:11:42:35}", 1429782155000,
GRANULARITY_SECOND));
EXPECT_TRUE(ParsesCorrectly("{23/Apr/2015 11:42:35}", 1429782155000,
GRANULARITY_SECOND));
EXPECT_TRUE(ParsesCorrectly("{23-Apr-2015 11:42:35}", 1429782155000,
GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectly("{23-Apr-2015 11:42:35.883}", 1429782155000,
+ EXPECT_TRUE(ParsesCorrectly("{23-Apr-2015 11:42:35}.883", 1429782155000,
GRANULARITY_SECOND));
EXPECT_TRUE(ParsesCorrectly("{23 Apr 2015 11:42:35}", 1429782155000,
GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectly("{23 Apr 2015 11:42:35.883}", 1429782155000,
+ EXPECT_TRUE(ParsesCorrectly("{23 Apr 2015 11:42:35}.883", 1429782155000,
GRANULARITY_SECOND));
EXPECT_TRUE(ParsesCorrectly("{04/23/15 11:42:35}", 1429782155000,
GRANULARITY_SECOND));
EXPECT_TRUE(ParsesCorrectly("{04/23/2015 11:42:35}", 1429782155000,
GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectly("{04/23/2015 11:42:35.883}", 1429782155000,
+ EXPECT_TRUE(ParsesCorrectly("{04/23/2015 11:42:35}.883", 1429782155000,
GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectly("{8/5/2011 3:31:18 AM:234}", 1312507878000,
+ EXPECT_TRUE(ParsesCorrectly("{8/5/2011 3:31:18 AM}:234}", 1312507878000,
GRANULARITY_SECOND));
EXPECT_TRUE(ParsesCorrectly("{9/28/2011 2:23:15 PM}", 1317212595000,
GRANULARITY_SECOND));
@@ -203,30 +203,25 @@
EXPECT_TRUE(ParsesCorrectly("{january 1 2018 at 4pm}", 1514818800000,
GRANULARITY_HOUR));
- // WARNING: The following cases have incorrect granularity.
- // TODO(zilka): Fix this, when granularity is correctly implemented in
- // InterpretParseData.
- EXPECT_TRUE(ParsesCorrectly("{today}", -3600000, GRANULARITY_YEAR));
- EXPECT_TRUE(ParsesCorrectly("{today}", -57600000, GRANULARITY_YEAR,
+ EXPECT_TRUE(ParsesCorrectly("{today}", -3600000, GRANULARITY_DAY));
+ EXPECT_TRUE(ParsesCorrectly("{today}", -57600000, GRANULARITY_DAY,
"America/Los_Angeles"));
- EXPECT_TRUE(ParsesCorrectly("{next week}", 255600000, GRANULARITY_YEAR));
- EXPECT_TRUE(ParsesCorrectly("{next day}", 82800000, GRANULARITY_YEAR));
- EXPECT_TRUE(ParsesCorrectly("{in three days}", 255600000, GRANULARITY_YEAR));
+ EXPECT_TRUE(ParsesCorrectly("{next week}", 255600000, GRANULARITY_WEEK));
+ EXPECT_TRUE(ParsesCorrectly("{next day}", 82800000, GRANULARITY_DAY));
+ EXPECT_TRUE(ParsesCorrectly("{in three days}", 255600000, GRANULARITY_DAY));
EXPECT_TRUE(
- ParsesCorrectly("{in three weeks}", 1465200000, GRANULARITY_YEAR));
- EXPECT_TRUE(ParsesCorrectly("{tomorrow}", 82800000, GRANULARITY_YEAR));
+ ParsesCorrectly("{in three weeks}", 1465200000, GRANULARITY_WEEK));
+ EXPECT_TRUE(ParsesCorrectly("{tomorrow}", 82800000, GRANULARITY_DAY));
EXPECT_TRUE(
ParsesCorrectly("{tomorrow at 4:00}", 97200000, GRANULARITY_MINUTE));
EXPECT_TRUE(ParsesCorrectly("{tomorrow at 4}", 97200000, GRANULARITY_HOUR));
- EXPECT_TRUE(ParsesCorrectly("{next wednesday}", 514800000, GRANULARITY_YEAR));
+ EXPECT_TRUE(ParsesCorrectly("{next wednesday}", 514800000, GRANULARITY_DAY));
EXPECT_TRUE(
ParsesCorrectly("{next wednesday at 4}", 529200000, GRANULARITY_HOUR));
EXPECT_TRUE(ParsesCorrectly("last seen {today at 9:01 PM}", 72060000,
GRANULARITY_MINUTE));
- EXPECT_TRUE(
- ParsesCorrectly("{Three days ago}", -262800000, GRANULARITY_YEAR));
- EXPECT_TRUE(
- ParsesCorrectly("{three days ago}", -262800000, GRANULARITY_YEAR));
+ EXPECT_TRUE(ParsesCorrectly("{Three days ago}", -262800000, GRANULARITY_DAY));
+ EXPECT_TRUE(ParsesCorrectly("{three days ago}", -262800000, GRANULARITY_DAY));
}
// TODO(zilka): Add a test that tests multiple locales.
diff --git a/feature-processor.cc b/feature-processor.cc
index 61e5735..e4df94f 100644
--- a/feature-processor.cc
+++ b/feature-processor.cc
@@ -695,24 +695,22 @@
}
}
-void FeatureProcessor::TokenizeAndFindClick(const std::string& context,
- CodepointSpan input_span,
- bool only_use_line_with_click,
- std::vector<Token>* tokens,
- int* click_pos) const {
+void FeatureProcessor::RetokenizeAndFindClick(const std::string& context,
+ CodepointSpan input_span,
+ bool only_use_line_with_click,
+ std::vector<Token>* tokens,
+ int* click_pos) const {
const UnicodeText context_unicode =
UTF8ToUnicodeText(context, /*do_copy=*/false);
- TokenizeAndFindClick(context_unicode, input_span, only_use_line_with_click,
- tokens, click_pos);
+ RetokenizeAndFindClick(context_unicode, input_span, only_use_line_with_click,
+ tokens, click_pos);
}
-void FeatureProcessor::TokenizeAndFindClick(const UnicodeText& context_unicode,
- CodepointSpan input_span,
- bool only_use_line_with_click,
- std::vector<Token>* tokens,
- int* click_pos) const {
+void FeatureProcessor::RetokenizeAndFindClick(
+ const UnicodeText& context_unicode, CodepointSpan input_span,
+ bool only_use_line_with_click, std::vector<Token>* tokens,
+ int* click_pos) const {
TC_CHECK(tokens != nullptr);
- *tokens = Tokenize(context_unicode);
if (options_->split_tokens_on_selection_boundaries()) {
internal::SplitTokensOnSelectionBoundaries(input_span, tokens);
@@ -775,7 +773,8 @@
bool FeatureProcessor::ExtractFeatures(
const std::vector<Token>& tokens, TokenSpan token_span,
CodepointSpan selection_span_for_feature,
- EmbeddingExecutor* embedding_executor, int feature_vector_size,
+ const EmbeddingExecutor* embedding_executor,
+ EmbeddingCache* embedding_cache, int feature_vector_size,
std::unique_ptr<CachedFeatures>* cached_features) const {
if (options_->min_supported_codepoint_ratio() > 0) {
const float supported_codepoint_ratio =
@@ -787,32 +786,30 @@
}
}
- std::vector<std::vector<int>> sparse_features(TokenSpanSize(token_span));
- std::vector<std::vector<float>> dense_features(TokenSpanSize(token_span));
+ std::unique_ptr<std::vector<float>> features(new std::vector<float>());
+ features->reserve(feature_vector_size * TokenSpanSize(token_span));
for (int i = token_span.first; i < token_span.second; ++i) {
- const Token& token = tokens[i];
- const int features_index = i - token_span.first;
- if (!feature_extractor_.Extract(
- token, token.IsContainedInSpan(selection_span_for_feature),
- &(sparse_features[features_index]),
- &(dense_features[features_index]))) {
- TC_LOG(ERROR) << "Could not extract token's features: " << token;
+ if (!AppendTokenFeaturesWithCache(tokens[i], selection_span_for_feature,
+ embedding_executor, embedding_cache,
+ features.get())) {
+ TC_LOG(ERROR) << "Could not get token features.";
return false;
}
}
- std::vector<int> padding_sparse_features;
- std::vector<float> padding_dense_features;
- if (!feature_extractor_.Extract(Token(), false, &padding_sparse_features,
- &padding_dense_features)) {
- TC_LOG(ERROR) << "Could not extract padding token's features.";
+ std::unique_ptr<std::vector<float>> padding_features(
+ new std::vector<float>());
+ padding_features->reserve(feature_vector_size);
+ if (!AppendTokenFeaturesWithCache(Token(), selection_span_for_feature,
+ embedding_executor, embedding_cache,
+ padding_features.get())) {
+ TC_LOG(ERROR) << "Count not get padding token features.";
return false;
}
- *cached_features =
- CachedFeatures::Create(token_span, sparse_features, dense_features,
- padding_sparse_features, padding_dense_features,
- options_, embedding_executor, feature_vector_size);
+ *cached_features = CachedFeatures::Create(token_span, std::move(features),
+ std::move(padding_features),
+ options_, feature_vector_size);
if (!*cached_features) {
TC_LOG(ERROR) << "Cound not create cached features.";
return false;
@@ -928,4 +925,69 @@
}
}
+bool FeatureProcessor::AppendTokenFeaturesWithCache(
+ const Token& token, CodepointSpan selection_span_for_feature,
+ const EmbeddingExecutor* embedding_executor,
+ EmbeddingCache* embedding_cache,
+ std::vector<float>* output_features) const {
+ // Look for the embedded features for the token in the cache, if there is one.
+ if (embedding_cache) {
+ const auto it = embedding_cache->find({token.start, token.end});
+ if (it != embedding_cache->end()) {
+ // The embedded features were found in the cache, extract only the dense
+ // features.
+ std::vector<float> dense_features;
+ if (!feature_extractor_.Extract(
+ token, token.IsContainedInSpan(selection_span_for_feature),
+ /*sparse_features=*/nullptr, &dense_features)) {
+ TC_LOG(ERROR) << "Could not extract token's dense features.";
+ return false;
+ }
+
+ // Append both embedded and dense features to the output and return.
+ output_features->insert(output_features->end(), it->second.begin(),
+ it->second.end());
+ output_features->insert(output_features->end(), dense_features.begin(),
+ dense_features.end());
+ return true;
+ }
+ }
+
+ // Extract the sparse and dense features.
+ std::vector<int> sparse_features;
+ std::vector<float> dense_features;
+ if (!feature_extractor_.Extract(
+ token, token.IsContainedInSpan(selection_span_for_feature),
+ &sparse_features, &dense_features)) {
+ TC_LOG(ERROR) << "Could not extract token's features.";
+ return false;
+ }
+
+ // Embed the sparse features, appending them directly to the output.
+ const int embedding_size = GetOptions()->embedding_size();
+ output_features->resize(output_features->size() + embedding_size);
+ float* output_features_end =
+ output_features->data() + output_features->size();
+ if (!embedding_executor->AddEmbedding(
+ TensorView<int>(sparse_features.data(),
+ {static_cast<int>(sparse_features.size())}),
+ /*dest=*/output_features_end - embedding_size,
+ /*dest_size=*/embedding_size)) {
+ TC_LOG(ERROR) << "Cound not embed token's sparse features.";
+ return false;
+ }
+
+ // If there is a cache, the embedded features for the token were not in it,
+ // so insert them.
+ if (embedding_cache) {
+ (*embedding_cache)[{token.start, token.end}] = std::vector<float>(
+ output_features_end - embedding_size, output_features_end);
+ }
+
+ // Append the dense features to the output.
+ output_features->insert(output_features->end(), dense_features.begin(),
+ dense_features.end());
+ return true;
+}
+
} // namespace libtextclassifier2
diff --git a/feature-processor.h b/feature-processor.h
index e6f33d6..553bd1e 100644
--- a/feature-processor.h
+++ b/feature-processor.h
@@ -86,6 +86,13 @@
// Takes care of preparing features for the span prediction model.
class FeatureProcessor {
public:
+ // A cache mapping codepoint spans to embedded tokens features. An instance
+ // can be provided to multiple calls to ExtractFeatures() operating on the
+ // same context (the same codepoint spans corresponding to the same tokens),
+ // as an optimization. Note that the tokenizations do not have to be
+ // identical.
+ typedef std::map<CodepointSpan, std::vector<float>> EmbeddingCache;
+
// If unilib is nullptr, will create and own an instance of a UniLib,
// otherwise will use what's passed in.
explicit FeatureProcessor(const FeatureProcessorOptions* options,
@@ -139,24 +146,25 @@
const FeatureProcessorOptions* GetOptions() const { return options_; }
- // Tokenizes the context and input span, and finds the click position.
- void TokenizeAndFindClick(const std::string& context,
- CodepointSpan input_span,
- bool only_use_line_with_click,
- std::vector<Token>* tokens, int* click_pos) const;
+ // Retokenizes the context and input span, and finds the click position.
+ // Depending on the options, might modify tokens (split them or remove them).
+ void RetokenizeAndFindClick(const std::string& context,
+ CodepointSpan input_span,
+ bool only_use_line_with_click,
+ std::vector<Token>* tokens, int* click_pos) const;
// Same as above but takes UnicodeText.
- void TokenizeAndFindClick(const UnicodeText& context_unicode,
- CodepointSpan input_span,
- bool only_use_line_with_click,
- std::vector<Token>* tokens, int* click_pos) const;
+ void RetokenizeAndFindClick(const UnicodeText& context_unicode,
+ CodepointSpan input_span,
+ bool only_use_line_with_click,
+ std::vector<Token>* tokens, int* click_pos) const;
// Extracts features as a CachedFeatures object that can be used for repeated
// inference over token spans in the given context.
bool ExtractFeatures(const std::vector<Token>& tokens, TokenSpan token_span,
CodepointSpan selection_span_for_feature,
- EmbeddingExecutor* embedding_executor,
- int feature_vector_size,
+ const EmbeddingExecutor* embedding_executor,
+ EmbeddingCache* embedding_cache, int feature_vector_size,
std::unique_ptr<CachedFeatures>* cached_features) const;
// Fills selection_label_spans with CodepointSpans that correspond to the
@@ -280,6 +288,15 @@
CodepointSpan span,
std::vector<Token>* tokens) const;
+ // Extracts the features of a token and appends them to the output vector.
+ // Uses the embedding cache to to avoid re-extracting the re-embedding the
+ // sparse features for the same token.
+ bool AppendTokenFeaturesWithCache(const Token& token,
+ CodepointSpan selection_span_for_feature,
+ const EmbeddingExecutor* embedding_executor,
+ EmbeddingCache* embedding_cache,
+ std::vector<float>* output_features) const;
+
private:
std::unique_ptr<UniLib> owned_unilib_;
const UniLib* unilib_;
diff --git a/feature-processor_test.cc b/feature-processor_test.cc
index 78977d4..70ef0a7 100644
--- a/feature-processor_test.cc
+++ b/feature-processor_test.cc
@@ -27,6 +27,7 @@
using testing::ElementsAreArray;
using testing::FloatEq;
+using testing::Matcher;
flatbuffers::DetachedBuffer PackFeatureProcessorOptions(
const FeatureProcessorOptionsT& options) {
@@ -35,6 +36,19 @@
return builder.Release();
}
+template <typename T>
+std::vector<T> Subvector(const std::vector<T>& vector, int start, int end) {
+ return std::vector<T>(vector.begin() + start, vector.begin() + end);
+}
+
+Matcher<std::vector<float>> ElementsAreFloat(const std::vector<float>& values) {
+ std::vector<Matcher<float>> matchers;
+ for (const float value : values) {
+ matchers.push_back(FloatEq(value));
+ }
+ return ElementsAreArray(matchers);
+}
+
class TestingFeatureProcessor : public FeatureProcessor {
public:
using FeatureProcessor::CountIgnoredSpanBoundaryCodepoints;
@@ -51,7 +65,7 @@
class FakeEmbeddingExecutor : public EmbeddingExecutor {
public:
bool AddEmbedding(const TensorView<int>& sparse_features, float* dest,
- int dest_size) override {
+ int dest_size) const override {
TC_CHECK_GE(dest_size, 4);
EXPECT_EQ(sparse_features.size(), 1);
dest[0] = sparse_features.data()[0];
@@ -147,7 +161,7 @@
}
TEST(FeatureProcessorTest, KeepLineWithClickFirst) {
- CREATE_UNILIB_FOR_TESTING
+ CREATE_UNILIB_FOR_TESTING;
FeatureProcessorOptionsT options;
options.only_use_line_with_click = true;
flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
@@ -173,7 +187,7 @@
}
TEST(FeatureProcessorTest, KeepLineWithClickSecond) {
- CREATE_UNILIB_FOR_TESTING
+ CREATE_UNILIB_FOR_TESTING;
FeatureProcessorOptionsT options;
options.only_use_line_with_click = true;
flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
@@ -199,7 +213,7 @@
}
TEST(FeatureProcessorTest, KeepLineWithClickThird) {
- CREATE_UNILIB_FOR_TESTING
+ CREATE_UNILIB_FOR_TESTING;
FeatureProcessorOptionsT options;
options.only_use_line_with_click = true;
flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
@@ -225,7 +239,7 @@
}
TEST(FeatureProcessorTest, KeepLineWithClickSecondWithPipe) {
- CREATE_UNILIB_FOR_TESTING
+ CREATE_UNILIB_FOR_TESTING;
FeatureProcessorOptionsT options;
options.only_use_line_with_click = true;
flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
@@ -251,7 +265,7 @@
}
TEST(FeatureProcessorTest, KeepLineWithCrosslineClick) {
- CREATE_UNILIB_FOR_TESTING
+ CREATE_UNILIB_FOR_TESTING;
FeatureProcessorOptionsT options;
options.only_use_line_with_click = true;
flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
@@ -279,7 +293,7 @@
}
TEST(FeatureProcessorTest, SpanToLabel) {
- CREATE_UNILIB_FOR_TESTING
+ CREATE_UNILIB_FOR_TESTING;
FeatureProcessorOptionsT options;
options.context_size = 1;
options.max_selection_span = 1;
@@ -354,7 +368,7 @@
}
TEST(FeatureProcessorTest, SpanToLabelIgnoresPunctuation) {
- CREATE_UNILIB_FOR_TESTING
+ CREATE_UNILIB_FOR_TESTING;
FeatureProcessorOptionsT options;
options.context_size = 1;
options.max_selection_span = 1;
@@ -542,7 +556,7 @@
}
flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
- CREATE_UNILIB_FOR_TESTING
+ CREATE_UNILIB_FOR_TESTING;
TestingFeatureProcessor feature_processor(
flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
&unilib);
@@ -590,7 +604,7 @@
EXPECT_TRUE(feature_processor2.ExtractFeatures(
tokens, /*token_span=*/{0, 3},
/*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex},
- &embedding_executor,
+ &embedding_executor, /*embedding_cache=*/nullptr,
/*feature_vector_size=*/4, &cached_features));
options.min_supported_codepoint_ratio = 0.2;
@@ -602,7 +616,7 @@
EXPECT_TRUE(feature_processor3.ExtractFeatures(
tokens, /*token_span=*/{0, 3},
/*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex},
- &embedding_executor,
+ &embedding_executor, /*embedding_cache=*/nullptr,
/*feature_vector_size=*/4, &cached_features));
options.min_supported_codepoint_ratio = 0.5;
@@ -614,7 +628,7 @@
EXPECT_FALSE(feature_processor4.ExtractFeatures(
tokens, /*token_span=*/{0, 3},
/*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex},
- &embedding_executor,
+ &embedding_executor, /*embedding_cache=*/nullptr,
/*feature_vector_size=*/4, &cached_features));
}
@@ -628,7 +642,7 @@
options.extract_selection_mask_feature = true;
flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
- CREATE_UNILIB_FOR_TESTING
+ CREATE_UNILIB_FOR_TESTING;
TestingFeatureProcessor feature_processor(
flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
&unilib);
@@ -643,7 +657,8 @@
EXPECT_TRUE(feature_processor.ExtractFeatures(
tokens, /*token_span=*/{0, 4},
/*selection_span_for_feature=*/{4, 11}, &embedding_executor,
- /*feature_vector_size=*/5, &cached_features));
+ /*embedding_cache=*/nullptr, /*feature_vector_size=*/5,
+ &cached_features));
std::vector<float> features;
cached_features->AppendClickContextFeaturesForClick(1, &features);
ASSERT_EQ(features.size(), 25);
@@ -654,6 +669,76 @@
EXPECT_THAT(features[24], FloatEq(0.0));
}
+TEST(FeatureProcessorTest, EmbeddingCache) {
+ FeatureProcessorOptionsT options;
+ options.context_size = 2;
+ options.max_selection_span = 2;
+ options.snap_label_span_boundaries_to_containing_tokens = false;
+ options.feature_version = 2;
+ options.embedding_size = 4;
+ options.bounds_sensitive_features.reset(
+ new FeatureProcessorOptions_::BoundsSensitiveFeaturesT());
+ options.bounds_sensitive_features->enabled = true;
+ options.bounds_sensitive_features->num_tokens_before = 3;
+ options.bounds_sensitive_features->num_tokens_inside_left = 2;
+ options.bounds_sensitive_features->num_tokens_inside_right = 2;
+ options.bounds_sensitive_features->num_tokens_after = 3;
+
+ flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
+ CREATE_UNILIB_FOR_TESTING;
+ TestingFeatureProcessor feature_processor(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
+ &unilib);
+
+ std::unique_ptr<CachedFeatures> cached_features;
+
+ FakeEmbeddingExecutor embedding_executor;
+
+ const std::vector<Token> tokens = {
+ Token("aaa", 0, 3), Token("bbb", 4, 7), Token("ccc", 8, 11),
+ Token("ddd", 12, 15), Token("eee", 16, 19), Token("fff", 20, 23)};
+
+ // We pre-populate the cache with dummy embeddings, to make sure they are
+ // used when populating the features vector.
+ const std::vector<float> cached_padding_features = {10.0, -10.0, 10.0, -10.0};
+ const std::vector<float> cached_features1 = {1.0, 2.0, 3.0, 4.0};
+ const std::vector<float> cached_features2 = {5.0, 6.0, 7.0, 8.0};
+ FeatureProcessor::EmbeddingCache embedding_cache = {
+ {{kInvalidIndex, kInvalidIndex}, cached_padding_features},
+ {{4, 7}, cached_features1},
+ {{12, 15}, cached_features2},
+ };
+
+ EXPECT_TRUE(feature_processor.ExtractFeatures(
+ tokens, /*token_span=*/{0, 6},
+ /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex},
+ &embedding_executor, &embedding_cache, /*feature_vector_size=*/4,
+ &cached_features));
+ std::vector<float> features;
+ cached_features->AppendBoundsSensitiveFeaturesForSpan({2, 4}, &features);
+ ASSERT_EQ(features.size(), 40);
+ // Check that the dummy embeddings were used.
+ EXPECT_THAT(Subvector(features, 0, 4),
+ ElementsAreFloat(cached_padding_features));
+ EXPECT_THAT(Subvector(features, 8, 12), ElementsAreFloat(cached_features1));
+ EXPECT_THAT(Subvector(features, 16, 20), ElementsAreFloat(cached_features2));
+ EXPECT_THAT(Subvector(features, 24, 28), ElementsAreFloat(cached_features2));
+ EXPECT_THAT(Subvector(features, 36, 40),
+ ElementsAreFloat(cached_padding_features));
+ // Check that the real embeddings were cached.
+ EXPECT_EQ(embedding_cache.size(), 7);
+ EXPECT_THAT(Subvector(features, 4, 8),
+ ElementsAreFloat(embedding_cache.at({0, 3})));
+ EXPECT_THAT(Subvector(features, 12, 16),
+ ElementsAreFloat(embedding_cache.at({8, 11})));
+ EXPECT_THAT(Subvector(features, 20, 24),
+ ElementsAreFloat(embedding_cache.at({8, 11})));
+ EXPECT_THAT(Subvector(features, 28, 32),
+ ElementsAreFloat(embedding_cache.at({16, 19})));
+ EXPECT_THAT(Subvector(features, 32, 36),
+ ElementsAreFloat(embedding_cache.at({20, 23})));
+}
+
TEST(FeatureProcessorTest, StripUnusedTokensWithNoRelativeClick) {
std::vector<Token> tokens_orig{
Token("0", 0, 0), Token("1", 0, 0), Token("2", 0, 0), Token("3", 0, 0),
@@ -767,7 +852,7 @@
}
TEST(FeatureProcessorTest, InternalTokenizeOnScriptChange) {
- CREATE_UNILIB_FOR_TESTING
+ CREATE_UNILIB_FOR_TESTING;
FeatureProcessorOptionsT options;
options.tokenization_codepoint_config.emplace_back(
new TokenizationCodepointRangeT());
@@ -907,7 +992,7 @@
#endif
TEST(FeatureProcessorTest, IgnoredSpanBoundaryCodepoints) {
- CREATE_UNILIB_FOR_TESTING
+ CREATE_UNILIB_FOR_TESTING;
FeatureProcessorOptionsT options;
options.ignored_span_boundary_codepoints.push_back('.');
options.ignored_span_boundary_codepoints.push_back(',');
diff --git a/model-executor.cc b/model-executor.cc
index fc7d4ae..c79056f 100644
--- a/model-executor.cc
+++ b/model-executor.cc
@@ -22,33 +22,34 @@
namespace libtextclassifier2 {
namespace internal {
bool FromModelSpec(const tflite::Model* model_spec,
- std::unique_ptr<tflite::FlatBufferModel>* model,
- std::unique_ptr<tflite::Interpreter>* interpreter) {
+ std::unique_ptr<const tflite::FlatBufferModel>* model) {
*model = tflite::FlatBufferModel::BuildFromModel(model_spec);
if (!(*model) || !(*model)->initialized()) {
TC_LOG(ERROR) << "Could not build TFLite model from a model spec. ";
return false;
}
-
- tflite::ops::builtin::BuiltinOpResolver builtins;
- tflite::InterpreterBuilder(**model, builtins)(interpreter);
- if (!interpreter) {
- TC_LOG(ERROR) << "Could not build TFLite interpreter.";
- return false;
- }
return true;
}
} // namespace internal
+std::unique_ptr<tflite::Interpreter> ModelExecutor::CreateInterpreter() const {
+ std::unique_ptr<tflite::Interpreter> interpreter;
+ tflite::InterpreterBuilder(*model_, builtins_)(&interpreter);
+ return interpreter;
+}
+
TFLiteEmbeddingExecutor::TFLiteEmbeddingExecutor(
const tflite::Model* model_spec, const int embedding_size,
const int quantization_bits)
: quantization_bits_(quantization_bits),
output_embedding_size_(embedding_size) {
- internal::FromModelSpec(model_spec, &model_, &interpreter_);
+ internal::FromModelSpec(model_spec, &model_);
+ tflite::InterpreterBuilder(*model_, builtins_)(&interpreter_);
if (!interpreter_) {
+ TC_LOG(ERROR) << "Could not build TFLite interpreter for embeddings.";
return;
}
+
if (interpreter_->tensors_size() != 2) {
return;
}
@@ -73,7 +74,7 @@
}
bool TFLiteEmbeddingExecutor::AddEmbedding(
- const TensorView<int>& sparse_features, float* dest, int dest_size) {
+ const TensorView<int>& sparse_features, float* dest, int dest_size) const {
if (!initialized_ || dest_size != output_embedding_size_) {
TC_LOG(ERROR) << "Mismatching dest_size and output_embedding_size: "
<< dest_size << " " << output_embedding_size_;
@@ -99,6 +100,9 @@
const int output_index_logits,
const TensorView<float>& features,
tflite::Interpreter* interpreter) {
+ if (!interpreter) {
+ return TensorView<float>::Invalid();
+ }
interpreter->ResizeInputTensor(input_index_features, features.shape());
if (interpreter->AllocateTensors() != kTfLiteOk) {
TC_VLOG(1) << "Allocation failed.";
diff --git a/model-executor.h b/model-executor.h
index b495427..547d596 100644
--- a/model-executor.h
+++ b/model-executor.h
@@ -32,8 +32,7 @@
namespace internal {
bool FromModelSpec(const tflite::Model* model_spec,
- std::unique_ptr<tflite::FlatBufferModel>* model,
- std::unique_ptr<tflite::Interpreter>* interpreter);
+ std::unique_ptr<const tflite::FlatBufferModel>* model);
} // namespace internal
// A helper function that given indices of feature and logits tensor, feature
@@ -44,24 +43,28 @@
tflite::Interpreter* interpreter);
// Executor for the text selection prediction and classification models.
-// NOTE: This class is not thread-safe.
class ModelExecutor {
public:
explicit ModelExecutor(const tflite::Model* model_spec) {
- internal::FromModelSpec(model_spec, &model_, &interpreter_);
+ internal::FromModelSpec(model_spec, &model_);
}
- TensorView<float> ComputeLogits(const TensorView<float>& features) {
+ // Creates an Interpreter for the model that serves as a scratch-pad for the
+ // inference. The Interpreter is NOT thread-safe.
+ std::unique_ptr<tflite::Interpreter> CreateInterpreter() const;
+
+ TensorView<float> ComputeLogits(const TensorView<float>& features,
+ tflite::Interpreter* interpreter) const {
return ComputeLogitsHelper(kInputIndexFeatures, kOutputIndexLogits,
- features, interpreter_.get());
+ features, interpreter);
}
protected:
static const int kInputIndexFeatures = 0;
static const int kOutputIndexLogits = 0;
- std::unique_ptr<tflite::FlatBufferModel> model_ = nullptr;
- std::unique_ptr<tflite::Interpreter> interpreter_ = nullptr;
+ std::unique_ptr<const tflite::FlatBufferModel> model_ = nullptr;
+ tflite::ops::builtin::BuiltinOpResolver builtins_;
};
// Executor for embedding sparse features into a dense vector.
@@ -72,21 +75,20 @@
// Embeds the sparse_features into a dense embedding and adds (+) it
// element-wise to the dest vector.
virtual bool AddEmbedding(const TensorView<int>& sparse_features, float* dest,
- int dest_size) = 0;
+ int dest_size) const = 0;
// Returns true when the model is ready to be used, false otherwise.
- virtual bool IsReady() { return true; }
+ virtual bool IsReady() const { return true; }
};
-// NOTE: This class is not thread-safe.
class TFLiteEmbeddingExecutor : public EmbeddingExecutor {
public:
explicit TFLiteEmbeddingExecutor(const tflite::Model* model_spec,
int embedding_size, int quantization_bits);
bool AddEmbedding(const TensorView<int>& sparse_features, float* dest,
- int dest_size) override;
+ int dest_size) const override;
- bool IsReady() override { return initialized_; }
+ bool IsReady() const override { return initialized_; }
protected:
int quantization_bits_;
@@ -97,8 +99,11 @@
const TfLiteTensor* scales_ = nullptr;
const TfLiteTensor* embeddings_ = nullptr;
- std::unique_ptr<tflite::FlatBufferModel> model_ = nullptr;
+ std::unique_ptr<const tflite::FlatBufferModel> model_ = nullptr;
+ // NOTE: This interpreter is used in a read-only way (as a storage for the
+ // model params), thus is still thread-safe.
std::unique_ptr<tflite::Interpreter> interpreter_ = nullptr;
+ tflite::ops::builtin::BuiltinOpResolver builtins_;
};
} // namespace libtextclassifier2
diff --git a/model.fbs b/model.fbs
index 2d69fe7..590c815 100755
--- a/model.fbs
+++ b/model.fbs
@@ -16,6 +16,19 @@
file_identifier "TC2 ";
+// The possible model modes, represents a bit field.
+namespace libtextclassifier2;
+enum ModeFlag : int {
+ NONE = 0,
+ ANNOTATION = 1,
+ CLASSIFICATION = 2,
+ ANNOTATION_AND_CLASSIFICATION = 3,
+ SELECTION = 4,
+ ANNOTATION_AND_SELECTION = 5,
+ CLASSIFICATION_AND_SELECTION = 6,
+ ALL = 7,
+}
+
namespace libtextclassifier2;
enum DatetimeExtractorType : int {
UNKNOWN_DATETIME_EXTRACTOR_TYPE = 0,
@@ -129,20 +142,19 @@
// Can specify a single capturing group used as match boundaries.
pattern:string;
- // Whether to apply the pattern for annotation.
- enabled_for_annotation:bool = 0;
-
- // Whether to apply the pattern for classification.
- enabled_for_classification:bool = 0;
-
- // Whether to apply the pattern for selection.
- enabled_for_selection:bool = 0;
+ // The modes for which to apply the patterns.
+ enabled_modes:libtextclassifier2.ModeFlag = ALL;
// The final score to assign to the results of this pattern.
target_classification_score:float = 1;
// Priority score used for conflict resolution with the other models.
priority_score:float = 0;
+
+ // If true, will use an approximate matching implementation implemented
+ // using Find() instead of the true Match(). This approximate matching will
+ // use the first Find() result and then check that it spans the whole input.
+ use_approximate_matching:bool = 0;
}
namespace libtextclassifier2;
@@ -164,6 +176,9 @@
// Priority score used for conflict resulution with the other models.
priority_score:float = 0;
+
+ // The modes for which to apply the patterns.
+ enabled_modes:libtextclassifier2.ModeFlag = ALL;
}
namespace libtextclassifier2;
@@ -181,13 +196,20 @@
patterns:[libtextclassifier2.DatetimeModelPattern];
extractors:[libtextclassifier2.DatetimeModelExtractor];
+
+ // If true, will use the extractors for determining the match location as
+ // opposed to using the location where the global pattern matched.
+ use_extractors_for_locating:bool = 1;
}
-// Options controlling the output of the models.
+// Options controlling the output of the Tensorflow Lite models.
namespace libtextclassifier2;
table ModelTriggeringOptions {
// Lower bound threshold for filtering annotation model outputs.
min_annotate_confidence:float = 0;
+
+ // The modes for which to enable the models.
+ enabled_modes:libtextclassifier2.ModeFlag = ALL;
}
namespace libtextclassifier2;
@@ -196,27 +218,32 @@
locales:string;
version:int;
+
+ // A name for the model that can be used for e.g. logging.
+ name:string;
+
selection_feature_options:libtextclassifier2.FeatureProcessorOptions;
classification_feature_options:libtextclassifier2.FeatureProcessorOptions;
- // TFLite models.
+ // Tensorflow Lite models.
selection_model:[ubyte] (force_align: 16);
classification_model:[ubyte] (force_align: 16);
embedding_model:[ubyte] (force_align: 16);
- regex_model:libtextclassifier2.RegexModel;
// Options for the different models.
selection_options:libtextclassifier2.SelectionModelOptions;
classification_options:libtextclassifier2.ClassificationModelOptions;
+ regex_model:libtextclassifier2.RegexModel;
datetime_model:libtextclassifier2.DatetimeModel;
// Options controlling the output of the models.
triggering_options:libtextclassifier2.ModelTriggeringOptions;
- // A name for the model that can be used for e.g. logging.
- name:string;
+ // Global switch that controls if SuggestSelection(), ClassifyText() and
+ // Annotate() will run. If a mode is disabled it returns empty/no-op results.
+ enabled_modes:libtextclassifier2.ModeFlag = ALL;
}
// Role of the codepoints in the range.
@@ -321,6 +348,10 @@
// If true, includes the selection length (in the number of tokens) as a
// feature.
include_inside_length:bool;
+
+ // If true, for selection, single token spans are not run through the model
+ // and their score is assumed to be zero.
+ score_single_token_spans_as_zero:bool;
}
namespace libtextclassifier2.FeatureProcessorOptions_;
@@ -337,6 +368,9 @@
// Size of the embedding.
embedding_size:int = -1;
+ // Number of bits for quantization for embeddings.
+ embedding_quantization_bits:int = 8;
+
// Context size defines the number of words to the left and to the right of
// the selected word to be used as context. For example, if context size is
// N, then we take N words to the left and N words to the right of the
@@ -428,7 +462,7 @@
// to it. So the resulting feature vector has two regions.
feature_version:int = 0;
- tokenization_type:libtextclassifier2.FeatureProcessorOptions_.TokenizationType;
+ tokenization_type:libtextclassifier2.FeatureProcessorOptions_.TokenizationType = INTERNAL_TOKENIZER;
icu_preserve_whitespace_tokens:bool = 0;
// List of codepoints that will be stripped from beginning and end of
@@ -447,9 +481,6 @@
// If true, tokens will be also split when the codepoint's script_id changes
// as defined in TokenizationCodepointRange.
tokenize_on_script_change:bool = 0;
-
- // Number of bits for quantization for embeddings.
- embedding_quantization_bits:int = 8;
}
root_type libtextclassifier2.Model;
diff --git a/model_generated.h b/model_generated.h
index 40f5a4c..21c1b85 100755
--- a/model_generated.h
+++ b/model_generated.h
@@ -74,6 +74,53 @@
struct FeatureProcessorOptions;
struct FeatureProcessorOptionsT;
+enum ModeFlag {
+ ModeFlag_NONE = 0,
+ ModeFlag_ANNOTATION = 1,
+ ModeFlag_CLASSIFICATION = 2,
+ ModeFlag_ANNOTATION_AND_CLASSIFICATION = 3,
+ ModeFlag_SELECTION = 4,
+ ModeFlag_ANNOTATION_AND_SELECTION = 5,
+ ModeFlag_CLASSIFICATION_AND_SELECTION = 6,
+ ModeFlag_ALL = 7,
+ ModeFlag_MIN = ModeFlag_NONE,
+ ModeFlag_MAX = ModeFlag_ALL
+};
+
+inline ModeFlag (&EnumValuesModeFlag())[8] {
+ static ModeFlag values[] = {
+ ModeFlag_NONE,
+ ModeFlag_ANNOTATION,
+ ModeFlag_CLASSIFICATION,
+ ModeFlag_ANNOTATION_AND_CLASSIFICATION,
+ ModeFlag_SELECTION,
+ ModeFlag_ANNOTATION_AND_SELECTION,
+ ModeFlag_CLASSIFICATION_AND_SELECTION,
+ ModeFlag_ALL
+ };
+ return values;
+}
+
+inline const char **EnumNamesModeFlag() {
+ static const char *names[] = {
+ "NONE",
+ "ANNOTATION",
+ "CLASSIFICATION",
+ "ANNOTATION_AND_CLASSIFICATION",
+ "SELECTION",
+ "ANNOTATION_AND_SELECTION",
+ "CLASSIFICATION_AND_SELECTION",
+ "ALL",
+ nullptr
+ };
+ return names;
+}
+
+inline const char *EnumNameModeFlag(ModeFlag e) {
+ const size_t index = static_cast<int>(e);
+ return EnumNamesModeFlag()[index];
+}
+
enum DatetimeExtractorType {
DatetimeExtractorType_UNKNOWN_DATETIME_EXTRACTOR_TYPE = 0,
DatetimeExtractorType_AM = 1,
@@ -584,17 +631,15 @@
typedef Pattern TableType;
std::string collection_name;
std::string pattern;
- bool enabled_for_annotation;
- bool enabled_for_classification;
- bool enabled_for_selection;
+ libtextclassifier2::ModeFlag enabled_modes;
float target_classification_score;
float priority_score;
+ bool use_approximate_matching;
PatternT()
- : enabled_for_annotation(false),
- enabled_for_classification(false),
- enabled_for_selection(false),
+ : enabled_modes(libtextclassifier2::ModeFlag_ALL),
target_classification_score(1.0f),
- priority_score(0.0f) {
+ priority_score(0.0f),
+ use_approximate_matching(false) {
}
};
@@ -603,11 +648,10 @@
enum {
VT_COLLECTION_NAME = 4,
VT_PATTERN = 6,
- VT_ENABLED_FOR_ANNOTATION = 8,
- VT_ENABLED_FOR_CLASSIFICATION = 10,
- VT_ENABLED_FOR_SELECTION = 12,
- VT_TARGET_CLASSIFICATION_SCORE = 14,
- VT_PRIORITY_SCORE = 16
+ VT_ENABLED_MODES = 8,
+ VT_TARGET_CLASSIFICATION_SCORE = 10,
+ VT_PRIORITY_SCORE = 12,
+ VT_USE_APPROXIMATE_MATCHING = 14
};
const flatbuffers::String *collection_name() const {
return GetPointer<const flatbuffers::String *>(VT_COLLECTION_NAME);
@@ -615,14 +659,8 @@
const flatbuffers::String *pattern() const {
return GetPointer<const flatbuffers::String *>(VT_PATTERN);
}
- bool enabled_for_annotation() const {
- return GetField<uint8_t>(VT_ENABLED_FOR_ANNOTATION, 0) != 0;
- }
- bool enabled_for_classification() const {
- return GetField<uint8_t>(VT_ENABLED_FOR_CLASSIFICATION, 0) != 0;
- }
- bool enabled_for_selection() const {
- return GetField<uint8_t>(VT_ENABLED_FOR_SELECTION, 0) != 0;
+ libtextclassifier2::ModeFlag enabled_modes() const {
+ return static_cast<libtextclassifier2::ModeFlag>(GetField<int32_t>(VT_ENABLED_MODES, 7));
}
float target_classification_score() const {
return GetField<float>(VT_TARGET_CLASSIFICATION_SCORE, 1.0f);
@@ -630,17 +668,19 @@
float priority_score() const {
return GetField<float>(VT_PRIORITY_SCORE, 0.0f);
}
+ bool use_approximate_matching() const {
+ return GetField<uint8_t>(VT_USE_APPROXIMATE_MATCHING, 0) != 0;
+ }
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyOffset(verifier, VT_COLLECTION_NAME) &&
verifier.Verify(collection_name()) &&
VerifyOffset(verifier, VT_PATTERN) &&
verifier.Verify(pattern()) &&
- VerifyField<uint8_t>(verifier, VT_ENABLED_FOR_ANNOTATION) &&
- VerifyField<uint8_t>(verifier, VT_ENABLED_FOR_CLASSIFICATION) &&
- VerifyField<uint8_t>(verifier, VT_ENABLED_FOR_SELECTION) &&
+ VerifyField<int32_t>(verifier, VT_ENABLED_MODES) &&
VerifyField<float>(verifier, VT_TARGET_CLASSIFICATION_SCORE) &&
VerifyField<float>(verifier, VT_PRIORITY_SCORE) &&
+ VerifyField<uint8_t>(verifier, VT_USE_APPROXIMATE_MATCHING) &&
verifier.EndTable();
}
PatternT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@@ -657,14 +697,8 @@
void add_pattern(flatbuffers::Offset<flatbuffers::String> pattern) {
fbb_.AddOffset(Pattern::VT_PATTERN, pattern);
}
- void add_enabled_for_annotation(bool enabled_for_annotation) {
- fbb_.AddElement<uint8_t>(Pattern::VT_ENABLED_FOR_ANNOTATION, static_cast<uint8_t>(enabled_for_annotation), 0);
- }
- void add_enabled_for_classification(bool enabled_for_classification) {
- fbb_.AddElement<uint8_t>(Pattern::VT_ENABLED_FOR_CLASSIFICATION, static_cast<uint8_t>(enabled_for_classification), 0);
- }
- void add_enabled_for_selection(bool enabled_for_selection) {
- fbb_.AddElement<uint8_t>(Pattern::VT_ENABLED_FOR_SELECTION, static_cast<uint8_t>(enabled_for_selection), 0);
+ void add_enabled_modes(libtextclassifier2::ModeFlag enabled_modes) {
+ fbb_.AddElement<int32_t>(Pattern::VT_ENABLED_MODES, static_cast<int32_t>(enabled_modes), 7);
}
void add_target_classification_score(float target_classification_score) {
fbb_.AddElement<float>(Pattern::VT_TARGET_CLASSIFICATION_SCORE, target_classification_score, 1.0f);
@@ -672,6 +706,9 @@
void add_priority_score(float priority_score) {
fbb_.AddElement<float>(Pattern::VT_PRIORITY_SCORE, priority_score, 0.0f);
}
+ void add_use_approximate_matching(bool use_approximate_matching) {
+ fbb_.AddElement<uint8_t>(Pattern::VT_USE_APPROXIMATE_MATCHING, static_cast<uint8_t>(use_approximate_matching), 0);
+ }
explicit PatternBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
@@ -688,19 +725,17 @@
flatbuffers::FlatBufferBuilder &_fbb,
flatbuffers::Offset<flatbuffers::String> collection_name = 0,
flatbuffers::Offset<flatbuffers::String> pattern = 0,
- bool enabled_for_annotation = false,
- bool enabled_for_classification = false,
- bool enabled_for_selection = false,
+ libtextclassifier2::ModeFlag enabled_modes = libtextclassifier2::ModeFlag_ALL,
float target_classification_score = 1.0f,
- float priority_score = 0.0f) {
+ float priority_score = 0.0f,
+ bool use_approximate_matching = false) {
PatternBuilder builder_(_fbb);
builder_.add_priority_score(priority_score);
builder_.add_target_classification_score(target_classification_score);
+ builder_.add_enabled_modes(enabled_modes);
builder_.add_pattern(pattern);
builder_.add_collection_name(collection_name);
- builder_.add_enabled_for_selection(enabled_for_selection);
- builder_.add_enabled_for_classification(enabled_for_classification);
- builder_.add_enabled_for_annotation(enabled_for_annotation);
+ builder_.add_use_approximate_matching(use_approximate_matching);
return builder_.Finish();
}
@@ -708,20 +743,18 @@
flatbuffers::FlatBufferBuilder &_fbb,
const char *collection_name = nullptr,
const char *pattern = nullptr,
- bool enabled_for_annotation = false,
- bool enabled_for_classification = false,
- bool enabled_for_selection = false,
+ libtextclassifier2::ModeFlag enabled_modes = libtextclassifier2::ModeFlag_ALL,
float target_classification_score = 1.0f,
- float priority_score = 0.0f) {
+ float priority_score = 0.0f,
+ bool use_approximate_matching = false) {
return libtextclassifier2::RegexModel_::CreatePattern(
_fbb,
collection_name ? _fbb.CreateString(collection_name) : 0,
pattern ? _fbb.CreateString(pattern) : 0,
- enabled_for_annotation,
- enabled_for_classification,
- enabled_for_selection,
+ enabled_modes,
target_classification_score,
- priority_score);
+ priority_score,
+ use_approximate_matching);
}
flatbuffers::Offset<Pattern> CreatePattern(flatbuffers::FlatBufferBuilder &_fbb, const PatternT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
@@ -797,9 +830,11 @@
std::vector<int32_t> locales;
float target_classification_score;
float priority_score;
+ ModeFlag enabled_modes;
DatetimeModelPatternT()
: target_classification_score(1.0f),
- priority_score(0.0f) {
+ priority_score(0.0f),
+ enabled_modes(ModeFlag_ALL) {
}
};
@@ -809,7 +844,8 @@
VT_REGEXES = 4,
VT_LOCALES = 6,
VT_TARGET_CLASSIFICATION_SCORE = 8,
- VT_PRIORITY_SCORE = 10
+ VT_PRIORITY_SCORE = 10,
+ VT_ENABLED_MODES = 12
};
const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *regexes() const {
return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_REGEXES);
@@ -823,6 +859,9 @@
float priority_score() const {
return GetField<float>(VT_PRIORITY_SCORE, 0.0f);
}
+ ModeFlag enabled_modes() const {
+ return static_cast<ModeFlag>(GetField<int32_t>(VT_ENABLED_MODES, 7));
+ }
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyOffset(verifier, VT_REGEXES) &&
@@ -832,6 +871,7 @@
verifier.Verify(locales()) &&
VerifyField<float>(verifier, VT_TARGET_CLASSIFICATION_SCORE) &&
VerifyField<float>(verifier, VT_PRIORITY_SCORE) &&
+ VerifyField<int32_t>(verifier, VT_ENABLED_MODES) &&
verifier.EndTable();
}
DatetimeModelPatternT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@@ -854,6 +894,9 @@
void add_priority_score(float priority_score) {
fbb_.AddElement<float>(DatetimeModelPattern::VT_PRIORITY_SCORE, priority_score, 0.0f);
}
+ void add_enabled_modes(ModeFlag enabled_modes) {
+ fbb_.AddElement<int32_t>(DatetimeModelPattern::VT_ENABLED_MODES, static_cast<int32_t>(enabled_modes), 7);
+ }
explicit DatetimeModelPatternBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
@@ -871,8 +914,10 @@
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> regexes = 0,
flatbuffers::Offset<flatbuffers::Vector<int32_t>> locales = 0,
float target_classification_score = 1.0f,
- float priority_score = 0.0f) {
+ float priority_score = 0.0f,
+ ModeFlag enabled_modes = ModeFlag_ALL) {
DatetimeModelPatternBuilder builder_(_fbb);
+ builder_.add_enabled_modes(enabled_modes);
builder_.add_priority_score(priority_score);
builder_.add_target_classification_score(target_classification_score);
builder_.add_locales(locales);
@@ -885,13 +930,15 @@
const std::vector<flatbuffers::Offset<flatbuffers::String>> *regexes = nullptr,
const std::vector<int32_t> *locales = nullptr,
float target_classification_score = 1.0f,
- float priority_score = 0.0f) {
+ float priority_score = 0.0f,
+ ModeFlag enabled_modes = ModeFlag_ALL) {
return libtextclassifier2::CreateDatetimeModelPattern(
_fbb,
regexes ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*regexes) : 0,
locales ? _fbb.CreateVector<int32_t>(*locales) : 0,
target_classification_score,
- priority_score);
+ priority_score,
+ enabled_modes);
}
flatbuffers::Offset<DatetimeModelPattern> CreateDatetimeModelPattern(flatbuffers::FlatBufferBuilder &_fbb, const DatetimeModelPatternT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
@@ -991,7 +1038,9 @@
std::vector<std::string> locales;
std::vector<std::unique_ptr<DatetimeModelPatternT>> patterns;
std::vector<std::unique_ptr<DatetimeModelExtractorT>> extractors;
- DatetimeModelT() {
+ bool use_extractors_for_locating;
+ DatetimeModelT()
+ : use_extractors_for_locating(true) {
}
};
@@ -1000,7 +1049,8 @@
enum {
VT_LOCALES = 4,
VT_PATTERNS = 6,
- VT_EXTRACTORS = 8
+ VT_EXTRACTORS = 8,
+ VT_USE_EXTRACTORS_FOR_LOCATING = 10
};
const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *locales() const {
return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_LOCALES);
@@ -1011,6 +1061,9 @@
const flatbuffers::Vector<flatbuffers::Offset<DatetimeModelExtractor>> *extractors() const {
return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<DatetimeModelExtractor>> *>(VT_EXTRACTORS);
}
+ bool use_extractors_for_locating() const {
+ return GetField<uint8_t>(VT_USE_EXTRACTORS_FOR_LOCATING, 1) != 0;
+ }
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyOffset(verifier, VT_LOCALES) &&
@@ -1022,6 +1075,7 @@
VerifyOffset(verifier, VT_EXTRACTORS) &&
verifier.Verify(extractors()) &&
verifier.VerifyVectorOfTables(extractors()) &&
+ VerifyField<uint8_t>(verifier, VT_USE_EXTRACTORS_FOR_LOCATING) &&
verifier.EndTable();
}
DatetimeModelT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@@ -1041,6 +1095,9 @@
void add_extractors(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<DatetimeModelExtractor>>> extractors) {
fbb_.AddOffset(DatetimeModel::VT_EXTRACTORS, extractors);
}
+ void add_use_extractors_for_locating(bool use_extractors_for_locating) {
+ fbb_.AddElement<uint8_t>(DatetimeModel::VT_USE_EXTRACTORS_FOR_LOCATING, static_cast<uint8_t>(use_extractors_for_locating), 1);
+ }
explicit DatetimeModelBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
@@ -1057,11 +1114,13 @@
flatbuffers::FlatBufferBuilder &_fbb,
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> locales = 0,
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<DatetimeModelPattern>>> patterns = 0,
- flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<DatetimeModelExtractor>>> extractors = 0) {
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<DatetimeModelExtractor>>> extractors = 0,
+ bool use_extractors_for_locating = true) {
DatetimeModelBuilder builder_(_fbb);
builder_.add_extractors(extractors);
builder_.add_patterns(patterns);
builder_.add_locales(locales);
+ builder_.add_use_extractors_for_locating(use_extractors_for_locating);
return builder_.Finish();
}
@@ -1069,12 +1128,14 @@
flatbuffers::FlatBufferBuilder &_fbb,
const std::vector<flatbuffers::Offset<flatbuffers::String>> *locales = nullptr,
const std::vector<flatbuffers::Offset<DatetimeModelPattern>> *patterns = nullptr,
- const std::vector<flatbuffers::Offset<DatetimeModelExtractor>> *extractors = nullptr) {
+ const std::vector<flatbuffers::Offset<DatetimeModelExtractor>> *extractors = nullptr,
+ bool use_extractors_for_locating = true) {
return libtextclassifier2::CreateDatetimeModel(
_fbb,
locales ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*locales) : 0,
patterns ? _fbb.CreateVector<flatbuffers::Offset<DatetimeModelPattern>>(*patterns) : 0,
- extractors ? _fbb.CreateVector<flatbuffers::Offset<DatetimeModelExtractor>>(*extractors) : 0);
+ extractors ? _fbb.CreateVector<flatbuffers::Offset<DatetimeModelExtractor>>(*extractors) : 0,
+ use_extractors_for_locating);
}
flatbuffers::Offset<DatetimeModel> CreateDatetimeModel(flatbuffers::FlatBufferBuilder &_fbb, const DatetimeModelT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
@@ -1082,22 +1143,29 @@
struct ModelTriggeringOptionsT : public flatbuffers::NativeTable {
typedef ModelTriggeringOptions TableType;
float min_annotate_confidence;
+ ModeFlag enabled_modes;
ModelTriggeringOptionsT()
- : min_annotate_confidence(0.0f) {
+ : min_annotate_confidence(0.0f),
+ enabled_modes(ModeFlag_ALL) {
}
};
struct ModelTriggeringOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
typedef ModelTriggeringOptionsT NativeTableType;
enum {
- VT_MIN_ANNOTATE_CONFIDENCE = 4
+ VT_MIN_ANNOTATE_CONFIDENCE = 4,
+ VT_ENABLED_MODES = 6
};
float min_annotate_confidence() const {
return GetField<float>(VT_MIN_ANNOTATE_CONFIDENCE, 0.0f);
}
+ ModeFlag enabled_modes() const {
+ return static_cast<ModeFlag>(GetField<int32_t>(VT_ENABLED_MODES, 7));
+ }
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyField<float>(verifier, VT_MIN_ANNOTATE_CONFIDENCE) &&
+ VerifyField<int32_t>(verifier, VT_ENABLED_MODES) &&
verifier.EndTable();
}
ModelTriggeringOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@@ -1111,6 +1179,9 @@
void add_min_annotate_confidence(float min_annotate_confidence) {
fbb_.AddElement<float>(ModelTriggeringOptions::VT_MIN_ANNOTATE_CONFIDENCE, min_annotate_confidence, 0.0f);
}
+ void add_enabled_modes(ModeFlag enabled_modes) {
+ fbb_.AddElement<int32_t>(ModelTriggeringOptions::VT_ENABLED_MODES, static_cast<int32_t>(enabled_modes), 7);
+ }
explicit ModelTriggeringOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
@@ -1125,8 +1196,10 @@
inline flatbuffers::Offset<ModelTriggeringOptions> CreateModelTriggeringOptions(
flatbuffers::FlatBufferBuilder &_fbb,
- float min_annotate_confidence = 0.0f) {
+ float min_annotate_confidence = 0.0f,
+ ModeFlag enabled_modes = ModeFlag_ALL) {
ModelTriggeringOptionsBuilder builder_(_fbb);
+ builder_.add_enabled_modes(enabled_modes);
builder_.add_min_annotate_confidence(min_annotate_confidence);
return builder_.Finish();
}
@@ -1137,19 +1210,21 @@
typedef Model TableType;
std::string locales;
int32_t version;
+ std::string name;
std::unique_ptr<FeatureProcessorOptionsT> selection_feature_options;
std::unique_ptr<FeatureProcessorOptionsT> classification_feature_options;
std::vector<uint8_t> selection_model;
std::vector<uint8_t> classification_model;
std::vector<uint8_t> embedding_model;
- std::unique_ptr<RegexModelT> regex_model;
std::unique_ptr<SelectionModelOptionsT> selection_options;
std::unique_ptr<ClassificationModelOptionsT> classification_options;
+ std::unique_ptr<RegexModelT> regex_model;
std::unique_ptr<DatetimeModelT> datetime_model;
std::unique_ptr<ModelTriggeringOptionsT> triggering_options;
- std::string name;
+ ModeFlag enabled_modes;
ModelT()
- : version(0) {
+ : version(0),
+ enabled_modes(ModeFlag_ALL) {
}
};
@@ -1158,17 +1233,18 @@
enum {
VT_LOCALES = 4,
VT_VERSION = 6,
- VT_SELECTION_FEATURE_OPTIONS = 8,
- VT_CLASSIFICATION_FEATURE_OPTIONS = 10,
- VT_SELECTION_MODEL = 12,
- VT_CLASSIFICATION_MODEL = 14,
- VT_EMBEDDING_MODEL = 16,
- VT_REGEX_MODEL = 18,
+ VT_NAME = 8,
+ VT_SELECTION_FEATURE_OPTIONS = 10,
+ VT_CLASSIFICATION_FEATURE_OPTIONS = 12,
+ VT_SELECTION_MODEL = 14,
+ VT_CLASSIFICATION_MODEL = 16,
+ VT_EMBEDDING_MODEL = 18,
VT_SELECTION_OPTIONS = 20,
VT_CLASSIFICATION_OPTIONS = 22,
- VT_DATETIME_MODEL = 24,
- VT_TRIGGERING_OPTIONS = 26,
- VT_NAME = 28
+ VT_REGEX_MODEL = 24,
+ VT_DATETIME_MODEL = 26,
+ VT_TRIGGERING_OPTIONS = 28,
+ VT_ENABLED_MODES = 30
};
const flatbuffers::String *locales() const {
return GetPointer<const flatbuffers::String *>(VT_LOCALES);
@@ -1176,6 +1252,9 @@
int32_t version() const {
return GetField<int32_t>(VT_VERSION, 0);
}
+ const flatbuffers::String *name() const {
+ return GetPointer<const flatbuffers::String *>(VT_NAME);
+ }
const FeatureProcessorOptions *selection_feature_options() const {
return GetPointer<const FeatureProcessorOptions *>(VT_SELECTION_FEATURE_OPTIONS);
}
@@ -1191,29 +1270,31 @@
const flatbuffers::Vector<uint8_t> *embedding_model() const {
return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_EMBEDDING_MODEL);
}
- const RegexModel *regex_model() const {
- return GetPointer<const RegexModel *>(VT_REGEX_MODEL);
- }
const SelectionModelOptions *selection_options() const {
return GetPointer<const SelectionModelOptions *>(VT_SELECTION_OPTIONS);
}
const ClassificationModelOptions *classification_options() const {
return GetPointer<const ClassificationModelOptions *>(VT_CLASSIFICATION_OPTIONS);
}
+ const RegexModel *regex_model() const {
+ return GetPointer<const RegexModel *>(VT_REGEX_MODEL);
+ }
const DatetimeModel *datetime_model() const {
return GetPointer<const DatetimeModel *>(VT_DATETIME_MODEL);
}
const ModelTriggeringOptions *triggering_options() const {
return GetPointer<const ModelTriggeringOptions *>(VT_TRIGGERING_OPTIONS);
}
- const flatbuffers::String *name() const {
- return GetPointer<const flatbuffers::String *>(VT_NAME);
+ ModeFlag enabled_modes() const {
+ return static_cast<ModeFlag>(GetField<int32_t>(VT_ENABLED_MODES, 7));
}
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyOffset(verifier, VT_LOCALES) &&
verifier.Verify(locales()) &&
VerifyField<int32_t>(verifier, VT_VERSION) &&
+ VerifyOffset(verifier, VT_NAME) &&
+ verifier.Verify(name()) &&
VerifyOffset(verifier, VT_SELECTION_FEATURE_OPTIONS) &&
verifier.VerifyTable(selection_feature_options()) &&
VerifyOffset(verifier, VT_CLASSIFICATION_FEATURE_OPTIONS) &&
@@ -1224,18 +1305,17 @@
verifier.Verify(classification_model()) &&
VerifyOffset(verifier, VT_EMBEDDING_MODEL) &&
verifier.Verify(embedding_model()) &&
- VerifyOffset(verifier, VT_REGEX_MODEL) &&
- verifier.VerifyTable(regex_model()) &&
VerifyOffset(verifier, VT_SELECTION_OPTIONS) &&
verifier.VerifyTable(selection_options()) &&
VerifyOffset(verifier, VT_CLASSIFICATION_OPTIONS) &&
verifier.VerifyTable(classification_options()) &&
+ VerifyOffset(verifier, VT_REGEX_MODEL) &&
+ verifier.VerifyTable(regex_model()) &&
VerifyOffset(verifier, VT_DATETIME_MODEL) &&
verifier.VerifyTable(datetime_model()) &&
VerifyOffset(verifier, VT_TRIGGERING_OPTIONS) &&
verifier.VerifyTable(triggering_options()) &&
- VerifyOffset(verifier, VT_NAME) &&
- verifier.Verify(name()) &&
+ VerifyField<int32_t>(verifier, VT_ENABLED_MODES) &&
verifier.EndTable();
}
ModelT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@@ -1252,6 +1332,9 @@
void add_version(int32_t version) {
fbb_.AddElement<int32_t>(Model::VT_VERSION, version, 0);
}
+ void add_name(flatbuffers::Offset<flatbuffers::String> name) {
+ fbb_.AddOffset(Model::VT_NAME, name);
+ }
void add_selection_feature_options(flatbuffers::Offset<FeatureProcessorOptions> selection_feature_options) {
fbb_.AddOffset(Model::VT_SELECTION_FEATURE_OPTIONS, selection_feature_options);
}
@@ -1267,23 +1350,23 @@
void add_embedding_model(flatbuffers::Offset<flatbuffers::Vector<uint8_t>> embedding_model) {
fbb_.AddOffset(Model::VT_EMBEDDING_MODEL, embedding_model);
}
- void add_regex_model(flatbuffers::Offset<RegexModel> regex_model) {
- fbb_.AddOffset(Model::VT_REGEX_MODEL, regex_model);
- }
void add_selection_options(flatbuffers::Offset<SelectionModelOptions> selection_options) {
fbb_.AddOffset(Model::VT_SELECTION_OPTIONS, selection_options);
}
void add_classification_options(flatbuffers::Offset<ClassificationModelOptions> classification_options) {
fbb_.AddOffset(Model::VT_CLASSIFICATION_OPTIONS, classification_options);
}
+ void add_regex_model(flatbuffers::Offset<RegexModel> regex_model) {
+ fbb_.AddOffset(Model::VT_REGEX_MODEL, regex_model);
+ }
void add_datetime_model(flatbuffers::Offset<DatetimeModel> datetime_model) {
fbb_.AddOffset(Model::VT_DATETIME_MODEL, datetime_model);
}
void add_triggering_options(flatbuffers::Offset<ModelTriggeringOptions> triggering_options) {
fbb_.AddOffset(Model::VT_TRIGGERING_OPTIONS, triggering_options);
}
- void add_name(flatbuffers::Offset<flatbuffers::String> name) {
- fbb_.AddOffset(Model::VT_NAME, name);
+ void add_enabled_modes(ModeFlag enabled_modes) {
+ fbb_.AddElement<int32_t>(Model::VT_ENABLED_MODES, static_cast<int32_t>(enabled_modes), 7);
}
explicit ModelBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
@@ -1301,29 +1384,31 @@
flatbuffers::FlatBufferBuilder &_fbb,
flatbuffers::Offset<flatbuffers::String> locales = 0,
int32_t version = 0,
+ flatbuffers::Offset<flatbuffers::String> name = 0,
flatbuffers::Offset<FeatureProcessorOptions> selection_feature_options = 0,
flatbuffers::Offset<FeatureProcessorOptions> classification_feature_options = 0,
flatbuffers::Offset<flatbuffers::Vector<uint8_t>> selection_model = 0,
flatbuffers::Offset<flatbuffers::Vector<uint8_t>> classification_model = 0,
flatbuffers::Offset<flatbuffers::Vector<uint8_t>> embedding_model = 0,
- flatbuffers::Offset<RegexModel> regex_model = 0,
flatbuffers::Offset<SelectionModelOptions> selection_options = 0,
flatbuffers::Offset<ClassificationModelOptions> classification_options = 0,
+ flatbuffers::Offset<RegexModel> regex_model = 0,
flatbuffers::Offset<DatetimeModel> datetime_model = 0,
flatbuffers::Offset<ModelTriggeringOptions> triggering_options = 0,
- flatbuffers::Offset<flatbuffers::String> name = 0) {
+ ModeFlag enabled_modes = ModeFlag_ALL) {
ModelBuilder builder_(_fbb);
- builder_.add_name(name);
+ builder_.add_enabled_modes(enabled_modes);
builder_.add_triggering_options(triggering_options);
builder_.add_datetime_model(datetime_model);
+ builder_.add_regex_model(regex_model);
builder_.add_classification_options(classification_options);
builder_.add_selection_options(selection_options);
- builder_.add_regex_model(regex_model);
builder_.add_embedding_model(embedding_model);
builder_.add_classification_model(classification_model);
builder_.add_selection_model(selection_model);
builder_.add_classification_feature_options(classification_feature_options);
builder_.add_selection_feature_options(selection_feature_options);
+ builder_.add_name(name);
builder_.add_version(version);
builder_.add_locales(locales);
return builder_.Finish();
@@ -1333,32 +1418,34 @@
flatbuffers::FlatBufferBuilder &_fbb,
const char *locales = nullptr,
int32_t version = 0,
+ const char *name = nullptr,
flatbuffers::Offset<FeatureProcessorOptions> selection_feature_options = 0,
flatbuffers::Offset<FeatureProcessorOptions> classification_feature_options = 0,
const std::vector<uint8_t> *selection_model = nullptr,
const std::vector<uint8_t> *classification_model = nullptr,
const std::vector<uint8_t> *embedding_model = nullptr,
- flatbuffers::Offset<RegexModel> regex_model = 0,
flatbuffers::Offset<SelectionModelOptions> selection_options = 0,
flatbuffers::Offset<ClassificationModelOptions> classification_options = 0,
+ flatbuffers::Offset<RegexModel> regex_model = 0,
flatbuffers::Offset<DatetimeModel> datetime_model = 0,
flatbuffers::Offset<ModelTriggeringOptions> triggering_options = 0,
- const char *name = nullptr) {
+ ModeFlag enabled_modes = ModeFlag_ALL) {
return libtextclassifier2::CreateModel(
_fbb,
locales ? _fbb.CreateString(locales) : 0,
version,
+ name ? _fbb.CreateString(name) : 0,
selection_feature_options,
classification_feature_options,
selection_model ? _fbb.CreateVector<uint8_t>(*selection_model) : 0,
classification_model ? _fbb.CreateVector<uint8_t>(*classification_model) : 0,
embedding_model ? _fbb.CreateVector<uint8_t>(*embedding_model) : 0,
- regex_model,
selection_options,
classification_options,
+ regex_model,
datetime_model,
triggering_options,
- name ? _fbb.CreateString(name) : 0);
+ enabled_modes);
}
flatbuffers::Offset<Model> CreateModel(flatbuffers::FlatBufferBuilder &_fbb, const ModelT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
@@ -1530,6 +1617,7 @@
int32_t num_tokens_after;
bool include_inside_bag;
bool include_inside_length;
+ bool score_single_token_spans_as_zero;
BoundsSensitiveFeaturesT()
: enabled(false),
num_tokens_before(0),
@@ -1537,7 +1625,8 @@
num_tokens_inside_right(0),
num_tokens_after(0),
include_inside_bag(false),
- include_inside_length(false) {
+ include_inside_length(false),
+ score_single_token_spans_as_zero(false) {
}
};
@@ -1550,7 +1639,8 @@
VT_NUM_TOKENS_INSIDE_RIGHT = 10,
VT_NUM_TOKENS_AFTER = 12,
VT_INCLUDE_INSIDE_BAG = 14,
- VT_INCLUDE_INSIDE_LENGTH = 16
+ VT_INCLUDE_INSIDE_LENGTH = 16,
+ VT_SCORE_SINGLE_TOKEN_SPANS_AS_ZERO = 18
};
bool enabled() const {
return GetField<uint8_t>(VT_ENABLED, 0) != 0;
@@ -1573,6 +1663,9 @@
bool include_inside_length() const {
return GetField<uint8_t>(VT_INCLUDE_INSIDE_LENGTH, 0) != 0;
}
+ bool score_single_token_spans_as_zero() const {
+ return GetField<uint8_t>(VT_SCORE_SINGLE_TOKEN_SPANS_AS_ZERO, 0) != 0;
+ }
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyField<uint8_t>(verifier, VT_ENABLED) &&
@@ -1582,6 +1675,7 @@
VerifyField<int32_t>(verifier, VT_NUM_TOKENS_AFTER) &&
VerifyField<uint8_t>(verifier, VT_INCLUDE_INSIDE_BAG) &&
VerifyField<uint8_t>(verifier, VT_INCLUDE_INSIDE_LENGTH) &&
+ VerifyField<uint8_t>(verifier, VT_SCORE_SINGLE_TOKEN_SPANS_AS_ZERO) &&
verifier.EndTable();
}
BoundsSensitiveFeaturesT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@@ -1613,6 +1707,9 @@
void add_include_inside_length(bool include_inside_length) {
fbb_.AddElement<uint8_t>(BoundsSensitiveFeatures::VT_INCLUDE_INSIDE_LENGTH, static_cast<uint8_t>(include_inside_length), 0);
}
+ void add_score_single_token_spans_as_zero(bool score_single_token_spans_as_zero) {
+ fbb_.AddElement<uint8_t>(BoundsSensitiveFeatures::VT_SCORE_SINGLE_TOKEN_SPANS_AS_ZERO, static_cast<uint8_t>(score_single_token_spans_as_zero), 0);
+ }
explicit BoundsSensitiveFeaturesBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
@@ -1633,12 +1730,14 @@
int32_t num_tokens_inside_right = 0,
int32_t num_tokens_after = 0,
bool include_inside_bag = false,
- bool include_inside_length = false) {
+ bool include_inside_length = false,
+ bool score_single_token_spans_as_zero = false) {
BoundsSensitiveFeaturesBuilder builder_(_fbb);
builder_.add_num_tokens_after(num_tokens_after);
builder_.add_num_tokens_inside_right(num_tokens_inside_right);
builder_.add_num_tokens_inside_left(num_tokens_inside_left);
builder_.add_num_tokens_before(num_tokens_before);
+ builder_.add_score_single_token_spans_as_zero(score_single_token_spans_as_zero);
builder_.add_include_inside_length(include_inside_length);
builder_.add_include_inside_bag(include_inside_bag);
builder_.add_enabled(enabled);
@@ -1729,6 +1828,7 @@
typedef FeatureProcessorOptions TableType;
int32_t num_buckets;
int32_t embedding_size;
+ int32_t embedding_quantization_bits;
int32_t context_size;
int32_t max_selection_span;
std::vector<int32_t> chargram_orders;
@@ -1757,10 +1857,10 @@
std::unique_ptr<libtextclassifier2::FeatureProcessorOptions_::BoundsSensitiveFeaturesT> bounds_sensitive_features;
std::vector<std::string> allowed_chargrams;
bool tokenize_on_script_change;
- int32_t embedding_quantization_bits;
FeatureProcessorOptionsT()
: num_buckets(-1),
embedding_size(-1),
+ embedding_quantization_bits(8),
context_size(-1),
max_selection_span(-1),
max_word_length(20),
@@ -1777,10 +1877,9 @@
snap_label_span_boundaries_to_containing_tokens(false),
min_supported_codepoint_ratio(0.0f),
feature_version(0),
- tokenization_type(libtextclassifier2::FeatureProcessorOptions_::TokenizationType_INVALID_TOKENIZATION_TYPE),
+ tokenization_type(libtextclassifier2::FeatureProcessorOptions_::TokenizationType_INTERNAL_TOKENIZER),
icu_preserve_whitespace_tokens(false),
- tokenize_on_script_change(false),
- embedding_quantization_bits(8) {
+ tokenize_on_script_change(false) {
}
};
@@ -1789,35 +1888,35 @@
enum {
VT_NUM_BUCKETS = 4,
VT_EMBEDDING_SIZE = 6,
- VT_CONTEXT_SIZE = 8,
- VT_MAX_SELECTION_SPAN = 10,
- VT_CHARGRAM_ORDERS = 12,
- VT_MAX_WORD_LENGTH = 14,
- VT_UNICODE_AWARE_FEATURES = 16,
- VT_EXTRACT_CASE_FEATURE = 18,
- VT_EXTRACT_SELECTION_MASK_FEATURE = 20,
- VT_REGEXP_FEATURE = 22,
- VT_REMAP_DIGITS = 24,
- VT_LOWERCASE_TOKENS = 26,
- VT_SELECTION_REDUCED_OUTPUT_SPACE = 28,
- VT_COLLECTIONS = 30,
- VT_DEFAULT_COLLECTION = 32,
- VT_ONLY_USE_LINE_WITH_CLICK = 34,
- VT_SPLIT_TOKENS_ON_SELECTION_BOUNDARIES = 36,
- VT_TOKENIZATION_CODEPOINT_CONFIG = 38,
- VT_CENTER_TOKEN_SELECTION_METHOD = 40,
- VT_SNAP_LABEL_SPAN_BOUNDARIES_TO_CONTAINING_TOKENS = 42,
- VT_SUPPORTED_CODEPOINT_RANGES = 44,
- VT_INTERNAL_TOKENIZER_CODEPOINT_RANGES = 46,
- VT_MIN_SUPPORTED_CODEPOINT_RATIO = 48,
- VT_FEATURE_VERSION = 50,
- VT_TOKENIZATION_TYPE = 52,
- VT_ICU_PRESERVE_WHITESPACE_TOKENS = 54,
- VT_IGNORED_SPAN_BOUNDARY_CODEPOINTS = 56,
- VT_BOUNDS_SENSITIVE_FEATURES = 58,
- VT_ALLOWED_CHARGRAMS = 60,
- VT_TOKENIZE_ON_SCRIPT_CHANGE = 62,
- VT_EMBEDDING_QUANTIZATION_BITS = 64
+ VT_EMBEDDING_QUANTIZATION_BITS = 8,
+ VT_CONTEXT_SIZE = 10,
+ VT_MAX_SELECTION_SPAN = 12,
+ VT_CHARGRAM_ORDERS = 14,
+ VT_MAX_WORD_LENGTH = 16,
+ VT_UNICODE_AWARE_FEATURES = 18,
+ VT_EXTRACT_CASE_FEATURE = 20,
+ VT_EXTRACT_SELECTION_MASK_FEATURE = 22,
+ VT_REGEXP_FEATURE = 24,
+ VT_REMAP_DIGITS = 26,
+ VT_LOWERCASE_TOKENS = 28,
+ VT_SELECTION_REDUCED_OUTPUT_SPACE = 30,
+ VT_COLLECTIONS = 32,
+ VT_DEFAULT_COLLECTION = 34,
+ VT_ONLY_USE_LINE_WITH_CLICK = 36,
+ VT_SPLIT_TOKENS_ON_SELECTION_BOUNDARIES = 38,
+ VT_TOKENIZATION_CODEPOINT_CONFIG = 40,
+ VT_CENTER_TOKEN_SELECTION_METHOD = 42,
+ VT_SNAP_LABEL_SPAN_BOUNDARIES_TO_CONTAINING_TOKENS = 44,
+ VT_SUPPORTED_CODEPOINT_RANGES = 46,
+ VT_INTERNAL_TOKENIZER_CODEPOINT_RANGES = 48,
+ VT_MIN_SUPPORTED_CODEPOINT_RATIO = 50,
+ VT_FEATURE_VERSION = 52,
+ VT_TOKENIZATION_TYPE = 54,
+ VT_ICU_PRESERVE_WHITESPACE_TOKENS = 56,
+ VT_IGNORED_SPAN_BOUNDARY_CODEPOINTS = 58,
+ VT_BOUNDS_SENSITIVE_FEATURES = 60,
+ VT_ALLOWED_CHARGRAMS = 62,
+ VT_TOKENIZE_ON_SCRIPT_CHANGE = 64
};
int32_t num_buckets() const {
return GetField<int32_t>(VT_NUM_BUCKETS, -1);
@@ -1825,6 +1924,9 @@
int32_t embedding_size() const {
return GetField<int32_t>(VT_EMBEDDING_SIZE, -1);
}
+ int32_t embedding_quantization_bits() const {
+ return GetField<int32_t>(VT_EMBEDDING_QUANTIZATION_BITS, 8);
+ }
int32_t context_size() const {
return GetField<int32_t>(VT_CONTEXT_SIZE, -1);
}
@@ -1892,7 +1994,7 @@
return GetField<int32_t>(VT_FEATURE_VERSION, 0);
}
libtextclassifier2::FeatureProcessorOptions_::TokenizationType tokenization_type() const {
- return static_cast<libtextclassifier2::FeatureProcessorOptions_::TokenizationType>(GetField<int32_t>(VT_TOKENIZATION_TYPE, 0));
+ return static_cast<libtextclassifier2::FeatureProcessorOptions_::TokenizationType>(GetField<int32_t>(VT_TOKENIZATION_TYPE, 1));
}
bool icu_preserve_whitespace_tokens() const {
return GetField<uint8_t>(VT_ICU_PRESERVE_WHITESPACE_TOKENS, 0) != 0;
@@ -1909,13 +2011,11 @@
bool tokenize_on_script_change() const {
return GetField<uint8_t>(VT_TOKENIZE_ON_SCRIPT_CHANGE, 0) != 0;
}
- int32_t embedding_quantization_bits() const {
- return GetField<int32_t>(VT_EMBEDDING_QUANTIZATION_BITS, 8);
- }
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyField<int32_t>(verifier, VT_NUM_BUCKETS) &&
VerifyField<int32_t>(verifier, VT_EMBEDDING_SIZE) &&
+ VerifyField<int32_t>(verifier, VT_EMBEDDING_QUANTIZATION_BITS) &&
VerifyField<int32_t>(verifier, VT_CONTEXT_SIZE) &&
VerifyField<int32_t>(verifier, VT_MAX_SELECTION_SPAN) &&
VerifyOffset(verifier, VT_CHARGRAM_ORDERS) &&
@@ -1959,7 +2059,6 @@
verifier.Verify(allowed_chargrams()) &&
verifier.VerifyVectorOfStrings(allowed_chargrams()) &&
VerifyField<uint8_t>(verifier, VT_TOKENIZE_ON_SCRIPT_CHANGE) &&
- VerifyField<int32_t>(verifier, VT_EMBEDDING_QUANTIZATION_BITS) &&
verifier.EndTable();
}
FeatureProcessorOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@@ -1976,6 +2075,9 @@
void add_embedding_size(int32_t embedding_size) {
fbb_.AddElement<int32_t>(FeatureProcessorOptions::VT_EMBEDDING_SIZE, embedding_size, -1);
}
+ void add_embedding_quantization_bits(int32_t embedding_quantization_bits) {
+ fbb_.AddElement<int32_t>(FeatureProcessorOptions::VT_EMBEDDING_QUANTIZATION_BITS, embedding_quantization_bits, 8);
+ }
void add_context_size(int32_t context_size) {
fbb_.AddElement<int32_t>(FeatureProcessorOptions::VT_CONTEXT_SIZE, context_size, -1);
}
@@ -2043,7 +2145,7 @@
fbb_.AddElement<int32_t>(FeatureProcessorOptions::VT_FEATURE_VERSION, feature_version, 0);
}
void add_tokenization_type(libtextclassifier2::FeatureProcessorOptions_::TokenizationType tokenization_type) {
- fbb_.AddElement<int32_t>(FeatureProcessorOptions::VT_TOKENIZATION_TYPE, static_cast<int32_t>(tokenization_type), 0);
+ fbb_.AddElement<int32_t>(FeatureProcessorOptions::VT_TOKENIZATION_TYPE, static_cast<int32_t>(tokenization_type), 1);
}
void add_icu_preserve_whitespace_tokens(bool icu_preserve_whitespace_tokens) {
fbb_.AddElement<uint8_t>(FeatureProcessorOptions::VT_ICU_PRESERVE_WHITESPACE_TOKENS, static_cast<uint8_t>(icu_preserve_whitespace_tokens), 0);
@@ -2060,9 +2162,6 @@
void add_tokenize_on_script_change(bool tokenize_on_script_change) {
fbb_.AddElement<uint8_t>(FeatureProcessorOptions::VT_TOKENIZE_ON_SCRIPT_CHANGE, static_cast<uint8_t>(tokenize_on_script_change), 0);
}
- void add_embedding_quantization_bits(int32_t embedding_quantization_bits) {
- fbb_.AddElement<int32_t>(FeatureProcessorOptions::VT_EMBEDDING_QUANTIZATION_BITS, embedding_quantization_bits, 8);
- }
explicit FeatureProcessorOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
@@ -2079,6 +2178,7 @@
flatbuffers::FlatBufferBuilder &_fbb,
int32_t num_buckets = -1,
int32_t embedding_size = -1,
+ int32_t embedding_quantization_bits = 8,
int32_t context_size = -1,
int32_t max_selection_span = -1,
flatbuffers::Offset<flatbuffers::Vector<int32_t>> chargram_orders = 0,
@@ -2101,15 +2201,13 @@
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CodepointRange>>> internal_tokenizer_codepoint_ranges = 0,
float min_supported_codepoint_ratio = 0.0f,
int32_t feature_version = 0,
- libtextclassifier2::FeatureProcessorOptions_::TokenizationType tokenization_type = libtextclassifier2::FeatureProcessorOptions_::TokenizationType_INVALID_TOKENIZATION_TYPE,
+ libtextclassifier2::FeatureProcessorOptions_::TokenizationType tokenization_type = libtextclassifier2::FeatureProcessorOptions_::TokenizationType_INTERNAL_TOKENIZER,
bool icu_preserve_whitespace_tokens = false,
flatbuffers::Offset<flatbuffers::Vector<int32_t>> ignored_span_boundary_codepoints = 0,
flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::BoundsSensitiveFeatures> bounds_sensitive_features = 0,
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> allowed_chargrams = 0,
- bool tokenize_on_script_change = false,
- int32_t embedding_quantization_bits = 8) {
+ bool tokenize_on_script_change = false) {
FeatureProcessorOptionsBuilder builder_(_fbb);
- builder_.add_embedding_quantization_bits(embedding_quantization_bits);
builder_.add_allowed_chargrams(allowed_chargrams);
builder_.add_bounds_sensitive_features(bounds_sensitive_features);
builder_.add_ignored_span_boundary_codepoints(ignored_span_boundary_codepoints);
@@ -2127,6 +2225,7 @@
builder_.add_chargram_orders(chargram_orders);
builder_.add_max_selection_span(max_selection_span);
builder_.add_context_size(context_size);
+ builder_.add_embedding_quantization_bits(embedding_quantization_bits);
builder_.add_embedding_size(embedding_size);
builder_.add_num_buckets(num_buckets);
builder_.add_tokenize_on_script_change(tokenize_on_script_change);
@@ -2147,6 +2246,7 @@
flatbuffers::FlatBufferBuilder &_fbb,
int32_t num_buckets = -1,
int32_t embedding_size = -1,
+ int32_t embedding_quantization_bits = 8,
int32_t context_size = -1,
int32_t max_selection_span = -1,
const std::vector<int32_t> *chargram_orders = nullptr,
@@ -2169,17 +2269,17 @@
const std::vector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CodepointRange>> *internal_tokenizer_codepoint_ranges = nullptr,
float min_supported_codepoint_ratio = 0.0f,
int32_t feature_version = 0,
- libtextclassifier2::FeatureProcessorOptions_::TokenizationType tokenization_type = libtextclassifier2::FeatureProcessorOptions_::TokenizationType_INVALID_TOKENIZATION_TYPE,
+ libtextclassifier2::FeatureProcessorOptions_::TokenizationType tokenization_type = libtextclassifier2::FeatureProcessorOptions_::TokenizationType_INTERNAL_TOKENIZER,
bool icu_preserve_whitespace_tokens = false,
const std::vector<int32_t> *ignored_span_boundary_codepoints = nullptr,
flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::BoundsSensitiveFeatures> bounds_sensitive_features = 0,
const std::vector<flatbuffers::Offset<flatbuffers::String>> *allowed_chargrams = nullptr,
- bool tokenize_on_script_change = false,
- int32_t embedding_quantization_bits = 8) {
+ bool tokenize_on_script_change = false) {
return libtextclassifier2::CreateFeatureProcessorOptions(
_fbb,
num_buckets,
embedding_size,
+ embedding_quantization_bits,
context_size,
max_selection_span,
chargram_orders ? _fbb.CreateVector<int32_t>(*chargram_orders) : 0,
@@ -2207,8 +2307,7 @@
ignored_span_boundary_codepoints ? _fbb.CreateVector<int32_t>(*ignored_span_boundary_codepoints) : 0,
bounds_sensitive_features,
allowed_chargrams ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*allowed_chargrams) : 0,
- tokenize_on_script_change,
- embedding_quantization_bits);
+ tokenize_on_script_change);
}
flatbuffers::Offset<FeatureProcessorOptions> CreateFeatureProcessorOptions(flatbuffers::FlatBufferBuilder &_fbb, const FeatureProcessorOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
@@ -2287,11 +2386,10 @@
(void)_resolver;
{ auto _e = collection_name(); if (_e) _o->collection_name = _e->str(); };
{ auto _e = pattern(); if (_e) _o->pattern = _e->str(); };
- { auto _e = enabled_for_annotation(); _o->enabled_for_annotation = _e; };
- { auto _e = enabled_for_classification(); _o->enabled_for_classification = _e; };
- { auto _e = enabled_for_selection(); _o->enabled_for_selection = _e; };
+ { auto _e = enabled_modes(); _o->enabled_modes = _e; };
{ auto _e = target_classification_score(); _o->target_classification_score = _e; };
{ auto _e = priority_score(); _o->priority_score = _e; };
+ { auto _e = use_approximate_matching(); _o->use_approximate_matching = _e; };
}
inline flatbuffers::Offset<Pattern> Pattern::Pack(flatbuffers::FlatBufferBuilder &_fbb, const PatternT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
@@ -2304,20 +2402,18 @@
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const PatternT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
auto _collection_name = _o->collection_name.empty() ? 0 : _fbb.CreateString(_o->collection_name);
auto _pattern = _o->pattern.empty() ? 0 : _fbb.CreateString(_o->pattern);
- auto _enabled_for_annotation = _o->enabled_for_annotation;
- auto _enabled_for_classification = _o->enabled_for_classification;
- auto _enabled_for_selection = _o->enabled_for_selection;
+ auto _enabled_modes = _o->enabled_modes;
auto _target_classification_score = _o->target_classification_score;
auto _priority_score = _o->priority_score;
+ auto _use_approximate_matching = _o->use_approximate_matching;
return libtextclassifier2::RegexModel_::CreatePattern(
_fbb,
_collection_name,
_pattern,
- _enabled_for_annotation,
- _enabled_for_classification,
- _enabled_for_selection,
+ _enabled_modes,
_target_classification_score,
- _priority_score);
+ _priority_score,
+ _use_approximate_matching);
}
} // namespace RegexModel_
@@ -2361,6 +2457,7 @@
{ auto _e = locales(); if (_e) { _o->locales.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->locales[_i] = _e->Get(_i); } } };
{ auto _e = target_classification_score(); _o->target_classification_score = _e; };
{ auto _e = priority_score(); _o->priority_score = _e; };
+ { auto _e = enabled_modes(); _o->enabled_modes = _e; };
}
inline flatbuffers::Offset<DatetimeModelPattern> DatetimeModelPattern::Pack(flatbuffers::FlatBufferBuilder &_fbb, const DatetimeModelPatternT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
@@ -2375,12 +2472,14 @@
auto _locales = _o->locales.size() ? _fbb.CreateVector(_o->locales) : 0;
auto _target_classification_score = _o->target_classification_score;
auto _priority_score = _o->priority_score;
+ auto _enabled_modes = _o->enabled_modes;
return libtextclassifier2::CreateDatetimeModelPattern(
_fbb,
_regexes,
_locales,
_target_classification_score,
- _priority_score);
+ _priority_score,
+ _enabled_modes);
}
inline DatetimeModelExtractorT *DatetimeModelExtractor::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
@@ -2427,6 +2526,7 @@
{ auto _e = locales(); if (_e) { _o->locales.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->locales[_i] = _e->Get(_i)->str(); } } };
{ auto _e = patterns(); if (_e) { _o->patterns.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->patterns[_i] = std::unique_ptr<DatetimeModelPatternT>(_e->Get(_i)->UnPack(_resolver)); } } };
{ auto _e = extractors(); if (_e) { _o->extractors.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->extractors[_i] = std::unique_ptr<DatetimeModelExtractorT>(_e->Get(_i)->UnPack(_resolver)); } } };
+ { auto _e = use_extractors_for_locating(); _o->use_extractors_for_locating = _e; };
}
inline flatbuffers::Offset<DatetimeModel> DatetimeModel::Pack(flatbuffers::FlatBufferBuilder &_fbb, const DatetimeModelT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
@@ -2440,11 +2540,13 @@
auto _locales = _o->locales.size() ? _fbb.CreateVectorOfStrings(_o->locales) : 0;
auto _patterns = _o->patterns.size() ? _fbb.CreateVector<flatbuffers::Offset<DatetimeModelPattern>> (_o->patterns.size(), [](size_t i, _VectorArgs *__va) { return CreateDatetimeModelPattern(*__va->__fbb, __va->__o->patterns[i].get(), __va->__rehasher); }, &_va ) : 0;
auto _extractors = _o->extractors.size() ? _fbb.CreateVector<flatbuffers::Offset<DatetimeModelExtractor>> (_o->extractors.size(), [](size_t i, _VectorArgs *__va) { return CreateDatetimeModelExtractor(*__va->__fbb, __va->__o->extractors[i].get(), __va->__rehasher); }, &_va ) : 0;
+ auto _use_extractors_for_locating = _o->use_extractors_for_locating;
return libtextclassifier2::CreateDatetimeModel(
_fbb,
_locales,
_patterns,
- _extractors);
+ _extractors,
+ _use_extractors_for_locating);
}
inline ModelTriggeringOptionsT *ModelTriggeringOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
@@ -2457,6 +2559,7 @@
(void)_o;
(void)_resolver;
{ auto _e = min_annotate_confidence(); _o->min_annotate_confidence = _e; };
+ { auto _e = enabled_modes(); _o->enabled_modes = _e; };
}
inline flatbuffers::Offset<ModelTriggeringOptions> ModelTriggeringOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ModelTriggeringOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
@@ -2468,9 +2571,11 @@
(void)_o;
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const ModelTriggeringOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
auto _min_annotate_confidence = _o->min_annotate_confidence;
+ auto _enabled_modes = _o->enabled_modes;
return libtextclassifier2::CreateModelTriggeringOptions(
_fbb,
- _min_annotate_confidence);
+ _min_annotate_confidence,
+ _enabled_modes);
}
inline ModelT *Model::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
@@ -2484,17 +2589,18 @@
(void)_resolver;
{ auto _e = locales(); if (_e) _o->locales = _e->str(); };
{ auto _e = version(); _o->version = _e; };
+ { auto _e = name(); if (_e) _o->name = _e->str(); };
{ auto _e = selection_feature_options(); if (_e) _o->selection_feature_options = std::unique_ptr<FeatureProcessorOptionsT>(_e->UnPack(_resolver)); };
{ auto _e = classification_feature_options(); if (_e) _o->classification_feature_options = std::unique_ptr<FeatureProcessorOptionsT>(_e->UnPack(_resolver)); };
{ auto _e = selection_model(); if (_e) { _o->selection_model.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->selection_model[_i] = _e->Get(_i); } } };
{ auto _e = classification_model(); if (_e) { _o->classification_model.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->classification_model[_i] = _e->Get(_i); } } };
{ auto _e = embedding_model(); if (_e) { _o->embedding_model.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->embedding_model[_i] = _e->Get(_i); } } };
- { auto _e = regex_model(); if (_e) _o->regex_model = std::unique_ptr<RegexModelT>(_e->UnPack(_resolver)); };
{ auto _e = selection_options(); if (_e) _o->selection_options = std::unique_ptr<SelectionModelOptionsT>(_e->UnPack(_resolver)); };
{ auto _e = classification_options(); if (_e) _o->classification_options = std::unique_ptr<ClassificationModelOptionsT>(_e->UnPack(_resolver)); };
+ { auto _e = regex_model(); if (_e) _o->regex_model = std::unique_ptr<RegexModelT>(_e->UnPack(_resolver)); };
{ auto _e = datetime_model(); if (_e) _o->datetime_model = std::unique_ptr<DatetimeModelT>(_e->UnPack(_resolver)); };
{ auto _e = triggering_options(); if (_e) _o->triggering_options = std::unique_ptr<ModelTriggeringOptionsT>(_e->UnPack(_resolver)); };
- { auto _e = name(); if (_e) _o->name = _e->str(); };
+ { auto _e = enabled_modes(); _o->enabled_modes = _e; };
}
inline flatbuffers::Offset<Model> Model::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ModelT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
@@ -2507,32 +2613,34 @@
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const ModelT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
auto _locales = _o->locales.empty() ? 0 : _fbb.CreateString(_o->locales);
auto _version = _o->version;
+ auto _name = _o->name.empty() ? 0 : _fbb.CreateString(_o->name);
auto _selection_feature_options = _o->selection_feature_options ? CreateFeatureProcessorOptions(_fbb, _o->selection_feature_options.get(), _rehasher) : 0;
auto _classification_feature_options = _o->classification_feature_options ? CreateFeatureProcessorOptions(_fbb, _o->classification_feature_options.get(), _rehasher) : 0;
auto _selection_model = _o->selection_model.size() ? _fbb.CreateVector(_o->selection_model) : 0;
auto _classification_model = _o->classification_model.size() ? _fbb.CreateVector(_o->classification_model) : 0;
auto _embedding_model = _o->embedding_model.size() ? _fbb.CreateVector(_o->embedding_model) : 0;
- auto _regex_model = _o->regex_model ? CreateRegexModel(_fbb, _o->regex_model.get(), _rehasher) : 0;
auto _selection_options = _o->selection_options ? CreateSelectionModelOptions(_fbb, _o->selection_options.get(), _rehasher) : 0;
auto _classification_options = _o->classification_options ? CreateClassificationModelOptions(_fbb, _o->classification_options.get(), _rehasher) : 0;
+ auto _regex_model = _o->regex_model ? CreateRegexModel(_fbb, _o->regex_model.get(), _rehasher) : 0;
auto _datetime_model = _o->datetime_model ? CreateDatetimeModel(_fbb, _o->datetime_model.get(), _rehasher) : 0;
auto _triggering_options = _o->triggering_options ? CreateModelTriggeringOptions(_fbb, _o->triggering_options.get(), _rehasher) : 0;
- auto _name = _o->name.empty() ? 0 : _fbb.CreateString(_o->name);
+ auto _enabled_modes = _o->enabled_modes;
return libtextclassifier2::CreateModel(
_fbb,
_locales,
_version,
+ _name,
_selection_feature_options,
_classification_feature_options,
_selection_model,
_classification_model,
_embedding_model,
- _regex_model,
_selection_options,
_classification_options,
+ _regex_model,
_datetime_model,
_triggering_options,
- _name);
+ _enabled_modes);
}
inline TokenizationCodepointRangeT *TokenizationCodepointRange::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
@@ -2617,6 +2725,7 @@
{ auto _e = num_tokens_after(); _o->num_tokens_after = _e; };
{ auto _e = include_inside_bag(); _o->include_inside_bag = _e; };
{ auto _e = include_inside_length(); _o->include_inside_length = _e; };
+ { auto _e = score_single_token_spans_as_zero(); _o->score_single_token_spans_as_zero = _e; };
}
inline flatbuffers::Offset<BoundsSensitiveFeatures> BoundsSensitiveFeatures::Pack(flatbuffers::FlatBufferBuilder &_fbb, const BoundsSensitiveFeaturesT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
@@ -2634,6 +2743,7 @@
auto _num_tokens_after = _o->num_tokens_after;
auto _include_inside_bag = _o->include_inside_bag;
auto _include_inside_length = _o->include_inside_length;
+ auto _score_single_token_spans_as_zero = _o->score_single_token_spans_as_zero;
return libtextclassifier2::FeatureProcessorOptions_::CreateBoundsSensitiveFeatures(
_fbb,
_enabled,
@@ -2642,7 +2752,8 @@
_num_tokens_inside_right,
_num_tokens_after,
_include_inside_bag,
- _include_inside_length);
+ _include_inside_length,
+ _score_single_token_spans_as_zero);
}
inline AlternativeCollectionMapEntryT *AlternativeCollectionMapEntry::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
@@ -2687,6 +2798,7 @@
(void)_resolver;
{ auto _e = num_buckets(); _o->num_buckets = _e; };
{ auto _e = embedding_size(); _o->embedding_size = _e; };
+ { auto _e = embedding_quantization_bits(); _o->embedding_quantization_bits = _e; };
{ auto _e = context_size(); _o->context_size = _e; };
{ auto _e = max_selection_span(); _o->max_selection_span = _e; };
{ auto _e = chargram_orders(); if (_e) { _o->chargram_orders.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->chargram_orders[_i] = _e->Get(_i); } } };
@@ -2715,7 +2827,6 @@
{ auto _e = bounds_sensitive_features(); if (_e) _o->bounds_sensitive_features = std::unique_ptr<libtextclassifier2::FeatureProcessorOptions_::BoundsSensitiveFeaturesT>(_e->UnPack(_resolver)); };
{ auto _e = allowed_chargrams(); if (_e) { _o->allowed_chargrams.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->allowed_chargrams[_i] = _e->Get(_i)->str(); } } };
{ auto _e = tokenize_on_script_change(); _o->tokenize_on_script_change = _e; };
- { auto _e = embedding_quantization_bits(); _o->embedding_quantization_bits = _e; };
}
inline flatbuffers::Offset<FeatureProcessorOptions> FeatureProcessorOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const FeatureProcessorOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
@@ -2728,6 +2839,7 @@
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const FeatureProcessorOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
auto _num_buckets = _o->num_buckets;
auto _embedding_size = _o->embedding_size;
+ auto _embedding_quantization_bits = _o->embedding_quantization_bits;
auto _context_size = _o->context_size;
auto _max_selection_span = _o->max_selection_span;
auto _chargram_orders = _o->chargram_orders.size() ? _fbb.CreateVector(_o->chargram_orders) : 0;
@@ -2756,11 +2868,11 @@
auto _bounds_sensitive_features = _o->bounds_sensitive_features ? CreateBoundsSensitiveFeatures(_fbb, _o->bounds_sensitive_features.get(), _rehasher) : 0;
auto _allowed_chargrams = _o->allowed_chargrams.size() ? _fbb.CreateVectorOfStrings(_o->allowed_chargrams) : 0;
auto _tokenize_on_script_change = _o->tokenize_on_script_change;
- auto _embedding_quantization_bits = _o->embedding_quantization_bits;
return libtextclassifier2::CreateFeatureProcessorOptions(
_fbb,
_num_buckets,
_embedding_size,
+ _embedding_quantization_bits,
_context_size,
_max_selection_span,
_chargram_orders,
@@ -2788,8 +2900,7 @@
_ignored_span_boundary_codepoints,
_bounds_sensitive_features,
_allowed_chargrams,
- _tokenize_on_script_change,
- _embedding_quantization_bits);
+ _tokenize_on_script_change);
}
inline const libtextclassifier2::Model *GetModel(const void *buf) {
diff --git a/models/textclassifier.en.model b/models/textclassifier.en.model
index daeb2a4..0452c1e 100644
--- a/models/textclassifier.en.model
+++ b/models/textclassifier.en.model
Binary files differ
diff --git a/test_data/test_model.fb b/test_data/test_model.fb
index daeb2a4..fc8353a 100644
--- a/test_data/test_model.fb
+++ b/test_data/test_model.fb
Binary files differ
diff --git a/test_data/test_model_cc.fb b/test_data/test_model_cc.fb
index 05b29d1..b396943 100644
--- a/test_data/test_model_cc.fb
+++ b/test_data/test_model_cc.fb
Binary files differ
diff --git a/test_data/wrong_embeddings.fb b/test_data/wrong_embeddings.fb
index 79c5799..000f739 100644
--- a/test_data/wrong_embeddings.fb
+++ b/test_data/wrong_embeddings.fb
Binary files differ
diff --git a/text-classifier.cc b/text-classifier.cc
index 303afee..67346b9 100644
--- a/text-classifier.cc
+++ b/text-classifier.cc
@@ -47,6 +47,28 @@
}
} // namespace
+tflite::Interpreter* InterpreterManager::SelectionInterpreter() {
+ if (!selection_interpreter_) {
+ TC_CHECK(selection_executor_);
+ selection_interpreter_ = selection_executor_->CreateInterpreter();
+ if (!selection_interpreter_) {
+ TC_LOG(ERROR) << "Could not build TFLite interpreter.";
+ }
+ }
+ return selection_interpreter_.get();
+}
+
+tflite::Interpreter* InterpreterManager::ClassificationInterpreter() {
+ if (!classification_interpreter_) {
+ TC_CHECK(classification_executor_);
+ classification_interpreter_ = classification_executor_->CreateInterpreter();
+ if (!classification_interpreter_) {
+ TC_LOG(ERROR) << "Could not build TFLite interpreter.";
+ }
+ }
+ return classification_interpreter_.get();
+}
+
std::unique_ptr<TextClassifier> TextClassifier::FromUnownedBuffer(
const char* buffer, int size, const UniLib* unilib) {
const Model* model = LoadAndVerifyModel(buffer, size);
@@ -112,58 +134,113 @@
return;
}
- if (!model_->selection_options()) {
- TC_LOG(ERROR) << "No selection options.";
- return;
+ const bool model_enabled_for_annotation =
+ (model_->triggering_options() != nullptr &&
+ (model_->triggering_options()->enabled_modes() & ModeFlag_ANNOTATION));
+ const bool model_enabled_for_classification =
+ (model_->triggering_options() != nullptr &&
+ (model_->triggering_options()->enabled_modes() &
+ ModeFlag_CLASSIFICATION));
+ const bool model_enabled_for_selection =
+ (model_->triggering_options() != nullptr &&
+ (model_->triggering_options()->enabled_modes() & ModeFlag_SELECTION));
+
+ // Annotation requires the selection model.
+ if (model_enabled_for_annotation || model_enabled_for_selection) {
+ if (!model_->selection_options()) {
+ TC_LOG(ERROR) << "No selection options.";
+ return;
+ }
+ if (!model_->selection_feature_options()) {
+ TC_LOG(ERROR) << "No selection feature options.";
+ return;
+ }
+ if (!model_->selection_feature_options()->bounds_sensitive_features()) {
+ TC_LOG(ERROR) << "No selection bounds sensitive feature options.";
+ return;
+ }
+ if (!model_->selection_model()) {
+ TC_LOG(ERROR) << "No selection model.";
+ return;
+ }
+ selection_executor_.reset(
+ new ModelExecutor(flatbuffers::GetRoot<tflite::Model>(
+ model_->selection_model()->data())));
+ if (!selection_executor_) {
+ TC_LOG(ERROR) << "Could not initialize selection executor.";
+ return;
+ }
+ selection_feature_processor_.reset(
+ new FeatureProcessor(model_->selection_feature_options(), unilib_));
}
- if (!model_->classification_options()) {
- TC_LOG(ERROR) << "No classification options.";
- return;
+ // Annotation requires the classification model for conflict resolution and
+ // scoring.
+ // Selection requires the classification model for conflict resolution.
+ if (model_enabled_for_annotation || model_enabled_for_classification ||
+ model_enabled_for_selection) {
+ if (!model_->classification_options()) {
+ TC_LOG(ERROR) << "No classification options.";
+ return;
+ }
+
+ if (!model_->classification_feature_options()) {
+ TC_LOG(ERROR) << "No classification feature options.";
+ return;
+ }
+
+ if (!model_->classification_feature_options()
+ ->bounds_sensitive_features()) {
+ TC_LOG(ERROR) << "No classification bounds sensitive feature options.";
+ return;
+ }
+ if (!model_->classification_model()) {
+ TC_LOG(ERROR) << "No clf model.";
+ return;
+ }
+
+ classification_executor_.reset(
+ new ModelExecutor(flatbuffers::GetRoot<tflite::Model>(
+ model_->classification_model()->data())));
+ if (!classification_executor_) {
+ TC_LOG(ERROR) << "Could not initialize classification executor.";
+ return;
+ }
+
+ classification_feature_processor_.reset(new FeatureProcessor(
+ model_->classification_feature_options(), unilib_));
}
- if (!model_->selection_feature_options()) {
- TC_LOG(ERROR) << "No selection feature options.";
- return;
- }
+ // The embeddings need to be specified if the model is to be used for
+ // classification or selection.
+ if (model_enabled_for_annotation || model_enabled_for_classification ||
+ model_enabled_for_selection) {
+ if (!model_->embedding_model()) {
+ TC_LOG(ERROR) << "No embedding model.";
+ return;
+ }
- if (!model_->classification_feature_options()) {
- TC_LOG(ERROR) << "No classification feature options.";
- return;
- }
+ // Check that the embedding size of the selection and classification model
+ // matches, as they are using the same embeddings.
+ if (model_enabled_for_selection &&
+ (model_->selection_feature_options()->embedding_size() !=
+ model_->classification_feature_options()->embedding_size() ||
+ model_->selection_feature_options()->embedding_quantization_bits() !=
+ model_->classification_feature_options()
+ ->embedding_quantization_bits())) {
+ TC_LOG(ERROR) << "Mismatching embedding size/quantization.";
+ return;
+ }
- if (!model_->classification_feature_options()->bounds_sensitive_features()) {
- TC_LOG(ERROR) << "No classification bounds sensitive feature options.";
- return;
- }
-
- if (!model_->selection_feature_options()->bounds_sensitive_features()) {
- TC_LOG(ERROR) << "No selection bounds sensitive feature options.";
- return;
- }
-
- if (model_->selection_feature_options()->embedding_size() !=
- model_->classification_feature_options()->embedding_size() ||
- model_->selection_feature_options()->embedding_quantization_bits() !=
- model_->classification_feature_options()
- ->embedding_quantization_bits()) {
- TC_LOG(ERROR) << "Mismatching embedding size/quantization.";
- return;
- }
-
- if (!model_->selection_model()) {
- TC_LOG(ERROR) << "No selection model.";
- return;
- }
-
- if (!model_->embedding_model()) {
- TC_LOG(ERROR) << "No embedding model.";
- return;
- }
-
- if (!model_->classification_model()) {
- TC_LOG(ERROR) << "No clf model.";
- return;
+ embedding_executor_.reset(new TFLiteEmbeddingExecutor(
+ flatbuffers::GetRoot<tflite::Model>(model_->embedding_model()->data()),
+ model_->classification_feature_options()->embedding_size(),
+ model_->classification_feature_options()
+ ->embedding_quantization_bits()));
+ if (!embedding_executor_ || !embedding_executor_->IsReady()) {
+ TC_LOG(ERROR) << "Could not initialize embedding executor.";
+ return;
+ }
}
if (model_->regex_model()) {
@@ -172,28 +249,6 @@
}
}
- embedding_executor_.reset(new TFLiteEmbeddingExecutor(
- flatbuffers::GetRoot<tflite::Model>(model_->embedding_model()->data()),
- model_->selection_feature_options()->embedding_size(),
- model_->selection_feature_options()->embedding_quantization_bits()));
- if (!embedding_executor_ || !embedding_executor_->IsReady()) {
- TC_LOG(ERROR) << "Could not initialize embedding executor.";
- return;
- }
- selection_executor_.reset(new ModelExecutor(
- flatbuffers::GetRoot<tflite::Model>(model_->selection_model()->data())));
- if (!selection_executor_) {
- TC_LOG(ERROR) << "Could not initialize selection executor.";
- return;
- }
- classification_executor_.reset(
- new ModelExecutor(flatbuffers::GetRoot<tflite::Model>(
- model_->classification_model()->data())));
- if (!classification_executor_) {
- TC_LOG(ERROR) << "Could not initialize classification executor.";
- return;
- }
-
if (model_->datetime_model()) {
datetime_parser_ =
DatetimeParser::Instance(model_->datetime_model(), *unilib_);
@@ -203,11 +258,6 @@
}
}
- selection_feature_processor_.reset(
- new FeatureProcessor(model_->selection_feature_options(), unilib_));
- classification_feature_processor_.reset(
- new FeatureProcessor(model_->classification_feature_options(), unilib_));
-
initialized_ = true;
}
@@ -232,21 +282,25 @@
continue;
}
- if (regex_pattern->enabled_for_annotation()) {
+ if (regex_pattern->enabled_modes() & ModeFlag_ANNOTATION) {
annotation_regex_patterns_.push_back(regex_pattern_id);
}
- if (regex_pattern->enabled_for_classification()) {
+ if (regex_pattern->enabled_modes() & ModeFlag_CLASSIFICATION) {
classification_regex_patterns_.push_back(regex_pattern_id);
}
- if (regex_pattern->enabled_for_selection()) {
+ if (regex_pattern->enabled_modes() & ModeFlag_SELECTION) {
selection_regex_patterns_.push_back(regex_pattern_id);
}
regex_patterns_.push_back({regex_pattern->collection_name()->str(),
regex_pattern->target_classification_score(),
regex_pattern->priority_score(),
std::move(compiled_pattern)});
+ if (regex_pattern->use_approximate_matching()) {
+ regex_approximate_match_pattern_ids_.insert(regex_pattern_id);
+ }
++regex_pattern_id;
}
+
return true;
}
@@ -284,6 +338,10 @@
TC_LOG(ERROR) << "Not initialized";
return click_indices;
}
+ if (!(model_->enabled_modes() & ModeFlag_SELECTION)) {
+ return click_indices;
+ }
+
const UnicodeText context_unicode = UTF8ToUnicodeText(context,
/*do_copy=*/false);
const int context_codepoint_size = context_unicode.size_codepoints();
@@ -298,7 +356,10 @@
}
std::vector<AnnotatedSpan> candidates;
- if (!ModelSuggestSelection(context_unicode, click_indices, &candidates)) {
+ InterpreterManager interpreter_manager(selection_executor_.get(),
+ classification_executor_.get());
+ if (!ModelSuggestSelection(context_unicode, click_indices,
+ &interpreter_manager, &candidates)) {
TC_LOG(ERROR) << "Model suggest selection failed.";
return click_indices;
}
@@ -308,7 +369,7 @@
}
if (!DatetimeChunk(UTF8ToUnicodeText(context, /*do_copy=*/false),
/*reference_time_ms_utc=*/0, /*reference_timezone=*/"",
- options.locales, &candidates)) {
+ options.locales, ModeFlag_SELECTION, &candidates)) {
TC_LOG(ERROR) << "Datetime suggest selection failed.";
return click_indices;
}
@@ -322,7 +383,8 @@
});
std::vector<int> candidate_indices;
- if (!ResolveConflicts(candidates, context, &candidate_indices)) {
+ if (!ResolveConflicts(candidates, context, &interpreter_manager,
+ &candidate_indices)) {
TC_LOG(ERROR) << "Couldn't resolve conflicts.";
return click_indices;
}
@@ -360,7 +422,7 @@
bool TextClassifier::ResolveConflicts(
const std::vector<AnnotatedSpan>& candidates, const std::string& context,
- std::vector<int>* result) const {
+ InterpreterManager* interpreter_manager, std::vector<int>* result) const {
result->clear();
result->reserve(candidates.size());
for (int i = 0; i < candidates.size();) {
@@ -371,7 +433,7 @@
if (conflict_found) {
std::vector<int> candidate_indices;
if (!ResolveConflict(context, candidates, i, first_non_overlapping,
- &candidate_indices)) {
+ interpreter_manager, &candidate_indices)) {
return false;
}
result->insert(result->end(), candidate_indices.begin(),
@@ -405,7 +467,8 @@
bool TextClassifier::ResolveConflict(
const std::string& context, const std::vector<AnnotatedSpan>& candidates,
- int start_index, int end_index, std::vector<int>* chosen_indices) const {
+ int start_index, int end_index, InterpreterManager* interpreter_manager,
+ std::vector<int>* chosen_indices) const {
std::vector<int> conflicting_indices;
std::unordered_map<int, float> scores;
for (int i = start_index; i < end_index; ++i) {
@@ -421,7 +484,8 @@
// candidate conflicts and comes from the model, we need to run a
// classification to determine its priority:
std::vector<ClassificationResult> classification;
- if (!ModelClassifyText(context, candidates[i].span, &classification)) {
+ if (!ModelClassifyText(context, candidates[i].span, interpreter_manager,
+ /*embedding_cache=*/nullptr, &classification)) {
return false;
}
@@ -458,10 +522,17 @@
bool TextClassifier::ModelSuggestSelection(
const UnicodeText& context_unicode, CodepointSpan click_indices,
+ InterpreterManager* interpreter_manager,
std::vector<AnnotatedSpan>* result) const {
- std::vector<Token> tokens;
+ if (model_->triggering_options() == nullptr ||
+ !(model_->triggering_options()->enabled_modes() & ModeFlag_SELECTION)) {
+ return true;
+ }
+
int click_pos;
- selection_feature_processor_->TokenizeAndFindClick(
+ std::vector<Token> tokens =
+ selection_feature_processor_->Tokenize(context_unicode);
+ selection_feature_processor_->RetokenizeAndFindClick(
context_unicode, click_indices,
selection_feature_processor_->GetOptions()->only_use_line_with_click(),
&tokens, &click_pos);
@@ -515,6 +586,7 @@
tokens, extraction_span,
/*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex},
embedding_executor_.get(),
+ /*embedding_cache=*/nullptr,
selection_feature_processor_->EmbeddingSize() +
selection_feature_processor_->DenseFeaturesCount(),
&cached_features)) {
@@ -525,7 +597,8 @@
// Produce selection model candidates.
std::vector<TokenSpan> chunks;
if (!ModelChunk(tokens.size(), /*span_of_interest=*/symmetry_context_span,
- *cached_features, &chunks)) {
+ interpreter_manager->SelectionInterpreter(), *cached_features,
+ &chunks)) {
TC_LOG(ERROR) << "Could not chunk.";
return false;
}
@@ -549,10 +622,85 @@
bool TextClassifier::ModelClassifyText(
const std::string& context, CodepointSpan selection_indices,
+ InterpreterManager* interpreter_manager,
+ FeatureProcessor::EmbeddingCache* embedding_cache,
+ std::vector<ClassificationResult>* classification_results) const {
+ if (model_->triggering_options() == nullptr ||
+ !(model_->triggering_options()->enabled_modes() &
+ ModeFlag_CLASSIFICATION)) {
+ return true;
+ }
+ return ModelClassifyText(context, {}, selection_indices, interpreter_manager,
+ embedding_cache, classification_results);
+}
+
+namespace internal {
+std::vector<Token> CopyCachedTokens(const std::vector<Token>& cached_tokens,
+ CodepointSpan selection_indices,
+ TokenSpan tokens_around_selection_to_copy) {
+ const auto first_selection_token = std::upper_bound(
+ cached_tokens.begin(), cached_tokens.end(), selection_indices.first,
+ [](int selection_start, const Token& token) {
+ return selection_start < token.end;
+ });
+ const auto last_selection_token = std::lower_bound(
+ cached_tokens.begin(), cached_tokens.end(), selection_indices.second,
+ [](const Token& token, int selection_end) {
+ return token.start < selection_end;
+ });
+
+ const int64 first_token = std::max(
+ static_cast<int64>(0),
+ static_cast<int64>((first_selection_token - cached_tokens.begin()) -
+ tokens_around_selection_to_copy.first));
+ const int64 last_token = std::min(
+ static_cast<int64>(cached_tokens.size()),
+ static_cast<int64>((last_selection_token - cached_tokens.begin()) +
+ tokens_around_selection_to_copy.second));
+
+ std::vector<Token> tokens;
+ tokens.reserve(last_token - first_token);
+ for (int i = first_token; i < last_token; ++i) {
+ tokens.push_back(cached_tokens[i]);
+ }
+ return tokens;
+}
+} // namespace internal
+
+TokenSpan TextClassifier::ClassifyTextUpperBoundNeededTokens() const {
+ const FeatureProcessorOptions_::BoundsSensitiveFeatures*
+ bounds_sensitive_features =
+ classification_feature_processor_->GetOptions()
+ ->bounds_sensitive_features();
+ if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
+ // The extraction span is the selection span expanded to include a relevant
+ // number of tokens outside of the bounds of the selection.
+ return {bounds_sensitive_features->num_tokens_before(),
+ bounds_sensitive_features->num_tokens_after()};
+ } else {
+ // The extraction span is the clicked token with context_size tokens on
+ // either side.
+ const int context_size =
+ selection_feature_processor_->GetOptions()->context_size();
+ return {context_size, context_size};
+ }
+}
+
+bool TextClassifier::ModelClassifyText(
+ const std::string& context, const std::vector<Token>& cached_tokens,
+ CodepointSpan selection_indices, InterpreterManager* interpreter_manager,
+ FeatureProcessor::EmbeddingCache* embedding_cache,
std::vector<ClassificationResult>* classification_results) const {
std::vector<Token> tokens;
+ if (cached_tokens.empty()) {
+ tokens = classification_feature_processor_->Tokenize(context);
+ } else {
+ tokens = internal::CopyCachedTokens(cached_tokens, selection_indices,
+ ClassifyTextUpperBoundNeededTokens());
+ }
+
int click_pos;
- classification_feature_processor_->TokenizeAndFindClick(
+ classification_feature_processor_->RetokenizeAndFindClick(
context, selection_indices,
classification_feature_processor_->GetOptions()
->only_use_line_with_click(),
@@ -586,7 +734,7 @@
// The extraction span is the clicked token with context_size tokens on
// either side.
const int context_size =
- selection_feature_processor_->GetOptions()->context_size();
+ classification_feature_processor_->GetOptions()->context_size();
extraction_span = ExpandTokenSpan(SingleTokenSpan(click_pos),
/*num_tokens_left=*/context_size,
/*num_tokens_right=*/context_size);
@@ -596,6 +744,7 @@
std::unique_ptr<CachedFeatures> cached_features;
if (!classification_feature_processor_->ExtractFeatures(
tokens, extraction_span, selection_indices, embedding_executor_.get(),
+ embedding_cache,
classification_feature_processor_->EmbeddingSize() +
classification_feature_processor_->DenseFeaturesCount(),
&cached_features)) {
@@ -612,9 +761,10 @@
cached_features->AppendClickContextFeaturesForClick(click_pos, &features);
}
- TensorView<float> logits =
- classification_executor_->ComputeLogits(TensorView<float>(
- features.data(), {1, static_cast<int>(features.size())}));
+ TensorView<float> logits = classification_executor_->ComputeLogits(
+ TensorView<float>(features.data(),
+ {1, static_cast<int>(features.size())}),
+ interpreter_manager->ClassificationInterpreter());
if (!logits.is_valid()) {
TC_LOG(ERROR) << "Couldn't compute logits.";
return false;
@@ -668,7 +818,17 @@
const std::unique_ptr<UniLib::RegexMatcher> matcher =
regex_pattern.pattern->Matcher(selection_text_unicode);
int status = UniLib::RegexMatcher::kNoError;
- if (matcher->Matches(&status) && status == UniLib::RegexMatcher::kNoError) {
+ bool matches;
+ if (regex_approximate_match_pattern_ids_.find(pattern_id) !=
+ regex_approximate_match_pattern_ids_.end()) {
+ matches = matcher->ApproximatelyMatches(&status);
+ } else {
+ matches = matcher->Matches(&status);
+ }
+ if (status != UniLib::RegexMatcher::kNoError) {
+ return false;
+ }
+ if (matches) {
*classification_result = {regex_pattern.collection_name,
regex_pattern.target_classification_score,
regex_pattern.priority_score};
@@ -686,13 +846,17 @@
const std::string& context, CodepointSpan selection_indices,
const ClassificationOptions& options,
ClassificationResult* classification_result) const {
+ if (!datetime_parser_) {
+ return false;
+ }
+
const std::string selection_text =
ExtractSelection(context, selection_indices);
std::vector<DatetimeParseResultSpan> datetime_spans;
if (!datetime_parser_->Parse(selection_text, options.reference_time_ms_utc,
options.reference_timezone, options.locales,
- &datetime_spans)) {
+ ModeFlag_CLASSIFICATION, &datetime_spans)) {
TC_LOG(ERROR) << "Error during parsing datetime.";
return false;
}
@@ -719,6 +883,10 @@
return {};
}
+ if (!(model_->enabled_modes() & ModeFlag_CLASSIFICATION)) {
+ return {};
+ }
+
if (std::get<0>(selection_indices) >= std::get<1>(selection_indices)) {
TC_VLOG(1) << "Trying to run ClassifyText with invalid indices: "
<< std::get<0>(selection_indices) << " "
@@ -741,7 +909,11 @@
// Fallback to the model.
std::vector<ClassificationResult> model_result;
- if (ModelClassifyText(context, selection_indices, &model_result)) {
+
+ InterpreterManager interpreter_manager(selection_executor_.get(),
+ classification_executor_.get());
+ if (ModelClassifyText(context, selection_indices, &interpreter_manager,
+ /*embedding_cache=*/nullptr, &model_result)) {
return model_result;
}
@@ -750,7 +922,13 @@
}
bool TextClassifier::ModelAnnotate(const std::string& context,
+ InterpreterManager* interpreter_manager,
std::vector<AnnotatedSpan>* result) const {
+ if (model_->triggering_options() == nullptr ||
+ !(model_->triggering_options()->enabled_modes() & ModeFlag_ANNOTATION)) {
+ return true;
+ }
+
const UnicodeText context_unicode = UTF8ToUnicodeText(context,
/*do_copy=*/false);
std::vector<UnicodeTextRange> lines;
@@ -760,13 +938,19 @@
lines = selection_feature_processor_->SplitContext(context_unicode);
}
- std::vector<TokenSpan> chunks;
+ const float min_annotate_confidence =
+ (model_->triggering_options() != nullptr
+ ? model_->triggering_options()->min_annotate_confidence()
+ : 0.f);
+
+ FeatureProcessor::EmbeddingCache embedding_cache;
for (const UnicodeTextRange& line : lines) {
const std::string line_str =
UnicodeText::UTF8Substring(line.first, line.second);
- std::vector<Token> tokens;
- selection_feature_processor_->TokenizeAndFindClick(
+ std::vector<Token> tokens =
+ selection_feature_processor_->Tokenize(line_str);
+ selection_feature_processor_->RetokenizeAndFindClick(
line_str, {0, std::distance(line.first, line.second)},
selection_feature_processor_->GetOptions()->only_use_line_with_click(),
&tokens,
@@ -778,6 +962,7 @@
tokens, full_line_span,
/*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex},
embedding_executor_.get(),
+ /*embedding_cache=*/nullptr,
selection_feature_processor_->EmbeddingSize() +
selection_feature_processor_->DenseFeaturesCount(),
&cached_features)) {
@@ -787,6 +972,7 @@
std::vector<TokenSpan> local_chunks;
if (!ModelChunk(tokens.size(), /*span_of_interest=*/full_line_span,
+ interpreter_manager->SelectionInterpreter(),
*cached_features, &local_chunks)) {
TC_LOG(ERROR) << "Could not chunk.";
return false;
@@ -800,28 +986,28 @@
// Skip empty spans.
if (codepoint_span.first != codepoint_span.second) {
- chunks.push_back(
- {codepoint_span.first + offset, codepoint_span.second + offset});
+ std::vector<ClassificationResult> classification;
+ if (!ModelClassifyText(line_str, tokens, codepoint_span,
+ interpreter_manager, &embedding_cache,
+ &classification)) {
+ TC_LOG(ERROR) << "Could not classify text: "
+ << (codepoint_span.first + offset) << " "
+ << (codepoint_span.second + offset);
+ return false;
+ }
+
+ // Do not include the span if it's classified as "other".
+ if (!classification.empty() && !ClassifiedAsOther(classification) &&
+ classification[0].score >= min_annotate_confidence) {
+ AnnotatedSpan result_span;
+ result_span.span = {codepoint_span.first + offset,
+ codepoint_span.second + offset};
+ result_span.classification = std::move(classification);
+ result->push_back(std::move(result_span));
+ }
}
}
}
-
- for (const CodepointSpan& chunk : chunks) {
- std::vector<ClassificationResult> classification;
- if (!ModelClassifyText(context, chunk, &classification)) {
- TC_LOG(ERROR) << "Could not classify text: " << chunk.first << " "
- << chunk.second;
- return false;
- }
-
- // Do not include the span if it's classified as "other".
- if (!classification.empty() && !ClassifiedAsOther(classification)) {
- AnnotatedSpan result_span;
- result_span.span = chunk;
- result_span.classification = std::move(classification);
- result->push_back(std::move(result_span));
- }
- }
return true;
}
@@ -834,8 +1020,14 @@
const std::string& context, const AnnotationOptions& options) const {
std::vector<AnnotatedSpan> candidates;
+ if (!(model_->enabled_modes() & ModeFlag_ANNOTATION)) {
+ return {};
+ }
+
+ InterpreterManager interpreter_manager(selection_executor_.get(),
+ classification_executor_.get());
// Annotate with the selection model.
- if (!ModelAnnotate(context, &candidates)) {
+ if (!ModelAnnotate(context, &interpreter_manager, &candidates)) {
TC_LOG(ERROR) << "Couldn't run ModelAnnotate.";
return {};
}
@@ -850,7 +1042,7 @@
// Annotate with the datetime model.
if (!DatetimeChunk(UTF8ToUnicodeText(context, /*do_copy=*/false),
options.reference_time_ms_utc, options.reference_timezone,
- options.locales, &candidates)) {
+ options.locales, ModeFlag_ANNOTATION, &candidates)) {
TC_LOG(ERROR) << "Couldn't run RegexChunk.";
return {};
}
@@ -864,22 +1056,17 @@
});
std::vector<int> candidate_indices;
- if (!ResolveConflicts(candidates, context, &candidate_indices)) {
+ if (!ResolveConflicts(candidates, context, &interpreter_manager,
+ &candidate_indices)) {
TC_LOG(ERROR) << "Couldn't resolve conflicts.";
return {};
}
- const float min_annotate_confidence =
- (model_->triggering_options() != nullptr
- ? model_->triggering_options()->min_annotate_confidence()
- : 0.f);
-
std::vector<AnnotatedSpan> result;
result.reserve(candidate_indices.size());
for (const int i : candidate_indices) {
if (!candidates[i].classification.empty() &&
- !ClassifiedAsOther(candidates[i].classification) &&
- candidates[i].classification[0].score >= min_annotate_confidence) {
+ !ClassifiedAsOther(candidates[i].classification)) {
result.push_back(std::move(candidates[i]));
}
}
@@ -917,6 +1104,7 @@
bool TextClassifier::ModelChunk(int num_tokens,
const TokenSpan& span_of_interest,
+ tflite::Interpreter* selection_interpreter,
const CachedFeatures& cached_features,
std::vector<TokenSpan>* chunks) const {
const int max_selection_span =
@@ -935,14 +1123,15 @@
selection_feature_processor_->GetOptions()
->bounds_sensitive_features()
->enabled()) {
- if (!ModelBoundsSensitiveScoreChunks(num_tokens, span_of_interest,
- inference_span, cached_features,
- &scored_chunks)) {
+ if (!ModelBoundsSensitiveScoreChunks(
+ num_tokens, span_of_interest, inference_span, cached_features,
+ selection_interpreter, &scored_chunks)) {
return false;
}
} else {
if (!ModelClickContextScoreChunks(num_tokens, span_of_interest,
- cached_features, &scored_chunks)) {
+ cached_features, selection_interpreter,
+ &scored_chunks)) {
return false;
}
}
@@ -1001,6 +1190,7 @@
bool TextClassifier::ModelClickContextScoreChunks(
int num_tokens, const TokenSpan& span_of_interest,
const CachedFeatures& cached_features,
+ tflite::Interpreter* selection_interpreter,
std::vector<ScoredChunk>* scored_chunks) const {
const int max_batch_size = model_->selection_options()->batch_size();
@@ -1023,7 +1213,8 @@
const int batch_size = batch_end - batch_start;
const int features_size = cached_features.OutputFeaturesSize();
TensorView<float> logits = selection_executor_->ComputeLogits(
- TensorView<float>(all_features.data(), {batch_size, features_size}));
+ TensorView<float>(all_features.data(), {batch_size, features_size}),
+ selection_interpreter);
if (!logits.is_valid()) {
TC_LOG(ERROR) << "Couldn't compute logits.";
return false;
@@ -1070,6 +1261,7 @@
bool TextClassifier::ModelBoundsSensitiveScoreChunks(
int num_tokens, const TokenSpan& span_of_interest,
const TokenSpan& inference_span, const CachedFeatures& cached_features,
+ tflite::Interpreter* selection_interpreter,
std::vector<ScoredChunk>* scored_chunks) const {
const int max_selection_span =
selection_feature_processor_->GetOptions()->max_selection_span();
@@ -1077,6 +1269,15 @@
->selection_reduced_output_space()
? max_selection_span + 1
: 2 * max_selection_span + 1;
+ const bool score_single_token_spans_as_zero =
+ selection_feature_processor_->GetOptions()
+ ->bounds_sensitive_features()
+ ->score_single_token_spans_as_zero();
+
+ scored_chunks->clear();
+ if (score_single_token_spans_as_zero) {
+ scored_chunks->reserve(TokenSpanSize(span_of_interest));
+ }
// Prepare all chunk candidates into one batch:
// - Are contained in the inference span
@@ -1090,15 +1291,22 @@
for (int end = leftmost_end_index;
end <= inference_span.second && end - start <= max_chunk_length;
++end) {
- candidate_spans.emplace_back(start, end);
+ const TokenSpan candidate_span = {start, end};
+ if (score_single_token_spans_as_zero &&
+ TokenSpanSize(candidate_span) == 1) {
+ // Do not include the single token span in the batch, add a zero score
+ // for it directly to the output.
+ scored_chunks->push_back(ScoredChunk{candidate_span, 0.0f});
+ } else {
+ candidate_spans.push_back(candidate_span);
+ }
}
}
const int max_batch_size = model_->selection_options()->batch_size();
std::vector<float> all_features;
- scored_chunks->clear();
- scored_chunks->reserve(candidate_spans.size());
+ scored_chunks->reserve(scored_chunks->size() + candidate_spans.size());
for (int batch_start = 0; batch_start < candidate_spans.size();
batch_start += max_batch_size) {
const int batch_end = std::min(batch_start + max_batch_size,
@@ -1116,7 +1324,8 @@
const int batch_size = batch_end - batch_start;
const int features_size = cached_features.OutputFeaturesSize();
TensorView<float> logits = selection_executor_->ComputeLogits(
- TensorView<float>(all_features.data(), {batch_size, features_size}));
+ TensorView<float>(all_features.data(), {batch_size, features_size}),
+ selection_interpreter);
if (!logits.is_valid()) {
TC_LOG(ERROR) << "Couldn't compute logits.";
return false;
@@ -1140,11 +1349,12 @@
bool TextClassifier::DatetimeChunk(const UnicodeText& context_unicode,
int64 reference_time_ms_utc,
const std::string& reference_timezone,
- const std::string& locales,
+ const std::string& locales, ModeFlag mode,
std::vector<AnnotatedSpan>* result) const {
std::vector<DatetimeParseResultSpan> datetime_spans;
if (!datetime_parser_->Parse(context_unicode, reference_time_ms_utc,
- reference_timezone, locales, &datetime_spans)) {
+ reference_timezone, locales, mode,
+ &datetime_spans)) {
return false;
}
for (const DatetimeParseResultSpan& datetime_span : datetime_spans) {
diff --git a/text-classifier.h b/text-classifier.h
index 33d6357..0c79429 100644
--- a/text-classifier.h
+++ b/text-classifier.h
@@ -75,6 +75,33 @@
static AnnotationOptions Default() { return AnnotationOptions(); }
};
+// Holds TFLite interpreters for selection and classification models.
+// NOTE: his class is not thread-safe, thus should NOT be re-used across
+// threads.
+class InterpreterManager {
+ public:
+ // The constructor can be called with nullptr for any of the executors, and is
+ // a defined behavior, as long as the corresponding *Interpreter() method is
+ // not called when the executor is null.
+ InterpreterManager(const ModelExecutor* selection_executor,
+ const ModelExecutor* classification_executor)
+ : selection_executor_(selection_executor),
+ classification_executor_(classification_executor) {}
+
+ // Gets or creates and caches an interpreter for the selection model.
+ tflite::Interpreter* SelectionInterpreter();
+
+ // Gets or creates and caches an interpreter for the classification model.
+ tflite::Interpreter* ClassificationInterpreter();
+
+ private:
+ const ModelExecutor* selection_executor_;
+ const ModelExecutor* classification_executor_;
+
+ std::unique_ptr<tflite::Interpreter> selection_interpreter_;
+ std::unique_ptr<tflite::Interpreter> classification_interpreter_;
+};
+
// A text processing model that provides text classification, annotation,
// selection suggestion for various types.
// NOTE: This class is not thread-safe.
@@ -167,6 +194,7 @@
// the span.
bool ResolveConflicts(const std::vector<AnnotatedSpan>& candidates,
const std::string& context,
+ InterpreterManager* interpreter_manager,
std::vector<int>* result) const;
// Resolves one conflict between candidates on indices 'start_index'
@@ -175,20 +203,35 @@
bool ResolveConflict(const std::string& context,
const std::vector<AnnotatedSpan>& candidates,
int start_index, int end_index,
+ InterpreterManager* interpreter_manager,
std::vector<int>* chosen_indices) const;
// Gets selection candidates from the ML model.
bool ModelSuggestSelection(const UnicodeText& context_unicode,
CodepointSpan click_indices,
+ InterpreterManager* interpreter_manager,
std::vector<AnnotatedSpan>* result) const;
// Classifies the selected text given the context string with the
// classification model.
// Returns true if no error occurred.
bool ModelClassifyText(
- const std::string& context, CodepointSpan selection_indices,
+ const std::string& context, const std::vector<Token>& cached_tokens,
+ CodepointSpan selection_indices, InterpreterManager* interpreter_manager,
+ FeatureProcessor::EmbeddingCache* embedding_cache,
std::vector<ClassificationResult>* classification_results) const;
+ bool ModelClassifyText(
+ const std::string& context, CodepointSpan selection_indices,
+ InterpreterManager* interpreter_manager,
+ FeatureProcessor::EmbeddingCache* embedding_cache,
+ std::vector<ClassificationResult>* classification_results) const;
+
+ // Returns a relative token span that represents how many tokens on the left
+ // from the selection and right from the selection are needed for the
+ // classifier input.
+ TokenSpan ClassifyTextUpperBoundNeededTokens() const;
+
// Classifies the selected text with the regular expressions models.
// Returns true if any regular expression matched and the result was set.
bool RegexClassifyText(const std::string& context,
@@ -207,6 +250,7 @@
// The annotations are sorted by their position in the context string and
// exclude spans classified as 'other'.
bool ModelAnnotate(const std::string& context,
+ InterpreterManager* interpreter_manager,
std::vector<AnnotatedSpan>* result) const;
// Groups the tokens into chunks. A chunk is a token span that should be the
@@ -219,6 +263,7 @@
// completely. The first and last chunk might extend beyond it.
// The chunks vector is cleared before filling.
bool ModelChunk(int num_tokens, const TokenSpan& span_of_interest,
+ tflite::Interpreter* selection_interpreter,
const CachedFeatures& cached_features,
std::vector<TokenSpan>* chunks) const;
@@ -228,6 +273,7 @@
bool ModelClickContextScoreChunks(
int num_tokens, const TokenSpan& span_of_interest,
const CachedFeatures& cached_features,
+ tflite::Interpreter* selection_interpreter,
std::vector<ScoredChunk>* scored_chunks) const;
// A helper method for ModelChunk(). It generates scored chunk candidates for
@@ -236,6 +282,7 @@
bool ModelBoundsSensitiveScoreChunks(
int num_tokens, const TokenSpan& span_of_interest,
const TokenSpan& inference_span, const CachedFeatures& cached_features,
+ tflite::Interpreter* selection_interpreter,
std::vector<ScoredChunk>* scored_chunks) const;
// Produces chunks isolated by a set of regular expressions.
@@ -247,19 +294,19 @@
bool DatetimeChunk(const UnicodeText& context_unicode,
int64 reference_time_ms_utc,
const std::string& reference_timezone,
- const std::string& locales,
+ const std::string& locales, ModeFlag mode,
std::vector<AnnotatedSpan>* result) const;
const Model* model_;
- std::unique_ptr<ModelExecutor> selection_executor_;
- std::unique_ptr<ModelExecutor> classification_executor_;
- std::unique_ptr<EmbeddingExecutor> embedding_executor_;
+ std::unique_ptr<const ModelExecutor> selection_executor_;
+ std::unique_ptr<const ModelExecutor> classification_executor_;
+ std::unique_ptr<const EmbeddingExecutor> embedding_executor_;
- std::unique_ptr<FeatureProcessor> selection_feature_processor_;
- std::unique_ptr<FeatureProcessor> classification_feature_processor_;
+ std::unique_ptr<const FeatureProcessor> selection_feature_processor_;
+ std::unique_ptr<const FeatureProcessor> classification_feature_processor_;
- std::unique_ptr<DatetimeParser> datetime_parser_;
+ std::unique_ptr<const DatetimeParser> datetime_parser_;
private:
struct CompiledRegexPattern {
@@ -271,8 +318,12 @@
std::unique_ptr<ScopedMmap> mmap_;
bool initialized_ = false;
+ bool enabled_for_annotation_ = false;
+ bool enabled_for_classification_ = false;
+ bool enabled_for_selection_ = false;
std::vector<CompiledRegexPattern> regex_patterns_;
+ std::unordered_set<int> regex_approximate_match_pattern_ids_;
// Indices into regex_patterns_ for the different modes.
std::vector<int> annotation_regex_patterns_, classification_regex_patterns_,
@@ -282,6 +333,15 @@
const UniLib* unilib_;
};
+namespace internal {
+// Copies tokens from 'cached_tokens' that are
+// 'tokens_around_selection_to_copy' (on the left, and right) tokens distant
+// from the tokens that correspond to 'selection_indices'.
+std::vector<Token> CopyCachedTokens(const std::vector<Token>& cached_tokens,
+ CodepointSpan selection_indices,
+ TokenSpan tokens_around_selection_to_copy);
+} // namespace internal
+
// Interprets the buffer as a Model flatbuffer and returns it for reading.
const Model* ViewModel(const void* buffer, int size);
diff --git a/text-classifier_test.cc b/text-classifier_test.cc
index 1145ac5..74534e2 100644
--- a/text-classifier_test.cc
+++ b/text-classifier_test.cc
@@ -30,6 +30,7 @@
namespace {
using testing::ElementsAreArray;
+using testing::IsEmpty;
using testing::Pair;
using testing::Values;
@@ -105,6 +106,50 @@
"a\n\n\n\nx x x\n\n\n\n\n\n", {1, 5})));
}
+TEST_P(TextClassifierTest, ClassifyTextDisabledFail) {
+ CREATE_UNILIB_FOR_TESTING;
+ const std::string test_model = ReadFile(GetModelPath() + GetParam());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ unpacked_model->classification_model.clear();
+ unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
+ unpacked_model->triggering_options->enabled_modes = ModeFlag_SELECTION;
+
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<TextClassifier> classifier =
+ TextClassifier::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), &unilib);
+
+ // The classification model is still needed for selection scores.
+ ASSERT_FALSE(classifier);
+}
+
+TEST_P(TextClassifierTest, ClassifyTextDisabled) {
+ CREATE_UNILIB_FOR_TESTING;
+ const std::string test_model = ReadFile(GetModelPath() + GetParam());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
+ unpacked_model->triggering_options->enabled_modes =
+ ModeFlag_ANNOTATION_AND_SELECTION;
+
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<TextClassifier> classifier =
+ TextClassifier::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), &unilib);
+ ASSERT_TRUE(classifier);
+
+ EXPECT_THAT(
+ classifier->ClassifyText("Call me at (800) 123-456 today", {11, 24}),
+ IsEmpty());
+}
+
std::unique_ptr<RegexModel_::PatternT> MakePattern(
const std::string& collection_name, const std::string& pattern,
const bool enabled_for_classification, const bool enabled_for_selection,
@@ -112,9 +157,12 @@
std::unique_ptr<RegexModel_::PatternT> result(new RegexModel_::PatternT);
result->collection_name = collection_name;
result->pattern = pattern;
- result->enabled_for_selection = enabled_for_selection;
- result->enabled_for_classification = enabled_for_classification;
- result->enabled_for_annotation = enabled_for_annotation;
+ // We cannot directly operate with |= on the flag, so use an int here.
+ int enabled_modes = ModeFlag_NONE;
+ if (enabled_for_annotation) enabled_modes |= ModeFlag_ANNOTATION;
+ if (enabled_for_classification) enabled_modes |= ModeFlag_CLASSIFICATION;
+ if (enabled_for_selection) enabled_modes |= ModeFlag_SELECTION;
+ result->enabled_modes = static_cast<ModeFlag>(enabled_modes);
result->target_classification_score = score;
result->priority_score = score;
return result;
@@ -171,7 +219,6 @@
#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
-
TEST_P(TextClassifierTest, SuggestSelectionRegularExpression) {
CREATE_UNILIB_FOR_TESTING;
const std::string test_model = ReadFile(GetModelPath() + GetParam());
@@ -308,7 +355,6 @@
IsAnnotatedSpan(79, 91, "phone"),
}));
}
-
#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
TEST_P(TextClassifierTest, PhoneFiltering) {
@@ -371,6 +417,58 @@
std::make_pair(11, 12));
}
+TEST_P(TextClassifierTest, SuggestSelectionDisabledFail) {
+ CREATE_UNILIB_FOR_TESTING;
+ const std::string test_model = ReadFile(GetModelPath() + GetParam());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ // Disable the selection model.
+ unpacked_model->selection_model.clear();
+ unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
+ unpacked_model->triggering_options->enabled_modes = ModeFlag_ANNOTATION;
+
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<TextClassifier> classifier =
+ TextClassifier::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), &unilib);
+ // Selection model needs to be present for annotation.
+ ASSERT_FALSE(classifier);
+}
+
+TEST_P(TextClassifierTest, SuggestSelectionDisabled) {
+ CREATE_UNILIB_FOR_TESTING;
+ const std::string test_model = ReadFile(GetModelPath() + GetParam());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ // Disable the selection model.
+ unpacked_model->selection_model.clear();
+ unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
+ unpacked_model->triggering_options->enabled_modes = ModeFlag_CLASSIFICATION;
+ unpacked_model->enabled_modes = ModeFlag_CLASSIFICATION;
+
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<TextClassifier> classifier =
+ TextClassifier::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), &unilib);
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ(
+ classifier->SuggestSelection("call me at 857 225 3556 today", {11, 14}),
+ std::make_pair(11, 14));
+
+ EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
+ "call me at (800) 123-456 today", {11, 24})));
+
+ EXPECT_THAT(classifier->Annotate("call me at (800) 123-456 today"),
+ IsEmpty());
+}
+
TEST_P(TextClassifierTest, SuggestSelectionsAreSymmetric) {
CREATE_UNILIB_FOR_TESTING;
std::unique_ptr<TextClassifier> classifier =
@@ -510,13 +608,14 @@
EXPECT_TRUE(classifier->Annotate("853 225\n3556", options).empty());
}
+#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
TEST_P(TextClassifierTest, AnnotateFilteringDiscardAll) {
CREATE_UNILIB_FOR_TESTING;
const std::string test_model = ReadFile(GetModelPath() + GetParam());
std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
- // Add test thresholds.
unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
+ // Add test threshold.
unpacked_model->triggering_options->min_annotate_confidence =
2.f; // Discards all results.
flatbuffers::FlatBufferBuilder builder;
@@ -531,8 +630,10 @@
const std::string test_string =
"& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
"number is 853 225 3556";
- EXPECT_TRUE(classifier->Annotate(test_string).empty());
+
+ EXPECT_EQ(classifier->Annotate(test_string).size(), 1);
}
+#endif
TEST_P(TextClassifierTest, AnnotateFilteringKeepAll) {
CREATE_UNILIB_FOR_TESTING;
@@ -543,6 +644,7 @@
unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
unpacked_model->triggering_options->min_annotate_confidence =
0.f; // Keeps all results.
+ unpacked_model->triggering_options->enabled_modes = ModeFlag_ALL;
flatbuffers::FlatBufferBuilder builder;
builder.Finish(Model::Pack(builder, unpacked_model.get()));
@@ -563,6 +665,27 @@
#endif
}
+TEST_P(TextClassifierTest, AnnotateDisabled) {
+ CREATE_UNILIB_FOR_TESTING;
+ const std::string test_model = ReadFile(GetModelPath() + GetParam());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ // Disable the model for annotation.
+ unpacked_model->enabled_modes = ModeFlag_CLASSIFICATION_AND_SELECTION;
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<TextClassifier> classifier =
+ TextClassifier::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), &unilib);
+ ASSERT_TRUE(classifier);
+ const std::string test_string =
+ "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
+ "number is 853 225 3556";
+ EXPECT_THAT(classifier->Annotate(test_string), IsEmpty());
+}
+
#ifdef LIBTEXTCLASSIFIER_CALENDAR_ICU
TEST_P(TextClassifierTest, ClassifyTextDate) {
std::unique_ptr<TextClassifier> classifier =
@@ -613,6 +736,32 @@
DatetimeGranularity::GRANULARITY_DAY);
result.clear();
}
+
+TEST_P(TextClassifierTest, SuggestTextDateDisabled) {
+ CREATE_UNILIB_FOR_TESTING;
+ const std::string test_model = ReadFile(GetModelPath() + GetParam());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ // Disable the patterns for selection.
+ for (int i = 0; i < unpacked_model->datetime_model->patterns.size(); i++) {
+ unpacked_model->datetime_model->patterns[i]->enabled_modes =
+ ModeFlag_ANNOTATION_AND_CLASSIFICATION;
+ }
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<TextClassifier> classifier =
+ TextClassifier::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), &unilib);
+ ASSERT_TRUE(classifier);
+ EXPECT_EQ("date",
+ FirstResult(classifier->ClassifyText("january 1, 2017", {0, 15})));
+ EXPECT_EQ(classifier->SuggestSelection("january 1, 2017", {0, 7}),
+ std::make_pair(0, 7));
+ EXPECT_THAT(classifier->Annotate("january 1, 2017"),
+ ElementsAreArray({IsAnnotatedSpan(0, 15, "date")}));
+}
#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
class TestingTextClassifier : public TextClassifier {
@@ -640,7 +789,8 @@
{MakeAnnotatedSpan({0, 1}, "phone", 1.0)}};
std::vector<int> chosen;
- classifier.ResolveConflicts(candidates, /*context=*/"", &chosen);
+ classifier.ResolveConflicts(candidates, /*context=*/"",
+ /*interpreter_manager=*/nullptr, &chosen);
EXPECT_THAT(chosen, ElementsAreArray({0}));
}
@@ -657,7 +807,8 @@
}};
std::vector<int> chosen;
- classifier.ResolveConflicts(candidates, /*context=*/"", &chosen);
+ classifier.ResolveConflicts(candidates, /*context=*/"",
+ /*interpreter_manager=*/nullptr, &chosen);
EXPECT_THAT(chosen, ElementsAreArray({0, 1, 2, 3, 4}));
}
@@ -672,7 +823,8 @@
}};
std::vector<int> chosen;
- classifier.ResolveConflicts(candidates, /*context=*/"", &chosen);
+ classifier.ResolveConflicts(candidates, /*context=*/"",
+ /*interpreter_manager=*/nullptr, &chosen);
EXPECT_THAT(chosen, ElementsAreArray({0, 2}));
}
@@ -687,7 +839,8 @@
}};
std::vector<int> chosen;
- classifier.ResolveConflicts(candidates, /*context=*/"", &chosen);
+ classifier.ResolveConflicts(candidates, /*context=*/"",
+ /*interpreter_manager=*/nullptr, &chosen);
EXPECT_THAT(chosen, ElementsAreArray({1}));
}
@@ -704,7 +857,8 @@
}};
std::vector<int> chosen;
- classifier.ResolveConflicts(candidates, /*context=*/"", &chosen);
+ classifier.ResolveConflicts(candidates, /*context=*/"",
+ /*interpreter_manager=*/nullptr, &chosen);
EXPECT_THAT(chosen, ElementsAreArray({0, 2, 4}));
}
@@ -740,5 +894,27 @@
}
#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
+#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
+// These coarse tests are there only to make sure the execution happens in
+// reasonable amount of time.
+TEST_P(TextClassifierTest, LongInputNoResultCheck) {
+ CREATE_UNILIB_FOR_TESTING;
+ std::unique_ptr<TextClassifier> classifier =
+ TextClassifier::FromPath(GetModelPath() + GetParam(), &unilib);
+ ASSERT_TRUE(classifier);
+
+ for (const std::string& value :
+ std::vector<std::string>{"http://www.aaaaaaaaaaaaaaaaaaaa.com "}) {
+ const std::string input_100k =
+ std::string(50000, ' ') + value + std::string(50000, ' ');
+ const int value_length = value.size();
+
+ classifier->Annotate(input_100k);
+ classifier->SuggestSelection(input_100k, {50000, 50001});
+ classifier->ClassifyText(input_100k, {50000, 50000 + value_length});
+ }
+}
+#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
+
} // namespace
} // namespace libtextclassifier2
diff --git a/token-feature-extractor.cc b/token-feature-extractor.cc
index e194179..13fba30 100644
--- a/token-feature-extractor.cc
+++ b/token-feature-extractor.cc
@@ -82,10 +82,12 @@
bool TokenFeatureExtractor::Extract(const Token& token, bool is_in_span,
std::vector<int>* sparse_features,
std::vector<float>* dense_features) const {
- if (sparse_features == nullptr || dense_features == nullptr) {
+ if (!dense_features) {
return false;
}
- *sparse_features = ExtractCharactergramFeatures(token);
+ if (sparse_features) {
+ *sparse_features = ExtractCharactergramFeatures(token);
+ }
*dense_features = ExtractDenseFeatures(token, is_in_span);
return true;
}
diff --git a/token-feature-extractor.h b/token-feature-extractor.h
index 1646f74..fee1355 100644
--- a/token-feature-extractor.h
+++ b/token-feature-extractor.h
@@ -71,7 +71,8 @@
// Extracts both the sparse (charactergram) and the dense features from a
// token. is_in_span is a bool indicator whether the token is a part of the
// selection span (true) or not (false).
- // Fails and returns false if either of the output pointers in a nullptr.
+ // The sparse_features output is optional. Fails and returns false if
+ // dense_fatures in a nullptr.
bool Extract(const Token& token, bool is_in_span,
std::vector<int>* sparse_features,
std::vector<float>* dense_features) const;
diff --git a/tokenizer.cc b/tokenizer.cc
index ebc9696..722a67b 100644
--- a/tokenizer.cc
+++ b/tokenizer.cc
@@ -26,20 +26,24 @@
Tokenizer::Tokenizer(
const std::vector<const TokenizationCodepointRange*>& codepoint_ranges,
bool split_on_script_change)
- : codepoint_ranges_(codepoint_ranges),
- split_on_script_change_(split_on_script_change) {
+ : split_on_script_change_(split_on_script_change) {
+ for (const TokenizationCodepointRange* range : codepoint_ranges) {
+ codepoint_ranges_.emplace_back(range->UnPack());
+ }
+
std::sort(codepoint_ranges_.begin(), codepoint_ranges_.end(),
- [](const TokenizationCodepointRange* a,
- const TokenizationCodepointRange* b) {
- return a->start() < b->start();
+ [](const std::unique_ptr<const TokenizationCodepointRangeT>& a,
+ const std::unique_ptr<const TokenizationCodepointRangeT>& b) {
+ return a->start < b->start;
});
}
-const TokenizationCodepointRange* Tokenizer::FindTokenizationRange(
+const TokenizationCodepointRangeT* Tokenizer::FindTokenizationRange(
int codepoint) const {
auto it = std::lower_bound(
codepoint_ranges_.begin(), codepoint_ranges_.end(), codepoint,
- [](const TokenizationCodepointRange* range, int codepoint) {
+ [](const std::unique_ptr<const TokenizationCodepointRangeT>& range,
+ int codepoint) {
// This function compares range with the codepoint for the purpose of
// finding the first greater or equal range. Because of the use of
// std::lower_bound it needs to return true when range < codepoint;
@@ -49,11 +53,11 @@
// It might seem weird that the condition is range.end <= codepoint
// here but when codepoint == range.end it means it's actually just
// outside of the range, thus the range is less than the codepoint.
- return range->end() <= codepoint;
+ return range->end <= codepoint;
});
- if (it != codepoint_ranges_.end() && (*it)->start() <= codepoint &&
- (*it)->end() > codepoint) {
- return *it;
+ if (it != codepoint_ranges_.end() && (*it)->start <= codepoint &&
+ (*it)->end > codepoint) {
+ return it->get();
} else {
return nullptr;
}
@@ -62,10 +66,10 @@
void Tokenizer::GetScriptAndRole(char32 codepoint,
TokenizationCodepointRange_::Role* role,
int* script) const {
- const TokenizationCodepointRange* range = FindTokenizationRange(codepoint);
+ const TokenizationCodepointRangeT* range = FindTokenizationRange(codepoint);
if (range) {
- *role = range->role();
- *script = range->script_id();
+ *role = range->role;
+ *script = range->script_id;
} else {
*role = TokenizationCodepointRange_::Role_DEFAULT_ROLE;
*script = kUnknownScript;
diff --git a/tokenizer.h b/tokenizer.h
index 9ce2c7c..2524e12 100644
--- a/tokenizer.h
+++ b/tokenizer.h
@@ -47,7 +47,7 @@
protected:
// Finds the tokenization codepoint range config for given codepoint.
// Internally uses binary search so should be O(log(# of codepoint_ranges)).
- const TokenizationCodepointRange* FindTokenizationRange(int codepoint) const;
+ const TokenizationCodepointRangeT* FindTokenizationRange(int codepoint) const;
// Finds the role and script for given codepoint. If not found, DEFAULT_ROLE
// and kUnknownScript are assigned.
@@ -58,7 +58,8 @@
private:
// Codepoint ranges that determine how different codepoints are tokenized.
// The ranges must not overlap.
- std::vector<const TokenizationCodepointRange*> codepoint_ranges_;
+ std::vector<std::unique_ptr<const TokenizationCodepointRangeT>>
+ codepoint_ranges_;
// If true, tokens will be additionally split when the codepoint's script_id
// changes.
diff --git a/tokenizer_test.cc b/tokenizer_test.cc
index d9a0dea..65072f3 100644
--- a/tokenizer_test.cc
+++ b/tokenizer_test.cc
@@ -58,10 +58,10 @@
}
TokenizationCodepointRange_::Role TestFindTokenizationRole(int c) const {
- const TokenizationCodepointRange* range =
+ const TokenizationCodepointRangeT* range =
tokenizer_->FindTokenizationRange(c);
if (range != nullptr) {
- return range->role();
+ return range->role;
} else {
return TokenizationCodepointRange_::Role_DEFAULT_ROLE;
}
diff --git a/types.h b/types.h
index dbfa312..b2f624d 100644
--- a/types.h
+++ b/types.h
@@ -305,6 +305,92 @@
typename std::vector<T>::const_iterator end_;
};
+struct DateParseData {
+ enum Relation {
+ NEXT = 1,
+ NEXT_OR_SAME = 2,
+ LAST = 3,
+ NOW = 4,
+ TOMORROW = 5,
+ YESTERDAY = 6,
+ PAST = 7,
+ FUTURE = 8
+ };
+
+ enum RelationType {
+ MONDAY = 1,
+ TUESDAY = 2,
+ WEDNESDAY = 3,
+ THURSDAY = 4,
+ FRIDAY = 5,
+ SATURDAY = 6,
+ SUNDAY = 7,
+ DAY = 8,
+ WEEK = 9,
+ MONTH = 10,
+ YEAR = 11
+ };
+
+ enum Fields {
+ YEAR_FIELD = 1 << 0,
+ MONTH_FIELD = 1 << 1,
+ DAY_FIELD = 1 << 2,
+ HOUR_FIELD = 1 << 3,
+ MINUTE_FIELD = 1 << 4,
+ SECOND_FIELD = 1 << 5,
+ AMPM_FIELD = 1 << 6,
+ ZONE_OFFSET_FIELD = 1 << 7,
+ DST_OFFSET_FIELD = 1 << 8,
+ RELATION_FIELD = 1 << 9,
+ RELATION_TYPE_FIELD = 1 << 10,
+ RELATION_DISTANCE_FIELD = 1 << 11
+ };
+
+ enum AMPM { AM = 0, PM = 1 };
+
+ enum TimeUnit {
+ DAYS = 1,
+ WEEKS = 2,
+ MONTHS = 3,
+ HOURS = 4,
+ MINUTES = 5,
+ SECONDS = 6,
+ YEARS = 7
+ };
+
+ // Bit mask of fields which have been set on the struct
+ int field_set_mask;
+
+ // Fields describing absolute date fields.
+ // Year of the date seen in the text match.
+ int year;
+ // Month of the year starting with January = 1.
+ int month;
+ // Day of the month starting with 1.
+ int day_of_month;
+ // Hour of the day with a range of 0-23,
+ // values less than 12 need the AMPM field below or heuristics
+ // to definitively determine the time.
+ int hour;
+ // Hour of the day with a range of 0-59.
+ int minute;
+ // Hour of the day with a range of 0-59.
+ int second;
+ // 0 == AM, 1 == PM
+ int ampm;
+ // Number of hours offset from UTC this date time is in.
+ int zone_offset;
+ // Number of hours offest for DST
+ int dst_offset;
+
+ // The permutation from now that was made to find the date time.
+ Relation relation;
+ // The unit of measure of the change to the date time.
+ RelationType relation_type;
+ // The number of units of change that were made.
+ int relation_distance;
+};
+
} // namespace libtextclassifier2
#endif // LIBTEXTCLASSIFIER_TYPES_H_
diff --git a/util/calendar/calendar-icu.h b/util/calendar/calendar-icu.h
index 50cb716..dc0a4f4 100644
--- a/util/calendar/calendar-icu.h
+++ b/util/calendar/calendar-icu.h
@@ -19,9 +19,9 @@
#include <string>
+#include "types.h"
#include "util/base/integral_types.h"
#include "util/base/logging.h"
-#include "util/calendar/types.h"
namespace libtextclassifier2 {
diff --git a/util/calendar/types.h b/util/calendar/types.h
deleted file mode 100644
index 4f58911..0000000
--- a/util/calendar/types.h
+++ /dev/null
@@ -1,106 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_UTIL_CALENDAR_TYPES_H_
-#define LIBTEXTCLASSIFIER_UTIL_CALENDAR_TYPES_H_
-
-struct DateParseData {
- enum Relation {
- NEXT = 1,
- NEXT_OR_SAME = 2,
- LAST = 3,
- NOW = 4,
- TOMORROW = 5,
- YESTERDAY = 6,
- PAST = 7,
- FUTURE = 8
- };
-
- enum RelationType {
- MONDAY = 1,
- TUESDAY = 2,
- WEDNESDAY = 3,
- THURSDAY = 4,
- FRIDAY = 5,
- SATURDAY = 6,
- SUNDAY = 7,
- DAY = 8,
- WEEK = 9,
- MONTH = 10,
- YEAR = 11
- };
-
- enum Fields {
- YEAR_FIELD = 1 << 0,
- MONTH_FIELD = 1 << 1,
- DAY_FIELD = 1 << 2,
- HOUR_FIELD = 1 << 3,
- MINUTE_FIELD = 1 << 4,
- SECOND_FIELD = 1 << 5,
- AMPM_FIELD = 1 << 6,
- ZONE_OFFSET_FIELD = 1 << 7,
- DST_OFFSET_FIELD = 1 << 8,
- RELATION_FIELD = 1 << 9,
- RELATION_TYPE_FIELD = 1 << 10,
- RELATION_DISTANCE_FIELD = 1 << 11
- };
-
- enum AMPM { AM = 0, PM = 1 };
-
- enum TimeUnit {
- DAYS = 1,
- WEEKS = 2,
- MONTHS = 3,
- HOURS = 4,
- MINUTES = 5,
- SECONDS = 6,
- YEARS = 7
- };
-
- // Bit mask of fields which have been set on the struct
- int field_set_mask;
-
- // Fields describing absolute date fields.
- // Year of the date seen in the text match.
- int year;
- // Month of the year starting with January = 1.
- int month;
- // Day of the month starting with 1.
- int day_of_month;
- // Hour of the day with a range of 0-23,
- // values less than 12 need the AMPM field below or heuristics
- // to definitively determine the time.
- int hour;
- // Hour of the day with a range of 0-59.
- int minute;
- // Hour of the day with a range of 0-59.
- int second;
- // 0 == AM, 1 == PM
- int ampm;
- // Number of hours offset from UTC this date time is in.
- int zone_offset;
- // Number of hours offest for DST
- int dst_offset;
-
- // The permutation from now that was made to find the date time.
- Relation relation;
- // The unit of measure of the change to the date time.
- RelationType relation_type;
- // The number of units of change that were made.
- int relation_distance;
-};
-
-#endif // LIBTEXTCLASSIFIER_UTIL_CALENDAR_TYPES_H_
diff --git a/util/utf8/unicodetext.cc b/util/utf8/unicodetext.cc
index 79381bf..90a581f 100644
--- a/util/utf8/unicodetext.cc
+++ b/util/utf8/unicodetext.cc
@@ -199,7 +199,7 @@
}
std::string UnicodeText::ToUTF8String() const {
- return std::string(begin(), end());
+ return UTF8Substring(begin(), end());
}
std::string UnicodeText::UTF8Substring(const const_iterator& first,
diff --git a/util/utf8/unicodetext.h b/util/utf8/unicodetext.h
index 7fb8ac1..8e13496 100644
--- a/util/utf8/unicodetext.h
+++ b/util/utf8/unicodetext.h
@@ -194,7 +194,6 @@
void append(const char* bytes, int byte_length);
void Copy(const char* data, int size);
- void TakeOwnershipOf(char* data, int size, int capacity);
void PointTo(const char* data, int size);
private:
diff --git a/util/utf8/unicodetext_test.cc b/util/utf8/unicodetext_test.cc
new file mode 100644
index 0000000..8aef952
--- /dev/null
+++ b/util/utf8/unicodetext_test.cc
@@ -0,0 +1,167 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "util/utf8/unicodetext.h"
+
+#include "gtest/gtest.h"
+
+namespace libtextclassifier2 {
+namespace {
+
+class UnicodeTextTest : public testing::Test {
+ protected:
+ UnicodeTextTest() : empty_text_() {
+ text_.AppendCodepoint(0x1C0);
+ text_.AppendCodepoint(0x4E8C);
+ text_.AppendCodepoint(0xD7DB);
+ text_.AppendCodepoint(0x34);
+ text_.AppendCodepoint(0x1D11E);
+ }
+
+ UnicodeText empty_text_;
+ UnicodeText text_;
+};
+
+// Tests for our modifications of UnicodeText.
+TEST(UnicodeTextTest, Custom) {
+ UnicodeText text = UTF8ToUnicodeText("1234πhello", /*do_copy=*/false);
+ EXPECT_EQ(text.ToUTF8String(), "1234πhello");
+ EXPECT_EQ(text.size_codepoints(), 10);
+ EXPECT_EQ(text.size_bytes(), 13);
+
+ auto it_begin = text.begin();
+ std::advance(it_begin, 4);
+ auto it_end = text.begin();
+ std::advance(it_end, 6);
+ EXPECT_EQ(text.UTF8Substring(it_begin, it_end), "πh");
+}
+
+TEST(UnicodeTextTest, Ownership) {
+ const std::string src = "\u304A\u00B0\u106B";
+
+ UnicodeText alias;
+ alias.PointToUTF8(src.data(), src.size());
+ EXPECT_EQ(alias.data(), src.data());
+ UnicodeText::const_iterator it = alias.begin();
+ EXPECT_EQ(*it++, 0x304A);
+ EXPECT_EQ(*it++, 0x00B0);
+ EXPECT_EQ(*it++, 0x106B);
+ EXPECT_EQ(it, alias.end());
+
+ UnicodeText t = alias; // Copy initialization copies the data.
+ EXPECT_NE(t.data(), alias.data());
+}
+
+class IteratorTest : public UnicodeTextTest {};
+
+TEST_F(IteratorTest, Iterates) {
+ UnicodeText::const_iterator iter = text_.begin();
+ EXPECT_EQ(0x1C0, *iter);
+ EXPECT_EQ(&iter, &++iter); // operator++ returns *this.
+ EXPECT_EQ(0x4E8C, *iter++);
+ EXPECT_EQ(0xD7DB, *iter);
+ // Make sure you can dereference more than once.
+ EXPECT_EQ(0xD7DB, *iter);
+ EXPECT_EQ(0x34, *++iter);
+ EXPECT_EQ(0x1D11E, *++iter);
+ ASSERT_TRUE(iter != text_.end());
+ iter++;
+ EXPECT_TRUE(iter == text_.end());
+}
+
+TEST_F(IteratorTest, MultiPass) {
+ // Also tests Default Constructible and Assignable.
+ UnicodeText::const_iterator i1, i2;
+ i1 = text_.begin();
+ i2 = i1;
+ EXPECT_EQ(0x4E8C, *++i1);
+ EXPECT_TRUE(i1 != i2);
+ EXPECT_EQ(0x1C0, *i2);
+ ++i2;
+ EXPECT_TRUE(i1 == i2);
+ EXPECT_EQ(0x4E8C, *i2);
+}
+
+TEST_F(IteratorTest, ReverseIterates) {
+ UnicodeText::const_iterator iter = text_.end();
+ EXPECT_TRUE(iter == text_.end());
+ iter--;
+ ASSERT_TRUE(iter != text_.end());
+ EXPECT_EQ(0x1D11E, *iter--);
+ EXPECT_EQ(0x34, *iter);
+ EXPECT_EQ(0xD7DB, *--iter);
+ // Make sure you can dereference more than once.
+ EXPECT_EQ(0xD7DB, *iter);
+ --iter;
+ EXPECT_EQ(0x4E8C, *iter--);
+ EXPECT_EQ(0x1C0, *iter);
+ EXPECT_TRUE(iter == text_.begin());
+}
+
+TEST_F(IteratorTest, Comparable) {
+ UnicodeText::const_iterator i1, i2;
+ i1 = text_.begin();
+ i2 = i1;
+ ++i2;
+
+ EXPECT_TRUE(i1 < i2);
+ EXPECT_TRUE(text_.begin() <= i1);
+ EXPECT_FALSE(i1 >= i2);
+ EXPECT_FALSE(i1 > text_.end());
+}
+
+TEST_F(IteratorTest, Advance) {
+ UnicodeText::const_iterator iter = text_.begin();
+ EXPECT_EQ(0x1C0, *iter);
+ std::advance(iter, 4);
+ EXPECT_EQ(0x1D11E, *iter);
+ ++iter;
+ EXPECT_TRUE(iter == text_.end());
+}
+
+TEST_F(IteratorTest, Distance) {
+ UnicodeText::const_iterator iter = text_.begin();
+ EXPECT_EQ(0, std::distance(text_.begin(), iter));
+ EXPECT_EQ(5, std::distance(iter, text_.end()));
+ ++iter;
+ ++iter;
+ EXPECT_EQ(2, std::distance(text_.begin(), iter));
+ EXPECT_EQ(3, std::distance(iter, text_.end()));
+ ++iter;
+ ++iter;
+ EXPECT_EQ(4, std::distance(text_.begin(), iter));
+ ++iter;
+ EXPECT_EQ(0, std::distance(iter, text_.end()));
+}
+
+class OperatorTest : public UnicodeTextTest {};
+
+TEST_F(OperatorTest, Clear) {
+ UnicodeText empty_text(UTF8ToUnicodeText("", /*do_copy=*/false));
+ EXPECT_FALSE(text_ == empty_text);
+ text_.clear();
+ EXPECT_TRUE(text_ == empty_text);
+}
+
+TEST_F(OperatorTest, Empty) {
+ EXPECT_TRUE(empty_text_.empty());
+ EXPECT_FALSE(text_.empty());
+ text_.clear();
+ EXPECT_TRUE(text_.empty());
+}
+
+} // namespace
+} // namespace libtextclassifier2
diff --git a/util/utf8/unilib-icu.cc b/util/utf8/unilib-icu.cc
index 3d5500a..b1eac2c 100644
--- a/util/utf8/unilib-icu.cc
+++ b/util/utf8/unilib-icu.cc
@@ -85,12 +85,11 @@
constexpr int UniLib::RegexMatcher::kNoError;
bool UniLib::RegexMatcher::Matches(int* status) const {
- std::string text = "";
- text_.toUTF8String(text);
if (!matcher_) {
*status = kError;
return false;
}
+
UErrorCode icu_status = U_ZERO_ERROR;
const bool result = matcher_->matches(/*startIndex=*/0, icu_status);
if (U_FAILURE(icu_status)) {
@@ -101,6 +100,31 @@
return result;
}
+bool UniLib::RegexMatcher::ApproximatelyMatches(int* status) {
+ if (!matcher_) {
+ *status = kError;
+ return false;
+ }
+
+ matcher_->reset();
+ *status = kNoError;
+ if (!Find(status) || *status != kNoError) {
+ return false;
+ }
+ const int found_start = Start(status);
+ if (*status != kNoError) {
+ return false;
+ }
+ const int found_end = End(status);
+ if (*status != kNoError) {
+ return false;
+ }
+ if (found_start != 0 || found_end != text_.countChar32()) {
+ return false;
+ }
+ return true;
+}
+
bool UniLib::RegexMatcher::Find(int* status) {
if (!matcher_) {
*status = kError;
@@ -117,18 +141,7 @@
}
int UniLib::RegexMatcher::Start(int* status) const {
- if (!matcher_) {
- *status = kError;
- return kError;
- }
- UErrorCode icu_status = U_ZERO_ERROR;
- const int result = matcher_->start(icu_status);
- if (U_FAILURE(icu_status)) {
- *status = kError;
- return kError;
- }
- *status = kNoError;
- return result;
+ return Start(/*group_idx=*/0, status);
}
int UniLib::RegexMatcher::Start(int group_idx, int* status) const {
@@ -143,22 +156,24 @@
return kError;
}
*status = kNoError;
- return result;
+ return text_.countChar32(/*start=*/0, /*length=*/result);
}
-int UniLib::RegexMatcher::End(int* status) const {
- if (!matcher_) {
- *status = kError;
- return kError;
- }
+int UniLib::RegexMatcher::Start(StringPiece group_name, int* status) const {
UErrorCode icu_status = U_ZERO_ERROR;
- const int result = matcher_->end(icu_status);
+ const int group_idx = pattern_->groupNumberFromName(
+ icu::UnicodeString::fromUTF8(
+ icu::StringPiece(group_name.data(), group_name.size())),
+ icu_status);
if (U_FAILURE(icu_status)) {
*status = kError;
return kError;
}
- *status = kNoError;
- return result;
+ return Start(group_idx, status);
+}
+
+int UniLib::RegexMatcher::End(int* status) const {
+ return End(/*group_idx=*/0, status);
}
int UniLib::RegexMatcher::End(int group_idx, int* status) const {
@@ -173,23 +188,24 @@
return kError;
}
*status = kNoError;
- return result;
+ return text_.countChar32(/*start=*/0, /*length=*/result);
+}
+
+int UniLib::RegexMatcher::End(StringPiece group_name, int* status) const {
+ UErrorCode icu_status = U_ZERO_ERROR;
+ const int group_idx = pattern_->groupNumberFromName(
+ icu::UnicodeString::fromUTF8(
+ icu::StringPiece(group_name.data(), group_name.size())),
+ icu_status);
+ if (U_FAILURE(icu_status)) {
+ *status = kError;
+ return kError;
+ }
+ return End(group_idx, status);
}
UnicodeText UniLib::RegexMatcher::Group(int* status) const {
- if (!matcher_) {
- *status = kError;
- return UTF8ToUnicodeText("", /*do_copy=*/false);
- }
- std::string result = "";
- UErrorCode icu_status = U_ZERO_ERROR;
- matcher_->group(icu_status).toUTF8String(result);
- if (U_FAILURE(icu_status)) {
- *status = kError;
- return UTF8ToUnicodeText("", /*do_copy=*/false);
- }
- *status = kNoError;
- return UTF8ToUnicodeText(result, /*do_copy=*/true);
+ return Group(/*group_idx=*/0, status);
}
UnicodeText UniLib::RegexMatcher::Group(int group_idx, int* status) const {
@@ -199,21 +215,22 @@
}
std::string result = "";
UErrorCode icu_status = U_ZERO_ERROR;
- matcher_->group(group_idx, icu_status).toUTF8String(result);
+ const icu::UnicodeString result_icu = matcher_->group(group_idx, icu_status);
if (U_FAILURE(icu_status)) {
*status = kError;
return UTF8ToUnicodeText("", /*do_copy=*/false);
}
+ result_icu.toUTF8String(result);
*status = kNoError;
return UTF8ToUnicodeText(result, /*do_copy=*/true);
}
-UnicodeText UniLib::RegexMatcher::Group(const std::string& group_name,
+UnicodeText UniLib::RegexMatcher::Group(StringPiece group_name,
int* status) const {
UErrorCode icu_status = U_ZERO_ERROR;
const int group_idx = pattern_->groupNumberFromName(
icu::UnicodeString::fromUTF8(
- icu::StringPiece(group_name.c_str(), group_name.size())),
+ icu::StringPiece(group_name.data(), group_name.size())),
icu_status);
if (U_FAILURE(icu_status)) {
*status = kError;
@@ -226,7 +243,9 @@
UniLib::BreakIterator::BreakIterator(const UnicodeText& text)
: text_(icu::UnicodeString::fromUTF8(
- icu::StringPiece(text.data(), text.size_bytes()))) {
+ icu::StringPiece(text.data(), text.size_bytes()))),
+ last_break_index_(0),
+ last_unicode_index_(0) {
icu::ErrorCode status;
break_iterator_.reset(
icu::BreakIterator::createWordInstance(icu::Locale("en"), status));
@@ -238,12 +257,14 @@
}
int UniLib::BreakIterator::Next() {
- const int result = break_iterator_->next();
- if (result == icu::BreakIterator::DONE) {
+ const int break_index = break_iterator_->next();
+ if (break_index == icu::BreakIterator::DONE) {
return BreakIterator::kDone;
- } else {
- return result;
}
+ last_unicode_index_ +=
+ text_.countChar32(last_break_index_, break_index - last_break_index_);
+ last_break_index_ = break_index;
+ return last_unicode_index_;
}
std::unique_ptr<UniLib::RegexPattern> UniLib::CreateRegexPattern(
diff --git a/util/utf8/unilib-icu.h b/util/utf8/unilib-icu.h
index 8070d24..9488a00 100644
--- a/util/utf8/unilib-icu.h
+++ b/util/utf8/unilib-icu.h
@@ -23,6 +23,7 @@
#include <memory>
#include "util/base/integral_types.h"
+#include "util/strings/stringpiece.h"
#include "util/utf8/unicodetext.h"
#include "unicode/brkiter.h"
#include "unicode/errorcode.h"
@@ -55,10 +56,18 @@
// Checks whether the input text matches the pattern exactly.
bool Matches(int* status) const;
+ // Approximate Matches() implementation implemented using Find(). It uses
+ // the first Find() result and then checks that it spans the whole input.
+ // NOTE: Unlike Matches() it can result in false negatives.
+ // NOTE: Resets the matcher, so the current Find() state will be lost.
+ bool ApproximatelyMatches(int* status);
+
// Finds occurrences of the pattern in the input text.
// Can be called repeatedly to find all occurences. A call will update
// internal state, so that 'Start', 'End' and 'Group' can be called to get
// information about the match.
+ // NOTE: Any call to ApproximatelyMatches() in between Find() calls will
+ // modify the state.
bool Find(int* status);
// Gets the start offset of the last match (from 'Find').
@@ -72,6 +81,9 @@
// was not called previously.
int Start(int group_idx, int* status) const;
+ // Same as above but uses the group name instead of the index.
+ int Start(StringPiece group_name, int* status) const;
+
// Gets the end offset of the last match (from 'Find').
// Sets status to 'kError' if 'Find'
// was not called previously.
@@ -83,6 +95,9 @@
// was not called previously.
int End(int group_idx, int* status) const;
+ // Same as above but uses the group name instead of the index.
+ int End(StringPiece group_name, int* status) const;
+
// Gets the text of the last match (from 'Find').
// Sets status to 'kError' if 'Find' was not called previously.
UnicodeText Group(int* status) const;
@@ -95,7 +110,7 @@
// Gets the text of the specified group of the last match (from 'Find').
// Sets status to 'kError' if an invalid group was specified or if 'Find'
// was not called previously.
- UnicodeText Group(const std::string& group_name, int* status) const;
+ UnicodeText Group(StringPiece group_name, int* status) const;
protected:
friend class RegexPattern;
@@ -133,6 +148,8 @@
private:
std::unique_ptr<icu::BreakIterator> break_iterator_;
icu::UnicodeString text_;
+ int last_break_index_;
+ int last_unicode_index_;
};
std::unique_ptr<RegexPattern> CreateRegexPattern(
diff --git a/util/utf8/unilib_test.cc b/util/utf8/unilib_test.cc
index bff2ffc..665bfec 100644
--- a/util/utf8/unilib_test.cc
+++ b/util/utf8/unilib_test.cc
@@ -15,17 +15,19 @@
*/
#include "util/utf8/unilib.h"
-#include "util/utf8/unicodetext.h"
#include "util/base/logging.h"
-
+#include "util/utf8/unicodetext.h"
+#include "gmock/gmock.h"
#include "gtest/gtest.h"
namespace libtextclassifier2 {
namespace {
+using ::testing::ElementsAre;
+
TEST(UniLibTest, CharacterClassesAscii) {
- CREATE_UNILIB_FOR_TESTING
+ CREATE_UNILIB_FOR_TESTING;
EXPECT_TRUE(unilib.IsOpeningBracket('('));
EXPECT_TRUE(unilib.IsClosingBracket(')'));
EXPECT_FALSE(unilib.IsWhitespace(')'));
@@ -45,7 +47,7 @@
#ifndef LIBTEXTCLASSIFIER_UNILIB_DUMMY
TEST(UniLibTest, CharacterClassesUnicode) {
- CREATE_UNILIB_FOR_TESTING
+ CREATE_UNILIB_FOR_TESTING;
EXPECT_TRUE(unilib.IsOpeningBracket(0x0F3C)); // TIBET ANG KHANG GYON
EXPECT_TRUE(unilib.IsClosingBracket(0x0F3D)); // TIBET ANG KHANG GYAS
EXPECT_FALSE(unilib.IsWhitespace(0x23F0)); // ALARM CLOCK
@@ -68,8 +70,8 @@
}
#endif // ndef LIBTEXTCLASSIFIER_UNILIB_DUMMY
-TEST(UniLibTest, Regex) {
- CREATE_UNILIB_FOR_TESTING
+TEST(UniLibTest, RegexInterface) {
+ CREATE_UNILIB_FOR_TESTING;
const UnicodeText regex_pattern =
UTF8ToUnicodeText("[0-9]+", /*do_copy=*/true);
std::unique_ptr<UniLib::RegexPattern> pattern =
@@ -80,48 +82,161 @@
TC_LOG(INFO) << matcher->Matches(&status);
TC_LOG(INFO) << matcher->Find(&status);
TC_LOG(INFO) << matcher->Start(0, &status);
+ TC_LOG(INFO) << matcher->Start("group_name", &status);
TC_LOG(INFO) << matcher->End(0, &status);
+ TC_LOG(INFO) << matcher->End("group_name", &status);
TC_LOG(INFO) << matcher->Group(0, &status).size_codepoints();
}
+#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
+TEST(UniLibTest, Regex) {
+ CREATE_UNILIB_FOR_TESTING;
+
+ // The smiley face is a 4-byte UTF8 codepoint 0x1F60B, and it's important to
+ // test the regex functionality with it to verify we are handling the indices
+ // correctly.
+ const UnicodeText regex_pattern =
+ UTF8ToUnicodeText("[0-9]+π", /*do_copy=*/false);
+ std::unique_ptr<UniLib::RegexPattern> pattern =
+ unilib.CreateRegexPattern(regex_pattern);
+ int status;
+ std::unique_ptr<UniLib::RegexMatcher> matcher;
+
+ matcher = pattern->Matcher(UTF8ToUnicodeText("0123π", /*do_copy=*/false));
+ EXPECT_TRUE(matcher->Matches(&status));
+ EXPECT_TRUE(matcher->ApproximatelyMatches(&status));
+ EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
+ EXPECT_TRUE(matcher->Matches(&status)); // Check that the state is reset.
+ EXPECT_TRUE(matcher->ApproximatelyMatches(&status));
+ EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
+
+ matcher = pattern->Matcher(
+ UTF8ToUnicodeText("helloππ 0123π world", /*do_copy=*/false));
+ EXPECT_FALSE(matcher->Matches(&status));
+ EXPECT_FALSE(matcher->ApproximatelyMatches(&status));
+ EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
+
+ matcher = pattern->Matcher(
+ UTF8ToUnicodeText("helloππ 0123π world", /*do_copy=*/false));
+ EXPECT_TRUE(matcher->Find(&status));
+ EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
+ EXPECT_EQ(matcher->Start(0, &status), 8);
+ EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
+ EXPECT_EQ(matcher->End(0, &status), 13);
+ EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
+ EXPECT_EQ(matcher->Group(0, &status).ToUTF8String(), "0123π");
+ EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
+}
+#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
+
+#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
+TEST(UniLibTest, RegexGroups) {
+ CREATE_UNILIB_FOR_TESTING;
+
+ // The smiley face is a 4-byte UTF8 codepoint 0x1F60B, and it's important to
+ // test the regex functionality with it to verify we are handling the indices
+ // correctly.
+ const UnicodeText regex_pattern = UTF8ToUnicodeText(
+ "(?<group1>[0-9])(?<group2>[0-9]+)π", /*do_copy=*/false);
+ std::unique_ptr<UniLib::RegexPattern> pattern =
+ unilib.CreateRegexPattern(regex_pattern);
+ int status;
+ std::unique_ptr<UniLib::RegexMatcher> matcher;
+
+ matcher = pattern->Matcher(
+ UTF8ToUnicodeText("helloππ 0123π world", /*do_copy=*/false));
+ EXPECT_TRUE(matcher->Find(&status));
+ EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
+ EXPECT_EQ(matcher->Start(0, &status), 8);
+ EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
+ EXPECT_EQ(matcher->Start(1, &status), 8);
+ EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
+ EXPECT_EQ(matcher->Start("group1", &status), 8);
+ EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
+ EXPECT_EQ(matcher->Start(2, &status), 9);
+ EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
+ EXPECT_EQ(matcher->Start("group2", &status), 9);
+ EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
+ EXPECT_EQ(matcher->End(0, &status), 13);
+ EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
+ EXPECT_EQ(matcher->End(1, &status), 9);
+ EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
+ EXPECT_EQ(matcher->End("group1", &status), 9);
+ EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
+ EXPECT_EQ(matcher->End(2, &status), 12);
+ EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
+ EXPECT_EQ(matcher->End("group2", &status), 12);
+ EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
+ EXPECT_EQ(matcher->Group(0, &status).ToUTF8String(), "0123π");
+ EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
+ EXPECT_EQ(matcher->Group(1, &status).ToUTF8String(), "0");
+ EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
+ EXPECT_EQ(matcher->Group(2, &status).ToUTF8String(), "123");
+ EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
+}
+#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
+
+#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
+
TEST(UniLibTest, BreakIterator) {
- CREATE_UNILIB_FOR_TESTING
+ CREATE_UNILIB_FOR_TESTING;
const UnicodeText text = UTF8ToUnicodeText("some text", /*do_copy=*/false);
std::unique_ptr<UniLib::BreakIterator> iterator =
unilib.CreateBreakIterator(text);
- TC_LOG(INFO) << iterator->Next();
- TC_LOG(INFO) << UniLib::BreakIterator::kDone;
+ std::vector<int> break_indices;
+ int break_index = 0;
+ while ((break_index = iterator->Next()) != UniLib::BreakIterator::kDone) {
+ break_indices.push_back(break_index);
+ }
+ EXPECT_THAT(break_indices, ElementsAre(4, 5, 9));
}
+#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
+#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
+TEST(UniLibTest, BreakIterator4ByteUTF8) {
+ CREATE_UNILIB_FOR_TESTING;
+ const UnicodeText text = UTF8ToUnicodeText("πππ", /*do_copy=*/false);
+ std::unique_ptr<UniLib::BreakIterator> iterator =
+ unilib.CreateBreakIterator(text);
+ std::vector<int> break_indices;
+ int break_index = 0;
+ while ((break_index = iterator->Next()) != UniLib::BreakIterator::kDone) {
+ break_indices.push_back(break_index);
+ }
+ EXPECT_THAT(break_indices, ElementsAre(1, 2, 3));
+}
+#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
+
+#ifndef LIBTEXTCLASSIFIER_UNILIB_JAVAICU
TEST(UniLibTest, IntegerParse) {
-#if !defined LIBTEXTCLASSIFIER_UNILIB_JAVAICU
CREATE_UNILIB_FOR_TESTING;
int result;
EXPECT_TRUE(
unilib.ParseInt32(UTF8ToUnicodeText("123", /*do_copy=*/false), &result));
EXPECT_EQ(result, 123);
-#endif
}
+#endif // ndef LIBTEXTCLASSIFIER_UNILIB_JAVAICU
+#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
TEST(UniLibTest, IntegerParseFullWidth) {
-#if defined LIBTEXTCLASSIFIER_UNILIB_ICU
CREATE_UNILIB_FOR_TESTING;
int result;
// The input string here is full width
EXPECT_TRUE(unilib.ParseInt32(UTF8ToUnicodeText("οΌοΌοΌ", /*do_copy=*/false),
&result));
EXPECT_EQ(result, 123);
-#endif
}
+#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
+#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
TEST(UniLibTest, IntegerParseFullWidthWithAlpha) {
-#if defined LIBTEXTCLASSIFIER_UNILIB_ICU
CREATE_UNILIB_FOR_TESTING;
int result;
// The input string here is full width
EXPECT_FALSE(unilib.ParseInt32(UTF8ToUnicodeText("οΌaοΌ", /*do_copy=*/false),
&result));
-#endif
}
+#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
+
} // namespace
} // namespace libtextclassifier2