Adds support for ICU tokenization.
Bug: 36886053
Test: Built, tested on device, google3 regression and unit tests pass.
Change-Id: Ia3345c6c7a5aa816233d3b3ae10e2a92b31f08a7
diff --git a/smartselect/feature-processor.cc b/smartselect/feature-processor.cc
index 0ba25ca..6f2dc73 100644
--- a/smartselect/feature-processor.cc
+++ b/smartselect/feature-processor.cc
@@ -24,6 +24,9 @@
#include "util/base/logging.h"
#include "util/strings/utf8.h"
#include "util/utf8/unicodetext.h"
+#include "unicode/brkiter.h"
+#include "unicode/errorcode.h"
+#include "unicode/uchar.h"
namespace libtextclassifier {
@@ -178,7 +181,22 @@
std::vector<Token> FeatureProcessor::Tokenize(
const std::string& utf8_text) const {
- return tokenizer_.Tokenize(utf8_text);
+ if (options_.tokenization_type() ==
+ libtextclassifier::FeatureProcessorOptions::INTERNAL_TOKENIZER) {
+ return tokenizer_.Tokenize(utf8_text);
+ } else if (options_.tokenization_type() ==
+ libtextclassifier::FeatureProcessorOptions::ICU) {
+ std::vector<Token> result;
+ if (ICUTokenize(utf8_text, &result)) {
+ return result;
+ } else {
+ return {};
+ }
+ } else {
+ TC_LOG(ERROR) << "Unknown tokenization type specified. Using "
+ "internal.";
+ return tokenizer_.Tokenize(utf8_text);
+ }
}
bool FeatureProcessor::LabelToSpan(
@@ -278,7 +296,6 @@
}
}
-// Converts a codepoint span to a token span in the given list of tokens.
TokenSpan CodepointSpanToTokenSpan(const std::vector<Token>& selectable_tokens,
CodepointSpan codepoint_span) {
const int codepoint_start = std::get<0>(codepoint_span);
@@ -299,6 +316,12 @@
return {start_token, end_token};
}
+CodepointSpan TokenSpanToCodepointSpan(
+ const std::vector<Token>& selectable_tokens, TokenSpan token_span) {
+ return {selectable_tokens[token_span.first].start,
+ selectable_tokens[token_span.second - 1].end};
+}
+
namespace {
// Finds a single token that completely contains the given span.
@@ -633,4 +656,49 @@
return options_.context_size();
}
+bool FeatureProcessor::ICUTokenize(const std::string& context,
+ std::vector<Token>* result) const {
+ icu::ErrorCode status;
+ icu::UnicodeString unicode_text = icu::UnicodeString::fromUTF8(context);
+ std::unique_ptr<icu::BreakIterator> break_iterator(
+ icu::BreakIterator::createWordInstance(icu::Locale("en"), status));
+ if (!status.isSuccess()) {
+ TC_LOG(ERROR) << "Break iterator did not initialize properly: "
+ << status.errorName();
+ return false;
+ }
+
+ break_iterator->setText(unicode_text);
+
+ size_t last_break_index = 0;
+ size_t break_index = 0;
+ size_t last_unicode_index = 0;
+ size_t unicode_index = 0;
+ while ((break_index = break_iterator->next()) != icu::BreakIterator::DONE) {
+ icu::UnicodeString token(unicode_text, last_break_index,
+ break_index - last_break_index);
+ int token_length = token.countChar32();
+ unicode_index = last_unicode_index + token_length;
+
+ std::string token_utf8;
+ token.toUTF8String(token_utf8);
+
+ bool is_whitespace = true;
+ for (int i = 0; i < token.length(); i++) {
+ if (!u_isWhitespace(token.char32At(i))) {
+ is_whitespace = false;
+ }
+ }
+
+ if (!is_whitespace || options_.icu_preserve_whitespace_tokens()) {
+ result->push_back(Token(token_utf8, last_unicode_index, unicode_index));
+ }
+
+ last_break_index = break_index;
+ last_unicode_index = unicode_index;
+ }
+
+ return result;
+}
+
} // namespace libtextclassifier
diff --git a/smartselect/feature-processor.h b/smartselect/feature-processor.h
index 2f1e530..1efcf63 100644
--- a/smartselect/feature-processor.h
+++ b/smartselect/feature-processor.h
@@ -79,9 +79,14 @@
} // namespace internal
+// Converts a codepoint span to a token span in the given list of tokens.
TokenSpan CodepointSpanToTokenSpan(const std::vector<Token>& selectable_tokens,
CodepointSpan codepoint_span);
+// Converts a token span to a codepoint span in the given list of tokens.
+CodepointSpan TokenSpanToCodepointSpan(
+ const std::vector<Token>& selectable_tokens, TokenSpan token_span);
+
// Takes care of preparing features for the span prediction model.
class FeatureProcessor {
public:
@@ -203,6 +208,10 @@
// Pads tokens with options.context_size() padding tokens on both sides.
int PadContext(std::vector<Token>* tokens) const;
+ // Tokenizes the input text using ICU tokenizer.
+ bool ICUTokenize(const std::string& context,
+ std::vector<Token>* result) const;
+
const TokenFeatureExtractor feature_extractor_;
private:
diff --git a/smartselect/text-classification-model.cc b/smartselect/text-classification-model.cc
index ca0484f..71457bc 100644
--- a/smartselect/text-classification-model.cc
+++ b/smartselect/text-classification-model.cc
@@ -201,6 +201,14 @@
return {};
}
+ if (selection_label_spans != nullptr) {
+ if (!feature_processor.SelectionLabelSpans(output_tokens,
+ selection_label_spans)) {
+ TC_LOG(ERROR) << "Could not get spans for selection labels.";
+ return {};
+ }
+ }
+
std::vector<float> scores;
network.ComputeLogits(features, &scores);
return scores;
diff --git a/smartselect/text-classification-model.proto b/smartselect/text-classification-model.proto
index 7a4c9f1..4e21af4 100644
--- a/smartselect/text-classification-model.proto
+++ b/smartselect/text-classification-model.proto
@@ -88,9 +88,6 @@
// Whether to remap all digits to a single number.
optional bool remap_digits = 20 [default = false];
- // If true, tokenize on space, otherwise tokenize using ICU.
- optional bool tokenize_on_space = 7 [default = true];
-
// If true, the selection classifier output will contain only the selections
// that are feasible (e.g., those that are shorter than max_selection_span),
// if false, the output will be a complete cross-product of possible
@@ -169,6 +166,43 @@
// and at the end, the dense features for the tokens are concatenated
// to it. So the resulting feature vector has two regions.
optional int32 feature_version = 25 [default = 0];
+
+ // These settings control whether a distortion is applied to part of the data,
+ // for Smart Sharing. Distortion means modifying (expanding) the bounds of the
+ // selection and changing the example's collection to "other". The goal is to
+ // expose the model to overselections as negative examples.
+ // If true, distortion is applied. Otherwise the other settings are ignored.
+ optional bool distortion_enable = 26;
+ // Probability settings. They individual values and their sum should be in the
+ // range [0, 1]. They specify the probability of the distorition being applied
+ // to just one of the bounds (left or right with equal probabiolity), both
+ // bounds, or not at all (the remaining probability).
+ // If the context does not contain tokens on the given side of the selection,
+ // the probabilistic decision is ignored. This means that the actual frequency
+ // of distortion is somewhat lower than specified here.
+ optional double distortion_probability_one_side = 27;
+ optional double distortion_probability_both_sides = 28;
+ // The maximum number of tokens to include (on one side) when distortion is
+ // applied. The actual number is selected (independently for each side)
+ // uniformly from integers from 1 to this value, inclusive. If the context is
+ // too short, the end result is truncated.
+ optional double distortion_max_num_tokens = 29;
+
+ // Controls the type of tokenization the model will use for the input text.
+ enum TokenizationType {
+ INVALID_TOKENIZATION_TYPE = 0;
+
+ // Use the internal tokenizer for tokenization.
+ INTERNAL_TOKENIZER = 1;
+
+ // Use ICU for tokenization.
+ ICU = 2;
+ }
+ optional TokenizationType tokenization_type = 30
+ [default = INTERNAL_TOKENIZER];
+ optional bool icu_preserve_whitespace_tokens = 31 [default = false];
+
+ reserved 7;
};
extend nlp_core.EmbeddingNetworkProto {
diff --git a/tests/feature-processor_test.cc b/tests/feature-processor_test.cc
index e3a39e3..cf09f96 100644
--- a/tests/feature-processor_test.cc
+++ b/tests/feature-processor_test.cc
@@ -204,13 +204,13 @@
using FeatureProcessor::SpanToLabel;
using FeatureProcessor::SupportedCodepointsRatio;
using FeatureProcessor::IsCodepointSupported;
+ using FeatureProcessor::ICUTokenize;
};
TEST(FeatureProcessorTest, SpanToLabel) {
FeatureProcessorOptions options;
options.set_context_size(1);
options.set_max_selection_span(1);
- options.set_tokenize_on_space(true);
options.set_snap_label_span_boundaries_to_containing_tokens(false);
TokenizationCodepointRange* config =
@@ -519,5 +519,45 @@
EXPECT_EQ(click_index, 5);
}
+TEST(FeatureProcessorTest, ICUTokenize) {
+ FeatureProcessorOptions options;
+ options.set_tokenization_type(
+ libtextclassifier::FeatureProcessorOptions::ICU);
+
+ TestingFeatureProcessor feature_processor(options);
+ std::vector<Token> tokens = feature_processor.Tokenize("พระบาทสมเด็จพระปรมิ");
+ ASSERT_EQ(tokens,
+ // clang-format off
+ std::vector<Token>({Token("พระบาท", 0, 6),
+ Token("สมเด็จ", 6, 12),
+ Token("พระ", 12, 15),
+ Token("ปร", 15, 17),
+ Token("มิ", 17, 19)}));
+ // clang-format on
+}
+
+TEST(FeatureProcessorTest, ICUTokenizeWithWhitespaces) {
+ FeatureProcessorOptions options;
+ options.set_tokenization_type(
+ libtextclassifier::FeatureProcessorOptions::ICU);
+ options.set_icu_preserve_whitespace_tokens(true);
+
+ TestingFeatureProcessor feature_processor(options);
+ std::vector<Token> tokens =
+ feature_processor.Tokenize("พระบาท สมเด็จ พระ ปร มิ");
+ ASSERT_EQ(tokens,
+ // clang-format off
+ std::vector<Token>({Token("พระบาท", 0, 6),
+ Token(" ", 6, 7),
+ Token("สมเด็จ", 7, 13),
+ Token(" ", 13, 14),
+ Token("พระ", 14, 17),
+ Token(" ", 17, 18),
+ Token("ปร", 18, 20),
+ Token(" ", 20, 21),
+ Token("มิ", 21, 23)}));
+ // clang-format on
+}
+
} // namespace
} // namespace libtextclassifier