Import latest version of libtextclassifier.
This includes newly trained i18n models.
The inference code now includes the option to normalize all input
to lowercase.
Bug: 36886059
Bug: 37534119
Test: Unit tests pass.
Change-Id: I28d1bd2241720f720d2dcabfb5710748a311b302
diff --git a/models/textclassifier.smartselection.ar.model b/models/textclassifier.smartselection.ar.model
index f81db50..b2ea877 100644
--- a/models/textclassifier.smartselection.ar.model
+++ b/models/textclassifier.smartselection.ar.model
Binary files differ
diff --git a/models/textclassifier.smartselection.de.model b/models/textclassifier.smartselection.de.model
index 0017b16..cf6ee82 100644
--- a/models/textclassifier.smartselection.de.model
+++ b/models/textclassifier.smartselection.de.model
Binary files differ
diff --git a/models/textclassifier.smartselection.en.model b/models/textclassifier.smartselection.en.model
index 850033a..92dd58b 100644
--- a/models/textclassifier.smartselection.en.model
+++ b/models/textclassifier.smartselection.en.model
Binary files differ
diff --git a/models/textclassifier.smartselection.es.model b/models/textclassifier.smartselection.es.model
index 23c2f0f..09612d1 100644
--- a/models/textclassifier.smartselection.es.model
+++ b/models/textclassifier.smartselection.es.model
Binary files differ
diff --git a/models/textclassifier.smartselection.fr.model b/models/textclassifier.smartselection.fr.model
index 12ee346..ee19caf 100644
--- a/models/textclassifier.smartselection.fr.model
+++ b/models/textclassifier.smartselection.fr.model
Binary files differ
diff --git a/models/textclassifier.smartselection.it.model b/models/textclassifier.smartselection.it.model
index 5a315b4..36a3016 100644
--- a/models/textclassifier.smartselection.it.model
+++ b/models/textclassifier.smartselection.it.model
Binary files differ
diff --git a/models/textclassifier.smartselection.ja.model b/models/textclassifier.smartselection.ja.model
index 525b36e..12f1e29 100644
--- a/models/textclassifier.smartselection.ja.model
+++ b/models/textclassifier.smartselection.ja.model
Binary files differ
diff --git a/models/textclassifier.smartselection.ko.model b/models/textclassifier.smartselection.ko.model
index b4c7b2d..70c55ca 100644
--- a/models/textclassifier.smartselection.ko.model
+++ b/models/textclassifier.smartselection.ko.model
Binary files differ
diff --git a/models/textclassifier.smartselection.nl.model b/models/textclassifier.smartselection.nl.model
index 470d83e..f80bbe9 100644
--- a/models/textclassifier.smartselection.nl.model
+++ b/models/textclassifier.smartselection.nl.model
Binary files differ
diff --git a/models/textclassifier.smartselection.pl.model b/models/textclassifier.smartselection.pl.model
index f3af3a4..aa8f585 100644
--- a/models/textclassifier.smartselection.pl.model
+++ b/models/textclassifier.smartselection.pl.model
Binary files differ
diff --git a/models/textclassifier.smartselection.pt-PT.model b/models/textclassifier.smartselection.pt-PT.model
index 2a27489..196a98b 100644
--- a/models/textclassifier.smartselection.pt-PT.model
+++ b/models/textclassifier.smartselection.pt-PT.model
Binary files differ
diff --git a/models/textclassifier.smartselection.ru.model b/models/textclassifier.smartselection.ru.model
index 0876043..9261e1a 100644
--- a/models/textclassifier.smartselection.ru.model
+++ b/models/textclassifier.smartselection.ru.model
Binary files differ
diff --git a/models/textclassifier.smartselection.th.model b/models/textclassifier.smartselection.th.model
index c26aedb..482ff2e 100644
--- a/models/textclassifier.smartselection.th.model
+++ b/models/textclassifier.smartselection.th.model
Binary files differ
diff --git a/models/textclassifier.smartselection.tr.model b/models/textclassifier.smartselection.tr.model
index 9d41ab7..b5c184f 100644
--- a/models/textclassifier.smartselection.tr.model
+++ b/models/textclassifier.smartselection.tr.model
Binary files differ
diff --git a/models/textclassifier.smartselection.zh-Hant.model b/models/textclassifier.smartselection.zh-Hant.model
index 7409a67..69e4ebb 100644
--- a/models/textclassifier.smartselection.zh-Hant.model
+++ b/models/textclassifier.smartselection.zh-Hant.model
Binary files differ
diff --git a/models/textclassifier.smartselection.zh.model b/models/textclassifier.smartselection.zh.model
index a9c0ee7..199e9b1 100644
--- a/models/textclassifier.smartselection.zh.model
+++ b/models/textclassifier.smartselection.zh.model
Binary files differ
diff --git a/smartselect/feature-processor.cc b/smartselect/feature-processor.cc
index 9c6c505..919b61f 100644
--- a/smartselect/feature-processor.cc
+++ b/smartselect/feature-processor.cc
@@ -49,6 +49,7 @@
extractor_options.regexp_features.push_back(options.regexp_feature(i));
}
extractor_options.remap_digits = options.remap_digits();
+ extractor_options.lowercase_tokens = options.lowercase_tokens();
return extractor_options;
}
diff --git a/smartselect/text-classification-model.proto b/smartselect/text-classification-model.proto
index ce83910..40e4ed0 100644
--- a/smartselect/text-classification-model.proto
+++ b/smartselect/text-classification-model.proto
@@ -98,6 +98,9 @@
// Whether to remap all digits to a single number.
optional bool remap_digits = 20 [default = false];
+ // Whether to lower-case each token before generating hashgrams.
+ optional bool lowercase_tokens = 33;
+
// If true, the selection classifier output will contain only the selections
// that are feasible (e.g., those that are shorter than max_selection_span),
// if false, the output will be a complete cross-product of possible
diff --git a/smartselect/token-feature-extractor.cc b/smartselect/token-feature-extractor.cc
index 6013ef3..479be41 100644
--- a/smartselect/token-feature-extractor.cc
+++ b/smartselect/token-feature-extractor.cc
@@ -16,6 +16,8 @@
#include "smartselect/token-feature-extractor.h"
+#include <string>
+
#include "util/base/logging.h"
#include "util/hash/farmhash.h"
#include "util/strings/stringpiece.h"
@@ -27,27 +29,46 @@
namespace {
-std::string MapDigitsToZeroAscii(const std::string& token) {
+std::string RemapTokenAscii(const std::string& token,
+ const TokenFeatureExtractorOptions& options) {
+ if (!options.remap_digits && !options.lowercase_tokens) {
+ return token;
+ }
+
std::string copy = token;
for (int i = 0; i < token.size(); ++i) {
- if (isdigit(copy[i])) {
+ if (options.remap_digits && isdigit(copy[i])) {
copy[i] = '0';
}
+ if (options.lowercase_tokens) {
+ copy[i] = tolower(copy[i]);
+ }
}
return copy;
}
-void MapDigitsToZeroUnicode(const std::string& token, UnicodeText* remapped) {
- remapped->clear();
+void RemapTokenUnicode(const std::string& token,
+ const TokenFeatureExtractorOptions& options,
+ UnicodeText* remapped) {
+ if (!options.remap_digits && !options.lowercase_tokens) {
+ // Leave remapped untouched.
+ return;
+ }
UnicodeText word = UTF8ToUnicodeText(token, /*do_copy=*/false);
+ icu::UnicodeString icu_string;
for (auto it = word.begin(); it != word.end(); ++it) {
- if (u_isdigit(*it)) {
- remapped->AppendUTF8("0", 1);
+ if (options.remap_digits && u_isdigit(*it)) {
+ icu_string.append('0');
+ } else if (options.lowercase_tokens) {
+ icu_string.append(u_tolower(*it));
} else {
- remapped->AppendUTF8(it.utf8_data(), it.utf8_length());
+ icu_string.append(*it);
}
}
+ std::string utf8_str;
+ icu_string.toUTF8String(utf8_str);
+ remapped->CopyUTF8(utf8_str.data(), utf8_str.length());
}
} // namespace
@@ -87,12 +108,7 @@
if (token.is_padding || token.value.empty()) {
result.push_back(HashToken("<PAD>"));
} else {
- std::string word;
- if (options_.remap_digits) {
- word = MapDigitsToZeroAscii(token.value);
- } else {
- word = token.value;
- }
+ const std::string word = RemapTokenAscii(token.value, options_);
// Trim words that are over max_word_length characters.
const int max_word_length = options_.max_word_length;
@@ -137,9 +153,7 @@
result.push_back(HashToken("<PAD>"));
} else {
UnicodeText word = UTF8ToUnicodeText(token.value, /*do_copy=*/false);
- if (options_.remap_digits) {
- MapDigitsToZeroUnicode(token.value, &word);
- }
+ RemapTokenUnicode(token.value, options_, &word);
// Trim the word if needed by finding a left-cut point and right-cut point.
auto left_cut = word.begin();
diff --git a/smartselect/token-feature-extractor.h b/smartselect/token-feature-extractor.h
index 8502199..8287fbd 100644
--- a/smartselect/token-feature-extractor.h
+++ b/smartselect/token-feature-extractor.h
@@ -50,6 +50,9 @@
// Whether to remap digits to a single number.
bool remap_digits = false;
+ // Whether to lowercase all tokens.
+ bool lowercase_tokens = false;
+
// Maximum length of a word.
int max_word_length = 20;
};
diff --git a/tests/testdata/smartselection.model b/tests/testdata/smartselection.model
index 850033a..92dd58b 100644
--- a/tests/testdata/smartselection.model
+++ b/tests/testdata/smartselection.model
Binary files differ
diff --git a/tests/text-classification-model_test.cc b/tests/text-classification-model_test.cc
index cac093d..ed00876 100644
--- a/tests/text-classification-model_test.cc
+++ b/tests/text-classification-model_test.cc
@@ -267,8 +267,7 @@
{90, 103})));
// Single word.
- EXPECT_EQ("other",
- FindBestResult(model->ClassifyText("Barack Obama", {0, 12})));
+ EXPECT_EQ("other", FindBestResult(model->ClassifyText("obama", {0, 5})));
EXPECT_EQ("other", FindBestResult(model->ClassifyText("asdf", {0, 4})));
EXPECT_EQ("<INVALID RESULTS>",
FindBestResult(model->ClassifyText("asdf", {0, 0})));
diff --git a/tests/token-feature-extractor_test.cc b/tests/token-feature-extractor_test.cc
index 277549e..c85ba50 100644
--- a/tests/token-feature-extractor_test.cc
+++ b/tests/token-feature-extractor_test.cc
@@ -250,6 +250,47 @@
testing::Not(testing::ElementsAreArray(sparse_features2)));
}
+TEST(TokenFeatureExtractorTest, LowercaseAscii) {
+ TokenFeatureExtractorOptions options;
+ options.num_buckets = 1000;
+ options.chargram_orders = std::vector<int>{1, 2};
+ options.lowercase_tokens = true;
+ options.unicode_aware_features = false;
+ TokenFeatureExtractor extractor(options);
+
+ std::vector<int> sparse_features;
+ std::vector<float> dense_features;
+ extractor.Extract(Token{"AABB", 0, 6}, true, &sparse_features,
+ &dense_features);
+
+ std::vector<int> sparse_features2;
+ extractor.Extract(Token{"aaBB", 0, 6}, true, &sparse_features2,
+ &dense_features);
+ EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
+
+ extractor.Extract(Token{"aAbB", 0, 6}, true, &sparse_features2,
+ &dense_features);
+ EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
+}
+
+TEST(TokenFeatureExtractorTest, LowercaseUnicode) {
+ TokenFeatureExtractorOptions options;
+ options.num_buckets = 1000;
+ options.chargram_orders = std::vector<int>{1, 2};
+ options.lowercase_tokens = true;
+ options.unicode_aware_features = true;
+ TokenFeatureExtractor extractor(options);
+
+ std::vector<int> sparse_features;
+ std::vector<float> dense_features;
+ extractor.Extract(Token{"ŘŘ", 0, 6}, true, &sparse_features, &dense_features);
+
+ std::vector<int> sparse_features2;
+ extractor.Extract(Token{"řř", 0, 6}, true, &sparse_features2,
+ &dense_features);
+ EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
+}
+
TEST(TokenFeatureExtractorTest, RegexFeatures) {
TokenFeatureExtractorOptions options;
options.num_buckets = 1000;