Adds support for ICU tokenization.

Bug: 36886053

Test: Built, tested on device, google3 regression and unit tests pass.
Change-Id: Ia3345c6c7a5aa816233d3b3ae10e2a92b31f08a7
diff --git a/smartselect/feature-processor.cc b/smartselect/feature-processor.cc
index 0ba25ca..6f2dc73 100644
--- a/smartselect/feature-processor.cc
+++ b/smartselect/feature-processor.cc
@@ -24,6 +24,9 @@
 #include "util/base/logging.h"
 #include "util/strings/utf8.h"
 #include "util/utf8/unicodetext.h"
+#include "unicode/brkiter.h"
+#include "unicode/errorcode.h"
+#include "unicode/uchar.h"
 
 namespace libtextclassifier {
 
@@ -178,7 +181,22 @@
 
 std::vector<Token> FeatureProcessor::Tokenize(
     const std::string& utf8_text) const {
-  return tokenizer_.Tokenize(utf8_text);
+  if (options_.tokenization_type() ==
+      libtextclassifier::FeatureProcessorOptions::INTERNAL_TOKENIZER) {
+    return tokenizer_.Tokenize(utf8_text);
+  } else if (options_.tokenization_type() ==
+             libtextclassifier::FeatureProcessorOptions::ICU) {
+    std::vector<Token> result;
+    if (ICUTokenize(utf8_text, &result)) {
+      return result;
+    } else {
+      return {};
+    }
+  } else {
+    TC_LOG(ERROR) << "Unknown tokenization type specified. Using "
+                     "internal.";
+    return tokenizer_.Tokenize(utf8_text);
+  }
 }
 
 bool FeatureProcessor::LabelToSpan(
@@ -278,7 +296,6 @@
   }
 }
 
-// Converts a codepoint span to a token span in the given list of tokens.
 TokenSpan CodepointSpanToTokenSpan(const std::vector<Token>& selectable_tokens,
                                    CodepointSpan codepoint_span) {
   const int codepoint_start = std::get<0>(codepoint_span);
@@ -299,6 +316,12 @@
   return {start_token, end_token};
 }
 
+CodepointSpan TokenSpanToCodepointSpan(
+    const std::vector<Token>& selectable_tokens, TokenSpan token_span) {
+  return {selectable_tokens[token_span.first].start,
+          selectable_tokens[token_span.second - 1].end};
+}
+
 namespace {
 
 // Finds a single token that completely contains the given span.
@@ -633,4 +656,49 @@
   return options_.context_size();
 }
 
+bool FeatureProcessor::ICUTokenize(const std::string& context,
+                                   std::vector<Token>* result) const {
+  icu::ErrorCode status;
+  icu::UnicodeString unicode_text = icu::UnicodeString::fromUTF8(context);
+  std::unique_ptr<icu::BreakIterator> break_iterator(
+      icu::BreakIterator::createWordInstance(icu::Locale("en"), status));
+  if (!status.isSuccess()) {
+    TC_LOG(ERROR) << "Break iterator did not initialize properly: "
+                  << status.errorName();
+    return false;
+  }
+
+  break_iterator->setText(unicode_text);
+
+  size_t last_break_index = 0;
+  size_t break_index = 0;
+  size_t last_unicode_index = 0;
+  size_t unicode_index = 0;
+  while ((break_index = break_iterator->next()) != icu::BreakIterator::DONE) {
+    icu::UnicodeString token(unicode_text, last_break_index,
+                             break_index - last_break_index);
+    int token_length = token.countChar32();
+    unicode_index = last_unicode_index + token_length;
+
+    std::string token_utf8;
+    token.toUTF8String(token_utf8);
+
+    bool is_whitespace = true;
+    for (int i = 0; i < token.length(); i++) {
+      if (!u_isWhitespace(token.char32At(i))) {
+        is_whitespace = false;
+      }
+    }
+
+    if (!is_whitespace || options_.icu_preserve_whitespace_tokens()) {
+      result->push_back(Token(token_utf8, last_unicode_index, unicode_index));
+    }
+
+    last_break_index = break_index;
+    last_unicode_index = unicode_index;
+  }
+
+  return result;
+}
+
 }  // namespace libtextclassifier
diff --git a/smartselect/feature-processor.h b/smartselect/feature-processor.h
index 2f1e530..1efcf63 100644
--- a/smartselect/feature-processor.h
+++ b/smartselect/feature-processor.h
@@ -79,9 +79,14 @@
 
 }  // namespace internal
 
+// Converts a codepoint span to a token span in the given list of tokens.
 TokenSpan CodepointSpanToTokenSpan(const std::vector<Token>& selectable_tokens,
                                    CodepointSpan codepoint_span);
 
+// Converts a token span to a codepoint span in the given list of tokens.
+CodepointSpan TokenSpanToCodepointSpan(
+    const std::vector<Token>& selectable_tokens, TokenSpan token_span);
+
 // Takes care of preparing features for the span prediction model.
 class FeatureProcessor {
  public:
@@ -203,6 +208,10 @@
   // Pads tokens with options.context_size() padding tokens on both sides.
   int PadContext(std::vector<Token>* tokens) const;
 
+  // Tokenizes the input text using ICU tokenizer.
+  bool ICUTokenize(const std::string& context,
+                   std::vector<Token>* result) const;
+
   const TokenFeatureExtractor feature_extractor_;
 
  private:
diff --git a/smartselect/text-classification-model.cc b/smartselect/text-classification-model.cc
index ca0484f..71457bc 100644
--- a/smartselect/text-classification-model.cc
+++ b/smartselect/text-classification-model.cc
@@ -201,6 +201,14 @@
     return {};
   }
 
