Fixes datetime model overtriggering, sms shortcodes being linkified to maps, not recognizing certain phone number formats, and incorrect date jumps when the granularity of the date is less than a day, and incorrect start of week for non-US locales.
am: 434442da2a
Change-Id: I5c16327d523dee342db7bd9c0beed6de150a3274
diff --git a/datetime/parser.cc b/datetime/parser.cc
index e9a6eb1..4bc5dff 100644
--- a/datetime/parser.cc
+++ b/datetime/parser.cc
@@ -95,6 +95,12 @@
}
}
+ if (model->default_locales() != nullptr) {
+ for (const int locale : *model->default_locales()) {
+ default_locale_ids_.push_back(locale);
+ }
+ }
+
use_extractors_for_locating_ = model->use_extractors_for_locating();
initialized_ = true;
@@ -110,14 +116,13 @@
anchor_start_end, results);
}
-bool DatetimeParser::Parse(
- const UnicodeText& input, const int64 reference_time_ms_utc,
- const std::string& reference_timezone, const std::string& locales,
- ModeFlag mode, bool anchor_start_end,
- std::vector<DatetimeParseResultSpan>* results) const {
- std::vector<DatetimeParseResultSpan> found_spans;
- std::unordered_set<int> executed_rules;
- for (const int locale_id : ParseAndExpandLocales(locales)) {
+bool DatetimeParser::FindSpansUsingLocales(
+ const std::vector<int>& locale_ids, const UnicodeText& input,
+ const int64 reference_time_ms_utc, const std::string& reference_timezone,
+ ModeFlag mode, bool anchor_start_end, const std::string& reference_locale,
+ std::unordered_set<int>* executed_rules,
+ std::vector<DatetimeParseResultSpan>* found_spans) const {
+ for (const int locale_id : locale_ids) {
auto rules_it = locale_to_rules_.find(locale_id);
if (rules_it == locale_to_rules_.end()) {
continue;
@@ -125,7 +130,7 @@
for (const int rule_id : rules_it->second) {
// Skip rules that were already executed in previous locales.
- if (executed_rules.find(rule_id) != executed_rules.end()) {
+ if (executed_rules->find(rule_id) != executed_rules->end()) {
continue;
}
@@ -133,15 +138,33 @@
continue;
}
- executed_rules.insert(rule_id);
+ executed_rules->insert(rule_id);
if (!ParseWithRule(rules_[rule_id], input, reference_time_ms_utc,
- reference_timezone, locale_id, anchor_start_end,
- &found_spans)) {
+ reference_timezone, reference_locale, locale_id,
+ anchor_start_end, found_spans)) {
return false;
}
}
}
+ return true;
+}
+
+bool DatetimeParser::Parse(
+ const UnicodeText& input, const int64 reference_time_ms_utc,
+ const std::string& reference_timezone, const std::string& locales,
+ ModeFlag mode, bool anchor_start_end,
+ std::vector<DatetimeParseResultSpan>* results) const {
+ std::vector<DatetimeParseResultSpan> found_spans;
+ std::unordered_set<int> executed_rules;
+ std::string reference_locale;
+ const std::vector<int> requested_locales =
+ ParseAndExpandLocales(locales, &reference_locale);
+ if (!FindSpansUsingLocales(requested_locales, input, reference_time_ms_utc,
+ reference_timezone, mode, anchor_start_end,
+ reference_locale, &executed_rules, &found_spans)) {
+ return false;
+ }
std::vector<std::pair<DatetimeParseResultSpan, int>> indexed_found_spans;
int counter = 0;
@@ -186,7 +209,8 @@
bool DatetimeParser::HandleParseMatch(
const CompiledRule& rule, const UniLib::RegexMatcher& matcher,
int64 reference_time_ms_utc, const std::string& reference_timezone,
- int locale_id, std::vector<DatetimeParseResultSpan>* result) const {
+ const std::string& reference_locale, int locale_id,
+ std::vector<DatetimeParseResultSpan>* result) const {
int status = UniLib::RegexMatcher::kNoError;
const int start = matcher.Start(&status);
if (status != UniLib::RegexMatcher::kNoError) {
@@ -200,7 +224,8 @@
DatetimeParseResultSpan parse_result;
if (!ExtractDatetime(rule, matcher, reference_time_ms_utc, reference_timezone,
- locale_id, &(parse_result.data), &parse_result.span)) {
+ reference_locale, locale_id, &(parse_result.data),
+ &parse_result.span)) {
return false;
}
if (!use_extractors_for_locating_) {
@@ -219,22 +244,24 @@
bool DatetimeParser::ParseWithRule(
const CompiledRule& rule, const UnicodeText& input,
const int64 reference_time_ms_utc, const std::string& reference_timezone,
- const int locale_id, bool anchor_start_end,
- std::vector<DatetimeParseResultSpan>* result) const {
+ const std::string& reference_locale, const int locale_id,
+ bool anchor_start_end, std::vector<DatetimeParseResultSpan>* result) const {
std::unique_ptr<UniLib::RegexMatcher> matcher =
rule.compiled_regex->Matcher(input);
int status = UniLib::RegexMatcher::kNoError;
if (anchor_start_end) {
if (matcher->Matches(&status) && status == UniLib::RegexMatcher::kNoError) {
if (!HandleParseMatch(rule, *matcher, reference_time_ms_utc,
- reference_timezone, locale_id, result)) {
+ reference_timezone, reference_locale, locale_id,
+ result)) {
return false;
}
}
} else {
while (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) {
if (!HandleParseMatch(rule, *matcher, reference_time_ms_utc,
- reference_timezone, locale_id, result)) {
+ reference_timezone, reference_locale, locale_id,
+ result)) {
return false;
}
}
@@ -242,11 +269,14 @@
return true;
}
-constexpr char const* kDefaultLocale = "";
-
std::vector<int> DatetimeParser::ParseAndExpandLocales(
- const std::string& locales) const {
+ const std::string& locales, std::string* reference_locale) const {
std::vector<StringPiece> split_locales = strings::Split(locales, ',');
+ if (!split_locales.empty()) {
+ *reference_locale = split_locales[0].ToString();
+ } else {
+ *reference_locale = "";
+ }
std::vector<int> result;
for (const StringPiece& locale_str : split_locales) {
@@ -264,36 +294,35 @@
const std::string script = locale.Script();
const std::string region = locale.Region();
- // First, try adding language-script-* locale.
- if (!script.empty()) {
- locale_it = locale_string_to_id_.find(language + "-" + script + "-*");
- if (locale_it != locale_string_to_id_.end()) {
- result.push_back(locale_it->second);
- }
- }
- // Second, try adding language-* locale.
- if (!language.empty()) {
- locale_it = locale_string_to_id_.find(language + "-*");
- if (locale_it != locale_string_to_id_.end()) {
- result.push_back(locale_it->second);
- }
- }
-
- // Second, try adding *-region locale.
+ // First, try adding *-region locale.
if (!region.empty()) {
locale_it = locale_string_to_id_.find("*-" + region);
if (locale_it != locale_string_to_id_.end()) {
result.push_back(locale_it->second);
}
}
+ // Second, try adding language-script-* locale.
+ if (!script.empty()) {
+ locale_it = locale_string_to_id_.find(language + "-" + script + "-*");
+ if (locale_it != locale_string_to_id_.end()) {
+ result.push_back(locale_it->second);
+ }
+ }
+ // Third, try adding language-* locale.
+ if (!language.empty()) {
+ locale_it = locale_string_to_id_.find(language + "-*");
+ if (locale_it != locale_string_to_id_.end()) {
+ result.push_back(locale_it->second);
+ }
+ }
}
- // Add a default fallback locale to the end of the list.
- auto locale_it = locale_string_to_id_.find(kDefaultLocale);
- if (locale_it != locale_string_to_id_.end()) {
- result.push_back(locale_it->second);
- } else {
- TC_VLOG(1) << "Could not add default locale.";
+ // Add the default locales if they haven't been added already.
+ const std::unordered_set<int> result_set(result.begin(), result.end());
+ for (const int default_locale_id : default_locale_ids_) {
+ if (result_set.find(default_locale_id) == result_set.end()) {
+ result.push_back(default_locale_id);
+ }
}
return result;
@@ -351,6 +380,7 @@
const UniLib::RegexMatcher& matcher,
const int64 reference_time_ms_utc,
const std::string& reference_timezone,
+ const std::string& reference_locale,
int locale_id, DatetimeParseResult* result,
CodepointSpan* result_span) const {
DateParseData parse;
@@ -361,12 +391,13 @@
return false;
}
- if (!calendar_lib_.InterpretParseData(parse, reference_time_ms_utc,
- reference_timezone,
- &(result->time_ms_utc))) {
+ result->granularity = GetGranularity(parse);
+
+ if (!calendar_lib_.InterpretParseData(
+ parse, reference_time_ms_utc, reference_timezone, reference_locale,
+ result->granularity, &(result->time_ms_utc))) {
return false;
}
- result->granularity = GetGranularity(parse);
return true;
}
diff --git a/datetime/parser.h b/datetime/parser.h
index c9d2119..0666607 100644
--- a/datetime/parser.h
+++ b/datetime/parser.h
@@ -20,6 +20,7 @@
#include <memory>
#include <string>
#include <unordered_map>
+#include <unordered_set>
#include <vector>
#include "datetime/extractor.h"
@@ -60,12 +61,23 @@
ZlibDecompressor* decompressor);
// Returns a list of locale ids for given locale spec string (comma-separated
- // locale names).
- std::vector<int> ParseAndExpandLocales(const std::string& locales) const;
+ // locale names). Assigns the first parsed locale to reference_locale.
+ std::vector<int> ParseAndExpandLocales(const std::string& locales,
+ std::string* reference_locale) const;
+
+ // Helper function that finds datetime spans, only using the rules associated
+ // with the given locales.
+ bool FindSpansUsingLocales(
+ const std::vector<int>& locale_ids, const UnicodeText& input,
+ const int64 reference_time_ms_utc, const std::string& reference_timezone,
+ ModeFlag mode, bool anchor_start_end, const std::string& reference_locale,
+ std::unordered_set<int>* executed_rules,
+ std::vector<DatetimeParseResultSpan>* found_spans) const;
bool ParseWithRule(const CompiledRule& rule, const UnicodeText& input,
int64 reference_time_ms_utc,
- const std::string& reference_timezone, const int locale_id,
+ const std::string& reference_timezone,
+ const std::string& reference_locale, const int locale_id,
bool anchor_start_end,
std::vector<DatetimeParseResultSpan>* result) const;
@@ -73,7 +85,8 @@
bool ExtractDatetime(const CompiledRule& rule,
const UniLib::RegexMatcher& matcher,
int64 reference_time_ms_utc,
- const std::string& reference_timezone, int locale_id,
+ const std::string& reference_timezone,
+ const std::string& reference_locale, int locale_id,
DatetimeParseResult* result,
CodepointSpan* result_span) const;
@@ -81,7 +94,8 @@
bool HandleParseMatch(const CompiledRule& rule,
const UniLib::RegexMatcher& matcher,
int64 reference_time_ms_utc,
- const std::string& reference_timezone, int locale_id,
+ const std::string& reference_timezone,
+ const std::string& reference_locale, int locale_id,
std::vector<DatetimeParseResultSpan>* result) const;
private:
@@ -93,6 +107,7 @@
std::unordered_map<DatetimeExtractorType, std::unordered_map<int, int>>
type_and_locale_to_extractor_rule_;
std::unordered_map<std::string, int> locale_string_to_id_;
+ std::vector<int> default_locale_ids_;
CalendarLib calendar_lib_;
bool use_extractors_for_locating_;
};
diff --git a/datetime/parser_test.cc b/datetime/parser_test.cc
index 36525e2..e61ed12 100644
--- a/datetime/parser_test.cc
+++ b/datetime/parser_test.cc
@@ -76,28 +76,41 @@
const int64 expected_ms_utc,
DatetimeGranularity expected_granularity,
bool anchor_start_end = false,
- const std::string& timezone = "Europe/Zurich") {
- auto expected_start_index = marked_text.find("{");
- EXPECT_TRUE(expected_start_index != std::string::npos);
- auto expected_end_index = marked_text.find("}");
- EXPECT_TRUE(expected_end_index != std::string::npos);
+ const std::string& timezone = "Europe/Zurich",
+ const std::string& locales = "en-US") {
+ const UnicodeText marked_text_unicode =
+ UTF8ToUnicodeText(marked_text, /*do_copy=*/false);
+ auto brace_open_it =
+ std::find(marked_text_unicode.begin(), marked_text_unicode.end(), '{');
+ auto brace_end_it =
+ std::find(marked_text_unicode.begin(), marked_text_unicode.end(), '}');
+ TC_CHECK(brace_open_it != marked_text_unicode.end());
+ TC_CHECK(brace_end_it != marked_text_unicode.end());
std::string text;
- text += std::string(marked_text.begin(),
- marked_text.begin() + expected_start_index);
- text += std::string(marked_text.begin() + expected_start_index + 1,
- marked_text.begin() + expected_end_index);
- text += std::string(marked_text.begin() + expected_end_index + 1,
- marked_text.end());
+ text +=
+ UnicodeText::UTF8Substring(marked_text_unicode.begin(), brace_open_it);
+ text += UnicodeText::UTF8Substring(std::next(brace_open_it), brace_end_it);
+ text += UnicodeText::UTF8Substring(std::next(brace_end_it),
+ marked_text_unicode.end());
std::vector<DatetimeParseResultSpan> results;
- if (!parser_->Parse(text, 0, timezone, /*locales=*/"", ModeFlag_ANNOTATION,
+ if (!parser_->Parse(text, 0, timezone, locales, ModeFlag_ANNOTATION,
anchor_start_end, &results)) {
TC_LOG(ERROR) << text;
TC_CHECK(false);
}
- EXPECT_TRUE(!results.empty());
+ if (results.empty()) {
+ TC_LOG(ERROR) << "No results.";
+ return false;
+ }
+
+ const int expected_start_index =
+ std::distance(marked_text_unicode.begin(), brace_open_it);
+ // The -1 bellow is to account for the opening bracket character.
+ const int expected_end_index =
+ std::distance(marked_text_unicode.begin(), brace_end_it) - 1;
std::vector<DatetimeParseResultSpan> filtered_results;
for (const DatetimeParseResultSpan& result : results) {
@@ -108,7 +121,7 @@
}
const std::vector<DatetimeParseResultSpan> expected{
- {{expected_start_index, expected_end_index - 1},
+ {{expected_start_index, expected_end_index},
{expected_ms_utc, expected_granularity},
/*target_classification_score=*/1.0,
/*priority_score=*/0.0}};
@@ -126,6 +139,14 @@
return matches;
}
+ bool ParsesCorrectlyGerman(const std::string& marked_text,
+ const int64 expected_ms_utc,
+ DatetimeGranularity expected_granularity) {
+ return ParsesCorrectly(marked_text, expected_ms_utc, expected_granularity,
+ /*anchor_start_end=*/false,
+ /*timezone=*/"Europe/Zurich", /*locales=*/"de");
+ }
+
protected:
std::string model_buffer_;
std::unique_ptr<TextClassifier> classifier_;
@@ -143,39 +164,26 @@
TEST_F(ParserTest, Parse) {
EXPECT_TRUE(
ParsesCorrectly("{January 1, 1988}", 567990000000, GRANULARITY_DAY));
- EXPECT_TRUE(ParsesCorrectly("{1 2 2018}", 1514847600000, GRANULARITY_DAY));
EXPECT_TRUE(
ParsesCorrectly("{january 31 2018}", 1517353200000, GRANULARITY_DAY));
EXPECT_TRUE(ParsesCorrectly("lorem {1 january 2018} ipsum", 1514761200000,
GRANULARITY_DAY));
- EXPECT_TRUE(ParsesCorrectly("{19/apr/2010:06:36:15}", 1271651775000,
- GRANULARITY_SECOND));
EXPECT_TRUE(ParsesCorrectly("{09/Mar/2004 22:02:40}", 1078866160000,
GRANULARITY_SECOND));
EXPECT_TRUE(ParsesCorrectly("{Dec 2, 2010 2:39:58 AM}", 1291253998000,
GRANULARITY_SECOND));
EXPECT_TRUE(ParsesCorrectly("{Jun 09 2011 15:28:14}", 1307626094000,
GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectly("{Apr 20 00:00:35 2010}", 1271714435000,
- GRANULARITY_SECOND));
EXPECT_TRUE(
ParsesCorrectly("{Mar 16 08:12:04}", 6419524000, GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectly("{2012-10-14T22:11:20}", 1350245480000,
- GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectly("{2014-07-01T14:59:55}.711Z", 1404219595000,
- GRANULARITY_SECOND));
EXPECT_TRUE(ParsesCorrectly("{2010-06-26 02:31:29},573", 1277512289000,
GRANULARITY_SECOND));
EXPECT_TRUE(ParsesCorrectly("{2006/01/22 04:11:05}", 1137899465000,
GRANULARITY_SECOND));
- EXPECT_TRUE(
- ParsesCorrectly("{150423 11:42:35}", 1429782155000, GRANULARITY_SECOND));
EXPECT_TRUE(ParsesCorrectly("{11:42:35}", 38555000, GRANULARITY_SECOND));
EXPECT_TRUE(ParsesCorrectly("{11:42:35}.173", 38555000, GRANULARITY_SECOND));
EXPECT_TRUE(
ParsesCorrectly("{23/Apr 11:42:35},173", 9715355000, GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectly("{23/Apr/2015:11:42:35}", 1429782155000,
- GRANULARITY_SECOND));
EXPECT_TRUE(ParsesCorrectly("{23/Apr/2015 11:42:35}", 1429782155000,
GRANULARITY_SECOND));
EXPECT_TRUE(ParsesCorrectly("{23-Apr-2015 11:42:35}", 1429782155000,
@@ -192,18 +200,14 @@
GRANULARITY_SECOND));
EXPECT_TRUE(ParsesCorrectly("{04/23/2015 11:42:35}.883", 1429782155000,
GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectly("{8/5/2011 3:31:18 AM}:234}", 1312507878000,
- GRANULARITY_SECOND));
EXPECT_TRUE(ParsesCorrectly("{9/28/2011 2:23:15 PM}", 1317212595000,
GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectly("{19/apr/2010:06:36:15}", 1271651775000,
- GRANULARITY_SECOND));
EXPECT_TRUE(ParsesCorrectly(
"Are sentiments apartments decisively the especially alteration. "
"Thrown shy denote ten ladies though ask saw. Or by to he going "
"think order event music. Incommode so intention defective at "
"convinced. Led income months itself and houses you. After nor "
- "you leave might share court balls. {19/apr/2010:06:36:15} Are "
+ "you leave might share court balls. {19/apr/2010 06:36:15} Are "
"sentiments apartments decisively the especially alteration. "
"Thrown shy denote ten ladies though ask saw. Or by to he going "
"think order event music. Incommode so intention defective at "
@@ -212,8 +216,8 @@
1271651775000, GRANULARITY_SECOND));
EXPECT_TRUE(ParsesCorrectly("{january 1 2018 at 4:30}", 1514777400000,
GRANULARITY_MINUTE));
- EXPECT_TRUE(ParsesCorrectly("{january 1 2018 at 4}", 1514775600000,
- GRANULARITY_HOUR));
+ EXPECT_TRUE(ParsesCorrectly("{january 1 2018 at 4:30 am}", 1514777400000,
+ GRANULARITY_MINUTE));
EXPECT_TRUE(ParsesCorrectly("{january 1 2018 at 4pm}", 1514818800000,
GRANULARITY_HOUR));
@@ -250,6 +254,117 @@
/*anchor_start_end=*/true));
}
+TEST_F(ParserTest, ParseGerman) {
+ EXPECT_TRUE(
+ ParsesCorrectlyGerman("{Januar 1 2018}", 1514761200000, GRANULARITY_DAY));
+ EXPECT_TRUE(
+ ParsesCorrectlyGerman("{1 2 2018}", 1517439600000, GRANULARITY_DAY));
+ EXPECT_TRUE(ParsesCorrectlyGerman("lorem {1 Januar 2018} ipsum",
+ 1514761200000, GRANULARITY_DAY));
+ EXPECT_TRUE(ParsesCorrectlyGerman("{19/Apr/2010:06:36:15}", 1271651775000,
+ GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectlyGerman("{09/März/2004 22:02:40}", 1078866160000,
+ GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectlyGerman("{Dez 2, 2010 2:39:58}", 1291253998000,
+ GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectlyGerman("{Juni 09 2011 15:28:14}", 1307626094000,
+ GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectlyGerman("{März 16 08:12:04}", 6419524000,
+ GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectlyGerman("{2010-06-26 02:31:29},573", 1277512289000,
+ GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectlyGerman("{2006/01/22 04:11:05}", 1137899465000,
+ GRANULARITY_SECOND));
+ EXPECT_TRUE(
+ ParsesCorrectlyGerman("{11:42:35}", 38555000, GRANULARITY_SECOND));
+ EXPECT_TRUE(
+ ParsesCorrectlyGerman("{11:42:35}.173", 38555000, GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectlyGerman("{23/Apr 11:42:35},173", 9715355000,
+ GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectlyGerman("{23/Apr/2015:11:42:35}", 1429782155000,
+ GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectlyGerman("{23/Apr/2015 11:42:35}", 1429782155000,
+ GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectlyGerman("{23-Apr-2015 11:42:35}", 1429782155000,
+ GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectlyGerman("{23-Apr-2015 11:42:35}.883", 1429782155000,
+ GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectlyGerman("{23 Apr 2015 11:42:35}", 1429782155000,
+ GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectlyGerman("{23 Apr 2015 11:42:35}.883", 1429782155000,
+ GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectlyGerman("{04/23/15 11:42:35}", 1429782155000,
+ GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectlyGerman("{04/23/2015 11:42:35}", 1429782155000,
+ GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectlyGerman("{04/23/2015 11:42:35}.883", 1429782155000,
+ GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectlyGerman("{19/apr/2010:06:36:15}", 1271651775000,
+ GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectlyGerman("{januar 1 2018 um 4:30}", 1514777400000,
+ GRANULARITY_MINUTE));
+ EXPECT_TRUE(ParsesCorrectlyGerman("{januar 1 2018 um 4:30 nachm}",
+ 1514820600000, GRANULARITY_MINUTE));
+ EXPECT_TRUE(ParsesCorrectlyGerman("{januar 1 2018 um 4 nachm}", 1514818800000,
+ GRANULARITY_HOUR));
+ EXPECT_TRUE(
+ ParsesCorrectlyGerman("{14.03.2017}", 1489446000000, GRANULARITY_DAY));
+ EXPECT_TRUE(ParsesCorrectlyGerman("{heute}", -3600000, GRANULARITY_DAY));
+ EXPECT_TRUE(
+ ParsesCorrectlyGerman("{nächste Woche}", 342000000, GRANULARITY_WEEK));
+ EXPECT_TRUE(
+ ParsesCorrectlyGerman("{nächsten Tag}", 82800000, GRANULARITY_DAY));
+ EXPECT_TRUE(
+ ParsesCorrectlyGerman("{in drei Tagen}", 255600000, GRANULARITY_DAY));
+ EXPECT_TRUE(
+ ParsesCorrectlyGerman("{in drei Wochen}", 1551600000, GRANULARITY_WEEK));
+ EXPECT_TRUE(
+ ParsesCorrectlyGerman("{vor drei Tagen}", -262800000, GRANULARITY_DAY));
+ EXPECT_TRUE(ParsesCorrectlyGerman("{morgen}", 82800000, GRANULARITY_DAY));
+ EXPECT_TRUE(
+ ParsesCorrectlyGerman("{morgen um 4:00}", 97200000, GRANULARITY_MINUTE));
+ EXPECT_TRUE(
+ ParsesCorrectlyGerman("{morgen um 4}", 97200000, GRANULARITY_HOUR));
+ EXPECT_TRUE(
+ ParsesCorrectlyGerman("{nächsten Mittwoch}", 514800000, GRANULARITY_DAY));
+ EXPECT_TRUE(ParsesCorrectlyGerman("{nächsten Mittwoch um 4}", 529200000,
+ GRANULARITY_HOUR));
+ EXPECT_TRUE(
+ ParsesCorrectlyGerman("{Vor drei Tagen}", -262800000, GRANULARITY_DAY));
+ EXPECT_TRUE(
+ ParsesCorrectlyGerman("{in einer woche}", 342000000, GRANULARITY_WEEK));
+ EXPECT_TRUE(
+ ParsesCorrectlyGerman("{in einer tag}", 82800000, GRANULARITY_DAY));
+}
+
+TEST_F(ParserTest, ParseNonUs) {
+ EXPECT_TRUE(ParsesCorrectly("{1/5/15}", 1430431200000, GRANULARITY_DAY,
+ /*anchor_start_end=*/false,
+ /*timezone=*/"Europe/Zurich",
+ /*locales=*/"en-GB"));
+ EXPECT_TRUE(ParsesCorrectly("{1/5/15}", 1430431200000, GRANULARITY_DAY,
+ /*anchor_start_end=*/false,
+ /*timezone=*/"Europe/Zurich", /*locales=*/"en"));
+}
+
+TEST_F(ParserTest, ParseUs) {
+ EXPECT_TRUE(ParsesCorrectly("{1/5/15}", 1420412400000, GRANULARITY_DAY,
+ /*anchor_start_end=*/false,
+ /*timezone=*/"Europe/Zurich",
+ /*locales=*/"en-US"));
+ EXPECT_TRUE(ParsesCorrectly("{1/5/15}", 1420412400000, GRANULARITY_DAY,
+ /*anchor_start_end=*/false,
+ /*timezone=*/"Europe/Zurich",
+ /*locales=*/"es-US"));
+}
+
+TEST_F(ParserTest, ParseUnknownLanguage) {
+ EXPECT_TRUE(ParsesCorrectly("bylo to {31. 12. 2015} v 6 hodin", 1451516400000,
+ GRANULARITY_DAY,
+ /*anchor_start_end=*/false,
+ /*timezone=*/"Europe/Zurich", /*locales=*/"xx"));
+}
+
class ParserLocaleTest : public testing::Test {
public:
void SetUp() override;
@@ -281,7 +396,8 @@
model.locales.push_back("en-*");
model.locales.push_back("zh-Hant-*");
model.locales.push_back("*-CH");
- model.locales.push_back("");
+ model.locales.push_back("default");
+ model.default_locales.push_back(6);
AddPattern(/*regex=*/"en-US", /*locale=*/0, &model.patterns);
AddPattern(/*regex=*/"en-CH", /*locale=*/1, &model.patterns);
diff --git a/feature-processor.cc b/feature-processor.cc
index aa71740..551e649 100644
--- a/feature-processor.cc
+++ b/feature-processor.cc
@@ -771,12 +771,8 @@
} // namespace internal
-bool FeatureProcessor::ExtractFeatures(
- const std::vector<Token>& tokens, TokenSpan token_span,
- CodepointSpan selection_span_for_feature,
- const EmbeddingExecutor* embedding_executor,
- EmbeddingCache* embedding_cache, int feature_vector_size,
- std::unique_ptr<CachedFeatures>* cached_features) const {
+bool FeatureProcessor::HasEnoughSupportedCodepoints(
+ const std::vector<Token>& tokens, TokenSpan token_span) const {
if (options_->min_supported_codepoint_ratio() > 0) {
const float supported_codepoint_ratio =
SupportedCodepointsRatio(token_span, tokens);
@@ -786,7 +782,15 @@
return false;
}
}
+ return true;
+}
+bool FeatureProcessor::ExtractFeatures(
+ const std::vector<Token>& tokens, TokenSpan token_span,
+ CodepointSpan selection_span_for_feature,
+ const EmbeddingExecutor* embedding_executor,
+ EmbeddingCache* embedding_cache, int feature_vector_size,
+ std::unique_ptr<CachedFeatures>* cached_features) const {
std::unique_ptr<std::vector<float>> features(new std::vector<float>());
features->reserve(feature_vector_size * TokenSpanSize(token_span));
for (int i = token_span.first; i < token_span.second; ++i) {
diff --git a/feature-processor.h b/feature-processor.h
index 553bd1e..98d3449 100644
--- a/feature-processor.h
+++ b/feature-processor.h
@@ -159,6 +159,11 @@
bool only_use_line_with_click,
std::vector<Token>* tokens, int* click_pos) const;
+ // Returns true if the token span has enough supported codepoints (as defined
+ // in the model config) or not and model should not run.
+ bool HasEnoughSupportedCodepoints(const std::vector<Token>& tokens,
+ TokenSpan token_span) const;
+
// Extracts features as a CachedFeatures object that can be used for repeated
// inference over token spans in the given context.
bool ExtractFeatures(const std::vector<Token>& tokens, TokenSpan token_span,
diff --git a/feature-processor_test.cc b/feature-processor_test.cc
index 70ef0a7..58b3033 100644
--- a/feature-processor_test.cc
+++ b/feature-processor_test.cc
@@ -588,10 +588,6 @@
EXPECT_TRUE(feature_processor.IsCodepointInRanges(
25000, feature_processor.supported_codepoint_ranges_));
- std::unique_ptr<CachedFeatures> cached_features;
-
- FakeEmbeddingExecutor embedding_executor;
-
const std::vector<Token> tokens = {Token("ěěě", 0, 3), Token("řřř", 4, 7),
Token("eee", 8, 11)};
@@ -601,11 +597,8 @@
TestingFeatureProcessor feature_processor2(
flatbuffers::GetRoot<FeatureProcessorOptions>(options2_fb.data()),
&unilib);
- EXPECT_TRUE(feature_processor2.ExtractFeatures(
- tokens, /*token_span=*/{0, 3},
- /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex},
- &embedding_executor, /*embedding_cache=*/nullptr,
- /*feature_vector_size=*/4, &cached_features));
+ EXPECT_TRUE(feature_processor2.HasEnoughSupportedCodepoints(
+ tokens, /*token_span=*/{0, 3}));
options.min_supported_codepoint_ratio = 0.2;
flatbuffers::DetachedBuffer options3_fb =
@@ -613,11 +606,8 @@
TestingFeatureProcessor feature_processor3(
flatbuffers::GetRoot<FeatureProcessorOptions>(options3_fb.data()),
&unilib);
- EXPECT_TRUE(feature_processor3.ExtractFeatures(
- tokens, /*token_span=*/{0, 3},
- /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex},
- &embedding_executor, /*embedding_cache=*/nullptr,
- /*feature_vector_size=*/4, &cached_features));
+ EXPECT_TRUE(feature_processor3.HasEnoughSupportedCodepoints(
+ tokens, /*token_span=*/{0, 3}));
options.min_supported_codepoint_ratio = 0.5;
flatbuffers::DetachedBuffer options4_fb =
@@ -625,11 +615,8 @@
TestingFeatureProcessor feature_processor4(
flatbuffers::GetRoot<FeatureProcessorOptions>(options4_fb.data()),
&unilib);
- EXPECT_FALSE(feature_processor4.ExtractFeatures(
- tokens, /*token_span=*/{0, 3},
- /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex},
- &embedding_executor, /*embedding_cache=*/nullptr,
- /*feature_vector_size=*/4, &cached_features));
+ EXPECT_FALSE(feature_processor4.HasEnoughSupportedCodepoints(
+ tokens, /*token_span=*/{0, 3}));
}
TEST(FeatureProcessorTest, InSpanFeature) {
diff --git a/model.fbs b/model.fbs
index 23cf229..fb9778b 100755
--- a/model.fbs
+++ b/model.fbs
@@ -165,6 +165,9 @@
// Limits for addresses.
address_min_num_tokens:int;
+
+ // Maximum number of tokens to attempt a classification (-1 is unlimited).
+ max_num_tokens:int = -1;
}
// List of regular expression matchers to check.
@@ -249,6 +252,22 @@
// If true, will use the extractors for determining the match location as
// opposed to using the location where the global pattern matched.
use_extractors_for_locating:bool = 1;
+
+ // List of locale ids, rules of whose are always run, after the requested
+ // ones.
+ default_locales:[int];
+}
+
+namespace libtextclassifier2.DatetimeModelLibrary_;
+table Item {
+ key:string;
+ value:libtextclassifier2.DatetimeModel;
+}
+
+// A set of named DateTime models.
+namespace libtextclassifier2;
+table DatetimeModelLibrary {
+ models:[libtextclassifier2.DatetimeModelLibrary_.Item];
}
// Options controlling the output of the Tensorflow Lite models.
diff --git a/model_generated.h b/model_generated.h
index ecf08fc..6ef75f6 100755
--- a/model_generated.h
+++ b/model_generated.h
@@ -59,6 +59,16 @@
struct DatetimeModel;
struct DatetimeModelT;
+namespace DatetimeModelLibrary_ {
+
+struct Item;
+struct ItemT;
+
+} // namespace DatetimeModelLibrary_
+
+struct DatetimeModelLibrary;
+struct DatetimeModelLibraryT;
+
struct ModelTriggeringOptions;
struct ModelTriggeringOptionsT;
@@ -730,10 +740,12 @@
int32_t phone_min_num_digits;
int32_t phone_max_num_digits;
int32_t address_min_num_tokens;
+ int32_t max_num_tokens;
ClassificationModelOptionsT()
: phone_min_num_digits(7),
phone_max_num_digits(15),
- address_min_num_tokens(0) {
+ address_min_num_tokens(0),
+ max_num_tokens(-1) {
}
};
@@ -742,7 +754,8 @@
enum {
VT_PHONE_MIN_NUM_DIGITS = 4,
VT_PHONE_MAX_NUM_DIGITS = 6,
- VT_ADDRESS_MIN_NUM_TOKENS = 8
+ VT_ADDRESS_MIN_NUM_TOKENS = 8,
+ VT_MAX_NUM_TOKENS = 10
};
int32_t phone_min_num_digits() const {
return GetField<int32_t>(VT_PHONE_MIN_NUM_DIGITS, 7);
@@ -753,11 +766,15 @@
int32_t address_min_num_tokens() const {
return GetField<int32_t>(VT_ADDRESS_MIN_NUM_TOKENS, 0);
}
+ int32_t max_num_tokens() const {
+ return GetField<int32_t>(VT_MAX_NUM_TOKENS, -1);
+ }
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyField<int32_t>(verifier, VT_PHONE_MIN_NUM_DIGITS) &&
VerifyField<int32_t>(verifier, VT_PHONE_MAX_NUM_DIGITS) &&
VerifyField<int32_t>(verifier, VT_ADDRESS_MIN_NUM_TOKENS) &&
+ VerifyField<int32_t>(verifier, VT_MAX_NUM_TOKENS) &&
verifier.EndTable();
}
ClassificationModelOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@@ -777,6 +794,9 @@
void add_address_min_num_tokens(int32_t address_min_num_tokens) {
fbb_.AddElement<int32_t>(ClassificationModelOptions::VT_ADDRESS_MIN_NUM_TOKENS, address_min_num_tokens, 0);
}
+ void add_max_num_tokens(int32_t max_num_tokens) {
+ fbb_.AddElement<int32_t>(ClassificationModelOptions::VT_MAX_NUM_TOKENS, max_num_tokens, -1);
+ }
explicit ClassificationModelOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
@@ -793,8 +813,10 @@
flatbuffers::FlatBufferBuilder &_fbb,
int32_t phone_min_num_digits = 7,
int32_t phone_max_num_digits = 15,
- int32_t address_min_num_tokens = 0) {
+ int32_t address_min_num_tokens = 0,
+ int32_t max_num_tokens = -1) {
ClassificationModelOptionsBuilder builder_(_fbb);
+ builder_.add_max_num_tokens(max_num_tokens);
builder_.add_address_min_num_tokens(address_min_num_tokens);
builder_.add_phone_max_num_digits(phone_max_num_digits);
builder_.add_phone_min_num_digits(phone_min_num_digits);
@@ -1339,6 +1361,7 @@
std::vector<std::unique_ptr<DatetimeModelPatternT>> patterns;
std::vector<std::unique_ptr<DatetimeModelExtractorT>> extractors;
bool use_extractors_for_locating;
+ std::vector<int32_t> default_locales;
DatetimeModelT()
: use_extractors_for_locating(true) {
}
@@ -1350,7 +1373,8 @@
VT_LOCALES = 4,
VT_PATTERNS = 6,
VT_EXTRACTORS = 8,
- VT_USE_EXTRACTORS_FOR_LOCATING = 10
+ VT_USE_EXTRACTORS_FOR_LOCATING = 10,
+ VT_DEFAULT_LOCALES = 12
};
const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *locales() const {
return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_LOCALES);
@@ -1364,6 +1388,9 @@
bool use_extractors_for_locating() const {
return GetField<uint8_t>(VT_USE_EXTRACTORS_FOR_LOCATING, 1) != 0;
}
+ const flatbuffers::Vector<int32_t> *default_locales() const {
+ return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_DEFAULT_LOCALES);
+ }
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyOffset(verifier, VT_LOCALES) &&
@@ -1376,6 +1403,8 @@
verifier.Verify(extractors()) &&
verifier.VerifyVectorOfTables(extractors()) &&
VerifyField<uint8_t>(verifier, VT_USE_EXTRACTORS_FOR_LOCATING) &&
+ VerifyOffset(verifier, VT_DEFAULT_LOCALES) &&
+ verifier.Verify(default_locales()) &&
verifier.EndTable();
}
DatetimeModelT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@@ -1398,6 +1427,9 @@
void add_use_extractors_for_locating(bool use_extractors_for_locating) {
fbb_.AddElement<uint8_t>(DatetimeModel::VT_USE_EXTRACTORS_FOR_LOCATING, static_cast<uint8_t>(use_extractors_for_locating), 1);
}
+ void add_default_locales(flatbuffers::Offset<flatbuffers::Vector<int32_t>> default_locales) {
+ fbb_.AddOffset(DatetimeModel::VT_DEFAULT_LOCALES, default_locales);
+ }
explicit DatetimeModelBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
@@ -1415,8 +1447,10 @@
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> locales = 0,
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<DatetimeModelPattern>>> patterns = 0,
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<DatetimeModelExtractor>>> extractors = 0,
- bool use_extractors_for_locating = true) {
+ bool use_extractors_for_locating = true,
+ flatbuffers::Offset<flatbuffers::Vector<int32_t>> default_locales = 0) {
DatetimeModelBuilder builder_(_fbb);
+ builder_.add_default_locales(default_locales);
builder_.add_extractors(extractors);
builder_.add_patterns(patterns);
builder_.add_locales(locales);
@@ -1429,17 +1463,162 @@
const std::vector<flatbuffers::Offset<flatbuffers::String>> *locales = nullptr,
const std::vector<flatbuffers::Offset<DatetimeModelPattern>> *patterns = nullptr,
const std::vector<flatbuffers::Offset<DatetimeModelExtractor>> *extractors = nullptr,
- bool use_extractors_for_locating = true) {
+ bool use_extractors_for_locating = true,
+ const std::vector<int32_t> *default_locales = nullptr) {
return libtextclassifier2::CreateDatetimeModel(
_fbb,
locales ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*locales) : 0,
patterns ? _fbb.CreateVector<flatbuffers::Offset<DatetimeModelPattern>>(*patterns) : 0,
extractors ? _fbb.CreateVector<flatbuffers::Offset<DatetimeModelExtractor>>(*extractors) : 0,
- use_extractors_for_locating);
+ use_extractors_for_locating,
+ default_locales ? _fbb.CreateVector<int32_t>(*default_locales) : 0);
}
flatbuffers::Offset<DatetimeModel> CreateDatetimeModel(flatbuffers::FlatBufferBuilder &_fbb, const DatetimeModelT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+namespace DatetimeModelLibrary_ {
+
+struct ItemT : public flatbuffers::NativeTable {
+ typedef Item TableType;
+ std::string key;
+ std::unique_ptr<libtextclassifier2::DatetimeModelT> value;
+ ItemT() {
+ }
+};
+
+struct Item FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef ItemT NativeTableType;
+ enum {
+ VT_KEY = 4,
+ VT_VALUE = 6
+ };
+ const flatbuffers::String *key() const {
+ return GetPointer<const flatbuffers::String *>(VT_KEY);
+ }
+ const libtextclassifier2::DatetimeModel *value() const {
+ return GetPointer<const libtextclassifier2::DatetimeModel *>(VT_VALUE);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffset(verifier, VT_KEY) &&
+ verifier.Verify(key()) &&
+ VerifyOffset(verifier, VT_VALUE) &&
+ verifier.VerifyTable(value()) &&
+ verifier.EndTable();
+ }
+ ItemT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(ItemT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<Item> Pack(flatbuffers::FlatBufferBuilder &_fbb, const ItemT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct ItemBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_key(flatbuffers::Offset<flatbuffers::String> key) {
+ fbb_.AddOffset(Item::VT_KEY, key);
+ }
+ void add_value(flatbuffers::Offset<libtextclassifier2::DatetimeModel> value) {
+ fbb_.AddOffset(Item::VT_VALUE, value);
+ }
+ explicit ItemBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ ItemBuilder &operator=(const ItemBuilder &);
+ flatbuffers::Offset<Item> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<Item>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<Item> CreateItem(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::String> key = 0,
+ flatbuffers::Offset<libtextclassifier2::DatetimeModel> value = 0) {
+ ItemBuilder builder_(_fbb);
+ builder_.add_value(value);
+ builder_.add_key(key);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<Item> CreateItemDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const char *key = nullptr,
+ flatbuffers::Offset<libtextclassifier2::DatetimeModel> value = 0) {
+ return libtextclassifier2::DatetimeModelLibrary_::CreateItem(
+ _fbb,
+ key ? _fbb.CreateString(key) : 0,
+ value);
+}
+
+flatbuffers::Offset<Item> CreateItem(flatbuffers::FlatBufferBuilder &_fbb, const ItemT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+} // namespace DatetimeModelLibrary_
+
+struct DatetimeModelLibraryT : public flatbuffers::NativeTable {
+ typedef DatetimeModelLibrary TableType;
+ std::vector<std::unique_ptr<libtextclassifier2::DatetimeModelLibrary_::ItemT>> models;
+ DatetimeModelLibraryT() {
+ }
+};
+
+struct DatetimeModelLibrary FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef DatetimeModelLibraryT NativeTableType;
+ enum {
+ VT_MODELS = 4
+ };
+ const flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::DatetimeModelLibrary_::Item>> *models() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::DatetimeModelLibrary_::Item>> *>(VT_MODELS);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffset(verifier, VT_MODELS) &&
+ verifier.Verify(models()) &&
+ verifier.VerifyVectorOfTables(models()) &&
+ verifier.EndTable();
+ }
+ DatetimeModelLibraryT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(DatetimeModelLibraryT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<DatetimeModelLibrary> Pack(flatbuffers::FlatBufferBuilder &_fbb, const DatetimeModelLibraryT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct DatetimeModelLibraryBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_models(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::DatetimeModelLibrary_::Item>>> models) {
+ fbb_.AddOffset(DatetimeModelLibrary::VT_MODELS, models);
+ }
+ explicit DatetimeModelLibraryBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ DatetimeModelLibraryBuilder &operator=(const DatetimeModelLibraryBuilder &);
+ flatbuffers::Offset<DatetimeModelLibrary> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<DatetimeModelLibrary>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<DatetimeModelLibrary> CreateDatetimeModelLibrary(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::DatetimeModelLibrary_::Item>>> models = 0) {
+ DatetimeModelLibraryBuilder builder_(_fbb);
+ builder_.add_models(models);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<DatetimeModelLibrary> CreateDatetimeModelLibraryDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const std::vector<flatbuffers::Offset<libtextclassifier2::DatetimeModelLibrary_::Item>> *models = nullptr) {
+ return libtextclassifier2::CreateDatetimeModelLibrary(
+ _fbb,
+ models ? _fbb.CreateVector<flatbuffers::Offset<libtextclassifier2::DatetimeModelLibrary_::Item>>(*models) : 0);
+}
+
+flatbuffers::Offset<DatetimeModelLibrary> CreateDatetimeModelLibrary(flatbuffers::FlatBufferBuilder &_fbb, const DatetimeModelLibraryT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
struct ModelTriggeringOptionsT : public flatbuffers::NativeTable {
typedef ModelTriggeringOptions TableType;
float min_annotate_confidence;
@@ -2809,6 +2988,7 @@
{ auto _e = phone_min_num_digits(); _o->phone_min_num_digits = _e; };
{ auto _e = phone_max_num_digits(); _o->phone_max_num_digits = _e; };
{ auto _e = address_min_num_tokens(); _o->address_min_num_tokens = _e; };
+ { auto _e = max_num_tokens(); _o->max_num_tokens = _e; };
}
inline flatbuffers::Offset<ClassificationModelOptions> ClassificationModelOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ClassificationModelOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
@@ -2822,11 +3002,13 @@
auto _phone_min_num_digits = _o->phone_min_num_digits;
auto _phone_max_num_digits = _o->phone_max_num_digits;
auto _address_min_num_tokens = _o->address_min_num_tokens;
+ auto _max_num_tokens = _o->max_num_tokens;
return libtextclassifier2::CreateClassificationModelOptions(
_fbb,
_phone_min_num_digits,
_phone_max_num_digits,
- _address_min_num_tokens);
+ _address_min_num_tokens,
+ _max_num_tokens);
}
namespace RegexModel_ {
@@ -3025,6 +3207,7 @@
{ auto _e = patterns(); if (_e) { _o->patterns.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->patterns[_i] = std::unique_ptr<DatetimeModelPatternT>(_e->Get(_i)->UnPack(_resolver)); } } };
{ auto _e = extractors(); if (_e) { _o->extractors.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->extractors[_i] = std::unique_ptr<DatetimeModelExtractorT>(_e->Get(_i)->UnPack(_resolver)); } } };
{ auto _e = use_extractors_for_locating(); _o->use_extractors_for_locating = _e; };
+ { auto _e = default_locales(); if (_e) { _o->default_locales.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->default_locales[_i] = _e->Get(_i); } } };
}
inline flatbuffers::Offset<DatetimeModel> DatetimeModel::Pack(flatbuffers::FlatBufferBuilder &_fbb, const DatetimeModelT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
@@ -3039,12 +3222,73 @@
auto _patterns = _o->patterns.size() ? _fbb.CreateVector<flatbuffers::Offset<DatetimeModelPattern>> (_o->patterns.size(), [](size_t i, _VectorArgs *__va) { return CreateDatetimeModelPattern(*__va->__fbb, __va->__o->patterns[i].get(), __va->__rehasher); }, &_va ) : 0;
auto _extractors = _o->extractors.size() ? _fbb.CreateVector<flatbuffers::Offset<DatetimeModelExtractor>> (_o->extractors.size(), [](size_t i, _VectorArgs *__va) { return CreateDatetimeModelExtractor(*__va->__fbb, __va->__o->extractors[i].get(), __va->__rehasher); }, &_va ) : 0;
auto _use_extractors_for_locating = _o->use_extractors_for_locating;
+ auto _default_locales = _o->default_locales.size() ? _fbb.CreateVector(_o->default_locales) : 0;
return libtextclassifier2::CreateDatetimeModel(
_fbb,
_locales,
_patterns,
_extractors,
- _use_extractors_for_locating);
+ _use_extractors_for_locating,
+ _default_locales);
+}
+
+namespace DatetimeModelLibrary_ {
+
+inline ItemT *Item::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new ItemT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void Item::UnPackTo(ItemT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ { auto _e = key(); if (_e) _o->key = _e->str(); };
+ { auto _e = value(); if (_e) _o->value = std::unique_ptr<libtextclassifier2::DatetimeModelT>(_e->UnPack(_resolver)); };
+}
+
+inline flatbuffers::Offset<Item> Item::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ItemT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateItem(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<Item> CreateItem(flatbuffers::FlatBufferBuilder &_fbb, const ItemT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const ItemT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _key = _o->key.empty() ? 0 : _fbb.CreateString(_o->key);
+ auto _value = _o->value ? CreateDatetimeModel(_fbb, _o->value.get(), _rehasher) : 0;
+ return libtextclassifier2::DatetimeModelLibrary_::CreateItem(
+ _fbb,
+ _key,
+ _value);
+}
+
+} // namespace DatetimeModelLibrary_
+
+inline DatetimeModelLibraryT *DatetimeModelLibrary::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new DatetimeModelLibraryT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void DatetimeModelLibrary::UnPackTo(DatetimeModelLibraryT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ { auto _e = models(); if (_e) { _o->models.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->models[_i] = std::unique_ptr<libtextclassifier2::DatetimeModelLibrary_::ItemT>(_e->Get(_i)->UnPack(_resolver)); } } };
+}
+
+inline flatbuffers::Offset<DatetimeModelLibrary> DatetimeModelLibrary::Pack(flatbuffers::FlatBufferBuilder &_fbb, const DatetimeModelLibraryT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateDatetimeModelLibrary(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<DatetimeModelLibrary> CreateDatetimeModelLibrary(flatbuffers::FlatBufferBuilder &_fbb, const DatetimeModelLibraryT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const DatetimeModelLibraryT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _models = _o->models.size() ? _fbb.CreateVector<flatbuffers::Offset<libtextclassifier2::DatetimeModelLibrary_::Item>> (_o->models.size(), [](size_t i, _VectorArgs *__va) { return CreateItem(*__va->__fbb, __va->__o->models[i].get(), __va->__rehasher); }, &_va ) : 0;
+ return libtextclassifier2::CreateDatetimeModelLibrary(
+ _fbb,
+ _models);
}
inline ModelTriggeringOptionsT *ModelTriggeringOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
diff --git a/models/textclassifier.ar.model b/models/textclassifier.ar.model
index 39e7ea2..2342daa 100644
--- a/models/textclassifier.ar.model
+++ b/models/textclassifier.ar.model
Binary files differ
diff --git a/models/textclassifier.en.model b/models/textclassifier.en.model
index 04f90b7..a40f940 100644
--- a/models/textclassifier.en.model
+++ b/models/textclassifier.en.model
Binary files differ
diff --git a/models/textclassifier.es.model b/models/textclassifier.es.model
index bc79119..7de4e5d 100644
--- a/models/textclassifier.es.model
+++ b/models/textclassifier.es.model
Binary files differ
diff --git a/models/textclassifier.fr.model b/models/textclassifier.fr.model
index 768968e..1072041 100644
--- a/models/textclassifier.fr.model
+++ b/models/textclassifier.fr.model
Binary files differ
diff --git a/models/textclassifier.it.model b/models/textclassifier.it.model
index 823d02b..5bc98ae 100644
--- a/models/textclassifier.it.model
+++ b/models/textclassifier.it.model
Binary files differ
diff --git a/models/textclassifier.ja.model b/models/textclassifier.ja.model
index c65b9b0..9f60b8a 100644
--- a/models/textclassifier.ja.model
+++ b/models/textclassifier.ja.model
Binary files differ
diff --git a/models/textclassifier.ko.model b/models/textclassifier.ko.model
index 0c12ebe..451df45 100644
--- a/models/textclassifier.ko.model
+++ b/models/textclassifier.ko.model
Binary files differ
diff --git a/models/textclassifier.nl.model b/models/textclassifier.nl.model
index a4aedb5..07ea076 100644
--- a/models/textclassifier.nl.model
+++ b/models/textclassifier.nl.model
Binary files differ
diff --git a/models/textclassifier.pl.model b/models/textclassifier.pl.model
index 6797f93..6cf62a5 100644
--- a/models/textclassifier.pl.model
+++ b/models/textclassifier.pl.model
Binary files differ
diff --git a/models/textclassifier.pt.model b/models/textclassifier.pt.model
index 39fa301..a745d58 100644
--- a/models/textclassifier.pt.model
+++ b/models/textclassifier.pt.model
Binary files differ
diff --git a/models/textclassifier.ru.model b/models/textclassifier.ru.model
index a824d4c..aa97ebc 100644
--- a/models/textclassifier.ru.model
+++ b/models/textclassifier.ru.model
Binary files differ
diff --git a/models/textclassifier.th.model b/models/textclassifier.th.model
index 5430511..37339b7 100644
--- a/models/textclassifier.th.model
+++ b/models/textclassifier.th.model
Binary files differ
diff --git a/models/textclassifier.tr.model b/models/textclassifier.tr.model
index 2132f89..2405d9e 100644
--- a/models/textclassifier.tr.model
+++ b/models/textclassifier.tr.model
Binary files differ
diff --git a/models/textclassifier.universal.model b/models/textclassifier.universal.model
new file mode 100644
index 0000000..5c4220f
--- /dev/null
+++ b/models/textclassifier.universal.model
Binary files differ
diff --git a/models/textclassifier.zh-Hant.model b/models/textclassifier.zh-Hant.model
index 96341ce..32edfe4 100644
--- a/models/textclassifier.zh-Hant.model
+++ b/models/textclassifier.zh-Hant.model
Binary files differ
diff --git a/models/textclassifier.zh.model b/models/textclassifier.zh.model
index adcab0f..eb1ff61 100644
--- a/models/textclassifier.zh.model
+++ b/models/textclassifier.zh.model
Binary files differ
diff --git a/test_data/test_model.fb b/test_data/test_model.fb
index 0f0161d..c651bdb 100644
--- a/test_data/test_model.fb
+++ b/test_data/test_model.fb
Binary files differ
diff --git a/test_data/test_model_cc.fb b/test_data/test_model_cc.fb
index 2500551..53af6bf 100644
--- a/test_data/test_model_cc.fb
+++ b/test_data/test_model_cc.fb
Binary files differ
diff --git a/test_data/wrong_embeddings.fb b/test_data/wrong_embeddings.fb
index 9879e0b..e1aa3ea 100644
--- a/test_data/wrong_embeddings.fb
+++ b/test_data/wrong_embeddings.fb
Binary files differ
diff --git a/text-classifier.cc b/text-classifier.cc
index 417c84a..e20813a 100644
--- a/text-classifier.cc
+++ b/text-classifier.cc
@@ -246,6 +246,7 @@
if (model_->regex_model()) {
if (!InitializeRegexModel(decompressor.get())) {
TC_LOG(ERROR) << "Could not initialize regex model.";
+ return;
}
}
@@ -284,8 +285,7 @@
bool TextClassifier::InitializeRegexModel(ZlibDecompressor* decompressor) {
if (!model_->regex_model()->patterns()) {
- initialized_ = false;
- return false;
+ return true;
}
// Initialize pattern recognizers.
@@ -705,6 +705,11 @@
}
extraction_span = IntersectTokenSpans(extraction_span, {0, tokens->size()});
+ if (!selection_feature_processor_->HasEnoughSupportedCodepoints(
+ *tokens, extraction_span)) {
+ return true;
+ }
+
std::unique_ptr<CachedFeatures> cached_features;
if (!selection_feature_processor_->ExtractFeatures(
*tokens, extraction_span,
@@ -831,6 +836,14 @@
&tokens, &click_pos);
const TokenSpan selection_token_span =
CodepointSpanToTokenSpan(tokens, selection_indices);
+ const int selection_num_tokens = TokenSpanSize(selection_token_span);
+ if (model_->classification_options()->max_num_tokens() > 0 &&
+ model_->classification_options()->max_num_tokens() <
+ selection_num_tokens) {
+ *classification_results = {{kOtherCollection, 1.0}};
+ return true;
+ }
+
const FeatureProcessorOptions_::BoundsSensitiveFeatures*
bounds_sensitive_features =
classification_feature_processor_->GetOptions()
@@ -865,6 +878,12 @@
}
extraction_span = IntersectTokenSpans(extraction_span, {0, tokens.size()});
+ if (!classification_feature_processor_->HasEnoughSupportedCodepoints(
+ tokens, extraction_span)) {
+ *classification_results = {{kOtherCollection, 1.0}};
+ return true;
+ }
+
std::unique_ptr<CachedFeatures> cached_features;
if (!classification_feature_processor_->ExtractFeatures(
tokens, extraction_span, selection_indices, embedding_executor_.get(),
@@ -928,7 +947,7 @@
// Address class sanity check.
if (!classification_results->empty() &&
classification_results->begin()->collection == kAddressCollection) {
- if (TokenSpanSize(selection_token_span) <
+ if (selection_num_tokens <
model_->classification_options()->address_min_num_tokens()) {
*classification_results = {{kOtherCollection, 1.0}};
}
@@ -1108,6 +1127,12 @@
/*click_pos=*/nullptr);
const TokenSpan full_line_span = {0, tokens->size()};
+ // TODO(zilka): Add support for greater granularity of this check.
+ if (!selection_feature_processor_->HasEnoughSupportedCodepoints(
+ *tokens, full_line_span)) {
+ continue;
+ }
+
std::unique_ptr<CachedFeatures> cached_features;
if (!selection_feature_processor_->ExtractFeatures(
*tokens, full_line_span,
@@ -1162,9 +1187,14 @@
return true;
}
-const FeatureProcessor& TextClassifier::SelectionFeatureProcessorForTests()
+const FeatureProcessor* TextClassifier::SelectionFeatureProcessorForTests()
const {
- return *selection_feature_processor_;
+ return selection_feature_processor_.get();
+}
+
+const FeatureProcessor* TextClassifier::ClassificationFeatureProcessorForTests()
+ const {
+ return classification_feature_processor_.get();
}
const DatetimeParser* TextClassifier::DatetimeParserForTests() const {
@@ -1513,7 +1543,7 @@
const std::string& locales, ModeFlag mode,
std::vector<AnnotatedSpan>* result) const {
if (!datetime_parser_) {
- return false;
+ return true;
}
std::vector<DatetimeParseResultSpan> datetime_spans;
diff --git a/text-classifier.h b/text-classifier.h
index ad94dc4..0692ecd 100644
--- a/text-classifier.h
+++ b/text-classifier.h
@@ -148,8 +148,9 @@
const std::string& context,
const AnnotationOptions& options = AnnotationOptions::Default()) const;
- // Exposes the selection feature processor for tests and evaluations.
- const FeatureProcessor& SelectionFeatureProcessorForTests() const;
+ // Exposes the feature processor for tests and evaluations.
+ const FeatureProcessor* SelectionFeatureProcessorForTests() const;
+ const FeatureProcessor* ClassificationFeatureProcessorForTests() const;
// Exposes the date time parser for tests and evaluations.
const DatetimeParser* DatetimeParserForTests() const;
diff --git a/text-classifier_test.cc b/text-classifier_test.cc
index 440cedf..c8ced76 100644
--- a/text-classifier_test.cc
+++ b/text-classifier_test.cc
@@ -1010,7 +1010,7 @@
result.clear();
options.reference_timezone = "Europe/Zurich";
options.locales = "en-US";
- result = classifier->ClassifyText("03/05", {0, 5}, options);
+ result = classifier->ClassifyText("03.05.1970", {0, 10}, options);
ASSERT_EQ(result.size(), 1);
EXPECT_THAT(result[0].collection, "date");
@@ -1020,8 +1020,8 @@
result.clear();
options.reference_timezone = "Europe/Zurich";
- options.locales = "en-GB,en-US";
- result = classifier->ClassifyText("03/05", {0, 5}, options);
+ options.locales = "de";
+ result = classifier->ClassifyText("03.05.1970", {0, 10}, options);
ASSERT_EQ(result.size(), 1);
EXPECT_THAT(result[0].collection, "date");
@@ -1212,6 +1212,44 @@
#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
+TEST_P(TextClassifierTest, MaxTokenLength) {
+ CREATE_UNILIB_FOR_TESTING;
+ const std::string test_model = ReadFile(GetModelPath() + GetParam());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ std::unique_ptr<TextClassifier> classifier;
+
+ // With unrestricted number of tokens should behave normally.
+ unpacked_model->classification_options->max_num_tokens = -1;
+
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(Model::Pack(builder, unpacked_model.get()));
+ classifier = TextClassifier::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), &unilib);
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ(FirstResult(classifier->ClassifyText(
+ "I live at 350 Third Street, Cambridge.", {10, 37})),
+ "address");
+
+ // Raise the maximum number of tokens to suppress the classification.
+ unpacked_model->classification_options->max_num_tokens = 3;
+
+ flatbuffers::FlatBufferBuilder builder2;
+ builder2.Finish(Model::Pack(builder2, unpacked_model.get()));
+ classifier = TextClassifier::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder2.GetBufferPointer()),
+ builder2.GetSize(), &unilib);
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ(FirstResult(classifier->ClassifyText(
+ "I live at 350 Third Street, Cambridge.", {10, 37})),
+ "other");
+}
+#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
+
+#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
TEST_P(TextClassifierTest, MinAddressTokenLength) {
CREATE_UNILIB_FOR_TESTING;
const std::string test_model = ReadFile(GetModelPath() + GetParam());
diff --git a/util/base/macros.h b/util/base/macros.h
index edb980e..a021ab9 100644
--- a/util/base/macros.h
+++ b/util/base/macros.h
@@ -68,10 +68,12 @@
//
// In either case this macro has no effect on runtime behavior and performance
// of code.
-#if defined(__clang__) && defined(LANG_CXX11) && defined(__has_warning)
+#if defined(__clang__) && defined(__has_warning)
#if __has_feature(cxx_attributes) && __has_warning("-Wimplicit-fallthrough")
-#define TC_FALLTHROUGH_INTENDED [[clang::fallthrough]] // NOLINT
+#define TC_FALLTHROUGH_INTENDED [[clang::fallthrough]]
#endif
+#elif defined(__GNUC__) && __GNUC__ >= 7
+#define TC_FALLTHROUGH_INTENDED [[gnu::fallthrough]]
#endif
#ifndef TC_FALLTHROUGH_INTENDED
diff --git a/util/calendar/calendar-icu.cc b/util/calendar/calendar-icu.cc
index 99deeb2..34ea22d 100644
--- a/util/calendar/calendar-icu.cc
+++ b/util/calendar/calendar-icu.cc
@@ -18,6 +18,7 @@
#include <memory>
+#include "util/base/macros.h"
#include "unicode/gregocal.h"
#include "unicode/timezone.h"
#include "unicode/ucal.h"
@@ -62,7 +63,7 @@
TC_LOG(ERROR) << "error day of week";
return false;
}
- date->add(UCalendarDateFields::UCAL_DATE, 1, status);
+ date->add(UCalendarDateFields::UCAL_DAY_OF_MONTH, 1, status);
if (U_FAILURE(status)) {
TC_LOG(ERROR) << "error adding a day";
return false;
@@ -72,7 +73,7 @@
}
return true;
case DateParseData::DAY:
- date->add(UCalendarDateFields::UCAL_DATE, -1 * distance, status);
+ date->add(UCalendarDateFields::UCAL_DAY_OF_MONTH, -1 * distance, status);
if (U_FAILURE(status)) {
TC_LOG(ERROR) << "error adding a day";
return false;
@@ -81,7 +82,8 @@
return true;
case DateParseData::WEEK:
date->set(UCalendarDateFields::UCAL_DAY_OF_WEEK, 1);
- date->add(UCalendarDateFields::UCAL_DATE, -7 * (distance - 1), status);
+ date->add(UCalendarDateFields::UCAL_DAY_OF_MONTH, -7 * (distance - 1),
+ status);
if (U_FAILURE(status)) {
TC_LOG(ERROR) << "error adding a week";
return false;
@@ -89,7 +91,7 @@
return true;
case DateParseData::MONTH:
- date->set(UCalendarDateFields::UCAL_DATE, 1);
+ date->set(UCalendarDateFields::UCAL_DAY_OF_MONTH, 1);
date->add(UCalendarDateFields::UCAL_MONTH, -1 * (distance - 1), status);
if (U_FAILURE(status)) {
TC_LOG(ERROR) << "error adding a month";
@@ -127,7 +129,7 @@
TC_LOG(ERROR) << "error day of week";
return false;
}
- date->add(UCalendarDateFields::UCAL_DATE, 1, status);
+ date->add(UCalendarDateFields::UCAL_DAY_OF_MONTH, 1, status);
if (U_FAILURE(status)) {
TC_LOG(ERROR) << "error adding a day";
return false;
@@ -135,7 +137,7 @@
}
return true;
case DateParseData::DAY:
- date->add(UCalendarDateFields::UCAL_DATE, 1, status);
+ date->add(UCalendarDateFields::UCAL_DAY_OF_MONTH, 1, status);
if (U_FAILURE(status)) {
TC_LOG(ERROR) << "error adding a day";
return false;
@@ -144,7 +146,7 @@
return true;
case DateParseData::WEEK:
date->set(UCalendarDateFields::UCAL_DAY_OF_WEEK, 1);
- date->add(UCalendarDateFields::UCAL_DATE, 7, status);
+ date->add(UCalendarDateFields::UCAL_DAY_OF_MONTH, 7, status);
if (U_FAILURE(status)) {
TC_LOG(ERROR) << "error adding a week";
return false;
@@ -152,7 +154,7 @@
return true;
case DateParseData::MONTH:
- date->set(UCalendarDateFields::UCAL_DATE, 1);
+ date->set(UCalendarDateFields::UCAL_DAY_OF_MONTH, 1);
date->add(UCalendarDateFields::UCAL_MONTH, 1, status);
if (U_FAILURE(status)) {
TC_LOG(ERROR) << "error adding a month";
@@ -190,7 +192,7 @@
TC_LOG(ERROR) << "error day of week";
return false;
}
- date->add(UCalendarDateFields::UCAL_DATE, 1, status);
+ date->add(UCalendarDateFields::UCAL_DAY_OF_MONTH, 1, status);
if (U_FAILURE(status)) {
TC_LOG(ERROR) << "error adding a day";
return false;
@@ -200,7 +202,7 @@
}
return true;
case DateParseData::DAY:
- date->add(UCalendarDateFields::UCAL_DATE, distance, status);
+ date->add(UCalendarDateFields::UCAL_DAY_OF_MONTH, distance, status);
if (U_FAILURE(status)) {
TC_LOG(ERROR) << "error adding a day";
return false;
@@ -209,7 +211,7 @@
return true;
case DateParseData::WEEK:
date->set(UCalendarDateFields::UCAL_DAY_OF_WEEK, 1);
- date->add(UCalendarDateFields::UCAL_DATE, 7 * distance, status);
+ date->add(UCalendarDateFields::UCAL_DAY_OF_MONTH, 7 * distance, status);
if (U_FAILURE(status)) {
TC_LOG(ERROR) << "error adding a week";
return false;
@@ -217,7 +219,7 @@
return true;
case DateParseData::MONTH:
- date->set(UCalendarDateFields::UCAL_DATE, 1);
+ date->set(UCalendarDateFields::UCAL_DAY_OF_MONTH, 1);
date->add(UCalendarDateFields::UCAL_MONTH, 1 * distance, status);
if (U_FAILURE(status)) {
TC_LOG(ERROR) << "error adding a month";
@@ -238,15 +240,61 @@
}
}
+bool RoundToGranularity(DatetimeGranularity granularity,
+ icu::Calendar* calendar) {
+ // Force recomputation before doing the rounding.
+ UErrorCode status = U_ZERO_ERROR;
+ calendar->get(UCalendarDateFields::UCAL_DAY_OF_WEEK, status);
+ if (U_FAILURE(status)) {
+ TC_LOG(ERROR) << "Can't interpret date.";
+ return false;
+ }
+
+ switch (granularity) {
+ case GRANULARITY_YEAR:
+ calendar->set(UCalendarDateFields::UCAL_MONTH, 0);
+ TC_FALLTHROUGH_INTENDED;
+ case GRANULARITY_MONTH:
+ calendar->set(UCalendarDateFields::UCAL_DAY_OF_MONTH, 1);
+ TC_FALLTHROUGH_INTENDED;
+ case GRANULARITY_DAY:
+ calendar->set(UCalendarDateFields::UCAL_HOUR, 0);
+ TC_FALLTHROUGH_INTENDED;
+ case GRANULARITY_HOUR:
+ calendar->set(UCalendarDateFields::UCAL_MINUTE, 0);
+ TC_FALLTHROUGH_INTENDED;
+ case GRANULARITY_MINUTE:
+ calendar->set(UCalendarDateFields::UCAL_SECOND, 0);
+ break;
+
+ case GRANULARITY_WEEK:
+ calendar->set(UCalendarDateFields::UCAL_DAY_OF_WEEK,
+ calendar->getFirstDayOfWeek());
+ calendar->set(UCalendarDateFields::UCAL_HOUR, 0);
+ calendar->set(UCalendarDateFields::UCAL_MINUTE, 0);
+ calendar->set(UCalendarDateFields::UCAL_SECOND, 0);
+ break;
+
+ case GRANULARITY_UNKNOWN:
+ case GRANULARITY_SECOND:
+ break;
+ }
+
+ return true;
+}
+
} // namespace
bool CalendarLib::InterpretParseData(const DateParseData& parse_data,
int64 reference_time_ms_utc,
const std::string& reference_timezone,
+ const std::string& reference_locale,
+ DatetimeGranularity granularity,
int64* interpreted_time_ms_utc) const {
UErrorCode status = U_ZERO_ERROR;
- std::unique_ptr<icu::Calendar> date(icu::Calendar::createInstance(status));
+ std::unique_ptr<icu::Calendar> date(icu::Calendar::createInstance(
+ icu::Locale::createFromName(reference_locale.c_str()), status));
if (U_FAILURE(status)) {
TC_LOG(ERROR) << "error getting calendar instance";
return false;
@@ -307,14 +355,14 @@
// NOOP
break;
case DateParseData::Relation::TOMORROW:
- date->add(UCalendarDateFields::UCAL_DATE, 1, status);
+ date->add(UCalendarDateFields::UCAL_DAY_OF_MONTH, 1, status);
if (U_FAILURE(status)) {
TC_LOG(ERROR) << "error adding a day";
return false;
}
break;
case DateParseData::Relation::YESTERDAY:
- date->add(UCalendarDateFields::UCAL_DATE, -1, status);
+ date->add(UCalendarDateFields::UCAL_DAY_OF_MONTH, -1, status);
if (U_FAILURE(status)) {
TC_LOG(ERROR) << "error subtracting a day";
return false;
@@ -356,7 +404,7 @@
date->set(UCalendarDateFields::UCAL_MONTH, parse_data.month - 1);
}
if (parse_data.field_set_mask & DateParseData::Fields::DAY_FIELD) {
- date->set(UCalendarDateFields::UCAL_DATE, parse_data.day_of_month);
+ date->set(UCalendarDateFields::UCAL_DAY_OF_MONTH, parse_data.day_of_month);
}
if (parse_data.field_set_mask & DateParseData::Fields::HOUR_FIELD) {
if (parse_data.field_set_mask & DateParseData::Fields::AMPM_FIELD &&
@@ -372,11 +420,17 @@
if (parse_data.field_set_mask & DateParseData::Fields::SECOND_FIELD) {
date->set(UCalendarDateFields::UCAL_SECOND, parse_data.second);
}
+
+ if (!RoundToGranularity(granularity, date.get())) {
+ return false;
+ }
+
*interpreted_time_ms_utc = date->getTime(status);
if (U_FAILURE(status)) {
TC_LOG(ERROR) << "error getting time from instance";
return false;
}
+
return true;
}
} // namespace libtextclassifier2
diff --git a/util/calendar/calendar-icu.h b/util/calendar/calendar-icu.h
index dc0a4f4..8aae7ab 100644
--- a/util/calendar/calendar-icu.h
+++ b/util/calendar/calendar-icu.h
@@ -33,6 +33,8 @@
bool InterpretParseData(const DateParseData& parse_data,
int64 reference_time_ms_utc,
const std::string& reference_timezone,
+ const std::string& reference_locale,
+ DatetimeGranularity granularity,
int64* interpreted_time_ms_utc) const;
};
} // namespace libtextclassifier2
diff --git a/util/calendar/calendar_test.cc b/util/calendar/calendar_test.cc
index 7065a95..1f29106 100644
--- a/util/calendar/calendar_test.cc
+++ b/util/calendar/calendar_test.cc
@@ -33,9 +33,97 @@
DateParseData{0l, 0, 0, 0, 0, 0, 0, 0, 0, 0,
static_cast<DateParseData::Relation>(0),
static_cast<DateParseData::RelationType>(0), 0},
- 0L, "Zurich", &time);
+ 0L, "Zurich", "en-CH", GRANULARITY_UNKNOWN, &time);
TC_LOG(INFO) << result;
}
+#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
+TEST(CalendarTest, RoundingToGranularity) {
+ CalendarLib calendar;
+ int64 time;
+ std::string timezone;
+ DateParseData data;
+ data.year = 2018;
+ data.month = 4;
+ data.day_of_month = 25;
+ data.hour = 9;
+ data.minute = 33;
+ data.second = 59;
+ data.field_set_mask = DateParseData::YEAR_FIELD | DateParseData::MONTH_FIELD |
+ DateParseData::DAY_FIELD | DateParseData::HOUR_FIELD |
+ DateParseData::MINUTE_FIELD |
+ DateParseData::SECOND_FIELD;
+ ASSERT_TRUE(calendar.InterpretParseData(
+ data,
+ /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
+ /*reference_locale=*/"en-CH",
+ /*granularity=*/GRANULARITY_YEAR, &time));
+ EXPECT_EQ(time, 1514761200000L /* Jan 01 2018 00:00:00 */);
+
+ ASSERT_TRUE(calendar.InterpretParseData(
+ data,
+ /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
+ /*reference_locale=*/"en-CH",
+ /*granularity=*/GRANULARITY_MONTH, &time));
+ EXPECT_EQ(time, 1522533600000L /* Apr 01 2018 00:00:00 */);
+
+ ASSERT_TRUE(calendar.InterpretParseData(
+ data,
+ /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
+ /*reference_locale=*/"en-CH",
+ /*granularity=*/GRANULARITY_WEEK, &time));
+ EXPECT_EQ(time, 1524434400000L /* Mon Apr 23 2018 00:00:00 */);
+
+ ASSERT_TRUE(calendar.InterpretParseData(
+ data,
+ /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
+ /*reference_locale=*/"*-CH",
+ /*granularity=*/GRANULARITY_WEEK, &time));
+ EXPECT_EQ(time, 1524434400000L /* Mon Apr 23 2018 00:00:00 */);
+
+ ASSERT_TRUE(calendar.InterpretParseData(
+ data,
+ /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
+ /*reference_locale=*/"en-US",
+ /*granularity=*/GRANULARITY_WEEK, &time));
+ EXPECT_EQ(time, 1524348000000L /* Sun Apr 22 2018 00:00:00 */);
+
+ ASSERT_TRUE(calendar.InterpretParseData(
+ data,
+ /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
+ /*reference_locale=*/"*-US",
+ /*granularity=*/GRANULARITY_WEEK, &time));
+ EXPECT_EQ(time, 1524348000000L /* Sun Apr 22 2018 00:00:00 */);
+
+ ASSERT_TRUE(calendar.InterpretParseData(
+ data,
+ /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
+ /*reference_locale=*/"en-CH",
+ /*granularity=*/GRANULARITY_DAY, &time));
+ EXPECT_EQ(time, 1524607200000L /* Apr 25 2018 00:00:00 */);
+
+ ASSERT_TRUE(calendar.InterpretParseData(
+ data,
+ /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
+ /*reference_locale=*/"en-CH",
+ /*granularity=*/GRANULARITY_HOUR, &time));
+ EXPECT_EQ(time, 1524639600000L /* Apr 25 2018 09:00:00 */);
+
+ ASSERT_TRUE(calendar.InterpretParseData(
+ data,
+ /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
+ /*reference_locale=*/"en-CH",
+ /*granularity=*/GRANULARITY_MINUTE, &time));
+ EXPECT_EQ(time, 1524641580000 /* Apr 25 2018 09:33:00 */);
+
+ ASSERT_TRUE(calendar.InterpretParseData(
+ data,
+ /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
+ /*reference_locale=*/"en-CH",
+ /*granularity=*/GRANULARITY_SECOND, &time));
+ EXPECT_EQ(time, 1524641639000 /* Apr 25 2018 09:33:59 */);
+}
+#endif // LIBTEXTCLASSIFIER_UNILIB_DUMMY
+
} // namespace
} // namespace libtextclassifier2
diff --git a/util/flatbuffers.cc b/util/flatbuffers.cc
new file mode 100644
index 0000000..6c0108e
--- /dev/null
+++ b/util/flatbuffers.cc
@@ -0,0 +1,26 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "util/flatbuffers.h"
+
+namespace libtextclassifier2 {
+
+template <>
+const char* FlatbufferFileIdentifier<Model>() {
+ return ModelIdentifier();
+}
+
+} // namespace libtextclassifier2
diff --git a/util/flatbuffers.h b/util/flatbuffers.h
new file mode 100644
index 0000000..93d73b6
--- /dev/null
+++ b/util/flatbuffers.h
@@ -0,0 +1,98 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Utility functions for working with FlatBuffers.
+
+#ifndef LIBTEXTCLASSIFIER_UTIL_FLATBUFFERS_H_
+#define LIBTEXTCLASSIFIER_UTIL_FLATBUFFERS_H_
+
+#include <memory>
+#include <string>
+
+#include "model_generated.h"
+#include "flatbuffers/flatbuffers.h"
+
+namespace libtextclassifier2 {
+
+// Loads and interprets the buffer as 'FlatbufferMessage' and verifies its
+// integrity.
+template <typename FlatbufferMessage>
+const FlatbufferMessage* LoadAndVerifyFlatbuffer(const void* buffer, int size) {
+ const FlatbufferMessage* message =
+ flatbuffers::GetRoot<FlatbufferMessage>(buffer);
+ if (message == nullptr) {
+ return nullptr;
+ }
+ flatbuffers::Verifier verifier(reinterpret_cast<const uint8_t*>(buffer),
+ size);
+ if (message->Verify(verifier)) {
+ return message;
+ } else {
+ return nullptr;
+ }
+}
+
+// Same as above but takes string.
+template <typename FlatbufferMessage>
+const FlatbufferMessage* LoadAndVerifyFlatbuffer(const std::string& buffer) {
+ return LoadAndVerifyFlatbuffer<FlatbufferMessage>(buffer.c_str(),
+ buffer.size());
+}
+
+// Loads and interprets the buffer as 'FlatbufferMessage', verifies its
+// integrity and returns its mutable version.
+template <typename FlatbufferMessage>
+std::unique_ptr<typename FlatbufferMessage::NativeTableType>
+LoadAndVerifyMutableFlatbuffer(const void* buffer, int size) {
+ const FlatbufferMessage* message =
+ LoadAndVerifyFlatbuffer<FlatbufferMessage>(buffer, size);
+ if (message == nullptr) {
+ return nullptr;
+ }
+ return std::unique_ptr<typename FlatbufferMessage::NativeTableType>(
+ message->UnPack());
+}
+
+// Same as above but takes string.
+template <typename FlatbufferMessage>
+std::unique_ptr<typename FlatbufferMessage::NativeTableType>
+LoadAndVerifyMutableFlatbuffer(const std::string& buffer) {
+ return LoadAndVerifyMutableFlatbuffer<FlatbufferMessage>(buffer.c_str(),
+ buffer.size());
+}
+
+template <typename FlatbufferMessage>
+const char* FlatbufferFileIdentifier() {
+ return nullptr;
+}
+
+template <>
+const char* FlatbufferFileIdentifier<Model>();
+
+// Packs the mutable flatbuffer message to string.
+template <typename FlatbufferMessage>
+std::string PackFlatbuffer(
+ const typename FlatbufferMessage::NativeTableType* mutable_message) {
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(FlatbufferMessage::Pack(builder, mutable_message),
+ FlatbufferFileIdentifier<FlatbufferMessage>());
+ return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize());
+}
+
+} // namespace libtextclassifier2
+
+#endif // LIBTEXTCLASSIFIER_UTIL_FLATBUFFERS_H_
diff --git a/util/utf8/unicodetext.cc b/util/utf8/unicodetext.cc
index 70fecd4..2ef79e9 100644
--- a/util/utf8/unicodetext.cc
+++ b/util/utf8/unicodetext.cc
@@ -292,5 +292,8 @@
return UTF8ToUnicodeText(str.data(), str.size(), do_copy);
}
+UnicodeText UTF8ToUnicodeText(const std::string& str) {
+ return UTF8ToUnicodeText(str, /*do_copy=*/true);
+}
} // namespace libtextclassifier2
diff --git a/util/utf8/unicodetext.h b/util/utf8/unicodetext.h
index 7300111..ec08f53 100644
--- a/util/utf8/unicodetext.h
+++ b/util/utf8/unicodetext.h
@@ -217,6 +217,7 @@
UnicodeText UTF8ToUnicodeText(const char* utf8_buf, int len, bool do_copy);
UnicodeText UTF8ToUnicodeText(const char* utf8_buf, bool do_copy);
UnicodeText UTF8ToUnicodeText(const std::string& str, bool do_copy);
+UnicodeText UTF8ToUnicodeText(const std::string& str);
} // namespace libtextclassifier2
diff --git a/zlib-utils.cc b/zlib-utils.cc
index 8650c9c..7e6646f 100644
--- a/zlib-utils.cc
+++ b/zlib-utils.cc
@@ -19,6 +19,7 @@
#include <memory>
#include "util/base/logging.h"
+#include "util/flatbuffers.h"
namespace libtextclassifier2 {
@@ -150,6 +151,72 @@
return true;
}
+namespace {
+
+bool DecompressBuffer(const CompressedBufferT* compressed_pattern,
+ ZlibDecompressor* zlib_decompressor,
+ std::string* uncompressed_pattern) {
+ std::string packed_pattern =
+ PackFlatbuffer<CompressedBuffer>(compressed_pattern);
+ if (!zlib_decompressor->Decompress(
+ LoadAndVerifyFlatbuffer<CompressedBuffer>(packed_pattern),
+ uncompressed_pattern)) {
+ return false;
+ }
+ return true;
+}
+
+} // namespace
+
+bool DecompressModel(ModelT* model) {
+ std::unique_ptr<ZlibDecompressor> zlib_decompressor =
+ ZlibDecompressor::Instance();
+ if (!zlib_decompressor) {
+ TC_LOG(ERROR) << "Cannot initialize decompressor.";
+ return false;
+ }
+
+ // Decompress regex rules.
+ if (model->regex_model != nullptr) {
+ for (int i = 0; i < model->regex_model->patterns.size(); i++) {
+ RegexModel_::PatternT* pattern = model->regex_model->patterns[i].get();
+ if (!DecompressBuffer(pattern->compressed_pattern.get(),
+ zlib_decompressor.get(), &pattern->pattern)) {
+ TC_LOG(ERROR) << "Cannot decompress pattern: " << i;
+ return false;
+ }
+ pattern->compressed_pattern.reset(nullptr);
+ }
+ }
+
+ // Decompress date-time rules.
+ if (model->datetime_model != nullptr) {
+ for (int i = 0; i < model->datetime_model->patterns.size(); i++) {
+ DatetimeModelPatternT* pattern = model->datetime_model->patterns[i].get();
+ for (int j = 0; j < pattern->regexes.size(); j++) {
+ DatetimeModelPattern_::RegexT* regex = pattern->regexes[j].get();
+ if (!DecompressBuffer(regex->compressed_pattern.get(),
+ zlib_decompressor.get(), ®ex->pattern)) {
+ TC_LOG(ERROR) << "Cannot decompress pattern: " << i << " " << j;
+ return false;
+ }
+ regex->compressed_pattern.reset(nullptr);
+ }
+ }
+ for (int i = 0; i < model->datetime_model->extractors.size(); i++) {
+ DatetimeModelExtractorT* extractor =
+ model->datetime_model->extractors[i].get();
+ if (!DecompressBuffer(extractor->compressed_pattern.get(),
+ zlib_decompressor.get(), &extractor->pattern)) {
+ TC_LOG(ERROR) << "Cannot decompress pattern: " << i;
+ return false;
+ }
+ extractor->compressed_pattern.reset(nullptr);
+ }
+ }
+ return true;
+}
+
std::string CompressSerializedModel(const std::string& model) {
std::unique_ptr<ModelT> unpacked_model = UnPackModel(model.c_str());
TC_CHECK(unpacked_model != nullptr);
@@ -162,8 +229,8 @@
std::unique_ptr<UniLib::RegexPattern> UncompressMakeRegexPattern(
const UniLib& unilib, const flatbuffers::String* uncompressed_pattern,
- const CompressedBuffer* compressed_pattern,
- ZlibDecompressor* decompressor) {
+ const CompressedBuffer* compressed_pattern, ZlibDecompressor* decompressor,
+ std::string* result_pattern_text) {
UnicodeText unicode_regex_pattern;
std::string decompressed_pattern;
if (compressed_pattern != nullptr &&
@@ -186,6 +253,10 @@
uncompressed_pattern->Length(), /*do_copy=*/false);
}
+ if (result_pattern_text != nullptr) {
+ *result_pattern_text = unicode_regex_pattern.ToUTF8String();
+ }
+
std::unique_ptr<UniLib::RegexPattern> regex_pattern =
unilib.CreateRegexPattern(unicode_regex_pattern);
if (!regex_pattern) {
diff --git a/zlib-utils.h b/zlib-utils.h
index d79f76e..136f4d2 100644
--- a/zlib-utils.h
+++ b/zlib-utils.h
@@ -62,13 +62,17 @@
// Compresses regex and datetime rules in the model in place.
bool CompressModel(ModelT* model);
+// Decompresses regex and datetime rules in the model in place.
+bool DecompressModel(ModelT* model);
+
// Compresses regex and datetime rules in the model.
std::string CompressSerializedModel(const std::string& model);
// Create and compile a regex pattern from optionally compressed pattern.
std::unique_ptr<UniLib::RegexPattern> UncompressMakeRegexPattern(
const UniLib& unilib, const flatbuffers::String* uncompressed_pattern,
- const CompressedBuffer* compressed_pattern, ZlibDecompressor* decompressor);
+ const CompressedBuffer* compressed_pattern, ZlibDecompressor* decompressor,
+ std::string* result_pattern_text = nullptr);
} // namespace libtextclassifier2
diff --git a/zlib-utils_test.cc b/zlib-utils_test.cc
index d3b5a19..155f14f 100644
--- a/zlib-utils_test.cc
+++ b/zlib-utils_test.cc
@@ -84,6 +84,15 @@
->compressed_pattern(),
&uncompressed_pattern));
EXPECT_EQ(uncompressed_pattern, "an example datetime extractor");
+
+ EXPECT_TRUE(DecompressModel(&model));
+ EXPECT_EQ(model.regex_model->patterns[0]->pattern, "this is a test pattern");
+ EXPECT_EQ(model.regex_model->patterns[1]->pattern,
+ "this is a second test pattern");
+ EXPECT_EQ(model.datetime_model->patterns[0]->regexes[0]->pattern,
+ "an example datetime pattern");
+ EXPECT_EQ(model.datetime_model->extractors[0]->pattern,
+ "an example datetime extractor");
}
} // namespace libtextclassifier2