Import libtextclassifier code and models.
This fixes a problem where models were not unmapped.
Also includes updated model files from some Tier-1 languages and
improved tokenization for cjt.
Switch back to dynamic linking of protobuf-lite library.
Bug: 37446398
Bug: 36886059
Test: Unit tests pass.
Change-Id: I5f9e8747918f49d8f1f7c65f3b8a6610141795df
diff --git a/Android.mk b/Android.mk
index 13c7891..f52a691 100644
--- a/Android.mk
+++ b/Android.mk
@@ -68,7 +68,7 @@
LOCAL_C_INCLUDES += $(proto_sources_dir)/proto/external/libtextclassifier
LOCAL_STATIC_LIBRARIES += libtextclassifier_protos
-LOCAL_STATIC_LIBRARIES += libprotobuf-cpp-lite
+LOCAL_SHARED_LIBRARIES += libprotobuf-cpp-lite
LOCAL_SHARED_LIBRARIES += liblog
LOCAL_SHARED_LIBRARIES += libicuuc libicui18n
LOCAL_REQUIRED_MODULES := textclassifier.langid.model
diff --git a/models/textclassifier.smartselection.ar.model b/models/textclassifier.smartselection.ar.model
index b2ea877..7d1065a 100644
--- a/models/textclassifier.smartselection.ar.model
+++ b/models/textclassifier.smartselection.ar.model
Binary files differ
diff --git a/models/textclassifier.smartselection.de.model b/models/textclassifier.smartselection.de.model
index cf6ee82..ebe319a 100644
--- a/models/textclassifier.smartselection.de.model
+++ b/models/textclassifier.smartselection.de.model
Binary files differ
diff --git a/models/textclassifier.smartselection.es.model b/models/textclassifier.smartselection.es.model
index 09612d1..917afd5 100644
--- a/models/textclassifier.smartselection.es.model
+++ b/models/textclassifier.smartselection.es.model
Binary files differ
diff --git a/models/textclassifier.smartselection.fr.model b/models/textclassifier.smartselection.fr.model
index ee19caf..619e09d 100644
--- a/models/textclassifier.smartselection.fr.model
+++ b/models/textclassifier.smartselection.fr.model
Binary files differ
diff --git a/models/textclassifier.smartselection.it.model b/models/textclassifier.smartselection.it.model
index 36a3016..7b0de8e 100644
--- a/models/textclassifier.smartselection.it.model
+++ b/models/textclassifier.smartselection.it.model
Binary files differ
diff --git a/models/textclassifier.smartselection.ja.model b/models/textclassifier.smartselection.ja.model
index 12f1e29..bb31154 100644
--- a/models/textclassifier.smartselection.ja.model
+++ b/models/textclassifier.smartselection.ja.model
Binary files differ
diff --git a/models/textclassifier.smartselection.ru.model b/models/textclassifier.smartselection.ru.model
index 9261e1a..6d42d66 100644
--- a/models/textclassifier.smartselection.ru.model
+++ b/models/textclassifier.smartselection.ru.model
Binary files differ
diff --git a/models/textclassifier.smartselection.th.model b/models/textclassifier.smartselection.th.model
index 482ff2e..f9b96ff 100644
--- a/models/textclassifier.smartselection.th.model
+++ b/models/textclassifier.smartselection.th.model
Binary files differ
diff --git a/smartselect/feature-processor.cc b/smartselect/feature-processor.cc
index 919b61f..1b15982 100644
--- a/smartselect/feature-processor.cc
+++ b/smartselect/feature-processor.cc
@@ -186,13 +186,18 @@
libtextclassifier::FeatureProcessorOptions::INTERNAL_TOKENIZER) {
return tokenizer_.Tokenize(utf8_text);
} else if (options_.tokenization_type() ==
- libtextclassifier::FeatureProcessorOptions::ICU) {
+ libtextclassifier::FeatureProcessorOptions::ICU ||
+ options_.tokenization_type() ==
+ libtextclassifier::FeatureProcessorOptions::MIXED) {
std::vector<Token> result;
- if (ICUTokenize(utf8_text, &result)) {
- return result;
- } else {
+ if (!ICUTokenize(utf8_text, &result)) {
return {};
}
+ if (options_.tokenization_type() ==
+ libtextclassifier::FeatureProcessorOptions::MIXED) {
+ InternalRetokenize(utf8_text, &result);
+ }
+ return result;
} else {
TC_LOG(ERROR) << "Unknown tokenization type specified. Using "
"internal.";
@@ -429,19 +434,20 @@
return true;
}
-void FeatureProcessor::PrepareSupportedCodepointRanges(
+void FeatureProcessor::PrepareCodepointRanges(
const std::vector<FeatureProcessorOptions::CodepointRange>&
- codepoint_ranges) {
- supported_codepoint_ranges_.clear();
- supported_codepoint_ranges_.reserve(codepoint_ranges.size());
+ codepoint_ranges,
+ std::vector<CodepointRange>* prepared_codepoint_ranges) {
+ prepared_codepoint_ranges->clear();
+ prepared_codepoint_ranges->reserve(codepoint_ranges.size());
for (const FeatureProcessorOptions::CodepointRange& range :
codepoint_ranges) {
- supported_codepoint_ranges_.push_back(
+ prepared_codepoint_ranges->push_back(
CodepointRange(range.start(), range.end()));
}
- std::sort(supported_codepoint_ranges_.begin(),
- supported_codepoint_ranges_.end(),
+ std::sort(prepared_codepoint_ranges->begin(),
+ prepared_codepoint_ranges->end(),
[](const CodepointRange& a, const CodepointRange& b) {
return a.start < b.start;
});
@@ -458,7 +464,7 @@
const UnicodeText value =
UTF8ToUnicodeText(tokens[i].value, /*do_copy=*/false);
for (auto codepoint : value) {
- if (IsCodepointSupported(codepoint)) {
+ if (IsCodepointInRanges(codepoint, supported_codepoint_ranges_)) {
++num_supported;
}
++num_total;
@@ -468,9 +474,10 @@
return static_cast<float>(num_supported) / static_cast<float>(num_total);
}
-bool FeatureProcessor::IsCodepointSupported(int codepoint) const {
- auto it = std::lower_bound(supported_codepoint_ranges_.begin(),
- supported_codepoint_ranges_.end(), codepoint,
+bool FeatureProcessor::IsCodepointInRanges(
+ int codepoint, const std::vector<CodepointRange>& codepoint_ranges) const {
+ auto it = std::lower_bound(codepoint_ranges.begin(), codepoint_ranges.end(),
+ codepoint,
[](const CodepointRange& range, int codepoint) {
// This function compares range with the
// codepoint for the purpose of finding the first
@@ -487,7 +494,7 @@
// than the codepoint.
return range.end <= codepoint;
});
- if (it != supported_codepoint_ranges_.end() && it->start <= codepoint &&
+ if (it != codepoint_ranges.end() && it->start <= codepoint &&
it->end > codepoint) {
return true;
} else {
@@ -691,7 +698,71 @@
last_unicode_index = unicode_index;
}
- return result;
+ return true;
+}
+
+void FeatureProcessor::InternalRetokenize(const std::string& context,
+ std::vector<Token>* tokens) const {
+ const UnicodeText unicode_text =
+ UTF8ToUnicodeText(context, /*do_copy=*/false);
+
+ std::vector<Token> result;
+ CodepointSpan span(-1, -1);
+ for (Token& token : *tokens) {
+ const UnicodeText unicode_token_value =
+ UTF8ToUnicodeText(token.value, /*do_copy=*/false);
+ bool should_retokenize = true;
+ for (const int codepoint : unicode_token_value) {
+ if (!IsCodepointInRanges(codepoint,
+ internal_tokenizer_codepoint_ranges_)) {
+ should_retokenize = false;
+ break;
+ }
+ }
+
+ if (should_retokenize) {
+ if (span.first < 0) {
+ span.first = token.start;
+ }
+ span.second = token.end;
+ } else {
+ TokenizeSubstring(unicode_text, span, &result);
+ span.first = -1;
+ result.emplace_back(std::move(token));
+ }
+ }
+ TokenizeSubstring(unicode_text, span, &result);
+
+ *tokens = std::move(result);
+}
+
+void FeatureProcessor::TokenizeSubstring(const UnicodeText& unicode_text,
+ CodepointSpan span,
+ std::vector<Token>* result) const {
+ if (span.first < 0) {
+ // There is no span to tokenize.
+ return;
+ }
+
+ // Extract the substring.
+ UnicodeText::const_iterator it_begin = unicode_text.begin();
+ for (int i = 0; i < span.first; ++i) {
+ ++it_begin;
+ }
+ UnicodeText::const_iterator it_end = unicode_text.begin();
+ for (int i = 0; i < span.second; ++i) {
+ ++it_end;
+ }
+ const std::string text = unicode_text.UTF8Substring(it_begin, it_end);
+
+ // Run the tokenizer and update the token bounds to reflect the offset of the
+ // substring.
+ std::vector<Token> tokens = tokenizer_.Tokenize(text);
+ for (Token& token : tokens) {
+ token.start += span.first;
+ token.end += span.first;
+ result->emplace_back(std::move(token));
+ }
}
} // namespace libtextclassifier
diff --git a/smartselect/feature-processor.h b/smartselect/feature-processor.h
index 0260a7c..2c64b67 100644
--- a/smartselect/feature-processor.h
+++ b/smartselect/feature-processor.h
@@ -29,6 +29,7 @@
#include "smartselect/tokenizer.h"
#include "smartselect/types.h"
#include "util/base/logging.h"
+#include "util/utf8/unicodetext.h"
namespace libtextclassifier {
@@ -96,9 +97,13 @@
tokenizer_({options.tokenization_codepoint_config().begin(),
options.tokenization_codepoint_config().end()}) {
MakeLabelMaps();
- PrepareSupportedCodepointRanges(
- {options.supported_codepoint_ranges().begin(),
- options.supported_codepoint_ranges().end()});
+ PrepareCodepointRanges({options.supported_codepoint_ranges().begin(),
+ options.supported_codepoint_ranges().end()},
+ &supported_codepoint_ranges_);
+ PrepareCodepointRanges(
+ {options.internal_tokenizer_codepoint_ranges().begin(),
+ options.internal_tokenizer_codepoint_ranges().end()},
+ &internal_tokenizer_codepoint_ranges_);
}
explicit FeatureProcessor(const std::string& serialized_options)
@@ -187,17 +192,20 @@
// Converts a token span to the corresponding label.
int TokenSpanToLabel(const std::pair<TokenIndex, TokenIndex>& span) const;
- void PrepareSupportedCodepointRanges(
+ void PrepareCodepointRanges(
const std::vector<FeatureProcessorOptions::CodepointRange>&
- codepoint_range_configs);
+ codepoint_ranges,
+ std::vector<CodepointRange>* prepared_codepoint_ranges);
// Returns the ratio of supported codepoints to total number of codepoints in
// the input context around given click position.
float SupportedCodepointsRatio(int click_pos,
const std::vector<Token>& tokens) const;
- // Returns true if given codepoint is supported.
- bool IsCodepointSupported(int codepoint) const;
+ // Returns true if given codepoint is covered by the given sorted vector of
+ // codepoint ranges.
+ bool IsCodepointInRanges(
+ int codepoint, const std::vector<CodepointRange>& codepoint_ranges) const;
// Finds the center token index in tokens vector, using the method defined
// in options_.
@@ -208,8 +216,29 @@
bool ICUTokenize(const std::string& context,
std::vector<Token>* result) const;
+ // Takes the result of ICU tokenization and retokenizes stretches of tokens
+ // made of a specific subset of characters using the internal tokenizer.
+ void InternalRetokenize(const std::string& context,
+ std::vector<Token>* tokens) const;
+
+ // Tokenizes a substring of the unicode string, appending the resulting tokens
+ // to the output vector. The resulting tokens have bounds relative to the full
+ // string. Does nothing if the start of the span is negative.
+ void TokenizeSubstring(const UnicodeText& unicode_text, CodepointSpan span,
+ std::vector<Token>* result) const;
+
const TokenFeatureExtractor feature_extractor_;
+ // Codepoint ranges that define what codepoints are supported by the model.
+ // NOTE: Must be sorted.
+ std::vector<CodepointRange> supported_codepoint_ranges_;
+
+ // Codepoint ranges that define which tokens (consisting of which codepoints)
+ // should be re-tokenized with the internal tokenizer in the mixed
+ // tokenization mode.
+ // NOTE: Must be sorted.
+ std::vector<CodepointRange> internal_tokenizer_codepoint_ranges_;
+
private:
const FeatureProcessorOptions options_;
@@ -221,10 +250,6 @@
std::map<std::string, int> collection_to_label_;
Tokenizer tokenizer_;
-
- // Codepoint ranges that define what codepoints are supported by the model.
- // NOTE: Must be sorted.
- std::vector<CodepointRange> supported_codepoint_ranges_;
};
} // namespace libtextclassifier
diff --git a/smartselect/text-classification-model.cc b/smartselect/text-classification-model.cc
index 44c544b..dee8f8b 100644
--- a/smartselect/text-classification-model.cc
+++ b/smartselect/text-classification-model.cc
@@ -39,6 +39,7 @@
using nlp_core::MemoryImageReader;
using nlp_core::MmapFile;
using nlp_core::MmapHandle;
+using nlp_core::ScopedMmap;
namespace {
@@ -94,8 +95,8 @@
}
}
-TextClassificationModel::TextClassificationModel(int fd) {
- initialized_ = LoadModels(fd);
+TextClassificationModel::TextClassificationModel(int fd) : mmap_(fd) {
+ initialized_ = LoadModels(mmap_.handle());
if (!initialized_) {
TC_LOG(ERROR) << "Failed to load models";
return;
@@ -169,8 +170,7 @@
} // namespace
-bool TextClassificationModel::LoadModels(int fd) {
- MmapHandle mmap_handle = MmapFile(fd);
+bool TextClassificationModel::LoadModels(const MmapHandle& mmap_handle) {
if (!mmap_handle.ok()) {
return false;
}
@@ -207,15 +207,15 @@
}
bool ReadSelectionModelOptions(int fd, ModelOptions* model_options) {
- MmapHandle mmap_handle = MmapFile(fd);
- if (!mmap_handle.ok()) {
+ ScopedMmap mmap = ScopedMmap(fd);
+ if (!mmap.handle().ok()) {
TC_LOG(ERROR) << "Can't mmap.";
return false;
}
const char *selection_model, *sharing_model;
int selection_model_length, sharing_model_length;
- ParseMergedModel(mmap_handle, &selection_model, &selection_model_length,
+ ParseMergedModel(mmap.handle(), &selection_model, &selection_model_length,
&sharing_model, &sharing_model_length);
MemoryImageReader<EmbeddingNetworkProto> reader(selection_model,
diff --git a/smartselect/text-classification-model.h b/smartselect/text-classification-model.h
index 312df1e..522372c 100644
--- a/smartselect/text-classification-model.h
+++ b/smartselect/text-classification-model.h
@@ -27,6 +27,7 @@
#include "common/embedding-network.h"
#include "common/feature-extractor.h"
#include "common/memory_image/embedding-network-params-from-image.h"
+#include "common/mmap.h"
#include "smartselect/feature-processor.h"
#include "smartselect/model-params.h"
#include "smartselect/text-classification-model.pb.h"
@@ -89,7 +90,7 @@
SharingModelOptions sharing_options_;
private:
- bool LoadModels(int fd);
+ bool LoadModels(const nlp_core::MmapHandle& mmap_handle);
nlp_core::EmbeddingNetwork::Vector InferInternal(
const std::string& context, CodepointSpan span,
@@ -108,6 +109,7 @@
CodepointSpan click_indices) const;
bool initialized_;
+ nlp_core::ScopedMmap mmap_;
std::unique_ptr<ModelParams> selection_params_;
std::unique_ptr<FeatureProcessor> selection_feature_processor_;
std::unique_ptr<nlp_core::EmbeddingNetwork> selection_network_;
diff --git a/smartselect/text-classification-model.proto b/smartselect/text-classification-model.proto
index 40e4ed0..b5b0287 100644
--- a/smartselect/text-classification-model.proto
+++ b/smartselect/text-classification-model.proto
@@ -156,6 +156,10 @@
// A set of codepoint ranges supported by the model.
repeated CodepointRange supported_codepoint_ranges = 23;
+ // A set of codepoint ranges to use in the mixed tokenization mode to identify
+ // stretches of tokens to re-tokenize using the internal tokenizer.
+ repeated CodepointRange internal_tokenizer_codepoint_ranges = 34;
+
// Minimum ratio of supported codepoints in the input context. If the ratio
// is lower than this, the feature computation will fail.
optional float min_supported_codepoint_ratio = 24 [default = 0.0];
@@ -179,6 +183,11 @@
// Use ICU for tokenization.
ICU = 2;
+
+ // First apply ICU tokenization. Then identify stretches of tokens
+ // consisting only of codepoints in internal_tokenizer_codepoint_ranges
+ // and re-tokenize them using the internal tokenizer.
+ MIXED = 3;
}
optional TokenizationType tokenization_type = 30
[default = INTERNAL_TOKENIZER];
diff --git a/tests/feature-processor_test.cc b/tests/feature-processor_test.cc
index cf09f96..4e27afc 100644
--- a/tests/feature-processor_test.cc
+++ b/tests/feature-processor_test.cc
@@ -203,8 +203,9 @@
using FeatureProcessor::FeatureProcessor;
using FeatureProcessor::SpanToLabel;
using FeatureProcessor::SupportedCodepointsRatio;
- using FeatureProcessor::IsCodepointSupported;
+ using FeatureProcessor::IsCodepointInRanges;
using FeatureProcessor::ICUTokenize;
+ using FeatureProcessor::supported_codepoint_ranges_;
};
TEST(FeatureProcessorTest, SpanToLabel) {
@@ -369,15 +370,24 @@
EXPECT_THAT(feature_processor.SupportedCodepointsRatio(
1, feature_processor.Tokenize("ěěě řřř ěěě")),
FloatEq(0.0));
- EXPECT_FALSE(feature_processor.IsCodepointSupported(-1));
- EXPECT_TRUE(feature_processor.IsCodepointSupported(0));
- EXPECT_TRUE(feature_processor.IsCodepointSupported(10));
- EXPECT_TRUE(feature_processor.IsCodepointSupported(127));
- EXPECT_FALSE(feature_processor.IsCodepointSupported(128));
- EXPECT_FALSE(feature_processor.IsCodepointSupported(9999));
- EXPECT_TRUE(feature_processor.IsCodepointSupported(10000));
- EXPECT_FALSE(feature_processor.IsCodepointSupported(10001));
- EXPECT_TRUE(feature_processor.IsCodepointSupported(25000));
+ EXPECT_FALSE(feature_processor.IsCodepointInRanges(
+ -1, feature_processor.supported_codepoint_ranges_));
+ EXPECT_TRUE(feature_processor.IsCodepointInRanges(
+ 0, feature_processor.supported_codepoint_ranges_));
+ EXPECT_TRUE(feature_processor.IsCodepointInRanges(
+ 10, feature_processor.supported_codepoint_ranges_));
+ EXPECT_TRUE(feature_processor.IsCodepointInRanges(
+ 127, feature_processor.supported_codepoint_ranges_));
+ EXPECT_FALSE(feature_processor.IsCodepointInRanges(
+ 128, feature_processor.supported_codepoint_ranges_));
+ EXPECT_FALSE(feature_processor.IsCodepointInRanges(
+ 9999, feature_processor.supported_codepoint_ranges_));
+ EXPECT_TRUE(feature_processor.IsCodepointInRanges(
+ 10000, feature_processor.supported_codepoint_ranges_));
+ EXPECT_FALSE(feature_processor.IsCodepointInRanges(
+ 10001, feature_processor.supported_codepoint_ranges_));
+ EXPECT_TRUE(feature_processor.IsCodepointInRanges(
+ 25000, feature_processor.supported_codepoint_ranges_));
std::vector<Token> tokens;
int click_pos;
@@ -559,5 +569,46 @@
// clang-format on
}
+TEST(FeatureProcessorTest, MixedTokenize) {
+ FeatureProcessorOptions options;
+ options.set_tokenization_type(
+ libtextclassifier::FeatureProcessorOptions::MIXED);
+
+ TokenizationCodepointRange* config =
+ options.add_tokenization_codepoint_config();
+ config->set_start(32);
+ config->set_end(33);
+ config->set_role(TokenizationCodepointRange::WHITESPACE_SEPARATOR);
+
+ FeatureProcessorOptions::CodepointRange* range;
+ range = options.add_internal_tokenizer_codepoint_ranges();
+ range->set_start(0);
+ range->set_end(128);
+
+ range = options.add_internal_tokenizer_codepoint_ranges();
+ range->set_start(128);
+ range->set_end(256);
+
+ range = options.add_internal_tokenizer_codepoint_ranges();
+ range->set_start(256);
+ range->set_end(384);
+
+ range = options.add_internal_tokenizer_codepoint_ranges();
+ range->set_start(384);
+ range->set_end(592);
+
+ TestingFeatureProcessor feature_processor(options);
+ std::vector<Token> tokens = feature_processor.Tokenize(
+ "こんにちはJapanese-ląnguagę text 世界 http://www.google.com/");
+ ASSERT_EQ(tokens,
+ // clang-format off
+ std::vector<Token>({Token("こんにちは", 0, 5),
+ Token("Japanese-ląnguagę", 5, 22),
+ Token("text", 23, 27),
+ Token("世界", 28, 30),
+ Token("http://www.google.com/", 31, 53)}));
+ // clang-format on
+}
+
} // namespace
} // namespace libtextclassifier