+  if (selection_label_spans != nullptr) {
+    if (!feature_processor.SelectionLabelSpans(output_tokens,
+                                               selection_label_spans)) {
+      TC_LOG(ERROR) << "Could not get spans for selection labels.";
+      return {};
+    }
+  }
+
   std::vector<float> scores;
   network.ComputeLogits(features, &scores);
   return scores;
diff --git a/smartselect/text-classification-model.proto b/smartselect/text-classification-model.proto
index 7a4c9f1..4e21af4 100644
--- a/smartselect/text-classification-model.proto
+++ b/smartselect/text-classification-model.proto
@@ -88,9 +88,6 @@
   // Whether to remap all digits to a single number.
   optional bool remap_digits = 20 [default = false];
 
-  // If true, tokenize on space, otherwise tokenize using ICU.
-  optional bool tokenize_on_space = 7 [default = true];
-
   // If true, the selection classifier output will contain only the selections
   // that are feasible (e.g., those that are shorter than max_selection_span),
   // if false, the output will be a complete cross-product of possible
@@ -169,6 +166,43 @@
   //      and at the end, the dense features for the tokens are concatenated
   //      to it. So the resulting feature vector has two regions.
   optional int32 feature_version = 25 [default = 0];
+
+  // These settings control whether a distortion is applied to part of the data,
+  // for Smart Sharing. Distortion means modifying (expanding) the bounds of the
+  // selection and changing the example's collection to "other". The goal is to
+  // expose the model to overselections as negative examples.
+  // If true, distortion is applied. Otherwise the other settings are ignored.
+  optional bool distortion_enable = 26;
+  // Probability settings. They individual values and their sum should be in the
+  // range [0, 1]. They specify the probability of the distorition being applied
+  // to just one of the bounds (left or right with equal probabiolity), both
+  // bounds, or not at all (the remaining probability).
+  // If the context does not contain tokens on the given side of the selection,
+  // the probabilistic decision is ignored. This means that the actual frequency
+  // of distortion is somewhat lower than specified here.
+  optional double distortion_probability_one_side = 27;
+  optional double distortion_probability_both_sides = 28;
+  // The maximum number of tokens to include (on one side) when distortion is
+  // applied. The actual number is selected (independently for each side)
+  // uniformly from integers from 1 to this value, inclusive. If the context is
+  // too short, the end result is truncated.
+  optional double distortion_max_num_tokens = 29;
+
+  // Controls the type of tokenization the model will use for the input text.
+  enum TokenizationType {
+    INVALID_TOKENIZATION_TYPE = 0;
+
+    // Use the internal tokenizer for tokenization.
+    INTERNAL_TOKENIZER = 1;
+
+    // Use ICU for tokenization.
+    ICU = 2;
+  }
+  optional TokenizationType tokenization_type = 30
+      [default = INTERNAL_TOKENIZER];
+  optional bool icu_preserve_whitespace_tokens = 31 [default = false];
+
+  reserved 7;
 };
 
 extend nlp_core.EmbeddingNetworkProto {
diff --git a/tests/feature-processor_test.cc b/tests/feature-processor_test.cc
index e3a39e3..cf09f96 100644
--- a/tests/feature-processor_test.cc
+++ b/tests/feature-processor_test.cc
@@ -204,13 +204,13 @@
   using FeatureProcessor::SpanToLabel;
   using FeatureProcessor::SupportedCodepointsRatio;
   using FeatureProcessor::IsCodepointSupported;
+  using FeatureProcessor::ICUTokenize;
 };
 
 TEST(FeatureProcessorTest, SpanToLabel) {
   FeatureProcessorOptions options;
   options.set_context_size(1);
   options.set_max_selection_span(1);
-  options.set_tokenize_on_space(true);
   options.set_snap_label_span_boundaries_to_containing_tokens(false);
 
   TokenizationCodepointRange* config =
@@ -519,5 +519,45 @@
   EXPECT_EQ(click_index, 5);
 }
 
+TEST(FeatureProcessorTest, ICUTokenize) {
+  FeatureProcessorOptions options;
+  options.set_tokenization_type(
+      libtextclassifier::FeatureProcessorOptions::ICU);
+
+  TestingFeatureProcessor feature_processor(options);
+  std::vector<Token> tokens = feature_processor.Tokenize("พระบาทสมเด็จพระปรมิ");
+  ASSERT_EQ(tokens,
+            // clang-format off
+            std::vector<Token>({Token("พระบาท", 0, 6),
+                                Token("สมเด็จ", 6, 12),
+                                Token("พระ", 12, 15),
+                                Token("ปร", 15, 17),
+                                Token("มิ", 17, 19)}));
+  // clang-format on
+}
+
+TEST(FeatureProcessorTest, ICUTokenizeWithWhitespaces) {
+  FeatureProcessorOptions options;
+  options.set_tokenization_type(
+      libtextclassifier::FeatureProcessorOptions::ICU);
+  options.set_icu_preserve_whitespace_tokens(true);
+
+  TestingFeatureProcessor feature_processor(options);
+  std::vector<Token> tokens =
+      feature_processor.Tokenize("พระบาท สมเด็จ พระ ปร มิ");
+  ASSERT_EQ(tokens,
+            // clang-format off
+            std::vector<Token>({Token("พระบาท", 0, 6),
+                                Token(" ", 6, 7),
+                                Token("สมเด็จ", 7, 13),
+                                Token(" ", 13, 14),
+                                Token("พระ", 14, 17),
+                                Token(" ", 17, 18),
+                                Token("ปร", 18, 20),
+                                Token(" ", 20, 21),
+                                Token("มิ", 21, 23)}));
+  // clang-format on
+}
+
 }  // namespace
 }  // namespace libtextclassifier