DO NOT MERGE - Merge Android 10 into master
Bug: 139893257
Change-Id: I97b1574b40b11d420d5ca9a00900b35816844ed0
diff --git a/Android.bp b/Android.bp
index 4e02d66..0962f2a 100644
--- a/Android.bp
+++ b/Android.bp
@@ -73,12 +73,15 @@
"-Wno-unused-parameter",
"-Wno-extern-c-compat",
+ "-funsigned-char",
"-fvisibility=hidden",
"-DLIBTEXTCLASSIFIER_UNILIB_ICU",
"-DZLIB_CONST",
"-DSAFTM_COMPACT_LOGGING",
+ "-DTC3_WITH_ACTIONS_OPS",
"-DTC3_UNILIB_JAVAICU",
"-DTC3_CALENDAR_JAVAICU",
+ "-DTC3_AOSP"
],
product_variables: {
@@ -89,9 +92,19 @@
},
generated_headers: [
+ "libtextclassifier_fbgen_flatbuffers",
+ "libtextclassifier_fbgen_tokenizer",
+ "libtextclassifier_fbgen_codepoint_range",
+ "libtextclassifier_fbgen_entity-data",
"libtextclassifier_fbgen_zlib_buffer",
+ "libtextclassifier_fbgen_resources_extra",
"libtextclassifier_fbgen_intent_config",
"libtextclassifier_fbgen_annotator_model",
+ "libtextclassifier_fbgen_actions_model",
+ "libtextclassifier_fbgen_tflite_text_encoder_config",
+ "libtextclassifier_fbgen_lang_id_embedded_network",
+ "libtextclassifier_fbgen_lang_id_model",
+ "libtextclassifier_fbgen_actions-entity-data",
],
header_libs: [
@@ -106,6 +119,7 @@
],
static_libs: [
+ "liblua",
"libutf",
],
}
@@ -113,32 +127,104 @@
// -----------------
// Generate headers with FlatBuffer schema compiler.
// -----------------
+genrule_defaults {
+ name: "fbgen",
+ tools: ["flatc"],
+ // "depfile" is used here in conjunction with flatc's -M to gather the deps
+ cmd: "$(location flatc) --cpp --no-union-value-namespacing --gen-object-api --keep-prefix -I external/libtextclassifier -M $(in) >$(depfile) && " +
+ "$(location flatc) --cpp --no-union-value-namespacing --gen-object-api --keep-prefix -I external/libtextclassifier -o $$(dirname $(out)) $(in)",
+ depfile: true,
+}
+
+genrule {
+ name: "libtextclassifier_fbgen_flatbuffers",
+ srcs: ["utils/flatbuffers.fbs"],
+ out: ["utils/flatbuffers_generated.h"],
+ defaults: ["fbgen"],
+}
+
+genrule {
+ name: "libtextclassifier_fbgen_tokenizer",
+ srcs: ["utils/tokenizer.fbs"],
+ out: ["utils/tokenizer_generated.h"],
+ defaults: ["fbgen"],
+}
+
+genrule {
+ name: "libtextclassifier_fbgen_codepoint_range",
+ srcs: ["utils/codepoint-range.fbs"],
+ out: ["utils/codepoint-range_generated.h"],
+ defaults: ["fbgen"],
+}
+
+genrule {
+ name: "libtextclassifier_fbgen_resources_extra",
+ srcs: ["utils/resources.fbs"],
+ out: ["utils/resources_generated.h"],
+ defaults: ["fbgen"],
+}
+
+genrule {
+ name: "libtextclassifier_fbgen_entity-data",
+ srcs: ["annotator/entity-data.fbs"],
+ out: ["annotator/entity-data_generated.h"],
+ defaults: ["fbgen"],
+}
genrule {
name: "libtextclassifier_fbgen_zlib_buffer",
- tools: ["flatc"],
- cmd: "$(location flatc) --cpp --no-union-value-namespacing --gen-object-api --keep-prefix -o $(genDir)/utils/zlib $(in)",
srcs: ["utils/zlib/buffer.fbs"],
out: ["utils/zlib/buffer_generated.h"],
+ defaults: ["fbgen"],
}
genrule {
name: "libtextclassifier_fbgen_intent_config",
- tools: ["flatc"],
- cmd: "$(location flatc) --cpp --no-union-value-namespacing --gen-object-api --keep-prefix -o $(genDir)/utils/intents $(in)",
srcs: ["utils/intents/intent-config.fbs"],
out: ["utils/intents/intent-config_generated.h"],
+ defaults: ["fbgen"],
}
genrule {
name: "libtextclassifier_fbgen_annotator_model",
- tools: ["flatc"],
- // "depfile" is used here in conjunction with flatc's -M to gather the deps of annotator/model.fbs
- cmd: "$(location flatc) --cpp --no-union-value-namespacing --gen-object-api --keep-prefix -I external/libtextclassifier -M $(in) >$(depfile) && " +
- "$(location flatc) --cpp --no-union-value-namespacing --gen-object-api --keep-prefix -I external/libtextclassifier -o $(genDir)/annotator $(in)",
- depfile: true,
srcs: ["annotator/model.fbs"],
out: ["annotator/model_generated.h"],
+ defaults: ["fbgen"],
+}
+
+genrule {
+ name: "libtextclassifier_fbgen_actions_model",
+ srcs: ["actions/actions_model.fbs"],
+ out: ["actions/actions_model_generated.h"],
+ defaults: ["fbgen"],
+}
+
+genrule {
+ name: "libtextclassifier_fbgen_tflite_text_encoder_config",
+ srcs: ["utils/tflite/text_encoder_config.fbs"],
+ out: ["utils/tflite/text_encoder_config_generated.h"],
+ defaults: ["fbgen"],
+}
+
+genrule {
+ name: "libtextclassifier_fbgen_lang_id_embedded_network",
+ srcs: ["lang_id/common/flatbuffers/embedding-network.fbs"],
+ out: ["lang_id/common/flatbuffers/embedding-network_generated.h"],
+ defaults: ["fbgen"],
+}
+
+genrule {
+ name: "libtextclassifier_fbgen_lang_id_model",
+ srcs: ["lang_id/common/flatbuffers/model.fbs"],
+ out: ["lang_id/common/flatbuffers/model_generated.h"],
+ defaults: ["fbgen"],
+}
+
+genrule {
+ name: "libtextclassifier_fbgen_actions-entity-data",
+ srcs: ["actions/actions-entity-data.fbs"],
+ out: ["actions/actions-entity-data_generated.h"],
+ defaults: ["fbgen"],
}
// -----------------
@@ -152,12 +238,17 @@
exclude_srcs: [
"**/*_test.cc",
"**/*-test-lib.cc",
+ "utils/testing/*.cc",
"test-util.*",
+ "utils/calendar/*_test-include.*",
+ "utils/utf8/*_test-include.*"
],
required: [
"libtextclassifier_annotator_en_model",
"libtextclassifier_annotator_universal_model",
+ "libtextclassifier_actions_suggestions_universal_model",
+ "libtextclassifier_lang_id_model",
],
version_script: "jni.lds",
@@ -174,11 +265,17 @@
data: [
"annotator/test_data/**/*",
+ "actions/test_data/**/*",
],
srcs: ["**/*.cc"],
// TODO: Do not filter out tflite test once the dependency issue is resolved.
- exclude_srcs: ["utils/tflite/*_test.cc"],
+ exclude_srcs: [
+ "utils/tflite/*_test.cc",
+ "utils/flatbuffers_test.cc",
+ "utils/calendar/*_test-include.*",
+ "utils/utf8/*_test-include.*"
+ ],
static_libs: ["libgmock"],
@@ -211,3 +308,27 @@
src: "models/textclassifier.universal.model",
sub_dir: "textclassifier",
}
+
+// ---------------------------
+// Actions Suggestions models
+// ---------------------------
+
+prebuilt_etc {
+ name: "libtextclassifier_actions_suggestions_universal_model",
+ filename: "actions_suggestions.universal.model",
+ owner: "google",
+ src: "models/actions_suggestions.universal.model",
+ sub_dir: "textclassifier",
+}
+
+// ------------
+// LangId model
+// ------------
+
+prebuilt_etc {
+ name: "libtextclassifier_lang_id_model",
+ filename: "lang_id.model",
+ owner: "google",
+ src: "models/lang_id.model",
+ sub_dir: "textclassifier",
+}
diff --git a/actions/actions-entity-data.fbs b/actions/actions-entity-data.fbs
new file mode 100755
index 0000000..4ed68bb
--- /dev/null
+++ b/actions/actions-entity-data.fbs
@@ -0,0 +1,24 @@
+//
+// Copyright (C) 2018 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+// Extra information and data associated with actions.
+namespace libtextclassifier3;
+table ActionsEntityData {
+ // Extracted text.
+ text:string;
+}
+
+root_type libtextclassifier3.ActionsEntityData;
diff --git a/actions/actions-suggestions.cc b/actions/actions-suggestions.cc
new file mode 100644
index 0000000..29a4424
--- /dev/null
+++ b/actions/actions-suggestions.cc
@@ -0,0 +1,1450 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "actions/actions-suggestions.h"
+
+#include <memory>
+
+#include "actions/lua-actions.h"
+#include "actions/types.h"
+#include "actions/zlib-utils.h"
+#include "utils/base/logging.h"
+#include "utils/flatbuffers.h"
+#include "utils/lua-utils.h"
+#include "utils/regex-match.h"
+#include "utils/strings/split.h"
+#include "utils/strings/stringpiece.h"
+#include "utils/utf8/unicodetext.h"
+#include "utils/zlib/zlib_regex.h"
+#include "tensorflow/lite/string_util.h"
+
+namespace libtextclassifier3 {
+
+const std::string& ActionsSuggestions::kViewCalendarType =
+ *[]() { return new std::string("view_calendar"); }();
+const std::string& ActionsSuggestions::kViewMapType =
+ *[]() { return new std::string("view_map"); }();
+const std::string& ActionsSuggestions::kTrackFlightType =
+ *[]() { return new std::string("track_flight"); }();
+const std::string& ActionsSuggestions::kOpenUrlType =
+ *[]() { return new std::string("open_url"); }();
+const std::string& ActionsSuggestions::kSendSmsType =
+ *[]() { return new std::string("send_sms"); }();
+const std::string& ActionsSuggestions::kCallPhoneType =
+ *[]() { return new std::string("call_phone"); }();
+const std::string& ActionsSuggestions::kSendEmailType =
+ *[]() { return new std::string("send_email"); }();
+const std::string& ActionsSuggestions::kShareLocation =
+ *[]() { return new std::string("share_location"); }();
+
+namespace {
+
+const ActionsModel* LoadAndVerifyModel(const uint8_t* addr, int size) {
+ flatbuffers::Verifier verifier(addr, size);
+ if (VerifyActionsModelBuffer(verifier)) {
+ return GetActionsModel(addr);
+ } else {
+ return nullptr;
+ }
+}
+
+template <typename T>
+T ValueOrDefault(const flatbuffers::Table* values, const int32 field_offset,
+ const T default_value) {
+ if (values == nullptr) {
+ return default_value;
+ }
+ return values->GetField<T>(field_offset, default_value);
+}
+
+// Returns number of (tail) messages of a conversation to consider.
+int NumMessagesToConsider(const Conversation& conversation,
+ const int max_conversation_history_length) {
+ return ((max_conversation_history_length < 0 ||
+ conversation.messages.size() < max_conversation_history_length)
+ ? conversation.messages.size()
+ : max_conversation_history_length);
+}
+
+} // namespace
+
+std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromUnownedBuffer(
+ const uint8_t* buffer, const int size, const UniLib* unilib,
+ const std::string& triggering_preconditions_overlay) {
+ auto actions = std::unique_ptr<ActionsSuggestions>(new ActionsSuggestions());
+ const ActionsModel* model = LoadAndVerifyModel(buffer, size);
+ if (model == nullptr) {
+ return nullptr;
+ }
+ actions->model_ = model;
+ actions->SetOrCreateUnilib(unilib);
+ actions->triggering_preconditions_overlay_buffer_ =
+ triggering_preconditions_overlay;
+ if (!actions->ValidateAndInitialize()) {
+ return nullptr;
+ }
+ return actions;
+}
+
+std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromScopedMmap(
+ std::unique_ptr<libtextclassifier3::ScopedMmap> mmap, const UniLib* unilib,
+ const std::string& triggering_preconditions_overlay) {
+ if (!mmap->handle().ok()) {
+ TC3_VLOG(1) << "Mmap failed.";
+ return nullptr;
+ }
+ const ActionsModel* model = LoadAndVerifyModel(
+ reinterpret_cast<const uint8_t*>(mmap->handle().start()),
+ mmap->handle().num_bytes());
+ if (!model) {
+ TC3_LOG(ERROR) << "Model verification failed.";
+ return nullptr;
+ }
+ auto actions = std::unique_ptr<ActionsSuggestions>(new ActionsSuggestions());
+ actions->model_ = model;
+ actions->mmap_ = std::move(mmap);
+ actions->SetOrCreateUnilib(unilib);
+ actions->triggering_preconditions_overlay_buffer_ =
+ triggering_preconditions_overlay;
+ if (!actions->ValidateAndInitialize()) {
+ return nullptr;
+ }
+ return actions;
+}
+
+std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromScopedMmap(
+ std::unique_ptr<libtextclassifier3::ScopedMmap> mmap,
+ std::unique_ptr<UniLib> unilib,
+ const std::string& triggering_preconditions_overlay) {
+ if (!mmap->handle().ok()) {
+ TC3_VLOG(1) << "Mmap failed.";
+ return nullptr;
+ }
+ const ActionsModel* model = LoadAndVerifyModel(
+ reinterpret_cast<const uint8_t*>(mmap->handle().start()),
+ mmap->handle().num_bytes());
+ if (!model) {
+ TC3_LOG(ERROR) << "Model verification failed.";
+ return nullptr;
+ }
+ auto actions = std::unique_ptr<ActionsSuggestions>(new ActionsSuggestions());
+ actions->model_ = model;
+ actions->mmap_ = std::move(mmap);
+ actions->owned_unilib_ = std::move(unilib);
+ actions->unilib_ = actions->owned_unilib_.get();
+ actions->triggering_preconditions_overlay_buffer_ =
+ triggering_preconditions_overlay;
+ if (!actions->ValidateAndInitialize()) {
+ return nullptr;
+ }
+ return actions;
+}
+
+std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromFileDescriptor(
+ const int fd, const int offset, const int size, const UniLib* unilib,
+ const std::string& triggering_preconditions_overlay) {
+ std::unique_ptr<libtextclassifier3::ScopedMmap> mmap;
+ if (offset >= 0 && size >= 0) {
+ mmap.reset(new libtextclassifier3::ScopedMmap(fd, offset, size));
+ } else {
+ mmap.reset(new libtextclassifier3::ScopedMmap(fd));
+ }
+ return FromScopedMmap(std::move(mmap), unilib,
+ triggering_preconditions_overlay);
+}
+
+std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromFileDescriptor(
+ const int fd, const int offset, const int size,
+ std::unique_ptr<UniLib> unilib,
+ const std::string& triggering_preconditions_overlay) {
+ std::unique_ptr<libtextclassifier3::ScopedMmap> mmap;
+ if (offset >= 0 && size >= 0) {
+ mmap.reset(new libtextclassifier3::ScopedMmap(fd, offset, size));
+ } else {
+ mmap.reset(new libtextclassifier3::ScopedMmap(fd));
+ }
+ return FromScopedMmap(std::move(mmap), std::move(unilib),
+ triggering_preconditions_overlay);
+}
+
+std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromFileDescriptor(
+ const int fd, const UniLib* unilib,
+ const std::string& triggering_preconditions_overlay) {
+ std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
+ new libtextclassifier3::ScopedMmap(fd));
+ return FromScopedMmap(std::move(mmap), unilib,
+ triggering_preconditions_overlay);
+}
+
+std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromFileDescriptor(
+ const int fd, std::unique_ptr<UniLib> unilib,
+ const std::string& triggering_preconditions_overlay) {
+ std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
+ new libtextclassifier3::ScopedMmap(fd));
+ return FromScopedMmap(std::move(mmap), std::move(unilib),
+ triggering_preconditions_overlay);
+}
+
+std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromPath(
+ const std::string& path, const UniLib* unilib,
+ const std::string& triggering_preconditions_overlay) {
+ std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
+ new libtextclassifier3::ScopedMmap(path));
+ return FromScopedMmap(std::move(mmap), unilib,
+ triggering_preconditions_overlay);
+}
+
+std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromPath(
+ const std::string& path, std::unique_ptr<UniLib> unilib,
+ const std::string& triggering_preconditions_overlay) {
+ std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
+ new libtextclassifier3::ScopedMmap(path));
+ return FromScopedMmap(std::move(mmap), std::move(unilib),
+ triggering_preconditions_overlay);
+}
+
+void ActionsSuggestions::SetOrCreateUnilib(const UniLib* unilib) {
+ if (unilib != nullptr) {
+ unilib_ = unilib;
+ } else {
+ owned_unilib_.reset(new UniLib);
+ unilib_ = owned_unilib_.get();
+ }
+}
+
+bool ActionsSuggestions::ValidateAndInitialize() {
+ if (model_ == nullptr) {
+ TC3_LOG(ERROR) << "No model specified.";
+ return false;
+ }
+
+ if (model_->smart_reply_action_type() == nullptr) {
+ TC3_LOG(ERROR) << "No smart reply action type specified.";
+ return false;
+ }
+
+ if (!InitializeTriggeringPreconditions()) {
+ TC3_LOG(ERROR) << "Could not initialize preconditions.";
+ return false;
+ }
+
+ if (model_->locales() &&
+ !ParseLocales(model_->locales()->c_str(), &locales_)) {
+ TC3_LOG(ERROR) << "Could not parse model supported locales.";
+ return false;
+ }
+
+ if (model_->tflite_model_spec() != nullptr) {
+ model_executor_ = TfLiteModelExecutor::FromBuffer(
+ model_->tflite_model_spec()->tflite_model());
+ if (!model_executor_) {
+ TC3_LOG(ERROR) << "Could not initialize model executor.";
+ return false;
+ }
+ }
+
+ if (model_->annotation_actions_spec() != nullptr &&
+ model_->annotation_actions_spec()->annotation_mapping() != nullptr) {
+ for (const AnnotationActionsSpec_::AnnotationMapping* mapping :
+ *model_->annotation_actions_spec()->annotation_mapping()) {
+ annotation_entity_types_.insert(mapping->annotation_collection()->str());
+ }
+ }
+
+ std::unique_ptr<ZlibDecompressor> decompressor = ZlibDecompressor::Instance();
+ if (!InitializeRules(decompressor.get())) {
+ TC3_LOG(ERROR) << "Could not initialize rules.";
+ return false;
+ }
+
+ if (model_->actions_entity_data_schema() != nullptr) {
+ entity_data_schema_ = LoadAndVerifyFlatbuffer<reflection::Schema>(
+ model_->actions_entity_data_schema()->Data(),
+ model_->actions_entity_data_schema()->size());
+ if (entity_data_schema_ == nullptr) {
+ TC3_LOG(ERROR) << "Could not load entity data schema data.";
+ return false;
+ }
+
+ entity_data_builder_.reset(
+ new ReflectiveFlatbufferBuilder(entity_data_schema_));
+ } else {
+ entity_data_schema_ = nullptr;
+ }
+
+ std::string actions_script;
+ if (GetUncompressedString(model_->lua_actions_script(),
+ model_->compressed_lua_actions_script(),
+ decompressor.get(), &actions_script) &&
+ !actions_script.empty()) {
+ if (!Compile(actions_script, &lua_bytecode_)) {
+ TC3_LOG(ERROR) << "Could not precompile lua actions snippet.";
+ return false;
+ }
+ }
+
+ if (!(ranker_ = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
+ model_->ranking_options(), decompressor.get(),
+ model_->smart_reply_action_type()->str()))) {
+ TC3_LOG(ERROR) << "Could not create an action suggestions ranker.";
+ return false;
+ }
+
+ // Create feature processor if specified.
+ const ActionsTokenFeatureProcessorOptions* options =
+ model_->feature_processor_options();
+ if (options != nullptr) {
+ if (options->tokenizer_options() == nullptr) {
+ TC3_LOG(ERROR) << "No tokenizer options specified.";
+ return false;
+ }
+
+ feature_processor_.reset(new ActionsFeatureProcessor(options, unilib_));
+ embedding_executor_ = TFLiteEmbeddingExecutor::FromBuffer(
+ options->embedding_model(), options->embedding_size(),
+ options->embedding_quantization_bits());
+
+ if (embedding_executor_ == nullptr) {
+ TC3_LOG(ERROR) << "Could not initialize embedding executor.";
+ return false;
+ }
+
+ // Cache embedding of padding, start and end token.
+ if (!EmbedTokenId(options->padding_token_id(), &embedded_padding_token_) ||
+ !EmbedTokenId(options->start_token_id(), &embedded_start_token_) ||
+ !EmbedTokenId(options->end_token_id(), &embedded_end_token_)) {
+ TC3_LOG(ERROR) << "Could not precompute token embeddings.";
+ return false;
+ }
+ token_embedding_size_ = feature_processor_->GetTokenEmbeddingSize();
+ }
+
+ // Create low confidence model if specified.
+ if (model_->low_confidence_ngram_model() != nullptr) {
+ ngram_model_ = NGramModel::Create(model_->low_confidence_ngram_model(),
+ feature_processor_ == nullptr
+ ? nullptr
+ : feature_processor_->tokenizer(),
+ unilib_);
+ if (ngram_model_ == nullptr) {
+ TC3_LOG(ERROR) << "Could not create ngram linear regression model.";
+ return false;
+ }
+ }
+
+ return true;
+}
+
+bool ActionsSuggestions::InitializeTriggeringPreconditions() {
+ triggering_preconditions_overlay_ =
+ LoadAndVerifyFlatbuffer<TriggeringPreconditions>(
+ triggering_preconditions_overlay_buffer_);
+
+ if (triggering_preconditions_overlay_ == nullptr &&
+ !triggering_preconditions_overlay_buffer_.empty()) {
+ TC3_LOG(ERROR) << "Could not load triggering preconditions overwrites.";
+ return false;
+ }
+ const flatbuffers::Table* overlay =
+ reinterpret_cast<const flatbuffers::Table*>(
+ triggering_preconditions_overlay_);
+ const TriggeringPreconditions* defaults = model_->preconditions();
+ if (defaults == nullptr) {
+ TC3_LOG(ERROR) << "No triggering conditions specified.";
+ return false;
+ }
+
+ preconditions_.min_smart_reply_triggering_score = ValueOrDefault(
+ overlay, TriggeringPreconditions::VT_MIN_SMART_REPLY_TRIGGERING_SCORE,
+ defaults->min_smart_reply_triggering_score());
+ preconditions_.max_sensitive_topic_score = ValueOrDefault(
+ overlay, TriggeringPreconditions::VT_MAX_SENSITIVE_TOPIC_SCORE,
+ defaults->max_sensitive_topic_score());
+ preconditions_.suppress_on_sensitive_topic = ValueOrDefault(
+ overlay, TriggeringPreconditions::VT_SUPPRESS_ON_SENSITIVE_TOPIC,
+ defaults->suppress_on_sensitive_topic());
+ preconditions_.min_input_length =
+ ValueOrDefault(overlay, TriggeringPreconditions::VT_MIN_INPUT_LENGTH,
+ defaults->min_input_length());
+ preconditions_.max_input_length =
+ ValueOrDefault(overlay, TriggeringPreconditions::VT_MAX_INPUT_LENGTH,
+ defaults->max_input_length());
+ preconditions_.min_locale_match_fraction = ValueOrDefault(
+ overlay, TriggeringPreconditions::VT_MIN_LOCALE_MATCH_FRACTION,
+ defaults->min_locale_match_fraction());
+ preconditions_.handle_missing_locale_as_supported = ValueOrDefault(
+ overlay, TriggeringPreconditions::VT_HANDLE_MISSING_LOCALE_AS_SUPPORTED,
+ defaults->handle_missing_locale_as_supported());
+ preconditions_.handle_unknown_locale_as_supported = ValueOrDefault(
+ overlay, TriggeringPreconditions::VT_HANDLE_UNKNOWN_LOCALE_AS_SUPPORTED,
+ defaults->handle_unknown_locale_as_supported());
+ preconditions_.suppress_on_low_confidence_input = ValueOrDefault(
+ overlay, TriggeringPreconditions::VT_SUPPRESS_ON_LOW_CONFIDENCE_INPUT,
+ defaults->suppress_on_low_confidence_input());
+ preconditions_.diversification_distance_threshold = ValueOrDefault(
+ overlay, TriggeringPreconditions::VT_DIVERSIFICATION_DISTANCE_THRESHOLD,
+ defaults->diversification_distance_threshold());
+ preconditions_.confidence_threshold =
+ ValueOrDefault(overlay, TriggeringPreconditions::VT_CONFIDENCE_THRESHOLD,
+ defaults->confidence_threshold());
+ preconditions_.empirical_probability_factor = ValueOrDefault(
+ overlay, TriggeringPreconditions::VT_EMPIRICAL_PROBABILITY_FACTOR,
+ defaults->empirical_probability_factor());
+ preconditions_.min_reply_score_threshold = ValueOrDefault(
+ overlay, TriggeringPreconditions::VT_MIN_REPLY_SCORE_THRESHOLD,
+ defaults->min_reply_score_threshold());
+
+ return true;
+}
+
+bool ActionsSuggestions::EmbedTokenId(const int32 token_id,
+ std::vector<float>* embedding) const {
+ return feature_processor_->AppendFeatures(
+ {token_id},
+ /*dense_features=*/{}, embedding_executor_.get(), embedding);
+}
+
+bool ActionsSuggestions::InitializeRules(ZlibDecompressor* decompressor) {
+ if (model_->rules() != nullptr) {
+ if (!InitializeRules(decompressor, model_->rules(), &rules_)) {
+ TC3_LOG(ERROR) << "Could not initialize action rules.";
+ return false;
+ }
+ }
+
+ if (model_->low_confidence_rules() != nullptr) {
+ if (!InitializeRules(decompressor, model_->low_confidence_rules(),
+ &low_confidence_rules_)) {
+ TC3_LOG(ERROR) << "Could not initialize low confidence rules.";
+ return false;
+ }
+ }
+
+ // Extend by rules provided by the overwrite.
+ // NOTE: The rules from the original models are *not* cleared.
+ if (triggering_preconditions_overlay_ != nullptr &&
+ triggering_preconditions_overlay_->low_confidence_rules() != nullptr) {
+ // These rules are optionally compressed, but separately.
+ std::unique_ptr<ZlibDecompressor> overwrite_decompressor =
+ ZlibDecompressor::Instance();
+ if (overwrite_decompressor == nullptr) {
+ TC3_LOG(ERROR) << "Could not initialze decompressor for overwrite rules.";
+ return false;
+ }
+ if (!InitializeRules(
+ overwrite_decompressor.get(),
+ triggering_preconditions_overlay_->low_confidence_rules(),
+ &low_confidence_rules_)) {
+ TC3_LOG(ERROR)
+ << "Could not initialize low confidence rules from overwrite.";
+ return false;
+ }
+ }
+
+ return true;
+}
+
+bool ActionsSuggestions::InitializeRules(
+ ZlibDecompressor* decompressor, const RulesModel* rules,
+ std::vector<CompiledRule>* compiled_rules) const {
+ for (const RulesModel_::Rule* rule : *rules->rule()) {
+ std::unique_ptr<UniLib::RegexPattern> compiled_pattern =
+ UncompressMakeRegexPattern(
+ *unilib_, rule->pattern(), rule->compressed_pattern(),
+ rules->lazy_regex_compilation(), decompressor);
+ if (compiled_pattern == nullptr) {
+ TC3_LOG(ERROR) << "Failed to load rule pattern.";
+ return false;
+ }
+
+ // Check whether there is a check on the output.
+ std::unique_ptr<UniLib::RegexPattern> compiled_output_pattern;
+ if (rule->output_pattern() != nullptr ||
+ rule->compressed_output_pattern() != nullptr) {
+ compiled_output_pattern = UncompressMakeRegexPattern(
+ *unilib_, rule->output_pattern(), rule->compressed_output_pattern(),
+ rules->lazy_regex_compilation(), decompressor);
+ if (compiled_output_pattern == nullptr) {
+ TC3_LOG(ERROR) << "Failed to load rule output pattern.";
+ return false;
+ }
+ }
+
+ compiled_rules->emplace_back(rule, std::move(compiled_pattern),
+ std::move(compiled_output_pattern));
+ }
+
+ return true;
+}
+
+bool ActionsSuggestions::IsLowConfidenceInput(
+ const Conversation& conversation, const int num_messages,
+ std::vector<int>* post_check_rules) const {
+ for (int i = 1; i <= num_messages; i++) {
+ const std::string& message =
+ conversation.messages[conversation.messages.size() - i].text;
+ const UnicodeText message_unicode(
+ UTF8ToUnicodeText(message, /*do_copy=*/false));
+
+ // Run ngram linear regression model.
+ if (ngram_model_ != nullptr) {
+ if (ngram_model_->Eval(message_unicode)) {
+ return true;
+ }
+ }
+
+ // Run the regex based rules.
+ for (int low_confidence_rule = 0;
+ low_confidence_rule < low_confidence_rules_.size();
+ low_confidence_rule++) {
+ const CompiledRule& rule = low_confidence_rules_[low_confidence_rule];
+ const std::unique_ptr<UniLib::RegexMatcher> matcher =
+ rule.pattern->Matcher(message_unicode);
+ int status = UniLib::RegexMatcher::kNoError;
+ if (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) {
+ // Rule only applies to input-output pairs, so defer the check.
+ if (rule.output_pattern != nullptr) {
+ post_check_rules->push_back(low_confidence_rule);
+ continue;
+ }
+ return true;
+ }
+ }
+ }
+ return false;
+}
+
+bool ActionsSuggestions::FilterConfidenceOutput(
+ const std::vector<int>& post_check_rules,
+ std::vector<ActionSuggestion>* actions) const {
+ if (post_check_rules.empty() || actions->empty()) {
+ return true;
+ }
+ std::vector<ActionSuggestion> filtered_text_replies;
+ for (const ActionSuggestion& action : *actions) {
+ if (action.response_text.empty()) {
+ filtered_text_replies.push_back(action);
+ continue;
+ }
+ bool passes_post_check = true;
+ const UnicodeText text_reply_unicode(
+ UTF8ToUnicodeText(action.response_text, /*do_copy=*/false));
+ for (const int rule_id : post_check_rules) {
+ const std::unique_ptr<UniLib::RegexMatcher> matcher =
+ low_confidence_rules_[rule_id].output_pattern->Matcher(
+ text_reply_unicode);
+ if (matcher == nullptr) {
+ TC3_LOG(ERROR) << "Could not create matcher for post check rule.";
+ return false;
+ }
+ int status = UniLib::RegexMatcher::kNoError;
+ if (matcher->Find(&status) || status != UniLib::RegexMatcher::kNoError) {
+ passes_post_check = false;
+ break;
+ }
+ }
+ if (passes_post_check) {
+ filtered_text_replies.push_back(action);
+ }
+ }
+ *actions = std::move(filtered_text_replies);
+ return true;
+}
+
+ActionSuggestion ActionsSuggestions::SuggestionFromSpec(
+ const ActionSuggestionSpec* action, const std::string& default_type,
+ const std::string& default_response_text,
+ const std::string& default_serialized_entity_data,
+ const float default_score, const float default_priority_score) const {
+ ActionSuggestion suggestion;
+ suggestion.score = action != nullptr ? action->score() : default_score;
+ suggestion.priority_score =
+ action != nullptr ? action->priority_score() : default_priority_score;
+ suggestion.type = action != nullptr && action->type() != nullptr
+ ? action->type()->str()
+ : default_type;
+ suggestion.response_text =
+ action != nullptr && action->response_text() != nullptr
+ ? action->response_text()->str()
+ : default_response_text;
+ suggestion.serialized_entity_data =
+ action != nullptr && action->serialized_entity_data() != nullptr
+ ? action->serialized_entity_data()->str()
+ : default_serialized_entity_data;
+ return suggestion;
+}
+
+std::vector<std::vector<Token>> ActionsSuggestions::Tokenize(
+ const std::vector<std::string>& context) const {
+ std::vector<std::vector<Token>> tokens;
+ tokens.reserve(context.size());
+ for (const std::string& message : context) {
+ tokens.push_back(feature_processor_->tokenizer()->Tokenize(message));
+ }
+ return tokens;
+}
+
+bool ActionsSuggestions::EmbedTokensPerMessage(
+ const std::vector<std::vector<Token>>& tokens,
+ std::vector<float>* embeddings, int* max_num_tokens_per_message) const {
+ const int num_messages = tokens.size();
+ *max_num_tokens_per_message = 0;
+ for (int i = 0; i < num_messages; i++) {
+ const int num_message_tokens = tokens[i].size();
+ if (num_message_tokens > *max_num_tokens_per_message) {
+ *max_num_tokens_per_message = num_message_tokens;
+ }
+ }
+
+ if (model_->feature_processor_options()->min_num_tokens_per_message() >
+ *max_num_tokens_per_message) {
+ *max_num_tokens_per_message =
+ model_->feature_processor_options()->min_num_tokens_per_message();
+ }
+ if (model_->feature_processor_options()->max_num_tokens_per_message() > 0 &&
+ *max_num_tokens_per_message >
+ model_->feature_processor_options()->max_num_tokens_per_message()) {
+ *max_num_tokens_per_message =
+ model_->feature_processor_options()->max_num_tokens_per_message();
+ }
+
+ // Embed all tokens and add paddings to pad tokens of each message to the
+ // maximum number of tokens in a message of the conversation.
+ // If a number of tokens is specified in the model config, tokens at the
+ // beginning of a message are dropped if they don't fit in the limit.
+ for (int i = 0; i < num_messages; i++) {
+ const int start =
+ std::max<int>(tokens[i].size() - *max_num_tokens_per_message, 0);
+ for (int pos = start; pos < tokens[i].size(); pos++) {
+ if (!feature_processor_->AppendTokenFeatures(
+ tokens[i][pos], embedding_executor_.get(), embeddings)) {
+ TC3_LOG(ERROR) << "Could not run token feature extractor.";
+ return false;
+ }
+ }
+ // Add padding.
+ for (int k = tokens[i].size(); k < *max_num_tokens_per_message; k++) {
+ embeddings->insert(embeddings->end(), embedded_padding_token_.begin(),
+ embedded_padding_token_.end());
+ }
+ }
+
+ return true;
+}
+
+bool ActionsSuggestions::EmbedAndFlattenTokens(
+ const std::vector<std::vector<Token>> tokens,
+ std::vector<float>* embeddings, int* total_token_count) const {
+ const int num_messages = tokens.size();
+ int start_message = 0;
+ int message_token_offset = 0;
+
+ // If a maximum model input length is specified, we need to check how
+ // much we need to trim at the start.
+ const int max_num_total_tokens =
+ model_->feature_processor_options()->max_num_total_tokens();
+ if (max_num_total_tokens > 0) {
+ int total_tokens = 0;
+ start_message = num_messages - 1;
+ for (; start_message >= 0; start_message--) {
+ // Tokens of the message + start and end token.
+ const int num_message_tokens = tokens[start_message].size() + 2;
+ total_tokens += num_message_tokens;
+
+ // Check whether we exhausted the budget.
+ if (total_tokens >= max_num_total_tokens) {
+ message_token_offset = total_tokens - max_num_total_tokens;
+ break;
+ }
+ }
+ }
+
+ // Add embeddings.
+ *total_token_count = 0;
+ for (int i = start_message; i < num_messages; i++) {
+ if (message_token_offset == 0) {
+ ++(*total_token_count);
+ // Add `start message` token.
+ embeddings->insert(embeddings->end(), embedded_start_token_.begin(),
+ embedded_start_token_.end());
+ }
+
+ for (int pos = std::max(0, message_token_offset - 1);
+ pos < tokens[i].size(); pos++) {
+ ++(*total_token_count);
+ if (!feature_processor_->AppendTokenFeatures(
+ tokens[i][pos], embedding_executor_.get(), embeddings)) {
+ TC3_LOG(ERROR) << "Could not run token feature extractor.";
+ return false;
+ }
+ }
+
+ // Add `end message` token.
+ ++(*total_token_count);
+ embeddings->insert(embeddings->end(), embedded_end_token_.begin(),
+ embedded_end_token_.end());
+
+ // Reset for the subsequent messages.
+ message_token_offset = 0;
+ }
+
+ // Add optional padding.
+ const int min_num_total_tokens =
+ model_->feature_processor_options()->min_num_total_tokens();
+ for (; *total_token_count < min_num_total_tokens; ++(*total_token_count)) {
+ embeddings->insert(embeddings->end(), embedded_padding_token_.begin(),
+ embedded_padding_token_.end());
+ }
+
+ return true;
+}
+
+bool ActionsSuggestions::AllocateInput(const int conversation_length,
+ const int max_tokens,
+ const int total_token_count,
+ tflite::Interpreter* interpreter) const {
+ if (model_->tflite_model_spec()->resize_inputs()) {
+ if (model_->tflite_model_spec()->input_context() >= 0) {
+ interpreter->ResizeInputTensor(
+ interpreter->inputs()[model_->tflite_model_spec()->input_context()],
+ {1, conversation_length});
+ }
+ if (model_->tflite_model_spec()->input_user_id() >= 0) {
+ interpreter->ResizeInputTensor(
+ interpreter->inputs()[model_->tflite_model_spec()->input_user_id()],
+ {1, conversation_length});
+ }
+ if (model_->tflite_model_spec()->input_time_diffs() >= 0) {
+ interpreter->ResizeInputTensor(
+ interpreter
+ ->inputs()[model_->tflite_model_spec()->input_time_diffs()],
+ {1, conversation_length});
+ }
+ if (model_->tflite_model_spec()->input_num_tokens() >= 0) {
+ interpreter->ResizeInputTensor(
+ interpreter
+ ->inputs()[model_->tflite_model_spec()->input_num_tokens()],
+ {conversation_length, 1});
+ }
+ if (model_->tflite_model_spec()->input_token_embeddings() >= 0) {
+ interpreter->ResizeInputTensor(
+ interpreter
+ ->inputs()[model_->tflite_model_spec()->input_token_embeddings()],
+ {conversation_length, max_tokens, token_embedding_size_});
+ }
+ if (model_->tflite_model_spec()->input_flattened_token_embeddings() >= 0) {
+ interpreter->ResizeInputTensor(
+ interpreter->inputs()[model_->tflite_model_spec()
+ ->input_flattened_token_embeddings()],
+ {1, total_token_count});
+ }
+ }
+
+ return interpreter->AllocateTensors() == kTfLiteOk;
+}
+
+bool ActionsSuggestions::SetupModelInput(
+ const std::vector<std::string>& context, const std::vector<int>& user_ids,
+ const std::vector<float>& time_diffs, const int num_suggestions,
+ const float confidence_threshold, const float diversification_distance,
+ const float empirical_probability_factor,
+ tflite::Interpreter* interpreter) const {
+ // Compute token embeddings.
+ std::vector<std::vector<Token>> tokens;
+ std::vector<float> token_embeddings;
+ std::vector<float> flattened_token_embeddings;
+ int max_tokens = 0;
+ int total_token_count = 0;
+ if (model_->tflite_model_spec()->input_num_tokens() >= 0 ||
+ model_->tflite_model_spec()->input_token_embeddings() >= 0 ||
+ model_->tflite_model_spec()->input_flattened_token_embeddings() >= 0) {
+ if (feature_processor_ == nullptr) {
+ TC3_LOG(ERROR) << "No feature processor specified.";
+ return false;
+ }
+
+ // Tokenize the messages in the conversation.
+ tokens = Tokenize(context);
+ if (model_->tflite_model_spec()->input_token_embeddings() >= 0) {
+ if (!EmbedTokensPerMessage(tokens, &token_embeddings, &max_tokens)) {
+ TC3_LOG(ERROR) << "Could not extract token features.";
+ return false;
+ }
+ }
+ if (model_->tflite_model_spec()->input_flattened_token_embeddings() >= 0) {
+ if (!EmbedAndFlattenTokens(tokens, &flattened_token_embeddings,
+ &total_token_count)) {
+ TC3_LOG(ERROR) << "Could not extract token features.";
+ return false;
+ }
+ }
+ }
+
+ if (!AllocateInput(context.size(), max_tokens, total_token_count,
+ interpreter)) {
+ TC3_LOG(ERROR) << "TensorFlow Lite model allocation failed.";
+ return false;
+ }
+ if (model_->tflite_model_spec()->input_context() >= 0) {
+ model_executor_->SetInput<std::string>(
+ model_->tflite_model_spec()->input_context(), context, interpreter);
+ }
+ if (model_->tflite_model_spec()->input_context_length() >= 0) {
+ model_executor_->SetInput<int>(
+ model_->tflite_model_spec()->input_context_length(), context.size(),
+ interpreter);
+ }
+ if (model_->tflite_model_spec()->input_user_id() >= 0) {
+ model_executor_->SetInput<int>(model_->tflite_model_spec()->input_user_id(),
+ user_ids, interpreter);
+ }
+ if (model_->tflite_model_spec()->input_num_suggestions() >= 0) {
+ model_executor_->SetInput<int>(
+ model_->tflite_model_spec()->input_num_suggestions(), num_suggestions,
+ interpreter);
+ }
+ if (model_->tflite_model_spec()->input_time_diffs() >= 0) {
+ model_executor_->SetInput<float>(
+ model_->tflite_model_spec()->input_time_diffs(), time_diffs,
+ interpreter);
+ }
+ if (model_->tflite_model_spec()->input_diversification_distance() >= 0) {
+ model_executor_->SetInput<float>(
+ model_->tflite_model_spec()->input_diversification_distance(),
+ diversification_distance, interpreter);
+ }
+ if (model_->tflite_model_spec()->input_confidence_threshold() >= 0) {
+ model_executor_->SetInput<float>(
+ model_->tflite_model_spec()->input_confidence_threshold(),
+ confidence_threshold, interpreter);
+ }
+ if (model_->tflite_model_spec()->input_empirical_probability_factor() >= 0) {
+ model_executor_->SetInput<float>(
+ model_->tflite_model_spec()->input_empirical_probability_factor(),
+ confidence_threshold, interpreter);
+ }
+ if (model_->tflite_model_spec()->input_num_tokens() >= 0) {
+ std::vector<int> num_tokens_per_message(tokens.size());
+ for (int i = 0; i < tokens.size(); i++) {
+ num_tokens_per_message[i] = tokens[i].size();
+ }
+ model_executor_->SetInput<int>(
+ model_->tflite_model_spec()->input_num_tokens(), num_tokens_per_message,
+ interpreter);
+ }
+ if (model_->tflite_model_spec()->input_token_embeddings() >= 0) {
+ model_executor_->SetInput<float>(
+ model_->tflite_model_spec()->input_token_embeddings(), token_embeddings,
+ interpreter);
+ }
+ if (model_->tflite_model_spec()->input_flattened_token_embeddings() >= 0) {
+ model_executor_->SetInput<float>(
+ model_->tflite_model_spec()->input_flattened_token_embeddings(),
+ flattened_token_embeddings, interpreter);
+ }
+ return true;
+}
+
+bool ActionsSuggestions::ReadModelOutput(
+ tflite::Interpreter* interpreter, const ActionSuggestionOptions& options,
+ ActionsSuggestionsResponse* response) const {
+ // Read sensitivity and triggering score predictions.
+ if (model_->tflite_model_spec()->output_triggering_score() >= 0) {
+ const TensorView<float>& triggering_score =
+ model_executor_->OutputView<float>(
+ model_->tflite_model_spec()->output_triggering_score(),
+ interpreter);
+ if (!triggering_score.is_valid() || triggering_score.size() == 0) {
+ TC3_LOG(ERROR) << "Could not compute triggering score.";
+ return false;
+ }
+ response->triggering_score = triggering_score.data()[0];
+ response->output_filtered_min_triggering_score =
+ (response->triggering_score <
+ preconditions_.min_smart_reply_triggering_score);
+ }
+ if (model_->tflite_model_spec()->output_sensitive_topic_score() >= 0) {
+ const TensorView<float>& sensitive_topic_score =
+ model_executor_->OutputView<float>(
+ model_->tflite_model_spec()->output_sensitive_topic_score(),
+ interpreter);
+ if (!sensitive_topic_score.is_valid() ||
+ sensitive_topic_score.dim(0) != 1) {
+ TC3_LOG(ERROR) << "Could not compute sensitive topic score.";
+ return false;
+ }
+ response->sensitivity_score = sensitive_topic_score.data()[0];
+ response->output_filtered_sensitivity =
+ (response->sensitivity_score >
+ preconditions_.max_sensitive_topic_score);
+ }
+
+ // Suppress model outputs.
+ if (response->output_filtered_sensitivity) {
+ return true;
+ }
+
+ // Read smart reply predictions.
+ std::vector<ActionSuggestion> text_replies;
+ if (!response->output_filtered_min_triggering_score &&
+ model_->tflite_model_spec()->output_replies() >= 0) {
+ const std::vector<tflite::StringRef> replies =
+ model_executor_->Output<tflite::StringRef>(
+ model_->tflite_model_spec()->output_replies(), interpreter);
+ TensorView<float> scores = model_executor_->OutputView<float>(
+ model_->tflite_model_spec()->output_replies_scores(), interpreter);
+ for (int i = 0; i < replies.size(); i++) {
+ if (replies[i].len == 0) continue;
+ const float score = scores.data()[i];
+ if (score < preconditions_.min_reply_score_threshold) {
+ continue;
+ }
+ response->actions.push_back({std::string(replies[i].str, replies[i].len),
+ model_->smart_reply_action_type()->str(),
+ score});
+ }
+ }
+
+ // Read actions suggestions.
+ if (model_->tflite_model_spec()->output_actions_scores() >= 0) {
+ const TensorView<float> actions_scores = model_executor_->OutputView<float>(
+ model_->tflite_model_spec()->output_actions_scores(), interpreter);
+ for (int i = 0; i < model_->action_type()->Length(); i++) {
+ const ActionTypeOptions* action_type = model_->action_type()->Get(i);
+ // Skip disabled action classes, such as the default other category.
+ if (!action_type->enabled()) {
+ continue;
+ }
+ const float score = actions_scores.data()[i];
+ if (score < action_type->min_triggering_score()) {
+ continue;
+ }
+ ActionSuggestion suggestion =
+ SuggestionFromSpec(action_type->action(),
+ /*default_type=*/action_type->name()->str());
+ suggestion.score = score;
+ response->actions.push_back(suggestion);
+ }
+ }
+
+ return true;
+}
+
+bool ActionsSuggestions::SuggestActionsFromModel(
+ const Conversation& conversation, const int num_messages,
+ const ActionSuggestionOptions& options,
+ ActionsSuggestionsResponse* response,
+ std::unique_ptr<tflite::Interpreter>* interpreter) const {
+ TC3_CHECK_LE(num_messages, conversation.messages.size());
+
+ if (!model_executor_) {
+ return true;
+ }
+ *interpreter = model_executor_->CreateInterpreter();
+
+ if (!*interpreter) {
+ TC3_LOG(ERROR) << "Could not build TensorFlow Lite interpreter for the "
+ "actions suggestions model.";
+ return false;
+ }
+
+ std::vector<std::string> context;
+ std::vector<int> user_ids;
+ std::vector<float> time_diffs;
+ context.reserve(num_messages);
+ user_ids.reserve(num_messages);
+ time_diffs.reserve(num_messages);
+
+ // Gather last `num_messages` messages from the conversation.
+ int64 last_message_reference_time_ms_utc = 0;
+ const float second_in_ms = 1000;
+ for (int i = conversation.messages.size() - num_messages;
+ i < conversation.messages.size(); i++) {
+ const ConversationMessage& message = conversation.messages[i];
+ context.push_back(message.text);
+ user_ids.push_back(message.user_id);
+
+ float time_diff_secs = 0;
+ if (message.reference_time_ms_utc != 0 &&
+ last_message_reference_time_ms_utc != 0) {
+ time_diff_secs = std::max(0.0f, (message.reference_time_ms_utc -
+ last_message_reference_time_ms_utc) /
+ second_in_ms);
+ }
+ if (message.reference_time_ms_utc != 0) {
+ last_message_reference_time_ms_utc = message.reference_time_ms_utc;
+ }
+ time_diffs.push_back(time_diff_secs);
+ }
+
+ if (!SetupModelInput(context, user_ids, time_diffs,
+ /*num_suggestions=*/model_->num_smart_replies(),
+ preconditions_.confidence_threshold,
+ preconditions_.diversification_distance_threshold,
+ preconditions_.empirical_probability_factor,
+ interpreter->get())) {
+ TC3_LOG(ERROR) << "Failed to setup input for TensorFlow Lite model.";
+ return false;
+ }
+
+ if ((*interpreter)->Invoke() != kTfLiteOk) {
+ TC3_LOG(ERROR) << "Failed to invoke TensorFlow Lite interpreter.";
+ return false;
+ }
+
+ return ReadModelOutput(interpreter->get(), options, response);
+}
+
+AnnotationOptions ActionsSuggestions::AnnotationOptionsForMessage(
+ const ConversationMessage& message) const {
+ AnnotationOptions options;
+ options.detected_text_language_tags = message.detected_text_language_tags;
+ options.reference_time_ms_utc = message.reference_time_ms_utc;
+ options.reference_timezone = message.reference_timezone;
+ options.annotation_usecase =
+ model_->annotation_actions_spec()->annotation_usecase();
+ options.is_serialized_entity_data_enabled =
+ model_->annotation_actions_spec()->is_serialized_entity_data_enabled();
+ options.entity_types = annotation_entity_types_;
+ return options;
+}
+
+void ActionsSuggestions::SuggestActionsFromAnnotations(
+ const Conversation& conversation, const ActionSuggestionOptions& options,
+ const Annotator* annotator, std::vector<ActionSuggestion>* actions) const {
+ if (model_->annotation_actions_spec() == nullptr ||
+ model_->annotation_actions_spec()->annotation_mapping() == nullptr ||
+ model_->annotation_actions_spec()->annotation_mapping()->size() == 0) {
+ return;
+ }
+
+ // Create actions based on the annotations.
+ const int max_from_any_person =
+ model_->annotation_actions_spec()->max_history_from_any_person();
+ const int max_from_last_person =
+ model_->annotation_actions_spec()->max_history_from_last_person();
+ const int last_person = conversation.messages.back().user_id;
+
+ int num_messages_last_person = 0;
+ int num_messages_any_person = 0;
+ bool all_from_last_person = true;
+ for (int message_index = conversation.messages.size() - 1; message_index >= 0;
+ message_index--) {
+ const ConversationMessage& message = conversation.messages[message_index];
+ std::vector<AnnotatedSpan> annotations = message.annotations;
+
+ // Update how many messages we have processed from the last person in the
+ // conversation and from any person in the conversation.
+ num_messages_any_person++;
+ if (all_from_last_person && message.user_id == last_person) {
+ num_messages_last_person++;
+ } else {
+ all_from_last_person = false;
+ }
+
+ if (num_messages_any_person > max_from_any_person &&
+ (!all_from_last_person ||
+ num_messages_last_person > max_from_last_person)) {
+ break;
+ }
+
+ if (message.user_id == kLocalUserId) {
+ if (model_->annotation_actions_spec()->only_until_last_sent()) {
+ break;
+ }
+ if (!model_->annotation_actions_spec()->include_local_user_messages()) {
+ continue;
+ }
+ }
+
+ if (annotations.empty() && annotator != nullptr) {
+ annotations = annotator->Annotate(message.text,
+ AnnotationOptionsForMessage(message));
+ }
+ std::vector<ActionSuggestionAnnotation> action_annotations;
+ action_annotations.reserve(annotations.size());
+ for (const AnnotatedSpan& annotation : annotations) {
+ if (annotation.classification.empty()) {
+ continue;
+ }
+
+ const ClassificationResult& classification_result =
+ annotation.classification[0];
+
+ ActionSuggestionAnnotation action_annotation;
+ action_annotation.span = {
+ message_index, annotation.span,
+ UTF8ToUnicodeText(message.text, /*do_copy=*/false)
+ .UTF8Substring(annotation.span.first, annotation.span.second)};
+ action_annotation.entity = classification_result;
+ action_annotation.name = classification_result.collection;
+ action_annotations.push_back(action_annotation);
+ }
+
+ if (model_->annotation_actions_spec()->deduplicate_annotations()) {
+ // Create actions only for deduplicated annotations.
+ for (const int annotation_id :
+ DeduplicateAnnotations(action_annotations)) {
+ SuggestActionsFromAnnotation(
+ message_index, action_annotations[annotation_id], actions);
+ }
+ } else {
+ // Create actions for all annotations.
+ for (const ActionSuggestionAnnotation& annotation : action_annotations) {
+ SuggestActionsFromAnnotation(message_index, annotation, actions);
+ }
+ }
+ }
+}
+
+void ActionsSuggestions::SuggestActionsFromAnnotation(
+ const int message_index, const ActionSuggestionAnnotation& annotation,
+ std::vector<ActionSuggestion>* actions) const {
+ for (const AnnotationActionsSpec_::AnnotationMapping* mapping :
+ *model_->annotation_actions_spec()->annotation_mapping()) {
+ if (annotation.entity.collection ==
+ mapping->annotation_collection()->str()) {
+ if (annotation.entity.score < mapping->min_annotation_score()) {
+ continue;
+ }
+ ActionSuggestion suggestion = SuggestionFromSpec(mapping->action());
+ if (mapping->use_annotation_score()) {
+ suggestion.score = annotation.entity.score;
+ }
+
+ // Set annotation text as (additional) entity data field.
+ if (mapping->entity_field() != nullptr) {
+ std::unique_ptr<ReflectiveFlatbuffer> entity_data =
+ entity_data_builder_->NewRoot();
+ TC3_CHECK(entity_data != nullptr);
+
+ // Merge existing static entity data.
+ if (!suggestion.serialized_entity_data.empty()) {
+ entity_data->MergeFromSerializedFlatbuffer(
+ StringPiece(suggestion.serialized_entity_data.c_str(),
+ suggestion.serialized_entity_data.size()));
+ }
+
+ entity_data->ParseAndSet(mapping->entity_field(), annotation.span.text);
+ suggestion.serialized_entity_data = entity_data->Serialize();
+ }
+
+ suggestion.annotations = {annotation};
+ actions->push_back(suggestion);
+ }
+ }
+}
+
+std::vector<int> ActionsSuggestions::DeduplicateAnnotations(
+ const std::vector<ActionSuggestionAnnotation>& annotations) const {
+ std::map<std::pair<std::string, std::string>, int> deduplicated_annotations;
+
+ for (int i = 0; i < annotations.size(); i++) {
+ const std::pair<std::string, std::string> key = {annotations[i].name,
+ annotations[i].span.text};
+ auto entry = deduplicated_annotations.find(key);
+ if (entry != deduplicated_annotations.end()) {
+ // Kepp the annotation with the higher score.
+ if (annotations[entry->second].entity.score <
+ annotations[i].entity.score) {
+ entry->second = i;
+ }
+ continue;
+ }
+ deduplicated_annotations.insert(entry, {key, i});
+ }
+
+ std::vector<int> result;
+ result.reserve(deduplicated_annotations.size());
+ for (const auto& key_and_annotation : deduplicated_annotations) {
+ result.push_back(key_and_annotation.second);
+ }
+ return result;
+}
+
+bool ActionsSuggestions::FillAnnotationFromMatchGroup(
+ const UniLib::RegexMatcher* matcher,
+ const RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroup* group,
+ const int message_index, ActionSuggestionAnnotation* annotation) const {
+ if (group->annotation_name() != nullptr ||
+ group->annotation_type() != nullptr) {
+ int status = UniLib::RegexMatcher::kNoError;
+ const CodepointSpan span = {matcher->Start(group->group_id(), &status),
+ matcher->End(group->group_id(), &status)};
+ std::string text =
+ matcher->Group(group->group_id(), &status).ToUTF8String();
+ if (status != UniLib::RegexMatcher::kNoError) {
+ TC3_LOG(ERROR) << "Could not extract span from rule capturing group.";
+ return false;
+ }
+
+ // The capturing group was not part of the match.
+ if (span.first == kInvalidIndex || span.second == kInvalidIndex) {
+ return false;
+ }
+ annotation->span.span = span;
+ annotation->span.message_index = message_index;
+ annotation->span.text = text;
+ if (group->annotation_name() != nullptr) {
+ annotation->name = group->annotation_name()->str();
+ }
+ if (group->annotation_type() != nullptr) {
+ annotation->entity.collection = group->annotation_type()->str();
+ }
+ }
+ return true;
+}
+
+bool ActionsSuggestions::SuggestActionsFromRules(
+ const Conversation& conversation,
+ std::vector<ActionSuggestion>* actions) const {
+ // Create actions based on rules checking the last message.
+ const int message_index = conversation.messages.size() - 1;
+ const std::string& message = conversation.messages.back().text;
+ const UnicodeText message_unicode(
+ UTF8ToUnicodeText(message, /*do_copy=*/false));
+ for (const CompiledRule& rule : rules_) {
+ const std::unique_ptr<UniLib::RegexMatcher> matcher =
+ rule.pattern->Matcher(message_unicode);
+ int status = UniLib::RegexMatcher::kNoError;
+ while (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) {
+ for (const RulesModel_::Rule_::RuleActionSpec* rule_action :
+ *rule.rule->actions()) {
+ const ActionSuggestionSpec* action = rule_action->action();
+ std::vector<ActionSuggestionAnnotation> annotations;
+
+ bool sets_entity_data = false;
+ std::unique_ptr<ReflectiveFlatbuffer> entity_data =
+ entity_data_builder_ != nullptr ? entity_data_builder_->NewRoot()
+ : nullptr;
+
+ // Set static entity data.
+ if (action != nullptr && action->serialized_entity_data() != nullptr) {
+ TC3_CHECK(entity_data != nullptr);
+ sets_entity_data = true;
+ entity_data->MergeFromSerializedFlatbuffer(
+ StringPiece(action->serialized_entity_data()->c_str(),
+ action->serialized_entity_data()->size()));
+ }
+
+ // Add entity data from rule capturing groups.
+ if (rule_action->capturing_group() != nullptr) {
+ for (const RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroup*
+ group : *rule_action->capturing_group()) {
+ if (group->entity_field() != nullptr) {
+ TC3_CHECK(entity_data != nullptr);
+ sets_entity_data = true;
+ if (!SetFieldFromCapturingGroup(
+ group->group_id(), group->entity_field(), matcher.get(),
+ entity_data.get())) {
+ TC3_LOG(ERROR)
+ << "Could not set entity data from rule capturing group.";
+ return false;
+ }
+ }
+
+ // Create a text annotation for the group span.
+ ActionSuggestionAnnotation annotation;
+ if (FillAnnotationFromMatchGroup(matcher.get(), group,
+ message_index, &annotation)) {
+ annotations.push_back(annotation);
+ }
+
+ // Create text reply.
+ if (group->text_reply() != nullptr) {
+ int status = UniLib::RegexMatcher::kNoError;
+ const std::string group_text =
+ matcher->Group(group->group_id(), &status).ToUTF8String();
+ if (status != UniLib::RegexMatcher::kNoError) {
+ TC3_LOG(ERROR) << "Could get text from capturing group.";
+ return false;
+ }
+ if (group_text.empty()) {
+ // The group was not part of the match, ignore and continue.
+ continue;
+ }
+ actions->push_back(SuggestionFromSpec(
+ group->text_reply(),
+ /*default_type=*/model_->smart_reply_action_type()->str(),
+ /*default_response_text=*/group_text));
+ }
+ }
+ }
+
+ if (action != nullptr) {
+ ActionSuggestion suggestion = SuggestionFromSpec(action);
+ suggestion.annotations = annotations;
+ if (sets_entity_data) {
+ suggestion.serialized_entity_data = entity_data->Serialize();
+ }
+ actions->push_back(suggestion);
+ }
+ }
+ }
+ }
+ return true;
+}
+
+bool ActionsSuggestions::SuggestActionsFromLua(
+ const Conversation& conversation, const TfLiteModelExecutor* model_executor,
+ const tflite::Interpreter* interpreter,
+ const reflection::Schema* annotation_entity_data_schema,
+ std::vector<ActionSuggestion>* actions) const {
+ if (lua_bytecode_.empty()) {
+ return true;
+ }
+
+ auto lua_actions = LuaActionsSuggestions::CreateLuaActionsSuggestions(
+ lua_bytecode_, conversation, model_executor, model_->tflite_model_spec(),
+ interpreter, entity_data_schema_, annotation_entity_data_schema);
+ if (lua_actions == nullptr) {
+ TC3_LOG(ERROR) << "Could not create lua actions.";
+ return false;
+ }
+ return lua_actions->SuggestActions(actions);
+}
+
+bool ActionsSuggestions::GatherActionsSuggestions(
+ const Conversation& conversation, const Annotator* annotator,
+ const ActionSuggestionOptions& options,
+ ActionsSuggestionsResponse* response) const {
+ if (conversation.messages.empty()) {
+ return true;
+ }
+
+ const int num_messages = NumMessagesToConsider(
+ conversation, model_->max_conversation_history_length());
+
+ if (num_messages <= 0) {
+ TC3_LOG(INFO) << "No messages provided for actions suggestions.";
+ return false;
+ }
+
+ SuggestActionsFromAnnotations(conversation, options, annotator,
+ &response->actions);
+
+ int input_text_length = 0;
+ int num_matching_locales = 0;
+ for (int i = conversation.messages.size() - num_messages;
+ i < conversation.messages.size(); i++) {
+ input_text_length += conversation.messages[i].text.length();
+ std::vector<Locale> message_languages;
+ if (!ParseLocales(conversation.messages[i].detected_text_language_tags,
+ &message_languages)) {
+ continue;
+ }
+ if (Locale::IsAnyLocaleSupported(
+ message_languages, locales_,
+ preconditions_.handle_unknown_locale_as_supported)) {
+ ++num_matching_locales;
+ }
+ }
+
+ // Bail out if we are provided with too few or too much input.
+ if (input_text_length < preconditions_.min_input_length ||
+ (preconditions_.max_input_length >= 0 &&
+ input_text_length > preconditions_.max_input_length)) {
+ TC3_LOG(INFO) << "Too much or not enough input for inference.";
+ return response;
+ }
+
+ // Bail out if the text does not look like it can be handled by the model.
+ const float matching_fraction =
+ static_cast<float>(num_matching_locales) / num_messages;
+ if (matching_fraction < preconditions_.min_locale_match_fraction) {
+ TC3_LOG(INFO) << "Not enough locale matches.";
+ response->output_filtered_locale_mismatch = true;
+ return true;
+ }
+
+ std::vector<int> post_check_rules;
+ if (preconditions_.suppress_on_low_confidence_input &&
+ IsLowConfidenceInput(conversation, num_messages, &post_check_rules)) {
+ response->output_filtered_low_confidence = true;
+ return true;
+ }
+
+ std::unique_ptr<tflite::Interpreter> interpreter;
+ if (!SuggestActionsFromModel(conversation, num_messages, options, response,
+ &interpreter)) {
+ TC3_LOG(ERROR) << "Could not run model.";
+ return false;
+ }
+
+ // Suppress all predictions if the conversation was deemed sensitive.
+ if (preconditions_.suppress_on_sensitive_topic &&
+ response->output_filtered_sensitivity) {
+ return true;
+ }
+
+ if (!SuggestActionsFromLua(
+ conversation, model_executor_.get(), interpreter.get(),
+ annotator != nullptr ? annotator->entity_data_schema() : nullptr,
+ &response->actions)) {
+ TC3_LOG(ERROR) << "Could not suggest actions from script.";
+ return false;
+ }
+
+ if (!SuggestActionsFromRules(conversation, &response->actions)) {
+ TC3_LOG(ERROR) << "Could not suggest actions from rules.";
+ return false;
+ }
+
+ if (preconditions_.suppress_on_low_confidence_input &&
+ !FilterConfidenceOutput(post_check_rules, &response->actions)) {
+ TC3_LOG(ERROR) << "Could not post-check actions.";
+ return false;
+ }
+
+ return true;
+}
+
+ActionsSuggestionsResponse ActionsSuggestions::SuggestActions(
+ const Conversation& conversation, const Annotator* annotator,
+ const ActionSuggestionOptions& options) const {
+ ActionsSuggestionsResponse response;
+ if (!GatherActionsSuggestions(conversation, annotator, options, &response)) {
+ TC3_LOG(ERROR) << "Could not gather actions suggestions.";
+ response.actions.clear();
+ } else if (!ranker_->RankActions(conversation, &response, entity_data_schema_,
+ annotator != nullptr
+ ? annotator->entity_data_schema()
+ : nullptr)) {
+ TC3_LOG(ERROR) << "Could not rank actions.";
+ response.actions.clear();
+ }
+ return response;
+}
+
+ActionsSuggestionsResponse ActionsSuggestions::SuggestActions(
+ const Conversation& conversation,
+ const ActionSuggestionOptions& options) const {
+ return SuggestActions(conversation, /*annotator=*/nullptr, options);
+}
+
+const ActionsModel* ActionsSuggestions::model() const { return model_; }
+const reflection::Schema* ActionsSuggestions::entity_data_schema() const {
+ return entity_data_schema_;
+}
+
+const ActionsModel* ViewActionsModel(const void* buffer, int size) {
+ if (buffer == nullptr) {
+ return nullptr;
+ }
+ return LoadAndVerifyModel(reinterpret_cast<const uint8_t*>(buffer), size);
+}
+
+} // namespace libtextclassifier3
diff --git a/actions/actions-suggestions.h b/actions/actions-suggestions.h
new file mode 100644
index 0000000..2dde133
--- /dev/null
+++ b/actions/actions-suggestions.h
@@ -0,0 +1,319 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ACTIONS_ACTIONS_SUGGESTIONS_H_
+#define LIBTEXTCLASSIFIER_ACTIONS_ACTIONS_SUGGESTIONS_H_
+
+#include <map>
+#include <memory>
+#include <string>
+#include <unordered_set>
+#include <vector>
+
+#include "actions/actions_model_generated.h"
+#include "actions/feature-processor.h"
+#include "actions/ngram-model.h"
+#include "actions/ranker.h"
+#include "actions/types.h"
+#include "annotator/annotator.h"
+#include "annotator/model-executor.h"
+#include "annotator/types.h"
+#include "utils/flatbuffers.h"
+#include "utils/i18n/locale.h"
+#include "utils/memory/mmap.h"
+#include "utils/tflite-model-executor.h"
+#include "utils/utf8/unilib.h"
+#include "utils/variant.h"
+#include "utils/zlib/zlib.h"
+
+namespace libtextclassifier3 {
+
+// Options for suggesting actions.
+struct ActionSuggestionOptions {
+ static ActionSuggestionOptions Default() { return ActionSuggestionOptions(); }
+};
+
+// Class for predicting actions following a conversation.
+class ActionsSuggestions {
+ public:
+ // Creates ActionsSuggestions from given data buffer with model.
+ static std::unique_ptr<ActionsSuggestions> FromUnownedBuffer(
+ const uint8_t* buffer, const int size, const UniLib* unilib = nullptr,
+ const std::string& triggering_preconditions_overlay = "");
+
+ // Creates ActionsSuggestions from model in the ScopedMmap object and takes
+ // ownership of it.
+ static std::unique_ptr<ActionsSuggestions> FromScopedMmap(
+ std::unique_ptr<libtextclassifier3::ScopedMmap> mmap,
+ const UniLib* unilib = nullptr,
+ const std::string& triggering_preconditions_overlay = "");
+ // Same as above, but also takes ownership of the unilib.
+ static std::unique_ptr<ActionsSuggestions> FromScopedMmap(
+ std::unique_ptr<libtextclassifier3::ScopedMmap> mmap,
+ std::unique_ptr<UniLib> unilib,
+ const std::string& triggering_preconditions_overlay);
+
+ // Creates ActionsSuggestions from model given as a file descriptor, offset
+ // and size in it. If offset and size are less than 0, will ignore them and
+ // will just use the fd.
+ static std::unique_ptr<ActionsSuggestions> FromFileDescriptor(
+ const int fd, const int offset, const int size,
+ const UniLib* unilib = nullptr,
+ const std::string& triggering_preconditions_overlay = "");
+ // Same as above, but also takes ownership of the unilib.
+ static std::unique_ptr<ActionsSuggestions> FromFileDescriptor(
+ const int fd, const int offset, const int size,
+ std::unique_ptr<UniLib> unilib,
+ const std::string& triggering_preconditions_overlay = "");
+
+ // Creates ActionsSuggestions from model given as a file descriptor.
+ static std::unique_ptr<ActionsSuggestions> FromFileDescriptor(
+ const int fd, const UniLib* unilib = nullptr,
+ const std::string& triggering_preconditions_overlay = "");
+ // Same as above, but also takes ownership of the unilib.
+ static std::unique_ptr<ActionsSuggestions> FromFileDescriptor(
+ const int fd, std::unique_ptr<UniLib> unilib,
+ const std::string& triggering_preconditions_overlay);
+
+ // Creates ActionsSuggestions from model given as a POSIX path.
+ static std::unique_ptr<ActionsSuggestions> FromPath(
+ const std::string& path, const UniLib* unilib = nullptr,
+ const std::string& triggering_preconditions_overlay = "");
+ // Same as above, but also takes ownership of unilib.
+ static std::unique_ptr<ActionsSuggestions> FromPath(
+ const std::string& path, std::unique_ptr<UniLib> unilib,
+ const std::string& triggering_preconditions_overlay);
+
+ ActionsSuggestionsResponse SuggestActions(
+ const Conversation& conversation,
+ const ActionSuggestionOptions& options = ActionSuggestionOptions()) const;
+
+ ActionsSuggestionsResponse SuggestActions(
+ const Conversation& conversation, const Annotator* annotator,
+ const ActionSuggestionOptions& options = ActionSuggestionOptions()) const;
+
+ const ActionsModel* model() const;
+ const reflection::Schema* entity_data_schema() const;
+
+ static const int kLocalUserId = 0;
+
+ // Should be in sync with those defined in Android.
+ // android/frameworks/base/core/java/android/view/textclassifier/ConversationActions.java
+ static const std::string& kViewCalendarType;
+ static const std::string& kViewMapType;
+ static const std::string& kTrackFlightType;
+ static const std::string& kOpenUrlType;
+ static const std::string& kSendSmsType;
+ static const std::string& kCallPhoneType;
+ static const std::string& kSendEmailType;
+ static const std::string& kShareLocation;
+
+ protected:
+ // Exposed for testing.
+ bool EmbedTokenId(const int32 token_id, std::vector<float>* embedding) const;
+
+ // Embeds the tokens per message separately. Each message is padded to the
+ // maximum length with the padding token.
+ bool EmbedTokensPerMessage(const std::vector<std::vector<Token>>& tokens,
+ std::vector<float>* embeddings,
+ int* max_num_tokens_per_message) const;
+
+ // Concatenates the embedded message tokens - separated by start and end
+ // token between messages.
+ // If the total token count is greater than the maximum length, tokens at the
+ // start are dropped to fit into the limit.
+ // If the total token count is smaller than the minimum length, padding tokens
+ // are added to the end.
+ // Messages are assumed to be ordered by recency - most recent is last.
+ bool EmbedAndFlattenTokens(const std::vector<std::vector<Token>> tokens,
+ std::vector<float>* embeddings,
+ int* total_token_count) const;
+
+ const ActionsModel* model_;
+
+ // Feature extractor and options.
+ std::unique_ptr<const ActionsFeatureProcessor> feature_processor_;
+ std::unique_ptr<const EmbeddingExecutor> embedding_executor_;
+ std::vector<float> embedded_padding_token_;
+ std::vector<float> embedded_start_token_;
+ std::vector<float> embedded_end_token_;
+ int token_embedding_size_;
+
+ private:
+ struct CompiledRule {
+ const RulesModel_::Rule* rule;
+ std::unique_ptr<UniLib::RegexPattern> pattern;
+ std::unique_ptr<UniLib::RegexPattern> output_pattern;
+ CompiledRule(const RulesModel_::Rule* rule,
+ std::unique_ptr<UniLib::RegexPattern> pattern,
+ std::unique_ptr<UniLib::RegexPattern> output_pattern)
+ : rule(rule),
+ pattern(std::move(pattern)),
+ output_pattern(std::move(output_pattern)) {}
+ };
+
+ // Checks that model contains all required fields, and initializes internal
+ // datastructures.
+ bool ValidateAndInitialize();
+
+ void SetOrCreateUnilib(const UniLib* unilib);
+
+ // Initializes regular expression rules.
+ bool InitializeRules(ZlibDecompressor* decompressor);
+ bool InitializeRules(ZlibDecompressor* decompressor, const RulesModel* rules,
+ std::vector<CompiledRule>* compiled_rules) const;
+
+ // Prepare preconditions.
+ // Takes values from flag provided data, but falls back to model provided
+ // values for parameters that are not explicitly provided.
+ bool InitializeTriggeringPreconditions();
+
+ // Tokenizes a conversation and produces the tokens per message.
+ std::vector<std::vector<Token>> Tokenize(
+ const std::vector<std::string>& context) const;
+
+ bool AllocateInput(const int conversation_length, const int max_tokens,
+ const int total_token_count,
+ tflite::Interpreter* interpreter) const;
+
+ bool SetupModelInput(const std::vector<std::string>& context,
+ const std::vector<int>& user_ids,
+ const std::vector<float>& time_diffs,
+ const int num_suggestions,
+ const float confidence_threshold,
+ const float diversification_distance,
+ const float empirical_probability_factor,
+ tflite::Interpreter* interpreter) const;
+ bool ReadModelOutput(tflite::Interpreter* interpreter,
+ const ActionSuggestionOptions& options,
+ ActionsSuggestionsResponse* response) const;
+
+ bool SuggestActionsFromModel(
+ const Conversation& conversation, const int num_messages,
+ const ActionSuggestionOptions& options,
+ ActionsSuggestionsResponse* response,
+ std::unique_ptr<tflite::Interpreter>* interpreter) const;
+
+ // Creates options for annotation of a message.
+ AnnotationOptions AnnotationOptionsForMessage(
+ const ConversationMessage& message) const;
+
+ void SuggestActionsFromAnnotations(
+ const Conversation& conversation, const ActionSuggestionOptions& options,
+ const Annotator* annotator, std::vector<ActionSuggestion>* actions) const;
+
+ void SuggestActionsFromAnnotation(
+ const int message_index, const ActionSuggestionAnnotation& annotation,
+ std::vector<ActionSuggestion>* actions) const;
+
+ // Deduplicates equivalent annotations - annotations that have the same type
+ // and same span text.
+ // Returns the indices of the deduplicated annotations.
+ std::vector<int> DeduplicateAnnotations(
+ const std::vector<ActionSuggestionAnnotation>& annotations) const;
+
+ bool SuggestActionsFromRules(const Conversation& conversation,
+ std::vector<ActionSuggestion>* actions) const;
+
+ bool SuggestActionsFromLua(
+ const Conversation& conversation,
+ const TfLiteModelExecutor* model_executor,
+ const tflite::Interpreter* interpreter,
+ const reflection::Schema* annotation_entity_data_schema,
+ std::vector<ActionSuggestion>* actions) const;
+
+ bool GatherActionsSuggestions(const Conversation& conversation,
+ const Annotator* annotator,
+ const ActionSuggestionOptions& options,
+ ActionsSuggestionsResponse* response) const;
+
+ // Checks whether the input triggers the low confidence checks.
+ bool IsLowConfidenceInput(const Conversation& conversation,
+ const int num_messages,
+ std::vector<int>* post_check_rules) const;
+ // Checks and filters suggestions triggering the low confidence post checks.
+ bool FilterConfidenceOutput(const std::vector<int>& post_check_rules,
+ std::vector<ActionSuggestion>* actions) const;
+
+ ActionSuggestion SuggestionFromSpec(
+ const ActionSuggestionSpec* action, const std::string& default_type = "",
+ const std::string& default_response_text = "",
+ const std::string& default_serialized_entity_data = "",
+ const float default_score = 0.0f,
+ const float default_priority_score = 0.0f) const;
+
+ bool FillAnnotationFromMatchGroup(
+ const UniLib::RegexMatcher* matcher,
+ const RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroup* group,
+ const int message_index, ActionSuggestionAnnotation* annotation) const;
+
+ std::unique_ptr<libtextclassifier3::ScopedMmap> mmap_;
+
+ // Tensorflow Lite models.
+ std::unique_ptr<const TfLiteModelExecutor> model_executor_;
+
+ // Rules.
+ std::vector<CompiledRule> rules_, low_confidence_rules_;
+
+ std::unique_ptr<UniLib> owned_unilib_;
+ const UniLib* unilib_;
+
+ // Locales supported by the model.
+ std::vector<Locale> locales_;
+
+ // Annotation entities used by the model.
+ std::unordered_set<std::string> annotation_entity_types_;
+
+ // Builder for creating extra data.
+ const reflection::Schema* entity_data_schema_;
+ std::unique_ptr<ReflectiveFlatbufferBuilder> entity_data_builder_;
+ std::unique_ptr<ActionsSuggestionsRanker> ranker_;
+
+ std::string lua_bytecode_;
+
+ // Triggering preconditions. These parameters can be backed by the model and
+ // (partially) be provided by flags.
+ TriggeringPreconditionsT preconditions_;
+ std::string triggering_preconditions_overlay_buffer_;
+ const TriggeringPreconditions* triggering_preconditions_overlay_;
+
+ // Low confidence input ngram classifier.
+ std::unique_ptr<const NGramModel> ngram_model_;
+};
+
+// Interprets the buffer as a Model flatbuffer and returns it for reading.
+const ActionsModel* ViewActionsModel(const void* buffer, int size);
+
+// Opens model from given path and runs a function, passing the loaded Model
+// flatbuffer as an argument.
+//
+// This is mainly useful if we don't want to pay the cost for the model
+// initialization because we'll be only reading some flatbuffer values from the
+// file.
+template <typename ReturnType, typename Func>
+ReturnType VisitActionsModel(const std::string& path, Func function) {
+ ScopedMmap mmap(path);
+ if (!mmap.handle().ok()) {
+ function(/*model=*/nullptr);
+ }
+ const ActionsModel* model =
+ ViewActionsModel(mmap.handle().start(), mmap.handle().num_bytes());
+ return function(model);
+}
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ACTIONS_ACTIONS_SUGGESTIONS_H_
diff --git a/actions/actions-suggestions_test.cc b/actions/actions-suggestions_test.cc
new file mode 100644
index 0000000..e0cfbaa
--- /dev/null
+++ b/actions/actions-suggestions_test.cc
@@ -0,0 +1,1332 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "actions/actions-suggestions.h"
+
+#include <fstream>
+#include <iterator>
+#include <memory>
+
+#include "actions/actions_model_generated.h"
+#include "actions/test_utils.h"
+#include "actions/zlib-utils.h"
+#include "annotator/collections.h"
+#include "annotator/types.h"
+#include "utils/flatbuffers.h"
+#include "utils/flatbuffers_generated.h"
+#include "utils/hash/farmhash.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "flatbuffers/flatbuffers.h"
+#include "flatbuffers/reflection.h"
+
+namespace libtextclassifier3 {
+namespace {
+using testing::_;
+
+constexpr char kModelFileName[] = "actions_suggestions_test.model";
+constexpr char kHashGramModelFileName[] =
+ "actions_suggestions_test.hashgram.model";
+
+std::string ReadFile(const std::string& file_name) {
+ std::ifstream file_stream(file_name);
+ return std::string(std::istreambuf_iterator<char>(file_stream), {});
+}
+
+std::string GetModelPath() {
+ return "";
+}
+
+class ActionsSuggestionsTest : public testing::Test {
+ protected:
+ ActionsSuggestionsTest() : INIT_UNILIB_FOR_TESTING(unilib_) {}
+ std::unique_ptr<ActionsSuggestions> LoadTestModel() {
+ return ActionsSuggestions::FromPath(GetModelPath() + kModelFileName,
+ &unilib_);
+ }
+ std::unique_ptr<ActionsSuggestions> LoadHashGramTestModel() {
+ return ActionsSuggestions::FromPath(GetModelPath() + kHashGramModelFileName,
+ &unilib_);
+ }
+ UniLib unilib_;
+};
+
+TEST_F(ActionsSuggestionsTest, InstantiateActionSuggestions) {
+ EXPECT_THAT(LoadTestModel(), testing::NotNull());
+}
+
+TEST_F(ActionsSuggestionsTest, SuggestActions) {
+ std::unique_ptr<ActionsSuggestions> actions_suggestions = LoadTestModel();
+ const ActionsSuggestionsResponse& response =
+ actions_suggestions->SuggestActions(
+ {{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/{}, /*locales=*/"en"}}});
+ EXPECT_EQ(response.actions.size(), 3 /* share_location + 2 smart replies*/);
+}
+
+TEST_F(ActionsSuggestionsTest, SuggestNoActionsForUnknownLocale) {
+ std::unique_ptr<ActionsSuggestions> actions_suggestions = LoadTestModel();
+ const ActionsSuggestionsResponse& response =
+ actions_suggestions->SuggestActions(
+ {{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/{}, /*locales=*/"zz"}}});
+ EXPECT_THAT(response.actions, testing::IsEmpty());
+}
+
+TEST_F(ActionsSuggestionsTest, SuggestActionsFromAnnotations) {
+ std::unique_ptr<ActionsSuggestions> actions_suggestions = LoadTestModel();
+ AnnotatedSpan annotation;
+ annotation.span = {11, 15};
+ annotation.classification = {ClassificationResult("address", 1.0)};
+ const ActionsSuggestionsResponse& response =
+ actions_suggestions->SuggestActions(
+ {{{/*user_id=*/1, "are you at home?",
+ /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/{annotation},
+ /*locales=*/"en"}}});
+ ASSERT_GE(response.actions.size(), 1);
+ EXPECT_EQ(response.actions.front().type, "view_map");
+ EXPECT_EQ(response.actions.front().score, 1.0);
+}
+
+TEST_F(ActionsSuggestionsTest, SuggestActionsFromAnnotationsWithEntityData) {
+ const std::string actions_model_string =
+ ReadFile(GetModelPath() + kModelFileName);
+ std::unique_ptr<ActionsModelT> actions_model =
+ UnPackActionsModel(actions_model_string.c_str());
+ SetTestEntityDataSchema(actions_model.get());
+
+ // Set custom actions from annotations config.
+ actions_model->annotation_actions_spec->annotation_mapping.clear();
+ actions_model->annotation_actions_spec->annotation_mapping.emplace_back(
+ new AnnotationActionsSpec_::AnnotationMappingT);
+ AnnotationActionsSpec_::AnnotationMappingT* mapping =
+ actions_model->annotation_actions_spec->annotation_mapping.back().get();
+ mapping->annotation_collection = "address";
+ mapping->action.reset(new ActionSuggestionSpecT);
+ mapping->action->type = "save_location";
+ mapping->action->score = 1.0;
+ mapping->action->priority_score = 2.0;
+ mapping->entity_field.reset(new FlatbufferFieldPathT);
+ mapping->entity_field->field.emplace_back(new FlatbufferFieldT);
+ mapping->entity_field->field.back()->field_name = "location";
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishActionsModelBuffer(builder,
+ ActionsModel::Pack(builder, actions_model.get()));
+ std::unique_ptr<ActionsSuggestions> actions_suggestions =
+ ActionsSuggestions::FromUnownedBuffer(
+ reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
+ builder.GetSize(), &unilib_);
+
+ AnnotatedSpan annotation;
+ annotation.span = {11, 15};
+ annotation.classification = {ClassificationResult("address", 1.0)};
+ const ActionsSuggestionsResponse& response =
+ actions_suggestions->SuggestActions(
+ {{{/*user_id=*/1, "are you at home?",
+ /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/{annotation},
+ /*locales=*/"en"}}});
+ ASSERT_GE(response.actions.size(), 1);
+ EXPECT_EQ(response.actions.front().type, "save_location");
+ EXPECT_EQ(response.actions.front().score, 1.0);
+
+ // Check that the `location` entity field holds the text from the address
+ // annotation.
+ const flatbuffers::Table* entity =
+ flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
+ response.actions.front().serialized_entity_data.data()));
+ EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/6)->str(),
+ "home");
+}
+
+TEST_F(ActionsSuggestionsTest, SuggestActionsFromDuplicatedAnnotations) {
+ std::unique_ptr<ActionsSuggestions> actions_suggestions = LoadTestModel();
+ AnnotatedSpan flight_annotation;
+ flight_annotation.span = {11, 15};
+ flight_annotation.classification = {ClassificationResult("flight", 2.5)};
+ AnnotatedSpan flight_annotation2;
+ flight_annotation2.span = {35, 39};
+ flight_annotation2.classification = {ClassificationResult("flight", 3.0)};
+ AnnotatedSpan email_annotation;
+ email_annotation.span = {55, 68};
+ email_annotation.classification = {ClassificationResult("email", 2.0)};
+
+ const ActionsSuggestionsResponse& response =
+ actions_suggestions->SuggestActions(
+ {{{/*user_id=*/1,
+ "call me at LX38 or send message to LX38 or test@test.com.",
+ /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/
+ {flight_annotation, flight_annotation2, email_annotation},
+ /*locales=*/"en"}}});
+
+ ASSERT_GE(response.actions.size(), 2);
+ EXPECT_EQ(response.actions[0].type, "track_flight");
+ EXPECT_EQ(response.actions[0].score, 3.0);
+ EXPECT_EQ(response.actions[1].type, "send_email");
+ EXPECT_EQ(response.actions[1].score, 2.0);
+}
+
+TEST_F(ActionsSuggestionsTest, SuggestActionsAnnotationsNoDeduplication) {
+ const std::string actions_model_string =
+ ReadFile(GetModelPath() + kModelFileName);
+ std::unique_ptr<ActionsModelT> actions_model =
+ UnPackActionsModel(actions_model_string.c_str());
+ // Disable deduplication.
+ actions_model->annotation_actions_spec->deduplicate_annotations = false;
+ flatbuffers::FlatBufferBuilder builder;
+ FinishActionsModelBuffer(builder,
+ ActionsModel::Pack(builder, actions_model.get()));
+ std::unique_ptr<ActionsSuggestions> actions_suggestions =
+ ActionsSuggestions::FromUnownedBuffer(
+ reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
+ builder.GetSize(), &unilib_);
+ AnnotatedSpan flight_annotation;
+ flight_annotation.span = {11, 15};
+ flight_annotation.classification = {ClassificationResult("flight", 2.5)};
+ AnnotatedSpan flight_annotation2;
+ flight_annotation2.span = {35, 39};
+ flight_annotation2.classification = {ClassificationResult("flight", 3.0)};
+ AnnotatedSpan email_annotation;
+ email_annotation.span = {55, 68};
+ email_annotation.classification = {ClassificationResult("email", 2.0)};
+
+ const ActionsSuggestionsResponse& response =
+ actions_suggestions->SuggestActions(
+ {{{/*user_id=*/1,
+ "call me at LX38 or send message to LX38 or test@test.com.",
+ /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/
+ {flight_annotation, flight_annotation2, email_annotation},
+ /*locales=*/"en"}}});
+
+ ASSERT_GE(response.actions.size(), 3);
+ EXPECT_EQ(response.actions[0].type, "track_flight");
+ EXPECT_EQ(response.actions[0].score, 3.0);
+ EXPECT_EQ(response.actions[1].type, "track_flight");
+ EXPECT_EQ(response.actions[1].score, 2.5);
+ EXPECT_EQ(response.actions[2].type, "send_email");
+ EXPECT_EQ(response.actions[2].score, 2.0);
+}
+
+ActionsSuggestionsResponse TestSuggestActionsFromAnnotations(
+ const std::function<void(ActionsModelT*)>& set_config_fn,
+ const UniLib* unilib = nullptr) {
+ const std::string actions_model_string =
+ ReadFile(GetModelPath() + kModelFileName);
+ std::unique_ptr<ActionsModelT> actions_model =
+ UnPackActionsModel(actions_model_string.c_str());
+
+ // Set custom config.
+ set_config_fn(actions_model.get());
+
+ // Disable smart reply for easier testing.
+ actions_model->preconditions->min_smart_reply_triggering_score = 1.0;
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishActionsModelBuffer(builder,
+ ActionsModel::Pack(builder, actions_model.get()));
+ std::unique_ptr<ActionsSuggestions> actions_suggestions =
+ ActionsSuggestions::FromUnownedBuffer(
+ reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib);
+
+ AnnotatedSpan flight_annotation;
+ flight_annotation.span = {15, 19};
+ flight_annotation.classification = {ClassificationResult("flight", 2.0)};
+ AnnotatedSpan email_annotation;
+ email_annotation.span = {0, 16};
+ email_annotation.classification = {ClassificationResult("email", 1.0)};
+
+ return actions_suggestions->SuggestActions(
+ {{{/*user_id=*/ActionsSuggestions::kLocalUserId,
+ "hehe@android.com",
+ /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/
+ {email_annotation},
+ /*locales=*/"en"},
+ {/*user_id=*/2,
+ "yoyo@android.com",
+ /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/
+ {email_annotation},
+ /*locales=*/"en"},
+ {/*user_id=*/1,
+ "test@android.com",
+ /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/
+ {email_annotation},
+ /*locales=*/"en"},
+ {/*user_id=*/1,
+ "I am on flight LX38.",
+ /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/
+ {flight_annotation},
+ /*locales=*/"en"}}});
+}
+
+TEST_F(ActionsSuggestionsTest, SuggestActionsWithAnnotationsOnlyLastMessage) {
+ const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations(
+ [](ActionsModelT* actions_model) {
+ actions_model->annotation_actions_spec->include_local_user_messages =
+ false;
+ actions_model->annotation_actions_spec->only_until_last_sent = true;
+ actions_model->annotation_actions_spec->max_history_from_any_person = 1;
+ actions_model->annotation_actions_spec->max_history_from_last_person =
+ 1;
+ },
+ &unilib_);
+ EXPECT_EQ(response.actions.size(), 1);
+ EXPECT_EQ(response.actions[0].type, "track_flight");
+}
+
+TEST_F(ActionsSuggestionsTest, SuggestActionsWithAnnotationsOnlyLastPerson) {
+ const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations(
+ [](ActionsModelT* actions_model) {
+ actions_model->annotation_actions_spec->include_local_user_messages =
+ false;
+ actions_model->annotation_actions_spec->only_until_last_sent = true;
+ actions_model->annotation_actions_spec->max_history_from_any_person = 1;
+ actions_model->annotation_actions_spec->max_history_from_last_person =
+ 3;
+ },
+ &unilib_);
+ EXPECT_EQ(response.actions.size(), 2);
+ EXPECT_EQ(response.actions[0].type, "track_flight");
+ EXPECT_EQ(response.actions[1].type, "send_email");
+}
+
+TEST_F(ActionsSuggestionsTest, SuggestActionsWithAnnotationsFromAny) {
+ const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations(
+ [](ActionsModelT* actions_model) {
+ actions_model->annotation_actions_spec->include_local_user_messages =
+ false;
+ actions_model->annotation_actions_spec->only_until_last_sent = true;
+ actions_model->annotation_actions_spec->max_history_from_any_person = 2;
+ actions_model->annotation_actions_spec->max_history_from_last_person =
+ 1;
+ },
+ &unilib_);
+ EXPECT_EQ(response.actions.size(), 2);
+ EXPECT_EQ(response.actions[0].type, "track_flight");
+ EXPECT_EQ(response.actions[1].type, "send_email");
+}
+
+TEST_F(ActionsSuggestionsTest,
+ SuggestActionsWithAnnotationsFromAnyManyMessages) {
+ const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations(
+ [](ActionsModelT* actions_model) {
+ actions_model->annotation_actions_spec->include_local_user_messages =
+ false;
+ actions_model->annotation_actions_spec->only_until_last_sent = true;
+ actions_model->annotation_actions_spec->max_history_from_any_person = 3;
+ actions_model->annotation_actions_spec->max_history_from_last_person =
+ 1;
+ },
+ &unilib_);
+ EXPECT_EQ(response.actions.size(), 3);
+ EXPECT_EQ(response.actions[0].type, "track_flight");
+ EXPECT_EQ(response.actions[1].type, "send_email");
+ EXPECT_EQ(response.actions[2].type, "send_email");
+}
+
+TEST_F(ActionsSuggestionsTest,
+ SuggestActionsWithAnnotationsFromAnyManyMessagesButNotLocalUser) {
+ const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations(
+ [](ActionsModelT* actions_model) {
+ actions_model->annotation_actions_spec->include_local_user_messages =
+ false;
+ actions_model->annotation_actions_spec->only_until_last_sent = true;
+ actions_model->annotation_actions_spec->max_history_from_any_person = 5;
+ actions_model->annotation_actions_spec->max_history_from_last_person =
+ 1;
+ },
+ &unilib_);
+ EXPECT_EQ(response.actions.size(), 3);
+ EXPECT_EQ(response.actions[0].type, "track_flight");
+ EXPECT_EQ(response.actions[1].type, "send_email");
+ EXPECT_EQ(response.actions[2].type, "send_email");
+}
+
+TEST_F(ActionsSuggestionsTest,
+ SuggestActionsWithAnnotationsFromAnyManyMessagesAlsoFromLocalUser) {
+ const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations(
+ [](ActionsModelT* actions_model) {
+ actions_model->annotation_actions_spec->include_local_user_messages =
+ true;
+ actions_model->annotation_actions_spec->only_until_last_sent = false;
+ actions_model->annotation_actions_spec->max_history_from_any_person = 5;
+ actions_model->annotation_actions_spec->max_history_from_last_person =
+ 1;
+ },
+ &unilib_);
+ EXPECT_EQ(response.actions.size(), 4);
+ EXPECT_EQ(response.actions[0].type, "track_flight");
+ EXPECT_EQ(response.actions[1].type, "send_email");
+ EXPECT_EQ(response.actions[2].type, "send_email");
+ EXPECT_EQ(response.actions[3].type, "send_email");
+}
+
+void TestSuggestActionsWithThreshold(
+ const std::function<void(ActionsModelT*)>& set_value_fn,
+ const UniLib* unilib = nullptr, const int expected_size = 0,
+ const std::string& preconditions_overwrite = "") {
+ const std::string actions_model_string =
+ ReadFile(GetModelPath() + kModelFileName);
+ std::unique_ptr<ActionsModelT> actions_model =
+ UnPackActionsModel(actions_model_string.c_str());
+ set_value_fn(actions_model.get());
+ flatbuffers::FlatBufferBuilder builder;
+ FinishActionsModelBuffer(builder,
+ ActionsModel::Pack(builder, actions_model.get()));
+ std::unique_ptr<ActionsSuggestions> actions_suggestions =
+ ActionsSuggestions::FromUnownedBuffer(
+ reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib, preconditions_overwrite);
+ ASSERT_TRUE(actions_suggestions);
+ const ActionsSuggestionsResponse& response =
+ actions_suggestions->SuggestActions(
+ {{{/*user_id=*/1, "I have the low-ground. Where are you?",
+ /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/{}, /*locales=*/"en"}}});
+ EXPECT_LE(response.actions.size(), expected_size);
+}
+
+TEST_F(ActionsSuggestionsTest, SuggestActionsWithTriggeringScore) {
+ TestSuggestActionsWithThreshold(
+ [](ActionsModelT* actions_model) {
+ actions_model->preconditions->min_smart_reply_triggering_score = 1.0;
+ },
+ &unilib_,
+ /*expected_size=*/1 /*no smart reply, only actions*/
+ );
+}
+
+TEST_F(ActionsSuggestionsTest, SuggestActionsWithMinReplyScore) {
+ TestSuggestActionsWithThreshold(
+ [](ActionsModelT* actions_model) {
+ actions_model->preconditions->min_reply_score_threshold = 1.0;
+ },
+ &unilib_,
+ /*expected_size=*/1 /*no smart reply, only actions*/
+ );
+}
+
+TEST_F(ActionsSuggestionsTest, SuggestActionsWithSensitiveTopicScore) {
+ TestSuggestActionsWithThreshold(
+ [](ActionsModelT* actions_model) {
+ actions_model->preconditions->max_sensitive_topic_score = 0.0;
+ },
+ &unilib_,
+ /*expected_size=*/4 /* no sensitive prediction in test model*/);
+}
+
+TEST_F(ActionsSuggestionsTest, SuggestActionsWithMaxInputLength) {
+ TestSuggestActionsWithThreshold(
+ [](ActionsModelT* actions_model) {
+ actions_model->preconditions->max_input_length = 0;
+ },
+ &unilib_);
+}
+
+TEST_F(ActionsSuggestionsTest, SuggestActionsWithMinInputLength) {
+ TestSuggestActionsWithThreshold(
+ [](ActionsModelT* actions_model) {
+ actions_model->preconditions->min_input_length = 100;
+ },
+ &unilib_);
+}
+
+TEST_F(ActionsSuggestionsTest, SuggestActionsWithPreconditionsOverwrite) {
+ TriggeringPreconditionsT preconditions_overwrite;
+ preconditions_overwrite.max_input_length = 0;
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(
+ TriggeringPreconditions::Pack(builder, &preconditions_overwrite));
+ TestSuggestActionsWithThreshold(
+ // Keep model untouched.
+ [](ActionsModelT* actions_model) {}, &unilib_,
+ /*expected_size=*/0,
+ std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize()));
+}
+
+#ifdef TC3_UNILIB_ICU
+TEST_F(ActionsSuggestionsTest, SuggestActionsLowConfidence) {
+ TestSuggestActionsWithThreshold(
+ [](ActionsModelT* actions_model) {
+ actions_model->preconditions->suppress_on_low_confidence_input = true;
+ actions_model->low_confidence_rules.reset(new RulesModelT);
+ actions_model->low_confidence_rules->rule.emplace_back(
+ new RulesModel_::RuleT);
+ actions_model->low_confidence_rules->rule.back()->pattern =
+ "low-ground";
+ },
+ &unilib_);
+}
+
+TEST_F(ActionsSuggestionsTest, SuggestActionsLowConfidenceInputOutput) {
+ const std::string actions_model_string =
+ ReadFile(GetModelPath() + kModelFileName);
+ std::unique_ptr<ActionsModelT> actions_model =
+ UnPackActionsModel(actions_model_string.c_str());
+ // Add custom triggering rule.
+ actions_model->rules.reset(new RulesModelT());
+ actions_model->rules->rule.emplace_back(new RulesModel_::RuleT);
+ RulesModel_::RuleT* rule = actions_model->rules->rule.back().get();
+ rule->pattern = "^(?i:hello\\s(there))$";
+ {
+ std::unique_ptr<RulesModel_::Rule_::RuleActionSpecT> rule_action(
+ new RulesModel_::Rule_::RuleActionSpecT);
+ rule_action->action.reset(new ActionSuggestionSpecT);
+ rule_action->action->type = "text_reply";
+ rule_action->action->response_text = "General Desaster!";
+ rule_action->action->score = 1.0f;
+ rule_action->action->priority_score = 1.0f;
+ rule->actions.push_back(std::move(rule_action));
+ }
+ {
+ std::unique_ptr<RulesModel_::Rule_::RuleActionSpecT> rule_action(
+ new RulesModel_::Rule_::RuleActionSpecT);
+ rule_action->action.reset(new ActionSuggestionSpecT);
+ rule_action->action->type = "text_reply";
+ rule_action->action->response_text = "General Kenobi!";
+ rule_action->action->score = 1.0f;
+ rule_action->action->priority_score = 1.0f;
+ rule->actions.push_back(std::move(rule_action));
+ }
+
+ // Add input-output low confidence rule.
+ actions_model->preconditions->suppress_on_low_confidence_input = true;
+ actions_model->low_confidence_rules.reset(new RulesModelT);
+ actions_model->low_confidence_rules->rule.emplace_back(
+ new RulesModel_::RuleT);
+ actions_model->low_confidence_rules->rule.back()->pattern = "hello";
+ actions_model->low_confidence_rules->rule.back()->output_pattern =
+ "(?i:desaster)";
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishActionsModelBuffer(builder,
+ ActionsModel::Pack(builder, actions_model.get()));
+ std::unique_ptr<ActionsSuggestions> actions_suggestions =
+ ActionsSuggestions::FromUnownedBuffer(
+ reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
+ builder.GetSize(), &unilib_);
+ ASSERT_TRUE(actions_suggestions);
+ const ActionsSuggestionsResponse& response =
+ actions_suggestions->SuggestActions(
+ {{{/*user_id=*/1, "hello there",
+ /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/{}, /*locales=*/"en"}}});
+ ASSERT_GE(response.actions.size(), 1);
+ EXPECT_EQ(response.actions[0].response_text, "General Kenobi!");
+}
+
+TEST_F(ActionsSuggestionsTest,
+ SuggestActionsLowConfidenceInputOutputOverwrite) {
+ const std::string actions_model_string =
+ ReadFile(GetModelPath() + kModelFileName);
+ std::unique_ptr<ActionsModelT> actions_model =
+ UnPackActionsModel(actions_model_string.c_str());
+ actions_model->low_confidence_rules.reset();
+
+ // Add custom triggering rule.
+ actions_model->rules.reset(new RulesModelT());
+ actions_model->rules->rule.emplace_back(new RulesModel_::RuleT);
+ RulesModel_::RuleT* rule = actions_model->rules->rule.back().get();
+ rule->pattern = "^(?i:hello\\s(there))$";
+ {
+ std::unique_ptr<RulesModel_::Rule_::RuleActionSpecT> rule_action(
+ new RulesModel_::Rule_::RuleActionSpecT);
+ rule_action->action.reset(new ActionSuggestionSpecT);
+ rule_action->action->type = "text_reply";
+ rule_action->action->response_text = "General Desaster!";
+ rule_action->action->score = 1.0f;
+ rule_action->action->priority_score = 1.0f;
+ rule->actions.push_back(std::move(rule_action));
+ }
+ {
+ std::unique_ptr<RulesModel_::Rule_::RuleActionSpecT> rule_action(
+ new RulesModel_::Rule_::RuleActionSpecT);
+ rule_action->action.reset(new ActionSuggestionSpecT);
+ rule_action->action->type = "text_reply";
+ rule_action->action->response_text = "General Kenobi!";
+ rule_action->action->score = 1.0f;
+ rule_action->action->priority_score = 1.0f;
+ rule->actions.push_back(std::move(rule_action));
+ }
+
+ // Add custom triggering rule via overwrite.
+ actions_model->preconditions->low_confidence_rules.reset();
+ TriggeringPreconditionsT preconditions;
+ preconditions.suppress_on_low_confidence_input = true;
+ preconditions.low_confidence_rules.reset(new RulesModelT);
+ preconditions.low_confidence_rules->rule.emplace_back(new RulesModel_::RuleT);
+ preconditions.low_confidence_rules->rule.back()->pattern = "hello";
+ preconditions.low_confidence_rules->rule.back()->output_pattern =
+ "(?i:desaster)";
+ flatbuffers::FlatBufferBuilder preconditions_builder;
+ preconditions_builder.Finish(
+ TriggeringPreconditions::Pack(preconditions_builder, &preconditions));
+ std::string serialize_preconditions = std::string(
+ reinterpret_cast<const char*>(preconditions_builder.GetBufferPointer()),
+ preconditions_builder.GetSize());
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishActionsModelBuffer(builder,
+ ActionsModel::Pack(builder, actions_model.get()));
+ std::unique_ptr<ActionsSuggestions> actions_suggestions =
+ ActionsSuggestions::FromUnownedBuffer(
+ reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
+ builder.GetSize(), &unilib_, serialize_preconditions);
+
+ ASSERT_TRUE(actions_suggestions);
+ const ActionsSuggestionsResponse& response =
+ actions_suggestions->SuggestActions(
+ {{{/*user_id=*/1, "hello there",
+ /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/{}, /*locales=*/"en"}}});
+ ASSERT_GE(response.actions.size(), 1);
+ EXPECT_EQ(response.actions[0].response_text, "General Kenobi!");
+}
+#endif
+
+TEST_F(ActionsSuggestionsTest, SuppressActionsFromAnnotationsOnSensitiveTopic) {
+ const std::string actions_model_string =
+ ReadFile(GetModelPath() + kModelFileName);
+ std::unique_ptr<ActionsModelT> actions_model =
+ UnPackActionsModel(actions_model_string.c_str());
+
+ // Don't test if no sensitivity score is produced
+ if (actions_model->tflite_model_spec->output_sensitive_topic_score < 0) {
+ return;
+ }
+
+ actions_model->preconditions->max_sensitive_topic_score = 0.0;
+ actions_model->preconditions->suppress_on_sensitive_topic = true;
+ flatbuffers::FlatBufferBuilder builder;
+ FinishActionsModelBuffer(builder,
+ ActionsModel::Pack(builder, actions_model.get()));
+ std::unique_ptr<ActionsSuggestions> actions_suggestions =
+ ActionsSuggestions::FromUnownedBuffer(
+ reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
+ builder.GetSize(), &unilib_);
+ AnnotatedSpan annotation;
+ annotation.span = {11, 15};
+ annotation.classification = {
+ ClassificationResult(Collections::Address(), 1.0)};
+ const ActionsSuggestionsResponse& response =
+ actions_suggestions->SuggestActions(
+ {{{/*user_id=*/1, "are you at home?",
+ /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/{annotation},
+ /*locales=*/"en"}}});
+ EXPECT_THAT(response.actions, testing::IsEmpty());
+}
+
+TEST_F(ActionsSuggestionsTest, SuggestActionsWithLongerConversation) {
+ const std::string actions_model_string =
+ ReadFile(GetModelPath() + kModelFileName);
+ std::unique_ptr<ActionsModelT> actions_model =
+ UnPackActionsModel(actions_model_string.c_str());
+
+ // Allow a larger conversation context.
+ actions_model->max_conversation_history_length = 10;
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishActionsModelBuffer(builder,
+ ActionsModel::Pack(builder, actions_model.get()));
+ std::unique_ptr<ActionsSuggestions> actions_suggestions =
+ ActionsSuggestions::FromUnownedBuffer(
+ reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
+ builder.GetSize(), &unilib_);
+ AnnotatedSpan annotation;
+ annotation.span = {11, 15};
+ annotation.classification = {
+ ClassificationResult(Collections::Address(), 1.0)};
+ const ActionsSuggestionsResponse& response =
+ actions_suggestions->SuggestActions(
+ {{{/*user_id=*/ActionsSuggestions::kLocalUserId, "hi, how are you?",
+ /*reference_time_ms_utc=*/10000,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/{}, /*locales=*/"en"},
+ {/*user_id=*/1, "good! are you at home?",
+ /*reference_time_ms_utc=*/15000,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/{annotation},
+ /*locales=*/"en"}}});
+ ASSERT_GE(response.actions.size(), 1);
+ EXPECT_EQ(response.actions[0].type, "view_map");
+ EXPECT_EQ(response.actions[0].score, 1.0);
+}
+
+TEST_F(ActionsSuggestionsTest, CreateActionsFromClassificationResult) {
+ std::unique_ptr<ActionsSuggestions> actions_suggestions = LoadTestModel();
+ AnnotatedSpan annotation;
+ annotation.span = {8, 12};
+ annotation.classification = {
+ ClassificationResult(Collections::Flight(), 1.0)};
+
+ const ActionsSuggestionsResponse& response =
+ actions_suggestions->SuggestActions(
+ {{{/*user_id=*/1, "I'm on LX38?",
+ /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/{annotation},
+ /*locales=*/"en"}}});
+
+ ASSERT_GE(response.actions.size(), 2);
+ EXPECT_EQ(response.actions[0].type, "track_flight");
+ EXPECT_EQ(response.actions[0].score, 1.0);
+ EXPECT_EQ(response.actions[0].annotations.size(), 1);
+ EXPECT_EQ(response.actions[0].annotations[0].span.message_index, 0);
+ EXPECT_EQ(response.actions[0].annotations[0].span.span, annotation.span);
+}
+
+#ifdef TC3_UNILIB_ICU
+TEST_F(ActionsSuggestionsTest, CreateActionsFromRules) {
+ const std::string actions_model_string =
+ ReadFile(GetModelPath() + kModelFileName);
+ std::unique_ptr<ActionsModelT> actions_model =
+ UnPackActionsModel(actions_model_string.c_str());
+ ASSERT_TRUE(DecompressActionsModel(actions_model.get()));
+
+ actions_model->rules.reset(new RulesModelT());
+ actions_model->rules->rule.emplace_back(new RulesModel_::RuleT);
+ RulesModel_::RuleT* rule = actions_model->rules->rule.back().get();
+ rule->pattern = "^(?i:hello\\s(there))$";
+ rule->actions.emplace_back(new RulesModel_::Rule_::RuleActionSpecT);
+ rule->actions.back()->action.reset(new ActionSuggestionSpecT);
+ ActionSuggestionSpecT* action = rule->actions.back()->action.get();
+ action->type = "text_reply";
+ action->response_text = "General Kenobi!";
+ action->score = 1.0f;
+ action->priority_score = 1.0f;
+
+ // Set capturing groups for entity data.
+ rule->actions.back()->capturing_group.emplace_back(
+ new RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT);
+ RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT* greeting_group =
+ rule->actions.back()->capturing_group.back().get();
+ greeting_group->group_id = 0;
+ greeting_group->entity_field.reset(new FlatbufferFieldPathT);
+ greeting_group->entity_field->field.emplace_back(new FlatbufferFieldT);
+ greeting_group->entity_field->field.back()->field_name = "greeting";
+ rule->actions.back()->capturing_group.emplace_back(
+ new RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT);
+ RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT* location_group =
+ rule->actions.back()->capturing_group.back().get();
+ location_group->group_id = 1;
+ location_group->entity_field.reset(new FlatbufferFieldPathT);
+ location_group->entity_field->field.emplace_back(new FlatbufferFieldT);
+ location_group->entity_field->field.back()->field_name = "location";
+
+ // Set test entity data schema.
+ SetTestEntityDataSchema(actions_model.get());
+
+ // Use meta data to generate custom serialized entity data.
+ ReflectiveFlatbufferBuilder entity_data_builder(
+ flatbuffers::GetRoot<reflection::Schema>(
+ actions_model->actions_entity_data_schema.data()));
+ std::unique_ptr<ReflectiveFlatbuffer> entity_data =
+ entity_data_builder.NewRoot();
+ entity_data->Set("person", "Kenobi");
+ action->serialized_entity_data = entity_data->Serialize();
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishActionsModelBuffer(builder,
+ ActionsModel::Pack(builder, actions_model.get()));
+ std::unique_ptr<ActionsSuggestions> actions_suggestions =
+ ActionsSuggestions::FromUnownedBuffer(
+ reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
+ builder.GetSize(), &unilib_);
+
+ const ActionsSuggestionsResponse& response =
+ actions_suggestions->SuggestActions(
+ {{{/*user_id=*/1, "hello there", /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/{}, /*locales=*/"en"}}});
+ EXPECT_GE(response.actions.size(), 1);
+ EXPECT_EQ(response.actions[0].response_text, "General Kenobi!");
+
+ // Check entity data.
+ const flatbuffers::Table* entity =
+ flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
+ response.actions[0].serialized_entity_data.data()));
+ EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(),
+ "hello there");
+ EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/6)->str(),
+ "there");
+ EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(),
+ "Kenobi");
+}
+
+TEST_F(ActionsSuggestionsTest, CreatesTextRepliesFromRules) {
+ const std::string actions_model_string =
+ ReadFile(GetModelPath() + kModelFileName);
+ std::unique_ptr<ActionsModelT> actions_model =
+ UnPackActionsModel(actions_model_string.c_str());
+ ASSERT_TRUE(DecompressActionsModel(actions_model.get()));
+
+ actions_model->rules.reset(new RulesModelT());
+ actions_model->rules->rule.emplace_back(new RulesModel_::RuleT);
+ RulesModel_::RuleT* rule = actions_model->rules->rule.back().get();
+ rule->pattern = "(?i:reply (stop|quit|end) (?:to|for) )";
+ rule->actions.emplace_back(new RulesModel_::Rule_::RuleActionSpecT);
+
+ // Set capturing groups for entity data.
+ rule->actions.back()->capturing_group.emplace_back(
+ new RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT);
+ RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT* code_group =
+ rule->actions.back()->capturing_group.back().get();
+ code_group->group_id = 1;
+ code_group->text_reply.reset(new ActionSuggestionSpecT);
+ code_group->text_reply->score = 1.0f;
+ code_group->text_reply->priority_score = 1.0f;
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishActionsModelBuffer(builder,
+ ActionsModel::Pack(builder, actions_model.get()));
+ std::unique_ptr<ActionsSuggestions> actions_suggestions =
+ ActionsSuggestions::FromUnownedBuffer(
+ reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
+ builder.GetSize(), &unilib_);
+
+ const ActionsSuggestionsResponse& response =
+ actions_suggestions->SuggestActions(
+ {{{/*user_id=*/1,
+ "visit test.com or reply STOP to cancel your subscription",
+ /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/{}, /*locales=*/"en"}}});
+ EXPECT_GE(response.actions.size(), 1);
+ EXPECT_EQ(response.actions[0].response_text, "STOP");
+}
+
+TEST_F(ActionsSuggestionsTest, DeduplicateActions) {
+ std::unique_ptr<ActionsSuggestions> actions_suggestions = LoadTestModel();
+ ActionsSuggestionsResponse response = actions_suggestions->SuggestActions(
+ {{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/{}, /*locales=*/"en"}}});
+
+ // Check that the location sharing model triggered.
+ bool has_location_sharing_action = false;
+ for (const ActionSuggestion action : response.actions) {
+ if (action.type == ActionsSuggestions::kShareLocation) {
+ has_location_sharing_action = true;
+ break;
+ }
+ }
+ EXPECT_TRUE(has_location_sharing_action);
+ const int num_actions = response.actions.size();
+
+ // Add custom rule for location sharing.
+ const std::string actions_model_string =
+ ReadFile(GetModelPath() + kModelFileName);
+ std::unique_ptr<ActionsModelT> actions_model =
+ UnPackActionsModel(actions_model_string.c_str());
+ ASSERT_TRUE(DecompressActionsModel(actions_model.get()));
+
+ actions_model->rules.reset(new RulesModelT());
+ actions_model->rules->rule.emplace_back(new RulesModel_::RuleT);
+ actions_model->rules->rule.back()->pattern = "^(?i:where are you[.?]?)$";
+ actions_model->rules->rule.back()->actions.emplace_back(
+ new RulesModel_::Rule_::RuleActionSpecT);
+ actions_model->rules->rule.back()->actions.back()->action.reset(
+ new ActionSuggestionSpecT);
+ ActionSuggestionSpecT* action =
+ actions_model->rules->rule.back()->actions.back()->action.get();
+ action->score = 1.0f;
+ action->type = ActionsSuggestions::kShareLocation;
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishActionsModelBuffer(builder,
+ ActionsModel::Pack(builder, actions_model.get()));
+ actions_suggestions = ActionsSuggestions::FromUnownedBuffer(
+ reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
+ builder.GetSize(), &unilib_);
+
+ response = actions_suggestions->SuggestActions(
+ {{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/{}, /*locales=*/"en"}}});
+ EXPECT_EQ(response.actions.size(), num_actions);
+}
+
+TEST_F(ActionsSuggestionsTest, DeduplicateConflictingActions) {
+ std::unique_ptr<ActionsSuggestions> actions_suggestions = LoadTestModel();
+ AnnotatedSpan annotation;
+ annotation.span = {7, 11};
+ annotation.classification = {
+ ClassificationResult(Collections::Flight(), 1.0)};
+ ActionsSuggestionsResponse response = actions_suggestions->SuggestActions(
+ {{{/*user_id=*/1, "I'm on LX38",
+ /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/{annotation},
+ /*locales=*/"en"}}});
+
+ // Check that the phone actions are present.
+ EXPECT_GE(response.actions.size(), 1);
+ EXPECT_EQ(response.actions[0].type, "track_flight");
+
+ // Add custom rule.
+ const std::string actions_model_string =
+ ReadFile(GetModelPath() + kModelFileName);
+ std::unique_ptr<ActionsModelT> actions_model =
+ UnPackActionsModel(actions_model_string.c_str());
+ ASSERT_TRUE(DecompressActionsModel(actions_model.get()));
+
+ actions_model->rules.reset(new RulesModelT());
+ actions_model->rules->rule.emplace_back(new RulesModel_::RuleT);
+ RulesModel_::RuleT* rule = actions_model->rules->rule.back().get();
+ rule->pattern = "^(?i:I'm on ([a-z0-9]+))$";
+ rule->actions.emplace_back(new RulesModel_::Rule_::RuleActionSpecT);
+ rule->actions.back()->action.reset(new ActionSuggestionSpecT);
+ ActionSuggestionSpecT* action = rule->actions.back()->action.get();
+ action->score = 1.0f;
+ action->priority_score = 2.0f;
+ action->type = "test_code";
+ rule->actions.back()->capturing_group.emplace_back(
+ new RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT);
+ RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT* code_group =
+ rule->actions.back()->capturing_group.back().get();
+ code_group->group_id = 1;
+ code_group->annotation_name = "code";
+ code_group->annotation_type = "code";
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishActionsModelBuffer(builder,
+ ActionsModel::Pack(builder, actions_model.get()));
+ actions_suggestions = ActionsSuggestions::FromUnownedBuffer(
+ reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
+ builder.GetSize(), &unilib_);
+
+ response = actions_suggestions->SuggestActions(
+ {{{/*user_id=*/1, "I'm on LX38",
+ /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/{annotation},
+ /*locales=*/"en"}}});
+ EXPECT_GE(response.actions.size(), 1);
+ EXPECT_EQ(response.actions[0].type, "test_code");
+}
+#endif
+
+TEST_F(ActionsSuggestionsTest, SuggestActionsRanking) {
+ std::unique_ptr<ActionsSuggestions> actions_suggestions = LoadTestModel();
+ std::vector<AnnotatedSpan> annotations(2);
+ annotations[0].span = {11, 15};
+ annotations[0].classification = {ClassificationResult("address", 1.0)};
+ annotations[1].span = {19, 23};
+ annotations[1].classification = {ClassificationResult("address", 2.0)};
+ const ActionsSuggestionsResponse& response =
+ actions_suggestions->SuggestActions(
+ {{{/*user_id=*/1, "are you at home or work?",
+ /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/annotations,
+ /*locales=*/"en"}}});
+ EXPECT_GE(response.actions.size(), 2);
+ EXPECT_EQ(response.actions[0].type, "view_map");
+ EXPECT_EQ(response.actions[0].score, 2.0);
+ EXPECT_EQ(response.actions[1].type, "view_map");
+ EXPECT_EQ(response.actions[1].score, 1.0);
+}
+
+TEST_F(ActionsSuggestionsTest, VisitActionsModel) {
+ EXPECT_TRUE(VisitActionsModel<bool>(GetModelPath() + kModelFileName,
+ [](const ActionsModel* model) {
+ if (model == nullptr) {
+ return false;
+ }
+ return true;
+ }));
+ EXPECT_FALSE(VisitActionsModel<bool>(GetModelPath() + "non_existing_model.fb",
+ [](const ActionsModel* model) {
+ if (model == nullptr) {
+ return false;
+ }
+ return true;
+ }));
+}
+
+TEST_F(ActionsSuggestionsTest, SuggestActionsWithHashGramModel) {
+ std::unique_ptr<ActionsSuggestions> actions_suggestions =
+ LoadHashGramTestModel();
+ ASSERT_TRUE(actions_suggestions != nullptr);
+ {
+ const ActionsSuggestionsResponse& response =
+ actions_suggestions->SuggestActions(
+ {{{/*user_id=*/1, "hello",
+ /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/{},
+ /*locales=*/"en"}}});
+ EXPECT_THAT(response.actions, testing::IsEmpty());
+ }
+ {
+ const ActionsSuggestionsResponse& response =
+ actions_suggestions->SuggestActions(
+ {{{/*user_id=*/1, "where are you",
+ /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/{},
+ /*locales=*/"en"}}});
+ EXPECT_THAT(
+ response.actions,
+ ElementsAre(testing::Field(&ActionSuggestion::type, "share_location")));
+ }
+ {
+ const ActionsSuggestionsResponse& response =
+ actions_suggestions->SuggestActions(
+ {{{/*user_id=*/1, "do you know johns number",
+ /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/{},
+ /*locales=*/"en"}}});
+ EXPECT_THAT(
+ response.actions,
+ ElementsAre(testing::Field(&ActionSuggestion::type, "share_contact")));
+ }
+}
+
+// Test class to expose token embedding methods for testing.
+class TestingMessageEmbedder : private ActionsSuggestions {
+ public:
+ explicit TestingMessageEmbedder(const ActionsModel* model);
+
+ using ActionsSuggestions::EmbedAndFlattenTokens;
+ using ActionsSuggestions::EmbedTokensPerMessage;
+
+ protected:
+ // EmbeddingExecutor that always returns features based on
+ // the id of the sparse features.
+ class FakeEmbeddingExecutor : public EmbeddingExecutor {
+ public:
+ bool AddEmbedding(const TensorView<int>& sparse_features, float* dest,
+ const int dest_size) const override {
+ TC3_CHECK_GE(dest_size, 1);
+ EXPECT_EQ(sparse_features.size(), 1);
+ dest[0] = sparse_features.data()[0];
+ return true;
+ }
+ };
+};
+
+TestingMessageEmbedder::TestingMessageEmbedder(const ActionsModel* model) {
+ model_ = model;
+ const ActionsTokenFeatureProcessorOptions* options =
+ model->feature_processor_options();
+ feature_processor_.reset(
+ new ActionsFeatureProcessor(options, /*unilib=*/nullptr));
+ embedding_executor_.reset(new FakeEmbeddingExecutor());
+ EXPECT_TRUE(
+ EmbedTokenId(options->padding_token_id(), &embedded_padding_token_));
+ EXPECT_TRUE(EmbedTokenId(options->start_token_id(), &embedded_start_token_));
+ EXPECT_TRUE(EmbedTokenId(options->end_token_id(), &embedded_end_token_));
+ token_embedding_size_ = feature_processor_->GetTokenEmbeddingSize();
+ EXPECT_EQ(token_embedding_size_, 1);
+}
+
+class EmbeddingTest : public testing::Test {
+ protected:
+ EmbeddingTest() {
+ model_.feature_processor_options.reset(
+ new ActionsTokenFeatureProcessorOptionsT);
+ options_ = model_.feature_processor_options.get();
+ options_->chargram_orders = {1};
+ options_->num_buckets = 1000;
+ options_->embedding_size = 1;
+ options_->start_token_id = 0;
+ options_->end_token_id = 1;
+ options_->padding_token_id = 2;
+ options_->tokenizer_options.reset(new ActionsTokenizerOptionsT);
+ }
+
+ TestingMessageEmbedder CreateTestingMessageEmbedder() {
+ flatbuffers::FlatBufferBuilder builder;
+ FinishActionsModelBuffer(builder, ActionsModel::Pack(builder, &model_));
+ buffer_ = builder.ReleaseBufferPointer();
+ return TestingMessageEmbedder(
+ flatbuffers::GetRoot<ActionsModel>(buffer_.data()));
+ }
+
+ flatbuffers::DetachedBuffer buffer_;
+ ActionsModelT model_;
+ ActionsTokenFeatureProcessorOptionsT* options_;
+};
+
+TEST_F(EmbeddingTest, EmbedsTokensPerMessageWithNoBounds) {
+ const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
+ std::vector<std::vector<Token>> tokens = {
+ {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}};
+ std::vector<float> embeddings;
+ int max_num_tokens_per_message = 0;
+
+ EXPECT_TRUE(embedder.EmbedTokensPerMessage(tokens, &embeddings,
+ &max_num_tokens_per_message));
+
+ EXPECT_EQ(max_num_tokens_per_message, 3);
+ EXPECT_EQ(embeddings.size(), 3);
+ EXPECT_THAT(embeddings[0],
+ testing::FloatEq(tc3farmhash::Fingerprint64("a", 1) %
+ options_->num_buckets));
+ EXPECT_THAT(embeddings[1],
+ testing::FloatEq(tc3farmhash::Fingerprint64("b", 1) %
+ options_->num_buckets));
+ EXPECT_THAT(embeddings[2],
+ testing::FloatEq(tc3farmhash::Fingerprint64("c", 1) %
+ options_->num_buckets));
+}
+
+TEST_F(EmbeddingTest, EmbedsTokensPerMessageWithPadding) {
+ options_->min_num_tokens_per_message = 5;
+ const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
+ std::vector<std::vector<Token>> tokens = {
+ {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}};
+ std::vector<float> embeddings;
+ int max_num_tokens_per_message = 0;
+
+ EXPECT_TRUE(embedder.EmbedTokensPerMessage(tokens, &embeddings,
+ &max_num_tokens_per_message));
+
+ EXPECT_EQ(max_num_tokens_per_message, 5);
+ EXPECT_EQ(embeddings.size(), 5);
+ EXPECT_THAT(embeddings[0],
+ testing::FloatEq(tc3farmhash::Fingerprint64("a", 1) %
+ options_->num_buckets));
+ EXPECT_THAT(embeddings[1],
+ testing::FloatEq(tc3farmhash::Fingerprint64("b", 1) %
+ options_->num_buckets));
+ EXPECT_THAT(embeddings[2],
+ testing::FloatEq(tc3farmhash::Fingerprint64("c", 1) %
+ options_->num_buckets));
+ EXPECT_THAT(embeddings[3], testing::FloatEq(options_->padding_token_id));
+ EXPECT_THAT(embeddings[4], testing::FloatEq(options_->padding_token_id));
+}
+
+TEST_F(EmbeddingTest, EmbedsTokensPerMessageDropsAtBeginning) {
+ options_->max_num_tokens_per_message = 2;
+ const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
+ std::vector<std::vector<Token>> tokens = {
+ {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}};
+ std::vector<float> embeddings;
+ int max_num_tokens_per_message = 0;
+
+ EXPECT_TRUE(embedder.EmbedTokensPerMessage(tokens, &embeddings,
+ &max_num_tokens_per_message));
+
+ EXPECT_EQ(max_num_tokens_per_message, 2);
+ EXPECT_EQ(embeddings.size(), 2);
+ EXPECT_THAT(embeddings[0],
+ testing::FloatEq(tc3farmhash::Fingerprint64("b", 1) %
+ options_->num_buckets));
+ EXPECT_THAT(embeddings[1],
+ testing::FloatEq(tc3farmhash::Fingerprint64("c", 1) %
+ options_->num_buckets));
+}
+
+TEST_F(EmbeddingTest, EmbedsTokensPerMessageWithMultipleMessagesNoBounds) {
+ const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
+ std::vector<std::vector<Token>> tokens = {
+ {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)},
+ {Token("d", 0, 1), Token("e", 2, 3)}};
+ std::vector<float> embeddings;
+ int max_num_tokens_per_message = 0;
+
+ EXPECT_TRUE(embedder.EmbedTokensPerMessage(tokens, &embeddings,
+ &max_num_tokens_per_message));
+
+ EXPECT_EQ(max_num_tokens_per_message, 3);
+ EXPECT_THAT(embeddings[0],
+ testing::FloatEq(tc3farmhash::Fingerprint64("a", 1) %
+ options_->num_buckets));
+ EXPECT_THAT(embeddings[1],
+ testing::FloatEq(tc3farmhash::Fingerprint64("b", 1) %
+ options_->num_buckets));
+ EXPECT_THAT(embeddings[2],
+ testing::FloatEq(tc3farmhash::Fingerprint64("c", 1) %
+ options_->num_buckets));
+ EXPECT_THAT(embeddings[3],
+ testing::FloatEq(tc3farmhash::Fingerprint64("d", 1) %
+ options_->num_buckets));
+ EXPECT_THAT(embeddings[4],
+ testing::FloatEq(tc3farmhash::Fingerprint64("e", 1) %
+ options_->num_buckets));
+ EXPECT_THAT(embeddings[5], testing::FloatEq(options_->padding_token_id));
+}
+
+TEST_F(EmbeddingTest, EmbedsFlattenedTokensWithNoBounds) {
+ const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
+ std::vector<std::vector<Token>> tokens = {
+ {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}};
+ std::vector<float> embeddings;
+ int total_token_count = 0;
+
+ EXPECT_TRUE(
+ embedder.EmbedAndFlattenTokens(tokens, &embeddings, &total_token_count));
+
+ EXPECT_EQ(total_token_count, 5);
+ EXPECT_EQ(embeddings.size(), 5);
+ EXPECT_THAT(embeddings[0], testing::FloatEq(options_->start_token_id));
+ EXPECT_THAT(embeddings[1],
+ testing::FloatEq(tc3farmhash::Fingerprint64("a", 1) %
+ options_->num_buckets));
+ EXPECT_THAT(embeddings[2],
+ testing::FloatEq(tc3farmhash::Fingerprint64("b", 1) %
+ options_->num_buckets));
+ EXPECT_THAT(embeddings[3],
+ testing::FloatEq(tc3farmhash::Fingerprint64("c", 1) %
+ options_->num_buckets));
+ EXPECT_THAT(embeddings[4], testing::FloatEq(options_->end_token_id));
+}
+
+TEST_F(EmbeddingTest, EmbedsFlattenedTokensWithPadding) {
+ options_->min_num_total_tokens = 7;
+ const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
+ std::vector<std::vector<Token>> tokens = {
+ {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}};
+ std::vector<float> embeddings;
+ int total_token_count = 0;
+
+ EXPECT_TRUE(
+ embedder.EmbedAndFlattenTokens(tokens, &embeddings, &total_token_count));
+
+ EXPECT_EQ(total_token_count, 7);
+ EXPECT_EQ(embeddings.size(), 7);
+ EXPECT_THAT(embeddings[0], testing::FloatEq(options_->start_token_id));
+ EXPECT_THAT(embeddings[1],
+ testing::FloatEq(tc3farmhash::Fingerprint64("a", 1) %
+ options_->num_buckets));
+ EXPECT_THAT(embeddings[2],
+ testing::FloatEq(tc3farmhash::Fingerprint64("b", 1) %
+ options_->num_buckets));
+ EXPECT_THAT(embeddings[3],
+ testing::FloatEq(tc3farmhash::Fingerprint64("c", 1) %
+ options_->num_buckets));
+ EXPECT_THAT(embeddings[4], testing::FloatEq(options_->end_token_id));
+ EXPECT_THAT(embeddings[5], testing::FloatEq(options_->padding_token_id));
+ EXPECT_THAT(embeddings[6], testing::FloatEq(options_->padding_token_id));
+}
+
+TEST_F(EmbeddingTest, EmbedsFlattenedTokensDropsAtBeginning) {
+ options_->max_num_total_tokens = 3;
+ const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
+ std::vector<std::vector<Token>> tokens = {
+ {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}};
+ std::vector<float> embeddings;
+ int total_token_count = 0;
+
+ EXPECT_TRUE(
+ embedder.EmbedAndFlattenTokens(tokens, &embeddings, &total_token_count));
+
+ EXPECT_EQ(total_token_count, 3);
+ EXPECT_EQ(embeddings.size(), 3);
+ EXPECT_THAT(embeddings[0],
+ testing::FloatEq(tc3farmhash::Fingerprint64("b", 1) %
+ options_->num_buckets));
+ EXPECT_THAT(embeddings[1],
+ testing::FloatEq(tc3farmhash::Fingerprint64("c", 1) %
+ options_->num_buckets));
+ EXPECT_THAT(embeddings[2], testing::FloatEq(options_->end_token_id));
+}
+
+TEST_F(EmbeddingTest, EmbedsFlattenedTokensWithMultipleMessagesNoBounds) {
+ const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
+ std::vector<std::vector<Token>> tokens = {
+ {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)},
+ {Token("d", 0, 1), Token("e", 2, 3)}};
+ std::vector<float> embeddings;
+ int total_token_count = 0;
+
+ EXPECT_TRUE(
+ embedder.EmbedAndFlattenTokens(tokens, &embeddings, &total_token_count));
+
+ EXPECT_EQ(total_token_count, 9);
+ EXPECT_EQ(embeddings.size(), 9);
+ EXPECT_THAT(embeddings[0], testing::FloatEq(options_->start_token_id));
+ EXPECT_THAT(embeddings[1],
+ testing::FloatEq(tc3farmhash::Fingerprint64("a", 1) %
+ options_->num_buckets));
+ EXPECT_THAT(embeddings[2],
+ testing::FloatEq(tc3farmhash::Fingerprint64("b", 1) %
+ options_->num_buckets));
+ EXPECT_THAT(embeddings[3],
+ testing::FloatEq(tc3farmhash::Fingerprint64("c", 1) %
+ options_->num_buckets));
+ EXPECT_THAT(embeddings[4], testing::FloatEq(options_->end_token_id));
+ EXPECT_THAT(embeddings[5], testing::FloatEq(options_->start_token_id));
+ EXPECT_THAT(embeddings[6],
+ testing::FloatEq(tc3farmhash::Fingerprint64("d", 1) %
+ options_->num_buckets));
+ EXPECT_THAT(embeddings[7],
+ testing::FloatEq(tc3farmhash::Fingerprint64("e", 1) %
+ options_->num_buckets));
+ EXPECT_THAT(embeddings[8], testing::FloatEq(options_->end_token_id));
+}
+
+TEST_F(EmbeddingTest,
+ EmbedsFlattenedTokensWithMultipleMessagesDropsAtBeginning) {
+ options_->max_num_total_tokens = 7;
+ const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
+ std::vector<std::vector<Token>> tokens = {
+ {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)},
+ {Token("d", 0, 1), Token("e", 2, 3), Token("f", 4, 5)}};
+ std::vector<float> embeddings;
+ int total_token_count = 0;
+
+ EXPECT_TRUE(
+ embedder.EmbedAndFlattenTokens(tokens, &embeddings, &total_token_count));
+
+ EXPECT_EQ(total_token_count, 7);
+ EXPECT_EQ(embeddings.size(), 7);
+ EXPECT_THAT(embeddings[0],
+ testing::FloatEq(tc3farmhash::Fingerprint64("c", 1) %
+ options_->num_buckets));
+ EXPECT_THAT(embeddings[1], testing::FloatEq(options_->end_token_id));
+ EXPECT_THAT(embeddings[2], testing::FloatEq(options_->start_token_id));
+ EXPECT_THAT(embeddings[3],
+ testing::FloatEq(tc3farmhash::Fingerprint64("d", 1) %
+ options_->num_buckets));
+ EXPECT_THAT(embeddings[4],
+ testing::FloatEq(tc3farmhash::Fingerprint64("e", 1) %
+ options_->num_buckets));
+ EXPECT_THAT(embeddings[5],
+ testing::FloatEq(tc3farmhash::Fingerprint64("f", 1) %
+ options_->num_buckets));
+ EXPECT_THAT(embeddings[6], testing::FloatEq(options_->end_token_id));
+}
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/actions/actions_jni.cc b/actions/actions_jni.cc
new file mode 100644
index 0000000..20891fa
--- /dev/null
+++ b/actions/actions_jni.cc
@@ -0,0 +1,408 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// JNI wrapper for actions.
+
+#include "actions/actions_jni.h"
+
+#include <jni.h>
+#include <map>
+#include <type_traits>
+#include <vector>
+
+#include "actions/actions-suggestions.h"
+#include "annotator/annotator.h"
+#include "annotator/annotator_jni_common.h"
+#include "utils/base/integral_types.h"
+#include "utils/intents/intent-generator.h"
+#include "utils/intents/jni.h"
+#include "utils/java/jni-cache.h"
+#include "utils/java/scoped_local_ref.h"
+#include "utils/java/string_utils.h"
+#include "utils/memory/mmap.h"
+
+using libtextclassifier3::ActionsSuggestions;
+using libtextclassifier3::ActionsSuggestionsResponse;
+using libtextclassifier3::ActionSuggestion;
+using libtextclassifier3::ActionSuggestionOptions;
+using libtextclassifier3::Annotator;
+using libtextclassifier3::Conversation;
+using libtextclassifier3::IntentGenerator;
+using libtextclassifier3::ScopedLocalRef;
+using libtextclassifier3::ToStlString;
+
+// When using the Java's ICU, UniLib needs to be instantiated with a JavaVM
+// pointer from JNI. When using a standard ICU the pointer is not needed and the
+// objects are instantiated implicitly.
+#ifdef TC3_UNILIB_JAVAICU
+using libtextclassifier3::UniLib;
+#endif
+
+namespace libtextclassifier3 {
+
+namespace {
+
+// Cached state for model inference.
+// Keeps a jni cache, intent generator and model instance so that they don't
+// have to be recreated for each call.
+class ActionsSuggestionsJniContext {
+ public:
+ static ActionsSuggestionsJniContext* Create(
+ const std::shared_ptr<libtextclassifier3::JniCache>& jni_cache,
+ std::unique_ptr<ActionsSuggestions> model) {
+ if (jni_cache == nullptr || model == nullptr) {
+ return nullptr;
+ }
+ std::unique_ptr<IntentGenerator> intent_generator =
+ IntentGenerator::Create(model->model()->android_intent_options(),
+ model->model()->resources(), jni_cache);
+ std::unique_ptr<RemoteActionTemplatesHandler> template_handler =
+ libtextclassifier3::RemoteActionTemplatesHandler::Create(jni_cache);
+
+ if (intent_generator == nullptr || template_handler == nullptr) {
+ return nullptr;
+ }
+
+ return new ActionsSuggestionsJniContext(jni_cache, std::move(model),
+ std::move(intent_generator),
+ std::move(template_handler));
+ }
+
+ std::shared_ptr<libtextclassifier3::JniCache> jni_cache() const {
+ return jni_cache_;
+ }
+
+ ActionsSuggestions* model() const { return model_.get(); }
+
+ IntentGenerator* intent_generator() const { return intent_generator_.get(); }
+
+ RemoteActionTemplatesHandler* template_handler() const {
+ return template_handler_.get();
+ }
+
+ private:
+ ActionsSuggestionsJniContext(
+ const std::shared_ptr<libtextclassifier3::JniCache>& jni_cache,
+ std::unique_ptr<ActionsSuggestions> model,
+ std::unique_ptr<IntentGenerator> intent_generator,
+ std::unique_ptr<RemoteActionTemplatesHandler> template_handler)
+ : jni_cache_(jni_cache),
+ model_(std::move(model)),
+ intent_generator_(std::move(intent_generator)),
+ template_handler_(std::move(template_handler)) {}
+
+ std::shared_ptr<libtextclassifier3::JniCache> jni_cache_;
+ std::unique_ptr<ActionsSuggestions> model_;
+ std::unique_ptr<IntentGenerator> intent_generator_;
+ std::unique_ptr<RemoteActionTemplatesHandler> template_handler_;
+};
+
+ActionSuggestionOptions FromJavaActionSuggestionOptions(JNIEnv* env,
+ jobject joptions) {
+ ActionSuggestionOptions options = ActionSuggestionOptions::Default();
+ return options;
+}
+
+jobjectArray ActionSuggestionsToJObjectArray(
+ JNIEnv* env, const ActionsSuggestionsJniContext* context,
+ jobject app_context,
+ const reflection::Schema* annotations_entity_data_schema,
+ const std::vector<ActionSuggestion>& action_result,
+ const Conversation& conversation, const jstring device_locales,
+ const bool generate_intents) {
+ const ScopedLocalRef<jclass> result_class(
+ env->FindClass(TC3_PACKAGE_PATH TC3_ACTIONS_CLASS_NAME_STR
+ "$ActionSuggestion"),
+ env);
+ if (!result_class) {
+ TC3_LOG(ERROR) << "Couldn't find ActionSuggestion class.";
+ return nullptr;
+ }
+
+ const jmethodID result_class_constructor = env->GetMethodID(
+ result_class.get(), "<init>",
+ "(Ljava/lang/String;Ljava/lang/String;F[L" TC3_PACKAGE_PATH
+ TC3_NAMED_VARIANT_CLASS_NAME_STR
+ ";[B[L" TC3_PACKAGE_PATH TC3_REMOTE_ACTION_TEMPLATE_CLASS_NAME_STR ";)V");
+ const jobjectArray results =
+ env->NewObjectArray(action_result.size(), result_class.get(), nullptr);
+ for (int i = 0; i < action_result.size(); i++) {
+ jobject extras = nullptr;
+
+ const reflection::Schema* actions_entity_data_schema =
+ context->model()->entity_data_schema();
+ if (actions_entity_data_schema != nullptr &&
+ !action_result[i].serialized_entity_data.empty()) {
+ extras = context->template_handler()->EntityDataAsNamedVariantArray(
+ actions_entity_data_schema, action_result[i].serialized_entity_data);
+ }
+
+ jbyteArray serialized_entity_data = nullptr;
+ if (!action_result[i].serialized_entity_data.empty()) {
+ serialized_entity_data =
+ env->NewByteArray(action_result[i].serialized_entity_data.size());
+ env->SetByteArrayRegion(
+ serialized_entity_data, 0,
+ action_result[i].serialized_entity_data.size(),
+ reinterpret_cast<const jbyte*>(
+ action_result[i].serialized_entity_data.data()));
+ }
+
+ jobject remote_action_templates_result = nullptr;
+ if (generate_intents) {
+ std::vector<RemoteActionTemplate> remote_action_templates;
+ if (context->intent_generator()->GenerateIntents(
+ device_locales, action_result[i], conversation, app_context,
+ actions_entity_data_schema, annotations_entity_data_schema,
+ &remote_action_templates)) {
+ remote_action_templates_result =
+ context->template_handler()->RemoteActionTemplatesToJObjectArray(
+ remote_action_templates);
+ }
+ }
+
+ ScopedLocalRef<jstring> reply = context->jni_cache()->ConvertToJavaString(
+ action_result[i].response_text);
+
+ ScopedLocalRef<jobject> result(env->NewObject(
+ result_class.get(), result_class_constructor, reply.get(),
+ env->NewStringUTF(action_result[i].type.c_str()),
+ static_cast<jfloat>(action_result[i].score), extras,
+ serialized_entity_data, remote_action_templates_result));
+ env->SetObjectArrayElement(results, i, result.get());
+ }
+ return results;
+}
+
+ConversationMessage FromJavaConversationMessage(JNIEnv* env, jobject jmessage) {
+ if (!jmessage) {
+ return {};
+ }
+
+ const ScopedLocalRef<jclass> message_class(
+ env->FindClass(TC3_PACKAGE_PATH TC3_ACTIONS_CLASS_NAME_STR
+ "$ConversationMessage"),
+ env);
+ const std::pair<bool, jobject> status_or_text = CallJniMethod0<jobject>(
+ env, jmessage, message_class.get(), &JNIEnv::CallObjectMethod, "getText",
+ "Ljava/lang/String;");
+ const std::pair<bool, int32> status_or_user_id =
+ CallJniMethod0<int32>(env, jmessage, message_class.get(),
+ &JNIEnv::CallIntMethod, "getUserId", "I");
+ const std::pair<bool, int64> status_or_reference_time = CallJniMethod0<int64>(
+ env, jmessage, message_class.get(), &JNIEnv::CallLongMethod,
+ "getReferenceTimeMsUtc", "J");
+ const std::pair<bool, jobject> status_or_reference_timezone =
+ CallJniMethod0<jobject>(env, jmessage, message_class.get(),
+ &JNIEnv::CallObjectMethod, "getReferenceTimezone",
+ "Ljava/lang/String;");
+ const std::pair<bool, jobject> status_or_detected_text_language_tags =
+ CallJniMethod0<jobject>(
+ env, jmessage, message_class.get(), &JNIEnv::CallObjectMethod,
+ "getDetectedTextLanguageTags", "Ljava/lang/String;");
+ if (!status_or_text.first || !status_or_user_id.first ||
+ !status_or_detected_text_language_tags.first ||
+ !status_or_reference_time.first || !status_or_reference_timezone.first) {
+ return {};
+ }
+
+ ConversationMessage message;
+ message.text = ToStlString(env, static_cast<jstring>(status_or_text.second));
+ message.user_id = status_or_user_id.second;
+ message.reference_time_ms_utc = status_or_reference_time.second;
+ message.reference_timezone = ToStlString(
+ env, static_cast<jstring>(status_or_reference_timezone.second));
+ message.detected_text_language_tags = ToStlString(
+ env, static_cast<jstring>(status_or_detected_text_language_tags.second));
+ return message;
+}
+
+Conversation FromJavaConversation(JNIEnv* env, jobject jconversation) {
+ if (!jconversation) {
+ return {};
+ }
+
+ const ScopedLocalRef<jclass> conversation_class(
+ env->FindClass(TC3_PACKAGE_PATH TC3_ACTIONS_CLASS_NAME_STR
+ "$Conversation"),
+ env);
+
+ const std::pair<bool, jobject> status_or_messages = CallJniMethod0<jobject>(
+ env, jconversation, conversation_class.get(), &JNIEnv::CallObjectMethod,
+ "getConversationMessages",
+ "[L" TC3_PACKAGE_PATH TC3_ACTIONS_CLASS_NAME_STR "$ConversationMessage;");
+
+ if (!status_or_messages.first) {
+ return {};
+ }
+
+ const jobjectArray jmessages =
+ reinterpret_cast<jobjectArray>(status_or_messages.second);
+
+ const int size = env->GetArrayLength(jmessages);
+
+ std::vector<ConversationMessage> messages;
+ for (int i = 0; i < size; i++) {
+ jobject jmessage = env->GetObjectArrayElement(jmessages, i);
+ ConversationMessage message = FromJavaConversationMessage(env, jmessage);
+ messages.push_back(message);
+ }
+ Conversation conversation;
+ conversation.messages = messages;
+ return conversation;
+}
+
+jstring GetLocalesFromMmap(JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) {
+ if (!mmap->handle().ok()) {
+ return env->NewStringUTF("");
+ }
+ const ActionsModel* model = libtextclassifier3::ViewActionsModel(
+ mmap->handle().start(), mmap->handle().num_bytes());
+ if (!model || !model->locales()) {
+ return env->NewStringUTF("");
+ }
+ return env->NewStringUTF(model->locales()->c_str());
+}
+
+jint GetVersionFromMmap(JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) {
+ if (!mmap->handle().ok()) {
+ return 0;
+ }
+ const ActionsModel* model = libtextclassifier3::ViewActionsModel(
+ mmap->handle().start(), mmap->handle().num_bytes());
+ if (!model) {
+ return 0;
+ }
+ return model->version();
+}
+
+jstring GetNameFromMmap(JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) {
+ if (!mmap->handle().ok()) {
+ return env->NewStringUTF("");
+ }
+ const ActionsModel* model = libtextclassifier3::ViewActionsModel(
+ mmap->handle().start(), mmap->handle().num_bytes());
+ if (!model || !model->name()) {
+ return env->NewStringUTF("");
+ }
+ return env->NewStringUTF(model->name()->c_str());
+}
+} // namespace
+} // namespace libtextclassifier3
+
+using libtextclassifier3::ActionsSuggestionsJniContext;
+using libtextclassifier3::ActionSuggestionsToJObjectArray;
+using libtextclassifier3::FromJavaActionSuggestionOptions;
+using libtextclassifier3::FromJavaConversation;
+
+TC3_JNI_METHOD(jlong, TC3_ACTIONS_CLASS_NAME, nativeNewActionsModel)
+(JNIEnv* env, jobject thiz, jint fd, jbyteArray serialized_preconditions) {
+ std::shared_ptr<libtextclassifier3::JniCache> jni_cache =
+ libtextclassifier3::JniCache::Create(env);
+ std::string preconditions;
+ if (serialized_preconditions != nullptr &&
+ !libtextclassifier3::JByteArrayToString(env, serialized_preconditions,
+ &preconditions)) {
+ TC3_LOG(ERROR) << "Could not convert serialized preconditions.";
+ return 0;
+ }
+#ifdef TC3_UNILIB_JAVAICU
+ return reinterpret_cast<jlong>(ActionsSuggestionsJniContext::Create(
+ jni_cache,
+ ActionsSuggestions::FromFileDescriptor(
+ fd, std::unique_ptr<UniLib>(new UniLib(jni_cache)), preconditions)));
+#else
+ return reinterpret_cast<jlong>(ActionsSuggestionsJniContext::Create(
+ jni_cache, ActionsSuggestions::FromFileDescriptor(fd, /*unilib=*/nullptr,
+ preconditions)));
+#endif // TC3_UNILIB_JAVAICU
+}
+
+TC3_JNI_METHOD(jlong, TC3_ACTIONS_CLASS_NAME, nativeNewActionsModelFromPath)
+(JNIEnv* env, jobject thiz, jstring path, jbyteArray serialized_preconditions) {
+ std::shared_ptr<libtextclassifier3::JniCache> jni_cache =
+ libtextclassifier3::JniCache::Create(env);
+ const std::string path_str = ToStlString(env, path);
+ std::string preconditions;
+ if (serialized_preconditions != nullptr &&
+ !libtextclassifier3::JByteArrayToString(env, serialized_preconditions,
+ &preconditions)) {
+ TC3_LOG(ERROR) << "Could not convert serialized preconditions.";
+ return 0;
+ }
+#ifdef TC3_UNILIB_JAVAICU
+ return reinterpret_cast<jlong>(ActionsSuggestionsJniContext::Create(
+ jni_cache, ActionsSuggestions::FromPath(
+ path_str, std::unique_ptr<UniLib>(new UniLib(jni_cache)),
+ preconditions)));
+#else
+ return reinterpret_cast<jlong>(ActionsSuggestionsJniContext::Create(
+ jni_cache, ActionsSuggestions::FromPath(path_str, /*unilib=*/nullptr,
+ preconditions)));
+#endif // TC3_UNILIB_JAVAICU
+}
+
+TC3_JNI_METHOD(jobjectArray, TC3_ACTIONS_CLASS_NAME, nativeSuggestActions)
+(JNIEnv* env, jobject clazz, jlong ptr, jobject jconversation, jobject joptions,
+ jlong annotatorPtr, jobject app_context, jstring device_locales,
+ jboolean generate_intents) {
+ if (!ptr) {
+ return nullptr;
+ }
+ const Conversation conversation = FromJavaConversation(env, jconversation);
+ const ActionSuggestionOptions options =
+ FromJavaActionSuggestionOptions(env, joptions);
+ const ActionsSuggestionsJniContext* context =
+ reinterpret_cast<ActionsSuggestionsJniContext*>(ptr);
+ const Annotator* annotator = reinterpret_cast<Annotator*>(annotatorPtr);
+
+ const ActionsSuggestionsResponse response =
+ context->model()->SuggestActions(conversation, annotator, options);
+
+ const reflection::Schema* anntotations_entity_data_schema =
+ annotator ? annotator->entity_data_schema() : nullptr;
+ return ActionSuggestionsToJObjectArray(
+ env, context, app_context, anntotations_entity_data_schema,
+ response.actions, conversation, device_locales, generate_intents);
+}
+
+TC3_JNI_METHOD(void, TC3_ACTIONS_CLASS_NAME, nativeCloseActionsModel)
+(JNIEnv* env, jobject clazz, jlong model_ptr) {
+ const ActionsSuggestionsJniContext* context =
+ reinterpret_cast<ActionsSuggestionsJniContext*>(model_ptr);
+ delete context;
+}
+
+TC3_JNI_METHOD(jstring, TC3_ACTIONS_CLASS_NAME, nativeGetLocales)
+(JNIEnv* env, jobject clazz, jint fd) {
+ const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
+ new libtextclassifier3::ScopedMmap(fd));
+ return libtextclassifier3::GetLocalesFromMmap(env, mmap.get());
+}
+
+TC3_JNI_METHOD(jstring, TC3_ACTIONS_CLASS_NAME, nativeGetName)
+(JNIEnv* env, jobject clazz, jint fd) {
+ const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
+ new libtextclassifier3::ScopedMmap(fd));
+ return libtextclassifier3::GetNameFromMmap(env, mmap.get());
+}
+
+TC3_JNI_METHOD(jint, TC3_ACTIONS_CLASS_NAME, nativeGetVersion)
+(JNIEnv* env, jobject clazz, jint fd) {
+ const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
+ new libtextclassifier3::ScopedMmap(fd));
+ return libtextclassifier3::GetVersionFromMmap(env, mmap.get());
+}
diff --git a/actions/actions_jni.h b/actions/actions_jni.h
new file mode 100644
index 0000000..fe2b998
--- /dev/null
+++ b/actions/actions_jni.h
@@ -0,0 +1,61 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ACTIONS_ACTIONS_JNI_H_
+#define LIBTEXTCLASSIFIER_ACTIONS_ACTIONS_JNI_H_
+
+#include <jni.h>
+#include <string>
+#include "utils/java/jni-base.h"
+
+#ifndef TC3_ACTIONS_CLASS_NAME
+#define TC3_ACTIONS_CLASS_NAME ActionsSuggestionsModel
+#endif
+
+#define TC3_ACTIONS_CLASS_NAME_STR TC3_ADD_QUOTES(TC3_ACTIONS_CLASS_NAME)
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+TC3_JNI_METHOD(jlong, TC3_ACTIONS_CLASS_NAME, nativeNewActionsModel)
+(JNIEnv* env, jobject thiz, jint fd, jbyteArray serialized_preconditions);
+
+TC3_JNI_METHOD(jlong, TC3_ACTIONS_CLASS_NAME, nativeNewActionsModelFromPath)
+(JNIEnv* env, jobject thiz, jstring path, jbyteArray serialized_preconditions);
+
+TC3_JNI_METHOD(jobjectArray, TC3_ACTIONS_CLASS_NAME, nativeSuggestActions)
+(JNIEnv* env, jobject thiz, jlong ptr, jobject jconversation, jobject joptions,
+ jlong annotatorPtr, jobject app_context, jstring device_locales,
+ jboolean generate_intents);
+
+TC3_JNI_METHOD(void, TC3_ACTIONS_CLASS_NAME, nativeCloseActionsModel)
+(JNIEnv* env, jobject thiz, jlong ptr);
+
+TC3_JNI_METHOD(jstring, TC3_ACTIONS_CLASS_NAME, nativeGetLocales)
+(JNIEnv* env, jobject clazz, jint fd);
+
+TC3_JNI_METHOD(jstring, TC3_ACTIONS_CLASS_NAME, nativeGetName)
+(JNIEnv* env, jobject clazz, jint fd);
+
+TC3_JNI_METHOD(jint, TC3_ACTIONS_CLASS_NAME, nativeGetVersion)
+(JNIEnv* env, jobject clazz, jint fd);
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif // LIBTEXTCLASSIFIER_ACTIONS_ACTIONS_JNI_H_
diff --git a/actions/actions_model.fbs b/actions/actions_model.fbs
new file mode 100755
index 0000000..42c7d88
--- /dev/null
+++ b/actions/actions_model.fbs
@@ -0,0 +1,480 @@
+//
+// Copyright (C) 2018 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+include "annotator/model.fbs";
+include "utils/codepoint-range.fbs";
+include "utils/flatbuffers.fbs";
+include "utils/intents/intent-config.fbs";
+include "utils/resources.fbs";
+include "utils/tokenizer.fbs";
+include "utils/zlib/buffer.fbs";
+
+file_identifier "TC3A";
+
+// TensorFlow Lite model for suggesting actions.
+namespace libtextclassifier3;
+table TensorflowLiteModelSpec {
+ // TensorFlow Lite model for suggesting actions.
+ tflite_model:[ubyte] (force_align: 16);
+
+ // Input specification.
+ // (num messages,) int32 tensor, the user id per message.
+ input_user_id:int = 0;
+
+ // (num messages,) string tensor, each message of the conversation.
+ input_context:int = 1;
+
+ // int, the number of messages in the conversation.
+ input_context_length:int = 2;
+
+ // (num messages,) float tensor, the time difference in seconds of the
+ // messages in the conversation.
+ input_time_diffs:int = 3;
+
+ // int, the number of smart replies to produce.
+ input_num_suggestions:int = 4;
+
+ // float, the output diversification distance parameter.
+ input_diversification_distance:int = -1;
+
+ // float, the empirical probability factor parameter.
+ input_empirical_probability_factor:int = -1;
+
+ // float, the confidence threshold.
+ input_confidence_threshold:int = -1;
+
+ // Input port for hashed and embedded tokens, a (num messages, max tokens,
+ // embedding size) float tensor specifying the embeddings of each token of
+ // each message in the conversation.
+ input_token_embeddings:int = -1;
+
+ // Input port for the number of tokens per message.
+ // (num messages) int32 tensor specifying the number of tokens in each message
+ // in the conversation.
+ input_num_tokens:int = -1;
+
+ // Output specification.
+ output_replies:int = 0;
+
+ output_replies_scores:int = 1;
+ output_sensitive_topic_score:int = 3;
+ output_triggering_score:int = 4;
+ output_actions_scores:int = 5;
+
+ // Model setup.
+ // When true, the inputs are resized to the concrete input sizes before
+ // inference otherwise, it's assumed that the model has the correct input
+ // shapes set.
+ resize_inputs:bool = false;
+
+ // Input port for the hashed, embedded and flattened/concatenated tokens.
+ // A (max tokens, embedding_size) float tensor specifying the embeddings of
+ // each token.
+ input_flattened_token_embeddings:int = -1;
+}
+
+// Configuration for the tokenizer.
+namespace libtextclassifier3;
+table ActionsTokenizerOptions {
+ type:TokenizationType = INTERNAL_TOKENIZER;
+
+ // If true, white space tokens will be kept when using the icu tokenizer.
+ icu_preserve_whitespace_tokens:bool = false;
+
+ // Codepoint ranges that determine what role the different codepoints play
+ // during tokenized. The ranges must not overlap.
+ tokenization_codepoint_config:[TokenizationCodepointRange];
+
+ // A set of codepoint ranges to use in the mixed tokenization mode to identify
+ // stretches of tokens to re-tokenize using the internal tokenizer.
+ internal_tokenizer_codepoint_ranges:[CodepointRange];
+
+ // If true, tokens will be also split when the codepoint's script_id changes
+ // as defined in TokenizationCodepointRange.
+ tokenize_on_script_change:bool = false;
+}
+
+// Configuration for the feature processor.
+namespace libtextclassifier3;
+table ActionsTokenFeatureProcessorOptions {
+ // Tokenizer options.
+ tokenizer_options:ActionsTokenizerOptions;
+
+ // Serialized TensorFlow Lite model with weights for the token embeddings.
+ embedding_model:[ubyte] (force_align: 16);
+
+ // Size of the embedding.
+ embedding_size:int = -1;
+
+ // Number of bits for quantization for embeddings.
+ embedding_quantization_bits:int = 8;
+
+ // Number of buckets used for hashing charactergrams.
+ num_buckets:int = -1;
+
+ // Orders of charactergrams to extract, e.g. 2 means character bigrams, 3
+ // character trigrams etc.
+ chargram_orders:[int];
+
+ // Whether to extract the token case feature.
+ extract_case_feature:bool;
+
+ // If true, will use the unicode-aware functionality for extracting features.
+ unicode_aware_features:bool;
+
+ // Regexp features to extract.
+ regexp_features:[string];
+
+ // Whether to remap digits to a single number.
+ remap_digits:bool;
+
+ // Whether to lowercase all tokens.
+ lowercase_tokens:bool;
+
+ // Maximum length of a word.
+ max_token_length:int = 20;
+
+ // The `max_num_tokens_per_message` and `min_num_tokens_per_message` are
+ // applied when tokens are embedded per message.
+ // If set and the number of tokens of a message is bigger than this limit,
+ // tokens at the beginning of the message are dropped to fit the limit.
+ max_num_tokens_per_message:int = -1;
+
+ // If set, the tokens of each message will be padded to this fixed number of
+ // tokens.
+ min_num_tokens_per_message:int = -1;
+
+ // If set and the total number of concatenated tokens is bigger than this
+ // limit, tokens at the start of the conversation are dropped.
+ max_num_total_tokens:int = -1;
+
+ // If set and the total number of concatenaed tokens is smaller than this
+ // limit, the conversation is padded with padding tokens.
+ min_num_total_tokens:int = -1;
+
+ // Id that is used as encoding of the padding token.
+ padding_token_id:int = 0;
+
+ // Id that is used as encoding of the start of message token.
+ start_token_id:int = 1;
+
+ // Id that is used as encoding of the end of message token.
+ end_token_id:int = 2;
+}
+
+// N-Gram based linear regression model.
+namespace libtextclassifier3;
+table NGramLinearRegressionModel {
+ // A flat list of all the hashed n-grams concatenated back to back. Elements
+ // should only ever be accessed via the offset table below.
+ hashed_ngram_tokens:[uint];
+
+ // Offsets to the start of the n-grams in hashed_ngram_tokens. The last
+ // element in this array is the length of hashed_ngrams to make it easier to
+ // compute n-gram lengths.
+ ngram_start_offsets:[ushort];
+
+ // Weights of the n-grams.
+ ngram_weights:[float];
+
+ // The default weight assigned to n-grams that weren't matched.
+ default_token_weight:float;
+
+ // Maximum n-gram length to consider when calculating the denominatior.
+ // This should usually be the same as max_ngram_length but can diverge
+ // if additional (longer) n-grams are added to a model as part of a minor
+ // update.
+ max_denom_ngram_length:int;
+
+ // If non-zero, the order of the skip-gram to match.
+ max_skips:int;
+
+ // The threshold above which the model output is considered positive.
+ threshold:float;
+
+ // Model specific tokenizer options.
+ // If not specified, will reuse the feature processor tokenizer.
+ tokenizer_options:ActionsTokenizerOptions;
+}
+
+namespace libtextclassifier3;
+table TriggeringPreconditions {
+ // Lower bound thresholds for the smart reply model prediction output.
+ min_smart_reply_triggering_score:float;
+
+ // Maximum sensitive score for which actions and smart replies are shown.
+ max_sensitive_topic_score:float = 1;
+
+ // Whether to suppress all model output when a conversation is classified as
+ // sensitive.
+ suppress_on_sensitive_topic:bool = true;
+
+ // Thresholds on the model prediction input.
+ // The minimal length of input to consider for prediction.
+ min_input_length:int = 0;
+
+ // The maximal length of input to consider for prediciton, -1 if unbounded.
+ max_input_length:int = -1;
+
+ // Minimal fraction of messages in the input conversation that need to match
+ // a locale that the model can handle.
+ min_locale_match_fraction:float = 0.75;
+
+ handle_missing_locale_as_supported:bool = false;
+ handle_unknown_locale_as_supported:bool = false;
+
+ // Filter input with low-confidence triggers.
+ suppress_on_low_confidence_input:bool = true;
+
+ // Same as low_confidence_rules in ActionsModel.
+ // NOTE: Only fill this when the TriggeringPreconditions are pushed separately
+ // as a flag value (i.e. as overlay).
+ low_confidence_rules:RulesModel;
+
+ // Smart reply thresholds.
+ diversification_distance_threshold:float = 0;
+
+ confidence_threshold:float = 0;
+ empirical_probability_factor:float = 0;
+ min_reply_score_threshold:float = 0;
+}
+
+namespace libtextclassifier3;
+table ActionSuggestionSpec {
+ // Type of the action suggestion.
+ type:string;
+
+ // Text of a smart reply action.
+ response_text:string;
+
+ // Score.
+ score:float;
+
+ // Serialized entity information.
+ serialized_entity_data:string;
+
+ // Priority score used for internal conflict resolution.
+ priority_score:float = 0;
+}
+
+// Options to specify triggering behaviour per action class.
+namespace libtextclassifier3;
+table ActionTypeOptions {
+ // The name of the predicted action.
+ name:string;
+
+ // Triggering behaviour.
+ // Whether the action class is considered in the model output or not.
+ enabled:bool = true;
+
+ // Minimal output score threshold.
+ min_triggering_score:float = 0;
+
+ // The action to trigger.
+ action:ActionSuggestionSpec;
+}
+
+namespace libtextclassifier3.AnnotationActionsSpec_;
+table AnnotationMapping {
+ // The annotation collection.
+ annotation_collection:string;
+
+ // The action name to use.
+ action:ActionSuggestionSpec;
+
+ // Whether to use the score of the annotation as the action score.
+ use_annotation_score:bool = true;
+
+ // Minimum threshold for the annotation score for filtering.
+ min_annotation_score:float;
+
+ // If set, the text of the annotation will be used to set a field in the
+ // action entity data.
+ entity_field:FlatbufferFieldPath;
+}
+
+// Configuration for actions based on annotatations.
+namespace libtextclassifier3;
+table AnnotationActionsSpec {
+ annotation_mapping:[AnnotationActionsSpec_.AnnotationMapping];
+
+ // Whether to deduplicate annotations by type and text prior to generating
+ // actions.
+ deduplicate_annotations:bool = true;
+
+ // Annotation usecase to specify for text annotation.
+ annotation_usecase:AnnotationUsecase = ANNOTATION_USECASE_SMART;
+
+ // Maximum number of recent messages to consider from any person.
+ // We consider at most `max_history_from_any_person` many recent messages if
+ // they were received from different users or at most the maximum of this and
+ // `max_history_from_last_person` if they are all from the same user.
+ max_history_from_any_person:int = 1;
+
+ // Maximum number of recent messages to consider from the last person.
+ max_history_from_last_person:int = 1;
+
+ // Whether to include messages from the local user.
+ include_local_user_messages:bool = false;
+
+ // Whether to only consider messages up to the last one sent by the local
+ // user.
+ only_until_last_sent:bool = true;
+
+ // If true, annotator would populare serialized_entity_data in the results.
+ is_serialized_entity_data_enabled:bool = true;
+}
+
+// Ranking options.
+namespace libtextclassifier3;
+table RankingOptions {
+ // When true, actions suggestions are deduplicated by `type`, `response_text`
+ // and associated annotations, keeping the higher scoring actions.
+ deduplicate_suggestions:bool = true;
+
+ // When true, actions are deduplicated by the span they are referring to.
+ deduplicate_suggestions_by_span:bool = true;
+
+ // Optional script to run for ranking and filtering the action suggestions.
+ // The following global variables are available to the script:
+ // * input: (optionally deduplicated) action suggestions, via the `actions`
+ // global
+ // * output: indices of the actions to keep in the provided order.
+ lua_ranking_script:string;
+
+ compressed_lua_ranking_script:CompressedBuffer;
+
+ // If true, suppresses smart replies if other smart actions are suggested.
+ suppress_smart_replies_with_actions:bool = false;
+
+ // If true, keep actions from the same entities together for ranking.
+ group_by_annotations:bool = true;
+}
+
+// Entity data to set from capturing groups.
+namespace libtextclassifier3.RulesModel_.Rule_.RuleActionSpec_;
+table RuleCapturingGroup {
+ // The id of group.
+ group_id:int;
+
+ // If set, the text of the capturing group will be used to set a field
+ // in the action entity data.
+ entity_field:FlatbufferFieldPath;
+
+ // If set, the capturing group will be used to create a text annotation
+ // with the given name and type.
+ annotation_type:string;
+
+ annotation_name:string;
+
+ // If set, the capturing group text will be used to create a text
+ // reply.
+ text_reply:ActionSuggestionSpec;
+}
+
+// The actions to produce upon triggering.
+namespace libtextclassifier3.RulesModel_.Rule_;
+table RuleActionSpec {
+ // The action.
+ action:ActionSuggestionSpec;
+
+ capturing_group:[RuleActionSpec_.RuleCapturingGroup];
+}
+
+// List of regular expression matchers.
+namespace libtextclassifier3.RulesModel_;
+table Rule {
+ // The regular expression pattern.
+ pattern:string;
+
+ compressed_pattern:CompressedBuffer;
+ actions:[Rule_.RuleActionSpec];
+
+ // Patterns for post-checking the outputs.
+ output_pattern:string;
+
+ compressed_output_pattern:CompressedBuffer;
+}
+
+// Rule based actions.
+namespace libtextclassifier3;
+table RulesModel {
+ rule:[RulesModel_.Rule];
+
+ // If true, will compile the regexes only on first use.
+ lazy_regex_compilation:bool = true;
+}
+
+namespace libtextclassifier3;
+table ActionsModel {
+ // Comma-separated list of locales supported by the model as BCP 47 tags.
+ locales:string;
+
+ // Version of the actions model.
+ version:int;
+
+ // A name for the model that can be used e.g. for logging.
+ name:string;
+
+ tflite_model_spec:TensorflowLiteModelSpec;
+
+ // Output classes.
+ smart_reply_action_type:string;
+
+ action_type:[ActionTypeOptions];
+
+ // Triggering conditions of the model.
+ preconditions:TriggeringPreconditions;
+
+ // Default number of smart reply predictions.
+ num_smart_replies:int = 3;
+
+ // Length of message history to consider, -1 if unbounded.
+ max_conversation_history_length:int = 1;
+
+ // Configuration for mapping annotations to action suggestions.
+ annotation_actions_spec:AnnotationActionsSpec;
+
+ // Configuration for rules.
+ rules:RulesModel;
+
+ // Configuration for intent generation on Android.
+ android_intent_options:IntentFactoryModel;
+
+ // Model resources.
+ resources:ResourcePool;
+
+ // Schema data for handling entity data.
+ actions_entity_data_schema:[ubyte];
+
+ // Action ranking options.
+ ranking_options:RankingOptions;
+
+ // Lua based actions.
+ lua_actions_script:string;
+
+ compressed_lua_actions_script:CompressedBuffer;
+
+ // Low confidence classifiers.
+ low_confidence_rules:RulesModel;
+
+ low_confidence_ngram_model:NGramLinearRegressionModel;
+
+ // Feature processor options.
+ feature_processor_options:ActionsTokenFeatureProcessorOptions;
+}
+
+root_type libtextclassifier3.ActionsModel;
diff --git a/actions/feature-processor.cc b/actions/feature-processor.cc
new file mode 100644
index 0000000..d0b2072
--- /dev/null
+++ b/actions/feature-processor.cc
@@ -0,0 +1,132 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "actions/feature-processor.h"
+
+namespace libtextclassifier3 {
+namespace {
+TokenFeatureExtractorOptions BuildTokenFeatureExtractorOptions(
+ const ActionsTokenFeatureProcessorOptions* const options) {
+ TokenFeatureExtractorOptions extractor_options;
+ extractor_options.num_buckets = options->num_buckets();
+ if (options->chargram_orders() != nullptr) {
+ for (int order : *options->chargram_orders()) {
+ extractor_options.chargram_orders.push_back(order);
+ }
+ }
+ extractor_options.max_word_length = options->max_token_length();
+ extractor_options.extract_case_feature = options->extract_case_feature();
+ extractor_options.unicode_aware_features = options->unicode_aware_features();
+ extractor_options.extract_selection_mask_feature = false;
+ if (options->regexp_features() != nullptr) {
+ for (const auto& regexp_feauture : *options->regexp_features()) {
+ extractor_options.regexp_features.push_back(regexp_feauture->str());
+ }
+ }
+ extractor_options.remap_digits = options->remap_digits();
+ extractor_options.lowercase_tokens = options->lowercase_tokens();
+ return extractor_options;
+}
+} // namespace
+
+std::unique_ptr<Tokenizer> CreateTokenizer(
+ const ActionsTokenizerOptions* options, const UniLib* unilib) {
+ std::vector<const TokenizationCodepointRange*> codepoint_config;
+ if (options->tokenization_codepoint_config() != nullptr) {
+ codepoint_config.insert(codepoint_config.end(),
+ options->tokenization_codepoint_config()->begin(),
+ options->tokenization_codepoint_config()->end());
+ }
+ std::vector<const CodepointRange*> internal_codepoint_config;
+ if (options->internal_tokenizer_codepoint_ranges() != nullptr) {
+ internal_codepoint_config.insert(
+ internal_codepoint_config.end(),
+ options->internal_tokenizer_codepoint_ranges()->begin(),
+ options->internal_tokenizer_codepoint_ranges()->end());
+ }
+ const bool tokenize_on_script_change =
+ options->tokenization_codepoint_config() != nullptr &&
+ options->tokenize_on_script_change();
+ return std::unique_ptr<Tokenizer>(new Tokenizer(
+ options->type(), unilib, codepoint_config, internal_codepoint_config,
+ tokenize_on_script_change, options->icu_preserve_whitespace_tokens()));
+}
+
+ActionsFeatureProcessor::ActionsFeatureProcessor(
+ const ActionsTokenFeatureProcessorOptions* options, const UniLib* unilib)
+ : options_(options),
+ tokenizer_(CreateTokenizer(options->tokenizer_options(), unilib)),
+ token_feature_extractor_(BuildTokenFeatureExtractorOptions(options),
+ *unilib) {}
+
+int ActionsFeatureProcessor::GetTokenEmbeddingSize() const {
+ return options_->embedding_size() +
+ token_feature_extractor_.DenseFeaturesCount();
+}
+
+bool ActionsFeatureProcessor::AppendFeatures(
+ const std::vector<int>& sparse_features,
+ const std::vector<float>& dense_features,
+ const EmbeddingExecutor* embedding_executor,
+ std::vector<float>* output_features) const {
+ // Embed the sparse features, appending them directly to the output.
+ const int embedding_size = options_->embedding_size();
+ output_features->resize(output_features->size() + embedding_size);
+ float* output_features_end =
+ output_features->data() + output_features->size();
+ if (!embedding_executor->AddEmbedding(
+ TensorView<int>(sparse_features.data(),
+ {static_cast<int>(sparse_features.size())}),
+ /*dest=*/output_features_end - embedding_size,
+ /*dest_size=*/embedding_size)) {
+ TC3_LOG(ERROR) << "Could not embed token's sparse features.";
+ return false;
+ }
+
+ // Append the dense features to the output.
+ output_features->insert(output_features->end(), dense_features.begin(),
+ dense_features.end());
+ return true;
+}
+
+bool ActionsFeatureProcessor::AppendTokenFeatures(
+ const Token& token, const EmbeddingExecutor* embedding_executor,
+ std::vector<float>* output_features) const {
+ // Extract the sparse and dense features.
+ std::vector<int> sparse_features;
+ std::vector<float> dense_features;
+ if (!token_feature_extractor_.Extract(token, /*(unused) is_in_span=*/false,
+ &sparse_features, &dense_features)) {
+ TC3_LOG(ERROR) << "Could not extract token's features.";
+ return false;
+ }
+ return AppendFeatures(sparse_features, dense_features, embedding_executor,
+ output_features);
+}
+
+bool ActionsFeatureProcessor::AppendTokenFeatures(
+ const std::vector<Token>& tokens,
+ const EmbeddingExecutor* embedding_executor,
+ std::vector<float>* output_features) const {
+ for (const Token& token : tokens) {
+ if (!AppendTokenFeatures(token, embedding_executor, output_features)) {
+ return false;
+ }
+ }
+ return true;
+}
+
+} // namespace libtextclassifier3
diff --git a/actions/feature-processor.h b/actions/feature-processor.h
new file mode 100644
index 0000000..e34ccff
--- /dev/null
+++ b/actions/feature-processor.h
@@ -0,0 +1,71 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ACTIONS_FEATURE_PROCESSOR_H_
+#define LIBTEXTCLASSIFIER_ACTIONS_FEATURE_PROCESSOR_H_
+
+#include <memory>
+
+#include "actions/actions_model_generated.h"
+#include "annotator/model-executor.h"
+#include "annotator/types.h"
+#include "utils/token-feature-extractor.h"
+#include "utils/tokenizer.h"
+#include "utils/utf8/unicodetext.h"
+#include "utils/utf8/unilib.h"
+
+namespace libtextclassifier3 {
+
+// Create tokenizer from options.
+std::unique_ptr<Tokenizer> CreateTokenizer(
+ const ActionsTokenizerOptions* options, const UniLib* unilib);
+
+// Feature processor for the actions suggestions model.
+class ActionsFeatureProcessor {
+ public:
+ ActionsFeatureProcessor(const ActionsTokenFeatureProcessorOptions* options,
+ const UniLib* unilib);
+
+ // Embeds and appends features to the output vector.
+ bool AppendFeatures(const std::vector<int>& sparse_features,
+ const std::vector<float>& dense_features,
+ const EmbeddingExecutor* embedding_executor,
+ std::vector<float>* output_features) const;
+
+ // Extracts the features of a token and appends them to the output vector.
+ bool AppendTokenFeatures(const Token& token,
+ const EmbeddingExecutor* embedding_executor,
+ std::vector<float>* output_features) const;
+
+ // Extracts the features of a vector of tokens and appends each to the output
+ // vector.
+ bool AppendTokenFeatures(const std::vector<Token>& tokens,
+ const EmbeddingExecutor* embedding_executor,
+ std::vector<float>* output_features) const;
+
+ int GetTokenEmbeddingSize() const;
+
+ const Tokenizer* tokenizer() const { return tokenizer_.get(); }
+
+ private:
+ const ActionsTokenFeatureProcessorOptions* options_;
+ const std::unique_ptr<Tokenizer> tokenizer_;
+ const TokenFeatureExtractor token_feature_extractor_;
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ACTIONS_FEATURE_PROCESSOR_H_
diff --git a/actions/feature-processor_test.cc b/actions/feature-processor_test.cc
new file mode 100644
index 0000000..0a1e3ac
--- /dev/null
+++ b/actions/feature-processor_test.cc
@@ -0,0 +1,130 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "actions/feature-processor.h"
+
+#include "actions/actions_model_generated.h"
+#include "annotator/model-executor.h"
+#include "utils/tensor-view.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+using testing::FloatEq;
+
+// EmbeddingExecutor that always returns features based on
+// the id of the sparse features.
+class FakeEmbeddingExecutor : public EmbeddingExecutor {
+ public:
+ bool AddEmbedding(const TensorView<int>& sparse_features, float* dest,
+ const int dest_size) const override {
+ TC3_CHECK_GE(dest_size, 4);
+ EXPECT_EQ(sparse_features.size(), 1);
+ dest[0] = sparse_features.data()[0];
+ dest[1] = sparse_features.data()[0];
+ dest[2] = -sparse_features.data()[0];
+ dest[3] = -sparse_features.data()[0];
+ return true;
+ }
+
+ private:
+ std::vector<float> storage_;
+};
+
+class FeatureProcessorTest : public ::testing::Test {
+ protected:
+ FeatureProcessorTest() : INIT_UNILIB_FOR_TESTING(unilib_) {}
+
+ flatbuffers::DetachedBuffer PackFeatureProcessorOptions(
+ ActionsTokenFeatureProcessorOptionsT* options) const {
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(CreateActionsTokenFeatureProcessorOptions(builder, options));
+ return builder.Release();
+ }
+
+ FakeEmbeddingExecutor embedding_executor_;
+ UniLib unilib_;
+};
+
+TEST_F(FeatureProcessorTest, TokenEmbeddings) {
+ ActionsTokenFeatureProcessorOptionsT options;
+ options.embedding_size = 4;
+ options.tokenizer_options.reset(new ActionsTokenizerOptionsT);
+
+ flatbuffers::DetachedBuffer options_fb =
+ PackFeatureProcessorOptions(&options);
+ ActionsFeatureProcessor feature_processor(
+ flatbuffers::GetRoot<ActionsTokenFeatureProcessorOptions>(
+ options_fb.data()),
+ &unilib_);
+
+ Token token("aaa", 0, 3);
+ std::vector<float> token_features;
+ EXPECT_TRUE(feature_processor.AppendTokenFeatures(token, &embedding_executor_,
+ &token_features));
+ EXPECT_EQ(token_features.size(), 4);
+}
+
+TEST_F(FeatureProcessorTest, TokenEmbeddingsCaseFeature) {
+ ActionsTokenFeatureProcessorOptionsT options;
+ options.embedding_size = 4;
+ options.extract_case_feature = true;
+ options.tokenizer_options.reset(new ActionsTokenizerOptionsT);
+
+ flatbuffers::DetachedBuffer options_fb =
+ PackFeatureProcessorOptions(&options);
+ ActionsFeatureProcessor feature_processor(
+ flatbuffers::GetRoot<ActionsTokenFeatureProcessorOptions>(
+ options_fb.data()),
+ &unilib_);
+
+ Token token("Aaa", 0, 3);
+ std::vector<float> token_features;
+ EXPECT_TRUE(feature_processor.AppendTokenFeatures(token, &embedding_executor_,
+ &token_features));
+ EXPECT_EQ(token_features.size(), 5);
+ EXPECT_THAT(token_features[4], FloatEq(1.0));
+}
+
+TEST_F(FeatureProcessorTest, MultipleTokenEmbeddingsCaseFeature) {
+ ActionsTokenFeatureProcessorOptionsT options;
+ options.embedding_size = 4;
+ options.extract_case_feature = true;
+ options.tokenizer_options.reset(new ActionsTokenizerOptionsT);
+
+ flatbuffers::DetachedBuffer options_fb =
+ PackFeatureProcessorOptions(&options);
+ ActionsFeatureProcessor feature_processor(
+ flatbuffers::GetRoot<ActionsTokenFeatureProcessorOptions>(
+ options_fb.data()),
+ &unilib_);
+
+ const std::vector<Token> tokens = {Token("Aaa", 0, 3), Token("bbb", 4, 7),
+ Token("Cccc", 8, 12)};
+ std::vector<float> token_features;
+ EXPECT_TRUE(feature_processor.AppendTokenFeatures(
+ tokens, &embedding_executor_, &token_features));
+ EXPECT_EQ(token_features.size(), 15);
+ EXPECT_THAT(token_features[4], FloatEq(1.0));
+ EXPECT_THAT(token_features[9], FloatEq(-1.0));
+ EXPECT_THAT(token_features[14], FloatEq(1.0));
+}
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/actions/lua-actions.cc b/actions/lua-actions.cc
new file mode 100644
index 0000000..5bbba98
--- /dev/null
+++ b/actions/lua-actions.cc
@@ -0,0 +1,164 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "actions/lua-actions.h"
+#include "utils/base/logging.h"
+#include "utils/lua-utils.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+#include "lauxlib.h"
+#include "lualib.h"
+#ifdef __cplusplus
+}
+#endif
+
+namespace libtextclassifier3 {
+namespace {
+TensorView<float> GetTensorViewForOutput(
+ const TfLiteModelExecutor* model_executor,
+ const tflite::Interpreter* interpreter, int output) {
+ if (output < 0 || model_executor == nullptr || interpreter == nullptr) {
+ return TensorView<float>::Invalid();
+ }
+ return model_executor->OutputView<float>(output, interpreter);
+}
+} // namespace
+
+int LuaActionsSuggestions::TensorViewIterator::Item(
+ const TensorView<float>* tensor, const int64 index,
+ lua_State* state) const {
+ lua_pushnumber(state, tensor->data()[index]);
+ return 1;
+}
+
+std::unique_ptr<LuaActionsSuggestions>
+LuaActionsSuggestions::CreateLuaActionsSuggestions(
+ const std::string& snippet, const Conversation& conversation,
+ const TfLiteModelExecutor* model_executor,
+ const TensorflowLiteModelSpec* model_spec,
+ const tflite::Interpreter* interpreter,
+ const reflection::Schema* actions_entity_data_schema,
+ const reflection::Schema* annotations_entity_data_schema) {
+ auto lua_actions =
+ std::unique_ptr<LuaActionsSuggestions>(new LuaActionsSuggestions(
+ snippet, conversation, model_executor, model_spec, interpreter,
+ actions_entity_data_schema, annotations_entity_data_schema));
+ if (!lua_actions->Initialize()) {
+ TC3_LOG(ERROR)
+ << "Could not initialize lua environment for actions suggestions.";
+ return nullptr;
+ }
+ return lua_actions;
+}
+
+LuaActionsSuggestions::LuaActionsSuggestions(
+ const std::string& snippet, const Conversation& conversation,
+ const TfLiteModelExecutor* model_executor,
+ const TensorflowLiteModelSpec* model_spec,
+ const tflite::Interpreter* interpreter,
+ const reflection::Schema* actions_entity_data_schema,
+ const reflection::Schema* annotations_entity_data_schema)
+ : snippet_(snippet),
+ conversation_(conversation),
+ conversation_iterator_(annotations_entity_data_schema, this),
+ actions_scores_(
+ model_spec == nullptr
+ ? TensorView<float>::Invalid()
+ : GetTensorViewForOutput(model_executor, interpreter,
+ model_spec->output_actions_scores())),
+ smart_reply_scores_(
+ model_spec == nullptr
+ ? TensorView<float>::Invalid()
+ : GetTensorViewForOutput(model_executor, interpreter,
+ model_spec->output_replies_scores())),
+ sensitivity_score_(model_spec == nullptr
+ ? TensorView<float>::Invalid()
+ : GetTensorViewForOutput(
+ model_executor, interpreter,
+ model_spec->output_sensitive_topic_score())),
+ triggering_score_(
+ model_spec == nullptr
+ ? TensorView<float>::Invalid()
+ : GetTensorViewForOutput(model_executor, interpreter,
+ model_spec->output_triggering_score())),
+ actions_entity_data_schema_(actions_entity_data_schema),
+ annotations_entity_data_schema_(annotations_entity_data_schema) {}
+
+bool LuaActionsSuggestions::Initialize() {
+ return RunProtected([this] {
+ LoadDefaultLibraries();
+
+ // Expose conversation message stream.
+ conversation_iterator_.NewIterator("messages",
+ &conversation_.messages, state_);
+ lua_setglobal(state_, "messages");
+
+ // Expose ML model output.
+ lua_newtable(state_);
+ {
+ tensor_iterator_.NewIterator("actions_scores", &actions_scores_,
+ state_);
+ lua_setfield(state_, /*idx=*/-2, "actions_scores");
+ }
+ {
+ tensor_iterator_.NewIterator("reply_scores", &smart_reply_scores_,
+ state_);
+ lua_setfield(state_, /*idx=*/-2, "reply_scores");
+ }
+ {
+ tensor_iterator_.NewIterator("sensitivity", &sensitivity_score_,
+ state_);
+ lua_setfield(state_, /*idx=*/-2, "sensitivity");
+ }
+ {
+ tensor_iterator_.NewIterator("triggering_score",
+ &triggering_score_, state_);
+ lua_setfield(state_, /*idx=*/-2, "triggering_score");
+ }
+ lua_setglobal(state_, "model");
+
+ return LUA_OK;
+ }) == LUA_OK;
+}
+
+bool LuaActionsSuggestions::SuggestActions(
+ std::vector<ActionSuggestion>* actions) {
+ if (luaL_loadbuffer(state_, snippet_.data(), snippet_.size(),
+ /*name=*/nullptr) != LUA_OK) {
+ TC3_LOG(ERROR) << "Could not load actions suggestions snippet.";
+ return false;
+ }
+
+ if (lua_pcall(state_, /*nargs=*/0, /*nargs=*/1, /*errfunc=*/0) != LUA_OK) {
+ TC3_LOG(ERROR) << "Could not run actions suggestions snippet.";
+ return false;
+ }
+
+ if (RunProtected(
+ [this, actions] {
+ return ReadActions(actions_entity_data_schema_,
+ annotations_entity_data_schema_, this, actions);
+ },
+ /*num_args=*/1) != LUA_OK) {
+ TC3_LOG(ERROR) << "Could not read lua result.";
+ return false;
+ }
+ return true;
+}
+
+} // namespace libtextclassifier3
diff --git a/actions/lua-actions.h b/actions/lua-actions.h
new file mode 100644
index 0000000..2f82653
--- /dev/null
+++ b/actions/lua-actions.h
@@ -0,0 +1,76 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ACTIONS_LUA_ACTIONS_H_
+#define LIBTEXTCLASSIFIER_ACTIONS_LUA_ACTIONS_H_
+
+#include "actions/actions_model_generated.h"
+#include "actions/lua-utils.h"
+#include "actions/types.h"
+#include "utils/lua-utils.h"
+#include "utils/tensor-view.h"
+#include "utils/tflite-model-executor.h"
+
+namespace libtextclassifier3 {
+
+// Lua backed actions suggestions.
+class LuaActionsSuggestions : public LuaEnvironment {
+ public:
+ static std::unique_ptr<LuaActionsSuggestions> CreateLuaActionsSuggestions(
+ const std::string& snippet, const Conversation& conversation,
+ const TfLiteModelExecutor* model_executor,
+ const TensorflowLiteModelSpec* model_spec,
+ const tflite::Interpreter* interpreter,
+ const reflection::Schema* actions_entity_data_schema,
+ const reflection::Schema* annotations_entity_data_schema);
+
+ bool SuggestActions(std::vector<ActionSuggestion>* actions);
+
+ private:
+ // Model tensor lua iterator.
+ class TensorViewIterator
+ : public LuaEnvironment::ItemIterator<TensorView<float>> {
+ public:
+ explicit TensorViewIterator() {}
+ int Item(const TensorView<float>* tensor, const int64 index,
+ lua_State* state) const override;
+ };
+
+ LuaActionsSuggestions(
+ const std::string& snippet, const Conversation& conversation,
+ const TfLiteModelExecutor* model_executor,
+ const TensorflowLiteModelSpec* model_spec,
+ const tflite::Interpreter* interpreter,
+ const reflection::Schema* actions_entity_data_schema,
+ const reflection::Schema* annotations_entity_data_schema);
+
+ bool Initialize();
+
+ const std::string& snippet_;
+ const Conversation& conversation_;
+ ConversationIterator conversation_iterator_;
+ TensorViewIterator tensor_iterator_;
+ TensorView<float> actions_scores_;
+ TensorView<float> smart_reply_scores_;
+ TensorView<float> sensitivity_score_;
+ TensorView<float> triggering_score_;
+ const reflection::Schema* actions_entity_data_schema_;
+ const reflection::Schema* annotations_entity_data_schema_;
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ACTIONS_LUA_ACTIONS_H_
diff --git a/actions/lua-actions_test.cc b/actions/lua-actions_test.cc
new file mode 100644
index 0000000..f7b9cd5
--- /dev/null
+++ b/actions/lua-actions_test.cc
@@ -0,0 +1,201 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "actions/lua-actions.h"
+
+#include <map>
+#include <string>
+
+#include "actions/test_utils.h"
+#include "actions/types.h"
+#include "utils/tflite-model-executor.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+MATCHER_P2(IsAction, type, response_text, "") {
+ return testing::Value(arg.type, type) &&
+ testing::Value(arg.response_text, response_text);
+}
+
+MATCHER_P(IsActionType, type, "") { return testing::Value(arg.type, type); }
+
+TEST(LuaActions, SimpleAction) {
+ Conversation conversation;
+ const std::string test_snippet = R"(
+ return {{ type = "test_action" }}
+ )";
+ std::vector<ActionSuggestion> actions;
+ EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
+ test_snippet, conversation,
+ /*model_executor=*/nullptr,
+ /*model_spec=*/nullptr,
+ /*interpreter=*/nullptr,
+ /*actions_entity_data_schema=*/nullptr,
+ /*annotations_entity_data_schema=*/nullptr)
+ ->SuggestActions(&actions));
+ EXPECT_THAT(actions,
+ testing::ElementsAreArray({IsActionType("test_action")}));
+}
+
+TEST(LuaActions, ConversationActions) {
+ Conversation conversation;
+ conversation.messages.push_back({/*user_id=*/0, "hello there!"});
+ conversation.messages.push_back({/*user_id=*/1, "general kenobi!"});
+ const std::string test_snippet = R"(
+ local actions = {}
+ for i, message in pairs(messages) do
+ if i < #messages then
+ if message.text == "hello there!" and
+ messages[i+1].text == "general kenobi!" then
+ table.insert(actions, {
+ type = "text_reply",
+ response_text = "you are a bold one!"
+ })
+ end
+ if message.text == "i am the senate!" and
+ messages[i+1].text == "not yet!" then
+ table.insert(actions, {
+ type = "text_reply",
+ response_text = "it's treason then"
+ })
+ end
+ end
+ end
+ return actions;
+ )";
+ std::vector<ActionSuggestion> actions;
+ EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
+ test_snippet, conversation,
+ /*model_executor=*/nullptr,
+ /*model_spec=*/nullptr,
+ /*interpreter=*/nullptr,
+ /*actions_entity_data_schema=*/nullptr,
+ /*annotations_entity_data_schema=*/nullptr)
+ ->SuggestActions(&actions));
+ EXPECT_THAT(actions, testing::ElementsAreArray(
+ {IsAction("text_reply", "you are a bold one!")}));
+}
+
+TEST(LuaActions, SimpleModelAction) {
+ Conversation conversation;
+ const std::string test_snippet = R"(
+ if #model.actions_scores == 0 then
+ return {{ type = "test_action" }}
+ end
+ return {}
+ )";
+ std::vector<ActionSuggestion> actions;
+ EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
+ test_snippet, conversation,
+ /*model_executor=*/nullptr,
+ /*model_spec=*/nullptr,
+ /*interpreter=*/nullptr,
+ /*actions_entity_data_schema=*/nullptr,
+ /*annotations_entity_data_schema=*/nullptr)
+ ->SuggestActions(&actions));
+ EXPECT_THAT(actions,
+ testing::ElementsAreArray({IsActionType("test_action")}));
+}
+
+TEST(LuaActions, AnnotationActions) {
+ AnnotatedSpan annotation;
+ annotation.span = {11, 15};
+ annotation.classification = {ClassificationResult("address", 1.0)};
+ Conversation conversation = {{{/*user_id=*/1, "are you at home?",
+ /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/{annotation},
+ /*locales=*/"en"}}};
+ const std::string test_snippet = R"(
+ local actions = {}
+ local last_message = messages[#messages]
+ for i, annotation in pairs(last_message.annotation) do
+ if #annotation.classification > 0 then
+ if annotation.classification[1].collection == "address" then
+ local text = string.sub(last_message.text,
+ annotation.span["begin"] + 1,
+ annotation.span["end"])
+ table.insert(actions, {
+ type = "text_reply",
+ response_text = "i am at " .. text,
+ annotation = {{
+ name = "location",
+ span = {
+ text = text
+ },
+ entity = annotation.classification[1]
+ }},
+ })
+ end
+ end
+ end
+ return actions;
+ )";
+ std::vector<ActionSuggestion> actions;
+ EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
+ test_snippet, conversation,
+ /*model_executor=*/nullptr,
+ /*model_spec=*/nullptr,
+ /*interpreter=*/nullptr,
+ /*actions_entity_data_schema=*/nullptr,
+ /*annotations_entity_data_schema=*/nullptr)
+ ->SuggestActions(&actions));
+ EXPECT_THAT(actions, testing::ElementsAreArray(
+ {IsAction("text_reply", "i am at home")}));
+ EXPECT_EQ("address", actions[0].annotations[0].entity.collection);
+}
+
+TEST(LuaActions, EntityData) {
+ std::string test_schema = TestEntityDataSchema();
+ Conversation conversation = {{{/*user_id=*/1, "hello there"}}};
+ const std::string test_snippet = R"(
+ return {{
+ type = "test",
+ entity = {
+ greeting = "hello",
+ location = "there",
+ person = "Kenobi",
+ },
+ }};
+ )";
+ std::vector<ActionSuggestion> actions;
+ EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
+ test_snippet, conversation,
+ /*model_executor=*/nullptr,
+ /*model_spec=*/nullptr,
+ /*interpreter=*/nullptr,
+ /*actions_entity_data_schema=*/
+ flatbuffers::GetRoot<reflection::Schema>(test_schema.data()),
+ /*annotations_entity_data_schema=*/nullptr)
+ ->SuggestActions(&actions));
+ EXPECT_THAT(actions, testing::SizeIs(1));
+ EXPECT_EQ("test", actions.front().type);
+ const flatbuffers::Table* entity =
+ flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
+ actions.front().serialized_entity_data.data()));
+ EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(),
+ "hello");
+ EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/6)->str(),
+ "there");
+ EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(),
+ "Kenobi");
+}
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/actions/lua-ranker.cc b/actions/lua-ranker.cc
new file mode 100644
index 0000000..a185b07
--- /dev/null
+++ b/actions/lua-ranker.cc
@@ -0,0 +1,117 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "actions/lua-ranker.h"
+#include "utils/base/logging.h"
+#include "utils/lua-utils.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+#include "lauxlib.h"
+#include "lualib.h"
+#ifdef __cplusplus
+}
+#endif
+
+namespace libtextclassifier3 {
+
+std::unique_ptr<ActionsSuggestionsLuaRanker>
+ActionsSuggestionsLuaRanker::Create(
+ const Conversation& conversation, const std::string& ranker_code,
+ const reflection::Schema* entity_data_schema,
+ const reflection::Schema* annotations_entity_data_schema,
+ ActionsSuggestionsResponse* response) {
+ auto ranker = std::unique_ptr<ActionsSuggestionsLuaRanker>(
+ new ActionsSuggestionsLuaRanker(
+ conversation, ranker_code, entity_data_schema,
+ annotations_entity_data_schema, response));
+ if (!ranker->Initialize()) {
+ TC3_LOG(ERROR) << "Could not initialize lua environment for ranker.";
+ return nullptr;
+ }
+ return ranker;
+}
+
+bool ActionsSuggestionsLuaRanker::Initialize() {
+ return RunProtected([this] {
+ LoadDefaultLibraries();
+
+ // Expose generated actions.
+ actions_iterator_.NewIterator("actions", &response_->actions,
+ state_);
+ lua_setglobal(state_, "actions");
+
+ // Expose conversation message stream.
+ conversation_iterator_.NewIterator("messages",
+ &conversation_.messages, state_);
+ lua_setglobal(state_, "messages");
+ return LUA_OK;
+ }) == LUA_OK;
+}
+
+int ActionsSuggestionsLuaRanker::ReadActionsRanking() {
+ if (lua_type(state_, /*idx=*/-1) != LUA_TTABLE) {
+ TC3_LOG(ERROR) << "Expected actions table, got: "
+ << lua_type(state_, /*idx=*/-1);
+ lua_pop(state_, 1);
+ lua_error(state_);
+ return LUA_ERRRUN;
+ }
+ std::vector<ActionSuggestion> ranked_actions;
+ lua_pushnil(state_);
+ while (lua_next(state_, /*idx=*/-2)) {
+ const int action_id =
+ static_cast<int>(lua_tointeger(state_, /*idx=*/-1)) - 1;
+ lua_pop(state_, 1);
+ if (action_id < 0 || action_id >= response_->actions.size()) {
+ TC3_LOG(ERROR) << "Invalid action index: " << action_id;
+ lua_error(state_);
+ return LUA_ERRRUN;
+ }
+ ranked_actions.push_back(response_->actions[action_id]);
+ }
+ lua_pop(state_, 1);
+ response_->actions = ranked_actions;
+ return LUA_OK;
+}
+
+bool ActionsSuggestionsLuaRanker::RankActions() {
+ if (response_->actions.empty()) {
+ // Nothing to do.
+ return true;
+ }
+
+ if (luaL_loadbuffer(state_, ranker_code_.data(), ranker_code_.size(),
+ /*name=*/nullptr) != LUA_OK) {
+ TC3_LOG(ERROR) << "Could not load compiled ranking snippet.";
+ return false;
+ }
+
+ if (lua_pcall(state_, /*nargs=*/0, /*nresults=*/1, /*errfunc=*/0) != LUA_OK) {
+ TC3_LOG(ERROR) << "Could not run ranking snippet.";
+ return false;
+ }
+
+ if (RunProtected([this] { return ReadActionsRanking(); }, /*num_args=*/1) !=
+ LUA_OK) {
+ TC3_LOG(ERROR) << "Could not read lua result.";
+ return false;
+ }
+ return true;
+}
+
+} // namespace libtextclassifier3
diff --git a/actions/lua-ranker.h b/actions/lua-ranker.h
new file mode 100644
index 0000000..687f412
--- /dev/null
+++ b/actions/lua-ranker.h
@@ -0,0 +1,67 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ACTIONS_LUA_RANKER_H_
+#define LIBTEXTCLASSIFIER_ACTIONS_LUA_RANKER_H_
+
+#include <memory>
+#include <string>
+
+#include "actions/lua-utils.h"
+#include "actions/types.h"
+#include "utils/lua-utils.h"
+
+namespace libtextclassifier3 {
+
+// Lua backed action suggestion ranking.
+class ActionsSuggestionsLuaRanker : public LuaEnvironment {
+ public:
+ static std::unique_ptr<ActionsSuggestionsLuaRanker> Create(
+ const Conversation& conversation, const std::string& ranker_code,
+ const reflection::Schema* entity_data_schema,
+ const reflection::Schema* annotations_entity_data_schema,
+ ActionsSuggestionsResponse* response);
+
+ bool RankActions();
+
+ private:
+ explicit ActionsSuggestionsLuaRanker(
+ const Conversation& conversation, const std::string& ranker_code,
+ const reflection::Schema* entity_data_schema,
+ const reflection::Schema* annotations_entity_data_schema,
+ ActionsSuggestionsResponse* response)
+ : conversation_(conversation),
+ ranker_code_(ranker_code),
+ response_(response),
+ actions_iterator_(entity_data_schema, annotations_entity_data_schema,
+ this),
+ conversation_iterator_(annotations_entity_data_schema, this) {}
+
+ bool Initialize();
+
+ // Reads ranking results from the lua stack.
+ int ReadActionsRanking();
+
+ const Conversation& conversation_;
+ const std::string& ranker_code_;
+ ActionsSuggestionsResponse* response_;
+ const ActionsIterator actions_iterator_;
+ const ConversationIterator conversation_iterator_;
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ACTIONS_LUA_RANKER_H_
diff --git a/actions/lua-ranker_test.cc b/actions/lua-ranker_test.cc
new file mode 100644
index 0000000..a790042
--- /dev/null
+++ b/actions/lua-ranker_test.cc
@@ -0,0 +1,269 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "actions/lua-ranker.h"
+
+#include <string>
+
+#include "actions/types.h"
+#include "utils/flatbuffers.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+MATCHER_P2(IsAction, type, response_text, "") {
+ return testing::Value(arg.type, type) &&
+ testing::Value(arg.response_text, response_text);
+}
+
+MATCHER_P(IsActionType, type, "") { return testing::Value(arg.type, type); }
+
+std::string TestEntitySchema() {
+ // Create fake entity data schema meta data.
+ // Cannot use object oriented API here as that is not available for the
+ // reflection schema.
+ flatbuffers::FlatBufferBuilder schema_builder;
+ std::vector<flatbuffers::Offset<reflection::Field>> fields = {
+ reflection::CreateField(
+ schema_builder,
+ /*name=*/schema_builder.CreateString("test"),
+ /*type=*/
+ reflection::CreateType(schema_builder,
+ /*base_type=*/reflection::String),
+ /*id=*/0,
+ /*offset=*/4)};
+ std::vector<flatbuffers::Offset<reflection::Enum>> enums;
+ std::vector<flatbuffers::Offset<reflection::Object>> objects = {
+ reflection::CreateObject(
+ schema_builder,
+ /*name=*/schema_builder.CreateString("EntityData"),
+ /*fields=*/
+ schema_builder.CreateVectorOfSortedTables(&fields))};
+ schema_builder.Finish(reflection::CreateSchema(
+ schema_builder, schema_builder.CreateVectorOfSortedTables(&objects),
+ schema_builder.CreateVectorOfSortedTables(&enums),
+ /*(unused) file_ident=*/0,
+ /*(unused) file_ext=*/0,
+ /*root_table*/ objects[0]));
+ return std::string(
+ reinterpret_cast<const char*>(schema_builder.GetBufferPointer()),
+ schema_builder.GetSize());
+}
+
+TEST(LuaRankingTest, PassThrough) {
+ const Conversation conversation = {{{/*user_id=*/1, "hello hello"}}};
+ ActionsSuggestionsResponse response;
+ response.actions = {
+ {/*response_text=*/"hello there", /*type=*/"text_reply",
+ /*score=*/1.0},
+ {/*response_text=*/"", /*type=*/"share_location", /*score=*/0.5},
+ {/*response_text=*/"", /*type=*/"add_to_collection", /*score=*/0.1}};
+ const std::string test_snippet = R"(
+ local result = {}
+ for i=1,#actions do
+ table.insert(result, i)
+ end
+ return result
+ )";
+
+ EXPECT_TRUE(ActionsSuggestionsLuaRanker::Create(
+ conversation, test_snippet, /*entity_data_schema=*/nullptr,
+ /*annotations_entity_data_schema=*/nullptr, &response)
+ ->RankActions());
+ EXPECT_THAT(response.actions,
+ testing::ElementsAreArray({IsActionType("text_reply"),
+ IsActionType("share_location"),
+ IsActionType("add_to_collection")}));
+}
+
+TEST(LuaRankingTest, Filtering) {
+ const Conversation conversation = {{{/*user_id=*/1, "hello hello"}}};
+ ActionsSuggestionsResponse response;
+ response.actions = {
+ {/*response_text=*/"hello there", /*type=*/"text_reply",
+ /*score=*/1.0},
+ {/*response_text=*/"", /*type=*/"share_location", /*score=*/0.5},
+ {/*response_text=*/"", /*type=*/"add_to_collection", /*score=*/0.1}};
+ const std::string test_snippet = R"(
+ return {}
+ )";
+
+ EXPECT_TRUE(ActionsSuggestionsLuaRanker::Create(
+ conversation, test_snippet, /*entity_data_schema=*/nullptr,
+ /*annotations_entity_data_schema=*/nullptr, &response)
+ ->RankActions());
+ EXPECT_THAT(response.actions, testing::IsEmpty());
+}
+
+TEST(LuaRankingTest, Duplication) {
+ const Conversation conversation = {{{/*user_id=*/1, "hello hello"}}};
+ ActionsSuggestionsResponse response;
+ response.actions = {
+ {/*response_text=*/"hello there", /*type=*/"text_reply",
+ /*score=*/1.0},
+ {/*response_text=*/"", /*type=*/"share_location", /*score=*/0.5},
+ {/*response_text=*/"", /*type=*/"add_to_collection", /*score=*/0.1}};
+ const std::string test_snippet = R"(
+ local result = {}
+ for i=1,#actions do
+ table.insert(result, 1)
+ end
+ return result
+ )";
+
+ EXPECT_TRUE(ActionsSuggestionsLuaRanker::Create(
+ conversation, test_snippet, /*entity_data_schema=*/nullptr,
+ /*annotations_entity_data_schema=*/nullptr, &response)
+ ->RankActions());
+ EXPECT_THAT(response.actions,
+ testing::ElementsAreArray({IsActionType("text_reply"),
+ IsActionType("text_reply"),
+ IsActionType("text_reply")}));
+}
+
+TEST(LuaRankingTest, SortByScore) {
+ const Conversation conversation = {{{/*user_id=*/1, "hello hello"}}};
+ ActionsSuggestionsResponse response;
+ response.actions = {
+ {/*response_text=*/"hello there", /*type=*/"text_reply",
+ /*score=*/1.0},
+ {/*response_text=*/"", /*type=*/"share_location", /*score=*/0.5},
+ {/*response_text=*/"", /*type=*/"add_to_collection", /*score=*/0.1}};
+ const std::string test_snippet = R"(
+ function testScoreSorter(a, b)
+ return actions[a].score < actions[b].score
+ end
+ local result = {}
+ for i=1,#actions do
+ result[i] = i
+ end
+ table.sort(result, testScoreSorter)
+ return result
+ )";
+
+ EXPECT_TRUE(ActionsSuggestionsLuaRanker::Create(
+ conversation, test_snippet, /*entity_data_schema=*/nullptr,
+ /*annotations_entity_data_schema=*/nullptr, &response)
+ ->RankActions());
+ EXPECT_THAT(response.actions,
+ testing::ElementsAreArray({IsActionType("add_to_collection"),
+ IsActionType("share_location"),
+ IsActionType("text_reply")}));
+}
+
+TEST(LuaRankingTest, SuppressType) {
+ const Conversation conversation = {{{/*user_id=*/1, "hello hello"}}};
+ ActionsSuggestionsResponse response;
+ response.actions = {
+ {/*response_text=*/"hello there", /*type=*/"text_reply",
+ /*score=*/1.0},
+ {/*response_text=*/"", /*type=*/"share_location", /*score=*/0.5},
+ {/*response_text=*/"", /*type=*/"add_to_collection", /*score=*/0.1}};
+ const std::string test_snippet = R"(
+ local result = {}
+ for id, action in pairs(actions) do
+ if action.type ~= "text_reply" then
+ table.insert(result, id)
+ end
+ end
+ return result
+ )";
+
+ EXPECT_TRUE(ActionsSuggestionsLuaRanker::Create(
+ conversation, test_snippet, /*entity_data_schema=*/nullptr,
+ /*annotations_entity_data_schema=*/nullptr, &response)
+ ->RankActions());
+ EXPECT_THAT(response.actions,
+ testing::ElementsAreArray({IsActionType("share_location"),
+ IsActionType("add_to_collection")}));
+}
+
+TEST(LuaRankingTest, HandlesConversation) {
+ const Conversation conversation = {{{/*user_id=*/1, "hello hello"}}};
+ ActionsSuggestionsResponse response;
+ response.actions = {
+ {/*response_text=*/"hello there", /*type=*/"text_reply",
+ /*score=*/1.0},
+ {/*response_text=*/"", /*type=*/"share_location", /*score=*/0.5},
+ {/*response_text=*/"", /*type=*/"add_to_collection", /*score=*/0.1}};
+ const std::string test_snippet = R"(
+ local result = {}
+ if messages[1].text ~= "hello hello" then
+ return result
+ end
+ for id, action in pairs(actions) do
+ if action.type ~= "text_reply" then
+ table.insert(result, id)
+ end
+ end
+ return result
+ )";
+
+ EXPECT_TRUE(ActionsSuggestionsLuaRanker::Create(
+ conversation, test_snippet, /*entity_data_schema=*/nullptr,
+ /*annotations_entity_data_schema=*/nullptr, &response)
+ ->RankActions());
+ EXPECT_THAT(response.actions,
+ testing::ElementsAreArray({IsActionType("share_location"),
+ IsActionType("add_to_collection")}));
+}
+
+TEST(LuaRankingTest, HandlesEntityData) {
+ std::string serialized_schema = TestEntitySchema();
+ const reflection::Schema* entity_data_schema =
+ flatbuffers::GetRoot<reflection::Schema>(serialized_schema.data());
+
+ // Create test entity data.
+ ReflectiveFlatbufferBuilder builder(entity_data_schema);
+ std::unique_ptr<ReflectiveFlatbuffer> buffer = builder.NewRoot();
+ buffer->Set("test", "value_a");
+ const std::string serialized_entity_data_a = buffer->Serialize();
+ buffer->Set("test", "value_b");
+ const std::string serialized_entity_data_b = buffer->Serialize();
+
+ const Conversation conversation = {{{/*user_id=*/1, "hello hello"}}};
+ ActionsSuggestionsResponse response;
+ response.actions = {
+ {/*response_text=*/"", /*type=*/"test",
+ /*score=*/1.0, /*priority_score=*/1.0, /*annotations=*/{},
+ /*serialized_entity_data=*/serialized_entity_data_a},
+ {/*response_text=*/"", /*type=*/"test",
+ /*score=*/1.0, /*priority_score=*/1.0, /*annotations=*/{},
+ /*serialized_entity_data=*/serialized_entity_data_b},
+ {/*response_text=*/"", /*type=*/"share_location", /*score=*/0.5},
+ {/*response_text=*/"", /*type=*/"add_to_collection", /*score=*/0.1}};
+ const std::string test_snippet = R"(
+ local result = {}
+ for id, action in pairs(actions) do
+ if action.type == "test" and action.test == "value_a" then
+ table.insert(result, id)
+ end
+ end
+ return result
+ )";
+
+ EXPECT_TRUE(ActionsSuggestionsLuaRanker::Create(
+ conversation, test_snippet, entity_data_schema,
+ /*annotations_entity_data_schema=*/nullptr, &response)
+ ->RankActions());
+ EXPECT_THAT(response.actions,
+ testing::ElementsAreArray({IsActionType("test")}));
+}
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/actions/lua-utils.cc b/actions/lua-utils.cc
new file mode 100644
index 0000000..edeadf9
--- /dev/null
+++ b/actions/lua-utils.cc
@@ -0,0 +1,354 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "actions/lua-utils.h"
+
+namespace libtextclassifier3 {
+namespace {
+static constexpr const char* kTextKey = "text";
+static constexpr const char* kTimeUsecKey = "parsed_time_ms_utc";
+static constexpr const char* kGranularityKey = "granularity";
+static constexpr const char* kCollectionKey = "collection";
+static constexpr const char* kNameKey = "name";
+static constexpr const char* kScoreKey = "score";
+static constexpr const char* kPriorityScoreKey = "priority_score";
+static constexpr const char* kTypeKey = "type";
+static constexpr const char* kResponseTextKey = "response_text";
+static constexpr const char* kAnnotationKey = "annotation";
+static constexpr const char* kSpanKey = "span";
+static constexpr const char* kMessageKey = "message";
+static constexpr const char* kBeginKey = "begin";
+static constexpr const char* kEndKey = "end";
+static constexpr const char* kClassificationKey = "classification";
+static constexpr const char* kSerializedEntity = "serialized_entity";
+static constexpr const char* kEntityKey = "entity";
+} // namespace
+
+template <>
+int AnnotationIterator<ClassificationResult>::Item(
+ const std::vector<ClassificationResult>* annotations, StringPiece key,
+ lua_State* state) const {
+ // Lookup annotation by collection.
+ for (const ClassificationResult& annotation : *annotations) {
+ if (key.Equals(annotation.collection)) {
+ PushAnnotation(annotation, entity_data_schema_, env_);
+ return 1;
+ }
+ }
+ TC3_LOG(ERROR) << "No annotation with collection: " << key.ToString()
+ << " found.";
+ lua_error(state);
+ return 0;
+}
+
+template <>
+int AnnotationIterator<ActionSuggestionAnnotation>::Item(
+ const std::vector<ActionSuggestionAnnotation>* annotations, StringPiece key,
+ lua_State* state) const {
+ // Lookup annotation by name.
+ for (const ActionSuggestionAnnotation& annotation : *annotations) {
+ if (key.Equals(annotation.name)) {
+ PushAnnotation(annotation, entity_data_schema_, env_);
+ return 1;
+ }
+ }
+ TC3_LOG(ERROR) << "No annotation with name: " << key.ToString() << " found.";
+ lua_error(state);
+ return 0;
+}
+
+void PushAnnotation(const ClassificationResult& classification,
+ const reflection::Schema* entity_data_schema,
+ LuaEnvironment* env) {
+ if (entity_data_schema == nullptr ||
+ classification.serialized_entity_data.empty()) {
+ // Empty table.
+ lua_newtable(env->state());
+ } else {
+ env->PushFlatbuffer(entity_data_schema,
+ flatbuffers::GetRoot<flatbuffers::Table>(
+ classification.serialized_entity_data.data()));
+ }
+ lua_pushinteger(env->state(),
+ classification.datetime_parse_result.time_ms_utc);
+ lua_setfield(env->state(), /*idx=*/-2, kTimeUsecKey);
+ lua_pushinteger(env->state(),
+ classification.datetime_parse_result.granularity);
+ lua_setfield(env->state(), /*idx=*/-2, kGranularityKey);
+ env->PushString(classification.collection);
+ lua_setfield(env->state(), /*idx=*/-2, kCollectionKey);
+ lua_pushnumber(env->state(), classification.score);
+ lua_setfield(env->state(), /*idx=*/-2, kScoreKey);
+ env->PushString(classification.serialized_entity_data);
+ lua_setfield(env->state(), /*idx=*/-2, kSerializedEntity);
+}
+
+void PushAnnotation(const ClassificationResult& classification,
+ StringPiece text,
+ const reflection::Schema* entity_data_schema,
+ LuaEnvironment* env) {
+ PushAnnotation(classification, entity_data_schema, env);
+ env->PushString(text);
+ lua_setfield(env->state(), /*idx=*/-2, kTextKey);
+}
+
+void PushAnnotatedSpan(
+ const AnnotatedSpan& annotated_span,
+ const AnnotationIterator<ClassificationResult>& annotation_iterator,
+ LuaEnvironment* env) {
+ lua_newtable(env->state());
+ {
+ lua_newtable(env->state());
+ lua_pushinteger(env->state(), annotated_span.span.first);
+ lua_setfield(env->state(), /*idx=*/-2, kBeginKey);
+ lua_pushinteger(env->state(), annotated_span.span.second);
+ lua_setfield(env->state(), /*idx=*/-2, kEndKey);
+ }
+ lua_setfield(env->state(), /*idx=*/-2, kSpanKey);
+ annotation_iterator.NewIterator(kClassificationKey,
+ &annotated_span.classification, env->state());
+ lua_setfield(env->state(), /*idx=*/-2, kClassificationKey);
+}
+
+MessageTextSpan ReadSpan(LuaEnvironment* env) {
+ MessageTextSpan span;
+ lua_pushnil(env->state());
+ while (lua_next(env->state(), /*idx=*/-2)) {
+ const StringPiece key = env->ReadString(/*index=*/-2);
+ if (key.Equals(kMessageKey)) {
+ span.message_index =
+ static_cast<int>(lua_tonumber(env->state(), /*idx=*/-1));
+ } else if (key.Equals(kBeginKey)) {
+ span.span.first =
+ static_cast<int>(lua_tonumber(env->state(), /*idx=*/-1));
+ } else if (key.Equals(kEndKey)) {
+ span.span.second =
+ static_cast<int>(lua_tonumber(env->state(), /*idx=*/-1));
+ } else if (key.Equals(kTextKey)) {
+ span.text = env->ReadString(/*index=*/-1).ToString();
+ } else {
+ TC3_LOG(INFO) << "Unknown span field: " << key.ToString();
+ }
+ lua_pop(env->state(), 1);
+ }
+ return span;
+}
+
+int ReadAnnotations(const reflection::Schema* entity_data_schema,
+ LuaEnvironment* env,
+ std::vector<ActionSuggestionAnnotation>* annotations) {
+ if (lua_type(env->state(), /*idx=*/-1) != LUA_TTABLE) {
+ TC3_LOG(ERROR) << "Expected annotations table, got: "
+ << lua_type(env->state(), /*idx=*/-1);
+ lua_pop(env->state(), 1);
+ lua_error(env->state());
+ return LUA_ERRRUN;
+ }
+
+ // Read actions.
+ lua_pushnil(env->state());
+ while (lua_next(env->state(), /*idx=*/-2)) {
+ if (lua_type(env->state(), /*idx=*/-1) != LUA_TTABLE) {
+ TC3_LOG(ERROR) << "Expected annotation table, got: "
+ << lua_type(env->state(), /*idx=*/-1);
+ lua_pop(env->state(), 1);
+ continue;
+ }
+ annotations->push_back(ReadAnnotation(entity_data_schema, env));
+ lua_pop(env->state(), 1);
+ }
+ return LUA_OK;
+}
+
+ActionSuggestionAnnotation ReadAnnotation(
+ const reflection::Schema* entity_data_schema, LuaEnvironment* env) {
+ ActionSuggestionAnnotation annotation;
+ lua_pushnil(env->state());
+ while (lua_next(env->state(), /*idx=*/-2)) {
+ const StringPiece key = env->ReadString(/*index=*/-2);
+ if (key.Equals(kNameKey)) {
+ annotation.name = env->ReadString(/*index=*/-1).ToString();
+ } else if (key.Equals(kSpanKey)) {
+ annotation.span = ReadSpan(env);
+ } else if (key.Equals(kEntityKey)) {
+ annotation.entity = ReadClassificationResult(entity_data_schema, env);
+ } else {
+ TC3_LOG(ERROR) << "Unknown annotation field: " << key.ToString();
+ }
+ lua_pop(env->state(), 1);
+ }
+ return annotation;
+}
+
+ClassificationResult ReadClassificationResult(
+ const reflection::Schema* entity_data_schema, LuaEnvironment* env) {
+ ClassificationResult classification;
+ lua_pushnil(env->state());
+ while (lua_next(env->state(), /*idx=*/-2)) {
+ const StringPiece key = env->ReadString(/*index=*/-2);
+ if (key.Equals(kCollectionKey)) {
+ classification.collection = env->ReadString(/*index=*/-1).ToString();
+ } else if (key.Equals(kScoreKey)) {
+ classification.score =
+ static_cast<float>(lua_tonumber(env->state(), /*idx=*/-1));
+ } else if (key.Equals(kTimeUsecKey)) {
+ classification.datetime_parse_result.time_ms_utc =
+ static_cast<int64>(lua_tonumber(env->state(), /*idx=*/-1));
+ } else if (key.Equals(kGranularityKey)) {
+ classification.datetime_parse_result.granularity =
+ static_cast<DatetimeGranularity>(
+ lua_tonumber(env->state(), /*idx=*/-1));
+ } else if (key.Equals(kSerializedEntity)) {
+ classification.serialized_entity_data =
+ env->ReadString(/*index=*/-1).ToString();
+ } else if (key.Equals(kEntityKey)) {
+ auto buffer = ReflectiveFlatbufferBuilder(entity_data_schema).NewRoot();
+ env->ReadFlatbuffer(buffer.get());
+ classification.serialized_entity_data = buffer->Serialize();
+ } else {
+ TC3_LOG(INFO) << "Unknown classification result field: "
+ << key.ToString();
+ }
+ lua_pop(env->state(), 1);
+ }
+ return classification;
+}
+
+void PushAnnotation(const ActionSuggestionAnnotation& annotation,
+ const reflection::Schema* entity_data_schema,
+ LuaEnvironment* env) {
+ PushAnnotation(annotation.entity, annotation.span.text, entity_data_schema,
+ env);
+ env->PushString(annotation.name);
+ lua_setfield(env->state(), /*idx=*/-2, kNameKey);
+ {
+ lua_newtable(env->state());
+ lua_pushinteger(env->state(), annotation.span.message_index);
+ lua_setfield(env->state(), /*idx=*/-2, kMessageKey);
+ lua_pushinteger(env->state(), annotation.span.span.first);
+ lua_setfield(env->state(), /*idx=*/-2, kBeginKey);
+ lua_pushinteger(env->state(), annotation.span.span.second);
+ lua_setfield(env->state(), /*idx=*/-2, kEndKey);
+ }
+ lua_setfield(env->state(), /*idx=*/-2, kSpanKey);
+}
+
+void PushAction(
+ const ActionSuggestion& action,
+ const reflection::Schema* entity_data_schema,
+ const AnnotationIterator<ActionSuggestionAnnotation>& annotation_iterator,
+ LuaEnvironment* env) {
+ if (entity_data_schema == nullptr || action.serialized_entity_data.empty()) {
+ // Empty table.
+ lua_newtable(env->state());
+ } else {
+ env->PushFlatbuffer(entity_data_schema,
+ flatbuffers::GetRoot<flatbuffers::Table>(
+ action.serialized_entity_data.data()));
+ }
+ env->PushString(action.type);
+ lua_setfield(env->state(), /*idx=*/-2, kTypeKey);
+ env->PushString(action.response_text);
+ lua_setfield(env->state(), /*idx=*/-2, kResponseTextKey);
+ lua_pushnumber(env->state(), action.score);
+ lua_setfield(env->state(), /*idx=*/-2, kScoreKey);
+ lua_pushnumber(env->state(), action.priority_score);
+ lua_setfield(env->state(), /*idx=*/-2, kPriorityScoreKey);
+ annotation_iterator.NewIterator(kAnnotationKey, &action.annotations,
+ env->state());
+ lua_setfield(env->state(), /*idx=*/-2, kAnnotationKey);
+}
+
+ActionSuggestion ReadAction(
+ const reflection::Schema* actions_entity_data_schema,
+ const reflection::Schema* annotations_entity_data_schema,
+ LuaEnvironment* env) {
+ ActionSuggestion action;
+ lua_pushnil(env->state());
+ while (lua_next(env->state(), /*idx=*/-2)) {
+ const StringPiece key = env->ReadString(/*index=*/-2);
+ if (key.Equals(kResponseTextKey)) {
+ action.response_text = env->ReadString(/*index=*/-1).ToString();
+ } else if (key.Equals(kTypeKey)) {
+ action.type = env->ReadString(/*index=*/-1).ToString();
+ } else if (key.Equals(kScoreKey)) {
+ action.score = static_cast<float>(lua_tonumber(env->state(), /*idx=*/-1));
+ } else if (key.Equals(kPriorityScoreKey)) {
+ action.priority_score =
+ static_cast<float>(lua_tonumber(env->state(), /*idx=*/-1));
+ } else if (key.Equals(kAnnotationKey)) {
+ ReadAnnotations(actions_entity_data_schema, env, &action.annotations);
+ } else if (key.Equals(kEntityKey)) {
+ auto buffer =
+ ReflectiveFlatbufferBuilder(actions_entity_data_schema).NewRoot();
+ env->ReadFlatbuffer(buffer.get());
+ action.serialized_entity_data = buffer->Serialize();
+ } else {
+ TC3_LOG(INFO) << "Unknown action field: " << key.ToString();
+ }
+ lua_pop(env->state(), 1);
+ }
+ return action;
+}
+
+int ReadActions(const reflection::Schema* actions_entity_data_schema,
+ const reflection::Schema* annotations_entity_data_schema,
+ LuaEnvironment* env, std::vector<ActionSuggestion>* actions) {
+ if (lua_type(env->state(), /*idx=*/-1) != LUA_TTABLE) {
+ TC3_LOG(ERROR) << "Expected actions table, got: "
+ << lua_type(env->state(), /*idx=*/-1);
+ lua_pop(env->state(), 1);
+ lua_error(env->state());
+ return LUA_ERRRUN;
+ }
+
+ // Read actions.
+ lua_pushnil(env->state());
+ while (lua_next(env->state(), /*idx=*/-2)) {
+ if (lua_type(env->state(), /*idx=*/-1) != LUA_TTABLE) {
+ TC3_LOG(ERROR) << "Expected action table, got: "
+ << lua_type(env->state(), /*idx=*/-1);
+ lua_pop(env->state(), 1);
+ continue;
+ }
+ actions->push_back(ReadAction(actions_entity_data_schema,
+ annotations_entity_data_schema, env));
+ lua_pop(env->state(), /*n=1*/ 1);
+ }
+ lua_pop(env->state(), /*n=*/1);
+
+ return LUA_OK;
+}
+
+int ConversationIterator::Item(const std::vector<ConversationMessage>* messages,
+ const int64 pos, lua_State* state) const {
+ const ConversationMessage& message = (*messages)[pos];
+ lua_newtable(state);
+ lua_pushinteger(state, message.user_id);
+ lua_setfield(state, /*idx=*/-2, "user_id");
+ env_->PushString(message.text);
+ lua_setfield(state, /*idx=*/-2, "text");
+ lua_pushinteger(state, message.reference_time_ms_utc);
+ lua_setfield(state, /*idx=*/-2, "time_ms_utc");
+ env_->PushString(message.reference_timezone);
+ lua_setfield(state, /*idx=*/-2, "timezone");
+ annotated_span_iterator_.NewIterator("annotation", &message.annotations,
+ state);
+ lua_setfield(state, /*idx=*/-2, "annotation");
+ return 1;
+}
+
+} // namespace libtextclassifier3
diff --git a/actions/lua-utils.h b/actions/lua-utils.h
new file mode 100644
index 0000000..4f06674
--- /dev/null
+++ b/actions/lua-utils.h
@@ -0,0 +1,182 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ACTIONS_LUA_UTILS_H_
+#define LIBTEXTCLASSIFIER_ACTIONS_LUA_UTILS_H_
+
+#include "actions/types.h"
+#include "annotator/types.h"
+#include "utils/flatbuffers.h"
+#include "utils/lua-utils.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+#include "lauxlib.h"
+#include "lua.h"
+#include "lualib.h"
+#ifdef __cplusplus
+}
+#endif
+
+// Action specific shared lua utilities.
+namespace libtextclassifier3 {
+
+// Provides an annotation to lua.
+void PushAnnotation(const ClassificationResult& classification,
+ const reflection::Schema* entity_data_schema,
+ LuaEnvironment* env);
+void PushAnnotation(const ClassificationResult& classification,
+ StringPiece text,
+ const reflection::Schema* entity_data_schema,
+ LuaEnvironment* env);
+void PushAnnotation(const ActionSuggestionAnnotation& annotation,
+ const reflection::Schema* entity_data_schema,
+ LuaEnvironment* env);
+
+// A lua iterator to enumerate annotation.
+template <typename Annotation>
+class AnnotationIterator
+ : public LuaEnvironment::ItemIterator<std::vector<Annotation>> {
+ public:
+ AnnotationIterator(const reflection::Schema* entity_data_schema,
+ LuaEnvironment* env)
+ : env_(env), entity_data_schema_(entity_data_schema) {}
+ int Item(const std::vector<Annotation>* annotations, const int64 pos,
+ lua_State* state) const override {
+ PushAnnotation((*annotations)[pos], entity_data_schema_, env_);
+ return 1;
+ }
+ int Item(const std::vector<Annotation>* annotations, StringPiece key,
+ lua_State* state) const override;
+
+ private:
+ LuaEnvironment* env_;
+ const reflection::Schema* entity_data_schema_;
+};
+
+template <>
+int AnnotationIterator<ClassificationResult>::Item(
+ const std::vector<ClassificationResult>* annotations, StringPiece key,
+ lua_State* state) const;
+
+template <>
+int AnnotationIterator<ActionSuggestionAnnotation>::Item(
+ const std::vector<ActionSuggestionAnnotation>* annotations, StringPiece key,
+ lua_State* state) const;
+
+void PushAnnotatedSpan(
+ const AnnotatedSpan& annotated_span,
+ const AnnotationIterator<ClassificationResult>& annotation_iterator,
+ LuaEnvironment* env);
+
+MessageTextSpan ReadSpan(LuaEnvironment* env);
+ActionSuggestionAnnotation ReadAnnotation(
+ const reflection::Schema* entity_data_schema, LuaEnvironment* env);
+int ReadAnnotations(const reflection::Schema* entity_data_schema,
+ LuaEnvironment* env,
+ std::vector<ActionSuggestionAnnotation>* annotations);
+ClassificationResult ReadClassificationResult(
+ const reflection::Schema* entity_data_schema, LuaEnvironment* env);
+
+// A lua iterator to enumerate annotated spans.
+class AnnotatedSpanIterator
+ : public LuaEnvironment::ItemIterator<std::vector<AnnotatedSpan>> {
+ public:
+ AnnotatedSpanIterator(
+ const AnnotationIterator<ClassificationResult>& annotation_iterator,
+ LuaEnvironment* env)
+ : env_(env), annotation_iterator_(annotation_iterator) {}
+ AnnotatedSpanIterator(const reflection::Schema* entity_data_schema,
+ LuaEnvironment* env)
+ : env_(env), annotation_iterator_(entity_data_schema, env) {}
+
+ int Item(const std::vector<AnnotatedSpan>* spans, const int64 pos,
+ lua_State* state) const override {
+ PushAnnotatedSpan((*spans)[pos], annotation_iterator_, env_);
+ return /*num results=*/1;
+ }
+
+ private:
+ LuaEnvironment* env_;
+ AnnotationIterator<ClassificationResult> annotation_iterator_;
+};
+
+// Provides an action to lua.
+void PushAction(
+ const ActionSuggestion& action,
+ const reflection::Schema* entity_data_schema,
+ const AnnotationIterator<ActionSuggestionAnnotation>& annotation_iterator,
+ LuaEnvironment* env);
+
+ActionSuggestion ReadAction(
+ const reflection::Schema* actions_entity_data_schema,
+ const reflection::Schema* annotations_entity_data_schema,
+ LuaEnvironment* env);
+int ReadActions(const reflection::Schema* actions_entity_data_schema,
+ const reflection::Schema* annotations_entity_data_schema,
+ LuaEnvironment* env, std::vector<ActionSuggestion>* actions);
+
+// A lua iterator to enumerate actions suggestions.
+class ActionsIterator
+ : public LuaEnvironment::ItemIterator<std::vector<ActionSuggestion>> {
+ public:
+ ActionsIterator(const reflection::Schema* entity_data_schema,
+ const reflection::Schema* annotations_entity_data_schema,
+ LuaEnvironment* env)
+ : env_(env),
+ entity_data_schema_(entity_data_schema),
+ annotation_iterator_(annotations_entity_data_schema, env) {}
+ int Item(const std::vector<ActionSuggestion>* actions, const int64 pos,
+ lua_State* state) const override {
+ PushAction((*actions)[pos], entity_data_schema_, annotation_iterator_,
+ env_);
+ return /*num results=*/1;
+ }
+
+ private:
+ LuaEnvironment* env_;
+ const reflection::Schema* entity_data_schema_;
+ AnnotationIterator<ActionSuggestionAnnotation> annotation_iterator_;
+};
+
+// Conversation message lua iterator.
+class ConversationIterator
+ : public LuaEnvironment::ItemIterator<std::vector<ConversationMessage>> {
+ public:
+ ConversationIterator(
+ const AnnotationIterator<ClassificationResult>& annotation_iterator,
+ LuaEnvironment* env)
+ : env_(env),
+ annotated_span_iterator_(
+ AnnotatedSpanIterator(annotation_iterator, env)) {}
+ ConversationIterator(const reflection::Schema* entity_data_schema,
+ LuaEnvironment* env)
+ : env_(env),
+ annotated_span_iterator_(
+ AnnotatedSpanIterator(entity_data_schema, env)) {}
+
+ int Item(const std::vector<ConversationMessage>* messages, const int64 pos,
+ lua_State* state) const override;
+
+ private:
+ LuaEnvironment* env_;
+ AnnotatedSpanIterator annotated_span_iterator_;
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ACTIONS_LUA_UTILS_H_
diff --git a/actions/ngram-model.cc b/actions/ngram-model.cc
new file mode 100644
index 0000000..2263617
--- /dev/null
+++ b/actions/ngram-model.cc
@@ -0,0 +1,209 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "actions/ngram-model.h"
+
+#include <algorithm>
+
+#include "actions/feature-processor.h"
+#include "utils/hash/farmhash.h"
+#include "utils/strings/stringpiece.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+// An iterator to iterate over the initial tokens of the n-grams of a model.
+class FirstTokenIterator
+ : public std::iterator<std::random_access_iterator_tag,
+ /*value_type=*/uint32, /*difference_type=*/ptrdiff_t,
+ /*pointer=*/const uint32*,
+ /*reference=*/uint32&> {
+ public:
+ explicit FirstTokenIterator(const NGramLinearRegressionModel* model,
+ int index)
+ : model_(model), index_(index) {}
+
+ FirstTokenIterator& operator++() {
+ index_++;
+ return *this;
+ }
+ FirstTokenIterator& operator+=(ptrdiff_t dist) {
+ index_ += dist;
+ return *this;
+ }
+ ptrdiff_t operator-(const FirstTokenIterator& other_it) const {
+ return index_ - other_it.index_;
+ }
+ uint32 operator*() const {
+ const uint32 token_offset = (*model_->ngram_start_offsets())[index_];
+ return (*model_->hashed_ngram_tokens())[token_offset];
+ }
+ int index() const { return index_; }
+
+ private:
+ const NGramLinearRegressionModel* model_;
+ int index_;
+};
+
+} // anonymous namespace
+
+std::unique_ptr<NGramModel> NGramModel::Create(
+ const NGramLinearRegressionModel* model, const Tokenizer* tokenizer,
+ const UniLib* unilib) {
+ if (model == nullptr) {
+ return nullptr;
+ }
+ if (tokenizer == nullptr && model->tokenizer_options() == nullptr) {
+ TC3_LOG(ERROR) << "No tokenizer options specified.";
+ return nullptr;
+ }
+ return std::unique_ptr<NGramModel>(new NGramModel(model, tokenizer, unilib));
+}
+
+NGramModel::NGramModel(const NGramLinearRegressionModel* model,
+ const Tokenizer* tokenizer, const UniLib* unilib)
+ : model_(model) {
+ // Create new tokenizer if options are specified, reuse feature processor
+ // tokenizer otherwise.
+ if (model->tokenizer_options() != nullptr) {
+ owned_tokenizer_ = CreateTokenizer(model->tokenizer_options(), unilib);
+ tokenizer_ = owned_tokenizer_.get();
+ } else {
+ tokenizer_ = tokenizer;
+ }
+}
+
+// Returns whether a given n-gram matches the token stream.
+bool NGramModel::IsNGramMatch(const uint32* tokens, size_t num_tokens,
+ const uint32* ngram_tokens,
+ size_t num_ngram_tokens, int max_skips) const {
+ int token_idx = 0, ngram_token_idx = 0, skip_remain = 0;
+ for (; token_idx < num_tokens && ngram_token_idx < num_ngram_tokens;) {
+ if (tokens[token_idx] == ngram_tokens[ngram_token_idx]) {
+ // Token matches. Advance both and reset the skip budget.
+ ++token_idx;
+ ++ngram_token_idx;
+ skip_remain = max_skips;
+ } else if (skip_remain > 0) {
+ // No match, but we have skips left, so just advance over the token.
+ ++token_idx;
+ skip_remain--;
+ } else {
+ // No match and we're out of skips. Reject.
+ return false;
+ }
+ }
+ return ngram_token_idx == num_ngram_tokens;
+}
+
+// Calculates the total number of skip-grams that can be created for a stream
+// with the given number of tokens.
+uint64 NGramModel::GetNumSkipGrams(int num_tokens, int max_ngram_length,
+ int max_skips) {
+ // Start with unigrams.
+ uint64 total = num_tokens;
+ for (int ngram_len = 2;
+ ngram_len <= max_ngram_length && ngram_len <= num_tokens; ++ngram_len) {
+ // We can easily compute the expected length of the n-gram (with skips),
+ // but it doesn't account for the fact that they may be longer than the
+ // input and should be pruned.
+ // Instead, we iterate over the distribution of effective n-gram lengths
+ // and add each length individually.
+ const int num_gaps = ngram_len - 1;
+ const int len_min = ngram_len;
+ const int len_max = ngram_len + num_gaps * max_skips;
+ const int len_mid = (len_max + len_min) / 2;
+ for (int len_i = len_min; len_i <= len_max; ++len_i) {
+ if (len_i > num_tokens) continue;
+ const int num_configs_of_len_i =
+ len_i <= len_mid ? len_i - len_min + 1 : len_max - len_i + 1;
+ const int num_start_offsets = num_tokens - len_i + 1;
+ total += num_configs_of_len_i * num_start_offsets;
+ }
+ }
+ return total;
+}
+
+std::pair<int, int> NGramModel::GetFirstTokenMatches(uint32 token_hash) const {
+ const int num_ngrams = model_->ngram_weights()->size();
+ const auto start_it = FirstTokenIterator(model_, 0);
+ const auto end_it = FirstTokenIterator(model_, num_ngrams);
+ const int start = std::lower_bound(start_it, end_it, token_hash).index();
+ const int end = std::upper_bound(start_it, end_it, token_hash).index();
+ return std::make_pair(start, end);
+}
+
+bool NGramModel::Eval(const UnicodeText& text, float* score) const {
+ const std::vector<Token> raw_tokens = tokenizer_->Tokenize(text);
+
+ // If we have no tokens, then just bail early.
+ if (raw_tokens.empty()) {
+ if (score != nullptr) {
+ *score = model_->default_token_weight();
+ }
+ return false;
+ }
+
+ // Hash the tokens.
+ std::vector<uint32> tokens;
+ tokens.reserve(raw_tokens.size());
+ for (const Token& raw_token : raw_tokens) {
+ tokens.push_back(tc3farmhash::Fingerprint32(raw_token.value.data(),
+ raw_token.value.length()));
+ }
+
+ // Calculate the total number of skip-grams that can be generated for the
+ // input text.
+ const uint64 num_candidates = GetNumSkipGrams(
+ tokens.size(), model_->max_denom_ngram_length(), model_->max_skips());
+
+ // For each token, see whether it denotes the start of an n-gram in the model.
+ int num_matches = 0;
+ float weight_matches = 0.f;
+ for (size_t start_i = 0; start_i < tokens.size(); ++start_i) {
+ const std::pair<int, int> ngram_range =
+ GetFirstTokenMatches(tokens[start_i]);
+ for (int ngram_idx = ngram_range.first; ngram_idx < ngram_range.second;
+ ++ngram_idx) {
+ const uint16 ngram_tokens_begin =
+ (*model_->ngram_start_offsets())[ngram_idx];
+ const uint16 ngram_tokens_end =
+ (*model_->ngram_start_offsets())[ngram_idx + 1];
+ if (IsNGramMatch(
+ /*tokens=*/tokens.data() + start_i,
+ /*num_tokens=*/tokens.size() - start_i,
+ /*ngram_tokens=*/model_->hashed_ngram_tokens()->data() +
+ ngram_tokens_begin,
+ /*num_ngram_tokens=*/ngram_tokens_end - ngram_tokens_begin,
+ /*max_skips=*/model_->max_skips())) {
+ ++num_matches;
+ weight_matches += (*model_->ngram_weights())[ngram_idx];
+ }
+ }
+ }
+
+ // Calculate the score.
+ const int num_misses = num_candidates - num_matches;
+ const float internal_score =
+ (weight_matches + (model_->default_token_weight() * num_misses)) /
+ num_candidates;
+ if (score != nullptr) {
+ *score = internal_score;
+ }
+ return internal_score > model_->threshold();
+}
+
+} // namespace libtextclassifier3
diff --git a/actions/ngram-model.h b/actions/ngram-model.h
new file mode 100644
index 0000000..ec0b606
--- /dev/null
+++ b/actions/ngram-model.h
@@ -0,0 +1,64 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ACTIONS_NGRAM_MODEL_H_
+#define LIBTEXTCLASSIFIER_ACTIONS_NGRAM_MODEL_H_
+
+#include <memory>
+
+#include "actions/actions_model_generated.h"
+#include "utils/tokenizer.h"
+#include "utils/utf8/unicodetext.h"
+#include "utils/utf8/unilib.h"
+
+namespace libtextclassifier3 {
+
+class NGramModel {
+ public:
+ static std::unique_ptr<NGramModel> Create(
+ const NGramLinearRegressionModel* model, const Tokenizer* tokenizer,
+ const UniLib* unilib);
+
+ // Evaluates an n-gram linear regression model, and tests against the
+ // threshold. Returns true in case of a positive classification. The caller
+ // may also optionally query the score.
+ bool Eval(const UnicodeText& text, float* score = nullptr) const;
+
+ // Exposed for testing only.
+ static uint64 GetNumSkipGrams(int num_tokens, int max_ngram_length,
+ int max_skips);
+
+ private:
+ NGramModel(const NGramLinearRegressionModel* model,
+ const Tokenizer* tokenizer, const UniLib* unilib);
+
+ // Returns the (begin,end] range of n-grams where the first hashed token
+ // matches the given value.
+ std::pair<int, int> GetFirstTokenMatches(uint32 token_hash) const;
+
+ // Returns whether a given n-gram matches the token stream.
+ bool IsNGramMatch(const uint32* tokens, size_t num_tokens,
+ const uint32* ngram_tokens, size_t num_ngram_tokens,
+ int max_skips) const;
+
+ const NGramLinearRegressionModel* model_;
+ const Tokenizer* tokenizer_;
+ std::unique_ptr<Tokenizer> owned_tokenizer_;
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ACTIONS_NGRAM_MODEL_H_
diff --git a/actions/ranker.cc b/actions/ranker.cc
new file mode 100644
index 0000000..5a03da5
--- /dev/null
+++ b/actions/ranker.cc
@@ -0,0 +1,353 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "actions/ranker.h"
+
+#include <functional>
+#include <set>
+#include <vector>
+
+#include "actions/lua-ranker.h"
+#include "actions/zlib-utils.h"
+#include "annotator/types.h"
+#include "utils/base/logging.h"
+#include "utils/lua-utils.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+void SortByScoreAndType(std::vector<ActionSuggestion>* actions) {
+ std::sort(actions->begin(), actions->end(),
+ [](const ActionSuggestion& a, const ActionSuggestion& b) {
+ return a.score > b.score ||
+ (a.score >= b.score && a.type < b.type);
+ });
+}
+
+template <typename T>
+int Compare(const T& left, const T& right) {
+ if (left < right) {
+ return -1;
+ }
+ if (left > right) {
+ return 1;
+ }
+ return 0;
+}
+
+template <>
+int Compare(const std::string& left, const std::string& right) {
+ return left.compare(right);
+}
+
+template <>
+int Compare(const MessageTextSpan& span, const MessageTextSpan& other) {
+ if (const int value = Compare(span.message_index, other.message_index)) {
+ return value;
+ }
+ if (const int value = Compare(span.span.first, other.span.first)) {
+ return value;
+ }
+ if (const int value = Compare(span.span.second, other.span.second)) {
+ return value;
+ }
+ return 0;
+}
+
+bool IsSameSpan(const MessageTextSpan& span, const MessageTextSpan& other) {
+ return Compare(span, other) == 0;
+}
+
+bool TextSpansIntersect(const MessageTextSpan& span,
+ const MessageTextSpan& other) {
+ return span.message_index == other.message_index &&
+ SpansOverlap(span.span, other.span);
+}
+
+template <>
+int Compare(const ActionSuggestionAnnotation& annotation,
+ const ActionSuggestionAnnotation& other) {
+ if (const int value = Compare(annotation.span, other.span)) {
+ return value;
+ }
+ if (const int value = Compare(annotation.name, other.name)) {
+ return value;
+ }
+ if (const int value =
+ Compare(annotation.entity.collection, other.entity.collection)) {
+ return value;
+ }
+ return 0;
+}
+
+// Checks whether two annotations can be considered equivalent.
+bool IsEquivalentActionAnnotation(const ActionSuggestionAnnotation& annotation,
+ const ActionSuggestionAnnotation& other) {
+ return Compare(annotation, other) == 0;
+}
+
+// Compares actions based on annotations.
+int CompareAnnotationsOnly(const ActionSuggestion& action,
+ const ActionSuggestion& other) {
+ if (const int value =
+ Compare(action.annotations.size(), other.annotations.size())) {
+ return value;
+ }
+ for (int i = 0; i < action.annotations.size(); i++) {
+ if (const int value =
+ Compare(action.annotations[i], other.annotations[i])) {
+ return value;
+ }
+ }
+ return 0;
+}
+
+// Checks whether two actions have the same annotations.
+bool HaveEquivalentAnnotations(const ActionSuggestion& action,
+ const ActionSuggestion& other) {
+ return CompareAnnotationsOnly(action, other) == 0;
+}
+
+template <>
+int Compare(const ActionSuggestion& action, const ActionSuggestion& other) {
+ if (const int value = Compare(action.type, other.type)) {
+ return value;
+ }
+ if (const int value = Compare(action.response_text, other.response_text)) {
+ return value;
+ }
+ if (const int value = Compare(action.serialized_entity_data,
+ other.serialized_entity_data)) {
+ return value;
+ }
+ return CompareAnnotationsOnly(action, other);
+}
+
+// Checks whether two action suggestions can be considered equivalent.
+bool IsEquivalentActionSuggestion(const ActionSuggestion& action,
+ const ActionSuggestion& other) {
+ return Compare(action, other) == 0;
+}
+
+// Checks whether any action is equivalent to the given one.
+bool IsAnyActionEquivalent(const ActionSuggestion& action,
+ const std::vector<ActionSuggestion>& actions) {
+ for (const ActionSuggestion& other : actions) {
+ if (IsEquivalentActionSuggestion(action, other)) {
+ return true;
+ }
+ }
+ return false;
+}
+
+bool IsConflicting(const ActionSuggestionAnnotation& annotation,
+ const ActionSuggestionAnnotation& other) {
+ // Two annotations are conflicting if they are different but refer to
+ // overlapping spans in the conversation.
+ return (!IsEquivalentActionAnnotation(annotation, other) &&
+ TextSpansIntersect(annotation.span, other.span));
+}
+
+// Checks whether two action suggestions can be considered conflicting.
+bool IsConflictingActionSuggestion(const ActionSuggestion& action,
+ const ActionSuggestion& other) {
+ // Actions are considered conflicting, iff they refer to the same text span,
+ // but were not generated from the same annotation.
+ if (action.annotations.empty() || other.annotations.empty()) {
+ return false;
+ }
+ for (const ActionSuggestionAnnotation& annotation : action.annotations) {
+ for (const ActionSuggestionAnnotation& other_annotation :
+ other.annotations) {
+ if (IsConflicting(annotation, other_annotation)) {
+ return true;
+ }
+ }
+ }
+ return false;
+}
+
+// Checks whether any action is considered conflicting with the given one.
+bool IsAnyActionConflicting(const ActionSuggestion& action,
+ const std::vector<ActionSuggestion>& actions) {
+ for (const ActionSuggestion& other : actions) {
+ if (IsConflictingActionSuggestion(action, other)) {
+ return true;
+ }
+ }
+ return false;
+}
+
+} // namespace
+
+std::unique_ptr<ActionsSuggestionsRanker>
+ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
+ const RankingOptions* options, ZlibDecompressor* decompressor,
+ const std::string& smart_reply_action_type) {
+ auto ranker = std::unique_ptr<ActionsSuggestionsRanker>(
+ new ActionsSuggestionsRanker(options, smart_reply_action_type));
+
+ if (!ranker->InitializeAndValidate(decompressor)) {
+ TC3_LOG(ERROR) << "Could not initialize action ranker.";
+ return nullptr;
+ }
+
+ return ranker;
+}
+
+bool ActionsSuggestionsRanker::InitializeAndValidate(
+ ZlibDecompressor* decompressor) {
+ if (options_ == nullptr) {
+ TC3_LOG(ERROR) << "No ranking options specified.";
+ return false;
+ }
+
+ std::string lua_ranking_script;
+ if (GetUncompressedString(options_->lua_ranking_script(),
+ options_->compressed_lua_ranking_script(),
+ decompressor, &lua_ranking_script) &&
+ !lua_ranking_script.empty()) {
+ if (!Compile(lua_ranking_script, &lua_bytecode_)) {
+ TC3_LOG(ERROR) << "Could not precompile lua ranking snippet.";
+ return false;
+ }
+ }
+
+ return true;
+}
+
+bool ActionsSuggestionsRanker::RankActions(
+ const Conversation& conversation, ActionsSuggestionsResponse* response,
+ const reflection::Schema* entity_data_schema,
+ const reflection::Schema* annotations_entity_data_schema) const {
+ if (options_->deduplicate_suggestions() ||
+ options_->deduplicate_suggestions_by_span()) {
+ // First order suggestions by priority score for deduplication.
+ std::sort(
+ response->actions.begin(), response->actions.end(),
+ [](const ActionSuggestion& a, const ActionSuggestion& b) {
+ return a.priority_score > b.priority_score ||
+ (a.priority_score >= b.priority_score && a.score > b.score);
+ });
+
+ // Deduplicate, keeping the higher score actions.
+ if (options_->deduplicate_suggestions()) {
+ std::vector<ActionSuggestion> deduplicated_actions;
+ for (const ActionSuggestion& candidate : response->actions) {
+ // Check whether we already have an equivalent action.
+ if (!IsAnyActionEquivalent(candidate, deduplicated_actions)) {
+ deduplicated_actions.push_back(std::move(candidate));
+ }
+ }
+ response->actions = std::move(deduplicated_actions);
+ }
+
+ // Resolve conflicts between conflicting actions referring to the same
+ // text span.
+ if (options_->deduplicate_suggestions_by_span()) {
+ std::vector<ActionSuggestion> deduplicated_actions;
+ for (const ActionSuggestion& candidate : response->actions) {
+ // Check whether we already have a conflicting action.
+ if (!IsAnyActionConflicting(candidate, deduplicated_actions)) {
+ deduplicated_actions.push_back(std::move(candidate));
+ }
+ }
+ response->actions = std::move(deduplicated_actions);
+ }
+ }
+
+ // Suppress smart replies if actions are present.
+ if (options_->suppress_smart_replies_with_actions()) {
+ std::vector<ActionSuggestion> non_smart_reply_actions;
+ for (const ActionSuggestion& action : response->actions) {
+ if (action.type != smart_reply_action_type_) {
+ non_smart_reply_actions.push_back(std::move(action));
+ }
+ }
+ response->actions = std::move(non_smart_reply_actions);
+ }
+
+ // Group by annotation if specified.
+ if (options_->group_by_annotations()) {
+ auto group_id = std::map<
+ ActionSuggestion, int,
+ std::function<bool(const ActionSuggestion&, const ActionSuggestion&)>>{
+ [](const ActionSuggestion& action, const ActionSuggestion& other) {
+ return (CompareAnnotationsOnly(action, other) < 0);
+ }};
+ typedef std::vector<ActionSuggestion> ActionSuggestionGroup;
+ std::vector<ActionSuggestionGroup> groups;
+
+ // Group actions by the annotation set they are based of.
+ for (const ActionSuggestion& action : response->actions) {
+ // Treat actions with no annotations idependently.
+ if (action.annotations.empty()) {
+ groups.emplace_back(1, action);
+ continue;
+ }
+
+ auto it = group_id.find(action);
+ if (it != group_id.end()) {
+ groups[it->second].push_back(action);
+ } else {
+ group_id[action] = groups.size();
+ groups.emplace_back(1, action);
+ }
+ }
+
+ // Sort within each group by score.
+ for (std::vector<ActionSuggestion>& group : groups) {
+ SortByScoreAndType(&group);
+ }
+
+ // Sort groups by maximum score.
+ std::sort(groups.begin(), groups.end(),
+ [](const std::vector<ActionSuggestion>& a,
+ const std::vector<ActionSuggestion>& b) {
+ return a.begin()->score > b.begin()->score ||
+ (a.begin()->score >= b.begin()->score &&
+ a.begin()->type < b.begin()->type);
+ });
+
+ // Flatten result.
+ const size_t num_actions = response->actions.size();
+ response->actions.clear();
+ response->actions.reserve(num_actions);
+ for (const std::vector<ActionSuggestion>& actions : groups) {
+ response->actions.insert(response->actions.end(), actions.begin(),
+ actions.end());
+ }
+
+ } else {
+ // Order suggestions independently by score.
+ SortByScoreAndType(&response->actions);
+ }
+
+ // Run lua ranking snippet, if provided.
+ if (!lua_bytecode_.empty()) {
+ auto lua_ranker = ActionsSuggestionsLuaRanker::Create(
+ conversation, lua_bytecode_, entity_data_schema,
+ annotations_entity_data_schema, response);
+ if (lua_ranker == nullptr || !lua_ranker->RankActions()) {
+ TC3_LOG(ERROR) << "Could not run lua ranking snippet.";
+ return false;
+ }
+ }
+
+ return true;
+}
+
+} // namespace libtextclassifier3
diff --git a/actions/ranker.h b/actions/ranker.h
new file mode 100644
index 0000000..2ab3146
--- /dev/null
+++ b/actions/ranker.h
@@ -0,0 +1,56 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ACTIONS_RANKER_H_
+#define LIBTEXTCLASSIFIER_ACTIONS_RANKER_H_
+
+#include <memory>
+
+#include "actions/actions_model_generated.h"
+#include "actions/types.h"
+#include "utils/zlib/zlib.h"
+
+namespace libtextclassifier3 {
+
+// Ranking and filtering of actions suggestions.
+class ActionsSuggestionsRanker {
+ public:
+ static std::unique_ptr<ActionsSuggestionsRanker>
+ CreateActionsSuggestionsRanker(const RankingOptions* options,
+ ZlibDecompressor* decompressor,
+ const std::string& smart_reply_action_type);
+
+ // Rank and filter actions.
+ bool RankActions(
+ const Conversation& conversation, ActionsSuggestionsResponse* response,
+ const reflection::Schema* entity_data_schema = nullptr,
+ const reflection::Schema* annotations_entity_data_schema = nullptr) const;
+
+ private:
+ explicit ActionsSuggestionsRanker(const RankingOptions* options,
+ const std::string& smart_reply_action_type)
+ : options_(options), smart_reply_action_type_(smart_reply_action_type) {}
+
+ bool InitializeAndValidate(ZlibDecompressor* decompressor);
+
+ const RankingOptions* const options_;
+ std::string lua_bytecode_;
+ std::string smart_reply_action_type_;
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ACTIONS_RANKER_H_
diff --git a/actions/ranker_test.cc b/actions/ranker_test.cc
new file mode 100644
index 0000000..b52cf45
--- /dev/null
+++ b/actions/ranker_test.cc
@@ -0,0 +1,382 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "actions/ranker.h"
+
+#include <string>
+
+#include "actions/types.h"
+#include "utils/zlib/zlib.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+MATCHER_P3(IsAction, type, response_text, score, "") {
+ return testing::Value(arg.type, type) &&
+ testing::Value(arg.response_text, response_text) &&
+ testing::Value(arg.score, score);
+}
+
+MATCHER_P(IsActionType, type, "") { return testing::Value(arg.type, type); }
+
+TEST(RankingTest, DeduplicationSmartReply) {
+ const Conversation conversation = {{{/*user_id=*/1, "hello hello"}}};
+ ActionsSuggestionsResponse response;
+ response.actions = {
+ {/*response_text=*/"hello there", /*type=*/"text_reply",
+ /*score=*/1.0},
+ {/*response_text=*/"hello there", /*type=*/"text_reply", /*score=*/0.5}};
+
+ RankingOptionsT options;
+ options.deduplicate_suggestions = true;
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(RankingOptions::Pack(builder, &options));
+ auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
+ flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
+ /*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
+
+ ranker->RankActions(conversation, &response);
+ EXPECT_THAT(
+ response.actions,
+ testing::ElementsAreArray({IsAction("text_reply", "hello there", 1.0)}));
+}
+
+TEST(RankingTest, DeduplicationExtraData) {
+ const Conversation conversation = {{{/*user_id=*/1, "hello hello"}}};
+ ActionsSuggestionsResponse response;
+ response.actions = {
+ {/*response_text=*/"hello there", /*type=*/"text_reply",
+ /*score=*/1.0, /*priority_score=*/0.0},
+ {/*response_text=*/"hello there", /*type=*/"text_reply", /*score=*/0.5,
+ /*priority_score=*/0.0},
+ {/*response_text=*/"hello there", /*type=*/"text_reply", /*score=*/0.6,
+ /*priority_score=*/0.0,
+ /*annotations=*/{}, /*serialized_entity_data=*/"test"},
+ };
+
+ RankingOptionsT options;
+ options.deduplicate_suggestions = true;
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(RankingOptions::Pack(builder, &options));
+ auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
+ flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
+ /*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
+
+ ranker->RankActions(conversation, &response);
+ EXPECT_THAT(
+ response.actions,
+ testing::ElementsAreArray({IsAction("text_reply", "hello there", 1.0),
+ // Is kept as it has different entity data.
+ IsAction("text_reply", "hello there", 0.6)}));
+}
+
+TEST(RankingTest, DeduplicationAnnotations) {
+ const Conversation conversation = {
+ {{/*user_id=*/1, "742 Evergreen Terrace, the number is 1-800-TESTING"}}};
+ ActionsSuggestionsResponse response;
+ {
+ ActionSuggestionAnnotation annotation;
+ annotation.span = {/*message_index=*/0, /*span=*/{0, 21},
+ /*text=*/"742 Evergreen Terrace"};
+ annotation.entity = ClassificationResult("address", 0.5);
+ response.actions.push_back({/*response_text=*/"",
+ /*type=*/"view_map",
+ /*score=*/0.5,
+ /*priority_score=*/1.0,
+ /*annotations=*/{annotation}});
+ }
+ {
+ ActionSuggestionAnnotation annotation;
+ annotation.span = {/*message_index=*/0, /*span=*/{0, 21},
+ /*text=*/"742 Evergreen Terrace"};
+ annotation.entity = ClassificationResult("address", 1.0);
+ response.actions.push_back({/*response_text=*/"",
+ /*type=*/"view_map",
+ /*score=*/1.0,
+ /*priority_score=*/2.0,
+ /*annotations=*/{annotation}});
+ }
+ {
+ ActionSuggestionAnnotation annotation;
+ annotation.span = {/*message_index=*/0, /*span=*/{37, 50},
+ /*text=*/"1-800-TESTING"};
+ annotation.entity = ClassificationResult("phone", 0.5);
+ response.actions.push_back({/*response_text=*/"",
+ /*type=*/"call_phone",
+ /*score=*/0.5,
+ /*priority_score=*/1.0,
+ /*annotations=*/{annotation}});
+ }
+
+ RankingOptionsT options;
+ options.deduplicate_suggestions = true;
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(RankingOptions::Pack(builder, &options));
+ auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
+ flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
+ /*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
+
+ ranker->RankActions(conversation, &response);
+ EXPECT_THAT(response.actions,
+ testing::ElementsAreArray({IsAction("view_map", "", 1.0),
+ IsAction("call_phone", "", 0.5)}));
+}
+
+TEST(RankingTest, DeduplicationAnnotationsByPriorityScore) {
+ const Conversation conversation = {
+ {{/*user_id=*/1, "742 Evergreen Terrace, the number is 1-800-TESTING"}}};
+ ActionsSuggestionsResponse response;
+ {
+ ActionSuggestionAnnotation annotation;
+ annotation.span = {/*message_index=*/0, /*span=*/{0, 21},
+ /*text=*/"742 Evergreen Terrace"};
+ annotation.entity = ClassificationResult("address", 0.5);
+ response.actions.push_back({/*response_text=*/"",
+ /*type=*/"view_map",
+ /*score=*/0.6,
+ /*priority_score=*/2.0,
+ /*annotations=*/{annotation}});
+ }
+ {
+ ActionSuggestionAnnotation annotation;
+ annotation.span = {/*message_index=*/0, /*span=*/{0, 21},
+ /*text=*/"742 Evergreen Terrace"};
+ annotation.entity = ClassificationResult("address", 1.0);
+ response.actions.push_back({/*response_text=*/"",
+ /*type=*/"view_map",
+ /*score=*/1.0,
+ /*priority_score=*/1.0,
+ /*annotations=*/{annotation}});
+ }
+ {
+ ActionSuggestionAnnotation annotation;
+ annotation.span = {/*message_index=*/0, /*span=*/{37, 50},
+ /*text=*/"1-800-TESTING"};
+ annotation.entity = ClassificationResult("phone", 0.5);
+ response.actions.push_back({/*response_text=*/"",
+ /*type=*/"call_phone",
+ /*score=*/0.5,
+ /*priority_score=*/1.0,
+ /*annotations=*/{annotation}});
+ }
+
+ RankingOptionsT options;
+ options.deduplicate_suggestions = true;
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(RankingOptions::Pack(builder, &options));
+ auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
+ flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
+ /*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
+
+ ranker->RankActions(conversation, &response);
+ EXPECT_THAT(
+ response.actions,
+ testing::ElementsAreArray(
+ {IsAction("view_map", "",
+ 0.6), // lower score wins, as priority score is higher
+ IsAction("call_phone", "", 0.5)}));
+}
+
+TEST(RankingTest, DeduplicatesConflictingActions) {
+ const Conversation conversation = {{{/*user_id=*/1, "code A-911"}}};
+ ActionsSuggestionsResponse response;
+ {
+ ActionSuggestionAnnotation annotation;
+ annotation.span = {/*message_index=*/0, /*span=*/{7, 10},
+ /*text=*/"911"};
+ annotation.entity = ClassificationResult("phone", 1.0);
+ response.actions.push_back({/*response_text=*/"",
+ /*type=*/"call_phone",
+ /*score=*/1.0,
+ /*priority_score=*/1.0,
+ /*annotations=*/{annotation}});
+ }
+ {
+ ActionSuggestionAnnotation annotation;
+ annotation.span = {/*message_index=*/0, /*span=*/{5, 10},
+ /*text=*/"A-911"};
+ annotation.entity = ClassificationResult("code", 1.0);
+ response.actions.push_back({/*response_text=*/"",
+ /*type=*/"copy_code",
+ /*score=*/1.0,
+ /*priority_score=*/2.0,
+ /*annotations=*/{annotation}});
+ }
+ RankingOptionsT options;
+ options.deduplicate_suggestions = true;
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(RankingOptions::Pack(builder, &options));
+ auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
+ flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
+ /*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
+
+ ranker->RankActions(conversation, &response);
+ EXPECT_THAT(response.actions,
+ testing::ElementsAreArray({IsAction("copy_code", "", 1.0)}));
+}
+
+TEST(RankingTest, HandlesCompressedLuaScript) {
+ const Conversation conversation = {{{/*user_id=*/1, "hello hello"}}};
+ ActionsSuggestionsResponse response;
+ response.actions = {
+ {/*response_text=*/"hello there", /*type=*/"text_reply",
+ /*score=*/1.0},
+ {/*response_text=*/"", /*type=*/"share_location", /*score=*/0.5},
+ {/*response_text=*/"", /*type=*/"add_to_collection", /*score=*/0.1}};
+ const std::string test_snippet = R"(
+ local result = {}
+ for id, action in pairs(actions) do
+ if action.type ~= "text_reply" then
+ table.insert(result, id)
+ end
+ end
+ return result
+ )";
+ RankingOptionsT options;
+ options.compressed_lua_ranking_script.reset(new CompressedBufferT);
+ std::unique_ptr<ZlibCompressor> compressor = ZlibCompressor::Instance();
+ compressor->Compress(test_snippet,
+ options.compressed_lua_ranking_script.get());
+ options.deduplicate_suggestions = true;
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(RankingOptions::Pack(builder, &options));
+
+ std::unique_ptr<ZlibDecompressor> decompressor = ZlibDecompressor::Instance();
+ auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
+ flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
+ decompressor.get(), /*smart_reply_action_type=*/"text_reply");
+
+ ranker->RankActions(conversation, &response);
+ EXPECT_THAT(response.actions,
+ testing::ElementsAreArray({IsActionType("share_location"),
+ IsActionType("add_to_collection")}));
+}
+
+TEST(RankingTest, SuppressSmartRepliesWithAction) {
+ const Conversation conversation = {{{/*user_id=*/1, "should i call 911"}}};
+ ActionsSuggestionsResponse response;
+ {
+ ActionSuggestionAnnotation annotation;
+ annotation.span = {/*message_index=*/0, /*span=*/{5, 8},
+ /*text=*/"911"};
+ annotation.entity = ClassificationResult("phone", 1.0);
+ response.actions.push_back({/*response_text=*/"",
+ /*type=*/"call_phone",
+ /*score=*/1.0,
+ /*priority_score=*/1.0,
+ /*annotations=*/{annotation}});
+ }
+ response.actions.push_back({/*response_text=*/"How are you?",
+ /*type=*/"text_reply"});
+ RankingOptionsT options;
+ options.suppress_smart_replies_with_actions = true;
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(RankingOptions::Pack(builder, &options));
+ auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
+ flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
+ /*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
+
+ ranker->RankActions(conversation, &response);
+
+ EXPECT_THAT(response.actions,
+ testing::ElementsAreArray({IsAction("call_phone", "", 1.0)}));
+}
+
+TEST(RankingTest, GroupsActionsByAnnotations) {
+ const Conversation conversation = {{{/*user_id=*/1, "should i call 911"}}};
+ ActionsSuggestionsResponse response;
+ {
+ ActionSuggestionAnnotation annotation;
+ annotation.span = {/*message_index=*/0, /*span=*/{5, 8},
+ /*text=*/"911"};
+ annotation.entity = ClassificationResult("phone", 1.0);
+ response.actions.push_back({/*response_text=*/"",
+ /*type=*/"call_phone",
+ /*score=*/1.0,
+ /*priority_score=*/1.0,
+ /*annotations=*/{annotation}});
+ response.actions.push_back({/*response_text=*/"",
+ /*type=*/"add_contact",
+ /*score=*/0.0,
+ /*priority_score=*/0.0,
+ /*annotations=*/{annotation}});
+ }
+ response.actions.push_back({/*response_text=*/"How are you?",
+ /*type=*/"text_reply",
+ /*score=*/0.5});
+ RankingOptionsT options;
+ options.group_by_annotations = true;
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(RankingOptions::Pack(builder, &options));
+ auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
+ flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
+ /*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
+
+ ranker->RankActions(conversation, &response);
+
+ // The text reply should be last, even though it has a higher score than the
+ // `add_contact` action.
+ EXPECT_THAT(
+ response.actions,
+ testing::ElementsAreArray({IsAction("call_phone", "", 1.0),
+ IsAction("add_contact", "", 0.0),
+ IsAction("text_reply", "How are you?", 0.5)}));
+}
+
+TEST(RankingTest, SortsActionsByScore) {
+ const Conversation conversation = {{{/*user_id=*/1, "should i call 911"}}};
+ ActionsSuggestionsResponse response;
+ {
+ ActionSuggestionAnnotation annotation;
+ annotation.span = {/*message_index=*/0, /*span=*/{5, 8},
+ /*text=*/"911"};
+ annotation.entity = ClassificationResult("phone", 1.0);
+ response.actions.push_back({/*response_text=*/"",
+ /*type=*/"call_phone",
+ /*score=*/1.0,
+ /*priority_score=*/1.0,
+ /*annotations=*/{annotation}});
+ response.actions.push_back({/*response_text=*/"",
+ /*type=*/"add_contact",
+ /*score=*/0.0,
+ /*priority_score=*/0.0,
+ /*annotations=*/{annotation}});
+ }
+ response.actions.push_back({/*response_text=*/"How are you?",
+ /*type=*/"text_reply",
+ /*score=*/0.5});
+ RankingOptionsT options;
+ // Don't group by annotation.
+ options.group_by_annotations = false;
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(RankingOptions::Pack(builder, &options));
+ auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
+ flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
+ /*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
+
+ ranker->RankActions(conversation, &response);
+
+ EXPECT_THAT(
+ response.actions,
+ testing::ElementsAreArray({IsAction("call_phone", "", 1.0),
+ IsAction("text_reply", "How are you?", 0.5),
+ IsAction("add_contact", "", 0.0)}));
+}
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/actions/test_data/actions_suggestions_test.default.model b/actions/test_data/actions_suggestions_test.default.model
new file mode 100644
index 0000000..60f10e6
--- /dev/null
+++ b/actions/test_data/actions_suggestions_test.default.model
Binary files differ
diff --git a/actions/test_data/actions_suggestions_test.hashgram.model b/actions/test_data/actions_suggestions_test.hashgram.model
new file mode 100644
index 0000000..cdc6bdc
--- /dev/null
+++ b/actions/test_data/actions_suggestions_test.hashgram.model
Binary files differ
diff --git a/actions/test_data/actions_suggestions_test.model b/actions/test_data/actions_suggestions_test.model
new file mode 100644
index 0000000..6cec2b7
--- /dev/null
+++ b/actions/test_data/actions_suggestions_test.model
Binary files differ
diff --git a/actions/test_utils.cc b/actions/test_utils.cc
new file mode 100644
index 0000000..187aa67
--- /dev/null
+++ b/actions/test_utils.cc
@@ -0,0 +1,78 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "actions/test_utils.h"
+
+namespace libtextclassifier3 {
+
+std::string TestEntityDataSchema() {
+ // Create fake entity data schema meta data.
+ // Cannot use object oriented API here as that is not available for the
+ // reflection schema.
+ flatbuffers::FlatBufferBuilder schema_builder;
+ std::vector<flatbuffers::Offset<reflection::Field>> fields = {
+ reflection::CreateField(
+ schema_builder,
+ /*name=*/schema_builder.CreateString("greeting"),
+ /*type=*/
+ reflection::CreateType(schema_builder,
+ /*base_type=*/reflection::String),
+ /*id=*/0,
+ /*offset=*/4),
+ reflection::CreateField(
+ schema_builder,
+ /*name=*/schema_builder.CreateString("location"),
+ /*type=*/
+ reflection::CreateType(schema_builder,
+ /*base_type=*/reflection::String),
+ /*id=*/1,
+ /*offset=*/6),
+ reflection::CreateField(
+ schema_builder,
+ /*name=*/schema_builder.CreateString("person"),
+ /*type=*/
+ reflection::CreateType(schema_builder,
+ /*base_type=*/reflection::String),
+ /*id=*/2,
+ /*offset=*/8)};
+ std::vector<flatbuffers::Offset<reflection::Enum>> enums;
+ std::vector<flatbuffers::Offset<reflection::Object>> objects = {
+ reflection::CreateObject(
+ schema_builder,
+ /*name=*/schema_builder.CreateString("EntityData"),
+ /*fields=*/
+ schema_builder.CreateVectorOfSortedTables(&fields))};
+ schema_builder.Finish(reflection::CreateSchema(
+ schema_builder, schema_builder.CreateVectorOfSortedTables(&objects),
+ schema_builder.CreateVectorOfSortedTables(&enums),
+ /*(unused) file_ident=*/0,
+ /*(unused) file_ext=*/0,
+ /*root_table*/ objects[0]));
+
+ return std::string(
+ reinterpret_cast<const char*>(schema_builder.GetBufferPointer()),
+ schema_builder.GetSize());
+}
+
+void SetTestEntityDataSchema(ActionsModelT* test_model) {
+ const std::string serialized_schema = TestEntityDataSchema();
+
+ test_model->actions_entity_data_schema.assign(
+ serialized_schema.data(),
+ serialized_schema.data() + serialized_schema.size());
+}
+
+} // namespace libtextclassifier3
diff --git a/actions/test_utils.h b/actions/test_utils.h
new file mode 100644
index 0000000..618523c
--- /dev/null
+++ b/actions/test_utils.h
@@ -0,0 +1,32 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ACTIONS_TEST_UTILS_H_
+#define LIBTEXTCLASSIFIER_ACTIONS_TEST_UTILS_H_
+
+#include <string>
+#include "actions/actions_model_generated.h"
+#include "utils/flatbuffers.h"
+
+namespace libtextclassifier3 {
+
+// Create test entity data schema.
+std::string TestEntityDataSchema();
+void SetTestEntityDataSchema(ActionsModelT* test_model);
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ACTIONS_TEST_UTILS_H_
diff --git a/actions/types.h b/actions/types.h
new file mode 100644
index 0000000..212cfda
--- /dev/null
+++ b/actions/types.h
@@ -0,0 +1,145 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ACTIONS_TYPES_H_
+#define LIBTEXTCLASSIFIER_ACTIONS_TYPES_H_
+
+#include <map>
+#include <string>
+#include <vector>
+
+#include "actions/actions-entity-data_generated.h"
+#include "annotator/types.h"
+#include "utils/flatbuffers.h"
+
+namespace libtextclassifier3 {
+
+// A text span in the conversation.
+struct MessageTextSpan {
+ // The referenced message.
+ // -1 if not referencing a particular message in the provided input.
+ int message_index;
+
+ // The span within the reference message.
+ // (-1, -1) if not referencing a particular location.
+ CodepointSpan span;
+
+ // The span text.
+ std::string text;
+
+ explicit MessageTextSpan()
+ : message_index(kInvalidIndex), span({kInvalidIndex, kInvalidIndex}) {}
+ MessageTextSpan(const int message_index, const CodepointSpan span,
+ const std::string& text)
+ : message_index(message_index), span(span), text(text) {}
+};
+
+// An entity associated with an action.
+struct ActionSuggestionAnnotation {
+ MessageTextSpan span;
+ ClassificationResult entity;
+
+ // Optional annotation name.
+ std::string name;
+};
+
+// Action suggestion that contains a response text and the type of the response.
+struct ActionSuggestion {
+ // Text of the action suggestion.
+ std::string response_text;
+
+ // Type of the action suggestion.
+ std::string type;
+
+ // Score.
+ float score;
+
+ // Priority score for internal conflict resolution.
+ float priority_score;
+
+ // The associated annotations.
+ std::vector<ActionSuggestionAnnotation> annotations;
+
+ // Extras information.
+ std::string serialized_entity_data;
+
+ const ActionsEntityData* entity_data() {
+ return LoadAndVerifyFlatbuffer<ActionsEntityData>(
+ serialized_entity_data.data(), serialized_entity_data.size());
+ }
+};
+
+// Actions suggestions result containing meta - information and the suggested
+// actions.
+struct ActionsSuggestionsResponse {
+ ActionsSuggestionsResponse()
+ : sensitivity_score(-1),
+ triggering_score(-1),
+ output_filtered_sensitivity(false),
+ output_filtered_min_triggering_score(false),
+ output_filtered_low_confidence(false),
+ output_filtered_locale_mismatch(false) {}
+
+ // The sensitivity assessment.
+ float sensitivity_score;
+ float triggering_score;
+
+ // Whether the output was suppressed by the sensitivity threshold.
+ bool output_filtered_sensitivity;
+
+ // Whether the output was suppressed by the triggering score threshold.
+ bool output_filtered_min_triggering_score;
+
+ // Whether the output was suppressed by the low confidence patterns.
+ bool output_filtered_low_confidence;
+
+ // Whether the output was suppressed due to locale mismatch.
+ bool output_filtered_locale_mismatch;
+
+ // The suggested actions.
+ std::vector<ActionSuggestion> actions;
+};
+
+// Represents a single message in the conversation.
+struct ConversationMessage {
+ // User ID distinguishing the user from other users in the conversation.
+ int user_id;
+
+ // Text of the message.
+ std::string text;
+
+ // Reference time of this message.
+ int64 reference_time_ms_utc;
+
+ // Timezone in which the input text was written (format as accepted by ICU).
+ std::string reference_timezone;
+
+ // Annotations on the text.
+ std::vector<AnnotatedSpan> annotations;
+
+ // Comma-separated list of BCP 47 language tags of the message.
+ std::string detected_text_language_tags;
+};
+
+// Conversation between multiple users.
+struct Conversation {
+ // Sequence of messages that were exchanged in the conversation.
+ std::vector<ConversationMessage> messages;
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ACTIONS_TYPES_H_
diff --git a/actions/zlib-utils.cc b/actions/zlib-utils.cc
new file mode 100644
index 0000000..b1d997d
--- /dev/null
+++ b/actions/zlib-utils.cc
@@ -0,0 +1,173 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "actions/zlib-utils.h"
+
+#include <memory>
+
+#include "utils/base/logging.h"
+#include "utils/intents/zlib-utils.h"
+#include "utils/resources.h"
+
+namespace libtextclassifier3 {
+
+// Compress rule fields in the model.
+bool CompressActionsModel(ActionsModelT* model) {
+ std::unique_ptr<ZlibCompressor> zlib_compressor = ZlibCompressor::Instance();
+ if (!zlib_compressor) {
+ TC3_LOG(ERROR) << "Cannot compress model.";
+ return false;
+ }
+
+ // Compress regex rules.
+ if (model->rules != nullptr) {
+ for (int i = 0; i < model->rules->rule.size(); i++) {
+ RulesModel_::RuleT* rule = model->rules->rule[i].get();
+ rule->compressed_pattern.reset(new CompressedBufferT);
+ zlib_compressor->Compress(rule->pattern, rule->compressed_pattern.get());
+ rule->pattern.clear();
+ }
+ }
+
+ if (model->low_confidence_rules != nullptr) {
+ for (int i = 0; i < model->low_confidence_rules->rule.size(); i++) {
+ RulesModel_::RuleT* rule = model->low_confidence_rules->rule[i].get();
+ if (!rule->pattern.empty()) {
+ rule->compressed_pattern.reset(new CompressedBufferT);
+ zlib_compressor->Compress(rule->pattern,
+ rule->compressed_pattern.get());
+ rule->pattern.clear();
+ }
+ if (!rule->output_pattern.empty()) {
+ rule->compressed_output_pattern.reset(new CompressedBufferT);
+ zlib_compressor->Compress(rule->pattern,
+ rule->compressed_output_pattern.get());
+ rule->output_pattern.clear();
+ }
+ }
+ }
+
+ if (!model->lua_actions_script.empty()) {
+ model->compressed_lua_actions_script.reset(new CompressedBufferT);
+ zlib_compressor->Compress(model->lua_actions_script,
+ model->compressed_lua_actions_script.get());
+ }
+
+ if (model->ranking_options != nullptr &&
+ !model->ranking_options->lua_ranking_script.empty()) {
+ model->ranking_options->compressed_lua_ranking_script.reset(
+ new CompressedBufferT);
+ zlib_compressor->Compress(
+ model->ranking_options->lua_ranking_script,
+ model->ranking_options->compressed_lua_ranking_script.get());
+ }
+
+ // Compress resources.
+ if (model->resources != nullptr) {
+ CompressResources(model->resources.get());
+ }
+
+ // Compress intent generator.
+ if (model->android_intent_options != nullptr) {
+ CompressIntentModel(model->android_intent_options.get());
+ }
+
+ return true;
+}
+
+bool DecompressActionsModel(ActionsModelT* model) {
+ std::unique_ptr<ZlibDecompressor> zlib_decompressor =
+ ZlibDecompressor::Instance();
+ if (!zlib_decompressor) {
+ TC3_LOG(ERROR) << "Cannot initialize decompressor.";
+ return false;
+ }
+
+ // Decompress regex rules.
+ if (model->rules != nullptr) {
+ for (int i = 0; i < model->rules->rule.size(); i++) {
+ RulesModel_::RuleT* rule = model->rules->rule[i].get();
+ if (!zlib_decompressor->MaybeDecompress(rule->compressed_pattern.get(),
+ &rule->pattern)) {
+ TC3_LOG(ERROR) << "Cannot decompress pattern: " << i;
+ return false;
+ }
+ rule->compressed_pattern.reset(nullptr);
+ }
+ }
+
+ // Decompress low confidence rules.
+ if (model->low_confidence_rules != nullptr) {
+ for (int i = 0; i < model->low_confidence_rules->rule.size(); i++) {
+ RulesModel_::RuleT* rule = model->low_confidence_rules->rule[i].get();
+ if (!zlib_decompressor->MaybeDecompress(rule->compressed_pattern.get(),
+ &rule->pattern)) {
+ TC3_LOG(ERROR) << "Cannot decompress pattern: " << i;
+ return false;
+ }
+ if (!zlib_decompressor->MaybeDecompress(
+ rule->compressed_output_pattern.get(), &rule->output_pattern)) {
+ TC3_LOG(ERROR) << "Cannot decompress pattern: " << i;
+ return false;
+ }
+ rule->compressed_pattern.reset(nullptr);
+ rule->compressed_output_pattern.reset(nullptr);
+ }
+ }
+
+ if (!zlib_decompressor->MaybeDecompress(
+ model->compressed_lua_actions_script.get(),
+ &model->lua_actions_script)) {
+ TC3_LOG(ERROR) << "Cannot decompress actions script.";
+ return false;
+ }
+
+ if (model->ranking_options != nullptr &&
+ !zlib_decompressor->MaybeDecompress(
+ model->ranking_options->compressed_lua_ranking_script.get(),
+ &model->ranking_options->lua_ranking_script)) {
+ TC3_LOG(ERROR) << "Cannot decompress actions script.";
+ return false;
+ }
+
+ return true;
+}
+
+std::string CompressSerializedActionsModel(const std::string& model) {
+ std::unique_ptr<ActionsModelT> unpacked_model =
+ UnPackActionsModel(model.c_str());
+ TC3_CHECK(unpacked_model != nullptr);
+ TC3_CHECK(CompressActionsModel(unpacked_model.get()));
+ flatbuffers::FlatBufferBuilder builder;
+ FinishActionsModelBuffer(builder,
+ ActionsModel::Pack(builder, unpacked_model.get()));
+ return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize());
+}
+
+bool GetUncompressedString(const flatbuffers::String* uncompressed_buffer,
+ const CompressedBuffer* compressed_buffer,
+ ZlibDecompressor* decompressor, std::string* out) {
+ if (uncompressed_buffer == nullptr && compressed_buffer == nullptr) {
+ out->clear();
+ return true;
+ }
+
+ return decompressor->MaybeDecompressOptionallyCompressedBuffer(
+ uncompressed_buffer, compressed_buffer, out);
+}
+
+} // namespace libtextclassifier3
diff --git a/actions/zlib-utils.h b/actions/zlib-utils.h
new file mode 100644
index 0000000..951a4e4
--- /dev/null
+++ b/actions/zlib-utils.h
@@ -0,0 +1,43 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Functions to compress and decompress low entropy entries in the model.
+
+#ifndef LIBTEXTCLASSIFIER_ACTIONS_ZLIB_UTILS_H_
+#define LIBTEXTCLASSIFIER_ACTIONS_ZLIB_UTILS_H_
+
+#include "actions/actions_model_generated.h"
+#include "utils/zlib/buffer_generated.h"
+#include "utils/zlib/zlib.h"
+
+namespace libtextclassifier3 {
+
+// Compresses regex rules in the model in place.
+bool CompressActionsModel(ActionsModelT* model);
+
+// Decompresses regex rules in the model in place.
+bool DecompressActionsModel(ActionsModelT* model);
+
+// Compresses regex rules in the model.
+std::string CompressSerializedActionsModel(const std::string& model);
+
+bool GetUncompressedString(const flatbuffers::String* uncompressed_buffer,
+ const CompressedBuffer* compressed_buffer,
+ ZlibDecompressor* decompressor, std::string* out);
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ACTIONS_ZLIB_UTILS_H_
diff --git a/actions/zlib-utils_test.cc b/actions/zlib-utils_test.cc
new file mode 100644
index 0000000..377f344
--- /dev/null
+++ b/actions/zlib-utils_test.cc
@@ -0,0 +1,72 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "actions/zlib-utils.h"
+
+#include <memory>
+
+#include "actions/actions_model_generated.h"
+#include "utils/zlib/zlib.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+
+namespace {
+
+TEST(ZlibUtilsTest, CompressModel) {
+ ActionsModelT model;
+ constexpr char kTestPattern1[] = "this is a test pattern";
+ constexpr char kTestPattern2[] = "this is a second test pattern";
+ model.rules.reset(new RulesModelT);
+ model.rules->rule.emplace_back(new RulesModel_::RuleT);
+ model.rules->rule.back()->pattern = kTestPattern1;
+ model.rules->rule.emplace_back(new RulesModel_::RuleT);
+ model.rules->rule.back()->pattern = kTestPattern2;
+
+ // Compress the model.
+ EXPECT_TRUE(CompressActionsModel(&model));
+
+ // Sanity check that uncompressed field is removed.
+ EXPECT_TRUE(model.rules->rule[0]->pattern.empty());
+ EXPECT_TRUE(model.rules->rule[1]->pattern.empty());
+ // Pack and load the model.
+ flatbuffers::FlatBufferBuilder builder;
+ FinishActionsModelBuffer(builder, ActionsModel::Pack(builder, &model));
+ const ActionsModel* compressed_model = GetActionsModel(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()));
+ ASSERT_TRUE(compressed_model != nullptr);
+
+ // Decompress the fields again and check that they match the original.
+ std::unique_ptr<ZlibDecompressor> decompressor = ZlibDecompressor::Instance();
+ ASSERT_TRUE(decompressor != nullptr);
+ std::string uncompressed_pattern;
+ EXPECT_TRUE(decompressor->MaybeDecompress(
+ compressed_model->rules()->rule()->Get(0)->compressed_pattern(),
+ &uncompressed_pattern));
+ EXPECT_EQ(uncompressed_pattern, kTestPattern1);
+ EXPECT_TRUE(decompressor->MaybeDecompress(
+ compressed_model->rules()->rule()->Get(1)->compressed_pattern(),
+ &uncompressed_pattern));
+ EXPECT_EQ(uncompressed_pattern, kTestPattern2);
+ EXPECT_TRUE(DecompressActionsModel(&model));
+ EXPECT_EQ(model.rules->rule[0]->pattern, kTestPattern1);
+ EXPECT_EQ(model.rules->rule[1]->pattern, kTestPattern2);
+}
+
+} // namespace
+
+} // namespace libtextclassifier3
diff --git a/annotator/annotator.cc b/annotator/annotator.cc
index 2be9d3c..53c8d8a 100644
--- a/annotator/annotator.cc
+++ b/annotator/annotator.cc
@@ -21,15 +21,23 @@
#include <cmath>
#include <iterator>
#include <numeric>
+#include <unordered_map>
+#include "annotator/collections.h"
+#include "annotator/model_generated.h"
+#include "annotator/types.h"
#include "utils/base/logging.h"
#include "utils/checksum.h"
#include "utils/math/softmax.h"
+#include "utils/regex-match.h"
#include "utils/utf8/unicodetext.h"
+#include "utils/zlib/zlib_regex.h"
+
namespace libtextclassifier3 {
-const std::string& Annotator::kOtherCollection =
- *[]() { return new std::string("other"); }();
+
+using SortedIntSet = std::set<int, std::function<bool(int, int)>>;
+
const std::string& Annotator::kPhoneCollection =
*[]() { return new std::string("phone"); }();
const std::string& Annotator::kAddressCollection =
@@ -38,18 +46,8 @@
*[]() { return new std::string("date"); }();
const std::string& Annotator::kUrlCollection =
*[]() { return new std::string("url"); }();
-const std::string& Annotator::kFlightCollection =
- *[]() { return new std::string("flight"); }();
const std::string& Annotator::kEmailCollection =
*[]() { return new std::string("email"); }();
-const std::string& Annotator::kIbanCollection =
- *[]() { return new std::string("iban"); }();
-const std::string& Annotator::kPaymentCardCollection =
- *[]() { return new std::string("payment_card"); }();
-const std::string& Annotator::kIsbnCollection =
- *[]() { return new std::string("isbn"); }();
-const std::string& Annotator::kTrackingNumberCollection =
- *[]() { return new std::string("tracking_number"); }();
namespace {
const Model* LoadAndVerifyModel(const void* addr, int size) {
@@ -150,6 +148,30 @@
return classifier;
}
+std::unique_ptr<Annotator> Annotator::FromScopedMmap(
+ std::unique_ptr<ScopedMmap>* mmap, std::unique_ptr<UniLib> unilib,
+ std::unique_ptr<CalendarLib> calendarlib) {
+ if (!(*mmap)->handle().ok()) {
+ TC3_VLOG(1) << "Mmap failed.";
+ return nullptr;
+ }
+
+ const Model* model = LoadAndVerifyModel((*mmap)->handle().start(),
+ (*mmap)->handle().num_bytes());
+ if (model == nullptr) {
+ TC3_LOG(ERROR) << "Model verification failed.";
+ return nullptr;
+ }
+
+ auto classifier = std::unique_ptr<Annotator>(
+ new Annotator(mmap, model, std::move(unilib), std::move(calendarlib)));
+ if (!classifier->IsInitialized()) {
+ return nullptr;
+ }
+
+ return classifier;
+}
+
std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
int fd, int offset, int size, const UniLib* unilib,
const CalendarLib* calendarlib) {
@@ -158,11 +180,25 @@
}
std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
+ int fd, int offset, int size, std::unique_ptr<UniLib> unilib,
+ std::unique_ptr<CalendarLib> calendarlib) {
+ std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd, offset, size));
+ return FromScopedMmap(&mmap, std::move(unilib), std::move(calendarlib));
+}
+
+std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
int fd, const UniLib* unilib, const CalendarLib* calendarlib) {
std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd));
return FromScopedMmap(&mmap, unilib, calendarlib);
}
+std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
+ int fd, std::unique_ptr<UniLib> unilib,
+ std::unique_ptr<CalendarLib> calendarlib) {
+ std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd));
+ return FromScopedMmap(&mmap, std::move(unilib), std::move(calendarlib));
+}
+
std::unique_ptr<Annotator> Annotator::FromPath(const std::string& path,
const UniLib* unilib,
const CalendarLib* calendarlib) {
@@ -170,6 +206,13 @@
return FromScopedMmap(&mmap, unilib, calendarlib);
}
+std::unique_ptr<Annotator> Annotator::FromPath(
+ const std::string& path, std::unique_ptr<UniLib> unilib,
+ std::unique_ptr<CalendarLib> calendarlib) {
+ std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(path));
+ return FromScopedMmap(&mmap, std::move(unilib), std::move(calendarlib));
+}
+
Annotator::Annotator(std::unique_ptr<ScopedMmap>* mmap, const Model* model,
const UniLib* unilib, const CalendarLib* calendarlib)
: model_(model),
@@ -181,6 +224,18 @@
ValidateAndInitialize();
}
+Annotator::Annotator(std::unique_ptr<ScopedMmap>* mmap, const Model* model,
+ std::unique_ptr<UniLib> unilib,
+ std::unique_ptr<CalendarLib> calendarlib)
+ : model_(model),
+ mmap_(std::move(*mmap)),
+ owned_unilib_(std::move(unilib)),
+ unilib_(owned_unilib_.get()),
+ owned_calendarlib_(std::move(calendarlib)),
+ calendarlib_(owned_calendarlib_.get()) {
+ ValidateAndInitialize();
+}
+
Annotator::Annotator(const Model* model, const UniLib* unilib,
const CalendarLib* calendarlib)
: model_(model),
@@ -297,8 +352,8 @@
embedding_executor_ = TFLiteEmbeddingExecutor::FromBuffer(
model_->embedding_model(),
model_->classification_feature_options()->embedding_size(),
- model_->classification_feature_options()
- ->embedding_quantization_bits());
+ model_->classification_feature_options()->embedding_quantization_bits(),
+ model_->embedding_pruning_mask());
if (!embedding_executor_) {
TC3_LOG(ERROR) << "Could not initialize embedding executor.";
return;
@@ -343,6 +398,65 @@
}
}
+ if (model_->number_annotator_options() &&
+ model_->number_annotator_options()->enabled()) {
+ if (selection_feature_processor_ == nullptr) {
+ TC3_LOG(ERROR)
+ << "Could not initialize NumberAnnotator without a feature processor";
+ return;
+ }
+
+ number_annotator_.reset(
+ new NumberAnnotator(model_->number_annotator_options(),
+ selection_feature_processor_.get()));
+ }
+
+ if (model_->duration_annotator_options() &&
+ model_->duration_annotator_options()->enabled()) {
+ duration_annotator_.reset(
+ new DurationAnnotator(model_->duration_annotator_options(),
+ selection_feature_processor_.get()));
+ }
+
+ if (model_->entity_data_schema()) {
+ entity_data_schema_ = LoadAndVerifyFlatbuffer<reflection::Schema>(
+ model_->entity_data_schema()->Data(),
+ model_->entity_data_schema()->size());
+ if (entity_data_schema_ == nullptr) {
+ TC3_LOG(ERROR) << "Could not load entity data schema data.";
+ return;
+ }
+
+ entity_data_builder_.reset(
+ new ReflectiveFlatbufferBuilder(entity_data_schema_));
+ } else {
+ entity_data_schema_ = nullptr;
+ entity_data_builder_ = nullptr;
+ }
+
+ if (model_->triggering_locales() &&
+ !ParseLocales(model_->triggering_locales()->c_str(),
+ &model_triggering_locales_)) {
+ TC3_LOG(ERROR) << "Could not parse model supported locales.";
+ return;
+ }
+
+ if (model_->triggering_options() != nullptr &&
+ model_->triggering_options()->locales() != nullptr &&
+ !ParseLocales(model_->triggering_options()->locales()->c_str(),
+ &ml_model_triggering_locales_)) {
+ TC3_LOG(ERROR) << "Could not parse supported ML model locales.";
+ return;
+ }
+
+ if (model_->triggering_options() != nullptr &&
+ model_->triggering_options()->dictionary_locales() != nullptr &&
+ !ParseLocales(model_->triggering_options()->dictionary_locales()->c_str(),
+ &dictionary_locales_)) {
+ TC3_LOG(ERROR) << "Could not parse dictionary supported locales.";
+ return;
+ }
+
initialized_ = true;
}
@@ -355,9 +469,10 @@
int regex_pattern_id = 0;
for (const auto& regex_pattern : *model_->regex_model()->patterns()) {
std::unique_ptr<UniLib::RegexPattern> compiled_pattern =
- UncompressMakeRegexPattern(*unilib_, regex_pattern->pattern(),
- regex_pattern->compressed_pattern(),
- decompressor);
+ UncompressMakeRegexPattern(
+ *unilib_, regex_pattern->pattern(),
+ regex_pattern->compressed_pattern(),
+ model_->regex_model()->lazy_regex_compilation(), decompressor);
if (!compiled_pattern) {
TC3_LOG(INFO) << "Failed to load regex pattern";
return false;
@@ -373,15 +488,9 @@
selection_regex_patterns_.push_back(regex_pattern_id);
}
regex_patterns_.push_back({
- regex_pattern->collection_name()->str(),
- regex_pattern->target_classification_score(),
- regex_pattern->priority_score(),
+ regex_pattern,
std::move(compiled_pattern),
- regex_pattern->verification_options(),
});
- if (regex_pattern->use_approximate_matching()) {
- regex_approximate_match_pattern_ids_.insert(regex_pattern_id);
- }
++regex_pattern_id;
}
@@ -400,6 +509,29 @@
return true;
}
+bool Annotator::InitializeContactEngine(const std::string& serialized_config) {
+ std::unique_ptr<ContactEngine> contact_engine(
+ new ContactEngine(selection_feature_processor_.get(), unilib_));
+ if (!contact_engine->Initialize(serialized_config)) {
+ TC3_LOG(ERROR) << "Failed to initialize the contact engine.";
+ return false;
+ }
+ contact_engine_ = std::move(contact_engine);
+ return true;
+}
+
+bool Annotator::InitializeInstalledAppEngine(
+ const std::string& serialized_config) {
+ std::unique_ptr<InstalledAppEngine> installed_app_engine(
+ new InstalledAppEngine(selection_feature_processor_.get(), unilib_));
+ if (!installed_app_engine->Initialize(serialized_config)) {
+ TC3_LOG(ERROR) << "Failed to initialize the installed app engine.";
+ return false;
+ }
+ installed_app_engine_ = std::move(installed_app_engine);
+ return true;
+}
+
namespace {
int CountDigits(const std::string& str, CodepointSpan selection_indices) {
@@ -415,29 +547,6 @@
return count;
}
-std::string ExtractSelection(const std::string& context,
- CodepointSpan selection_indices) {
- const UnicodeText context_unicode =
- UTF8ToUnicodeText(context, /*do_copy=*/false);
- auto selection_begin = context_unicode.begin();
- std::advance(selection_begin, selection_indices.first);
- auto selection_end = context_unicode.begin();
- std::advance(selection_end, selection_indices.second);
- return UnicodeText::UTF8Substring(selection_begin, selection_end);
-}
-
-bool VerifyCandidate(const VerificationOptions* verification_options,
- const std::string& match) {
- if (!verification_options) {
- return true;
- }
- if (verification_options->verify_luhn_checksum() &&
- !VerifyLuhnChecksum(match)) {
- return false;
- }
- return true;
-}
-
} // namespace
namespace internal {
@@ -501,6 +610,47 @@
filtered_collections_selection_.end();
}
+namespace {
+inline bool ClassifiedAsOther(
+ const std::vector<ClassificationResult>& classification) {
+ return !classification.empty() &&
+ classification[0].collection == Collections::Other();
+}
+
+float GetPriorityScore(
+ const std::vector<ClassificationResult>& classification) {
+ if (!classification.empty() && !ClassifiedAsOther(classification)) {
+ return classification[0].priority_score;
+ } else {
+ return -1.0;
+ }
+}
+} // namespace
+
+bool Annotator::VerifyRegexMatchCandidate(
+ const std::string& context, const VerificationOptions* verification_options,
+ const std::string& match, const UniLib::RegexMatcher* matcher) const {
+ if (verification_options == nullptr) {
+ return true;
+ }
+ if (verification_options->verify_luhn_checksum() &&
+ !VerifyLuhnChecksum(match)) {
+ return false;
+ }
+ const int lua_verifier = verification_options->lua_verifier();
+ if (lua_verifier >= 0) {
+ if (model_->regex_model()->lua_verifier() == nullptr ||
+ lua_verifier >= model_->regex_model()->lua_verifier()->size()) {
+ TC3_LOG(ERROR) << "Invalid lua verifier specified: " << lua_verifier;
+ return false;
+ }
+ return VerifyMatch(
+ context, matcher,
+ model_->regex_model()->lua_verifier()->Get(lua_verifier)->str());
+ }
+ return true;
+}
+
CodepointSpan Annotator::SuggestSelection(
const std::string& context, CodepointSpan click_indices,
const SelectionOptions& options) const {
@@ -513,6 +663,19 @@
return original_click_indices;
}
+ std::vector<Locale> detected_text_language_tags;
+ if (!ParseLocales(options.detected_text_language_tags,
+ &detected_text_language_tags)) {
+ TC3_LOG(WARNING)
+ << "Failed to parse the detected_text_language_tags in options: "
+ << options.detected_text_language_tags;
+ }
+ if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
+ model_triggering_locales_,
+ /*default_value=*/true)) {
+ return original_click_indices;
+ }
+
const UnicodeText context_unicode = UTF8ToUnicodeText(context,
/*do_copy=*/false);
@@ -553,24 +716,51 @@
classification_executor_.get());
std::vector<Token> tokens;
if (!ModelSuggestSelection(context_unicode, click_indices,
- &interpreter_manager, &tokens, &candidates)) {
+ detected_text_language_tags, &interpreter_manager,
+ &tokens, &candidates)) {
TC3_LOG(ERROR) << "Model suggest selection failed.";
return original_click_indices;
}
- if (!RegexChunk(context_unicode, selection_regex_patterns_, &candidates)) {
+ if (!RegexChunk(context_unicode, selection_regex_patterns_, &candidates,
+ /*is_serialized_entity_data_enabled=*/false)) {
TC3_LOG(ERROR) << "Regex suggest selection failed.";
return original_click_indices;
}
- if (!DatetimeChunk(UTF8ToUnicodeText(context, /*do_copy=*/false),
- /*reference_time_ms_utc=*/0, /*reference_timezone=*/"",
- options.locales, ModeFlag_SELECTION, &candidates)) {
+ if (!DatetimeChunk(
+ UTF8ToUnicodeText(context, /*do_copy=*/false),
+ /*reference_time_ms_utc=*/0, /*reference_timezone=*/"",
+ options.locales, ModeFlag_SELECTION, options.annotation_usecase,
+ /*is_serialized_entity_data_enabled=*/false, &candidates)) {
TC3_LOG(ERROR) << "Datetime suggest selection failed.";
return original_click_indices;
}
- if (knowledge_engine_ && !knowledge_engine_->Chunk(context, &candidates)) {
+ if (knowledge_engine_ != nullptr &&
+ !knowledge_engine_->Chunk(context, &candidates)) {
TC3_LOG(ERROR) << "Knowledge suggest selection failed.";
return original_click_indices;
}
+ if (contact_engine_ != nullptr &&
+ !contact_engine_->Chunk(context_unicode, tokens, &candidates)) {
+ TC3_LOG(ERROR) << "Contact suggest selection failed.";
+ return original_click_indices;
+ }
+ if (installed_app_engine_ != nullptr &&
+ !installed_app_engine_->Chunk(context_unicode, tokens, &candidates)) {
+ TC3_LOG(ERROR) << "Installed app suggest selection failed.";
+ return original_click_indices;
+ }
+ if (number_annotator_ != nullptr &&
+ !number_annotator_->FindAll(context_unicode, options.annotation_usecase,
+ &candidates)) {
+ TC3_LOG(ERROR) << "Number annotator failed in suggest selection.";
+ return original_click_indices;
+ }
+ if (duration_annotator_ != nullptr &&
+ !duration_annotator_->FindAll(context_unicode, tokens,
+ options.annotation_usecase, &candidates)) {
+ TC3_LOG(ERROR) << "Duration annotator failed in suggest selection.";
+ return original_click_indices;
+ }
// Sort candidates according to their position in the input, so that the next
// code can assume that any connected component of overlapping spans forms a
@@ -581,12 +771,19 @@
});
std::vector<int> candidate_indices;
- if (!ResolveConflicts(candidates, context, tokens, &interpreter_manager,
- &candidate_indices)) {
+ if (!ResolveConflicts(candidates, context, tokens,
+ detected_text_language_tags, options.annotation_usecase,
+ &interpreter_manager, &candidate_indices)) {
TC3_LOG(ERROR) << "Couldn't resolve conflicts.";
return original_click_indices;
}
+ std::sort(candidate_indices.begin(), candidate_indices.end(),
+ [&candidates](int a, int b) {
+ return GetPriorityScore(candidates[a].classification) >
+ GetPriorityScore(candidates[b].classification);
+ });
+
for (const int i : candidate_indices) {
if (SpansOverlap(candidates[i].span, click_indices) &&
SpansOverlap(candidates[i].span, original_click_indices)) {
@@ -595,9 +792,10 @@
if (candidates[i].classification.empty() &&
model_->selection_options()->always_classify_suggested_selection() &&
!filtered_collections_selection_.empty()) {
- if (!ModelClassifyText(
- context, candidates[i].span, &interpreter_manager,
- /*embedding_cache=*/nullptr, &candidates[i].classification)) {
+ if (!ModelClassifyText(context, detected_text_language_tags,
+ candidates[i].span, &interpreter_manager,
+ /*embedding_cache=*/nullptr,
+ &candidates[i].classification)) {
return original_click_indices;
}
}
@@ -636,11 +834,12 @@
}
} // namespace
-bool Annotator::ResolveConflicts(const std::vector<AnnotatedSpan>& candidates,
- const std::string& context,
- const std::vector<Token>& cached_tokens,
- InterpreterManager* interpreter_manager,
- std::vector<int>* result) const {
+bool Annotator::ResolveConflicts(
+ const std::vector<AnnotatedSpan>& candidates, const std::string& context,
+ const std::vector<Token>& cached_tokens,
+ const std::vector<Locale>& detected_text_language_tags,
+ AnnotationUsecase annotation_usecase,
+ InterpreterManager* interpreter_manager, std::vector<int>* result) const {
result->clear();
result->reserve(candidates.size());
for (int i = 0; i < candidates.size();) {
@@ -650,9 +849,10 @@
const bool conflict_found = first_non_overlapping != (i + 1);
if (conflict_found) {
std::vector<int> candidate_indices;
- if (!ResolveConflict(context, cached_tokens, candidates, i,
- first_non_overlapping, interpreter_manager,
- &candidate_indices)) {
+ if (!ResolveConflict(context, cached_tokens, candidates,
+ detected_text_language_tags, i,
+ first_non_overlapping, annotation_usecase,
+ interpreter_manager, &candidate_indices)) {
return false;
}
result->insert(result->end(), candidate_indices.begin(),
@@ -668,28 +868,53 @@
}
namespace {
-inline bool ClassifiedAsOther(
- const std::vector<ClassificationResult>& classification) {
- return !classification.empty() &&
- classification[0].collection == Annotator::kOtherCollection;
-}
+// Returns true, if the given two sources do conflict in given annotation
+// usecase.
+// - In SMART usecase, all sources do conflict, because there's only 1 possible
+// annotation for a given span.
+// - In RAW usecase, certain annotations are allowed to overlap (e.g. datetime
+// and duration), while others not (e.g. duration and number).
+bool DoSourcesConflict(AnnotationUsecase annotation_usecase,
+ const AnnotatedSpan::Source source1,
+ const AnnotatedSpan::Source source2) {
+ uint32 source_mask =
+ (1 << static_cast<int>(source1)) | (1 << static_cast<int>(source2));
-float GetPriorityScore(
- const std::vector<ClassificationResult>& classification) {
- if (!ClassifiedAsOther(classification)) {
- return classification[0].priority_score;
- } else {
- return -1.0;
+ switch (annotation_usecase) {
+ case AnnotationUsecase_ANNOTATION_USECASE_SMART:
+ // In the SMART mode, all annotations conflict.
+ return true;
+
+ case AnnotationUsecase_ANNOTATION_USECASE_RAW:
+ // DURATION and DATETIME do not conflict. E.g. "let's meet in 3 hours",
+ // can have two non-conflicting annotations: "in 3 hours" (datetime), "3
+ // hours" (duration).
+ if ((source_mask &
+ (1 << static_cast<int>(AnnotatedSpan::Source::DURATION))) &&
+ (source_mask &
+ (1 << static_cast<int>(AnnotatedSpan::Source::DATETIME)))) {
+ return false;
+ }
+
+ // A KNOWLEDGE entity does not conflict with anything.
+ if ((source_mask &
+ (1 << static_cast<int>(AnnotatedSpan::Source::KNOWLEDGE)))) {
+ return false;
+ }
+
+ // Entities from other sources can conflict.
+ return true;
}
}
} // namespace
-bool Annotator::ResolveConflict(const std::string& context,
- const std::vector<Token>& cached_tokens,
- const std::vector<AnnotatedSpan>& candidates,
- int start_index, int end_index,
- InterpreterManager* interpreter_manager,
- std::vector<int>* chosen_indices) const {
+bool Annotator::ResolveConflict(
+ const std::string& context, const std::vector<Token>& cached_tokens,
+ const std::vector<AnnotatedSpan>& candidates,
+ const std::vector<Locale>& detected_text_language_tags, int start_index,
+ int end_index, AnnotationUsecase annotation_usecase,
+ InterpreterManager* interpreter_manager,
+ std::vector<int>* chosen_indices) const {
std::vector<int> conflicting_indices;
std::unordered_map<int, float> scores;
for (int i = start_index; i < end_index; ++i) {
@@ -705,8 +930,8 @@
// candidate conflicts and comes from the model, we need to run a
// classification to determine its priority:
std::vector<ClassificationResult> classification;
- if (!ModelClassifyText(context, cached_tokens, candidates[i].span,
- interpreter_manager,
+ if (!ModelClassifyText(context, cached_tokens, detected_text_language_tags,
+ candidates[i].span, interpreter_manager,
/*embedding_cache=*/nullptr, &classification)) {
return false;
}
@@ -719,31 +944,65 @@
std::sort(conflicting_indices.begin(), conflicting_indices.end(),
[&scores](int i, int j) { return scores[i] > scores[j]; });
- // Keeps the candidates sorted by their position in the text (their left span
- // index) for fast retrieval down.
- std::set<int, std::function<bool(int, int)>> chosen_indices_set(
- [&candidates](int a, int b) {
- return candidates[a].span.first < candidates[b].span.first;
- });
+ // Here we keep a set of indices that were chosen, per-source, to enable
+ // effective computation.
+ std::unordered_map<AnnotatedSpan::Source, SortedIntSet>
+ chosen_indices_for_source_map;
// Greedily place the candidates if they don't conflict with the already
// placed ones.
for (int i = 0; i < conflicting_indices.size(); ++i) {
const int considered_candidate = conflicting_indices[i];
- if (!DoesCandidateConflict(considered_candidate, candidates,
- chosen_indices_set)) {
- chosen_indices_set.insert(considered_candidate);
+
+ // See if there is a conflict between the candidate and all already placed
+ // candidates.
+ bool conflict = false;
+ SortedIntSet* chosen_indices_for_source_ptr = nullptr;
+ for (auto& source_set_pair : chosen_indices_for_source_map) {
+ if (source_set_pair.first == candidates[considered_candidate].source) {
+ chosen_indices_for_source_ptr = &source_set_pair.second;
+ }
+
+ if (DoSourcesConflict(annotation_usecase, source_set_pair.first,
+ candidates[considered_candidate].source) &&
+ DoesCandidateConflict(considered_candidate, candidates,
+ source_set_pair.second)) {
+ conflict = true;
+ break;
+ }
}
+
+ // Skip the candidate if a conflict was found.
+ if (conflict) {
+ continue;
+ }
+
+ // If the set of indices for the current source doesn't exist yet,
+ // initialize it.
+ if (chosen_indices_for_source_ptr == nullptr) {
+ SortedIntSet new_set([&candidates](int a, int b) {
+ return candidates[a].span.first < candidates[b].span.first;
+ });
+ chosen_indices_for_source_map[candidates[considered_candidate].source] =
+ std::move(new_set);
+ chosen_indices_for_source_ptr =
+ &chosen_indices_for_source_map[candidates[considered_candidate]
+ .source];
+ }
+
+ // Place the candidate to the output and to the per-source conflict set.
+ chosen_indices->push_back(considered_candidate);
+ chosen_indices_for_source_ptr->insert(considered_candidate);
}
- *chosen_indices =
- std::vector<int>(chosen_indices_set.begin(), chosen_indices_set.end());
+ std::sort(chosen_indices->begin(), chosen_indices->end());
return true;
}
bool Annotator::ModelSuggestSelection(
const UnicodeText& context_unicode, CodepointSpan click_indices,
+ const std::vector<Locale>& detected_text_language_tags,
InterpreterManager* interpreter_manager, std::vector<Token>* tokens,
std::vector<AnnotatedSpan>* result) const {
if (model_->triggering_options() == nullptr ||
@@ -751,6 +1010,12 @@
return true;
}
+ if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
+ ml_model_triggering_locales_,
+ /*default_value=*/true)) {
+ return true;
+ }
+
int click_pos;
*tokens = selection_feature_processor_->Tokenize(context_unicode);
selection_feature_processor_->RetokenizeAndFindClick(
@@ -847,16 +1112,13 @@
}
bool Annotator::ModelClassifyText(
- const std::string& context, CodepointSpan selection_indices,
- InterpreterManager* interpreter_manager,
+ const std::string& context,
+ const std::vector<Locale>& detected_text_language_tags,
+ CodepointSpan selection_indices, InterpreterManager* interpreter_manager,
FeatureProcessor::EmbeddingCache* embedding_cache,
std::vector<ClassificationResult>* classification_results) const {
- if (model_->triggering_options() == nullptr ||
- !(model_->triggering_options()->enabled_modes() &
- ModeFlag_CLASSIFICATION)) {
- return true;
- }
- return ModelClassifyText(context, {}, selection_indices, interpreter_manager,
+ return ModelClassifyText(context, {}, detected_text_language_tags,
+ selection_indices, interpreter_manager,
embedding_cache, classification_results);
}
@@ -912,17 +1174,53 @@
}
}
+namespace {
+// Sorts the classification results from high score to low score.
+void SortClassificationResults(
+ std::vector<ClassificationResult>* classification_results) {
+ std::sort(classification_results->begin(), classification_results->end(),
+ [](const ClassificationResult& a, const ClassificationResult& b) {
+ return a.score > b.score;
+ });
+}
+} // namespace
+
bool Annotator::ModelClassifyText(
const std::string& context, const std::vector<Token>& cached_tokens,
+ const std::vector<Locale>& detected_text_language_tags,
CodepointSpan selection_indices, InterpreterManager* interpreter_manager,
FeatureProcessor::EmbeddingCache* embedding_cache,
std::vector<ClassificationResult>* classification_results) const {
std::vector<Token> tokens;
+ return ModelClassifyText(context, cached_tokens, detected_text_language_tags,
+ selection_indices, interpreter_manager,
+ embedding_cache, classification_results, &tokens);
+}
+
+bool Annotator::ModelClassifyText(
+ const std::string& context, const std::vector<Token>& cached_tokens,
+ const std::vector<Locale>& detected_text_language_tags,
+ CodepointSpan selection_indices, InterpreterManager* interpreter_manager,
+ FeatureProcessor::EmbeddingCache* embedding_cache,
+ std::vector<ClassificationResult>* classification_results,
+ std::vector<Token>* tokens) const {
+ if (model_->triggering_options() == nullptr ||
+ !(model_->triggering_options()->enabled_modes() &
+ ModeFlag_CLASSIFICATION)) {
+ return true;
+ }
+
+ if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
+ ml_model_triggering_locales_,
+ /*default_value=*/true)) {
+ return true;
+ }
+
if (cached_tokens.empty()) {
- tokens = classification_feature_processor_->Tokenize(context);
+ *tokens = classification_feature_processor_->Tokenize(context);
} else {
- tokens = internal::CopyCachedTokens(cached_tokens, selection_indices,
- ClassifyTextUpperBoundNeededTokens());
+ *tokens = internal::CopyCachedTokens(cached_tokens, selection_indices,
+ ClassifyTextUpperBoundNeededTokens());
}
int click_pos;
@@ -930,14 +1228,14 @@
context, selection_indices,
classification_feature_processor_->GetOptions()
->only_use_line_with_click(),
- &tokens, &click_pos);
+ tokens, &click_pos);
const TokenSpan selection_token_span =
- CodepointSpanToTokenSpan(tokens, selection_indices);
+ CodepointSpanToTokenSpan(*tokens, selection_indices);
const int selection_num_tokens = TokenSpanSize(selection_token_span);
if (model_->classification_options()->max_num_tokens() > 0 &&
model_->classification_options()->max_num_tokens() <
selection_num_tokens) {
- *classification_results = {{kOtherCollection, 1.0}};
+ *classification_results = {{Collections::Other(), 1.0}};
return true;
}
@@ -973,18 +1271,18 @@
/*num_tokens_left=*/context_size,
/*num_tokens_right=*/context_size);
}
- extraction_span = IntersectTokenSpans(extraction_span, {0, tokens.size()});
+ extraction_span = IntersectTokenSpans(extraction_span, {0, tokens->size()});
if (!classification_feature_processor_->HasEnoughSupportedCodepoints(
- tokens, extraction_span)) {
- *classification_results = {{kOtherCollection, 1.0}};
+ *tokens, extraction_span)) {
+ *classification_results = {{Collections::Other(), 1.0}};
return true;
}
std::unique_ptr<CachedFeatures> cached_features;
if (!classification_feature_processor_->ExtractFeatures(
- tokens, extraction_span, selection_indices, embedding_executor_.get(),
- embedding_cache,
+ *tokens, extraction_span, selection_indices,
+ embedding_executor_.get(), embedding_cache,
classification_feature_processor_->EmbeddingSize() +
classification_feature_processor_->DenseFeaturesCount(),
&cached_features)) {
@@ -1019,45 +1317,51 @@
const std::vector<float> scores =
ComputeSoftmax(logits.data(), logits.dim(1));
- classification_results->resize(scores.size());
- for (int i = 0; i < scores.size(); i++) {
- (*classification_results)[i] = {
- classification_feature_processor_->LabelToCollection(i), scores[i]};
+ if (scores.empty()) {
+ *classification_results = {{Collections::Other(), 1.0}};
+ return true;
}
- std::sort(classification_results->begin(), classification_results->end(),
- [](const ClassificationResult& a, const ClassificationResult& b) {
- return a.score > b.score;
- });
- // Phone class sanity check.
- if (!classification_results->empty() &&
- classification_results->begin()->collection == kPhoneCollection) {
+ const int best_score_index =
+ std::max_element(scores.begin(), scores.end()) - scores.begin();
+ const std::string top_collection =
+ classification_feature_processor_->LabelToCollection(best_score_index);
+
+ // Sanity checks.
+ if (top_collection == Collections::Phone()) {
const int digit_count = CountDigits(context, selection_indices);
if (digit_count <
model_->classification_options()->phone_min_num_digits() ||
digit_count >
model_->classification_options()->phone_max_num_digits()) {
- *classification_results = {{kOtherCollection, 1.0}};
+ *classification_results = {{Collections::Other(), 1.0}};
+ return true;
}
- }
-
- // Address class sanity check.
- if (!classification_results->empty() &&
- classification_results->begin()->collection == kAddressCollection) {
+ } else if (top_collection == Collections::Address()) {
if (selection_num_tokens <
model_->classification_options()->address_min_num_tokens()) {
- *classification_results = {{kOtherCollection, 1.0}};
+ *classification_results = {{Collections::Other(), 1.0}};
+ return true;
+ }
+ } else if (top_collection == Collections::Dictionary()) {
+ if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
+ dictionary_locales_,
+ /*default_value=*/false)) {
+ *classification_results = {{Collections::Other(), 1.0}};
+ return true;
}
}
+ *classification_results = {{top_collection, 1.0, scores[best_score_index]}};
return true;
}
bool Annotator::RegexClassifyText(
const std::string& context, CodepointSpan selection_indices,
- ClassificationResult* classification_result) const {
+ std::vector<ClassificationResult>* classification_result) const {
const std::string selection_text =
- ExtractSelection(context, selection_indices);
+ UTF8ToUnicodeText(context, /*do_copy=*/false)
+ .UTF8Substring(selection_indices.first, selection_indices.second);
const UnicodeText selection_text_unicode(
UTF8ToUnicodeText(selection_text, /*do_copy=*/false));
@@ -1068,8 +1372,7 @@
regex_pattern.pattern->Matcher(selection_text_unicode);
int status = UniLib::RegexMatcher::kNoError;
bool matches;
- if (regex_approximate_match_pattern_ids_.find(pattern_id) !=
- regex_approximate_match_pattern_ids_.end()) {
+ if (regex_pattern.config->use_approximate_matching()) {
matches = matcher->ApproximatelyMatches(&status);
} else {
matches = matcher->Matches(&status);
@@ -1077,36 +1380,71 @@
if (status != UniLib::RegexMatcher::kNoError) {
return false;
}
- if (matches &&
- VerifyCandidate(regex_pattern.verification_options, selection_text)) {
- *classification_result = {regex_pattern.collection_name,
- regex_pattern.target_classification_score,
- regex_pattern.priority_score};
- return true;
- }
- if (status != UniLib::RegexMatcher::kNoError) {
- TC3_LOG(ERROR) << "Cound't match regex: " << pattern_id;
+ if (matches && VerifyRegexMatchCandidate(
+ context, regex_pattern.config->verification_options(),
+ selection_text, matcher.get())) {
+ classification_result->push_back(
+ {regex_pattern.config->collection_name()->str(),
+ regex_pattern.config->target_classification_score(),
+ regex_pattern.config->priority_score()});
+ if (!SerializedEntityDataFromRegexMatch(
+ regex_pattern.config, matcher.get(),
+ &classification_result->back().serialized_entity_data)) {
+ TC3_LOG(ERROR) << "Could not get entity data.";
+ return false;
+ }
}
}
- return false;
+ return true;
}
+namespace {
+std::string PickCollectionForDatetime(
+ const DatetimeParseResult& datetime_parse_result) {
+ switch (datetime_parse_result.granularity) {
+ case GRANULARITY_HOUR:
+ case GRANULARITY_MINUTE:
+ case GRANULARITY_SECOND:
+ return Collections::DateTime();
+ default:
+ return Collections::Date();
+ }
+}
+
+std::string CreateDatetimeSerializedEntityData(
+ const DatetimeParseResult& parse_result) {
+ EntityDataT entity_data;
+ entity_data.datetime.reset(new EntityData_::DatetimeT());
+ entity_data.datetime->time_ms_utc = parse_result.time_ms_utc;
+ entity_data.datetime->granularity =
+ static_cast<EntityData_::Datetime_::Granularity>(
+ parse_result.granularity);
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishEntityDataBuffer(builder, EntityData::Pack(builder, &entity_data));
+ return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize());
+}
+} // namespace
+
bool Annotator::DatetimeClassifyText(
const std::string& context, CodepointSpan selection_indices,
const ClassificationOptions& options,
- ClassificationResult* classification_result) const {
+ std::vector<ClassificationResult>* classification_results) const {
if (!datetime_parser_) {
return false;
}
const std::string selection_text =
- ExtractSelection(context, selection_indices);
+ UTF8ToUnicodeText(context, /*do_copy=*/false)
+ .UTF8Substring(selection_indices.first, selection_indices.second);
std::vector<DatetimeParseResultSpan> datetime_spans;
if (!datetime_parser_->Parse(selection_text, options.reference_time_ms_utc,
options.reference_timezone, options.locales,
ModeFlag_CLASSIFICATION,
+ options.annotation_usecase,
/*anchor_start_end=*/true, &datetime_spans)) {
TC3_LOG(ERROR) << "Error during parsing datetime.";
return false;
@@ -1117,13 +1455,20 @@
if (std::make_pair(datetime_span.span.first + selection_indices.first,
datetime_span.span.second + selection_indices.first) ==
selection_indices) {
- *classification_result = {kDateCollection,
- datetime_span.target_classification_score};
- classification_result->datetime_parse_result = datetime_span.data;
+ for (const DatetimeParseResult& parse_result : datetime_span.data) {
+ classification_results->emplace_back(
+ PickCollectionForDatetime(parse_result),
+ datetime_span.target_classification_score);
+ classification_results->back().datetime_parse_result = parse_result;
+ classification_results->back().serialized_entity_data =
+ CreateDatetimeSerializedEntityData(parse_result);
+ classification_results->back().priority_score =
+ datetime_span.priority_score;
+ }
return true;
}
}
- return false;
+ return true;
}
std::vector<ClassificationResult> Annotator::ClassifyText(
@@ -1138,6 +1483,19 @@
return {};
}
+ std::vector<Locale> detected_text_language_tags;
+ if (!ParseLocales(options.detected_text_language_tags,
+ &detected_text_language_tags)) {
+ TC3_LOG(WARNING)
+ << "Failed to parse the detected_text_language_tags in options: "
+ << options.detected_text_language_tags;
+ }
+ if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
+ model_triggering_locales_,
+ /*default_value=*/true)) {
+ return {};
+ }
+
if (!UTF8ToUnicodeText(context, /*do_copy=*/false).is_valid()) {
return {};
}
@@ -1149,66 +1507,145 @@
return {};
}
+ // We'll accumulate a list of candidates, and pick the best candidate in the
+ // end.
+ std::vector<AnnotatedSpan> candidates;
+
// Try the knowledge engine.
+ // TODO(b/126579108): Propagate error status.
ClassificationResult knowledge_result;
if (knowledge_engine_ && knowledge_engine_->ClassifyText(
context, selection_indices, &knowledge_result)) {
- if (!FilteredForClassification(knowledge_result)) {
- return {knowledge_result};
- } else {
- return {{kOtherCollection, 1.0}};
- }
+ candidates.push_back({selection_indices, {knowledge_result}});
+ candidates.back().source = AnnotatedSpan::Source::KNOWLEDGE;
+ }
+
+ // Try the contact engine.
+ // TODO(b/126579108): Propagate error status.
+ ClassificationResult contact_result;
+ if (contact_engine_ && contact_engine_->ClassifyText(
+ context, selection_indices, &contact_result)) {
+ candidates.push_back({selection_indices, {contact_result}});
+ }
+
+ // Try the installed app engine.
+ // TODO(b/126579108): Propagate error status.
+ ClassificationResult installed_app_result;
+ if (installed_app_engine_ &&
+ installed_app_engine_->ClassifyText(context, selection_indices,
+ &installed_app_result)) {
+ candidates.push_back({selection_indices, {installed_app_result}});
}
// Try the regular expression models.
- ClassificationResult regex_result;
- if (RegexClassifyText(context, selection_indices, ®ex_result)) {
- if (!FilteredForClassification(regex_result)) {
- return {regex_result};
- } else {
- return {{kOtherCollection, 1.0}};
- }
+ std::vector<ClassificationResult> regex_results;
+ if (!RegexClassifyText(context, selection_indices, ®ex_results)) {
+ return {};
+ }
+ for (const ClassificationResult& result : regex_results) {
+ candidates.push_back({selection_indices, {result}});
}
// Try the date model.
- ClassificationResult datetime_result;
- if (DatetimeClassifyText(context, selection_indices, options,
- &datetime_result)) {
- if (!FilteredForClassification(datetime_result)) {
- return {datetime_result};
- } else {
- return {{kOtherCollection, 1.0}};
- }
+ //
+ // DatetimeClassifyText only returns the first result, which can however have
+ // more interpretations. They are inserted in the candidates as a single
+ // AnnotatedSpan, so that they get treated together by the conflict resolution
+ // algorithm.
+ std::vector<ClassificationResult> datetime_results;
+ if (!DatetimeClassifyText(context, selection_indices, options,
+ &datetime_results)) {
+ return {};
+ }
+ if (!datetime_results.empty()) {
+ candidates.push_back({selection_indices, std::move(datetime_results)});
+ candidates.back().source = AnnotatedSpan::Source::DATETIME;
}
- // Fallback to the model.
- std::vector<ClassificationResult> model_result;
+ // Try the number annotator.
+ // TODO(b/126579108): Propagate error status.
+ ClassificationResult number_annotator_result;
+ if (number_annotator_ &&
+ number_annotator_->ClassifyText(
+ UTF8ToUnicodeText(context, /*do_copy=*/false), selection_indices,
+ options.annotation_usecase, &number_annotator_result)) {
+ candidates.push_back({selection_indices, {number_annotator_result}});
+ }
+ // Try the duration annotator.
+ ClassificationResult duration_annotator_result;
+ if (duration_annotator_ &&
+ duration_annotator_->ClassifyText(
+ UTF8ToUnicodeText(context, /*do_copy=*/false), selection_indices,
+ options.annotation_usecase, &duration_annotator_result)) {
+ candidates.push_back({selection_indices, {duration_annotator_result}});
+ candidates.back().source = AnnotatedSpan::Source::DURATION;
+ }
+
+ // Try the ML model.
+ //
+ // The output of the model is considered as an exclusive 1-of-N choice. That's
+ // why it's inserted as only 1 AnnotatedSpan into candidates, as opposed to 1
+ // span for each candidate, like e.g. the regex model.
InterpreterManager interpreter_manager(selection_executor_.get(),
classification_executor_.get());
- if (ModelClassifyText(context, selection_indices, &interpreter_manager,
- /*embedding_cache=*/nullptr, &model_result) &&
- !model_result.empty()) {
- if (!FilteredForClassification(model_result[0])) {
- return model_result;
- } else {
- return {{kOtherCollection, 1.0}};
+ std::vector<ClassificationResult> model_results;
+ std::vector<Token> tokens;
+ if (!ModelClassifyText(
+ context, /*cached_tokens=*/{}, detected_text_language_tags,
+ selection_indices, &interpreter_manager,
+ /*embedding_cache=*/nullptr, &model_results, &tokens)) {
+ return {};
+ }
+ if (!model_results.empty()) {
+ candidates.push_back({selection_indices, std::move(model_results)});
+ }
+
+ std::vector<int> candidate_indices;
+ if (!ResolveConflicts(candidates, context, tokens,
+ detected_text_language_tags, options.annotation_usecase,
+ &interpreter_manager, &candidate_indices)) {
+ TC3_LOG(ERROR) << "Couldn't resolve conflicts.";
+ return {};
+ }
+
+ std::vector<ClassificationResult> results;
+ for (const int i : candidate_indices) {
+ for (const ClassificationResult& result : candidates[i].classification) {
+ if (!FilteredForClassification(result)) {
+ results.push_back(result);
+ }
}
}
- // No classifications.
- return {};
+ // Sort results according to score.
+ std::sort(results.begin(), results.end(),
+ [](const ClassificationResult& a, const ClassificationResult& b) {
+ return a.score > b.score;
+ });
+
+ if (results.empty()) {
+ results = {{Collections::Other(), 1.0}};
+ }
+ return results;
}
-bool Annotator::ModelAnnotate(const std::string& context,
- InterpreterManager* interpreter_manager,
- std::vector<Token>* tokens,
- std::vector<AnnotatedSpan>* result) const {
+bool Annotator::ModelAnnotate(
+ const std::string& context,
+ const std::vector<Locale>& detected_text_language_tags,
+ InterpreterManager* interpreter_manager, std::vector<Token>* tokens,
+ std::vector<AnnotatedSpan>* result) const {
if (model_->triggering_options() == nullptr ||
!(model_->triggering_options()->enabled_modes() & ModeFlag_ANNOTATION)) {
return true;
}
+ if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
+ ml_model_triggering_locales_,
+ /*default_value=*/true)) {
+ return true;
+ }
+
const UnicodeText context_unicode = UTF8ToUnicodeText(context,
/*do_copy=*/false);
std::vector<UnicodeTextRange> lines;
@@ -1223,8 +1660,8 @@
? model_->triggering_options()->min_annotate_confidence()
: 0.f);
- FeatureProcessor::EmbeddingCache embedding_cache;
for (const UnicodeTextRange& line : lines) {
+ FeatureProcessor::EmbeddingCache embedding_cache;
const std::string line_str =
UnicodeText::UTF8Substring(line.first, line.second);
@@ -1272,9 +1709,9 @@
// Skip empty spans.
if (codepoint_span.first != codepoint_span.second) {
std::vector<ClassificationResult> classification;
- if (!ModelClassifyText(line_str, *tokens, codepoint_span,
- interpreter_manager, &embedding_cache,
- &classification)) {
+ if (!ModelClassifyText(line_str, *tokens, detected_text_language_tags,
+ codepoint_span, interpreter_manager,
+ &embedding_cache, &classification)) {
TC3_LOG(ERROR) << "Could not classify text: "
<< (codepoint_span.first + offset) << " "
<< (codepoint_span.second + offset);
@@ -1309,6 +1746,29 @@
return datetime_parser_.get();
}
+void Annotator::RemoveNotEnabledEntityTypes(
+ const EnabledEntityTypes& is_entity_type_enabled,
+ std::vector<AnnotatedSpan>* annotated_spans) const {
+ for (AnnotatedSpan& annotated_span : *annotated_spans) {
+ std::vector<ClassificationResult>& classifications =
+ annotated_span.classification;
+ classifications.erase(
+ std::remove_if(classifications.begin(), classifications.end(),
+ [&is_entity_type_enabled](
+ const ClassificationResult& classification_result) {
+ return !is_entity_type_enabled(
+ classification_result.collection);
+ }),
+ classifications.end());
+ }
+ annotated_spans->erase(
+ std::remove_if(annotated_spans->begin(), annotated_spans->end(),
+ [](const AnnotatedSpan& annotated_span) {
+ return annotated_span.classification.empty();
+ }),
+ annotated_spans->end());
+}
+
std::vector<AnnotatedSpan> Annotator::Annotate(
const std::string& context, const AnnotationOptions& options) const {
std::vector<AnnotatedSpan> candidates;
@@ -1317,30 +1777,53 @@
return {};
}
- if (!UTF8ToUnicodeText(context, /*do_copy=*/false).is_valid()) {
+ const UnicodeText context_unicode =
+ UTF8ToUnicodeText(context, /*do_copy=*/false);
+ if (!context_unicode.is_valid()) {
+ return {};
+ }
+
+ std::vector<Locale> detected_text_language_tags;
+ if (!ParseLocales(options.detected_text_language_tags,
+ &detected_text_language_tags)) {
+ TC3_LOG(WARNING)
+ << "Failed to parse the detected_text_language_tags in options: "
+ << options.detected_text_language_tags;
+ }
+ if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
+ model_triggering_locales_,
+ /*default_value=*/true)) {
return {};
}
InterpreterManager interpreter_manager(selection_executor_.get(),
classification_executor_.get());
+
// Annotate with the selection model.
std::vector<Token> tokens;
- if (!ModelAnnotate(context, &interpreter_manager, &tokens, &candidates)) {
+ if (!ModelAnnotate(context, detected_text_language_tags, &interpreter_manager,
+ &tokens, &candidates)) {
TC3_LOG(ERROR) << "Couldn't run ModelAnnotate.";
return {};
}
// Annotate with the regular expression models.
if (!RegexChunk(UTF8ToUnicodeText(context, /*do_copy=*/false),
- annotation_regex_patterns_, &candidates)) {
+ annotation_regex_patterns_, &candidates,
+ options.is_serialized_entity_data_enabled)) {
TC3_LOG(ERROR) << "Couldn't run RegexChunk.";
return {};
}
// Annotate with the datetime model.
- if (!DatetimeChunk(UTF8ToUnicodeText(context, /*do_copy=*/false),
+ const EnabledEntityTypes is_entity_type_enabled(options.entity_types);
+ if ((is_entity_type_enabled(Collections::Date()) ||
+ is_entity_type_enabled(Collections::DateTime())) &&
+ !DatetimeChunk(UTF8ToUnicodeText(context, /*do_copy=*/false),
options.reference_time_ms_utc, options.reference_timezone,
- options.locales, ModeFlag_ANNOTATION, &candidates)) {
+ options.locales, ModeFlag_ANNOTATION,
+ options.annotation_usecase,
+ options.is_serialized_entity_data_enabled, &candidates)) {
TC3_LOG(ERROR) << "Couldn't run RegexChunk.";
return {};
}
@@ -1351,6 +1834,37 @@
return {};
}
+ // Annotate with the contact engine.
+ if (contact_engine_ &&
+ !contact_engine_->Chunk(context_unicode, tokens, &candidates)) {
+ TC3_LOG(ERROR) << "Couldn't run contact engine Chunk.";
+ return {};
+ }
+
+ // Annotate with the installed app engine.
+ if (installed_app_engine_ &&
+ !installed_app_engine_->Chunk(context_unicode, tokens, &candidates)) {
+ TC3_LOG(ERROR) << "Couldn't run installed app engine Chunk.";
+ return {};
+ }
+
+ // Annotate with the number annotator.
+ if (number_annotator_ != nullptr &&
+ !number_annotator_->FindAll(context_unicode, options.annotation_usecase,
+ &candidates)) {
+ TC3_LOG(ERROR) << "Couldn't run number annotator FindAll.";
+ return {};
+ }
+
+ // Annotate with the duration annotator.
+ if (is_entity_type_enabled(Collections::Duration()) &&
+ duration_annotator_ != nullptr &&
+ !duration_annotator_->FindAll(context_unicode, tokens,
+ options.annotation_usecase, &candidates)) {
+ TC3_LOG(ERROR) << "Couldn't run duration annotator FindAll.";
+ return {};
+ }
+
// Sort candidates according to their position in the input, so that the next
// code can assume that any connected component of overlapping spans forms a
// contiguous block.
@@ -1360,28 +1874,156 @@
});
std::vector<int> candidate_indices;
- if (!ResolveConflicts(candidates, context, tokens, &interpreter_manager,
- &candidate_indices)) {
+ if (!ResolveConflicts(candidates, context, tokens,
+ detected_text_language_tags, options.annotation_usecase,
+ &interpreter_manager, &candidate_indices)) {
TC3_LOG(ERROR) << "Couldn't resolve conflicts.";
return {};
}
std::vector<AnnotatedSpan> result;
result.reserve(candidate_indices.size());
+ AnnotatedSpan aggregated_span;
for (const int i : candidate_indices) {
- if (!candidates[i].classification.empty() &&
- !ClassifiedAsOther(candidates[i].classification) &&
- !FilteredForAnnotation(candidates[i])) {
- result.push_back(std::move(candidates[i]));
+ if (candidates[i].span != aggregated_span.span) {
+ if (!aggregated_span.classification.empty()) {
+ result.push_back(std::move(aggregated_span));
+ }
+ aggregated_span =
+ AnnotatedSpan(candidates[i].span, /*arg_classification=*/{});
}
+ if (candidates[i].classification.empty() ||
+ ClassifiedAsOther(candidates[i].classification) ||
+ FilteredForAnnotation(candidates[i])) {
+ continue;
+ }
+ for (ClassificationResult& classification : candidates[i].classification) {
+ aggregated_span.classification.push_back(std::move(classification));
+ }
+ }
+ if (!aggregated_span.classification.empty()) {
+ result.push_back(std::move(aggregated_span));
+ }
+
+ // We generate all candidates and remove them later (with the exception of
+ // date/time/duration entities) because there are complex interdependencies
+ // between the entity types. E.g., the TLD of an email can be interpreted as a
+ // URL, but most likely a user of the API does not want such annotations if
+ // "url" is enabled and "email" is not.
+ RemoveNotEnabledEntityTypes(is_entity_type_enabled, &result);
+
+ for (AnnotatedSpan& annotated_span : result) {
+ SortClassificationResults(&annotated_span.classification);
}
return result;
}
+CodepointSpan Annotator::ComputeSelectionBoundaries(
+ const UniLib::RegexMatcher* match,
+ const RegexModel_::Pattern* config) const {
+ if (config->capturing_group() == nullptr) {
+ // Use first capturing group to specify the selection.
+ int status = UniLib::RegexMatcher::kNoError;
+ const CodepointSpan result = {match->Start(1, &status),
+ match->End(1, &status)};
+ if (status != UniLib::RegexMatcher::kNoError) {
+ return {kInvalidIndex, kInvalidIndex};
+ }
+ return result;
+ }
+
+ CodepointSpan result = {kInvalidIndex, kInvalidIndex};
+ const int num_groups = config->capturing_group()->size();
+ for (int i = 0; i < num_groups; i++) {
+ if (!config->capturing_group()->Get(i)->extend_selection()) {
+ continue;
+ }
+
+ int status = UniLib::RegexMatcher::kNoError;
+ // Check match and adjust bounds.
+ const int group_start = match->Start(i, &status);
+ const int group_end = match->End(i, &status);
+ if (status != UniLib::RegexMatcher::kNoError) {
+ return {kInvalidIndex, kInvalidIndex};
+ }
+ if (group_start == kInvalidIndex || group_end == kInvalidIndex) {
+ continue;
+ }
+ if (result.first == kInvalidIndex) {
+ result = {group_start, group_end};
+ } else {
+ result.first = std::min(result.first, group_start);
+ result.second = std::max(result.second, group_end);
+ }
+ }
+ return result;
+}
+
+bool Annotator::HasEntityData(const RegexModel_::Pattern* pattern) const {
+ if (pattern->serialized_entity_data() != nullptr) {
+ return true;
+ }
+ if (pattern->capturing_group() != nullptr) {
+ for (const RegexModel_::Pattern_::CapturingGroup* group :
+ *pattern->capturing_group()) {
+ if (group->entity_field_path() != nullptr) {
+ return true;
+ }
+ }
+ }
+ return false;
+}
+
+bool Annotator::SerializedEntityDataFromRegexMatch(
+ const RegexModel_::Pattern* pattern, UniLib::RegexMatcher* matcher,
+ std::string* serialized_entity_data) const {
+ if (!HasEntityData(pattern)) {
+ serialized_entity_data->clear();
+ return true;
+ }
+ TC3_CHECK(entity_data_builder_ != nullptr);
+
+ std::unique_ptr<ReflectiveFlatbuffer> entity_data =
+ entity_data_builder_->NewRoot();
+
+ TC3_CHECK(entity_data != nullptr);
+
+ // Set static entity data.
+ if (pattern->serialized_entity_data() != nullptr) {
+ TC3_CHECK(entity_data != nullptr);
+ entity_data->MergeFromSerializedFlatbuffer(
+ StringPiece(pattern->serialized_entity_data()->c_str(),
+ pattern->serialized_entity_data()->size()));
+ }
+
+ // Add entity data from rule capturing groups.
+ if (pattern->capturing_group() != nullptr) {
+ const int num_groups = pattern->capturing_group()->size();
+ for (int i = 0; i < num_groups; i++) {
+ const FlatbufferFieldPath* field_path =
+ pattern->capturing_group()->Get(i)->entity_field_path();
+ if (field_path == nullptr) {
+ continue;
+ }
+ TC3_CHECK(entity_data != nullptr);
+ if (!SetFieldFromCapturingGroup(/*group_id=*/i, field_path, matcher,
+ entity_data.get())) {
+ TC3_LOG(ERROR)
+ << "Could not set entity data from rule capturing group.";
+ return false;
+ }
+ }
+ }
+
+ *serialized_entity_data = entity_data->Serialize();
+ return true;
+}
+
bool Annotator::RegexChunk(const UnicodeText& context_unicode,
const std::vector<int>& rules,
- std::vector<AnnotatedSpan>* result) const {
+ std::vector<AnnotatedSpan>* result,
+ bool is_serialized_entity_data_enabled) const {
for (int pattern_id : rules) {
const CompiledRegexPattern& regex_pattern = regex_patterns_[pattern_id];
const auto matcher = regex_pattern.pattern->Matcher(context_unicode);
@@ -1393,21 +2035,38 @@
int status = UniLib::RegexMatcher::kNoError;
while (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) {
- if (regex_pattern.verification_options) {
- if (!VerifyCandidate(regex_pattern.verification_options,
- matcher->Group(1, &status).ToUTF8String())) {
+ if (regex_pattern.config->verification_options()) {
+ if (!VerifyRegexMatchCandidate(
+ context_unicode.ToUTF8String(),
+ regex_pattern.config->verification_options(),
+ matcher->Group(1, &status).ToUTF8String(), matcher.get())) {
continue;
}
}
+
+ std::string serialized_entity_data;
+ if (is_serialized_entity_data_enabled) {
+ if (!SerializedEntityDataFromRegexMatch(
+ regex_pattern.config, matcher.get(), &serialized_entity_data)) {
+ TC3_LOG(ERROR) << "Could not get entity data.";
+ return false;
+ }
+ }
+
result->emplace_back();
+
// Selection/annotation regular expressions need to specify a capturing
// group specifying the selection.
- result->back().span = {matcher->Start(1, &status),
- matcher->End(1, &status)};
+ result->back().span =
+ ComputeSelectionBoundaries(matcher.get(), regex_pattern.config);
+
result->back().classification = {
- {regex_pattern.collection_name,
- regex_pattern.target_classification_score,
- regex_pattern.priority_score}};
+ {regex_pattern.config->collection_name()->str(),
+ regex_pattern.config->target_classification_score(),
+ regex_pattern.config->priority_score()}};
+
+ result->back().classification[0].serialized_entity_data =
+ serialized_entity_data;
}
}
return true;
@@ -1660,6 +2319,8 @@
int64 reference_time_ms_utc,
const std::string& reference_timezone,
const std::string& locales, ModeFlag mode,
+ AnnotationUsecase annotation_usecase,
+ bool is_serialized_entity_data_enabled,
std::vector<AnnotatedSpan>* result) const {
if (!datetime_parser_) {
return true;
@@ -1668,22 +2329,35 @@
std::vector<DatetimeParseResultSpan> datetime_spans;
if (!datetime_parser_->Parse(context_unicode, reference_time_ms_utc,
reference_timezone, locales, mode,
+ annotation_usecase,
/*anchor_start_end=*/false, &datetime_spans)) {
return false;
}
for (const DatetimeParseResultSpan& datetime_span : datetime_spans) {
AnnotatedSpan annotated_span;
annotated_span.span = datetime_span.span;
- annotated_span.classification = {{kDateCollection,
- datetime_span.target_classification_score,
- datetime_span.priority_score}};
- annotated_span.classification[0].datetime_parse_result = datetime_span.data;
-
+ for (const DatetimeParseResult& parse_result : datetime_span.data) {
+ annotated_span.classification.emplace_back(
+ PickCollectionForDatetime(parse_result),
+ datetime_span.target_classification_score,
+ datetime_span.priority_score);
+ annotated_span.classification.back().datetime_parse_result = parse_result;
+ if (is_serialized_entity_data_enabled) {
+ annotated_span.classification.back().serialized_entity_data =
+ CreateDatetimeSerializedEntityData(parse_result);
+ }
+ }
+ annotated_span.source = AnnotatedSpan::Source::DATETIME;
result->push_back(std::move(annotated_span));
}
return true;
}
+const Model* Annotator::model() const { return model_; }
+const reflection::Schema* Annotator::entity_data_schema() const {
+ return entity_data_schema_;
+}
+
const Model* ViewModel(const void* buffer, int size) {
if (!buffer) {
return nullptr;
@@ -1692,4 +2366,10 @@
return LoadAndVerifyModel(buffer, size);
}
+bool Annotator::LookUpKnowledgeEntity(
+ const std::string& id, std::string* serialized_knowledge_result) const {
+ return knowledge_engine_ &&
+ knowledge_engine_->LookUpEntity(id, serialized_knowledge_result);
+}
+
} // namespace libtextclassifier3
diff --git a/annotator/annotator.h b/annotator/annotator.h
index c58c03d..0b1c9f9 100644
--- a/annotator/annotator.h
+++ b/annotator/annotator.h
@@ -22,28 +22,52 @@
#include <memory>
#include <set>
#include <string>
+#include <unordered_set>
#include <vector>
+#include "annotator/contact/contact-engine.h"
#include "annotator/datetime/parser.h"
+#include "annotator/duration/duration.h"
#include "annotator/feature-processor.h"
+#include "annotator/installed_app/installed-app-engine.h"
#include "annotator/knowledge/knowledge-engine.h"
#include "annotator/model-executor.h"
#include "annotator/model_generated.h"
+#include "annotator/number/number.h"
#include "annotator/strip-unpaired-brackets.h"
#include "annotator/types.h"
#include "annotator/zlib-utils.h"
+#include "utils/flatbuffers.h"
+#include "utils/i18n/locale.h"
#include "utils/memory/mmap.h"
#include "utils/utf8/unilib.h"
#include "utils/zlib/zlib.h"
namespace libtextclassifier3 {
+// Aliases for long enum values.
+const AnnotationUsecase ANNOTATION_USECASE_SMART =
+ AnnotationUsecase_ANNOTATION_USECASE_SMART;
+const AnnotationUsecase ANNOTATION_USECASE_RAW =
+ AnnotationUsecase_ANNOTATION_USECASE_RAW;
+
struct SelectionOptions {
// Comma-separated list of locale specification for the input text (BCP 47
// tags).
std::string locales;
- static SelectionOptions Default() { return SelectionOptions(); }
+ // Comma-separated list of BCP 47 language tags.
+ std::string detected_text_language_tags;
+
+ // Tailors the output annotations according to the specified use-case.
+ AnnotationUsecase annotation_usecase = ANNOTATION_USECASE_SMART;
+
+ bool operator==(const SelectionOptions& other) const {
+ return this->locales == other.locales &&
+ this->annotation_usecase == other.annotation_usecase &&
+ this->detected_text_language_tags ==
+ other.detected_text_language_tags;
+ }
};
struct ClassificationOptions {
@@ -59,7 +83,20 @@
// tags).
std::string locales;
- static ClassificationOptions Default() { return ClassificationOptions(); }
+ // Comma-separated list of language tags.
+ std::string detected_text_language_tags;
+
+ // Tailors the output annotations according to the specified use-case.
+ AnnotationUsecase annotation_usecase = ANNOTATION_USECASE_SMART;
+
+ bool operator==(const ClassificationOptions& other) const {
+ return this->reference_time_ms_utc == other.reference_time_ms_utc &&
+ this->reference_timezone == other.reference_timezone &&
+ this->locales == other.locales &&
+ this->detected_text_language_tags ==
+ other.detected_text_language_tags &&
+ this->annotation_usecase == other.annotation_usecase;
+ }
};
struct AnnotationOptions {
@@ -75,7 +112,28 @@
// tags).
std::string locales;
- static AnnotationOptions Default() { return AnnotationOptions(); }
+ // Comma-separated list of language tags.
+ std::string detected_text_language_tags;
+
+ // List of entity types that should be used for annotation.
+ std::unordered_set<std::string> entity_types;
+
+ // If true, serialized_entity_data in the results is populated."
+ bool is_serialized_entity_data_enabled = false;
+
+ // Tailors the output annotations according to the specified use-case.
+ AnnotationUsecase annotation_usecase = ANNOTATION_USECASE_SMART;
+
+ bool operator==(const AnnotationOptions& other) const {
+ return this->reference_time_ms_utc == other.reference_time_ms_utc &&
+ this->reference_timezone == other.reference_timezone &&
+ this->locales == other.locales &&
+ this->detected_text_language_tags ==
+ other.detected_text_language_tags &&
+ this->annotation_usecase == other.annotation_usecase &&
+ this->is_serialized_entity_data_enabled ==
+ other.is_serialized_entity_data_enabled;
+ }
};
// Holds TFLite interpreters for selection and classification models.
@@ -105,6 +163,23 @@
std::unique_ptr<tflite::Interpreter> classification_interpreter_;
};
+// Stores entity types enabled for annotation, and provides operator() for
+// checking whether a given entity type is enabled.
+class EnabledEntityTypes {
+ public:
+ explicit EnabledEntityTypes(
+ const std::unordered_set<std::string>& entity_types)
+ : entity_types_(entity_types) {}
+
+ bool operator()(const std::string& entity_type) const {
+ return entity_types_.empty() ||
+ entity_types_.find(entity_type) != entity_types_.cend();
+ }
+
+ private:
+ const std::unordered_set<std::string>& entity_types_;
+};
+
// A text processing model that provides text classification, annotation,
// selection suggestion for various types.
// NOTE: This class is not thread-safe.
@@ -117,15 +192,27 @@
static std::unique_ptr<Annotator> FromScopedMmap(
std::unique_ptr<ScopedMmap>* mmap, const UniLib* unilib = nullptr,
const CalendarLib* calendarlib = nullptr);
+ static std::unique_ptr<Annotator> FromScopedMmap(
+ std::unique_ptr<ScopedMmap>* mmap, std::unique_ptr<UniLib> unilib,
+ std::unique_ptr<CalendarLib> calendarlib);
static std::unique_ptr<Annotator> FromFileDescriptor(
int fd, int offset, int size, const UniLib* unilib = nullptr,
const CalendarLib* calendarlib = nullptr);
static std::unique_ptr<Annotator> FromFileDescriptor(
+ int fd, int offset, int size, std::unique_ptr<UniLib> unilib,
+ std::unique_ptr<CalendarLib> calendarlib);
+ static std::unique_ptr<Annotator> FromFileDescriptor(
int fd, const UniLib* unilib = nullptr,
const CalendarLib* calendarlib = nullptr);
+ static std::unique_ptr<Annotator> FromFileDescriptor(
+ int fd, std::unique_ptr<UniLib> unilib,
+ std::unique_ptr<CalendarLib> calendarlib);
static std::unique_ptr<Annotator> FromPath(
const std::string& path, const UniLib* unilib = nullptr,
const CalendarLib* calendarlib = nullptr);
+ static std::unique_ptr<Annotator> FromPath(
+ const std::string& path, std::unique_ptr<UniLib> unilib,
+ std::unique_ptr<CalendarLib> calendarlib);
// Returns true if the model is ready for use.
bool IsInitialized() { return initialized_; }
@@ -133,6 +220,12 @@
// Initializes the knowledge engine with the given config.
bool InitializeKnowledgeEngine(const std::string& serialized_config);
+ // Initializes the contact engine with the given config.
+ bool InitializeContactEngine(const std::string& serialized_config);
+
+ // Initializes the installed app engine with the given config.
+ bool InitializeInstalledAppEngine(const std::string& serialized_config);
+
// Runs inference for given a context and current selection (i.e. index
// of the first and one past last selected characters (utf8 codepoint
// offsets)). Returns the indices (utf8 codepoint offsets) of the selection
@@ -143,20 +236,27 @@
// Requires that the model is a smart selection model.
CodepointSpan SuggestSelection(
const std::string& context, CodepointSpan click_indices,
- const SelectionOptions& options = SelectionOptions::Default()) const;
+ const SelectionOptions& options = SelectionOptions()) const;
// Classifies the selected text given the context string.
// Returns an empty result if an error occurs.
std::vector<ClassificationResult> ClassifyText(
const std::string& context, CodepointSpan selection_indices,
- const ClassificationOptions& options =
- ClassificationOptions::Default()) const;
+ const ClassificationOptions& options = ClassificationOptions()) const;
// Annotates given input text. The annotations are sorted by their position
// in the context string and exclude spans classified as 'other'.
std::vector<AnnotatedSpan> Annotate(
const std::string& context,
- const AnnotationOptions& options = AnnotationOptions::Default()) const;
+ const AnnotationOptions& options = AnnotationOptions()) const;
+
+ // Looks up a knowledge entity by its id. If successful, populates the
+ // serialized knowledge result and returns true.
+ bool LookUpKnowledgeEntity(const std::string& id,
+ std::string* serialized_knowledge_result) const;
+
+ const Model* model() const;
+ const reflection::Schema* entity_data_schema() const;
// Exposes the feature processor for tests and evaluations.
const FeatureProcessor* SelectionFeatureProcessorForTests() const;
@@ -165,18 +265,11 @@
// Exposes the date time parser for tests and evaluations.
const DatetimeParser* DatetimeParserForTests() const;
- // String collection names for various classes.
- static const std::string& kOtherCollection;
static const std::string& kPhoneCollection;
static const std::string& kAddressCollection;
static const std::string& kDateCollection;
static const std::string& kUrlCollection;
- static const std::string& kFlightCollection;
static const std::string& kEmailCollection;
- static const std::string& kIbanCollection;
- static const std::string& kPaymentCardCollection;
- static const std::string& kIsbnCollection;
- static const std::string& kTrackingNumberCollection;
protected:
struct ScoredChunk {
@@ -188,11 +281,14 @@
// Takes ownership of 'mmap', and thus owns the buffer that backs 'model'.
Annotator(std::unique_ptr<ScopedMmap>* mmap, const Model* model,
const UniLib* unilib, const CalendarLib* calendarlib);
+ Annotator(std::unique_ptr<ScopedMmap>* mmap, const Model* model,
+ std::unique_ptr<UniLib> unilib,
+ std::unique_ptr<CalendarLib> calendarlib);
// Constructs, validates and initializes text classifier from given model.
// Does not own the buffer that backs 'model'.
- explicit Annotator(const Model* model, const UniLib* unilib,
- const CalendarLib* calendarlib);
+ Annotator(const Model* model, const UniLib* unilib,
+ const CalendarLib* calendarlib);
// Checks that model contains all required fields, and initializes internal
// datastructures.
@@ -208,6 +304,8 @@
bool ResolveConflicts(const std::vector<AnnotatedSpan>& candidates,
const std::string& context,
const std::vector<Token>& cached_tokens,
+ const std::vector<Locale>& detected_text_language_tags,
+ AnnotationUsecase annotation_usecase,
InterpreterManager* interpreter_manager,
std::vector<int>* result) const;
@@ -217,31 +315,45 @@
bool ResolveConflict(const std::string& context,
const std::vector<Token>& cached_tokens,
const std::vector<AnnotatedSpan>& candidates,
+ const std::vector<Locale>& detected_text_language_tags,
int start_index, int end_index,
+ AnnotationUsecase annotation_usecase,
InterpreterManager* interpreter_manager,
std::vector<int>* chosen_indices) const;
// Gets selection candidates from the ML model.
// Provides the tokens produced during tokenization of the context string for
// reuse.
- bool ModelSuggestSelection(const UnicodeText& context_unicode,
- CodepointSpan click_indices,
- InterpreterManager* interpreter_manager,
- std::vector<Token>* tokens,
- std::vector<AnnotatedSpan>* result) const;
+ bool ModelSuggestSelection(
+ const UnicodeText& context_unicode, CodepointSpan click_indices,
+ const std::vector<Locale>& detected_text_language_tags,
+ InterpreterManager* interpreter_manager, std::vector<Token>* tokens,
+ std::vector<AnnotatedSpan>* result) const;
// Classifies the selected text given the context string with the
// classification model.
// Returns true if no error occurred.
bool ModelClassifyText(
const std::string& context, const std::vector<Token>& cached_tokens,
+ const std::vector<Locale>& locales, CodepointSpan selection_indices,
+ InterpreterManager* interpreter_manager,
+ FeatureProcessor::EmbeddingCache* embedding_cache,
+ std::vector<ClassificationResult>* classification_results,
+ std::vector<Token>* tokens) const;
+
+ // Same as above but doesn't output tokens.
+ bool ModelClassifyText(
+ const std::string& context, const std::vector<Token>& cached_tokens,
+ const std::vector<Locale>& detected_text_language_tags,
CodepointSpan selection_indices, InterpreterManager* interpreter_manager,
FeatureProcessor::EmbeddingCache* embedding_cache,
std::vector<ClassificationResult>* classification_results) const;
+ // Same as above but doesn't take cached tokens and doesn't output tokens.
bool ModelClassifyText(
- const std::string& context, CodepointSpan selection_indices,
- InterpreterManager* interpreter_manager,
+ const std::string& context,
+ const std::vector<Locale>& detected_text_language_tags,
+ CodepointSpan selection_indices, InterpreterManager* interpreter_manager,
FeatureProcessor::EmbeddingCache* embedding_cache,
std::vector<ClassificationResult>* classification_results) const;
@@ -251,17 +363,17 @@
TokenSpan ClassifyTextUpperBoundNeededTokens() const;
// Classifies the selected text with the regular expressions models.
- // Returns true if any regular expression matched and the result was set.
- bool RegexClassifyText(const std::string& context,
- CodepointSpan selection_indices,
- ClassificationResult* classification_result) const;
+ // Returns true if no error happened, false otherwise.
+ bool RegexClassifyText(
+ const std::string& context, CodepointSpan selection_indices,
+ std::vector<ClassificationResult>* classification_result) const;
// Classifies the selected text with the date time model.
- // Returns true if there was a match and the result was set.
- bool DatetimeClassifyText(const std::string& context,
- CodepointSpan selection_indices,
- const ClassificationOptions& options,
- ClassificationResult* classification_result) const;
+ // Returns true if no error happened, false otherwise.
+ bool DatetimeClassifyText(
+ const std::string& context, CodepointSpan selection_indices,
+ const ClassificationOptions& options,
+ std::vector<ClassificationResult>* classification_results) const;
// Chunks given input text with the selection model and classifies the spans
// with the classification model.
@@ -270,6 +382,7 @@
// Provides the tokens produced during tokenization of the context string for
// reuse.
bool ModelAnnotate(const std::string& context,
+ const std::vector<Locale>& detected_text_language_tags,
InterpreterManager* interpreter_manager,
std::vector<Token>* tokens,
std::vector<AnnotatedSpan>* result) const;
@@ -309,13 +422,16 @@
// Produces chunks isolated by a set of regular expressions.
bool RegexChunk(const UnicodeText& context_unicode,
const std::vector<int>& rules,
- std::vector<AnnotatedSpan>* result) const;
+ std::vector<AnnotatedSpan>* result,
+ bool is_serialized_entity_data_enabled) const;
// Produces chunks from the datetime parser.
bool DatetimeChunk(const UnicodeText& context_unicode,
int64 reference_time_ms_utc,
const std::string& reference_timezone,
const std::string& locales, ModeFlag mode,
+ AnnotationUsecase annotation_usecase,
+ bool is_serialized_entity_data_enabled,
std::vector<AnnotatedSpan>* result) const;
// Returns whether a classification should be filtered.
@@ -324,6 +440,25 @@
const ClassificationResult& classification) const;
bool FilteredForSelection(const AnnotatedSpan& span) const;
+ // Computes the selection boundaries from a regular expression match.
+ CodepointSpan ComputeSelectionBoundaries(
+ const UniLib::RegexMatcher* match,
+ const RegexModel_::Pattern* config) const;
+
+ // Returns whether a regex pattern provides entity data from a match.
+ bool HasEntityData(const RegexModel_::Pattern* pattern) const;
+
+ // Constructs and serializes entity data from regex matches.
+ bool SerializedEntityDataFromRegexMatch(
+ const RegexModel_::Pattern* pattern, UniLib::RegexMatcher* matcher,
+ std::string* serialized_entity_data) const;
+
+ // Verifies a regex match and returns true if verification was successful.
+ bool VerifyRegexMatchCandidate(
+ const std::string& context,
+ const VerificationOptions* verification_options, const std::string& match,
+ const UniLib::RegexMatcher* matcher) const;
+
const Model* model_;
std::unique_ptr<const ModelExecutor> selection_executor_;
@@ -337,13 +472,16 @@
private:
struct CompiledRegexPattern {
- std::string collection_name;
- float target_classification_score;
- float priority_score;
+ const RegexModel_::Pattern* config;
std::unique_ptr<UniLib::RegexPattern> pattern;
- const VerificationOptions* verification_options;
};
+ // Removes annotations the entity type of which is not in the set of enabled
+ // entity types.
+ void RemoveNotEnabledEntityTypes(
+ const EnabledEntityTypes& is_entity_type_enabled,
+ std::vector<AnnotatedSpan>* annotated_spans) const;
+
std::unique_ptr<ScopedMmap> mmap_;
bool initialized_ = false;
bool enabled_for_annotation_ = false;
@@ -354,7 +492,6 @@
std::unordered_set<std::string> filtered_collections_selection_;
std::vector<CompiledRegexPattern> regex_patterns_;
- std::unordered_set<int> regex_approximate_match_pattern_ids_;
// Indices into regex_patterns_ for the different modes.
std::vector<int> annotation_regex_patterns_, classification_regex_patterns_,
@@ -366,6 +503,23 @@
const CalendarLib* calendarlib_;
std::unique_ptr<const KnowledgeEngine> knowledge_engine_;
+ std::unique_ptr<const ContactEngine> contact_engine_;
+ std::unique_ptr<const InstalledAppEngine> installed_app_engine_;
+ std::unique_ptr<const NumberAnnotator> number_annotator_;
+ std::unique_ptr<const DurationAnnotator> duration_annotator_;
+
+ // Builder for creating extra data.
+ const reflection::Schema* entity_data_schema_;
+ std::unique_ptr<ReflectiveFlatbufferBuilder> entity_data_builder_;
+
+ // Locales for which the entire model triggers.
+ std::vector<Locale> model_triggering_locales_;
+
+ // Locales for which the ML model triggers.
+ std::vector<Locale> ml_model_triggering_locales_;
+
+ // Locales that the dictionary classification support.
+ std::vector<Locale> dictionary_locales_;
};
namespace internal {
@@ -388,6 +542,23 @@
// Interprets the buffer as a Model flatbuffer and returns it for reading.
const Model* ViewModel(const void* buffer, int size);
+// Opens model from given path and runs a function, passing the loaded Model
+// flatbuffer as an argument.
+//
+// This is mainly useful if we don't want to pay the cost for the model
+// initialization because we'll be only reading some flatbuffer values from the
+// file.
+template <typename ReturnType, typename Func>
+ReturnType VisitAnnotatorModel(const std::string& path, Func function) {
+ ScopedMmap mmap(path);
+ if (!mmap.handle().ok()) {
+ function(/*model=*/nullptr);
+ }
+ const Model* model =
+ ViewModel(mmap.handle().start(), mmap.handle().num_bytes());
+ return function(model);
+}
+
} // namespace libtextclassifier3
#endif // LIBTEXTCLASSIFIER_ANNOTATOR_ANNOTATOR_H_
diff --git a/annotator/annotator_jni.cc b/annotator/annotator_jni.cc
index 9bda35a..9118f30 100644
--- a/annotator/annotator_jni.cc
+++ b/annotator/annotator_jni.cc
@@ -24,11 +24,16 @@
#include "annotator/annotator.h"
#include "annotator/annotator_jni_common.h"
+#include "annotator/types.h"
#include "utils/base/integral_types.h"
#include "utils/calendar/calendar.h"
+#include "utils/intents/intent-generator.h"
+#include "utils/intents/jni.h"
+#include "utils/java/jni-cache.h"
#include "utils/java/scoped_local_ref.h"
#include "utils/java/string_utils.h"
#include "utils/memory/mmap.h"
+#include "utils/strings/stringpiece.h"
#include "utils/utf8/unilib.h"
#ifdef TC3_UNILIB_JAVAICU
@@ -58,10 +63,183 @@
using libtextclassifier3::CodepointSpan;
namespace {
+class AnnotatorJniContext {
+ public:
+ static AnnotatorJniContext* Create(
+ const std::shared_ptr<libtextclassifier3::JniCache>& jni_cache,
+ std::unique_ptr<Annotator> model) {
+ if (jni_cache == nullptr || model == nullptr) {
+ return nullptr;
+ }
+ std::unique_ptr<IntentGenerator> intent_generator =
+ IntentGenerator::Create(model->model()->intent_options(),
+ model->model()->resources(), jni_cache);
+ std::unique_ptr<RemoteActionTemplatesHandler> template_handler =
+ libtextclassifier3::RemoteActionTemplatesHandler::Create(jni_cache);
+ if (template_handler == nullptr) {
+ return nullptr;
+ }
+ return new AnnotatorJniContext(jni_cache, std::move(model),
+ std::move(intent_generator),
+ std::move(template_handler));
+ }
-jobjectArray ClassificationResultsToJObjectArray(
- JNIEnv* env,
- const std::vector<ClassificationResult>& classification_result) {
+ std::shared_ptr<libtextclassifier3::JniCache> jni_cache() const {
+ return jni_cache_;
+ }
+
+ Annotator* model() const { return model_.get(); }
+
+ IntentGenerator* intent_generator() const { return intent_generator_.get(); }
+
+ RemoteActionTemplatesHandler* template_handler() const {
+ return template_handler_.get();
+ }
+
+ private:
+ AnnotatorJniContext(
+ const std::shared_ptr<libtextclassifier3::JniCache>& jni_cache,
+ std::unique_ptr<Annotator> model,
+ std::unique_ptr<IntentGenerator> intent_generator,
+ std::unique_ptr<RemoteActionTemplatesHandler> template_handler)
+ : jni_cache_(jni_cache),
+ model_(std::move(model)),
+ intent_generator_(std::move(intent_generator)),
+ template_handler_(std::move(template_handler)) {}
+
+ std::shared_ptr<libtextclassifier3::JniCache> jni_cache_;
+ std::unique_ptr<Annotator> model_;
+ std::unique_ptr<IntentGenerator> intent_generator_;
+ std::unique_ptr<RemoteActionTemplatesHandler> template_handler_;
+};
+
+jobject ClassificationResultWithIntentsToJObject(
+ JNIEnv* env, const AnnotatorJniContext* model_context, jobject app_context,
+ jclass result_class, jmethodID result_class_constructor,
+ jclass datetime_parse_class, jmethodID datetime_parse_class_constructor,
+ const jstring device_locales, const ClassificationOptions* options,
+ const std::string& context, const CodepointSpan& selection_indices,
+ const ClassificationResult& classification_result, bool generate_intents) {
+ jstring row_string =
+ env->NewStringUTF(classification_result.collection.c_str());
+
+ jobject row_datetime_parse = nullptr;
+ if (classification_result.datetime_parse_result.IsSet()) {
+ row_datetime_parse =
+ env->NewObject(datetime_parse_class, datetime_parse_class_constructor,
+ classification_result.datetime_parse_result.time_ms_utc,
+ classification_result.datetime_parse_result.granularity);
+ }
+
+ jbyteArray serialized_knowledge_result = nullptr;
+ const std::string& serialized_knowledge_result_string =
+ classification_result.serialized_knowledge_result;
+ if (!serialized_knowledge_result_string.empty()) {
+ serialized_knowledge_result =
+ env->NewByteArray(serialized_knowledge_result_string.size());
+ env->SetByteArrayRegion(serialized_knowledge_result, 0,
+ serialized_knowledge_result_string.size(),
+ reinterpret_cast<const jbyte*>(
+ serialized_knowledge_result_string.data()));
+ }
+
+ jstring contact_name = nullptr;
+ if (!classification_result.contact_name.empty()) {
+ contact_name =
+ env->NewStringUTF(classification_result.contact_name.c_str());
+ }
+
+ jstring contact_given_name = nullptr;
+ if (!classification_result.contact_given_name.empty()) {
+ contact_given_name =
+ env->NewStringUTF(classification_result.contact_given_name.c_str());
+ }
+
+ jstring contact_nickname = nullptr;
+ if (!classification_result.contact_nickname.empty()) {
+ contact_nickname =
+ env->NewStringUTF(classification_result.contact_nickname.c_str());
+ }
+
+ jstring contact_email_address = nullptr;
+ if (!classification_result.contact_email_address.empty()) {
+ contact_email_address =
+ env->NewStringUTF(classification_result.contact_email_address.c_str());
+ }
+
+ jstring contact_phone_number = nullptr;
+ if (!classification_result.contact_phone_number.empty()) {
+ contact_phone_number =
+ env->NewStringUTF(classification_result.contact_phone_number.c_str());
+ }
+
+ jstring contact_id = nullptr;
+ if (!classification_result.contact_id.empty()) {
+ contact_id = env->NewStringUTF(classification_result.contact_id.c_str());
+ }
+
+ jstring app_name = nullptr;
+ if (!classification_result.app_name.empty()) {
+ app_name = env->NewStringUTF(classification_result.app_name.c_str());
+ }
+
+ jstring app_package_name = nullptr;
+ if (!classification_result.app_package_name.empty()) {
+ app_package_name =
+ env->NewStringUTF(classification_result.app_package_name.c_str());
+ }
+
+ jobject extras = nullptr;
+ if (model_context->model()->entity_data_schema() != nullptr &&
+ !classification_result.serialized_entity_data.empty()) {
+ extras = model_context->template_handler()->EntityDataAsNamedVariantArray(
+ model_context->model()->entity_data_schema(),
+ classification_result.serialized_entity_data);
+ }
+
+ jbyteArray serialized_entity_data = nullptr;
+ if (!classification_result.serialized_entity_data.empty()) {
+ serialized_entity_data =
+ env->NewByteArray(classification_result.serialized_entity_data.size());
+ env->SetByteArrayRegion(
+ serialized_entity_data, 0,
+ classification_result.serialized_entity_data.size(),
+ reinterpret_cast<const jbyte*>(
+ classification_result.serialized_entity_data.data()));
+ }
+
+ jobject remote_action_templates_result = nullptr;
+ // Only generate RemoteActionTemplate for the top classification result
+ // as classifyText does not need RemoteAction from other results anyway.
+ if (generate_intents && model_context->intent_generator() != nullptr) {
+ std::vector<RemoteActionTemplate> remote_action_templates;
+ if (model_context->intent_generator()->GenerateIntents(
+ device_locales, classification_result,
+ options->reference_time_ms_utc, context, selection_indices,
+ app_context, model_context->model()->entity_data_schema(),
+ &remote_action_templates)) {
+ remote_action_templates_result =
+ model_context->template_handler()
+ ->RemoteActionTemplatesToJObjectArray(remote_action_templates);
+ }
+ }
+
+ return env->NewObject(
+ result_class, result_class_constructor, row_string,
+ static_cast<jfloat>(classification_result.score), row_datetime_parse,
+ serialized_knowledge_result, contact_name, contact_given_name,
+ contact_nickname, contact_email_address, contact_phone_number, contact_id,
+ app_name, app_package_name, extras, serialized_entity_data,
+ remote_action_templates_result, classification_result.duration_ms,
+ classification_result.numeric_value);
+}
+
+jobjectArray ClassificationResultsWithIntentsToJObjectArray(
+ JNIEnv* env, const AnnotatorJniContext* model_context, jobject app_context,
+ const jstring device_locales, const ClassificationOptions* options,
+ const std::string& context, const CodepointSpan& selection_indices,
+ const std::vector<ClassificationResult>& classification_result,
+ bool generate_intents) {
const ScopedLocalRef<jclass> result_class(
env->FindClass(TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR
"$ClassificationResult"),
@@ -82,46 +260,43 @@
const jmethodID result_class_constructor = env->GetMethodID(
result_class.get(), "<init>",
"(Ljava/lang/String;FL" TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR
- "$DatetimeResult;[B)V");
+ "$DatetimeResult;[BLjava/lang/String;Ljava/lang/String;Ljava/lang/String;"
+ "Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;"
+ "Ljava/lang/String;[L" TC3_PACKAGE_PATH TC3_NAMED_VARIANT_CLASS_NAME_STR
+ ";[B[L" TC3_PACKAGE_PATH TC3_REMOTE_ACTION_TEMPLATE_CLASS_NAME_STR
+ ";JJ)V");
const jmethodID datetime_parse_class_constructor =
env->GetMethodID(datetime_parse_class.get(), "<init>", "(JI)V");
const jobjectArray results = env->NewObjectArray(classification_result.size(),
result_class.get(), nullptr);
for (int i = 0; i < classification_result.size(); i++) {
- jstring row_string =
- env->NewStringUTF(classification_result[i].collection.c_str());
-
- jobject row_datetime_parse = nullptr;
- if (classification_result[i].datetime_parse_result.IsSet()) {
- row_datetime_parse = env->NewObject(
- datetime_parse_class.get(), datetime_parse_class_constructor,
- classification_result[i].datetime_parse_result.time_ms_utc,
- classification_result[i].datetime_parse_result.granularity);
- }
-
- jbyteArray serialized_knowledge_result = nullptr;
- const std::string& serialized_knowledge_result_string =
- classification_result[i].serialized_knowledge_result;
- if (!serialized_knowledge_result_string.empty()) {
- serialized_knowledge_result =
- env->NewByteArray(serialized_knowledge_result_string.size());
- env->SetByteArrayRegion(serialized_knowledge_result, 0,
- serialized_knowledge_result_string.size(),
- reinterpret_cast<const jbyte*>(
- serialized_knowledge_result_string.data()));
- }
-
- jobject result =
- env->NewObject(result_class.get(), result_class_constructor, row_string,
- static_cast<jfloat>(classification_result[i].score),
- row_datetime_parse, serialized_knowledge_result);
+ jobject result = ClassificationResultWithIntentsToJObject(
+ env, model_context, app_context, result_class.get(),
+ result_class_constructor, datetime_parse_class.get(),
+ datetime_parse_class_constructor, device_locales, options, context,
+ selection_indices, classification_result[i],
+ generate_intents && (i == 0));
env->SetObjectArrayElement(results, i, result);
env->DeleteLocalRef(result);
}
return results;
}
+jobjectArray ClassificationResultsToJObjectArray(
+ JNIEnv* env, const AnnotatorJniContext* model_context,
+ const std::vector<ClassificationResult>& classification_result) {
+ return ClassificationResultsWithIntentsToJObjectArray(
+ env, model_context,
+ /*(unused) app_context=*/nullptr,
+ /*(unused) devide_locale=*/nullptr,
+ /*(unusued) options=*/nullptr,
+ /*(unused) selection_text=*/"",
+ /*(unused) selection_indices=*/{kInvalidIndex, kInvalidIndex},
+ classification_result,
+ /*generate_intents=*/false);
+}
+
CodepointSpan ConvertIndicesBMPUTF8(const std::string& utf8_str,
CodepointSpan orig_indices,
bool from_utf8) {
@@ -217,7 +392,9 @@
} // namespace libtextclassifier3
+using libtextclassifier3::AnnotatorJniContext;
using libtextclassifier3::ClassificationResultsToJObjectArray;
+using libtextclassifier3::ClassificationResultsWithIntentsToJObjectArray;
using libtextclassifier3::ConvertIndicesBMPToUTF8;
using libtextclassifier3::ConvertIndicesUTF8ToBMP;
using libtextclassifier3::FromJavaAnnotationOptions;
@@ -227,47 +404,52 @@
TC3_JNI_METHOD(jlong, TC3_ANNOTATOR_CLASS_NAME, nativeNewAnnotator)
(JNIEnv* env, jobject thiz, jint fd) {
-#ifdef TC3_USE_JAVAICU
std::shared_ptr<libtextclassifier3::JniCache> jni_cache(
libtextclassifier3::JniCache::Create(env));
- return reinterpret_cast<jlong>(
- Annotator::FromFileDescriptor(fd, new UniLib(jni_cache),
- new CalendarLib(jni_cache))
- .release());
+#ifdef TC3_USE_JAVAICU
+ return reinterpret_cast<jlong>(AnnotatorJniContext::Create(
+ jni_cache,
+ Annotator::FromFileDescriptor(
+ fd, std::unique_ptr<UniLib>(new UniLib(jni_cache)),
+ std::unique_ptr<CalendarLib>(new CalendarLib(jni_cache)))));
#else
- return reinterpret_cast<jlong>(Annotator::FromFileDescriptor(fd).release());
+ return reinterpret_cast<jlong>(AnnotatorJniContext::Create(
+ jni_cache, Annotator::FromFileDescriptor(fd)));
#endif
}
TC3_JNI_METHOD(jlong, TC3_ANNOTATOR_CLASS_NAME, nativeNewAnnotatorFromPath)
(JNIEnv* env, jobject thiz, jstring path) {
const std::string path_str = ToStlString(env, path);
-#ifdef TC3_USE_JAVAICU
std::shared_ptr<libtextclassifier3::JniCache> jni_cache(
libtextclassifier3::JniCache::Create(env));
- return reinterpret_cast<jlong>(Annotator::FromPath(path_str,
- new UniLib(jni_cache),
- new CalendarLib(jni_cache))
- .release());
+#ifdef TC3_USE_JAVAICU
+ return reinterpret_cast<jlong>(AnnotatorJniContext::Create(
+ jni_cache,
+ Annotator::FromPath(
+ path_str, std::unique_ptr<UniLib>(new UniLib(jni_cache)),
+ std::unique_ptr<CalendarLib>(new CalendarLib(jni_cache)))));
#else
- return reinterpret_cast<jlong>(Annotator::FromPath(path_str).release());
+ return reinterpret_cast<jlong>(
+ AnnotatorJniContext::Create(jni_cache, Annotator::FromPath(path_str)));
#endif
}
TC3_JNI_METHOD(jlong, TC3_ANNOTATOR_CLASS_NAME,
nativeNewAnnotatorFromAssetFileDescriptor)
(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size) {
- const jint fd = libtextclassifier3::GetFdFromAssetFileDescriptor(env, afd);
-#ifdef TC3_USE_JAVAICU
std::shared_ptr<libtextclassifier3::JniCache> jni_cache(
libtextclassifier3::JniCache::Create(env));
- return reinterpret_cast<jlong>(
- Annotator::FromFileDescriptor(fd, offset, size, new UniLib(jni_cache),
- new CalendarLib(jni_cache))
- .release());
+ const jint fd = libtextclassifier3::GetFdFromAssetFileDescriptor(env, afd);
+#ifdef TC3_USE_JAVAICU
+ return reinterpret_cast<jlong>(AnnotatorJniContext::Create(
+ jni_cache,
+ Annotator::FromFileDescriptor(
+ fd, offset, size, std::unique_ptr<UniLib>(new UniLib(jni_cache)),
+ std::unique_ptr<CalendarLib>(new CalendarLib(jni_cache)))));
#else
- return reinterpret_cast<jlong>(
- Annotator::FromFileDescriptor(fd, offset, size).release());
+ return reinterpret_cast<jlong>(AnnotatorJniContext::Create(
+ jni_cache, Annotator::FromFileDescriptor(fd, offset, size)));
#endif
}
@@ -278,7 +460,7 @@
return false;
}
- Annotator* model = reinterpret_cast<Annotator*>(ptr);
+ Annotator* model = reinterpret_cast<AnnotatorJniContext*>(ptr)->model();
std::string serialized_config_string;
const int length = env->GetArrayLength(serialized_config);
@@ -290,15 +472,60 @@
return model->InitializeKnowledgeEngine(serialized_config_string);
}
+TC3_JNI_METHOD(jboolean, TC3_ANNOTATOR_CLASS_NAME,
+ nativeInitializeContactEngine)
+(JNIEnv* env, jobject thiz, jlong ptr, jbyteArray serialized_config) {
+ if (!ptr) {
+ return false;
+ }
+
+ Annotator* model = reinterpret_cast<AnnotatorJniContext*>(ptr)->model();
+
+ std::string serialized_config_string;
+ const int length = env->GetArrayLength(serialized_config);
+ serialized_config_string.resize(length);
+ env->GetByteArrayRegion(serialized_config, 0, length,
+ reinterpret_cast<jbyte*>(const_cast<char*>(
+ serialized_config_string.data())));
+
+ return model->InitializeContactEngine(serialized_config_string);
+}
+
+TC3_JNI_METHOD(jboolean, TC3_ANNOTATOR_CLASS_NAME,
+ nativeInitializeInstalledAppEngine)
+(JNIEnv* env, jobject thiz, jlong ptr, jbyteArray serialized_config) {
+ if (!ptr) {
+ return false;
+ }
+
+ Annotator* model = reinterpret_cast<AnnotatorJniContext*>(ptr)->model();
+
+ std::string serialized_config_string;
+ const int length = env->GetArrayLength(serialized_config);
+ serialized_config_string.resize(length);
+ env->GetByteArrayRegion(serialized_config, 0, length,
+ reinterpret_cast<jbyte*>(const_cast<char*>(
+ serialized_config_string.data())));
+
+ return model->InitializeInstalledAppEngine(serialized_config_string);
+}
+
+TC3_JNI_METHOD(jlong, TC3_ANNOTATOR_CLASS_NAME, nativeGetNativeModelPtr)
+(JNIEnv* env, jobject thiz, jlong ptr) {
+ if (!ptr) {
+ return 0L;
+ }
+ return reinterpret_cast<jlong>(
+ reinterpret_cast<AnnotatorJniContext*>(ptr)->model());
+}
+
TC3_JNI_METHOD(jintArray, TC3_ANNOTATOR_CLASS_NAME, nativeSuggestSelection)
(JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin,
jint selection_end, jobject options) {
if (!ptr) {
return nullptr;
}
-
- Annotator* model = reinterpret_cast<Annotator*>(ptr);
-
+ const Annotator* model = reinterpret_cast<AnnotatorJniContext*>(ptr)->model();
const std::string context_utf8 = ToStlString(env, context);
CodepointSpan input_indices =
ConvertIndicesBMPToUTF8(context_utf8, {selection_begin, selection_end});
@@ -314,20 +541,31 @@
TC3_JNI_METHOD(jobjectArray, TC3_ANNOTATOR_CLASS_NAME, nativeClassifyText)
(JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin,
- jint selection_end, jobject options) {
+ jint selection_end, jobject options, jobject app_context,
+ jstring device_locales) {
if (!ptr) {
return nullptr;
}
- Annotator* ff_model = reinterpret_cast<Annotator*>(ptr);
+ const AnnotatorJniContext* model_context =
+ reinterpret_cast<AnnotatorJniContext*>(ptr);
const std::string context_utf8 = ToStlString(env, context);
const CodepointSpan input_indices =
ConvertIndicesBMPToUTF8(context_utf8, {selection_begin, selection_end});
+ const libtextclassifier3::ClassificationOptions classification_options =
+ FromJavaClassificationOptions(env, options);
const std::vector<ClassificationResult> classification_result =
- ff_model->ClassifyText(context_utf8, input_indices,
- FromJavaClassificationOptions(env, options));
-
- return ClassificationResultsToJObjectArray(env, classification_result);
+ model_context->model()->ClassifyText(context_utf8, input_indices,
+ classification_options);
+ if (app_context != nullptr) {
+ return ClassificationResultsWithIntentsToJObjectArray(
+ env, model_context, app_context, device_locales,
+ &classification_options, context_utf8, input_indices,
+ classification_result,
+ /*generate_intents=*/true);
+ }
+ return ClassificationResultsToJObjectArray(env, model_context,
+ classification_result);
}
TC3_JNI_METHOD(jobjectArray, TC3_ANNOTATOR_CLASS_NAME, nativeAnnotate)
@@ -335,10 +573,12 @@
if (!ptr) {
return nullptr;
}
- Annotator* model = reinterpret_cast<Annotator*>(ptr);
- std::string context_utf8 = ToStlString(env, context);
- std::vector<AnnotatedSpan> annotations =
- model->Annotate(context_utf8, FromJavaAnnotationOptions(env, options));
+ const AnnotatorJniContext* model_context =
+ reinterpret_cast<AnnotatorJniContext*>(ptr);
+ const std::string context_utf8 = ToStlString(env, context);
+ const std::vector<AnnotatedSpan> annotations =
+ model_context->model()->Annotate(context_utf8,
+ FromJavaAnnotationOptions(env, options));
jclass result_class = env->FindClass(
TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR "$AnnotatedSpan");
@@ -360,11 +600,11 @@
for (int i = 0; i < annotations.size(); ++i) {
CodepointSpan span_bmp =
ConvertIndicesUTF8ToBMP(context_utf8, annotations[i].span);
- jobject result = env->NewObject(result_class, result_class_constructor,
- static_cast<jint>(span_bmp.first),
- static_cast<jint>(span_bmp.second),
- ClassificationResultsToJObjectArray(
- env, annotations[i].classification));
+ jobject result = env->NewObject(
+ result_class, result_class_constructor,
+ static_cast<jint>(span_bmp.first), static_cast<jint>(span_bmp.second),
+ ClassificationResultsToJObjectArray(env, model_context,
+ annotations[i].classification));
env->SetObjectArrayElement(results, i, result);
env->DeleteLocalRef(result);
}
@@ -372,10 +612,30 @@
return results;
}
+TC3_JNI_METHOD(jbyteArray, TC3_ANNOTATOR_CLASS_NAME,
+ nativeLookUpKnowledgeEntity)
+(JNIEnv* env, jobject thiz, jlong ptr, jstring id) {
+ if (!ptr) {
+ return nullptr;
+ }
+ const Annotator* model = reinterpret_cast<AnnotatorJniContext*>(ptr)->model();
+ const std::string id_utf8 = ToStlString(env, id);
+ std::string serialized_knowledge_result;
+ if (!model->LookUpKnowledgeEntity(id_utf8, &serialized_knowledge_result)) {
+ return nullptr;
+ }
+ jbyteArray result = env->NewByteArray(serialized_knowledge_result.size());
+ env->SetByteArrayRegion(
+ result, 0, serialized_knowledge_result.size(),
+ reinterpret_cast<const jbyte*>(serialized_knowledge_result.data()));
+ return result;
+}
+
TC3_JNI_METHOD(void, TC3_ANNOTATOR_CLASS_NAME, nativeCloseAnnotator)
(JNIEnv* env, jobject thiz, jlong ptr) {
- Annotator* model = reinterpret_cast<Annotator*>(ptr);
- delete model;
+ const AnnotatorJniContext* context =
+ reinterpret_cast<AnnotatorJniContext*>(ptr);
+ delete context;
}
TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME, nativeGetLanguage)
diff --git a/annotator/annotator_jni.h b/annotator/annotator_jni.h
index 47715b4..bca1dcd 100644
--- a/annotator/annotator_jni.h
+++ b/annotator/annotator_jni.h
@@ -42,17 +42,33 @@
nativeInitializeKnowledgeEngine)
(JNIEnv* env, jobject thiz, jlong ptr, jbyteArray serialized_config);
+TC3_JNI_METHOD(jboolean, TC3_ANNOTATOR_CLASS_NAME,
+ nativeInitializeContactEngine)
+(JNIEnv* env, jobject thiz, jlong ptr, jbyteArray serialized_config);
+
+TC3_JNI_METHOD(jboolean, TC3_ANNOTATOR_CLASS_NAME,
+ nativeInitializeInstalledAppEngine)
+(JNIEnv* env, jobject thiz, jlong ptr, jbyteArray serialized_config);
+
+TC3_JNI_METHOD(jlong, TC3_ANNOTATOR_CLASS_NAME, nativeGetNativeModelPtr)
+(JNIEnv* env, jobject thiz, jlong ptr);
+
TC3_JNI_METHOD(jintArray, TC3_ANNOTATOR_CLASS_NAME, nativeSuggestSelection)
(JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin,
jint selection_end, jobject options);
TC3_JNI_METHOD(jobjectArray, TC3_ANNOTATOR_CLASS_NAME, nativeClassifyText)
(JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin,
- jint selection_end, jobject options);
+ jint selection_end, jobject options, jobject app_context,
+ jstring device_locales);
TC3_JNI_METHOD(jobjectArray, TC3_ANNOTATOR_CLASS_NAME, nativeAnnotate)
(JNIEnv* env, jobject thiz, jlong ptr, jstring context, jobject options);
+TC3_JNI_METHOD(jbyteArray, TC3_ANNOTATOR_CLASS_NAME,
+ nativeLookUpKnowledgeEntity)
+(JNIEnv* env, jobject thiz, jlong ptr, jstring id);
+
TC3_JNI_METHOD(void, TC3_ANNOTATOR_CLASS_NAME, nativeCloseAnnotator)
(JNIEnv* env, jobject thiz, jlong ptr);
diff --git a/annotator/annotator_jni_common.cc b/annotator/annotator_jni_common.cc
index 0fdb87b..55f14e6 100644
--- a/annotator/annotator_jni_common.cc
+++ b/annotator/annotator_jni_common.cc
@@ -21,6 +21,20 @@
namespace libtextclassifier3 {
namespace {
+
+std::unordered_set<std::string> EntityTypesFromJObject(JNIEnv* env,
+ const jobject& jobject) {
+ std::unordered_set<std::string> entity_types;
+ jobjectArray jentity_types = reinterpret_cast<jobjectArray>(jobject);
+ const int size = env->GetArrayLength(jentity_types);
+ for (int i = 0; i < size; ++i) {
+ jstring jentity_type =
+ reinterpret_cast<jstring>(env->GetObjectArrayElement(jentity_types, i));
+ entity_types.insert(ToStlString(env, jentity_type));
+ }
+ return entity_types;
+}
+
template <typename T>
T FromJavaOptionsInternal(JNIEnv* env, jobject joptions,
const std::string& class_name) {
@@ -45,9 +59,18 @@
CallJniMethod0<int64>(env, joptions, options_class.get(),
&JNIEnv::CallLongMethod, "getReferenceTimeMsUtc",
"J");
+ const std::pair<bool, jobject> status_or_detected_text_language_tags =
+ CallJniMethod0<jobject>(
+ env, joptions, options_class.get(), &JNIEnv::CallObjectMethod,
+ "getDetectedTextLanguageTags", "Ljava/lang/String;");
+ const std::pair<bool, int> status_or_annotation_usecase =
+ CallJniMethod0<int>(env, joptions, options_class.get(),
+ &JNIEnv::CallIntMethod, "getAnnotationUsecase", "I");
if (!status_or_locales.first || !status_or_reference_timezone.first ||
- !status_or_reference_time_ms_utc.first) {
+ !status_or_reference_time_ms_utc.first ||
+ !status_or_detected_text_language_tags.first ||
+ !status_or_annotation_usecase.first) {
return {};
}
@@ -57,6 +80,11 @@
options.reference_timezone = ToStlString(
env, reinterpret_cast<jstring>(status_or_reference_timezone.second));
options.reference_time_ms_utc = status_or_reference_time_ms_utc.second;
+ options.detected_text_language_tags = ToStlString(
+ env,
+ reinterpret_cast<jstring>(status_or_detected_text_language_tags.second));
+ options.annotation_usecase =
+ static_cast<AnnotationUsecase>(status_or_annotation_usecase.second);
return options;
}
} // namespace
@@ -73,13 +101,18 @@
const std::pair<bool, jobject> status_or_locales = CallJniMethod0<jobject>(
env, joptions, options_class.get(), &JNIEnv::CallObjectMethod,
"getLocales", "Ljava/lang/String;");
- if (!status_or_locales.first) {
+ const std::pair<bool, int> status_or_annotation_usecase =
+ CallJniMethod0<int>(env, joptions, options_class.get(),
+ &JNIEnv::CallIntMethod, "getAnnotationUsecase", "I");
+ if (!status_or_locales.first || !status_or_annotation_usecase.first) {
return {};
}
SelectionOptions options;
options.locales =
ToStlString(env, reinterpret_cast<jstring>(status_or_locales.second));
+ options.annotation_usecase =
+ static_cast<AnnotationUsecase>(status_or_annotation_usecase.second);
return options;
}
@@ -92,9 +125,31 @@
}
AnnotationOptions FromJavaAnnotationOptions(JNIEnv* env, jobject joptions) {
- return FromJavaOptionsInternal<AnnotationOptions>(
- env, joptions,
- TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR "$AnnotationOptions");
+ if (!joptions) return {};
+ const ScopedLocalRef<jclass> options_class(
+ env->FindClass(TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR
+ "$AnnotationOptions"),
+ env);
+ if (!options_class) return {};
+ const std::pair<bool, jobject> status_or_entity_types =
+ CallJniMethod0<jobject>(env, joptions, options_class.get(),
+ &JNIEnv::CallObjectMethod, "getEntityTypes",
+ "[Ljava/lang/String;");
+ if (!status_or_entity_types.first) return {};
+ const std::pair<bool, bool> status_or_enable_serialized_entity_data =
+ CallJniMethod0<bool>(env, joptions, options_class.get(),
+ &JNIEnv::CallBooleanMethod,
+ "isSerializedEntityDataEnabled", "Z");
+ if (!status_or_enable_serialized_entity_data.first) return {};
+ AnnotationOptions annotation_options =
+ FromJavaOptionsInternal<AnnotationOptions>(
+ env, joptions,
+ TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR "$AnnotationOptions");
+ annotation_options.entity_types =
+ EntityTypesFromJObject(env, status_or_entity_types.second);
+ annotation_options.is_serialized_entity_data_enabled =
+ status_or_enable_serialized_entity_data.second;
+ return annotation_options;
}
} // namespace libtextclassifier3
diff --git a/annotator/annotator_test.cc b/annotator/annotator_test.cc
deleted file mode 100644
index fbaf039..0000000
--- a/annotator/annotator_test.cc
+++ /dev/null
@@ -1,1254 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "annotator/annotator.h"
-
-#include <fstream>
-#include <iostream>
-#include <memory>
-#include <string>
-
-#include "annotator/model_generated.h"
-#include "annotator/types-test-util.h"
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-
-namespace libtextclassifier3 {
-namespace {
-
-using testing::ElementsAreArray;
-using testing::IsEmpty;
-using testing::Pair;
-using testing::Values;
-
-std::string FirstResult(const std::vector<ClassificationResult>& results) {
- if (results.empty()) {
- return "<INVALID RESULTS>";
- }
- return results[0].collection;
-}
-
-MATCHER_P3(IsAnnotatedSpan, start, end, best_class, "") {
- return testing::Value(arg.span, Pair(start, end)) &&
- testing::Value(FirstResult(arg.classification), best_class);
-}
-
-std::string ReadFile(const std::string& file_name) {
- std::ifstream file_stream(file_name);
- return std::string(std::istreambuf_iterator<char>(file_stream), {});
-}
-
-std::string GetModelPath() {
- return TC3_TEST_DATA_DIR;
-}
-
-class AnnotatorTest : public ::testing::TestWithParam<const char*> {
- protected:
- AnnotatorTest()
- : INIT_UNILIB_FOR_TESTING(unilib_),
- INIT_CALENDARLIB_FOR_TESTING(calendarlib_) {}
- UniLib unilib_;
- CalendarLib calendarlib_;
-};
-
-TEST_F(AnnotatorTest, EmbeddingExecutorLoadingFails) {
- std::unique_ptr<Annotator> classifier = Annotator::FromPath(
- GetModelPath() + "wrong_embeddings.fb", &unilib_, &calendarlib_);
- EXPECT_FALSE(classifier);
-}
-
-INSTANTIATE_TEST_CASE_P(ClickContext, AnnotatorTest,
- Values("test_model_cc.fb"));
-INSTANTIATE_TEST_CASE_P(BoundsSensitive, AnnotatorTest,
- Values("test_model.fb"));
-
-TEST_P(AnnotatorTest, ClassifyText) {
- std::unique_ptr<Annotator> classifier =
- Annotator::FromPath(GetModelPath() + GetParam(), &unilib_, &calendarlib_);
- ASSERT_TRUE(classifier);
-
- EXPECT_EQ("other",
- FirstResult(classifier->ClassifyText(
- "this afternoon Barack Obama gave a speech at", {15, 27})));
- EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
- "Call me at (800) 123-456 today", {11, 24})));
-
- // More lines.
- EXPECT_EQ("other",
- FirstResult(classifier->ClassifyText(
- "this afternoon Barack Obama gave a speech at|Visit "
- "www.google.com every today!|Call me at (800) 123-456 today.",
- {15, 27})));
- EXPECT_EQ("phone",
- FirstResult(classifier->ClassifyText(
- "this afternoon Barack Obama gave a speech at|Visit "
- "www.google.com every today!|Call me at (800) 123-456 today.",
- {90, 103})));
-
- // Single word.
- EXPECT_EQ("other", FirstResult(classifier->ClassifyText("obama", {0, 5})));
- EXPECT_EQ("other", FirstResult(classifier->ClassifyText("asdf", {0, 4})));
- EXPECT_EQ("<INVALID RESULTS>",
- FirstResult(classifier->ClassifyText("asdf", {0, 0})));
-
- // Junk.
- EXPECT_EQ("<INVALID RESULTS>",
- FirstResult(classifier->ClassifyText("", {0, 0})));
- EXPECT_EQ("<INVALID RESULTS>", FirstResult(classifier->ClassifyText(
- "a\n\n\n\nx x x\n\n\n\n\n\n", {1, 5})));
- // Test invalid utf8 input.
- EXPECT_EQ("<INVALID RESULTS>", FirstResult(classifier->ClassifyText(
- "\xf0\x9f\x98\x8b\x8b", {0, 0})));
-}
-
-TEST_P(AnnotatorTest, ClassifyTextDisabledFail) {
- const std::string test_model = ReadFile(GetModelPath() + GetParam());
- std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
-
- unpacked_model->classification_model.clear();
- unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
- unpacked_model->triggering_options->enabled_modes = ModeFlag_SELECTION;
-
- flatbuffers::FlatBufferBuilder builder;
- FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
-
- std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
- reinterpret_cast<const char*>(builder.GetBufferPointer()),
- builder.GetSize(), &unilib_, &calendarlib_);
-
- // The classification model is still needed for selection scores.
- ASSERT_FALSE(classifier);
-}
-
-TEST_P(AnnotatorTest, ClassifyTextDisabled) {
- const std::string test_model = ReadFile(GetModelPath() + GetParam());
- std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
-
- unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
- unpacked_model->triggering_options->enabled_modes =
- ModeFlag_ANNOTATION_AND_SELECTION;
-
- flatbuffers::FlatBufferBuilder builder;
- FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
-
- std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
- reinterpret_cast<const char*>(builder.GetBufferPointer()),
- builder.GetSize(), &unilib_, &calendarlib_);
- ASSERT_TRUE(classifier);
-
- EXPECT_THAT(
- classifier->ClassifyText("Call me at (800) 123-456 today", {11, 24}),
- IsEmpty());
-}
-
-TEST_P(AnnotatorTest, ClassifyTextFilteredCollections) {
- const std::string test_model = ReadFile(GetModelPath() + GetParam());
-
- std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
- test_model.c_str(), test_model.size(), &unilib_, &calendarlib_);
- ASSERT_TRUE(classifier);
-
- EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
- "Call me at (800) 123-456 today", {11, 24})));
-
- std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
- unpacked_model->output_options.reset(new OutputOptionsT);
-
- // Disable phone classification
- unpacked_model->output_options->filtered_collections_classification.push_back(
- "phone");
-
- flatbuffers::FlatBufferBuilder builder;
- FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
-
- classifier = Annotator::FromUnownedBuffer(
- reinterpret_cast<const char*>(builder.GetBufferPointer()),
- builder.GetSize(), &unilib_, &calendarlib_);
- ASSERT_TRUE(classifier);
-
- EXPECT_EQ("other", FirstResult(classifier->ClassifyText(
- "Call me at (800) 123-456 today", {11, 24})));
-
- // Check that the address classification still passes.
- EXPECT_EQ("address", FirstResult(classifier->ClassifyText(
- "350 Third Street, Cambridge", {0, 27})));
-}
-
-std::unique_ptr<RegexModel_::PatternT> MakePattern(
- const std::string& collection_name, const std::string& pattern,
- const bool enabled_for_classification, const bool enabled_for_selection,
- const bool enabled_for_annotation, const float score) {
- std::unique_ptr<RegexModel_::PatternT> result(new RegexModel_::PatternT);
- result->collection_name = collection_name;
- result->pattern = pattern;
- // We cannot directly operate with |= on the flag, so use an int here.
- int enabled_modes = ModeFlag_NONE;
- if (enabled_for_annotation) enabled_modes |= ModeFlag_ANNOTATION;
- if (enabled_for_classification) enabled_modes |= ModeFlag_CLASSIFICATION;
- if (enabled_for_selection) enabled_modes |= ModeFlag_SELECTION;
- result->enabled_modes = static_cast<ModeFlag>(enabled_modes);
- result->target_classification_score = score;
- result->priority_score = score;
- return result;
-}
-
-#ifdef TC3_UNILIB_ICU
-TEST_P(AnnotatorTest, ClassifyTextRegularExpression) {
- const std::string test_model = ReadFile(GetModelPath() + GetParam());
- std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
-
- // Add test regex models.
- unpacked_model->regex_model->patterns.push_back(MakePattern(
- "person", "Barack Obama", /*enabled_for_classification=*/true,
- /*enabled_for_selection=*/false, /*enabled_for_annotation=*/false, 1.0));
- unpacked_model->regex_model->patterns.push_back(MakePattern(
- "flight", "[a-zA-Z]{2}\\d{2,4}", /*enabled_for_classification=*/true,
- /*enabled_for_selection=*/false, /*enabled_for_annotation=*/false, 0.5));
- std::unique_ptr<RegexModel_::PatternT> verified_pattern =
- MakePattern("payment_card", "\\d{4}(?: \\d{4}){3}",
- /*enabled_for_classification=*/true,
- /*enabled_for_selection=*/false,
- /*enabled_for_annotation=*/false, 1.0);
- verified_pattern->verification_options.reset(new VerificationOptionsT);
- verified_pattern->verification_options->verify_luhn_checksum = true;
- unpacked_model->regex_model->patterns.push_back(std::move(verified_pattern));
-
- flatbuffers::FlatBufferBuilder builder;
- FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
-
- std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
- reinterpret_cast<const char*>(builder.GetBufferPointer()),
- builder.GetSize(), &unilib_, &calendarlib_);
- ASSERT_TRUE(classifier);
-
- EXPECT_EQ("flight",
- FirstResult(classifier->ClassifyText(
- "Your flight LX373 is delayed by 3 hours.", {12, 17})));
- EXPECT_EQ("person",
- FirstResult(classifier->ClassifyText(
- "this afternoon Barack Obama gave a speech at", {15, 27})));
- EXPECT_EQ("email",
- FirstResult(classifier->ClassifyText("you@android.com", {0, 15})));
- EXPECT_EQ("email", FirstResult(classifier->ClassifyText(
- "Contact me at you@android.com", {14, 29})));
-
- EXPECT_EQ("url", FirstResult(classifier->ClassifyText(
- "Visit www.google.com every today!", {6, 20})));
-
- EXPECT_EQ("flight", FirstResult(classifier->ClassifyText("LX 37", {0, 5})));
- EXPECT_EQ("flight", FirstResult(classifier->ClassifyText("flight LX 37 abcd",
- {7, 12})));
- EXPECT_EQ("payment_card", FirstResult(classifier->ClassifyText(
- "cc: 4012 8888 8888 1881", {4, 23})));
- EXPECT_EQ("payment_card", FirstResult(classifier->ClassifyText(
- "2221 0067 4735 6281", {0, 19})));
- // Luhn check fails.
- EXPECT_EQ("other", FirstResult(classifier->ClassifyText("2221 0067 4735 6282",
- {0, 19})));
-
- // More lines.
- EXPECT_EQ("url",
- FirstResult(classifier->ClassifyText(
- "this afternoon Barack Obama gave a speech at|Visit "
- "www.google.com every today!|Call me at (800) 123-456 today.",
- {51, 65})));
-}
-#endif // TC3_UNILIB_ICU
-
-#ifdef TC3_UNILIB_ICU
-TEST_P(AnnotatorTest, SuggestSelectionRegularExpression) {
- const std::string test_model = ReadFile(GetModelPath() + GetParam());
- std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
-
- // Add test regex models.
- unpacked_model->regex_model->patterns.push_back(MakePattern(
- "person", " (Barack Obama) ", /*enabled_for_classification=*/false,
- /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
- unpacked_model->regex_model->patterns.push_back(MakePattern(
- "flight", "([a-zA-Z]{2} ?\\d{2,4})", /*enabled_for_classification=*/false,
- /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
- unpacked_model->regex_model->patterns.back()->priority_score = 1.1;
- std::unique_ptr<RegexModel_::PatternT> verified_pattern =
- MakePattern("payment_card", "(\\d{4}(?: \\d{4}){3})",
- /*enabled_for_classification=*/false,
- /*enabled_for_selection=*/true,
- /*enabled_for_annotation=*/false, 1.0);
- verified_pattern->verification_options.reset(new VerificationOptionsT);
- verified_pattern->verification_options->verify_luhn_checksum = true;
- unpacked_model->regex_model->patterns.push_back(std::move(verified_pattern));
-
- flatbuffers::FlatBufferBuilder builder;
- FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
-
- std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
- reinterpret_cast<const char*>(builder.GetBufferPointer()),
- builder.GetSize(), &unilib_, &calendarlib_);
- ASSERT_TRUE(classifier);
-
- // Check regular expression selection.
- EXPECT_EQ(classifier->SuggestSelection(
- "Your flight MA 0123 is delayed by 3 hours.", {12, 14}),
- std::make_pair(12, 19));
- EXPECT_EQ(classifier->SuggestSelection(
- "this afternoon Barack Obama gave a speech at", {15, 21}),
- std::make_pair(15, 27));
- EXPECT_EQ(classifier->SuggestSelection("cc: 4012 8888 8888 1881", {9, 14}),
- std::make_pair(4, 23));
-}
-#endif // TC3_UNILIB_ICU
-
-#ifdef TC3_UNILIB_ICU
-TEST_P(AnnotatorTest, SuggestSelectionRegularExpressionConflictsModelWins) {
- const std::string test_model = ReadFile(GetModelPath() + GetParam());
- std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
-
- // Add test regex models.
- unpacked_model->regex_model->patterns.push_back(MakePattern(
- "person", " (Barack Obama) ", /*enabled_for_classification=*/false,
- /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
- unpacked_model->regex_model->patterns.push_back(MakePattern(
- "flight", "([a-zA-Z]{2} ?\\d{2,4})", /*enabled_for_classification=*/false,
- /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
- unpacked_model->regex_model->patterns.back()->priority_score = 0.5;
-
- flatbuffers::FlatBufferBuilder builder;
- FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
-
- std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
- reinterpret_cast<const char*>(builder.GetBufferPointer()),
- builder.GetSize());
- ASSERT_TRUE(classifier);
-
- // Check conflict resolution.
- EXPECT_EQ(
- classifier->SuggestSelection(
- "saw Barack Obama today .. 350 Third Street, Cambridge, MA 0123",
- {55, 57}),
- std::make_pair(26, 62));
-}
-#endif // TC3_UNILIB_ICU
-
-#ifdef TC3_UNILIB_ICU
-TEST_P(AnnotatorTest, SuggestSelectionRegularExpressionConflictsRegexWins) {
- const std::string test_model = ReadFile(GetModelPath() + GetParam());
- std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
-
- // Add test regex models.
- unpacked_model->regex_model->patterns.push_back(MakePattern(
- "person", " (Barack Obama) ", /*enabled_for_classification=*/false,
- /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
- unpacked_model->regex_model->patterns.push_back(MakePattern(
- "flight", "([a-zA-Z]{2} ?\\d{2,4})", /*enabled_for_classification=*/false,
- /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
- unpacked_model->regex_model->patterns.back()->priority_score = 1.1;
-
- flatbuffers::FlatBufferBuilder builder;
- FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
-
- std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
- reinterpret_cast<const char*>(builder.GetBufferPointer()),
- builder.GetSize());
- ASSERT_TRUE(classifier);
-
- // Check conflict resolution.
- EXPECT_EQ(
- classifier->SuggestSelection(
- "saw Barack Obama today .. 350 Third Street, Cambridge, MA 0123",
- {55, 57}),
- std::make_pair(55, 62));
-}
-#endif // TC3_UNILIB_ICU
-
-#ifdef TC3_UNILIB_ICU
-TEST_P(AnnotatorTest, AnnotateRegex) {
- const std::string test_model = ReadFile(GetModelPath() + GetParam());
- std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
-
- // Add test regex models.
- unpacked_model->regex_model->patterns.push_back(MakePattern(
- "person", " (Barack Obama) ", /*enabled_for_classification=*/false,
- /*enabled_for_selection=*/false, /*enabled_for_annotation=*/true, 1.0));
- unpacked_model->regex_model->patterns.push_back(MakePattern(
- "flight", "([a-zA-Z]{2} ?\\d{2,4})", /*enabled_for_classification=*/false,
- /*enabled_for_selection=*/false, /*enabled_for_annotation=*/true, 0.5));
- std::unique_ptr<RegexModel_::PatternT> verified_pattern =
- MakePattern("payment_card", "(\\d{4}(?: \\d{4}){3})",
- /*enabled_for_classification=*/false,
- /*enabled_for_selection=*/false,
- /*enabled_for_annotation=*/true, 1.0);
- verified_pattern->verification_options.reset(new VerificationOptionsT);
- verified_pattern->verification_options->verify_luhn_checksum = true;
- unpacked_model->regex_model->patterns.push_back(std::move(verified_pattern));
- flatbuffers::FlatBufferBuilder builder;
- FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
-
- std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
- reinterpret_cast<const char*>(builder.GetBufferPointer()),
- builder.GetSize(), &unilib_, &calendarlib_);
- ASSERT_TRUE(classifier);
-
- const std::string test_string =
- "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
- "number is 853 225 3556\nand my card is 4012 8888 8888 1881.\n";
- EXPECT_THAT(classifier->Annotate(test_string),
- ElementsAreArray({IsAnnotatedSpan(6, 18, "person"),
- IsAnnotatedSpan(28, 55, "address"),
- IsAnnotatedSpan(79, 91, "phone"),
- IsAnnotatedSpan(107, 126, "payment_card")}));
-}
-#endif // TC3_UNILIB_ICU
-
-TEST_P(AnnotatorTest, PhoneFiltering) {
- std::unique_ptr<Annotator> classifier =
- Annotator::FromPath(GetModelPath() + GetParam(), &unilib_, &calendarlib_);
- ASSERT_TRUE(classifier);
-
- EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
- "phone: (123) 456 789", {7, 20})));
- EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
- "phone: (123) 456 789,0001112", {7, 25})));
- EXPECT_EQ("other", FirstResult(classifier->ClassifyText(
- "phone: (123) 456 789,0001112", {7, 28})));
-}
-
-TEST_P(AnnotatorTest, SuggestSelection) {
- std::unique_ptr<Annotator> classifier =
- Annotator::FromPath(GetModelPath() + GetParam(), &unilib_, &calendarlib_);
- ASSERT_TRUE(classifier);
-
- EXPECT_EQ(classifier->SuggestSelection(
- "this afternoon Barack Obama gave a speech at", {15, 21}),
- std::make_pair(15, 21));
-
- // Try passing whole string.
- // If more than 1 token is specified, we should return back what entered.
- EXPECT_EQ(
- classifier->SuggestSelection("350 Third Street, Cambridge", {0, 27}),
- std::make_pair(0, 27));
-
- // Single letter.
- EXPECT_EQ(classifier->SuggestSelection("a", {0, 1}), std::make_pair(0, 1));
-
- // Single word.
- EXPECT_EQ(classifier->SuggestSelection("asdf", {0, 4}), std::make_pair(0, 4));
-
- EXPECT_EQ(
- classifier->SuggestSelection("call me at 857 225 3556 today", {11, 14}),
- std::make_pair(11, 23));
-
- // Unpaired bracket stripping.
- EXPECT_EQ(
- classifier->SuggestSelection("call me at (857) 225 3556 today", {11, 16}),
- std::make_pair(11, 25));
- EXPECT_EQ(classifier->SuggestSelection("call me at (857 today", {11, 15}),
- std::make_pair(12, 15));
- EXPECT_EQ(classifier->SuggestSelection("call me at 3556) today", {11, 16}),
- std::make_pair(11, 15));
- EXPECT_EQ(classifier->SuggestSelection("call me at )857( today", {11, 16}),
- std::make_pair(12, 15));
-
- // If the resulting selection would be empty, the original span is returned.
- EXPECT_EQ(classifier->SuggestSelection("call me at )( today", {11, 13}),
- std::make_pair(11, 13));
- EXPECT_EQ(classifier->SuggestSelection("call me at ( today", {11, 12}),
- std::make_pair(11, 12));
- EXPECT_EQ(classifier->SuggestSelection("call me at ) today", {11, 12}),
- std::make_pair(11, 12));
-}
-
-TEST_P(AnnotatorTest, SuggestSelectionDisabledFail) {
- const std::string test_model = ReadFile(GetModelPath() + GetParam());
- std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
-
- // Disable the selection model.
- unpacked_model->selection_model.clear();
- unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
- unpacked_model->triggering_options->enabled_modes = ModeFlag_ANNOTATION;
-
- flatbuffers::FlatBufferBuilder builder;
- FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
-
- std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
- reinterpret_cast<const char*>(builder.GetBufferPointer()),
- builder.GetSize(), &unilib_, &calendarlib_);
- // Selection model needs to be present for annotation.
- ASSERT_FALSE(classifier);
-}
-
-TEST_P(AnnotatorTest, SuggestSelectionDisabled) {
- const std::string test_model = ReadFile(GetModelPath() + GetParam());
- std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
-
- // Disable the selection model.
- unpacked_model->selection_model.clear();
- unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
- unpacked_model->triggering_options->enabled_modes = ModeFlag_CLASSIFICATION;
- unpacked_model->enabled_modes = ModeFlag_CLASSIFICATION;
-
- flatbuffers::FlatBufferBuilder builder;
- FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
-
- std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
- reinterpret_cast<const char*>(builder.GetBufferPointer()),
- builder.GetSize(), &unilib_, &calendarlib_);
- ASSERT_TRUE(classifier);
-
- EXPECT_EQ(
- classifier->SuggestSelection("call me at 857 225 3556 today", {11, 14}),
- std::make_pair(11, 14));
-
- EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
- "call me at (800) 123-456 today", {11, 24})));
-
- EXPECT_THAT(classifier->Annotate("call me at (800) 123-456 today"),
- IsEmpty());
-}
-
-TEST_P(AnnotatorTest, SuggestSelectionFilteredCollections) {
- const std::string test_model = ReadFile(GetModelPath() + GetParam());
-
- std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
- test_model.c_str(), test_model.size(), &unilib_, &calendarlib_);
- ASSERT_TRUE(classifier);
-
- EXPECT_EQ(
- classifier->SuggestSelection("call me at 857 225 3556 today", {11, 14}),
- std::make_pair(11, 23));
-
- std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
- unpacked_model->output_options.reset(new OutputOptionsT);
-
- // Disable phone selection
- unpacked_model->output_options->filtered_collections_selection.push_back(
- "phone");
- // We need to force this for filtering.
- unpacked_model->selection_options->always_classify_suggested_selection = true;
-
- flatbuffers::FlatBufferBuilder builder;
- FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
-
- classifier = Annotator::FromUnownedBuffer(
- reinterpret_cast<const char*>(builder.GetBufferPointer()),
- builder.GetSize(), &unilib_, &calendarlib_);
- ASSERT_TRUE(classifier);
-
- EXPECT_EQ(
- classifier->SuggestSelection("call me at 857 225 3556 today", {11, 14}),
- std::make_pair(11, 14));
-
- // Address selection should still work.
- EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {4, 9}),
- std::make_pair(0, 27));
-}
-
-TEST_P(AnnotatorTest, SuggestSelectionsAreSymmetric) {
- std::unique_ptr<Annotator> classifier =
- Annotator::FromPath(GetModelPath() + GetParam(), &unilib_, &calendarlib_);
- ASSERT_TRUE(classifier);
-
- EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {0, 3}),
- std::make_pair(0, 27));
- EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {4, 9}),
- std::make_pair(0, 27));
- EXPECT_EQ(
- classifier->SuggestSelection("350 Third Street, Cambridge", {10, 16}),
- std::make_pair(0, 27));
- EXPECT_EQ(classifier->SuggestSelection("a\nb\nc\n350 Third Street, Cambridge",
- {16, 22}),
- std::make_pair(6, 33));
-}
-
-TEST_P(AnnotatorTest, SuggestSelectionWithNewLine) {
- std::unique_ptr<Annotator> classifier =
- Annotator::FromPath(GetModelPath() + GetParam(), &unilib_, &calendarlib_);
- ASSERT_TRUE(classifier);
-
- EXPECT_EQ(classifier->SuggestSelection("abc\n857 225 3556", {4, 7}),
- std::make_pair(4, 16));
- EXPECT_EQ(classifier->SuggestSelection("857 225 3556\nabc", {0, 3}),
- std::make_pair(0, 12));
-
- SelectionOptions options;
- EXPECT_EQ(classifier->SuggestSelection("857 225\n3556\nabc", {0, 3}, options),
- std::make_pair(0, 7));
-}
-
-TEST_P(AnnotatorTest, SuggestSelectionWithPunctuation) {
- std::unique_ptr<Annotator> classifier =
- Annotator::FromPath(GetModelPath() + GetParam(), &unilib_, &calendarlib_);
- ASSERT_TRUE(classifier);
-
- // From the right.
- EXPECT_EQ(classifier->SuggestSelection(
- "this afternoon BarackObama, gave a speech at", {15, 26}),
- std::make_pair(15, 26));
-
- // From the right multiple.
- EXPECT_EQ(classifier->SuggestSelection(
- "this afternoon BarackObama,.,.,, gave a speech at", {15, 26}),
- std::make_pair(15, 26));
-
- // From the left multiple.
- EXPECT_EQ(classifier->SuggestSelection(
- "this afternoon ,.,.,,BarackObama gave a speech at", {21, 32}),
- std::make_pair(21, 32));
-
- // From both sides.
- EXPECT_EQ(classifier->SuggestSelection(
- "this afternoon !BarackObama,- gave a speech at", {16, 27}),
- std::make_pair(16, 27));
-}
-
-TEST_P(AnnotatorTest, SuggestSelectionNoCrashWithJunk) {
- std::unique_ptr<Annotator> classifier =
- Annotator::FromPath(GetModelPath() + GetParam(), &unilib_, &calendarlib_);
- ASSERT_TRUE(classifier);
-
- // Try passing in bunch of invalid selections.
- EXPECT_EQ(classifier->SuggestSelection("", {0, 27}), std::make_pair(0, 27));
- EXPECT_EQ(classifier->SuggestSelection("", {-10, 27}),
- std::make_pair(-10, 27));
- EXPECT_EQ(classifier->SuggestSelection("Word 1 2 3 hello!", {0, 27}),
- std::make_pair(0, 27));
- EXPECT_EQ(classifier->SuggestSelection("Word 1 2 3 hello!", {-30, 300}),
- std::make_pair(-30, 300));
- EXPECT_EQ(classifier->SuggestSelection("Word 1 2 3 hello!", {-10, -1}),
- std::make_pair(-10, -1));
- EXPECT_EQ(classifier->SuggestSelection("Word 1 2 3 hello!", {100, 17}),
- std::make_pair(100, 17));
-
- // Try passing invalid utf8.
- EXPECT_EQ(classifier->SuggestSelection("\xf0\x9f\x98\x8b\x8b", {-1, -1}),
- std::make_pair(-1, -1));
-}
-
-TEST_P(AnnotatorTest, SuggestSelectionSelectSpace) {
- std::unique_ptr<Annotator> classifier =
- Annotator::FromPath(GetModelPath() + GetParam(), &unilib_, &calendarlib_);
- ASSERT_TRUE(classifier);
-
- EXPECT_EQ(
- classifier->SuggestSelection("call me at 857 225 3556 today", {14, 15}),
- std::make_pair(11, 23));
- EXPECT_EQ(
- classifier->SuggestSelection("call me at 857 225 3556 today", {10, 11}),
- std::make_pair(10, 11));
- EXPECT_EQ(
- classifier->SuggestSelection("call me at 857 225 3556 today", {23, 24}),
- std::make_pair(23, 24));
- EXPECT_EQ(
- classifier->SuggestSelection("call me at 857 225 3556, today", {23, 24}),
- std::make_pair(23, 24));
- EXPECT_EQ(classifier->SuggestSelection("call me at 857 225 3556, today",
- {14, 17}),
- std::make_pair(11, 25));
- EXPECT_EQ(
- classifier->SuggestSelection("call me at 857-225 3556, today", {14, 17}),
- std::make_pair(11, 23));
- EXPECT_EQ(
- classifier->SuggestSelection(
- "let's meet at 350 Third Street Cambridge and go there", {30, 31}),
- std::make_pair(14, 40));
- EXPECT_EQ(classifier->SuggestSelection("call me today", {4, 5}),
- std::make_pair(4, 5));
- EXPECT_EQ(classifier->SuggestSelection("call me today", {7, 8}),
- std::make_pair(7, 8));
-
- // With a punctuation around the selected whitespace.
- EXPECT_EQ(
- classifier->SuggestSelection(
- "let's meet at 350 Third Street, Cambridge and go there", {31, 32}),
- std::make_pair(14, 41));
-
- // When all's whitespace, should return the original indices.
- EXPECT_EQ(classifier->SuggestSelection(" ", {0, 1}),
- std::make_pair(0, 1));
- EXPECT_EQ(classifier->SuggestSelection(" ", {0, 3}),
- std::make_pair(0, 3));
- EXPECT_EQ(classifier->SuggestSelection(" ", {2, 3}),
- std::make_pair(2, 3));
- EXPECT_EQ(classifier->SuggestSelection(" ", {5, 6}),
- std::make_pair(5, 6));
-}
-
-TEST_F(AnnotatorTest, SnapLeftIfWhitespaceSelection) {
- UnicodeText text;
-
- text = UTF8ToUnicodeText("abcd efgh", /*do_copy=*/false);
- EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, unilib_),
- std::make_pair(3, 4));
- text = UTF8ToUnicodeText("abcd ", /*do_copy=*/false);
- EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, unilib_),
- std::make_pair(3, 4));
-
- // Nothing on the left.
- text = UTF8ToUnicodeText(" efgh", /*do_copy=*/false);
- EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, unilib_),
- std::make_pair(4, 5));
- text = UTF8ToUnicodeText(" efgh", /*do_copy=*/false);
- EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({0, 1}, text, unilib_),
- std::make_pair(0, 1));
-
- // Whitespace only.
- text = UTF8ToUnicodeText(" ", /*do_copy=*/false);
- EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({2, 3}, text, unilib_),
- std::make_pair(2, 3));
- text = UTF8ToUnicodeText(" ", /*do_copy=*/false);
- EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, unilib_),
- std::make_pair(4, 5));
- text = UTF8ToUnicodeText(" ", /*do_copy=*/false);
- EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({0, 1}, text, unilib_),
- std::make_pair(0, 1));
-}
-
-TEST_P(AnnotatorTest, Annotate) {
- std::unique_ptr<Annotator> classifier =
- Annotator::FromPath(GetModelPath() + GetParam(), &unilib_, &calendarlib_);
- ASSERT_TRUE(classifier);
-
- const std::string test_string =
- "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
- "number is 853 225 3556";
- EXPECT_THAT(classifier->Annotate(test_string),
- ElementsAreArray({
- IsAnnotatedSpan(28, 55, "address"),
- IsAnnotatedSpan(79, 91, "phone"),
- }));
-
- AnnotationOptions options;
- EXPECT_THAT(classifier->Annotate("853 225 3556", options),
- ElementsAreArray({IsAnnotatedSpan(0, 12, "phone")}));
- EXPECT_TRUE(classifier->Annotate("853 225\n3556", options).empty());
-
- // Try passing invalid utf8.
- EXPECT_TRUE(
- classifier->Annotate("853 225 3556\n\xf0\x9f\x98\x8b\x8b", options)
- .empty());
-}
-
-
-TEST_P(AnnotatorTest, AnnotateSmallBatches) {
- const std::string test_model = ReadFile(GetModelPath() + GetParam());
- std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
-
- // Set the batch size.
- unpacked_model->selection_options->batch_size = 4;
- flatbuffers::FlatBufferBuilder builder;
- FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
-
- std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
- reinterpret_cast<const char*>(builder.GetBufferPointer()),
- builder.GetSize(), &unilib_, &calendarlib_);
- ASSERT_TRUE(classifier);
-
- const std::string test_string =
- "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
- "number is 853 225 3556";
- EXPECT_THAT(classifier->Annotate(test_string),
- ElementsAreArray({
- IsAnnotatedSpan(28, 55, "address"),
- IsAnnotatedSpan(79, 91, "phone"),
- }));
-
- AnnotationOptions options;
- EXPECT_THAT(classifier->Annotate("853 225 3556", options),
- ElementsAreArray({IsAnnotatedSpan(0, 12, "phone")}));
- EXPECT_TRUE(classifier->Annotate("853 225\n3556", options).empty());
-}
-
-#ifdef TC3_UNILIB_ICU
-TEST_P(AnnotatorTest, AnnotateFilteringDiscardAll) {
- const std::string test_model = ReadFile(GetModelPath() + GetParam());
- std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
-
- unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
- // Add test threshold.
- unpacked_model->triggering_options->min_annotate_confidence =
- 2.f; // Discards all results.
- flatbuffers::FlatBufferBuilder builder;
- FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
-
- std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
- reinterpret_cast<const char*>(builder.GetBufferPointer()),
- builder.GetSize(), &unilib_, &calendarlib_);
- ASSERT_TRUE(classifier);
-
- const std::string test_string =
- "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
- "number is 853 225 3556";
-
- EXPECT_EQ(classifier->Annotate(test_string).size(), 0);
-}
-#endif // TC3_UNILIB_ICU
-
-TEST_P(AnnotatorTest, AnnotateFilteringKeepAll) {
- const std::string test_model = ReadFile(GetModelPath() + GetParam());
- std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
-
- // Add test thresholds.
- unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
- unpacked_model->triggering_options->min_annotate_confidence =
- 0.f; // Keeps all results.
- unpacked_model->triggering_options->enabled_modes = ModeFlag_ALL;
- flatbuffers::FlatBufferBuilder builder;
- FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
-
- std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
- reinterpret_cast<const char*>(builder.GetBufferPointer()),
- builder.GetSize(), &unilib_, &calendarlib_);
- ASSERT_TRUE(classifier);
-
- const std::string test_string =
- "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
- "number is 853 225 3556";
- EXPECT_EQ(classifier->Annotate(test_string).size(), 2);
-}
-
-TEST_P(AnnotatorTest, AnnotateDisabled) {
- const std::string test_model = ReadFile(GetModelPath() + GetParam());
- std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
-
- // Disable the model for annotation.
- unpacked_model->enabled_modes = ModeFlag_CLASSIFICATION_AND_SELECTION;
- flatbuffers::FlatBufferBuilder builder;
- FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
-
- std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
- reinterpret_cast<const char*>(builder.GetBufferPointer()),
- builder.GetSize(), &unilib_, &calendarlib_);
- ASSERT_TRUE(classifier);
- const std::string test_string =
- "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
- "number is 853 225 3556";
- EXPECT_THAT(classifier->Annotate(test_string), IsEmpty());
-}
-
-TEST_P(AnnotatorTest, AnnotateFilteredCollections) {
- const std::string test_model = ReadFile(GetModelPath() + GetParam());
-
- std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
- test_model.c_str(), test_model.size(), &unilib_, &calendarlib_);
- ASSERT_TRUE(classifier);
-
- const std::string test_string =
- "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
- "number is 853 225 3556";
-
- EXPECT_THAT(classifier->Annotate(test_string),
- ElementsAreArray({
- IsAnnotatedSpan(28, 55, "address"),
- IsAnnotatedSpan(79, 91, "phone"),
- }));
-
- std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
- unpacked_model->output_options.reset(new OutputOptionsT);
-
- // Disable phone annotation
- unpacked_model->output_options->filtered_collections_annotation.push_back(
- "phone");
-
- flatbuffers::FlatBufferBuilder builder;
- FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
-
- classifier = Annotator::FromUnownedBuffer(
- reinterpret_cast<const char*>(builder.GetBufferPointer()),
- builder.GetSize(), &unilib_, &calendarlib_);
- ASSERT_TRUE(classifier);
-
- EXPECT_THAT(classifier->Annotate(test_string),
- ElementsAreArray({
- IsAnnotatedSpan(28, 55, "address"),
- }));
-}
-
-#ifdef TC3_UNILIB_ICU
-TEST_P(AnnotatorTest, AnnotateFilteredCollectionsSuppress) {
- const std::string test_model = ReadFile(GetModelPath() + GetParam());
-
- std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
- test_model.c_str(), test_model.size(), &unilib_, &calendarlib_);
- ASSERT_TRUE(classifier);
-
- const std::string test_string =
- "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
- "number is 853 225 3556";
-
- EXPECT_THAT(classifier->Annotate(test_string),
- ElementsAreArray({
- IsAnnotatedSpan(28, 55, "address"),
- IsAnnotatedSpan(79, 91, "phone"),
- }));
-
- std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
- unpacked_model->output_options.reset(new OutputOptionsT);
-
- // We add a custom annotator that wins against the phone classification
- // below and that we subsequently suppress.
- unpacked_model->output_options->filtered_collections_annotation.push_back(
- "suppress");
-
- unpacked_model->regex_model->patterns.push_back(MakePattern(
- "suppress", "(\\d{3} ?\\d{4})",
- /*enabled_for_classification=*/false,
- /*enabled_for_selection=*/false, /*enabled_for_annotation=*/true, 2.0));
-
- flatbuffers::FlatBufferBuilder builder;
- FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
-
- classifier = Annotator::FromUnownedBuffer(
- reinterpret_cast<const char*>(builder.GetBufferPointer()),
- builder.GetSize(), &unilib_, &calendarlib_);
- ASSERT_TRUE(classifier);
-
- EXPECT_THAT(classifier->Annotate(test_string),
- ElementsAreArray({
- IsAnnotatedSpan(28, 55, "address"),
- }));
-}
-#endif // TC3_UNILIB_ICU
-
-#ifdef TC3_CALENDAR_ICU
-TEST_P(AnnotatorTest, ClassifyTextDate) {
- std::unique_ptr<Annotator> classifier =
- Annotator::FromPath(GetModelPath() + GetParam());
- EXPECT_TRUE(classifier);
-
- std::vector<ClassificationResult> result;
- ClassificationOptions options;
-
- options.reference_timezone = "Europe/Zurich";
- result = classifier->ClassifyText("january 1, 2017", {0, 15}, options);
-
- ASSERT_EQ(result.size(), 1);
- EXPECT_THAT(result[0].collection, "date");
- EXPECT_EQ(result[0].datetime_parse_result.time_ms_utc, 1483225200000);
- EXPECT_EQ(result[0].datetime_parse_result.granularity,
- DatetimeGranularity::GRANULARITY_DAY);
- result.clear();
-
- options.reference_timezone = "America/Los_Angeles";
- result = classifier->ClassifyText("march 1, 2017", {0, 13}, options);
- ASSERT_EQ(result.size(), 1);
- EXPECT_THAT(result[0].collection, "date");
- EXPECT_EQ(result[0].datetime_parse_result.time_ms_utc, 1488355200000);
- EXPECT_EQ(result[0].datetime_parse_result.granularity,
- DatetimeGranularity::GRANULARITY_DAY);
- result.clear();
-
- options.reference_timezone = "America/Los_Angeles";
- result = classifier->ClassifyText("2018/01/01 10:30:20", {0, 19}, options);
- ASSERT_EQ(result.size(), 1);
- EXPECT_THAT(result[0].collection, "date");
- EXPECT_EQ(result[0].datetime_parse_result.time_ms_utc, 1514831420000);
- EXPECT_EQ(result[0].datetime_parse_result.granularity,
- DatetimeGranularity::GRANULARITY_SECOND);
- result.clear();
-
- // Date on another line.
- options.reference_timezone = "Europe/Zurich";
- result = classifier->ClassifyText(
- "hello world this is the first line\n"
- "january 1, 2017",
- {35, 50}, options);
- ASSERT_EQ(result.size(), 1);
- EXPECT_THAT(result[0].collection, "date");
- EXPECT_EQ(result[0].datetime_parse_result.time_ms_utc, 1483225200000);
- EXPECT_EQ(result[0].datetime_parse_result.granularity,
- DatetimeGranularity::GRANULARITY_DAY);
-}
-#endif // TC3_UNILIB_ICU
-
-#ifdef TC3_CALENDAR_ICU
-TEST_P(AnnotatorTest, ClassifyTextDatePriorities) {
- std::unique_ptr<Annotator> classifier =
- Annotator::FromPath(GetModelPath() + GetParam());
- EXPECT_TRUE(classifier);
-
- std::vector<ClassificationResult> result;
- ClassificationOptions options;
-
- result.clear();
- options.reference_timezone = "Europe/Zurich";
- options.locales = "en-US";
- result = classifier->ClassifyText("03.05.1970", {0, 10}, options);
-
- ASSERT_EQ(result.size(), 1);
- EXPECT_THAT(result[0].collection, "date");
- EXPECT_EQ(result[0].datetime_parse_result.time_ms_utc, 5439600000);
- EXPECT_EQ(result[0].datetime_parse_result.granularity,
- DatetimeGranularity::GRANULARITY_DAY);
-
- result.clear();
- options.reference_timezone = "Europe/Zurich";
- options.locales = "de";
- result = classifier->ClassifyText("03.05.1970", {0, 10}, options);
-
- ASSERT_EQ(result.size(), 1);
- EXPECT_THAT(result[0].collection, "date");
- EXPECT_EQ(result[0].datetime_parse_result.time_ms_utc, 10537200000);
- EXPECT_EQ(result[0].datetime_parse_result.granularity,
- DatetimeGranularity::GRANULARITY_DAY);
-}
-#endif // TC3_UNILIB_ICU
-
-#ifdef TC3_CALENDAR_ICU
-TEST_P(AnnotatorTest, SuggestTextDateDisabled) {
- const std::string test_model = ReadFile(GetModelPath() + GetParam());
- std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
-
- // Disable the patterns for selection.
- for (int i = 0; i < unpacked_model->datetime_model->patterns.size(); i++) {
- unpacked_model->datetime_model->patterns[i]->enabled_modes =
- ModeFlag_ANNOTATION_AND_CLASSIFICATION;
- }
- flatbuffers::FlatBufferBuilder builder;
- FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
-
- std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
- reinterpret_cast<const char*>(builder.GetBufferPointer()),
- builder.GetSize(), &unilib_, &calendarlib_);
- ASSERT_TRUE(classifier);
- EXPECT_EQ("date",
- FirstResult(classifier->ClassifyText("january 1, 2017", {0, 15})));
- EXPECT_EQ(classifier->SuggestSelection("january 1, 2017", {0, 7}),
- std::make_pair(0, 7));
- EXPECT_THAT(classifier->Annotate("january 1, 2017"),
- ElementsAreArray({IsAnnotatedSpan(0, 15, "date")}));
-}
-#endif // TC3_UNILIB_ICU
-
-class TestingAnnotator : public Annotator {
- public:
- TestingAnnotator(const std::string& model, const UniLib* unilib,
- const CalendarLib* calendarlib)
- : Annotator(ViewModel(model.data(), model.size()), unilib, calendarlib) {}
-
- using Annotator::ResolveConflicts;
-};
-
-AnnotatedSpan MakeAnnotatedSpan(CodepointSpan span,
- const std::string& collection,
- const float score) {
- AnnotatedSpan result;
- result.span = span;
- result.classification.push_back({collection, score});
- return result;
-}
-
-TEST_F(AnnotatorTest, ResolveConflictsTrivial) {
- TestingAnnotator classifier("", &unilib_, &calendarlib_);
-
- std::vector<AnnotatedSpan> candidates{
- {MakeAnnotatedSpan({0, 1}, "phone", 1.0)}};
-
- std::vector<int> chosen;
- classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
- /*interpreter_manager=*/nullptr, &chosen);
- EXPECT_THAT(chosen, ElementsAreArray({0}));
-}
-
-TEST_F(AnnotatorTest, ResolveConflictsSequence) {
- TestingAnnotator classifier("", &unilib_, &calendarlib_);
-
- std::vector<AnnotatedSpan> candidates{{
- MakeAnnotatedSpan({0, 1}, "phone", 1.0),
- MakeAnnotatedSpan({1, 2}, "phone", 1.0),
- MakeAnnotatedSpan({2, 3}, "phone", 1.0),
- MakeAnnotatedSpan({3, 4}, "phone", 1.0),
- MakeAnnotatedSpan({4, 5}, "phone", 1.0),
- }};
-
- std::vector<int> chosen;
- classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
- /*interpreter_manager=*/nullptr, &chosen);
- EXPECT_THAT(chosen, ElementsAreArray({0, 1, 2, 3, 4}));
-}
-
-TEST_F(AnnotatorTest, ResolveConflictsThreeSpans) {
- TestingAnnotator classifier("", &unilib_, &calendarlib_);
-
- std::vector<AnnotatedSpan> candidates{{
- MakeAnnotatedSpan({0, 3}, "phone", 1.0),
- MakeAnnotatedSpan({1, 5}, "phone", 0.5), // Looser!
- MakeAnnotatedSpan({3, 7}, "phone", 1.0),
- }};
-
- std::vector<int> chosen;
- classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
- /*interpreter_manager=*/nullptr, &chosen);
- EXPECT_THAT(chosen, ElementsAreArray({0, 2}));
-}
-
-TEST_F(AnnotatorTest, ResolveConflictsThreeSpansReversed) {
- TestingAnnotator classifier("", &unilib_, &calendarlib_);
-
- std::vector<AnnotatedSpan> candidates{{
- MakeAnnotatedSpan({0, 3}, "phone", 0.5), // Looser!
- MakeAnnotatedSpan({1, 5}, "phone", 1.0),
- MakeAnnotatedSpan({3, 7}, "phone", 0.6), // Looser!
- }};
-
- std::vector<int> chosen;
- classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
- /*interpreter_manager=*/nullptr, &chosen);
- EXPECT_THAT(chosen, ElementsAreArray({1}));
-}
-
-TEST_F(AnnotatorTest, ResolveConflictsFiveSpans) {
- TestingAnnotator classifier("", &unilib_, &calendarlib_);
-
- std::vector<AnnotatedSpan> candidates{{
- MakeAnnotatedSpan({0, 3}, "phone", 0.5),
- MakeAnnotatedSpan({1, 5}, "other", 1.0), // Looser!
- MakeAnnotatedSpan({3, 7}, "phone", 0.6),
- MakeAnnotatedSpan({8, 12}, "phone", 0.6), // Looser!
- MakeAnnotatedSpan({11, 15}, "phone", 0.9),
- }};
-
- std::vector<int> chosen;
- classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
- /*interpreter_manager=*/nullptr, &chosen);
- EXPECT_THAT(chosen, ElementsAreArray({0, 2, 4}));
-}
-
-#ifdef TC3_UNILIB_ICU
-TEST_P(AnnotatorTest, LongInput) {
- std::unique_ptr<Annotator> classifier =
- Annotator::FromPath(GetModelPath() + GetParam(), &unilib_, &calendarlib_);
- ASSERT_TRUE(classifier);
-
- for (const auto& type_value_pair :
- std::vector<std::pair<std::string, std::string>>{
- {"address", "350 Third Street, Cambridge"},
- {"phone", "123 456-7890"},
- {"url", "www.google.com"},
- {"email", "someone@gmail.com"},
- {"flight", "LX 38"},
- {"date", "September 1, 2018"}}) {
- const std::string input_100k = std::string(50000, ' ') +
- type_value_pair.second +
- std::string(50000, ' ');
- const int value_length = type_value_pair.second.size();
-
- EXPECT_THAT(classifier->Annotate(input_100k),
- ElementsAreArray({IsAnnotatedSpan(50000, 50000 + value_length,
- type_value_pair.first)}));
- EXPECT_EQ(classifier->SuggestSelection(input_100k, {50000, 50001}),
- std::make_pair(50000, 50000 + value_length));
- EXPECT_EQ(type_value_pair.first,
- FirstResult(classifier->ClassifyText(
- input_100k, {50000, 50000 + value_length})));
- }
-}
-#endif // TC3_UNILIB_ICU
-
-#ifdef TC3_UNILIB_ICU
-// These coarse tests are there only to make sure the execution happens in
-// reasonable amount of time.
-TEST_P(AnnotatorTest, LongInputNoResultCheck) {
- std::unique_ptr<Annotator> classifier =
- Annotator::FromPath(GetModelPath() + GetParam(), &unilib_, &calendarlib_);
- ASSERT_TRUE(classifier);
-
- for (const std::string& value :
- std::vector<std::string>{"http://www.aaaaaaaaaaaaaaaaaaaa.com "}) {
- const std::string input_100k =
- std::string(50000, ' ') + value + std::string(50000, ' ');
- const int value_length = value.size();
-
- classifier->Annotate(input_100k);
- classifier->SuggestSelection(input_100k, {50000, 50001});
- classifier->ClassifyText(input_100k, {50000, 50000 + value_length});
- }
-}
-#endif // TC3_UNILIB_ICU
-
-#ifdef TC3_UNILIB_ICU
-TEST_P(AnnotatorTest, MaxTokenLength) {
- const std::string test_model = ReadFile(GetModelPath() + GetParam());
- std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
-
- std::unique_ptr<Annotator> classifier;
-
- // With unrestricted number of tokens should behave normally.
- unpacked_model->classification_options->max_num_tokens = -1;
-
- flatbuffers::FlatBufferBuilder builder;
- FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
- classifier = Annotator::FromUnownedBuffer(
- reinterpret_cast<const char*>(builder.GetBufferPointer()),
- builder.GetSize(), &unilib_, &calendarlib_);
- ASSERT_TRUE(classifier);
-
- EXPECT_EQ(FirstResult(classifier->ClassifyText(
- "I live at 350 Third Street, Cambridge.", {10, 37})),
- "address");
-
- // Raise the maximum number of tokens to suppress the classification.
- unpacked_model->classification_options->max_num_tokens = 3;
-
- flatbuffers::FlatBufferBuilder builder2;
- FinishModelBuffer(builder2, Model::Pack(builder2, unpacked_model.get()));
- classifier = Annotator::FromUnownedBuffer(
- reinterpret_cast<const char*>(builder2.GetBufferPointer()),
- builder2.GetSize(), &unilib_, &calendarlib_);
- ASSERT_TRUE(classifier);
-
- EXPECT_EQ(FirstResult(classifier->ClassifyText(
- "I live at 350 Third Street, Cambridge.", {10, 37})),
- "other");
-}
-#endif // TC3_UNILIB_ICU
-
-#ifdef TC3_UNILIB_ICU
-TEST_P(AnnotatorTest, MinAddressTokenLength) {
- const std::string test_model = ReadFile(GetModelPath() + GetParam());
- std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
-
- std::unique_ptr<Annotator> classifier;
-
- // With unrestricted number of address tokens should behave normally.
- unpacked_model->classification_options->address_min_num_tokens = 0;
-
- flatbuffers::FlatBufferBuilder builder;
- FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
- classifier = Annotator::FromUnownedBuffer(
- reinterpret_cast<const char*>(builder.GetBufferPointer()),
- builder.GetSize(), &unilib_, &calendarlib_);
- ASSERT_TRUE(classifier);
-
- EXPECT_EQ(FirstResult(classifier->ClassifyText(
- "I live at 350 Third Street, Cambridge.", {10, 37})),
- "address");
-
- // Raise number of address tokens to suppress the address classification.
- unpacked_model->classification_options->address_min_num_tokens = 5;
-
- flatbuffers::FlatBufferBuilder builder2;
- FinishModelBuffer(builder2, Model::Pack(builder2, unpacked_model.get()));
- classifier = Annotator::FromUnownedBuffer(
- reinterpret_cast<const char*>(builder2.GetBufferPointer()),
- builder2.GetSize(), &unilib_, &calendarlib_);
- ASSERT_TRUE(classifier);
-
- EXPECT_EQ(FirstResult(classifier->ClassifyText(
- "I live at 350 Third Street, Cambridge.", {10, 37})),
- "other");
-}
-#endif // TC3_UNILIB_ICU
-
-} // namespace
-} // namespace libtextclassifier3
diff --git a/annotator/collections.h b/annotator/collections.h
new file mode 100644
index 0000000..a23623e
--- /dev/null
+++ b/annotator/collections.h
@@ -0,0 +1,126 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_COLLECTIONS_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_COLLECTIONS_H_
+
+#include <string>
+
+namespace libtextclassifier3 {
+
+// String collection names for various classes.
+class Collections {
+ public:
+ static const std::string& Address() {
+ static const std::string& value =
+ *[]() { return new std::string("address"); }();
+ return value;
+ }
+ static const std::string& App() {
+ static const std::string& value =
+ *[]() { return new std::string("app"); }();
+ return value;
+ }
+ static const std::string& Contact() {
+ static const std::string& value =
+ *[]() { return new std::string("contact"); }();
+ return value;
+ }
+ static const std::string& Date() {
+ static const std::string& value =
+ *[]() { return new std::string("date"); }();
+ return value;
+ }
+ static const std::string& DateTime() {
+ static const std::string& value =
+ *[]() { return new std::string("datetime"); }();
+ return value;
+ }
+ static const std::string& Dictionary() {
+ static const std::string& value =
+ *[]() { return new std::string("dictionary"); }();
+ return value;
+ }
+ static const std::string& Duration() {
+ static const std::string& value =
+ *[]() { return new std::string("duration"); }();
+ return value;
+ }
+ static const std::string& Email() {
+ static const std::string& value =
+ *[]() { return new std::string("email"); }();
+ return value;
+ }
+ static const std::string& Entity() {
+ static const std::string& value =
+ *[]() { return new std::string("entity"); }();
+ return value;
+ }
+ static const std::string& Flight() {
+ static const std::string& value =
+ *[]() { return new std::string("flight"); }();
+ return value;
+ }
+ static const std::string& Iban() {
+ static const std::string& value =
+ *[]() { return new std::string("iban"); }();
+ return value;
+ }
+ static const std::string& Isbn() {
+ static const std::string& value =
+ *[]() { return new std::string("isbn"); }();
+ return value;
+ }
+ static const std::string& Money() {
+ static const std::string& value =
+ *[]() { return new std::string("money"); }();
+ return value;
+ }
+ static const std::string& Number() {
+ static const std::string& value =
+ *[]() { return new std::string("number"); }();
+ return value;
+ }
+ static const std::string& Other() {
+ static const std::string& value =
+ *[]() { return new std::string("other"); }();
+ return value;
+ }
+ static const std::string& PaymentCard() {
+ static const std::string& value =
+ *[]() { return new std::string("payment_card"); }();
+ return value;
+ }
+ static const std::string& Phone() {
+ static const std::string& value =
+ *[]() { return new std::string("phone"); }();
+ return value;
+ }
+ static const std::string& TrackingNumber() {
+ static const std::string& value =
+ *[]() { return new std::string("tracking_number"); }();
+ return value;
+ }
+ static const std::string& Url() {
+ static const std::string& value =
+ *[]() { return new std::string("url"); }();
+ return value;
+ }
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_COLLECTIONS_H_
diff --git a/annotator/contact/contact-engine-dummy.h b/annotator/contact/contact-engine-dummy.h
new file mode 100644
index 0000000..c7a389d
--- /dev/null
+++ b/annotator/contact/contact-engine-dummy.h
@@ -0,0 +1,56 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_CONTACT_CONTACT_ENGINE_DUMMY_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_CONTACT_CONTACT_ENGINE_DUMMY_H_
+
+#include <string>
+#include <vector>
+
+#include "annotator/feature-processor.h"
+#include "annotator/types.h"
+#include "utils/base/logging.h"
+#include "utils/utf8/unicodetext.h"
+#include "utils/utf8/unilib.h"
+
+namespace libtextclassifier3 {
+
+// A dummy implementation of the contact engine.
+class ContactEngine {
+ public:
+ explicit ContactEngine(const FeatureProcessor* feature_processor,
+ const UniLib* unilib) {}
+
+ bool Initialize(const std::string& serialized_config) {
+ TC3_LOG(ERROR) << "No contact engine to initialize.";
+ return false;
+ }
+
+ bool ClassifyText(const std::string& context, CodepointSpan selection_indices,
+ ClassificationResult* classification_result) const {
+ return false;
+ }
+
+ bool Chunk(const UnicodeText& context_unicode,
+ const std::vector<Token>& tokens,
+ std::vector<AnnotatedSpan>* result) const {
+ return true;
+ }
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_CONTACT_CONTACT_ENGINE_DUMMY_H_
diff --git a/annotator/contact/contact-engine.h b/annotator/contact/contact-engine.h
new file mode 100644
index 0000000..01d3323
--- /dev/null
+++ b/annotator/contact/contact-engine.h
@@ -0,0 +1,22 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_CONTACT_CONTACT_ENGINE_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_CONTACT_CONTACT_ENGINE_H_
+
+#include "annotator/contact/contact-engine-dummy.h"
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_CONTACT_CONTACT_ENGINE_H_
diff --git a/annotator/datetime/extractor.cc b/annotator/datetime/extractor.cc
index 31229dd..b9d0c30 100644
--- a/annotator/datetime/extractor.cc
+++ b/annotator/datetime/extractor.cc
@@ -376,7 +376,7 @@
}
bool DatetimeExtractor::ParseAMPM(const UnicodeText& input,
- int* parsed_ampm) const {
+ DateParseData::AMPM* parsed_ampm) const {
return MapInput(input,
{
{DatetimeExtractorType_AM, DateParseData::AMPM::AM},
@@ -420,50 +420,25 @@
return MapInput(
input,
{
- {DatetimeExtractorType_MONDAY, DateParseData::MONDAY},
- {DatetimeExtractorType_TUESDAY, DateParseData::TUESDAY},
- {DatetimeExtractorType_WEDNESDAY, DateParseData::WEDNESDAY},
- {DatetimeExtractorType_THURSDAY, DateParseData::THURSDAY},
- {DatetimeExtractorType_FRIDAY, DateParseData::FRIDAY},
- {DatetimeExtractorType_SATURDAY, DateParseData::SATURDAY},
- {DatetimeExtractorType_SUNDAY, DateParseData::SUNDAY},
- {DatetimeExtractorType_DAY, DateParseData::DAY},
- {DatetimeExtractorType_WEEK, DateParseData::WEEK},
- {DatetimeExtractorType_MONTH, DateParseData::MONTH},
- {DatetimeExtractorType_YEAR, DateParseData::YEAR},
+ {DatetimeExtractorType_MONDAY, DateParseData::RelationType::MONDAY},
+ {DatetimeExtractorType_TUESDAY, DateParseData::RelationType::TUESDAY},
+ {DatetimeExtractorType_WEDNESDAY,
+ DateParseData::RelationType::WEDNESDAY},
+ {DatetimeExtractorType_THURSDAY,
+ DateParseData::RelationType::THURSDAY},
+ {DatetimeExtractorType_FRIDAY, DateParseData::RelationType::FRIDAY},
+ {DatetimeExtractorType_SATURDAY,
+ DateParseData::RelationType::SATURDAY},
+ {DatetimeExtractorType_SUNDAY, DateParseData::RelationType::SUNDAY},
+ {DatetimeExtractorType_SECONDS, DateParseData::RelationType::SECOND},
+ {DatetimeExtractorType_MINUTES, DateParseData::RelationType::MINUTE},
+ {DatetimeExtractorType_HOURS, DateParseData::RelationType::HOUR},
+ {DatetimeExtractorType_DAY, DateParseData::RelationType::DAY},
+ {DatetimeExtractorType_WEEK, DateParseData::RelationType::WEEK},
+ {DatetimeExtractorType_MONTH, DateParseData::RelationType::MONTH},
+ {DatetimeExtractorType_YEAR, DateParseData::RelationType::YEAR},
},
parsed_relation_type);
}
-bool DatetimeExtractor::ParseTimeUnit(const UnicodeText& input,
- int* parsed_time_unit) const {
- return MapInput(input,
- {
- {DatetimeExtractorType_DAYS, DateParseData::DAYS},
- {DatetimeExtractorType_WEEKS, DateParseData::WEEKS},
- {DatetimeExtractorType_MONTHS, DateParseData::MONTHS},
- {DatetimeExtractorType_HOURS, DateParseData::HOURS},
- {DatetimeExtractorType_MINUTES, DateParseData::MINUTES},
- {DatetimeExtractorType_SECONDS, DateParseData::SECONDS},
- {DatetimeExtractorType_YEARS, DateParseData::YEARS},
- },
- parsed_time_unit);
-}
-
-bool DatetimeExtractor::ParseWeekday(const UnicodeText& input,
- int* parsed_weekday) const {
- return MapInput(
- input,
- {
- {DatetimeExtractorType_MONDAY, DateParseData::MONDAY},
- {DatetimeExtractorType_TUESDAY, DateParseData::TUESDAY},
- {DatetimeExtractorType_WEDNESDAY, DateParseData::WEDNESDAY},
- {DatetimeExtractorType_THURSDAY, DateParseData::THURSDAY},
- {DatetimeExtractorType_FRIDAY, DateParseData::FRIDAY},
- {DatetimeExtractorType_SATURDAY, DateParseData::SATURDAY},
- {DatetimeExtractorType_SUNDAY, DateParseData::SUNDAY},
- },
- parsed_weekday);
-}
-
} // namespace libtextclassifier3
diff --git a/annotator/datetime/extractor.h b/annotator/datetime/extractor.h
index 4c17aa7..95e7f7c 100644
--- a/annotator/datetime/extractor.h
+++ b/annotator/datetime/extractor.h
@@ -86,16 +86,19 @@
bool ParseWrittenNumber(const UnicodeText& input, int* parsed_number) const;
bool ParseYear(const UnicodeText& input, int* parsed_year) const;
bool ParseMonth(const UnicodeText& input, int* parsed_month) const;
- bool ParseAMPM(const UnicodeText& input, int* parsed_ampm) const;
+ bool ParseAMPM(const UnicodeText& input,
+ DateParseData::AMPM* parsed_ampm) const;
bool ParseRelation(const UnicodeText& input,
DateParseData::Relation* parsed_relation) const;
bool ParseRelationDistance(const UnicodeText& input,
int* parsed_distance) const;
- bool ParseTimeUnit(const UnicodeText& input, int* parsed_time_unit) const;
+ bool ParseTimeUnit(const UnicodeText& input,
+ DateParseData::TimeUnit* parsed_time_unit) const;
bool ParseRelationType(
const UnicodeText& input,
DateParseData::RelationType* parsed_relation_type) const;
- bool ParseWeekday(const UnicodeText& input, int* parsed_weekday) const;
+ bool ParseWeekday(const UnicodeText& input,
+ DateParseData::RelationType* parsed_weekday) const;
const CompiledRule& rule_;
const UniLib::RegexMatcher& matcher_;
diff --git a/annotator/datetime/parser.cc b/annotator/datetime/parser.cc
index ac3a62d..6d844f4 100644
--- a/annotator/datetime/parser.cc
+++ b/annotator/datetime/parser.cc
@@ -23,6 +23,7 @@
#include "utils/calendar/calendar.h"
#include "utils/i18n/locale.h"
#include "utils/strings/split.h"
+#include "utils/zlib/zlib_regex.h"
namespace libtextclassifier3 {
std::unique_ptr<DatetimeParser> DatetimeParser::Instance(
@@ -51,9 +52,9 @@
if (pattern->regexes()) {
for (const DatetimeModelPattern_::Regex* regex : *pattern->regexes()) {
std::unique_ptr<UniLib::RegexPattern> regex_pattern =
- UncompressMakeRegexPattern(unilib, regex->pattern(),
- regex->compressed_pattern(),
- decompressor);
+ UncompressMakeRegexPattern(
+ unilib, regex->pattern(), regex->compressed_pattern(),
+ model->lazy_regex_compilation(), decompressor);
if (!regex_pattern) {
TC3_LOG(ERROR) << "Couldn't create rule pattern.";
return;
@@ -72,9 +73,9 @@
if (model->extractors() != nullptr) {
for (const DatetimeModelExtractor* extractor : *model->extractors()) {
std::unique_ptr<UniLib::RegexPattern> regex_pattern =
- UncompressMakeRegexPattern(unilib, extractor->pattern(),
- extractor->compressed_pattern(),
- decompressor);
+ UncompressMakeRegexPattern(
+ unilib, extractor->pattern(), extractor->compressed_pattern(),
+ model->lazy_regex_compilation(), decompressor);
if (!regex_pattern) {
TC3_LOG(ERROR) << "Couldn't create extractor pattern";
return;
@@ -103,6 +104,8 @@
}
use_extractors_for_locating_ = model->use_extractors_for_locating();
+ generate_alternative_interpretations_when_ambiguous_ =
+ model->generate_alternative_interpretations_when_ambiguous();
initialized_ = true;
}
@@ -110,17 +113,18 @@
bool DatetimeParser::Parse(
const std::string& input, const int64 reference_time_ms_utc,
const std::string& reference_timezone, const std::string& locales,
- ModeFlag mode, bool anchor_start_end,
+ ModeFlag mode, AnnotationUsecase annotation_usecase, bool anchor_start_end,
std::vector<DatetimeParseResultSpan>* results) const {
return Parse(UTF8ToUnicodeText(input, /*do_copy=*/false),
reference_time_ms_utc, reference_timezone, locales, mode,
- anchor_start_end, results);
+ annotation_usecase, anchor_start_end, results);
}
bool DatetimeParser::FindSpansUsingLocales(
const std::vector<int>& locale_ids, const UnicodeText& input,
const int64 reference_time_ms_utc, const std::string& reference_timezone,
- ModeFlag mode, bool anchor_start_end, const std::string& reference_locale,
+ ModeFlag mode, AnnotationUsecase annotation_usecase, bool anchor_start_end,
+ const std::string& reference_locale,
std::unordered_set<int>* executed_rules,
std::vector<DatetimeParseResultSpan>* found_spans) const {
for (const int locale_id : locale_ids) {
@@ -135,6 +139,11 @@
continue;
}
+ if ((rules_[rule_id].pattern->enabled_annotation_usecases() &
+ (1 << annotation_usecase)) == 0) {
+ continue;
+ }
+
if (!(rules_[rule_id].pattern->enabled_modes() & mode)) {
continue;
}
@@ -154,7 +163,7 @@
bool DatetimeParser::Parse(
const UnicodeText& input, const int64 reference_time_ms_utc,
const std::string& reference_timezone, const std::string& locales,
- ModeFlag mode, bool anchor_start_end,
+ ModeFlag mode, AnnotationUsecase annotation_usecase, bool anchor_start_end,
std::vector<DatetimeParseResultSpan>* results) const {
std::vector<DatetimeParseResultSpan> found_spans;
std::unordered_set<int> executed_rules;
@@ -162,16 +171,16 @@
const std::vector<int> requested_locales =
ParseAndExpandLocales(locales, &reference_locale);
if (!FindSpansUsingLocales(requested_locales, input, reference_time_ms_utc,
- reference_timezone, mode, anchor_start_end,
- reference_locale, &executed_rules, &found_spans)) {
+ reference_timezone, mode, annotation_usecase,
+ anchor_start_end, reference_locale,
+ &executed_rules, &found_spans)) {
return false;
}
std::vector<std::pair<DatetimeParseResultSpan, int>> indexed_found_spans;
- int counter = 0;
- for (const auto& found_span : found_spans) {
- indexed_found_spans.push_back({found_span, counter});
- counter++;
+ indexed_found_spans.reserve(found_spans.size());
+ for (int i = 0; i < found_spans.size(); i++) {
+ indexed_found_spans.push_back({found_spans[i], i});
}
// Resolve conflicts by always picking the longer span and breaking ties by
@@ -224,21 +233,28 @@
}
DatetimeParseResultSpan parse_result;
+ std::vector<DatetimeParseResult> alternatives;
if (!ExtractDatetime(rule, matcher, reference_time_ms_utc, reference_timezone,
- reference_locale, locale_id, &(parse_result.data),
+ reference_locale, locale_id, &alternatives,
&parse_result.span)) {
return false;
}
+
if (!use_extractors_for_locating_) {
parse_result.span = {start, end};
}
+
if (parse_result.span.first != kInvalidIndex &&
parse_result.span.second != kInvalidIndex) {
parse_result.target_classification_score =
rule.pattern->target_classification_score();
parse_result.priority_score = rule.pattern->priority_score();
- result->push_back(parse_result);
+
+ for (DatetimeParseResult& alternative : alternatives) {
+ parse_result.data.push_back(alternative);
+ }
}
+ result->push_back(parse_result);
return true;
}
@@ -329,60 +345,54 @@
return result;
}
-namespace {
+void DatetimeParser::FillInterpretations(
+ const DateParseData& parse,
+ std::vector<DateParseData>* interpretations) const {
+ DatetimeGranularity granularity = calendarlib_.GetGranularity(parse);
-DatetimeGranularity GetGranularity(const DateParseData& data) {
- DatetimeGranularity granularity = DatetimeGranularity::GRANULARITY_YEAR;
- if ((data.field_set_mask & DateParseData::YEAR_FIELD) ||
- (data.field_set_mask & DateParseData::RELATION_TYPE_FIELD &&
- (data.relation_type == DateParseData::RelationType::YEAR))) {
- granularity = DatetimeGranularity::GRANULARITY_YEAR;
+ DateParseData modified_parse(parse);
+ // If the relation field is not set, but relation_type field *is*, assume
+ // the relation field is NEXT_OR_SAME. This is necessary to handle e.g.
+ // "monday 3pm" (otherwise only "this monday 3pm" would work).
+ if (!(modified_parse.field_set_mask &
+ DateParseData::Fields::RELATION_FIELD) &&
+ (modified_parse.field_set_mask &
+ DateParseData::Fields::RELATION_TYPE_FIELD)) {
+ modified_parse.relation = DateParseData::Relation::NEXT_OR_SAME;
+ modified_parse.field_set_mask |= DateParseData::Fields::RELATION_FIELD;
}
- if ((data.field_set_mask & DateParseData::MONTH_FIELD) ||
- (data.field_set_mask & DateParseData::RELATION_TYPE_FIELD &&
- (data.relation_type == DateParseData::RelationType::MONTH))) {
- granularity = DatetimeGranularity::GRANULARITY_MONTH;
+
+ // Multiple interpretations of ambiguous datetime expressions are generated
+ // here.
+ if (granularity > DatetimeGranularity::GRANULARITY_DAY &&
+ (modified_parse.field_set_mask & DateParseData::Fields::HOUR_FIELD) &&
+ modified_parse.hour <= 12 &&
+ !(modified_parse.field_set_mask & DateParseData::Fields::AMPM_FIELD)) {
+ // If it's not clear if the time is AM or PM, generate all variants.
+ interpretations->push_back(modified_parse);
+ interpretations->back().field_set_mask |= DateParseData::Fields::AMPM_FIELD;
+ interpretations->back().ampm = DateParseData::AMPM::AM;
+
+ interpretations->push_back(modified_parse);
+ interpretations->back().field_set_mask |= DateParseData::Fields::AMPM_FIELD;
+ interpretations->back().ampm = DateParseData::AMPM::PM;
+ } else {
+ // Otherwise just generate 1 variant.
+ interpretations->push_back(modified_parse);
}
- if (data.field_set_mask & DateParseData::RELATION_TYPE_FIELD &&
- (data.relation_type == DateParseData::RelationType::WEEK)) {
- granularity = DatetimeGranularity::GRANULARITY_WEEK;
- }
- if (data.field_set_mask & DateParseData::DAY_FIELD ||
- (data.field_set_mask & DateParseData::RELATION_FIELD &&
- (data.relation == DateParseData::Relation::NOW ||
- data.relation == DateParseData::Relation::TOMORROW ||
- data.relation == DateParseData::Relation::YESTERDAY)) ||
- (data.field_set_mask & DateParseData::RELATION_TYPE_FIELD &&
- (data.relation_type == DateParseData::RelationType::MONDAY ||
- data.relation_type == DateParseData::RelationType::TUESDAY ||
- data.relation_type == DateParseData::RelationType::WEDNESDAY ||
- data.relation_type == DateParseData::RelationType::THURSDAY ||
- data.relation_type == DateParseData::RelationType::FRIDAY ||
- data.relation_type == DateParseData::RelationType::SATURDAY ||
- data.relation_type == DateParseData::RelationType::SUNDAY ||
- data.relation_type == DateParseData::RelationType::DAY))) {
- granularity = DatetimeGranularity::GRANULARITY_DAY;
- }
- if (data.field_set_mask & DateParseData::HOUR_FIELD) {
- granularity = DatetimeGranularity::GRANULARITY_HOUR;
- }
- if (data.field_set_mask & DateParseData::MINUTE_FIELD) {
- granularity = DatetimeGranularity::GRANULARITY_MINUTE;
- }
- if (data.field_set_mask & DateParseData::SECOND_FIELD) {
- granularity = DatetimeGranularity::GRANULARITY_SECOND;
- }
- return granularity;
+ // TODO(zilka): Add support for generating alternatives for "monday" -> "this
+ // monday", "next monday", "last monday". The previous implementation did not
+ // work as expected, because didn't work correctly for this/previous day of
+ // week, and resulted sometimes results in the same date being proposed.
}
-} // namespace
-
bool DatetimeParser::ExtractDatetime(const CompiledRule& rule,
const UniLib::RegexMatcher& matcher,
const int64 reference_time_ms_utc,
const std::string& reference_timezone,
const std::string& reference_locale,
- int locale_id, DatetimeParseResult* result,
+ int locale_id,
+ std::vector<DatetimeParseResult>* results,
CodepointSpan* result_span) const {
DateParseData parse;
DatetimeExtractor extractor(rule, matcher, locale_id, unilib_,
@@ -392,14 +402,23 @@
return false;
}
- result->granularity = GetGranularity(parse);
-
- if (!calendarlib_.InterpretParseData(
- parse, reference_time_ms_utc, reference_timezone, reference_locale,
- result->granularity, &(result->time_ms_utc))) {
- return false;
+ std::vector<DateParseData> interpretations;
+ if (generate_alternative_interpretations_when_ambiguous_) {
+ FillInterpretations(parse, &interpretations);
+ } else {
+ interpretations.push_back(parse);
}
+ results->reserve(results->size() + interpretations.size());
+ for (const DateParseData& interpretation : interpretations) {
+ DatetimeParseResult result;
+ if (!calendarlib_.InterpretParseData(
+ interpretation, reference_time_ms_utc, reference_timezone,
+ reference_locale, &(result.time_ms_utc), &(result.granularity))) {
+ return false;
+ }
+ results->push_back(result);
+ }
return true;
}
diff --git a/annotator/datetime/parser.h b/annotator/datetime/parser.h
index c7eaf1f..3f0c143 100644
--- a/annotator/datetime/parser.h
+++ b/annotator/datetime/parser.h
@@ -47,15 +47,23 @@
// beginning of 'input' and end at the end of it.
bool Parse(const std::string& input, int64 reference_time_ms_utc,
const std::string& reference_timezone, const std::string& locales,
- ModeFlag mode, bool anchor_start_end,
+ ModeFlag mode, AnnotationUsecase annotation_usecase,
+ bool anchor_start_end,
std::vector<DatetimeParseResultSpan>* results) const;
// Same as above but takes UnicodeText.
bool Parse(const UnicodeText& input, int64 reference_time_ms_utc,
const std::string& reference_timezone, const std::string& locales,
- ModeFlag mode, bool anchor_start_end,
+ ModeFlag mode, AnnotationUsecase annotation_usecase,
+ bool anchor_start_end,
std::vector<DatetimeParseResultSpan>* results) const;
+#ifdef TC3_TEST_ONLY
+ void TestOnlySetGenerateAlternativeInterpretationsWhenAmbiguous(bool value) {
+ generate_alternative_interpretations_when_ambiguous_ = value;
+ }
+#endif // TC3_TEST_ONLY
+
protected:
DatetimeParser(const DatetimeModel* model, const UniLib& unilib,
const CalendarLib& calendarlib,
@@ -71,7 +79,8 @@
bool FindSpansUsingLocales(
const std::vector<int>& locale_ids, const UnicodeText& input,
const int64 reference_time_ms_utc, const std::string& reference_timezone,
- ModeFlag mode, bool anchor_start_end, const std::string& reference_locale,
+ ModeFlag mode, AnnotationUsecase annotation_usecase,
+ bool anchor_start_end, const std::string& reference_locale,
std::unordered_set<int>* executed_rules,
std::vector<DatetimeParseResultSpan>* found_spans) const;
@@ -82,13 +91,16 @@
bool anchor_start_end,
std::vector<DatetimeParseResultSpan>* result) const;
+ void FillInterpretations(const DateParseData& parse,
+ std::vector<DateParseData>* interpretations) const;
+
// Converts the current match in 'matcher' into DatetimeParseResult.
bool ExtractDatetime(const CompiledRule& rule,
const UniLib::RegexMatcher& matcher,
int64 reference_time_ms_utc,
const std::string& reference_timezone,
const std::string& reference_locale, int locale_id,
- DatetimeParseResult* result,
+ std::vector<DatetimeParseResult>* results,
CodepointSpan* result_span) const;
// Parse and extract information from current match in 'matcher'.
@@ -111,6 +123,7 @@
std::unordered_map<std::string, int> locale_string_to_id_;
std::vector<int> default_locale_ids_;
bool use_extractors_for_locating_;
+ bool generate_alternative_interpretations_when_ambiguous_;
};
} // namespace libtextclassifier3
diff --git a/annotator/datetime/parser_test.cc b/annotator/datetime/parser_test.cc
index d46accf..8196fa7 100644
--- a/annotator/datetime/parser_test.cc
+++ b/annotator/datetime/parser_test.cc
@@ -27,6 +27,7 @@
#include "annotator/datetime/parser.h"
#include "annotator/model_generated.h"
#include "annotator/types-test-util.h"
+#include "utils/testing/annotator.h"
using testing::ElementsAreArray;
@@ -42,30 +43,32 @@
return std::string(std::istreambuf_iterator<char>(file_stream), {});
}
-std::string FormatMillis(int64 time_ms_utc) {
- long time_seconds = time_ms_utc / 1000; // NOLINT
- // Format time, "ddd yyyy-mm-dd hh:mm:ss zzz"
- char buffer[512];
- strftime(buffer, sizeof(buffer), "%a %Y-%m-%d %H:%M:%S %Z",
- localtime(&time_seconds));
- return std::string(buffer);
-}
-
class ParserTest : public testing::Test {
public:
void SetUp() override {
- model_buffer_ = ReadFile(GetModelPath() + "test_model.fb");
+ // Loads default unmodified model. Individual tests can call LoadModel to
+ // make changes.
+ LoadModel([](ModelT* model) {});
+ }
+
+ template <typename Fn>
+ void LoadModel(Fn model_visitor_fn) {
+ std::string model_buffer = ReadFile(GetModelPath() + "test_model.fb");
+ model_buffer_ = ModifyAnnotatorModel(model_buffer, model_visitor_fn);
classifier_ = Annotator::FromUnownedBuffer(model_buffer_.data(),
model_buffer_.size(), &unilib_);
TC3_CHECK(classifier_);
parser_ = classifier_->DatetimeParserForTests();
+ TC3_CHECK(parser_);
}
bool HasNoResult(const std::string& text, bool anchor_start_end = false,
- const std::string& timezone = "Europe/Zurich") {
+ const std::string& timezone = "Europe/Zurich",
+ AnnotationUsecase annotation_usecase =
+ AnnotationUsecase_ANNOTATION_USECASE_SMART) {
std::vector<DatetimeParseResultSpan> results;
if (!parser_->Parse(text, 0, timezone, /*locales=*/"", ModeFlag_ANNOTATION,
- anchor_start_end, &results)) {
+ annotation_usecase, anchor_start_end, &results)) {
TC3_LOG(ERROR) << text;
TC3_CHECK(false);
}
@@ -73,11 +76,13 @@
}
bool ParsesCorrectly(const std::string& marked_text,
- const int64 expected_ms_utc,
+ const std::vector<int64>& expected_ms_utcs,
DatetimeGranularity expected_granularity,
bool anchor_start_end = false,
const std::string& timezone = "Europe/Zurich",
- const std::string& locales = "en-US") {
+ const std::string& locales = "en-US",
+ AnnotationUsecase annotation_usecase =
+ AnnotationUsecase_ANNOTATION_USECASE_SMART) {
const UnicodeText marked_text_unicode =
UTF8ToUnicodeText(marked_text, /*do_copy=*/false);
auto brace_open_it =
@@ -97,7 +102,7 @@
std::vector<DatetimeParseResultSpan> results;
if (!parser_->Parse(text, 0, timezone, locales, ModeFlag_ANNOTATION,
- anchor_start_end, &results)) {
+ annotation_usecase, anchor_start_end, &results)) {
TC3_LOG(ERROR) << text;
TC3_CHECK(false);
}
@@ -120,25 +125,50 @@
}
}
- const std::vector<DatetimeParseResultSpan> expected{
+ std::vector<DatetimeParseResultSpan> expected{
{{expected_start_index, expected_end_index},
- {expected_ms_utc, expected_granularity},
+ {},
/*target_classification_score=*/1.0,
/*priority_score=*/0.1}};
+ expected[0].data.resize(expected_ms_utcs.size());
+ for (int i = 0; i < expected_ms_utcs.size(); i++) {
+ expected[0].data[i] = {expected_ms_utcs[i], expected_granularity};
+ }
+
const bool matches =
testing::Matches(ElementsAreArray(expected))(filtered_results);
if (!matches) {
- TC3_LOG(ERROR) << "Expected: " << expected[0] << " which corresponds to: "
- << FormatMillis(expected[0].data.time_ms_utc);
- for (int i = 0; i < filtered_results.size(); ++i) {
- TC3_LOG(ERROR) << "Actual[" << i << "]: " << filtered_results[i]
- << " which corresponds to: "
- << FormatMillis(filtered_results[i].data.time_ms_utc);
+ TC3_LOG(ERROR) << "Expected: " << expected[0];
+ if (filtered_results.empty()) {
+ TC3_LOG(ERROR) << "But got no results.";
}
+ TC3_LOG(ERROR) << "Actual: " << filtered_results[0];
}
+
return matches;
}
+ bool ParsesCorrectly(const std::string& marked_text,
+ const int64 expected_ms_utc,
+ DatetimeGranularity expected_granularity,
+ bool anchor_start_end = false,
+ const std::string& timezone = "Europe/Zurich",
+ const std::string& locales = "en-US",
+ AnnotationUsecase annotation_usecase =
+ AnnotationUsecase_ANNOTATION_USECASE_SMART) {
+ return ParsesCorrectly(marked_text, std::vector<int64>{expected_ms_utc},
+ expected_granularity, anchor_start_end, timezone,
+ locales, annotation_usecase);
+ }
+
+ bool ParsesCorrectlyGerman(const std::string& marked_text,
+ const std::vector<int64>& expected_ms_utcs,
+ DatetimeGranularity expected_granularity) {
+ return ParsesCorrectly(marked_text, expected_ms_utcs, expected_granularity,
+ /*anchor_start_end=*/false,
+ /*timezone=*/"Europe/Zurich", /*locales=*/"de");
+ }
+
bool ParsesCorrectlyGerman(const std::string& marked_text,
const int64 expected_ms_utc,
DatetimeGranularity expected_granularity) {
@@ -173,24 +203,32 @@
GRANULARITY_SECOND));
EXPECT_TRUE(ParsesCorrectly("{Jun 09 2011 15:28:14}", 1307626094000,
GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectly("{Mar 16 08:12:04}", {6419524000, 6462724000},
+ GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectly("{2010-06-26 02:31:29}",
+ {1277512289000, 1277555489000},
+ GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectly("{2006/01/22 04:11:05}",
+ {1137899465000, 1137942665000},
+ GRANULARITY_SECOND));
EXPECT_TRUE(
- ParsesCorrectly("{Mar 16 08:12:04}", 6419524000, GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectly("{2010-06-26 02:31:29}", 1277512289000,
+ ParsesCorrectly("{11:42:35}", {38555000, 81755000}, GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectly("{23/Apr 11:42:35}", {9715355000, 9758555000},
GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectly("{2006/01/22 04:11:05}", 1137899465000,
+ EXPECT_TRUE(ParsesCorrectly("{23/Apr/2015 11:42:35}",
+ {1429782155000, 1429825355000},
GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectly("{11:42:35}", 38555000, GRANULARITY_SECOND));
- EXPECT_TRUE(
- ParsesCorrectly("{23/Apr 11:42:35}", 9715355000, GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectly("{23/Apr/2015 11:42:35}", 1429782155000,
+ EXPECT_TRUE(ParsesCorrectly("{23-Apr-2015 11:42:35}",
+ {1429782155000, 1429825355000},
GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectly("{23-Apr-2015 11:42:35}", 1429782155000,
+ EXPECT_TRUE(ParsesCorrectly("{23 Apr 2015 11:42:35}",
+ {1429782155000, 1429825355000},
GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectly("{23 Apr 2015 11:42:35}", 1429782155000,
+ EXPECT_TRUE(ParsesCorrectly("{04/23/15 11:42:35}",
+ {1429782155000, 1429825355000},
GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectly("{04/23/15 11:42:35}", 1429782155000,
- GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectly("{04/23/2015 11:42:35}", 1429782155000,
+ EXPECT_TRUE(ParsesCorrectly("{04/23/2015 11:42:35}",
+ {1429782155000, 1429825355000},
GRANULARITY_SECOND));
EXPECT_TRUE(ParsesCorrectly("{9/28/2011 2:23:15 PM}", 1317212595000,
GRANULARITY_SECOND));
@@ -205,25 +243,31 @@
"think order event music. Incommode so intention defective at "
"convinced. Led income months itself and houses you. After nor "
"you leave might share court balls. ",
- 1271651775000, GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectly("{january 1 2018 at 4:30}", 1514777400000,
+ {1271651775000, 1271694975000}, GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectly("{january 1 2018 at 4:30}",
+ {1514777400000, 1514820600000},
GRANULARITY_MINUTE));
EXPECT_TRUE(ParsesCorrectly("{january 1 2018 at 4:30 am}", 1514777400000,
GRANULARITY_MINUTE));
EXPECT_TRUE(ParsesCorrectly("{january 1 2018 at 4pm}", 1514818800000,
GRANULARITY_HOUR));
- EXPECT_TRUE(ParsesCorrectly("{today at 0:00}", -3600000, GRANULARITY_MINUTE));
- EXPECT_TRUE(ParsesCorrectly("{today at 0:00}", -57600000, GRANULARITY_MINUTE,
- /*anchor_start_end=*/false,
- "America/Los_Angeles"));
- EXPECT_TRUE(
- ParsesCorrectly("{tomorrow at 4:00}", 97200000, GRANULARITY_MINUTE));
+ EXPECT_TRUE(ParsesCorrectly("{today at 0:00}", {-3600000, 39600000},
+ GRANULARITY_MINUTE));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{today at 0:00}", {-57600000, -14400000}, GRANULARITY_MINUTE,
+ /*anchor_start_end=*/false, "America/Los_Angeles"));
+ EXPECT_TRUE(ParsesCorrectly("{tomorrow at 4:00}", {97200000, 140400000},
+ GRANULARITY_MINUTE));
EXPECT_TRUE(ParsesCorrectly("{tomorrow at 4am}", 97200000, GRANULARITY_HOUR));
EXPECT_TRUE(
ParsesCorrectly("{wednesday at 4am}", 529200000, GRANULARITY_HOUR));
EXPECT_TRUE(ParsesCorrectly("last seen {today at 9:01 PM}", 72060000,
GRANULARITY_MINUTE));
+ EXPECT_TRUE(ParsesCorrectly("set an alarm for {7am tomorrow}", 108000000,
+ GRANULARITY_HOUR));
+ EXPECT_TRUE(
+ ParsesCorrectly("set an alarm for {7 a.m}", 21600000, GRANULARITY_HOUR));
}
TEST_F(ParserTest, ParseWithAnchor) {
@@ -237,6 +281,50 @@
/*anchor_start_end=*/true));
}
+TEST_F(ParserTest, ParseWithRawUsecase) {
+ // Annotated for RAW usecase.
+ EXPECT_TRUE(ParsesCorrectly(
+ "{tomorrow}", 82800000, GRANULARITY_DAY, /*anchor_start_end=*/false,
+ /*timezone=*/"Europe/Zurich", /*locales=*/"en-US",
+ /*annotation_usecase=*/AnnotationUsecase_ANNOTATION_USECASE_RAW));
+
+ EXPECT_TRUE(ParsesCorrectly(
+ "call me {in two hours}", 7200000, GRANULARITY_HOUR,
+ /*anchor_start_end=*/false,
+ /*timezone=*/"Europe/Zurich", /*locales=*/"en-US",
+ /*annotation_usecase=*/AnnotationUsecase_ANNOTATION_USECASE_RAW));
+
+ EXPECT_TRUE(ParsesCorrectly(
+ "call me {next month}", 2674800000, GRANULARITY_MONTH,
+ /*anchor_start_end=*/false,
+ /*timezone=*/"Europe/Zurich", /*locales=*/"en-US",
+ /*annotation_usecase=*/AnnotationUsecase_ANNOTATION_USECASE_RAW));
+ EXPECT_TRUE(ParsesCorrectly(
+ "what's the time {now}", -3600000, GRANULARITY_DAY,
+ /*anchor_start_end=*/false,
+ /*timezone=*/"Europe/Zurich", /*locales=*/"en-US",
+ /*annotation_usecase=*/AnnotationUsecase_ANNOTATION_USECASE_RAW));
+
+ EXPECT_TRUE(ParsesCorrectly(
+ "call me on {Saturday}", 169200000, GRANULARITY_DAY,
+ /*anchor_start_end=*/false,
+ /*timezone=*/"Europe/Zurich", /*locales=*/"en-US",
+ /*annotation_usecase=*/AnnotationUsecase_ANNOTATION_USECASE_RAW));
+
+ // Not annotated for Smart usecase.
+ EXPECT_TRUE(HasNoResult(
+ "{tomorrow}", /*anchor_start_end=*/false,
+ /*timezone=*/"Europe/Zurich",
+ /*annotation_usecase=*/AnnotationUsecase_ANNOTATION_USECASE_SMART));
+}
+
+TEST_F(ParserTest, ParsesNoonAndMidnightCorrectly) {
+ EXPECT_TRUE(ParsesCorrectly("{January 1, 1988 12:30am}", 567991800000,
+ GRANULARITY_MINUTE));
+ EXPECT_TRUE(ParsesCorrectly("{January 1, 1988 12:30pm}", 568035000000,
+ GRANULARITY_MINUTE));
+}
+
TEST_F(ParserTest, ParseGerman) {
EXPECT_TRUE(
ParsesCorrectlyGerman("{Januar 1 2018}", 1514761200000, GRANULARITY_DAY));
@@ -244,39 +332,51 @@
ParsesCorrectlyGerman("{1 2 2018}", 1517439600000, GRANULARITY_DAY));
EXPECT_TRUE(ParsesCorrectlyGerman("lorem {1 Januar 2018} ipsum",
1514761200000, GRANULARITY_DAY));
- EXPECT_TRUE(ParsesCorrectlyGerman("{19/Apr/2010:06:36:15}", 1271651775000,
+ EXPECT_TRUE(ParsesCorrectlyGerman("{19/Apr/2010:06:36:15}",
+ {1271651775000, 1271694975000},
GRANULARITY_SECOND));
EXPECT_TRUE(ParsesCorrectlyGerman("{09/März/2004 22:02:40}", 1078866160000,
GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectlyGerman("{Dez 2, 2010 2:39:58}", 1291253998000,
+ EXPECT_TRUE(ParsesCorrectlyGerman("{Dez 2, 2010 2:39:58}",
+ {1291253998000, 1291297198000},
GRANULARITY_SECOND));
EXPECT_TRUE(ParsesCorrectlyGerman("{Juni 09 2011 15:28:14}", 1307626094000,
GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectlyGerman("{März 16 08:12:04}", 6419524000,
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "{März 16 08:12:04}", {6419524000, 6462724000}, GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectlyGerman("{2010-06-26 02:31:29}",
+ {1277512289000, 1277555489000},
GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectlyGerman("{2010-06-26 02:31:29}", 1277512289000,
+ EXPECT_TRUE(ParsesCorrectlyGerman("{2006/01/22 04:11:05}",
+ {1137899465000, 1137942665000},
GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectlyGerman("{2006/01/22 04:11:05}", 1137899465000,
+ EXPECT_TRUE(ParsesCorrectlyGerman("{11:42:35}", {38555000, 81755000},
GRANULARITY_SECOND));
- EXPECT_TRUE(
- ParsesCorrectlyGerman("{11:42:35}", 38555000, GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectlyGerman("{23/Apr 11:42:35}", 9715355000,
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "{23/Apr 11:42:35}", {9715355000, 9758555000}, GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectlyGerman("{23/Apr/2015:11:42:35}",
+ {1429782155000, 1429825355000},
GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectlyGerman("{23/Apr/2015:11:42:35}", 1429782155000,
+ EXPECT_TRUE(ParsesCorrectlyGerman("{23/Apr/2015 11:42:35}",
+ {1429782155000, 1429825355000},
GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectlyGerman("{23/Apr/2015 11:42:35}", 1429782155000,
+ EXPECT_TRUE(ParsesCorrectlyGerman("{23-Apr-2015 11:42:35}",
+ {1429782155000, 1429825355000},
GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectlyGerman("{23-Apr-2015 11:42:35}", 1429782155000,
+ EXPECT_TRUE(ParsesCorrectlyGerman("{23 Apr 2015 11:42:35}",
+ {1429782155000, 1429825355000},
GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectlyGerman("{23 Apr 2015 11:42:35}", 1429782155000,
+ EXPECT_TRUE(ParsesCorrectlyGerman("{04/23/15 11:42:35}",
+ {1429782155000, 1429825355000},
GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectlyGerman("{04/23/15 11:42:35}", 1429782155000,
+ EXPECT_TRUE(ParsesCorrectlyGerman("{04/23/2015 11:42:35}",
+ {1429782155000, 1429825355000},
GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectlyGerman("{04/23/2015 11:42:35}", 1429782155000,
+ EXPECT_TRUE(ParsesCorrectlyGerman("{19/apr/2010:06:36:15}",
+ {1271651775000, 1271694975000},
GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectlyGerman("{19/apr/2010:06:36:15}", 1271651775000,
- GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectlyGerman("{januar 1 2018 um 4:30}", 1514777400000,
+ EXPECT_TRUE(ParsesCorrectlyGerman("{januar 1 2018 um 4:30}",
+ {1514777400000, 1514820600000},
GRANULARITY_MINUTE));
EXPECT_TRUE(ParsesCorrectlyGerman("{januar 1 2018 um 4:30 nachm}",
1514820600000, GRANULARITY_MINUTE));
@@ -284,10 +384,10 @@
GRANULARITY_HOUR));
EXPECT_TRUE(
ParsesCorrectlyGerman("{14.03.2017}", 1489446000000, GRANULARITY_DAY));
- EXPECT_TRUE(
- ParsesCorrectlyGerman("{morgen 0:00}", 82800000, GRANULARITY_MINUTE));
- EXPECT_TRUE(
- ParsesCorrectlyGerman("{morgen um 4:00}", 97200000, GRANULARITY_MINUTE));
+ EXPECT_TRUE(ParsesCorrectlyGerman("{morgen 0:00}", {82800000, 126000000},
+ GRANULARITY_MINUTE));
+ EXPECT_TRUE(ParsesCorrectlyGerman("{morgen um 4:00}", {97200000, 140400000},
+ GRANULARITY_MINUTE));
EXPECT_TRUE(
ParsesCorrectlyGerman("{morgen um 4 vorm}", 97200000, GRANULARITY_HOUR));
}
@@ -320,6 +420,30 @@
/*timezone=*/"Europe/Zurich", /*locales=*/"xx"));
}
+TEST_F(ParserTest, WhenAlternativesEnabledGeneratesAlternatives) {
+ LoadModel([](ModelT* model) {
+ model->datetime_model->generate_alternative_interpretations_when_ambiguous =
+ true;
+ });
+
+ EXPECT_TRUE(ParsesCorrectly("{january 1 2018 at 4:30}",
+ {1514777400000, 1514820600000},
+ GRANULARITY_MINUTE));
+ EXPECT_TRUE(ParsesCorrectly("{monday 3pm}", 396000000, GRANULARITY_HOUR));
+ EXPECT_TRUE(ParsesCorrectly("{monday 3:00}", {352800000, 396000000},
+ GRANULARITY_MINUTE));
+}
+
+TEST_F(ParserTest, WhenAlternativesDisabledDoesNotGenerateAlternatives) {
+ LoadModel([](ModelT* model) {
+ model->datetime_model->generate_alternative_interpretations_when_ambiguous =
+ false;
+ });
+
+ EXPECT_TRUE(ParsesCorrectly("{january 1 2018 at 4:30}", 1514777400000,
+ GRANULARITY_MINUTE));
+}
+
class ParserLocaleTest : public testing::Test {
public:
void SetUp() override;
@@ -376,9 +500,10 @@
bool ParserLocaleTest::HasResult(const std::string& input,
const std::string& locales) {
std::vector<DatetimeParseResultSpan> results;
- EXPECT_TRUE(parser_->Parse(input, /*reference_time_ms_utc=*/0,
- /*reference_timezone=*/"", locales,
- ModeFlag_ANNOTATION, false, &results));
+ EXPECT_TRUE(parser_->Parse(
+ input, /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"", locales, ModeFlag_ANNOTATION,
+ AnnotationUsecase_ANNOTATION_USECASE_SMART, false, &results));
return results.size() == 1;
}
diff --git a/annotator/duration/duration.cc b/annotator/duration/duration.cc
new file mode 100644
index 0000000..d442dc6
--- /dev/null
+++ b/annotator/duration/duration.cc
@@ -0,0 +1,290 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/duration/duration.h"
+
+#include <climits>
+#include <cstdlib>
+
+#include "annotator/collections.h"
+#include "annotator/types.h"
+#include "utils/base/logging.h"
+#include "utils/strings/numbers.h"
+
+namespace libtextclassifier3 {
+
+using DurationUnit = internal::DurationUnit;
+
+namespace internal {
+
+namespace {
+void FillDurationUnitMap(
+ const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>*
+ expressions,
+ DurationUnit duration_unit,
+ std::unordered_map<std::string, DurationUnit>* target_map) {
+ if (expressions == nullptr) {
+ return;
+ }
+
+ for (const flatbuffers::String* expression_string : *expressions) {
+ (*target_map)[expression_string->c_str()] = duration_unit;
+ }
+}
+} // namespace
+
+std::unordered_map<std::string, DurationUnit> BuildTokenToDurationUnitMapping(
+ const DurationAnnotatorOptions* options) {
+ std::unordered_map<std::string, DurationUnit> mapping;
+ FillDurationUnitMap(options->week_expressions(), DurationUnit::WEEK,
+ &mapping);
+ FillDurationUnitMap(options->day_expressions(), DurationUnit::DAY, &mapping);
+ FillDurationUnitMap(options->hour_expressions(), DurationUnit::HOUR,
+ &mapping);
+ FillDurationUnitMap(options->minute_expressions(), DurationUnit::MINUTE,
+ &mapping);
+ FillDurationUnitMap(options->second_expressions(), DurationUnit::SECOND,
+ &mapping);
+ return mapping;
+}
+
+std::unordered_set<std::string> BuildStringSet(
+ const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>*
+ strings) {
+ std::unordered_set<std::string> result;
+ if (strings == nullptr) {
+ return result;
+ }
+
+ for (const flatbuffers::String* string_value : *strings) {
+ result.insert(string_value->c_str());
+ }
+
+ return result;
+}
+
+} // namespace internal
+
+bool DurationAnnotator::ClassifyText(
+ const UnicodeText& context, CodepointSpan selection_indices,
+ AnnotationUsecase annotation_usecase,
+ ClassificationResult* classification_result) const {
+ if (!options_->enabled() || ((options_->enabled_annotation_usecases() &
+ (1 << annotation_usecase))) == 0) {
+ return false;
+ }
+
+ const UnicodeText selection =
+ UnicodeText::Substring(context, selection_indices.first,
+ selection_indices.second, /*do_copy=*/false);
+ const std::vector<Token> tokens = feature_processor_->Tokenize(selection);
+
+ AnnotatedSpan annotated_span;
+ if (FindDurationStartingAt(context, tokens, 0, &annotated_span) !=
+ tokens.size()) {
+ return false;
+ }
+
+ TC3_CHECK(!annotated_span.classification.empty());
+
+ *classification_result = annotated_span.classification[0];
+ return true;
+}
+
+bool DurationAnnotator::FindAll(const UnicodeText& context,
+ const std::vector<Token>& tokens,
+ AnnotationUsecase annotation_usecase,
+ std::vector<AnnotatedSpan>* results) const {
+ if (!options_->enabled() || ((options_->enabled_annotation_usecases() &
+ (1 << annotation_usecase))) == 0) {
+ return true;
+ }
+
+ for (int i = 0; i < tokens.size();) {
+ AnnotatedSpan span;
+ const int next_i = FindDurationStartingAt(context, tokens, i, &span);
+ if (next_i != i) {
+ results->push_back(span);
+ i = next_i;
+ } else {
+ i++;
+ }
+ }
+ return true;
+}
+
+int DurationAnnotator::FindDurationStartingAt(const UnicodeText& context,
+ const std::vector<Token>& tokens,
+ int start_token_index,
+ AnnotatedSpan* result) const {
+ CodepointIndex start_index = kInvalidIndex;
+ CodepointIndex end_index = kInvalidIndex;
+
+ bool has_quantity = false;
+ ParsedDurationAtom parsed_duration;
+
+ std::vector<ParsedDurationAtom> parsed_duration_atoms;
+
+ // This is the core algorithm for finding the duration expressions. It
+ // basically iterates over tokens and changes the state variables above as it
+ // goes.
+ int token_index;
+ for (token_index = start_token_index; token_index < tokens.size();
+ token_index++) {
+ const Token& token = tokens[token_index];
+
+ if (ParseQuantityToken(token, &parsed_duration)) {
+ has_quantity = true;
+ if (start_index == kInvalidIndex) {
+ start_index = token.start;
+ }
+ end_index = token.end;
+ } else if (ParseDurationUnitToken(token, &parsed_duration.unit)) {
+ if (start_index == kInvalidIndex) {
+ start_index = token.start;
+ }
+ end_index = token.end;
+ parsed_duration_atoms.push_back(parsed_duration);
+ has_quantity = false;
+ parsed_duration = ParsedDurationAtom();
+ } else if (ParseFillerToken(token)) {
+ } else {
+ break;
+ }
+ }
+
+ if (parsed_duration_atoms.empty()) {
+ return start_token_index;
+ }
+
+ const bool parse_ended_without_unit_for_last_mentioned_quantity =
+ has_quantity;
+
+ ClassificationResult classification{Collections::Duration(),
+ options_->score()};
+ classification.priority_score = options_->priority_score();
+ classification.duration_ms =
+ ParsedDurationAtomsToMillis(parsed_duration_atoms);
+
+ // Process suffix expressions like "and half" that don't have the
+ // duration_unit explicitly mentioned.
+ if (parse_ended_without_unit_for_last_mentioned_quantity &&
+ parsed_duration.plus_half) {
+ ParsedDurationAtom atom = ParsedDurationAtom::Half();
+ atom.unit = parsed_duration_atoms.rbegin()->unit;
+ classification.duration_ms += ParsedDurationAtomsToMillis({atom});
+ }
+
+ result->span = feature_processor_->StripBoundaryCodepoints(
+ context, {start_index, end_index});
+ result->classification.push_back(classification);
+ result->source = AnnotatedSpan::Source::DURATION;
+
+ return token_index;
+}
+
+int64 DurationAnnotator::ParsedDurationAtomsToMillis(
+ const std::vector<ParsedDurationAtom>& atoms) const {
+ int64 result = 0;
+ for (auto atom : atoms) {
+ int multiplier;
+ switch (atom.unit) {
+ case DurationUnit::WEEK:
+ multiplier = 7 * 24 * 60 * 60 * 1000;
+ break;
+ case DurationUnit::DAY:
+ multiplier = 24 * 60 * 60 * 1000;
+ break;
+ case DurationUnit::HOUR:
+ multiplier = 60 * 60 * 1000;
+ break;
+ case DurationUnit::MINUTE:
+ multiplier = 60 * 1000;
+ break;
+ case DurationUnit::SECOND:
+ multiplier = 1000;
+ break;
+ case DurationUnit::UNKNOWN:
+ TC3_LOG(ERROR) << "Requesting parse of UNKNOWN duration duration_unit.";
+ return -1;
+ break;
+ }
+
+ int value = atom.value;
+ // This condition handles expressions like "an hour", where the quantity is
+ // not specified. In this case we assume quantity 1. Except for cases like
+ // "half hour".
+ if (value == 0 && !atom.plus_half) {
+ value = 1;
+ }
+ result += value * multiplier;
+ result += atom.plus_half * multiplier / 2;
+ }
+ return result;
+}
+
+bool DurationAnnotator::ParseQuantityToken(const Token& token,
+ ParsedDurationAtom* value) const {
+ if (token.value.empty()) {
+ return false;
+ }
+
+ std::string token_value_buffer;
+ const std::string& token_value = feature_processor_->StripBoundaryCodepoints(
+ token.value, &token_value_buffer);
+
+ if (half_expressions_.find(token_value) != half_expressions_.end()) {
+ value->plus_half = true;
+ return true;
+ }
+
+ int32 parsed_value;
+ if (ParseInt32(token_value.c_str(), &parsed_value)) {
+ value->value = parsed_value;
+ return true;
+ }
+
+ return false;
+}
+
+bool DurationAnnotator::ParseDurationUnitToken(
+ const Token& token, DurationUnit* duration_unit) const {
+ std::string token_value_buffer;
+ const std::string& token_value = feature_processor_->StripBoundaryCodepoints(
+ token.value, &token_value_buffer);
+
+ const auto it = token_value_to_duration_unit_.find(token_value);
+ if (it == token_value_to_duration_unit_.end()) {
+ return false;
+ }
+
+ *duration_unit = it->second;
+ return true;
+}
+
+bool DurationAnnotator::ParseFillerToken(const Token& token) const {
+ std::string token_value_buffer;
+ const std::string& token_value = feature_processor_->StripBoundaryCodepoints(
+ token.value, &token_value_buffer);
+
+ if (filler_expressions_.find(token_value) == filler_expressions_.end()) {
+ return false;
+ }
+
+ return true;
+}
+
+} // namespace libtextclassifier3
diff --git a/annotator/duration/duration.h b/annotator/duration/duration.h
new file mode 100644
index 0000000..4311afc
--- /dev/null
+++ b/annotator/duration/duration.h
@@ -0,0 +1,128 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_DURATION_DURATION_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_DURATION_DURATION_H_
+
+#include <string>
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+
+#include "annotator/feature-processor.h"
+#include "annotator/model_generated.h"
+#include "annotator/types.h"
+#include "utils/utf8/unicodetext.h"
+
+namespace libtextclassifier3 {
+
+namespace internal {
+enum class DurationUnit {
+ UNKNOWN = -1,
+ WEEK = 0,
+ DAY = 1,
+ HOUR = 2,
+ MINUTE = 3,
+ SECOND = 4
+
+ // NOTE: If we want to add MONTH and YEAR we'll have to think of different
+ // parsing format, because MONTH and YEAR don't have a fixed number of
+ // milliseconds, unlike week/day/hour/minute/second. We ignore the daylight
+ // savings time and assume the day is always 24 hours.
+};
+
+// Prepares the mapping between token values and duration unit types.
+std::unordered_map<std::string, internal::DurationUnit>
+BuildTokenToDurationUnitMapping(const DurationAnnotatorOptions* options);
+
+// Creates a set of strings from a flatbuffer string vector.
+std::unordered_set<std::string> BuildStringSet(
+ const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>*);
+
+} // namespace internal
+
+// Annotator of duration expressions like "3 minutes 30 seconds".
+class DurationAnnotator {
+ public:
+ explicit DurationAnnotator(const DurationAnnotatorOptions* options,
+ const FeatureProcessor* feature_processor)
+ : options_(options),
+ feature_processor_(feature_processor),
+ token_value_to_duration_unit_(
+ internal::BuildTokenToDurationUnitMapping(options)),
+ filler_expressions_(
+ internal::BuildStringSet(options->filler_expressions())),
+ half_expressions_(
+ internal::BuildStringSet(options->half_expressions())) {}
+
+ // Classifies given text, and if it is a duration, it passes the result in
+ // 'classification_result' and returns true, otherwise returns false.
+ bool ClassifyText(const UnicodeText& context, CodepointSpan selection_indices,
+ AnnotationUsecase annotation_usecase,
+ ClassificationResult* classification_result) const;
+
+ // Finds all duration instances in the input text.
+ bool FindAll(const UnicodeText& context, const std::vector<Token>& tokens,
+ AnnotationUsecase annotation_usecase,
+ std::vector<AnnotatedSpan>* results) const;
+
+ private:
+ // Represents a component of duration parsed from text (e.g. "3 hours" from
+ // the expression "3 hours and 20 minutes").
+ struct ParsedDurationAtom {
+ // Unit of the duration.
+ internal::DurationUnit unit = internal::DurationUnit::UNKNOWN;
+
+ // Quantity of the duration unit.
+ int value = 0;
+
+ // True, if half an unit was specified (either in addition, or exclusively).
+ // E.g. "hour and a half".
+ // NOTE: Quarter, three-quarters etc. is not supported.
+ bool plus_half = false;
+
+ static ParsedDurationAtom Half() {
+ ParsedDurationAtom result;
+ result.plus_half = true;
+ return result;
+ }
+ };
+
+ // Starts consuming tokens and returns the index past the last consumed token.
+ int FindDurationStartingAt(const UnicodeText& context,
+ const std::vector<Token>& tokens,
+ int start_token_index,
+ AnnotatedSpan* result) const;
+
+ bool ParseQuantityToken(const Token& token, ParsedDurationAtom* value) const;
+ bool ParseDurationUnitToken(const Token& token,
+ internal::DurationUnit* duration_unit) const;
+ bool ParseFillerToken(const Token& token) const;
+
+ int64 ParsedDurationAtomsToMillis(
+ const std::vector<ParsedDurationAtom>& atoms) const;
+
+ const DurationAnnotatorOptions* options_;
+ const FeatureProcessor* feature_processor_;
+ const std::unordered_map<std::string, internal::DurationUnit>
+ token_value_to_duration_unit_;
+ const std::unordered_set<std::string> filler_expressions_;
+ const std::unordered_set<std::string> half_expressions_;
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_DURATION_DURATION_H_
diff --git a/annotator/duration/duration_test.cc b/annotator/duration/duration_test.cc
new file mode 100644
index 0000000..78548fe
--- /dev/null
+++ b/annotator/duration/duration_test.cc
@@ -0,0 +1,320 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/duration/duration.h"
+
+#include <string>
+#include <vector>
+
+#include "annotator/collections.h"
+#include "annotator/model_generated.h"
+#include "annotator/types-test-util.h"
+#include "annotator/types.h"
+#include "utils/test-utils.h"
+#include "utils/utf8/unicodetext.h"
+#include "utils/utf8/unilib.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+using testing::AllOf;
+using testing::ElementsAre;
+using testing::Field;
+
+const DurationAnnotatorOptions* TestingDurationAnnotatorOptions() {
+ static const flatbuffers::DetachedBuffer* options_data = []() {
+ DurationAnnotatorOptionsT options;
+ options.enabled = true;
+
+ options.week_expressions.push_back("week");
+ options.week_expressions.push_back("weeks");
+
+ options.day_expressions.push_back("day");
+ options.day_expressions.push_back("days");
+
+ options.hour_expressions.push_back("hour");
+ options.hour_expressions.push_back("hours");
+
+ options.minute_expressions.push_back("minute");
+ options.minute_expressions.push_back("minutes");
+
+ options.second_expressions.push_back("second");
+ options.second_expressions.push_back("seconds");
+
+ options.filler_expressions.push_back("and");
+ options.filler_expressions.push_back("a");
+ options.filler_expressions.push_back("an");
+ options.filler_expressions.push_back("one");
+
+ options.half_expressions.push_back("half");
+
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(DurationAnnotatorOptions::Pack(builder, &options));
+ return new flatbuffers::DetachedBuffer(builder.Release());
+ }();
+
+ return flatbuffers::GetRoot<DurationAnnotatorOptions>(options_data->data());
+}
+
+FeatureProcessor BuildFeatureProcessor(const UniLib* unilib) {
+ static const flatbuffers::DetachedBuffer* options_data = []() {
+ FeatureProcessorOptionsT options;
+ options.context_size = 1;
+ options.max_selection_span = 1;
+ options.snap_label_span_boundaries_to_containing_tokens = false;
+ options.ignored_span_boundary_codepoints.push_back(',');
+
+ options.tokenization_codepoint_config.emplace_back(
+ new TokenizationCodepointRangeT());
+ auto& config = options.tokenization_codepoint_config.back();
+ config->start = 32;
+ config->end = 33;
+ config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR;
+
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(FeatureProcessorOptions::Pack(builder, &options));
+ return new flatbuffers::DetachedBuffer(builder.Release());
+ }();
+
+ const FeatureProcessorOptions* feature_processor_options =
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_data->data());
+
+ return FeatureProcessor(feature_processor_options, unilib);
+}
+
+class DurationAnnotatorTest : public ::testing::Test {
+ protected:
+ DurationAnnotatorTest()
+ : INIT_UNILIB_FOR_TESTING(unilib_),
+ feature_processor_(BuildFeatureProcessor(&unilib_)),
+ duration_annotator_(TestingDurationAnnotatorOptions(),
+ &feature_processor_) {}
+
+ std::vector<Token> Tokenize(const UnicodeText& text) {
+ return feature_processor_.Tokenize(text);
+ }
+
+ UniLib unilib_;
+ FeatureProcessor feature_processor_;
+ DurationAnnotator duration_annotator_;
+};
+
+TEST_F(DurationAnnotatorTest, ClassifiesSimpleDuration) {
+ ClassificationResult classification;
+ EXPECT_TRUE(duration_annotator_.ClassifyText(
+ UTF8ToUnicodeText("Wake me up in 15 minutes ok?"), {14, 24},
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification));
+
+ EXPECT_THAT(classification,
+ AllOf(Field(&ClassificationResult::collection, "duration"),
+ Field(&ClassificationResult::duration_ms, 15 * 60 * 1000)));
+}
+
+TEST_F(DurationAnnotatorTest, ClassifiesWhenTokensDontAlignWithSelection) {
+ ClassificationResult classification;
+ EXPECT_TRUE(duration_annotator_.ClassifyText(
+ UTF8ToUnicodeText("Wake me up in15 minutesok?"), {13, 23},
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification));
+
+ EXPECT_THAT(classification,
+ AllOf(Field(&ClassificationResult::collection, "duration"),
+ Field(&ClassificationResult::duration_ms, 15 * 60 * 1000)));
+}
+
+TEST_F(DurationAnnotatorTest, FindsSimpleDuration) {
+ const UnicodeText text = UTF8ToUnicodeText("Wake me up in 15 minutes ok?");
+ std::vector<Token> tokens = Tokenize(text);
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(duration_annotator_.FindAll(
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+
+ EXPECT_THAT(
+ result,
+ ElementsAre(
+ AllOf(Field(&AnnotatedSpan::span, CodepointSpan(14, 24)),
+ Field(&AnnotatedSpan::classification,
+ ElementsAre(AllOf(
+ Field(&ClassificationResult::collection, "duration"),
+ Field(&ClassificationResult::duration_ms,
+ 15 * 60 * 1000)))))));
+}
+
+TEST_F(DurationAnnotatorTest, FindsDurationWithHalfExpression) {
+ const UnicodeText text =
+ UTF8ToUnicodeText("Set a timer for 3 and half minutes ok?");
+ std::vector<Token> tokens = Tokenize(text);
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(duration_annotator_.FindAll(
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+
+ EXPECT_THAT(
+ result,
+ ElementsAre(
+ AllOf(Field(&AnnotatedSpan::span, CodepointSpan(16, 34)),
+ Field(&AnnotatedSpan::classification,
+ ElementsAre(AllOf(
+ Field(&ClassificationResult::collection, "duration"),
+ Field(&ClassificationResult::duration_ms,
+ 3.5 * 60 * 1000)))))));
+}
+
+TEST_F(DurationAnnotatorTest, FindsComposedDuration) {
+ const UnicodeText text =
+ UTF8ToUnicodeText("Wake me up in 3 hours and 5 seconds ok?");
+ std::vector<Token> tokens = Tokenize(text);
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(duration_annotator_.FindAll(
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+
+ EXPECT_THAT(
+ result,
+ ElementsAre(
+ AllOf(Field(&AnnotatedSpan::span, CodepointSpan(14, 35)),
+ Field(&AnnotatedSpan::classification,
+ ElementsAre(AllOf(
+ Field(&ClassificationResult::collection, "duration"),
+ Field(&ClassificationResult::duration_ms,
+ 3 * 60 * 60 * 1000 + 5 * 1000)))))));
+}
+
+TEST_F(DurationAnnotatorTest, FindsHalfAnHour) {
+ const UnicodeText text = UTF8ToUnicodeText("Set a timer for half an hour");
+ std::vector<Token> tokens = Tokenize(text);
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(duration_annotator_.FindAll(
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+
+ EXPECT_THAT(
+ result,
+ ElementsAre(
+ AllOf(Field(&AnnotatedSpan::span, CodepointSpan(16, 28)),
+ Field(&AnnotatedSpan::classification,
+ ElementsAre(AllOf(
+ Field(&ClassificationResult::collection, "duration"),
+ Field(&ClassificationResult::duration_ms,
+ 0.5 * 60 * 60 * 1000)))))));
+}
+
+TEST_F(DurationAnnotatorTest, FindsWhenHalfIsAfterGranularitySpecification) {
+ const UnicodeText text =
+ UTF8ToUnicodeText("Set a timer for 1 hour and a half");
+ std::vector<Token> tokens = Tokenize(text);
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(duration_annotator_.FindAll(
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+
+ EXPECT_THAT(
+ result,
+ ElementsAre(
+ AllOf(Field(&AnnotatedSpan::span, CodepointSpan(16, 33)),
+ Field(&AnnotatedSpan::classification,
+ ElementsAre(AllOf(
+ Field(&ClassificationResult::collection, "duration"),
+ Field(&ClassificationResult::duration_ms,
+ 1.5 * 60 * 60 * 1000)))))));
+}
+
+TEST_F(DurationAnnotatorTest, FindsAnHourAndAHalf) {
+ const UnicodeText text =
+ UTF8ToUnicodeText("Set a timer for an hour and a half");
+ std::vector<Token> tokens = Tokenize(text);
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(duration_annotator_.FindAll(
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+
+ EXPECT_THAT(
+ result,
+ ElementsAre(
+ AllOf(Field(&AnnotatedSpan::span, CodepointSpan(19, 34)),
+ Field(&AnnotatedSpan::classification,
+ ElementsAre(AllOf(
+ Field(&ClassificationResult::collection, "duration"),
+ Field(&ClassificationResult::duration_ms,
+ 1.5 * 60 * 60 * 1000)))))));
+}
+
+TEST_F(DurationAnnotatorTest,
+ FindsCorrectlyWhenSecondsComeSecondAndDontHaveNumber) {
+ const UnicodeText text =
+ UTF8ToUnicodeText("Set a timer for 10 minutes and a second ok?");
+ std::vector<Token> tokens = Tokenize(text);
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(duration_annotator_.FindAll(
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+
+ EXPECT_THAT(
+ result,
+ ElementsAre(
+ AllOf(Field(&AnnotatedSpan::span, CodepointSpan(16, 39)),
+ Field(&AnnotatedSpan::classification,
+ ElementsAre(AllOf(
+ Field(&ClassificationResult::collection, "duration"),
+ Field(&ClassificationResult::duration_ms,
+ 10 * 60 * 1000 + 1 * 1000)))))));
+}
+
+TEST_F(DurationAnnotatorTest, DoesNotGreedilyTakeFillerWords) {
+ const UnicodeText text = UTF8ToUnicodeText(
+ "Set a timer for a a a 10 minutes and 2 seconds an and an ok?");
+ std::vector<Token> tokens = Tokenize(text);
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(duration_annotator_.FindAll(
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+
+ EXPECT_THAT(
+ result,
+ ElementsAre(
+ AllOf(Field(&AnnotatedSpan::span, CodepointSpan(22, 46)),
+ Field(&AnnotatedSpan::classification,
+ ElementsAre(AllOf(
+ Field(&ClassificationResult::collection, "duration"),
+ Field(&ClassificationResult::duration_ms,
+ 10 * 60 * 1000 + 2 * 1000)))))));
+}
+
+TEST_F(DurationAnnotatorTest, DoesNotCrashWhenJustHalfIsSaid) {
+ const UnicodeText text = UTF8ToUnicodeText("Set a timer for half ok?");
+ std::vector<Token> tokens = Tokenize(text);
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(duration_annotator_.FindAll(
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+
+ ASSERT_EQ(result.size(), 0);
+}
+
+TEST_F(DurationAnnotatorTest, StripsPunctuationFromTokens) {
+ const UnicodeText text =
+ UTF8ToUnicodeText("Set a timer for 10 ,minutes, ,and, ,2, seconds, ok?");
+ std::vector<Token> tokens = Tokenize(text);
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(duration_annotator_.FindAll(
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+
+ EXPECT_THAT(
+ result,
+ ElementsAre(
+ AllOf(Field(&AnnotatedSpan::span, CodepointSpan(16, 46)),
+ Field(&AnnotatedSpan::classification,
+ ElementsAre(AllOf(
+ Field(&ClassificationResult::collection, "duration"),
+ Field(&ClassificationResult::duration_ms,
+ 10 * 60 * 1000 + 2 * 1000)))))));
+}
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/annotator/entity-data.fbs b/annotator/entity-data.fbs
new file mode 100755
index 0000000..2143e28
--- /dev/null
+++ b/annotator/entity-data.fbs
@@ -0,0 +1,69 @@
+//
+// Copyright (C) 2018 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+namespace libtextclassifier3.EntityData_.Datetime_;
+enum Granularity : int {
+ GRANULARITY_UNKNOWN = -1,
+ GRANULARITY_YEAR = 0,
+ GRANULARITY_MONTH = 1,
+ GRANULARITY_WEEK = 2,
+ GRANULARITY_DAY = 3,
+ GRANULARITY_HOUR = 4,
+ GRANULARITY_MINUTE = 5,
+ GRANULARITY_SECOND = 6,
+}
+
+namespace libtextclassifier3.EntityData_;
+table Datetime {
+ time_ms_utc:long;
+ granularity:Datetime_.Granularity = GRANULARITY_UNKNOWN;
+}
+
+namespace libtextclassifier3.EntityData_;
+table Contact {
+ name:string;
+ given_name:string;
+ nickname:string;
+ email_address:string;
+ phone_number:string;
+ contact_id:string;
+}
+
+namespace libtextclassifier3.EntityData_;
+table App {
+ name:string;
+ package_name:string;
+}
+
+// Represents an entity annotated in text.
+namespace libtextclassifier3;
+table EntityData {
+ // Codepoint indices of the annotation, start is inclusive, end is
+ // exclusive.
+ start:int;
+
+ end:int;
+
+ // The entity type, as in the TextClassifier APIs.
+ type:string;
+
+ datetime:EntityData_.Datetime;
+ reserved_5:int (deprecated);
+ contact:EntityData_.Contact;
+ app:EntityData_.App;
+}
+
+root_type libtextclassifier3.EntityData;
diff --git a/annotator/feature-processor.cc b/annotator/feature-processor.cc
index a18393b..c0f5c82 100644
--- a/annotator/feature-processor.cc
+++ b/annotator/feature-processor.cc
@@ -28,6 +28,29 @@
namespace internal {
+Tokenizer BuildTokenizer(const FeatureProcessorOptions* options,
+ const UniLib* unilib) {
+ std::vector<const TokenizationCodepointRange*> codepoint_config;
+ if (options->tokenization_codepoint_config() != nullptr) {
+ codepoint_config.insert(codepoint_config.end(),
+ options->tokenization_codepoint_config()->begin(),
+ options->tokenization_codepoint_config()->end());
+ }
+ std::vector<const CodepointRange*> internal_codepoint_config;
+ if (options->internal_tokenizer_codepoint_ranges() != nullptr) {
+ internal_codepoint_config.insert(
+ internal_codepoint_config.end(),
+ options->internal_tokenizer_codepoint_ranges()->begin(),
+ options->internal_tokenizer_codepoint_ranges()->end());
+ }
+ const bool tokenize_on_script_change =
+ options->tokenization_codepoint_config() != nullptr &&
+ options->tokenize_on_script_change();
+ return Tokenizer(options->tokenization_type(), unilib, codepoint_config,
+ internal_codepoint_config, tokenize_on_script_change,
+ options->icu_preserve_whitespace_tokens());
+}
+
TokenFeatureExtractorOptions BuildTokenFeatureExtractorOptions(
const FeatureProcessorOptions* const options) {
TokenFeatureExtractorOptions extractor_options;
@@ -166,33 +189,12 @@
}
std::vector<Token> FeatureProcessor::Tokenize(const std::string& text) const {
- const UnicodeText text_unicode = UTF8ToUnicodeText(text, /*do_copy=*/false);
- return Tokenize(text_unicode);
+ return tokenizer_.Tokenize(text);
}
std::vector<Token> FeatureProcessor::Tokenize(
const UnicodeText& text_unicode) const {
- if (options_->tokenization_type() ==
- FeatureProcessorOptions_::TokenizationType_INTERNAL_TOKENIZER) {
- return tokenizer_.Tokenize(text_unicode);
- } else if (options_->tokenization_type() ==
- FeatureProcessorOptions_::TokenizationType_ICU ||
- options_->tokenization_type() ==
- FeatureProcessorOptions_::TokenizationType_MIXED) {
- std::vector<Token> result;
- if (!ICUTokenize(text_unicode, &result)) {
- return {};
- }
- if (options_->tokenization_type() ==
- FeatureProcessorOptions_::TokenizationType_MIXED) {
- InternalRetokenize(text_unicode, &result);
- }
- return result;
- } else {
- TC3_LOG(ERROR) << "Unknown tokenization type specified. Using "
- "internal.";
- return tokenizer_.Tokenize(text_unicode);
- }
+ return tokenizer_.Tokenize(text_unicode);
}
bool FeatureProcessor::LabelToSpan(
@@ -471,25 +473,6 @@
return true;
}
-void FeatureProcessor::PrepareCodepointRanges(
- const std::vector<const FeatureProcessorOptions_::CodepointRange*>&
- codepoint_ranges,
- std::vector<CodepointRange>* prepared_codepoint_ranges) {
- prepared_codepoint_ranges->clear();
- prepared_codepoint_ranges->reserve(codepoint_ranges.size());
- for (const FeatureProcessorOptions_::CodepointRange* range :
- codepoint_ranges) {
- prepared_codepoint_ranges->push_back(
- CodepointRange(range->start(), range->end()));
- }
-
- std::sort(prepared_codepoint_ranges->begin(),
- prepared_codepoint_ranges->end(),
- [](const CodepointRange& a, const CodepointRange& b) {
- return a.start < b.start;
- });
-}
-
void FeatureProcessor::PrepareIgnoredSpanBoundaryCodepoints() {
if (options_->ignored_span_boundary_codepoints() != nullptr) {
for (const int codepoint : *options_->ignored_span_boundary_codepoints()) {
@@ -591,6 +574,16 @@
UnicodeText::const_iterator span_end = context_unicode.begin();
std::advance(span_end, span.second);
+ return StripBoundaryCodepoints(span_begin, span_end, span);
+}
+
+CodepointSpan FeatureProcessor::StripBoundaryCodepoints(
+ const UnicodeText::const_iterator& span_begin,
+ const UnicodeText::const_iterator& span_end, CodepointSpan span) const {
+ if (!ValidNonEmptySpan(span) || span_begin == span_end) {
+ return span;
+ }
+
const int start_offset = CountIgnoredSpanBoundaryCodepoints(
span_begin, span_end, /*count_from_beginning=*/true);
const int end_offset = CountIgnoredSpanBoundaryCodepoints(
@@ -620,32 +613,21 @@
return static_cast<float>(num_supported) / static_cast<float>(num_total);
}
-bool FeatureProcessor::IsCodepointInRanges(
- int codepoint, const std::vector<CodepointRange>& codepoint_ranges) const {
- auto it = std::lower_bound(codepoint_ranges.begin(), codepoint_ranges.end(),
- codepoint,
- [](const CodepointRange& range, int codepoint) {
- // This function compares range with the
- // codepoint for the purpose of finding the first
- // greater or equal range. Because of the use of
- // std::lower_bound it needs to return true when
- // range < codepoint; the first time it will
- // return false the lower bound is found and
- // returned.
- //
- // It might seem weird that the condition is
- // range.end <= codepoint here but when codepoint
- // == range.end it means it's actually just
- // outside of the range, thus the range is less
- // than the codepoint.
- return range.end <= codepoint;
- });
- if (it != codepoint_ranges.end() && it->start <= codepoint &&
- it->end > codepoint) {
- return true;
- } else {
- return false;
+const std::string& FeatureProcessor::StripBoundaryCodepoints(
+ const std::string& value, std::string* buffer) const {
+ const UnicodeText value_unicode = UTF8ToUnicodeText(value, /*do_copy=*/false);
+ const CodepointSpan initial_span{0, value_unicode.size_codepoints()};
+ const CodepointSpan stripped_span =
+ StripBoundaryCodepoints(value_unicode, initial_span);
+
+ if (initial_span != stripped_span) {
+ const UnicodeText stripped_token_value =
+ UnicodeText::Substring(value_unicode, stripped_span.first,
+ stripped_span.second, /*do_copy=*/false);
+ *buffer = stripped_token_value.ToUTF8String();
+ return *buffer;
}
+ return value;
}
int FeatureProcessor::CollectionToLabel(const std::string& collection) const {
@@ -813,113 +795,6 @@
return true;
}
-bool FeatureProcessor::ICUTokenize(const UnicodeText& context_unicode,
- std::vector<Token>* result) const {
- std::unique_ptr<UniLib::BreakIterator> break_iterator =
- unilib_->CreateBreakIterator(context_unicode);
- if (!break_iterator) {
- return false;
- }
- int last_break_index = 0;
- int break_index = 0;
- int last_unicode_index = 0;
- int unicode_index = 0;
- auto token_begin_it = context_unicode.begin();
- while ((break_index = break_iterator->Next()) !=
- UniLib::BreakIterator::kDone) {
- const int token_length = break_index - last_break_index;
- unicode_index = last_unicode_index + token_length;
-
- auto token_end_it = token_begin_it;
- std::advance(token_end_it, token_length);
-
- // Determine if the whole token is whitespace.
- bool is_whitespace = true;
- for (auto char_it = token_begin_it; char_it < token_end_it; ++char_it) {
- if (!unilib_->IsWhitespace(*char_it)) {
- is_whitespace = false;
- break;
- }
- }
-
- const std::string token =
- context_unicode.UTF8Substring(token_begin_it, token_end_it);
-
- if (!is_whitespace || options_->icu_preserve_whitespace_tokens()) {
- result->push_back(Token(token, last_unicode_index, unicode_index));
- }
-
- last_break_index = break_index;
- last_unicode_index = unicode_index;
- token_begin_it = token_end_it;
- }
-
- return true;
-}
-
-void FeatureProcessor::InternalRetokenize(const UnicodeText& unicode_text,
- std::vector<Token>* tokens) const {
- std::vector<Token> result;
- CodepointSpan span(-1, -1);
- for (Token& token : *tokens) {
- const UnicodeText unicode_token_value =
- UTF8ToUnicodeText(token.value, /*do_copy=*/false);
- bool should_retokenize = true;
- for (const int codepoint : unicode_token_value) {
- if (!IsCodepointInRanges(codepoint,
- internal_tokenizer_codepoint_ranges_)) {
- should_retokenize = false;
- break;
- }
- }
-
- if (should_retokenize) {
- if (span.first < 0) {
- span.first = token.start;
- }
- span.second = token.end;
- } else {
- TokenizeSubstring(unicode_text, span, &result);
- span.first = -1;
- result.emplace_back(std::move(token));
- }
- }
- TokenizeSubstring(unicode_text, span, &result);
-
- *tokens = std::move(result);
-}
-
-void FeatureProcessor::TokenizeSubstring(const UnicodeText& unicode_text,
- CodepointSpan span,
- std::vector<Token>* result) const {
- if (span.first < 0) {
- // There is no span to tokenize.
- return;
- }
-
- // Extract the substring.
- UnicodeText::const_iterator it_begin = unicode_text.begin();
- for (int i = 0; i < span.first; ++i) {
- ++it_begin;
- }
- UnicodeText::const_iterator it_end = unicode_text.begin();
- for (int i = 0; i < span.second; ++i) {
- ++it_end;
- }
- const std::string text = unicode_text.UTF8Substring(it_begin, it_end);
-
- // Run the tokenizer and update the token bounds to reflect the offset of the
- // substring.
- std::vector<Token> tokens = tokenizer_.Tokenize(text);
- // Avoids progressive capacity increases in the for loop.
- result->reserve(result->size() + tokens.size());
- for (Token& token : tokens) {
- token.start += span.first;
- token.end += span.first;
- result->emplace_back(std::move(token));
- }
-}
-
bool FeatureProcessor::AppendTokenFeaturesWithCache(
const Token& token, CodepointSpan selection_span_for_feature,
const EmbeddingExecutor* embedding_executor,
diff --git a/annotator/feature-processor.h b/annotator/feature-processor.h
index 2d04253..4a753b0 100644
--- a/annotator/feature-processor.h
+++ b/annotator/feature-processor.h
@@ -27,11 +27,11 @@
#include "annotator/cached-features.h"
#include "annotator/model_generated.h"
-#include "annotator/token-feature-extractor.h"
-#include "annotator/tokenizer.h"
#include "annotator/types.h"
#include "utils/base/integral_types.h"
#include "utils/base/logging.h"
+#include "utils/token-feature-extractor.h"
+#include "utils/tokenizer.h"
#include "utils/utf8/unicodetext.h"
#include "utils/utf8/unilib.h"
@@ -41,6 +41,9 @@
namespace internal {
+Tokenizer BuildTokenizer(const FeatureProcessorOptions* options,
+ const UniLib* unilib);
+
TokenFeatureExtractorOptions BuildTokenFeatureExtractorOptions(
const FeatureProcessorOptions* options);
@@ -89,27 +92,15 @@
typedef std::map<CodepointSpan, std::vector<float>> EmbeddingCache;
FeatureProcessor(const FeatureProcessorOptions* options, const UniLib* unilib)
- : unilib_(unilib),
- feature_extractor_(internal::BuildTokenFeatureExtractorOptions(options),
- *unilib_),
+ : feature_extractor_(internal::BuildTokenFeatureExtractorOptions(options),
+ *unilib),
options_(options),
- tokenizer_(
- options->tokenization_codepoint_config() != nullptr
- ? Tokenizer({options->tokenization_codepoint_config()->begin(),
- options->tokenization_codepoint_config()->end()},
- options->tokenize_on_script_change())
- : Tokenizer({}, /*split_on_script_change=*/false)) {
+ tokenizer_(internal::BuildTokenizer(options, unilib)) {
MakeLabelMaps();
if (options->supported_codepoint_ranges() != nullptr) {
- PrepareCodepointRanges({options->supported_codepoint_ranges()->begin(),
- options->supported_codepoint_ranges()->end()},
- &supported_codepoint_ranges_);
- }
- if (options->internal_tokenizer_codepoint_ranges() != nullptr) {
- PrepareCodepointRanges(
- {options->internal_tokenizer_codepoint_ranges()->begin(),
- options->internal_tokenizer_codepoint_ranges()->end()},
- &internal_tokenizer_codepoint_ranges_);
+ SortCodepointRanges({options->supported_codepoint_ranges()->begin(),
+ options->supported_codepoint_ranges()->end()},
+ &supported_codepoint_ranges_);
}
PrepareIgnoredSpanBoundaryCodepoints();
}
@@ -190,16 +181,18 @@
CodepointSpan StripBoundaryCodepoints(const UnicodeText& context_unicode,
CodepointSpan span) const;
+ // Same as above but takes a pair of iterators for the span, for efficiency.
+ CodepointSpan StripBoundaryCodepoints(
+ const UnicodeText::const_iterator& span_begin,
+ const UnicodeText::const_iterator& span_end, CodepointSpan span) const;
+
+ // Same as above, but takes an optional buffer for saving the modified value.
+ // As an optimization, returns pointer to 'value' if nothing was stripped, or
+ // pointer to 'buffer' if something was stripped.
+ const std::string& StripBoundaryCodepoints(const std::string& value,
+ std::string* buffer) const;
+
protected:
- // Represents a codepoint range [start, end).
- struct CodepointRange {
- int32 start;
- int32 end;
-
- CodepointRange(int32 arg_start, int32 arg_end)
- : start(arg_start), end(arg_end) {}
- };
-
// Returns the class id corresponding to the given string collection
// identifier. There is a catch-all class id that the function returns for
// unknown collections.
@@ -227,21 +220,11 @@
// Converts a token span to the corresponding label.
int TokenSpanToLabel(const std::pair<TokenIndex, TokenIndex>& span) const;
- void PrepareCodepointRanges(
- const std::vector<const FeatureProcessorOptions_::CodepointRange*>&
- codepoint_ranges,
- std::vector<CodepointRange>* prepared_codepoint_ranges);
-
// Returns the ratio of supported codepoints to total number of codepoints in
// the given token span.
float SupportedCodepointsRatio(const TokenSpan& token_span,
const std::vector<Token>& tokens) const;
- // Returns true if given codepoint is covered by the given sorted vector of
- // codepoint ranges.
- bool IsCodepointInRanges(
- int codepoint, const std::vector<CodepointRange>& codepoint_ranges) const;
-
void PrepareIgnoredSpanBoundaryCodepoints();
// Counts the number of span boundary codepoints. If count_from_beginning is
@@ -259,21 +242,6 @@
int FindCenterToken(CodepointSpan span,
const std::vector<Token>& tokens) const;
- // Tokenizes the input text using ICU tokenizer.
- bool ICUTokenize(const UnicodeText& context_unicode,
- std::vector<Token>* result) const;
-
- // Takes the result of ICU tokenization and retokenizes stretches of tokens
- // made of a specific subset of characters using the internal tokenizer.
- void InternalRetokenize(const UnicodeText& unicode_text,
- std::vector<Token>* tokens) const;
-
- // Tokenizes a substring of the unicode string, appending the resulting tokens
- // to the output vector. The resulting tokens have bounds relative to the full
- // string. Does nothing if the start of the span is negative.
- void TokenizeSubstring(const UnicodeText& unicode_text, CodepointSpan span,
- std::vector<Token>* result) const;
-
// Removes all tokens from tokens that are not on a line (defined by calling
// SplitContext on the context) to which span points.
void StripTokensFromOtherLines(const std::string& context, CodepointSpan span,
@@ -293,21 +261,12 @@
EmbeddingCache* embedding_cache,
std::vector<float>* output_features) const;
- private:
- const UniLib* unilib_;
-
protected:
const TokenFeatureExtractor feature_extractor_;
// Codepoint ranges that define what codepoints are supported by the model.
// NOTE: Must be sorted.
- std::vector<CodepointRange> supported_codepoint_ranges_;
-
- // Codepoint ranges that define which tokens (consisting of which codepoints)
- // should be re-tokenized with the internal tokenizer in the mixed
- // tokenization mode.
- // NOTE: Must be sorted.
- std::vector<CodepointRange> internal_tokenizer_codepoint_ranges_;
+ std::vector<CodepointRangeStruct> supported_codepoint_ranges_;
private:
// Set of codepoints that will be stripped from beginning and end of
diff --git a/annotator/feature-processor_test.cc b/annotator/feature-processor_test.cc
index c9f0e0d..5337776 100644
--- a/annotator/feature-processor_test.cc
+++ b/annotator/feature-processor_test.cc
@@ -53,8 +53,6 @@
public:
using FeatureProcessor::CountIgnoredSpanBoundaryCodepoints;
using FeatureProcessor::FeatureProcessor;
- using FeatureProcessor::ICUTokenize;
- using FeatureProcessor::IsCodepointInRanges;
using FeatureProcessor::SpanToLabel;
using FeatureProcessor::StripTokensFromOtherLines;
using FeatureProcessor::supported_codepoint_ranges_;
@@ -531,24 +529,21 @@
config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR;
{
- options.supported_codepoint_ranges.emplace_back(
- new FeatureProcessorOptions_::CodepointRangeT());
+ options.supported_codepoint_ranges.emplace_back(new CodepointRangeT());
auto& range = options.supported_codepoint_ranges.back();
range->start = 0;
range->end = 128;
}
{
- options.supported_codepoint_ranges.emplace_back(
- new FeatureProcessorOptions_::CodepointRangeT());
+ options.supported_codepoint_ranges.emplace_back(new CodepointRangeT());
auto& range = options.supported_codepoint_ranges.back();
range->start = 10000;
range->end = 10001;
}
{
- options.supported_codepoint_ranges.emplace_back(
- new FeatureProcessorOptions_::CodepointRangeT());
+ options.supported_codepoint_ranges.emplace_back(new CodepointRangeT());
auto& range = options.supported_codepoint_ranges.back();
range->start = 20000;
range->end = 30000;
@@ -567,23 +562,23 @@
EXPECT_THAT(feature_processor.SupportedCodepointsRatio(
{0, 3}, feature_processor.Tokenize("ěěě řřř ěěě")),
FloatEq(0.0));
- EXPECT_FALSE(feature_processor.IsCodepointInRanges(
- -1, feature_processor.supported_codepoint_ranges_));
- EXPECT_TRUE(feature_processor.IsCodepointInRanges(
- 0, feature_processor.supported_codepoint_ranges_));
- EXPECT_TRUE(feature_processor.IsCodepointInRanges(
- 10, feature_processor.supported_codepoint_ranges_));
- EXPECT_TRUE(feature_processor.IsCodepointInRanges(
- 127, feature_processor.supported_codepoint_ranges_));
- EXPECT_FALSE(feature_processor.IsCodepointInRanges(
- 128, feature_processor.supported_codepoint_ranges_));
- EXPECT_FALSE(feature_processor.IsCodepointInRanges(
- 9999, feature_processor.supported_codepoint_ranges_));
- EXPECT_TRUE(feature_processor.IsCodepointInRanges(
+ EXPECT_FALSE(
+ IsCodepointInRanges(-1, feature_processor.supported_codepoint_ranges_));
+ EXPECT_TRUE(
+ IsCodepointInRanges(0, feature_processor.supported_codepoint_ranges_));
+ EXPECT_TRUE(
+ IsCodepointInRanges(10, feature_processor.supported_codepoint_ranges_));
+ EXPECT_TRUE(
+ IsCodepointInRanges(127, feature_processor.supported_codepoint_ranges_));
+ EXPECT_FALSE(
+ IsCodepointInRanges(128, feature_processor.supported_codepoint_ranges_));
+ EXPECT_FALSE(
+ IsCodepointInRanges(9999, feature_processor.supported_codepoint_ranges_));
+ EXPECT_TRUE(IsCodepointInRanges(
10000, feature_processor.supported_codepoint_ranges_));
- EXPECT_FALSE(feature_processor.IsCodepointInRanges(
+ EXPECT_FALSE(IsCodepointInRanges(
10001, feature_processor.supported_codepoint_ranges_));
- EXPECT_TRUE(feature_processor.IsCodepointInRanges(
+ EXPECT_TRUE(IsCodepointInRanges(
25000, feature_processor.supported_codepoint_ranges_));
const std::vector<Token> tokens = {Token("ěěě", 0, 3), Token("řřř", 4, 7),
@@ -834,151 +829,6 @@
EXPECT_EQ(click_index, 5);
}
-TEST_F(FeatureProcessorTest, InternalTokenizeOnScriptChange) {
- FeatureProcessorOptionsT options;
- options.tokenization_codepoint_config.emplace_back(
- new TokenizationCodepointRangeT());
- {
- auto& config = options.tokenization_codepoint_config.back();
- config->start = 0;
- config->end = 256;
- config->role = TokenizationCodepointRange_::Role_DEFAULT_ROLE;
- config->script_id = 1;
- }
- options.tokenize_on_script_change = false;
-
- flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
- TestingFeatureProcessor feature_processor(
- flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
- &unilib_);
-
- EXPECT_EQ(feature_processor.Tokenize("앨라배마123웹사이트"),
- std::vector<Token>({Token("앨라배마123웹사이트", 0, 11)}));
-
- options.tokenize_on_script_change = true;
- flatbuffers::DetachedBuffer options_fb2 =
- PackFeatureProcessorOptions(options);
- TestingFeatureProcessor feature_processor2(
- flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb2.data()),
- &unilib_);
-
- EXPECT_EQ(feature_processor2.Tokenize("앨라배마123웹사이트"),
- std::vector<Token>({Token("앨라배마", 0, 4), Token("123", 4, 7),
- Token("웹사이트", 7, 11)}));
-}
-
-#ifdef TC3_TEST_ICU
-TEST_F(FeatureProcessorTest, ICUTokenize) {
- FeatureProcessorOptionsT options;
- options.tokenization_type = FeatureProcessorOptions_::TokenizationType_ICU;
-
- flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
- UniLib unilib;
- TestingFeatureProcessor feature_processor(
- flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
- &unilib);
- std::vector<Token> tokens = feature_processor.Tokenize("พระบาทสมเด็จพระปรมิ");
- ASSERT_EQ(tokens,
- // clang-format off
- std::vector<Token>({Token("พระบาท", 0, 6),
- Token("สมเด็จ", 6, 12),
- Token("พระ", 12, 15),
- Token("ปร", 15, 17),
- Token("มิ", 17, 19)}));
- // clang-format on
-}
-#endif
-
-#ifdef TC3_TEST_ICU
-TEST_F(FeatureProcessorTest, ICUTokenizeWithWhitespaces) {
- FeatureProcessorOptionsT options;
- options.tokenization_type = FeatureProcessorOptions_::TokenizationType_ICU;
- options.icu_preserve_whitespace_tokens = true;
-
- flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
- UniLib unilib;
- TestingFeatureProcessor feature_processor(
- flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
- &unilib);
- std::vector<Token> tokens =
- feature_processor.Tokenize("พระบาท สมเด็จ พระ ปร มิ");
- ASSERT_EQ(tokens,
- // clang-format off
- std::vector<Token>({Token("พระบาท", 0, 6),
- Token(" ", 6, 7),
- Token("สมเด็จ", 7, 13),
- Token(" ", 13, 14),
- Token("พระ", 14, 17),
- Token(" ", 17, 18),
- Token("ปร", 18, 20),
- Token(" ", 20, 21),
- Token("มิ", 21, 23)}));
- // clang-format on
-}
-#endif
-
-#ifdef TC3_TEST_ICU
-TEST_F(FeatureProcessorTest, MixedTokenize) {
- FeatureProcessorOptionsT options;
- options.tokenization_type = FeatureProcessorOptions_::TokenizationType_MIXED;
-
- options.tokenization_codepoint_config.emplace_back(
- new TokenizationCodepointRangeT());
- auto& config = options.tokenization_codepoint_config.back();
- config->start = 32;
- config->end = 33;
- config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR;
-
- {
- options.internal_tokenizer_codepoint_ranges.emplace_back(
- new FeatureProcessorOptions_::CodepointRangeT());
- auto& range = options.internal_tokenizer_codepoint_ranges.back();
- range->start = 0;
- range->end = 128;
- }
-
- {
- options.internal_tokenizer_codepoint_ranges.emplace_back(
- new FeatureProcessorOptions_::CodepointRangeT());
- auto& range = options.internal_tokenizer_codepoint_ranges.back();
- range->start = 128;
- range->end = 256;
- }
-
- {
- options.internal_tokenizer_codepoint_ranges.emplace_back(
- new FeatureProcessorOptions_::CodepointRangeT());
- auto& range = options.internal_tokenizer_codepoint_ranges.back();
- range->start = 256;
- range->end = 384;
- }
-
- {
- options.internal_tokenizer_codepoint_ranges.emplace_back(
- new FeatureProcessorOptions_::CodepointRangeT());
- auto& range = options.internal_tokenizer_codepoint_ranges.back();
- range->start = 384;
- range->end = 592;
- }
-
- flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
- UniLib unilib;
- TestingFeatureProcessor feature_processor(
- flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
- &unilib);
- std::vector<Token> tokens = feature_processor.Tokenize(
- "こんにちはJapanese-ląnguagę text 世界 http://www.google.com/");
- ASSERT_EQ(tokens,
- // clang-format off
- std::vector<Token>({Token("こんにちは", 0, 5),
- Token("Japanese-ląnguagę", 5, 22),
- Token("text", 23, 27),
- Token("世界", 28, 30),
- Token("http://www.google.com/", 31, 53)}));
- // clang-format on
-}
-#endif
-
TEST_F(FeatureProcessorTest, IgnoredSpanBoundaryCodepoints) {
FeatureProcessorOptionsT options;
options.ignored_span_boundary_codepoints.push_back('.');
diff --git a/annotator/installed_app/installed-app-engine-dummy.h b/annotator/installed_app/installed-app-engine-dummy.h
new file mode 100644
index 0000000..2f2b62f
--- /dev/null
+++ b/annotator/installed_app/installed-app-engine-dummy.h
@@ -0,0 +1,56 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_INSTALLED_APP_INSTALLED_APP_ENGINE_DUMMY_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_INSTALLED_APP_INSTALLED_APP_ENGINE_DUMMY_H_
+
+#include <string>
+#include <vector>
+
+#include "annotator/feature-processor.h"
+#include "annotator/types.h"
+#include "utils/base/logging.h"
+#include "utils/utf8/unicodetext.h"
+#include "utils/utf8/unilib.h"
+
+namespace libtextclassifier3 {
+
+// A dummy implementation of the installed app engine.
+class InstalledAppEngine {
+ public:
+ explicit InstalledAppEngine(const FeatureProcessor* feature_processor,
+ const UniLib* unilib) {}
+
+ bool Initialize(const std::string& serialized_config) {
+ TC3_LOG(ERROR) << "No installed app engine to initialize.";
+ return false;
+ }
+
+ bool ClassifyText(const std::string& context, CodepointSpan selection_indices,
+ ClassificationResult* classification_result) const {
+ return false;
+ }
+
+ bool Chunk(const UnicodeText& context_unicode,
+ const std::vector<Token>& tokens,
+ std::vector<AnnotatedSpan>* result) const {
+ return true;
+ }
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_INSTALLED_APP_INSTALLED_APP_ENGINE_DUMMY_H_
diff --git a/annotator/installed_app/installed-app-engine.h b/annotator/installed_app/installed-app-engine.h
new file mode 100644
index 0000000..d05d357
--- /dev/null
+++ b/annotator/installed_app/installed-app-engine.h
@@ -0,0 +1,22 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_INSTALLED_APP_INSTALLED_APP_ENGINE_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_INSTALLED_APP_INSTALLED_APP_ENGINE_H_
+
+#include "annotator/installed_app/installed-app-engine-dummy.h"
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_INSTALLED_APP_INSTALLED_APP_ENGINE_H_
diff --git a/annotator/knowledge/knowledge-engine-dummy.h b/annotator/knowledge/knowledge-engine-dummy.h
index a6285dc..96d77c5 100644
--- a/annotator/knowledge/knowledge-engine-dummy.h
+++ b/annotator/knowledge/knowledge-engine-dummy.h
@@ -40,6 +40,11 @@
std::vector<AnnotatedSpan>* result) const {
return true;
}
+
+ bool LookUpEntity(const std::string& id,
+ std::string* serialized_knowledge_result) const {
+ return false;
+ }
};
} // namespace libtextclassifier3
diff --git a/annotator/model-executor.cc b/annotator/model-executor.cc
index 7c57e8f..5466cc6 100644
--- a/annotator/model-executor.cc
+++ b/annotator/model-executor.cc
@@ -44,7 +44,8 @@
std::unique_ptr<TFLiteEmbeddingExecutor> TFLiteEmbeddingExecutor::FromBuffer(
const flatbuffers::Vector<uint8_t>* model_spec_buffer, int embedding_size,
- int quantization_bits) {
+ int quantization_bits,
+ const Model_::EmbeddingPruningMask* embedding_pruning_mask) {
std::unique_ptr<TfLiteModelExecutor> executor =
TfLiteModelExecutor::FromBuffer(model_spec_buffer);
if (!executor) {
@@ -81,14 +82,16 @@
return std::unique_ptr<TFLiteEmbeddingExecutor>(new TFLiteEmbeddingExecutor(
std::move(executor), quantization_bits, num_buckets, bytes_per_embedding,
- embedding_size, scales, embeddings, std::move(interpreter)));
+ embedding_size, scales, embeddings, std::move(interpreter),
+ embedding_pruning_mask));
}
TFLiteEmbeddingExecutor::TFLiteEmbeddingExecutor(
std::unique_ptr<TfLiteModelExecutor> executor, int quantization_bits,
int num_buckets, int bytes_per_embedding, int output_embedding_size,
const TfLiteTensor* scales, const TfLiteTensor* embeddings,
- std::unique_ptr<tflite::Interpreter> interpreter)
+ std::unique_ptr<tflite::Interpreter> interpreter,
+ const Model_::EmbeddingPruningMask* embedding_pruning_mask)
: executor_(std::move(executor)),
quantization_bits_(quantization_bits),
num_buckets_(num_buckets),
@@ -96,7 +99,55 @@
output_embedding_size_(output_embedding_size),
scales_(scales),
embeddings_(embeddings),
- interpreter_(std::move(interpreter)) {}
+ interpreter_(std::move(interpreter)) {
+ if ((embedding_pruning_mask != nullptr) &&
+ (embedding_pruning_mask->enabled())) {
+ for (int i = 0; i < embedding_pruning_mask->pruning_mask()->size(); i++) {
+ pruning_mask_.push_back((*(embedding_pruning_mask->pruning_mask()))[i]);
+ }
+ ComputePrefixCounts();
+ full_num_buckets_ = embedding_pruning_mask->full_num_buckets();
+ pruned_row_bucket_id_ = embedding_pruning_mask->pruned_row_bucket_id();
+ } else {
+ full_num_buckets_ = num_buckets;
+ }
+}
+
+void TFLiteEmbeddingExecutor::ComputePrefixCounts() {
+ // Pre-compute the prefix sums.
+ // For each i in {0, 1,...,pruning_mask_.size()-1}, we compute number of 1s
+ // in binary representations of the uint64 values in pruning_mask_ before
+ // index i. We set pruned_row_bucket_id_ to the total number of 1s
+ // in binary representations of all values in pruning_mask_.
+ int count = 0;
+ for (const uint64 mask : pruning_mask_) {
+ prefix_counts_.push_back(count);
+ count += __builtin_popcountll(mask);
+ }
+}
+
+int TFLiteEmbeddingExecutor::PruneBucketId(int bucket_id) const {
+ // Implements auxiliary data structure for computing the pruned index of a
+ // given bucket_id.
+ // If bucket_id is present in pruning_mask_, we compute floor(bucket_id/64),
+ // look it up in the auxiliary array prefix_counts_, and add to it the number
+ // of 1s before before bucket_id % 64 in the 64-bit sequence
+ // pruning_mask_[floor(bucket_id/64)].
+ // If bucket_id is absent from pruning_mask_, we return pruned_row_bucket_id_.
+ const int bucket_id_major = bucket_id >> 6;
+ const int bucket_id_minor = bucket_id & 63;
+ uint64_t one = 1;
+ if (!(pruning_mask_[bucket_id_major] & (one << bucket_id_minor)))
+ return pruned_row_bucket_id_;
+ const uint64 zero = 0;
+ uint64 minor_mask;
+ if (bucket_id_minor == 0)
+ minor_mask = zero;
+ else
+ minor_mask = ((~zero) >> (64 - bucket_id_minor));
+ return prefix_counts_[bucket_id_major] +
+ __builtin_popcountll(pruning_mask_[bucket_id_major] & minor_mask);
+}
bool TFLiteEmbeddingExecutor::AddEmbedding(
const TensorView<int>& sparse_features, float* dest, int dest_size) const {
@@ -108,13 +159,24 @@
const int num_sparse_features = sparse_features.size();
for (int i = 0; i < num_sparse_features; ++i) {
const int bucket_id = sparse_features.data()[i];
- if (bucket_id >= num_buckets_) {
+ int full_num_buckets;
+ if (!pruning_mask_.empty()) {
+ full_num_buckets = full_num_buckets_;
+ } else {
+ full_num_buckets = num_buckets_;
+ }
+ if (bucket_id >= full_num_buckets) {
return false;
}
-
+ int final_bucket_id;
+ if (!pruning_mask_.empty()) {
+ final_bucket_id = PruneBucketId(bucket_id);
+ } else {
+ final_bucket_id = bucket_id;
+ }
if (!DequantizeAdd(scales_->data.f, embeddings_->data.uint8,
bytes_per_embedding_, num_sparse_features,
- quantization_bits_, bucket_id, dest, dest_size)) {
+ quantization_bits_, final_bucket_id, dest, dest_size)) {
return false;
}
}
diff --git a/annotator/model-executor.h b/annotator/model-executor.h
index 5ad3a7f..bcc318b 100644
--- a/annotator/model-executor.h
+++ b/annotator/model-executor.h
@@ -78,19 +78,28 @@
public:
static std::unique_ptr<TFLiteEmbeddingExecutor> FromBuffer(
const flatbuffers::Vector<uint8_t>* model_spec_buffer, int embedding_size,
- int quantization_bits);
+ int quantization_bits,
+ const Model_::EmbeddingPruningMask* embedding_pruning_mask = nullptr);
// Embeds the sparse_features into a dense embedding and adds (+) it
// element-wise to the dest vector.
bool AddEmbedding(const TensorView<int>& sparse_features, float* dest,
int dest_size) const;
+ // Auxiliary function for computing prefixes used in implementation of
+ // efficient mask indexing data structure.
+ void ComputePrefixCounts();
+
+ // Function implementing mask indexing based on efficient data structure
+ int PruneBucketId(int bucket_id) const;
+
protected:
explicit TFLiteEmbeddingExecutor(
std::unique_ptr<TfLiteModelExecutor> executor, int quantization_bits,
int num_buckets, int bytes_per_embedding, int output_embedding_size,
const TfLiteTensor* scales, const TfLiteTensor* embeddings,
- std::unique_ptr<tflite::Interpreter> interpreter);
+ std::unique_ptr<tflite::Interpreter> interpreter,
+ const Model_::EmbeddingPruningMask* embedding_pruning_mask = nullptr);
std::unique_ptr<TfLiteModelExecutor> executor_;
@@ -104,6 +113,13 @@
// NOTE: This interpreter is used in a read-only way (as a storage for the
// model params), thus is still thread-safe.
std::unique_ptr<tflite::Interpreter> interpreter_;
+
+ std::vector<uint64> pruning_mask_;
+ std::vector<uint16> prefix_counts_;
+ int full_num_buckets_ = -1;
+
+ // Index of row of embedding table corresponding to all pruned buckets.
+ int pruned_row_bucket_id_ = -1;
};
} // namespace libtextclassifier3
diff --git a/annotator/model.fbs b/annotator/model.fbs
index 3682994..9d18779 100755
--- a/annotator/model.fbs
+++ b/annotator/model.fbs
@@ -14,7 +14,11 @@
// limitations under the License.
//
+include "utils/codepoint-range.fbs";
+include "utils/flatbuffers.fbs";
include "utils/intents/intent-config.fbs";
+include "utils/resources.fbs";
+include "utils/tokenizer.fbs";
include "utils/zlib/buffer.fbs";
file_identifier "TC2 ";
@@ -32,6 +36,17 @@
ALL = 7,
}
+// Enum for specifying the annotation usecase.
+namespace libtextclassifier3;
+enum AnnotationUsecase : int {
+ // Results are optimized for Smart{Select,Share,Linkify}.
+ ANNOTATION_USECASE_SMART = 0,
+
+ // Results are optimized for using TextClassifier as an infrastructure that
+ // annotates as much as possible.
+ ANNOTATION_USECASE_RAW = 1,
+}
+
namespace libtextclassifier3;
enum DatetimeExtractorType : int {
UNKNOWN_DATETIME_EXTRACTOR_TYPE = 0,
@@ -71,7 +86,10 @@
DAYS = 34,
WEEKS = 35,
MONTHS = 36,
+
+ // TODO(zilka): Make the following 3 values singular for consistency.
HOURS = 37,
+
MINUTES = 38,
SECONDS = 39,
YEARS = 40,
@@ -171,6 +189,22 @@
namespace libtextclassifier3;
table VerificationOptions {
verify_luhn_checksum:bool = false;
+
+ // Lua verifier to use.
+ // Index of the lua verifier in the model.
+ lua_verifier:int = -1;
+}
+
+// Behaviour of capturing groups.
+namespace libtextclassifier3.RegexModel_.Pattern_;
+table CapturingGroup {
+ // If true, the span of the capturing group will be used to
+ // extend the selection.
+ extend_selection:bool = true;
+
+ // If set, the text of the capturing group will be used to set a field in
+ // the classfication result entity data.
+ entity_field_path:FlatbufferFieldPath;
}
// List of regular expression matchers to check.
@@ -180,11 +214,10 @@
collection_name:string;
// The pattern to check.
- // Can specify a single capturing group used as match boundaries.
pattern:string;
// The modes for which to apply the patterns.
- enabled_modes:libtextclassifier3.ModeFlag = ALL;
+ enabled_modes:ModeFlag = ALL;
// The final score to assign to the results of this pattern.
target_classification_score:float = 1;
@@ -197,15 +230,34 @@
// use the first Find() result and then check that it spans the whole input.
use_approximate_matching:bool = false;
- compressed_pattern:libtextclassifier3.CompressedBuffer;
+ compressed_pattern:CompressedBuffer;
// Verification to apply on a match.
- verification_options:libtextclassifier3.VerificationOptions;
+ verification_options:VerificationOptions;
+
+ capturing_group:[Pattern_.CapturingGroup];
+
+ // Serialized entity data to set for a match.
+ serialized_entity_data:string;
}
namespace libtextclassifier3;
table RegexModel {
- patterns:[libtextclassifier3.RegexModel_.Pattern];
+ patterns:[RegexModel_.Pattern];
+
+ // If true, will compile the regexes only on first use.
+ lazy_regex_compilation:bool = true;
+
+ // Lua scripts for match verification.
+ // The verifier can access:
+ // * `context`: The context as a string.
+ // * `match`: The groups of the regex match as an array, each group gives
+ // * `begin`: span start
+ // * `end`: span end
+ // * `text`: the text
+ // The verifier is expected to return a boolean, indicating whether the
+ // verification succeeded or not.
+ lua_verifier:[string];
}
// List of regex patterns.
@@ -215,14 +267,14 @@
// The ith entry specifies the type of the ith capturing group.
// This is used to decide how the matched content has to be parsed.
- groups:[libtextclassifier3.DatetimeGroupType];
+ groups:[DatetimeGroupType];
- compressed_pattern:libtextclassifier3.CompressedBuffer;
+ compressed_pattern:CompressedBuffer;
}
namespace libtextclassifier3;
table DatetimeModelPattern {
- regexes:[libtextclassifier3.DatetimeModelPattern_.Regex];
+ regexes:[DatetimeModelPattern_.Regex];
// List of locale indices in DatetimeModel that represent the locales that
// these patterns should be used for. If empty, can be used for all locales.
@@ -235,15 +287,19 @@
priority_score:float = 0;
// The modes for which to apply the patterns.
- enabled_modes:libtextclassifier3.ModeFlag = ALL;
+ enabled_modes:ModeFlag = ALL;
+
+ // The annotation usecases for which to apply the patterns.
+ // This is a flag field for values of AnnotationUsecase.
+ enabled_annotation_usecases:uint = 4294967295;
}
namespace libtextclassifier3;
table DatetimeModelExtractor {
- extractor:libtextclassifier3.DatetimeExtractorType;
+ extractor:DatetimeExtractorType;
pattern:string;
locales:[int];
- compressed_pattern:libtextclassifier3.CompressedBuffer;
+ compressed_pattern:CompressedBuffer;
}
namespace libtextclassifier3;
@@ -252,8 +308,8 @@
// model. The individual patterns refer back to them using an index.
locales:[string];
- patterns:[libtextclassifier3.DatetimeModelPattern];
- extractors:[libtextclassifier3.DatetimeModelExtractor];
+ patterns:[DatetimeModelPattern];
+ extractors:[DatetimeModelExtractor];
// If true, will use the extractors for determining the match location as
// opposed to using the location where the global pattern matched.
@@ -262,18 +318,25 @@
// List of locale ids, rules of whose are always run, after the requested
// ones.
default_locales:[int];
+
+ // If true, will generate the alternative interpretations for ambiguous
+ // datetime expressions.
+ generate_alternative_interpretations_when_ambiguous:bool = false;
+
+ // If true, will compile the regexes only on first use.
+ lazy_regex_compilation:bool = true;
}
namespace libtextclassifier3.DatetimeModelLibrary_;
table Item {
key:string;
- value:libtextclassifier3.DatetimeModel;
+ value:DatetimeModel;
}
// A set of named DateTime models.
namespace libtextclassifier3;
table DatetimeModelLibrary {
- models:[libtextclassifier3.DatetimeModelLibrary_.Item];
+ models:[DatetimeModelLibrary_.Item];
}
// Options controlling the output of the Tensorflow Lite models.
@@ -283,7 +346,16 @@
min_annotate_confidence:float = 0;
// The modes for which to enable the models.
- enabled_modes:libtextclassifier3.ModeFlag = ALL;
+ enabled_modes:ModeFlag = ALL;
+
+ // Comma-separated list of locales (BCP 47 tags) that dictionary
+ // classification supports.
+ dictionary_locales:string;
+
+ // Comma-separated list of locales (BCP 47 tags) that the model supports, that
+ // are used to prevent triggering on input in unsupported languages. If
+ // empty, the model will trigger on all inputs.
+ locales:string;
}
// Options controlling the output of the classifier.
@@ -300,6 +372,23 @@
filtered_collections_selection:[string];
}
+namespace libtextclassifier3.Model_;
+table EmbeddingPruningMask {
+ // If true, use pruning mask. In this case, we use mask
+ // pruning_mask to determine the mapping of hashed-charactergrams.
+ enabled:bool;
+
+ // Packing of the binary pruning mask into uint64 values.
+ pruning_mask:[ulong] (force_align: 16);
+
+ // Number of buckets before pruning.
+ full_num_buckets:int;
+
+ // Index of row of compressed embedding matrix to which all pruned buckets
+ // are mapped.
+ pruned_row_bucket_id:int;
+}
+
namespace libtextclassifier3;
table Model {
// Comma-separated list of locales supported by the model as BCP 47 tags.
@@ -310,8 +399,8 @@
// A name for the model that can be used for e.g. logging.
name:string;
- selection_feature_options:libtextclassifier3.FeatureProcessorOptions;
- classification_feature_options:libtextclassifier3.FeatureProcessorOptions;
+ selection_feature_options:FeatureProcessorOptions;
+ classification_feature_options:FeatureProcessorOptions;
// Tensorflow Lite models.
selection_model:[ubyte] (force_align: 16);
@@ -320,18 +409,18 @@
embedding_model:[ubyte] (force_align: 16);
// Options for the different models.
- selection_options:libtextclassifier3.SelectionModelOptions;
+ selection_options:SelectionModelOptions;
- classification_options:libtextclassifier3.ClassificationModelOptions;
- regex_model:libtextclassifier3.RegexModel;
- datetime_model:libtextclassifier3.DatetimeModel;
+ classification_options:ClassificationModelOptions;
+ regex_model:RegexModel;
+ datetime_model:DatetimeModel;
// Options controlling the output of the models.
- triggering_options:libtextclassifier3.ModelTriggeringOptions;
+ triggering_options:ModelTriggeringOptions;
// Global switch that controls if SuggestSelection(), ClassifyText() and
// Annotate() will run. If a mode is disabled it returns empty/no-op results.
- enabled_modes:libtextclassifier3.ModeFlag = ALL;
+ enabled_modes:ModeFlag = ALL;
// If true, will snap the selections that consist only of whitespaces to the
// containing suggested span. Otherwise, no suggestion is proposed, since the
@@ -340,50 +429,28 @@
// Global configuration for the output of SuggestSelection(), ClassifyText()
// and Annotate().
- output_options:libtextclassifier3.OutputOptions;
+ output_options:OutputOptions;
// Configures how Intents should be generated on Android.
- // TODO(smillius): Remove deprecated factory options.
- android_intent_options:libtextclassifier3.AndroidIntentFactoryOptions;
+ android_intent_options:AndroidIntentFactoryOptions;
- intent_options:libtextclassifier3.IntentFactoryModel;
-}
+ intent_options:IntentFactoryModel;
-// Role of the codepoints in the range.
-namespace libtextclassifier3.TokenizationCodepointRange_;
-enum Role : int {
- // Concatenates the codepoint to the current run of codepoints.
- DEFAULT_ROLE = 0,
+ // Model resources.
+ resources:ResourcePool;
- // Splits a run of codepoints before the current codepoint.
- SPLIT_BEFORE = 1,
+ // Schema data for handling entity data.
+ entity_data_schema:[ubyte];
- // Splits a run of codepoints after the current codepoint.
- SPLIT_AFTER = 2,
+ number_annotator_options:NumberAnnotatorOptions;
+ duration_annotator_options:DurationAnnotatorOptions;
- // Each codepoint will be a separate token. Good e.g. for Chinese
- // characters.
- TOKEN_SEPARATOR = 3,
+ // Comma-separated list of locales (BCP 47 tags) that the model supports, that
+ // are used to prevent triggering on input in unsupported languages. If
+ // empty, the model will trigger on all inputs.
+ triggering_locales:string;
- // Discards the codepoint.
- DISCARD_CODEPOINT = 4,
-
- // Common values:
- // Splits on the characters and discards them. Good e.g. for the space
- // character.
- WHITESPACE_SEPARATOR = 7,
-}
-
-// Represents a codepoint range [start, end) with its role for tokenization.
-namespace libtextclassifier3;
-table TokenizationCodepointRange {
- start:int;
- end:int;
- role:libtextclassifier3.TokenizationCodepointRange_.Role;
-
- // Integer identifier of the script this range denotes. Negative values are
- // reserved for Tokenizer's internal use.
- script_id:int;
+ embedding_pruning_mask:Model_.EmbeddingPruningMask;
}
// Method for selecting the center token.
@@ -399,30 +466,6 @@
CENTER_TOKEN_MIDDLE_OF_SELECTION = 2,
}
-// Controls the type of tokenization the model will use for the input text.
-namespace libtextclassifier3.FeatureProcessorOptions_;
-enum TokenizationType : int {
- INVALID_TOKENIZATION_TYPE = 0,
-
- // Use the internal tokenizer for tokenization.
- INTERNAL_TOKENIZER = 1,
-
- // Use ICU for tokenization.
- ICU = 2,
-
- // First apply ICU tokenization. Then identify stretches of tokens
- // consisting only of codepoints in internal_tokenizer_codepoint_ranges
- // and re-tokenize them using the internal tokenizer.
- MIXED = 3,
-}
-
-// Range of codepoints start - end, where end is exclusive.
-namespace libtextclassifier3.FeatureProcessorOptions_;
-table CodepointRange {
- start:int;
- end:int;
-}
-
// Bounds-sensitive feature extraction configuration.
namespace libtextclassifier3.FeatureProcessorOptions_;
table BoundsSensitiveFeatures {
@@ -530,20 +573,20 @@
// Codepoint ranges that determine how different codepoints are tokenized.
// The ranges must not overlap.
- tokenization_codepoint_config:[libtextclassifier3.TokenizationCodepointRange];
+ tokenization_codepoint_config:[TokenizationCodepointRange];
- center_token_selection_method:libtextclassifier3.FeatureProcessorOptions_.CenterTokenSelectionMethod;
+ center_token_selection_method:FeatureProcessorOptions_.CenterTokenSelectionMethod;
// If true, span boundaries will be snapped to containing tokens and not
// required to exactly match token boundaries.
snap_label_span_boundaries_to_containing_tokens:bool;
// A set of codepoint ranges supported by the model.
- supported_codepoint_ranges:[libtextclassifier3.FeatureProcessorOptions_.CodepointRange];
+ supported_codepoint_ranges:[CodepointRange];
// A set of codepoint ranges to use in the mixed tokenization mode to identify
// stretches of tokens to re-tokenize using the internal tokenizer.
- internal_tokenizer_codepoint_ranges:[libtextclassifier3.FeatureProcessorOptions_.CodepointRange];
+ internal_tokenizer_codepoint_ranges:[CodepointRange];
// Minimum ratio of supported codepoints in the input context. If the ratio
// is lower than this, the feature computation will fail.
@@ -559,14 +602,14 @@
// to it. So the resulting feature vector has two regions.
feature_version:int = 0;
- tokenization_type:libtextclassifier3.FeatureProcessorOptions_.TokenizationType = INTERNAL_TOKENIZER;
+ tokenization_type:TokenizationType = INTERNAL_TOKENIZER;
icu_preserve_whitespace_tokens:bool = false;
// List of codepoints that will be stripped from beginning and end of
// predicted spans.
ignored_span_boundary_codepoints:[int];
- bounds_sensitive_features:libtextclassifier3.FeatureProcessorOptions_.BoundsSensitiveFeatures;
+ bounds_sensitive_features:FeatureProcessorOptions_.BoundsSensitiveFeatures;
// List of allowed charactergrams. The extracted charactergrams are filtered
// using this list, and charactergrams that are not present are interpreted as
@@ -580,4 +623,67 @@
tokenize_on_script_change:bool = false;
}
+namespace libtextclassifier3;
+table NumberAnnotatorOptions {
+ // If true, number annotations will be produced.
+ enabled:bool = false;
+
+ // Score to assign to the annotated numbers from the annotator.
+ score:float = 1;
+
+ // Priority score used for conflict resolution with the other models.
+ priority_score:float = 0;
+
+ // The modes in which to enable number annotations.
+ enabled_modes:ModeFlag = ALL;
+
+ // The annotation usecases for which to produce number annotations.
+ // This is a flag field for values of AnnotationUsecase.
+ enabled_annotation_usecases:uint = 4294967295;
+
+ // A list of codepoints that can form a prefix of a valid number.
+ allowed_prefix_codepoints:[int];
+
+ // A list of codepoints that can form a suffix of a valid number.
+ allowed_suffix_codepoints:[int];
+}
+
+// DurationAnnotator is so far tailored for English only.
+namespace libtextclassifier3;
+table DurationAnnotatorOptions {
+ // If true, duration annotations will be produced.
+ enabled:bool = false;
+
+ // Score to assign to the annotated durations from the annotator.
+ score:float = 1;
+
+ // Priority score used for conflict resolution with the other models.
+ priority_score:float = 0;
+
+ // The modes in which to enable duration annotations.
+ enabled_modes:ModeFlag = ALL;
+
+ // The annotation usecases for which to produce duration annotations.
+ enabled_annotation_usecases:uint = 4294967295;
+
+ // Durations typically look like XX hours and XX minutes etc... The list of
+ // strings below enumerate variants of "hours", "minutes", etc. in these
+ // expressions. These are verbatim strings that are matched against tokens in
+ // the input.
+ week_expressions:[string];
+
+ day_expressions:[string];
+ hour_expressions:[string];
+ minute_expressions:[string];
+ second_expressions:[string];
+
+ // List of expressions that doesn't break a duration expression (can become
+ // a part of it) but has not semantic meaning.
+ filler_expressions:[string];
+
+ // List of expressions that mean half of a unit of duration (e.g. "half an
+ // hour").
+ half_expressions:[string];
+}
+
root_type libtextclassifier3.Model;
diff --git a/annotator/number/number.cc b/annotator/number/number.cc
new file mode 100644
index 0000000..bc3a2fe
--- /dev/null
+++ b/annotator/number/number.cc
@@ -0,0 +1,187 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/number/number.h"
+
+#include <climits>
+#include <cstdlib>
+
+#include "annotator/collections.h"
+#include "utils/base/logging.h"
+
+namespace libtextclassifier3 {
+
+bool NumberAnnotator::ClassifyText(
+ const UnicodeText& context, CodepointSpan selection_indices,
+ AnnotationUsecase annotation_usecase,
+ ClassificationResult* classification_result) const {
+ int64 parsed_value;
+ int num_prefix_codepoints;
+ int num_suffix_codepoints;
+ if (ParseNumber(UnicodeText::Substring(context, selection_indices.first,
+ selection_indices.second),
+ &parsed_value, &num_prefix_codepoints,
+ &num_suffix_codepoints)) {
+ ClassificationResult classification{Collections::Number(), 1.0};
+ TC3_CHECK(classification_result != nullptr);
+ classification_result->collection = Collections::Number();
+ classification_result->score = options_->score();
+ classification_result->priority_score = options_->priority_score();
+ classification_result->numeric_value = parsed_value;
+ return true;
+ }
+ return false;
+}
+
+bool NumberAnnotator::FindAll(const UnicodeText& context,
+ AnnotationUsecase annotation_usecase,
+ std::vector<AnnotatedSpan>* result) const {
+ if (!options_->enabled() || ((1 << annotation_usecase) &
+ options_->enabled_annotation_usecases()) == 0) {
+ return true;
+ }
+
+ const std::vector<Token> tokens = feature_processor_->Tokenize(context);
+ for (const Token& token : tokens) {
+ const UnicodeText token_text =
+ UTF8ToUnicodeText(token.value, /*do_copy=*/false);
+ int64 parsed_value;
+ int num_prefix_codepoints;
+ int num_suffix_codepoints;
+ if (ParseNumber(token_text, &parsed_value, &num_prefix_codepoints,
+ &num_suffix_codepoints)) {
+ ClassificationResult classification{Collections::Number(),
+ options_->score()};
+ classification.numeric_value = parsed_value;
+ classification.priority_score = options_->priority_score();
+
+ AnnotatedSpan annotated_span;
+ annotated_span.span = {token.start + num_prefix_codepoints,
+ token.end - num_suffix_codepoints};
+ annotated_span.classification.push_back(classification);
+
+ result->push_back(annotated_span);
+ }
+ }
+
+ return true;
+}
+
+std::unordered_set<int> NumberAnnotator::FlatbuffersVectorToSet(
+ const flatbuffers::Vector<int32_t>* codepoints) {
+ if (codepoints == nullptr) {
+ return std::unordered_set<int>{};
+ }
+
+ std::unordered_set<int> result;
+ for (const int codepoint : *codepoints) {
+ result.insert(codepoint);
+ }
+ return result;
+}
+
+namespace {
+UnicodeText::const_iterator ConsumeAndParseNumber(
+ const UnicodeText::const_iterator& it_begin,
+ const UnicodeText::const_iterator& it_end, int64* result) {
+ *result = 0;
+
+ // See if there's a sign in the beginning of the number.
+ int sign = 1;
+ auto it = it_begin;
+ if (it != it_end) {
+ if (*it == '-') {
+ ++it;
+ sign = -1;
+ } else if (*it == '+') {
+ ++it;
+ sign = 1;
+ }
+ }
+
+ while (it != it_end) {
+ if (*it >= '0' && *it <= '9') {
+ // When overflow is imminent we'll fail to parse the number.
+ if (*result > INT64_MAX / 10) {
+ return it_begin;
+ }
+ *result *= 10;
+ *result += *it - '0';
+ } else {
+ *result *= sign;
+ return it;
+ }
+
+ ++it;
+ }
+
+ *result *= sign;
+ return it_end;
+}
+} // namespace
+
+bool NumberAnnotator::ParseNumber(const UnicodeText& text, int64* result,
+ int* num_prefix_codepoints,
+ int* num_suffix_codepoints) const {
+ TC3_CHECK(result != nullptr && num_prefix_codepoints != nullptr &&
+ num_suffix_codepoints != nullptr);
+ auto it = text.begin();
+ auto it_end = text.end();
+
+ // Strip boundary codepoints from both ends.
+ const CodepointSpan original_span{0, text.size_codepoints()};
+ const CodepointSpan stripped_span =
+ feature_processor_->StripBoundaryCodepoints(text, original_span);
+ const int num_stripped_end = (original_span.second - stripped_span.second);
+ std::advance(it, stripped_span.first);
+ std::advance(it_end, -num_stripped_end);
+
+ // Consume prefix codepoints.
+ *num_prefix_codepoints = stripped_span.first;
+ while (it != text.end()) {
+ if (allowed_prefix_codepoints_.find(*it) ==
+ allowed_prefix_codepoints_.end()) {
+ break;
+ }
+
+ ++it;
+ ++(*num_prefix_codepoints);
+ }
+
+ auto it_start = it;
+ it = ConsumeAndParseNumber(it, text.end(), result);
+ if (it == it_start) {
+ return false;
+ }
+
+ // Consume suffix codepoints.
+ bool valid_suffix = true;
+ *num_suffix_codepoints = 0;
+ while (it != it_end) {
+ if (allowed_suffix_codepoints_.find(*it) ==
+ allowed_suffix_codepoints_.end()) {
+ valid_suffix = false;
+ break;
+ }
+
+ ++it;
+ ++(*num_suffix_codepoints);
+ }
+ *num_suffix_codepoints += num_stripped_end;
+ return valid_suffix;
+}
+
+} // namespace libtextclassifier3
diff --git a/annotator/number/number.h b/annotator/number/number.h
new file mode 100644
index 0000000..488f5ea
--- /dev/null
+++ b/annotator/number/number.h
@@ -0,0 +1,78 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_NUMBER_NUMBER_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_NUMBER_NUMBER_H_
+
+#include <string>
+#include <unordered_set>
+#include <vector>
+
+#include "annotator/feature-processor.h"
+#include "annotator/model_generated.h"
+#include "annotator/types.h"
+#include "utils/utf8/unicodetext.h"
+
+namespace libtextclassifier3 {
+
+// Annotator of numbers in text.
+//
+// Only supports values in range [-999 999 999, 999 999 999] (inclusive).
+//
+// TODO(zilka): Add support for non-ASCII digits.
+// TODO(zilka): Add support for written-out numbers.
+class NumberAnnotator {
+ public:
+ explicit NumberAnnotator(const NumberAnnotatorOptions* options,
+ const FeatureProcessor* feature_processor)
+ : options_(options),
+ feature_processor_(feature_processor),
+ allowed_prefix_codepoints_(
+ FlatbuffersVectorToSet(options->allowed_prefix_codepoints())),
+ allowed_suffix_codepoints_(
+ FlatbuffersVectorToSet(options->allowed_suffix_codepoints())) {}
+
+ // Classifies given text, and if it is a number, it passes the result in
+ // 'classification_result' and returns true, otherwise returns false.
+ bool ClassifyText(const UnicodeText& context, CodepointSpan selection_indices,
+ AnnotationUsecase annotation_usecase,
+ ClassificationResult* classification_result) const;
+
+ // Finds all number instances in the input text.
+ bool FindAll(const UnicodeText& context_unicode,
+ AnnotationUsecase annotation_usecase,
+ std::vector<AnnotatedSpan>* result) const;
+
+ private:
+ static std::unordered_set<int> FlatbuffersVectorToSet(
+ const flatbuffers::Vector<int32_t>* codepoints);
+
+ // Parses the text to an int64 value and returns true if succeeded, otherwise
+ // false. Also returns the number of prefix/suffix codepoints that were
+ // stripped from the number.
+ bool ParseNumber(const UnicodeText& text, int64* result,
+ int* num_prefix_codepoints,
+ int* num_suffix_codepoints) const;
+
+ const NumberAnnotatorOptions* options_;
+ const FeatureProcessor* feature_processor_;
+ const std::unordered_set<int> allowed_prefix_codepoints_;
+ const std::unordered_set<int> allowed_suffix_codepoints_;
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_NUMBER_NUMBER_H_
diff --git a/annotator/number/number_test.cc b/annotator/number/number_test.cc
new file mode 100644
index 0000000..d3b2e8c
--- /dev/null
+++ b/annotator/number/number_test.cc
@@ -0,0 +1,258 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/number/number.h"
+
+#include <string>
+#include <vector>
+
+#include "annotator/collections.h"
+#include "annotator/model_generated.h"
+#include "annotator/types-test-util.h"
+#include "annotator/types.h"
+#include "utils/test-utils.h"
+#include "utils/utf8/unicodetext.h"
+#include "utils/utf8/unilib.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+using testing::AllOf;
+using testing::ElementsAre;
+using testing::Field;
+
+const NumberAnnotatorOptions* TestingNumberAnnotatorOptions() {
+ static const flatbuffers::DetachedBuffer* options_data = []() {
+ NumberAnnotatorOptionsT options;
+ options.enabled = true;
+ options.allowed_prefix_codepoints.push_back('$');
+ options.allowed_suffix_codepoints.push_back('%');
+
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(NumberAnnotatorOptions::Pack(builder, &options));
+ return new flatbuffers::DetachedBuffer(builder.Release());
+ }();
+
+ return flatbuffers::GetRoot<NumberAnnotatorOptions>(options_data->data());
+}
+
+FeatureProcessor BuildFeatureProcessor(const UniLib* unilib) {
+ static const flatbuffers::DetachedBuffer* options_data = []() {
+ FeatureProcessorOptionsT options;
+ options.context_size = 1;
+ options.max_selection_span = 1;
+ options.snap_label_span_boundaries_to_containing_tokens = false;
+ options.ignored_span_boundary_codepoints.push_back(',');
+
+ options.tokenization_codepoint_config.emplace_back(
+ new TokenizationCodepointRangeT());
+ auto& config = options.tokenization_codepoint_config.back();
+ config->start = 32;
+ config->end = 33;
+ config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR;
+
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(FeatureProcessorOptions::Pack(builder, &options));
+ return new flatbuffers::DetachedBuffer(builder.Release());
+ }();
+
+ const FeatureProcessorOptions* feature_processor_options =
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_data->data());
+
+ return FeatureProcessor(feature_processor_options, unilib);
+}
+
+class NumberAnnotatorTest : public ::testing::Test {
+ protected:
+ NumberAnnotatorTest()
+ : INIT_UNILIB_FOR_TESTING(unilib_),
+ feature_processor_(BuildFeatureProcessor(&unilib_)),
+ number_annotator_(TestingNumberAnnotatorOptions(),
+ &feature_processor_) {}
+
+ UniLib unilib_;
+ FeatureProcessor feature_processor_;
+ NumberAnnotator number_annotator_;
+};
+
+TEST_F(NumberAnnotatorTest, ClassifiesAndParsesNumberCorrectly) {
+ ClassificationResult classification_result;
+ EXPECT_TRUE(number_annotator_.ClassifyText(
+ UTF8ToUnicodeText("... 12345 ..."), {4, 9},
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
+
+ EXPECT_EQ(classification_result.collection, "number");
+ EXPECT_EQ(classification_result.numeric_value, 12345);
+}
+
+TEST_F(NumberAnnotatorTest, ClassifiesNonNumberCorrectly) {
+ ClassificationResult classification_result;
+ EXPECT_FALSE(number_annotator_.ClassifyText(
+ UTF8ToUnicodeText("... 123a45 ..."), {4, 10},
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
+}
+
+TEST_F(NumberAnnotatorTest, FindsAllNumbersInText) {
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(number_annotator_.FindAll(
+ UTF8ToUnicodeText("... 12345 ... 9 is my number and I paid $99 and "
+ "sometimes 27% but not 68# nor #68"),
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+
+ ASSERT_EQ(result.size(), 4);
+ ASSERT_EQ(result[0].classification.size(), 1);
+ EXPECT_EQ(result[0].classification[0].collection, "number");
+ EXPECT_EQ(result[0].classification[0].numeric_value, 12345);
+ ASSERT_EQ(result[1].classification.size(), 1);
+ EXPECT_EQ(result[1].classification[0].collection, "number");
+ EXPECT_EQ(result[1].classification[0].numeric_value, 9);
+ ASSERT_EQ(result[2].classification.size(), 1);
+ EXPECT_EQ(result[2].classification[0].collection, "number");
+ EXPECT_EQ(result[2].classification[0].numeric_value, 99);
+ ASSERT_EQ(result[3].classification.size(), 1);
+ EXPECT_EQ(result[3].classification[0].collection, "number");
+ EXPECT_EQ(result[3].classification[0].numeric_value, 27);
+}
+
+TEST_F(NumberAnnotatorTest, FindsNumberWithPunctuation) {
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(number_annotator_.FindAll(
+ UTF8ToUnicodeText("Come at 9, ok?"),
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+
+ EXPECT_THAT(
+ result,
+ ElementsAre(
+ AllOf(Field(&AnnotatedSpan::span, CodepointSpan(8, 9)),
+ Field(&AnnotatedSpan::classification,
+ ElementsAre(AllOf(
+ Field(&ClassificationResult::collection, "number"),
+ Field(&ClassificationResult::numeric_value, 9)))))));
+}
+
+TEST_F(NumberAnnotatorTest, HandlesNumbersAtBeginning) {
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(number_annotator_.FindAll(
+ UTF8ToUnicodeText("-5"), AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ &result));
+
+ EXPECT_THAT(
+ result,
+ ElementsAre(
+ AllOf(Field(&AnnotatedSpan::span, CodepointSpan(0, 2)),
+ Field(&AnnotatedSpan::classification,
+ ElementsAre(AllOf(
+ Field(&ClassificationResult::collection, "number"),
+ Field(&ClassificationResult::numeric_value, -5)))))));
+}
+
+TEST_F(NumberAnnotatorTest, WhenLowestSupportedNumberParsesIt) {
+ ClassificationResult classification_result;
+ EXPECT_TRUE(number_annotator_.ClassifyText(
+ UTF8ToUnicodeText("-999999999999999999"), {0, 19},
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
+
+ EXPECT_THAT(
+ classification_result,
+ AllOf(Field(&ClassificationResult::collection, "number"),
+ Field(&ClassificationResult::numeric_value, -999999999999999999L)));
+}
+
+TEST_F(NumberAnnotatorTest, WhenLargestSupportedNumberParsesIt) {
+ ClassificationResult classification_result;
+ EXPECT_TRUE(number_annotator_.ClassifyText(
+ UTF8ToUnicodeText("999999999999999999"), {0, 18},
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
+
+ EXPECT_THAT(
+ classification_result,
+ AllOf(Field(&ClassificationResult::collection, "number"),
+ Field(&ClassificationResult::numeric_value, 999999999999999999L)));
+}
+
+TEST_F(NumberAnnotatorTest, WhenFirstLowestNonSupportedNumberDoesNotParseIt) {
+ ClassificationResult classification_result;
+ EXPECT_FALSE(number_annotator_.ClassifyText(
+ UTF8ToUnicodeText("-10000000000000000000"), {0, 21},
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
+}
+
+TEST_F(NumberAnnotatorTest, WhenFirstLargestNonSupportedNumberDoesNotParseIt) {
+ ClassificationResult classification_result;
+ EXPECT_FALSE(number_annotator_.ClassifyText(
+ UTF8ToUnicodeText("10000000000000000000"), {0, 20},
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
+}
+
+TEST_F(NumberAnnotatorTest, WhenLargeNumberDoesNotParseIt) {
+ ClassificationResult classification_result;
+ EXPECT_FALSE(number_annotator_.ClassifyText(
+ UTF8ToUnicodeText("1234567890123456789012345678901234567890"), {0, 40},
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
+}
+
+TEST_F(NumberAnnotatorTest, WhenMultipleMinusSignsDoesNotParseIt) {
+ ClassificationResult classification_result;
+ EXPECT_FALSE(number_annotator_.ClassifyText(
+ UTF8ToUnicodeText("--10"), {0, 4},
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
+}
+
+TEST_F(NumberAnnotatorTest, WhenMinusSignSuffixDoesNotParseIt) {
+ ClassificationResult classification_result;
+ EXPECT_FALSE(number_annotator_.ClassifyText(
+ UTF8ToUnicodeText("10-"), {0, 3},
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
+}
+
+TEST_F(NumberAnnotatorTest, WhenMinusInTheMiddleDoesNotParseIt) {
+ ClassificationResult classification_result;
+ EXPECT_FALSE(number_annotator_.ClassifyText(
+ UTF8ToUnicodeText("2016-2017"), {0, 9},
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
+}
+
+TEST_F(NumberAnnotatorTest, WhenSuffixWithoutNumberDoesNotParseIt) {
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(number_annotator_.FindAll(
+ UTF8ToUnicodeText("... % ..."), AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ &result));
+
+ ASSERT_EQ(result.size(), 0);
+}
+
+TEST_F(NumberAnnotatorTest, WhenPrefixWithoutNumberDoesNotParseIt) {
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(number_annotator_.FindAll(
+ UTF8ToUnicodeText("... $ ..."), AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ &result));
+
+ ASSERT_EQ(result.size(), 0);
+}
+
+TEST_F(NumberAnnotatorTest, WhenPrefixAndSuffixWithoutNumberDoesNotParseIt) {
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(number_annotator_.FindAll(
+ UTF8ToUnicodeText("... $% ..."), AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ &result));
+
+ ASSERT_EQ(result.size(), 0);
+}
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/annotator/test_data/test_model.fb b/annotator/test_data/test_model.fb
index fa9cec5..0f2ec16 100644
--- a/annotator/test_data/test_model.fb
+++ b/annotator/test_data/test_model.fb
Binary files differ
diff --git a/annotator/test_data/test_model_cc.fb b/annotator/test_data/test_model_cc.fb
deleted file mode 100644
index b73d84f..0000000
--- a/annotator/test_data/test_model_cc.fb
+++ /dev/null
Binary files differ
diff --git a/annotator/test_data/wrong_embeddings.fb b/annotator/test_data/wrong_embeddings.fb
index ba71cdd..5439623 100644
--- a/annotator/test_data/wrong_embeddings.fb
+++ b/annotator/test_data/wrong_embeddings.fb
Binary files differ
diff --git a/annotator/tokenizer.cc b/annotator/tokenizer.cc
deleted file mode 100644
index 099dccc..0000000
--- a/annotator/tokenizer.cc
+++ /dev/null
@@ -1,126 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "annotator/tokenizer.h"
-
-#include <algorithm>
-
-#include "utils/base/logging.h"
-#include "utils/strings/utf8.h"
-
-namespace libtextclassifier3 {
-
-Tokenizer::Tokenizer(
- const std::vector<const TokenizationCodepointRange*>& codepoint_ranges,
- bool split_on_script_change)
- : split_on_script_change_(split_on_script_change) {
- for (const TokenizationCodepointRange* range : codepoint_ranges) {
- codepoint_ranges_.emplace_back(range->UnPack());
- }
-
- std::sort(codepoint_ranges_.begin(), codepoint_ranges_.end(),
- [](const std::unique_ptr<const TokenizationCodepointRangeT>& a,
- const std::unique_ptr<const TokenizationCodepointRangeT>& b) {
- return a->start < b->start;
- });
-}
-
-const TokenizationCodepointRangeT* Tokenizer::FindTokenizationRange(
- int codepoint) const {
- auto it = std::lower_bound(
- codepoint_ranges_.begin(), codepoint_ranges_.end(), codepoint,
- [](const std::unique_ptr<const TokenizationCodepointRangeT>& range,
- int codepoint) {
- // This function compares range with the codepoint for the purpose of
- // finding the first greater or equal range. Because of the use of
- // std::lower_bound it needs to return true when range < codepoint;
- // the first time it will return false the lower bound is found and
- // returned.
- //
- // It might seem weird that the condition is range.end <= codepoint
- // here but when codepoint == range.end it means it's actually just
- // outside of the range, thus the range is less than the codepoint.
- return range->end <= codepoint;
- });
- if (it != codepoint_ranges_.end() && (*it)->start <= codepoint &&
- (*it)->end > codepoint) {
- return it->get();
- } else {
- return nullptr;
- }
-}
-
-void Tokenizer::GetScriptAndRole(char32 codepoint,
- TokenizationCodepointRange_::Role* role,
- int* script) const {
- const TokenizationCodepointRangeT* range = FindTokenizationRange(codepoint);
- if (range) {
- *role = range->role;
- *script = range->script_id;
- } else {
- *role = TokenizationCodepointRange_::Role_DEFAULT_ROLE;
- *script = kUnknownScript;
- }
-}
-
-std::vector<Token> Tokenizer::Tokenize(const std::string& text) const {
- UnicodeText text_unicode = UTF8ToUnicodeText(text, /*do_copy=*/false);
- return Tokenize(text_unicode);
-}
-
-std::vector<Token> Tokenizer::Tokenize(const UnicodeText& text_unicode) const {
- std::vector<Token> result;
- Token new_token("", 0, 0);
- int codepoint_index = 0;
-
- int last_script = kInvalidScript;
- for (auto it = text_unicode.begin(); it != text_unicode.end();
- ++it, ++codepoint_index) {
- TokenizationCodepointRange_::Role role;
- int script;
- GetScriptAndRole(*it, &role, &script);
-
- if (role & TokenizationCodepointRange_::Role_SPLIT_BEFORE ||
- (split_on_script_change_ && last_script != kInvalidScript &&
- last_script != script)) {
- if (!new_token.value.empty()) {
- result.push_back(new_token);
- }
- new_token = Token("", codepoint_index, codepoint_index);
- }
- if (!(role & TokenizationCodepointRange_::Role_DISCARD_CODEPOINT)) {
- new_token.value += std::string(
- it.utf8_data(),
- it.utf8_data() + GetNumBytesForNonZeroUTF8Char(it.utf8_data()));
- ++new_token.end;
- }
- if (role & TokenizationCodepointRange_::Role_SPLIT_AFTER) {
- if (!new_token.value.empty()) {
- result.push_back(new_token);
- }
- new_token = Token("", codepoint_index + 1, codepoint_index + 1);
- }
-
- last_script = script;
- }
- if (!new_token.value.empty()) {
- result.push_back(new_token);
- }
-
- return result;
-}
-
-} // namespace libtextclassifier3
diff --git a/annotator/tokenizer.h b/annotator/tokenizer.h
deleted file mode 100644
index ec33f2d..0000000
--- a/annotator/tokenizer.h
+++ /dev/null
@@ -1,71 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_TOKENIZER_H_
-#define LIBTEXTCLASSIFIER_ANNOTATOR_TOKENIZER_H_
-
-#include <string>
-#include <vector>
-
-#include "annotator/model_generated.h"
-#include "annotator/types.h"
-#include "utils/base/integral_types.h"
-#include "utils/utf8/unicodetext.h"
-
-namespace libtextclassifier3 {
-
-const int kInvalidScript = -1;
-const int kUnknownScript = -2;
-
-// Tokenizer splits the input string into a sequence of tokens, according to the
-// configuration.
-class Tokenizer {
- public:
- explicit Tokenizer(
- const std::vector<const TokenizationCodepointRange*>& codepoint_ranges,
- bool split_on_script_change);
-
- // Tokenizes the input string using the selected tokenization method.
- std::vector<Token> Tokenize(const std::string& text) const;
-
- // Same as above but takes UnicodeText.
- std::vector<Token> Tokenize(const UnicodeText& text_unicode) const;
-
- protected:
- // Finds the tokenization codepoint range config for given codepoint.
- // Internally uses binary search so should be O(log(# of codepoint_ranges)).
- const TokenizationCodepointRangeT* FindTokenizationRange(int codepoint) const;
-
- // Finds the role and script for given codepoint. If not found, DEFAULT_ROLE
- // and kUnknownScript are assigned.
- void GetScriptAndRole(char32 codepoint,
- TokenizationCodepointRange_::Role* role,
- int* script) const;
-
- private:
- // Codepoint ranges that determine how different codepoints are tokenized.
- // The ranges must not overlap.
- std::vector<std::unique_ptr<const TokenizationCodepointRangeT>>
- codepoint_ranges_;
-
- // If true, tokens will be additionally split when the codepoint's script_id
- // changes.
- bool split_on_script_change_;
-};
-
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_ANNOTATOR_TOKENIZER_H_
diff --git a/annotator/tokenizer_test.cc b/annotator/tokenizer_test.cc
deleted file mode 100644
index a3ab9da..0000000
--- a/annotator/tokenizer_test.cc
+++ /dev/null
@@ -1,334 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "annotator/tokenizer.h"
-
-#include <vector>
-
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-
-namespace libtextclassifier3 {
-namespace {
-
-using testing::ElementsAreArray;
-
-class TestingTokenizer : public Tokenizer {
- public:
- explicit TestingTokenizer(
- const std::vector<const TokenizationCodepointRange*>&
- codepoint_range_configs,
- bool split_on_script_change)
- : Tokenizer(codepoint_range_configs, split_on_script_change) {}
-
- using Tokenizer::FindTokenizationRange;
-};
-
-class TestingTokenizerProxy {
- public:
- explicit TestingTokenizerProxy(
- const std::vector<TokenizationCodepointRangeT>& codepoint_range_configs,
- bool split_on_script_change) {
- int num_configs = codepoint_range_configs.size();
- std::vector<const TokenizationCodepointRange*> configs_fb;
- buffers_.reserve(num_configs);
- for (int i = 0; i < num_configs; i++) {
- flatbuffers::FlatBufferBuilder builder;
- builder.Finish(CreateTokenizationCodepointRange(
- builder, &codepoint_range_configs[i]));
- buffers_.push_back(builder.Release());
- configs_fb.push_back(
- flatbuffers::GetRoot<TokenizationCodepointRange>(buffers_[i].data()));
- }
- tokenizer_ = std::unique_ptr<TestingTokenizer>(
- new TestingTokenizer(configs_fb, split_on_script_change));
- }
-
- TokenizationCodepointRange_::Role TestFindTokenizationRole(int c) const {
- const TokenizationCodepointRangeT* range =
- tokenizer_->FindTokenizationRange(c);
- if (range != nullptr) {
- return range->role;
- } else {
- return TokenizationCodepointRange_::Role_DEFAULT_ROLE;
- }
- }
-
- std::vector<Token> Tokenize(const std::string& utf8_text) const {
- return tokenizer_->Tokenize(utf8_text);
- }
-
- private:
- std::vector<flatbuffers::DetachedBuffer> buffers_;
- std::unique_ptr<TestingTokenizer> tokenizer_;
-};
-
-TEST(TokenizerTest, FindTokenizationRange) {
- std::vector<TokenizationCodepointRangeT> configs;
- TokenizationCodepointRangeT* config;
-
- configs.emplace_back();
- config = &configs.back();
- config->start = 0;
- config->end = 10;
- config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
-
- configs.emplace_back();
- config = &configs.back();
- config->start = 32;
- config->end = 33;
- config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR;
-
- configs.emplace_back();
- config = &configs.back();
- config->start = 1234;
- config->end = 12345;
- config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
-
- TestingTokenizerProxy tokenizer(configs, /*split_on_script_change=*/false);
-
- // Test hits to the first group.
- EXPECT_EQ(tokenizer.TestFindTokenizationRole(0),
- TokenizationCodepointRange_::Role_TOKEN_SEPARATOR);
- EXPECT_EQ(tokenizer.TestFindTokenizationRole(5),
- TokenizationCodepointRange_::Role_TOKEN_SEPARATOR);
- EXPECT_EQ(tokenizer.TestFindTokenizationRole(10),
- TokenizationCodepointRange_::Role_DEFAULT_ROLE);
-
- // Test a hit to the second group.
- EXPECT_EQ(tokenizer.TestFindTokenizationRole(31),
- TokenizationCodepointRange_::Role_DEFAULT_ROLE);
- EXPECT_EQ(tokenizer.TestFindTokenizationRole(32),
- TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR);
- EXPECT_EQ(tokenizer.TestFindTokenizationRole(33),
- TokenizationCodepointRange_::Role_DEFAULT_ROLE);
-
- // Test hits to the third group.
- EXPECT_EQ(tokenizer.TestFindTokenizationRole(1233),
- TokenizationCodepointRange_::Role_DEFAULT_ROLE);
- EXPECT_EQ(tokenizer.TestFindTokenizationRole(1234),
- TokenizationCodepointRange_::Role_TOKEN_SEPARATOR);
- EXPECT_EQ(tokenizer.TestFindTokenizationRole(12344),
- TokenizationCodepointRange_::Role_TOKEN_SEPARATOR);
- EXPECT_EQ(tokenizer.TestFindTokenizationRole(12345),
- TokenizationCodepointRange_::Role_DEFAULT_ROLE);
-
- // Test a hit outside.
- EXPECT_EQ(tokenizer.TestFindTokenizationRole(99),
- TokenizationCodepointRange_::Role_DEFAULT_ROLE);
-}
-
-TEST(TokenizerTest, TokenizeOnSpace) {
- std::vector<TokenizationCodepointRangeT> configs;
- TokenizationCodepointRangeT* config;
-
- configs.emplace_back();
- config = &configs.back();
- // Space character.
- config->start = 32;
- config->end = 33;
- config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR;
-
- TestingTokenizerProxy tokenizer(configs, /*split_on_script_change=*/false);
- std::vector<Token> tokens = tokenizer.Tokenize("Hello world!");
-
- EXPECT_THAT(tokens,
- ElementsAreArray({Token("Hello", 0, 5), Token("world!", 6, 12)}));
-}
-
-TEST(TokenizerTest, TokenizeOnSpaceAndScriptChange) {
- std::vector<TokenizationCodepointRangeT> configs;
- TokenizationCodepointRangeT* config;
-
- // Latin.
- configs.emplace_back();
- config = &configs.back();
- config->start = 0;
- config->end = 32;
- config->role = TokenizationCodepointRange_::Role_DEFAULT_ROLE;
- config->script_id = 1;
- configs.emplace_back();
- config = &configs.back();
- config->start = 32;
- config->end = 33;
- config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR;
- config->script_id = 1;
- configs.emplace_back();
- config = &configs.back();
- config->start = 33;
- config->end = 0x77F + 1;
- config->role = TokenizationCodepointRange_::Role_DEFAULT_ROLE;
- config->script_id = 1;
-
- TestingTokenizerProxy tokenizer(configs, /*split_on_script_change=*/true);
- EXPECT_THAT(tokenizer.Tokenize("앨라배마 주 전화(123) 456-789웹사이트"),
- std::vector<Token>({Token("앨라배마", 0, 4), Token("주", 5, 6),
- Token("전화", 7, 10), Token("(123)", 10, 15),
- Token("456-789", 16, 23),
- Token("웹사이트", 23, 28)}));
-} // namespace
-
-TEST(TokenizerTest, TokenizeComplex) {
- std::vector<TokenizationCodepointRangeT> configs;
- TokenizationCodepointRangeT* config;
-
- // Source: http://www.unicode.org/Public/10.0.0/ucd/Blocks-10.0.0d1.txt
- // Latin - cyrilic.
- // 0000..007F; Basic Latin
- // 0080..00FF; Latin-1 Supplement
- // 0100..017F; Latin Extended-A
- // 0180..024F; Latin Extended-B
- // 0250..02AF; IPA Extensions
- // 02B0..02FF; Spacing Modifier Letters
- // 0300..036F; Combining Diacritical Marks
- // 0370..03FF; Greek and Coptic
- // 0400..04FF; Cyrillic
- // 0500..052F; Cyrillic Supplement
- // 0530..058F; Armenian
- // 0590..05FF; Hebrew
- // 0600..06FF; Arabic
- // 0700..074F; Syriac
- // 0750..077F; Arabic Supplement
- configs.emplace_back();
- config = &configs.back();
- config->start = 0;
- config->end = 32;
- config->role = TokenizationCodepointRange_::Role_DEFAULT_ROLE;
- configs.emplace_back();
- config = &configs.back();
- config->start = 32;
- config->end = 33;
- config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR;
- configs.emplace_back();
- config = &configs.back();
- config->start = 33;
- config->end = 0x77F + 1;
- config->role = TokenizationCodepointRange_::Role_DEFAULT_ROLE;
-
- // CJK
- // 2E80..2EFF; CJK Radicals Supplement
- // 3000..303F; CJK Symbols and Punctuation
- // 3040..309F; Hiragana
- // 30A0..30FF; Katakana
- // 3100..312F; Bopomofo
- // 3130..318F; Hangul Compatibility Jamo
- // 3190..319F; Kanbun
- // 31A0..31BF; Bopomofo Extended
- // 31C0..31EF; CJK Strokes
- // 31F0..31FF; Katakana Phonetic Extensions
- // 3200..32FF; Enclosed CJK Letters and Months
- // 3300..33FF; CJK Compatibility
- // 3400..4DBF; CJK Unified Ideographs Extension A
- // 4DC0..4DFF; Yijing Hexagram Symbols
- // 4E00..9FFF; CJK Unified Ideographs
- // A000..A48F; Yi Syllables
- // A490..A4CF; Yi Radicals
- // A4D0..A4FF; Lisu
- // A500..A63F; Vai
- // F900..FAFF; CJK Compatibility Ideographs
- // FE30..FE4F; CJK Compatibility Forms
- // 20000..2A6DF; CJK Unified Ideographs Extension B
- // 2A700..2B73F; CJK Unified Ideographs Extension C
- // 2B740..2B81F; CJK Unified Ideographs Extension D
- // 2B820..2CEAF; CJK Unified Ideographs Extension E
- // 2CEB0..2EBEF; CJK Unified Ideographs Extension F
- // 2F800..2FA1F; CJK Compatibility Ideographs Supplement
- configs.emplace_back();
- config = &configs.back();
- config->start = 0x2E80;
- config->end = 0x2EFF + 1;
- config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
- configs.emplace_back();
- config = &configs.back();
- config->start = 0x3000;
- config->end = 0xA63F + 1;
- config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
- configs.emplace_back();
- config = &configs.back();
- config->start = 0xF900;
- config->end = 0xFAFF + 1;
- config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
- configs.emplace_back();
- config = &configs.back();
- config->start = 0xFE30;
- config->end = 0xFE4F + 1;
- config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
- configs.emplace_back();
- config = &configs.back();
- config->start = 0x20000;
- config->end = 0x2A6DF + 1;
- config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
- configs.emplace_back();
- config = &configs.back();
- config->start = 0x2A700;
- config->end = 0x2B73F + 1;
- config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
- configs.emplace_back();
- config = &configs.back();
- config->start = 0x2B740;
- config->end = 0x2B81F + 1;
- config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
- configs.emplace_back();
- config = &configs.back();
- config->start = 0x2B820;
- config->end = 0x2CEAF + 1;
- config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
- configs.emplace_back();
- config = &configs.back();
- config->start = 0x2CEB0;
- config->end = 0x2EBEF + 1;
- config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
- configs.emplace_back();
- config = &configs.back();
- config->start = 0x2F800;
- config->end = 0x2FA1F + 1;
- config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
-
- // Thai.
- // 0E00..0E7F; Thai
- configs.emplace_back();
- config = &configs.back();
- config->start = 0x0E00;
- config->end = 0x0E7F + 1;
- config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
-
- TestingTokenizerProxy tokenizer(configs, /*split_on_script_change=*/false);
- std::vector<Token> tokens;
-
- tokens = tokenizer.Tokenize(
- "問少目木輸走猶術権自京門録球変。細開括省用掲情結傍走愛明氷。");
- EXPECT_EQ(tokens.size(), 30);
-
- tokens = tokenizer.Tokenize("問少目 hello 木輸ยามきゃ");
- // clang-format off
- EXPECT_THAT(
- tokens,
- ElementsAreArray({Token("問", 0, 1),
- Token("少", 1, 2),
- Token("目", 2, 3),
- Token("hello", 4, 9),
- Token("木", 10, 11),
- Token("輸", 11, 12),
- Token("ย", 12, 13),
- Token("า", 13, 14),
- Token("ม", 14, 15),
- Token("き", 15, 16),
- Token("ゃ", 16, 17)}));
- // clang-format on
-}
-
-} // namespace
-} // namespace libtextclassifier3
diff --git a/annotator/types-test-util.h b/annotator/types-test-util.h
index fbbdd63..c0b0980 100644
--- a/annotator/types-test-util.h
+++ b/annotator/types-test-util.h
@@ -24,25 +24,21 @@
namespace libtextclassifier3 {
-inline std::ostream& operator<<(std::ostream& stream, const Token& value) {
- logging::LoggingStringStream tmp_stream;
- tmp_stream << value;
- return stream << tmp_stream.message;
-}
+#define TC3_DECLARE_PRINT_OPERATOR(TYPE_NAME) \
+ inline std::ostream& operator<<(std::ostream& stream, \
+ const TYPE_NAME& value) { \
+ logging::LoggingStringStream tmp_stream; \
+ tmp_stream << value; \
+ return stream << tmp_stream.message; \
+ }
-inline std::ostream& operator<<(std::ostream& stream,
- const AnnotatedSpan& value) {
- logging::LoggingStringStream tmp_stream;
- tmp_stream << value;
- return stream << tmp_stream.message;
-}
+TC3_DECLARE_PRINT_OPERATOR(AnnotatedSpan)
+TC3_DECLARE_PRINT_OPERATOR(ClassificationResult)
+TC3_DECLARE_PRINT_OPERATOR(DateParseData)
+TC3_DECLARE_PRINT_OPERATOR(DatetimeParseResultSpan)
+TC3_DECLARE_PRINT_OPERATOR(Token)
-inline std::ostream& operator<<(std::ostream& stream,
- const DatetimeParseResultSpan& value) {
- logging::LoggingStringStream tmp_stream;
- tmp_stream << value;
- return stream << tmp_stream.message;
-}
+#undef TC3_DECLARE_PRINT_OPERATOR
} // namespace libtextclassifier3
diff --git a/annotator/types.cc b/annotator/types.cc
new file mode 100644
index 0000000..ee150c8
--- /dev/null
+++ b/annotator/types.cc
@@ -0,0 +1,106 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/types.h"
+
+namespace libtextclassifier3 {
+
+logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
+ const Token& token) {
+ if (!token.is_padding) {
+ return stream << "Token(\"" << token.value << "\", " << token.start << ", "
+ << token.end << ")";
+ } else {
+ return stream << "Token()";
+ }
+}
+
+namespace {
+std::string FormatMillis(int64 time_ms_utc) {
+ long time_seconds = time_ms_utc / 1000; // NOLINT
+ char buffer[512];
+ strftime(buffer, sizeof(buffer), "%a %Y-%m-%d %H:%M:%S %Z",
+ localtime(&time_seconds));
+ return std::string(buffer);
+}
+} // namespace
+
+logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
+ const DatetimeParseResultSpan& value) {
+ stream << "DatetimeParseResultSpan({" << value.span.first << ", "
+ << value.span.second << "}, {";
+ for (const DatetimeParseResult& data : value.data) {
+ stream << "{/*time_ms_utc=*/ " << data.time_ms_utc << " /* "
+ << FormatMillis(data.time_ms_utc) << " */, /*granularity=*/ "
+ << data.granularity << "}, ";
+ }
+ stream << "})";
+ return stream;
+}
+
+logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
+ const ClassificationResult& result) {
+ return stream << "ClassificationResult(" << result.collection
+ << ", /*score=*/ " << result.score << ", /*priority_score=*/ "
+ << result.priority_score << ")";
+}
+
+logging::LoggingStringStream& operator<<(
+ logging::LoggingStringStream& stream,
+ const std::vector<ClassificationResult>& results) {
+ stream = stream << "{\n";
+ for (const ClassificationResult& result : results) {
+ stream = stream << " " << result << "\n";
+ }
+ stream = stream << "}";
+ return stream;
+}
+
+logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
+ const AnnotatedSpan& span) {
+ std::string best_class;
+ float best_score = -1;
+ if (!span.classification.empty()) {
+ best_class = span.classification[0].collection;
+ best_score = span.classification[0].score;
+ }
+ return stream << "Span(" << span.span.first << ", " << span.span.second
+ << ", " << best_class << ", " << best_score << ")";
+}
+
+logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
+ const DateParseData& data) {
+ // TODO(zilka): Add human-readable form of field_set_mask and the enum fields.
+ stream = stream << "DateParseData {\n";
+ stream = stream << " field_set_mask: " << data.field_set_mask << "\n";
+ stream = stream << " year: " << data.year << "\n";
+ stream = stream << " month: " << data.month << "\n";
+ stream = stream << " day_of_month: " << data.day_of_month << "\n";
+ stream = stream << " hour: " << data.hour << "\n";
+ stream = stream << " minute: " << data.minute << "\n";
+ stream = stream << " second: " << data.second << "\n";
+ stream = stream << " ampm: " << static_cast<int>(data.ampm) << "\n";
+ stream = stream << " zone_offset: " << data.zone_offset << "\n";
+ stream = stream << " dst_offset: " << data.dst_offset << "\n";
+ stream = stream << " relation: " << static_cast<int>(data.relation) << "\n";
+ stream = stream << " relation_type: " << static_cast<int>(data.relation_type)
+ << "\n";
+ stream = stream << " relation_distance: " << data.relation_distance << "\n";
+ stream = stream << "}";
+ return stream;
+}
+
+} // namespace libtextclassifier3
diff --git a/annotator/types.h b/annotator/types.h
index 38bce41..48fefe4 100644
--- a/annotator/types.h
+++ b/annotator/types.h
@@ -17,6 +17,7 @@
#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_TYPES_H_
#define LIBTEXTCLASSIFIER_ANNOTATOR_TYPES_H_
+#include <time.h>
#include <algorithm>
#include <cmath>
#include <functional>
@@ -26,8 +27,10 @@
#include <utility>
#include <vector>
+#include "annotator/entity-data_generated.h"
#include "utils/base/integral_types.h"
#include "utils/base/logging.h"
+#include "utils/flatbuffers.h"
#include "utils/variant.h"
namespace libtextclassifier3 {
@@ -147,15 +150,8 @@
};
// Pretty-printing function for Token.
-inline logging::LoggingStringStream& operator<<(
- logging::LoggingStringStream& stream, const Token& token) {
- if (!token.is_padding) {
- return stream << "Token(\"" << token.value << "\", " << token.start << ", "
- << token.end << ")";
- } else {
- return stream << "Token()";
- }
-}
+logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
+ const Token& token);
enum DatetimeGranularity {
GRANULARITY_UNKNOWN = -1, // GRANULARITY_UNKNOWN is used as a proxy for this
@@ -170,9 +166,7 @@
};
struct DatetimeParseResult {
- // The absolute time in milliseconds since the epoch in UTC. This is derived
- // from the reference time and the fields specified in the text - so it may
- // be imperfect where the time was ambiguous. (e.g. "at 7:30" may be am or pm)
+ // The absolute time in milliseconds since the epoch in UTC.
int64 time_ms_utc;
// The precision of the estimate then in to calculating the milliseconds
@@ -195,13 +189,12 @@
struct DatetimeParseResultSpan {
CodepointSpan span;
- DatetimeParseResult data;
+ std::vector<DatetimeParseResult> data;
float target_classification_score;
float priority_score;
bool operator==(const DatetimeParseResultSpan& other) const {
- return span == other.span && data.granularity == other.data.granularity &&
- data.time_ms_utc == other.data.time_ms_utc &&
+ return span == other.span && data == other.data &&
std::abs(target_classification_score -
other.target_classification_score) < kFloatCompareEpsilon &&
std::abs(priority_score - other.priority_score) <
@@ -210,26 +203,32 @@
};
// Pretty-printing function for DatetimeParseResultSpan.
-inline logging::LoggingStringStream& operator<<(
- logging::LoggingStringStream& stream,
- const DatetimeParseResultSpan& value) {
- return stream << "DatetimeParseResultSpan({" << value.span.first << ", "
- << value.span.second << "}, {/*time_ms_utc=*/ "
- << value.data.time_ms_utc << ", /*granularity=*/ "
- << value.data.granularity << "})";
-}
+logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
+ const DatetimeParseResultSpan& value);
struct ClassificationResult {
std::string collection;
float score;
DatetimeParseResult datetime_parse_result;
std::string serialized_knowledge_result;
+ std::string contact_name, contact_given_name, contact_nickname,
+ contact_email_address, contact_phone_number, contact_id;
+ std::string app_name, app_package_name;
+ int64 numeric_value;
+
+ // Length of the parsed duration in milliseconds.
+ int64 duration_ms;
// Internal score used for conflict resolution.
float priority_score;
- // Extra information.
- std::map<std::string, Variant> extra;
+
+ // Entity data information.
+ std::string serialized_entity_data;
+ const EntityData* entity_data() {
+ return LoadAndVerifyFlatbuffer<EntityData>(serialized_entity_data.data(),
+ serialized_entity_data.size());
+ }
explicit ClassificationResult() : score(-1.0f), priority_score(-1.0) {}
@@ -246,45 +245,37 @@
};
// Pretty-printing function for ClassificationResult.
-inline logging::LoggingStringStream& operator<<(
- logging::LoggingStringStream& stream, const ClassificationResult& result) {
- return stream << "ClassificationResult(" << result.collection << ", "
- << result.score << ")";
-}
+logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
+ const ClassificationResult& result);
// Pretty-printing function for std::vector<ClassificationResult>.
-inline logging::LoggingStringStream& operator<<(
+logging::LoggingStringStream& operator<<(
logging::LoggingStringStream& stream,
- const std::vector<ClassificationResult>& results) {
- stream = stream << "{\n";
- for (const ClassificationResult& result : results) {
- stream = stream << " " << result << "\n";
- }
- stream = stream << "}";
- return stream;
-}
+ const std::vector<ClassificationResult>& results);
// Represents a result of Annotate call.
struct AnnotatedSpan {
+ enum class Source { OTHER, KNOWLEDGE, DURATION, DATETIME };
+
// Unicode codepoint indices in the input string.
CodepointSpan span = {kInvalidIndex, kInvalidIndex};
// Classification result for the span.
std::vector<ClassificationResult> classification;
+
+ // The source of the annotation, used in conflict resolution.
+ Source source = Source::OTHER;
+
+ AnnotatedSpan() = default;
+
+ AnnotatedSpan(CodepointSpan arg_span,
+ std::vector<ClassificationResult> arg_classification)
+ : span(arg_span), classification(std::move(arg_classification)) {}
};
// Pretty-printing function for AnnotatedSpan.
-inline logging::LoggingStringStream& operator<<(
- logging::LoggingStringStream& stream, const AnnotatedSpan& span) {
- std::string best_class;
- float best_score = -1;
- if (!span.classification.empty()) {
- best_class = span.classification[0].collection;
- best_score = span.classification[0].score;
- }
- return stream << "Span(" << span.span.first << ", " << span.span.second
- << ", " << best_class << ", " << best_score << ")";
-}
+logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
+ const AnnotatedSpan& span);
// StringPiece analogue for std::vector<T>.
template <class T>
@@ -312,7 +303,8 @@
};
struct DateParseData {
- enum Relation {
+ enum class Relation {
+ UNSPECIFIED = 0,
NEXT = 1,
NEXT_OR_SAME = 2,
LAST = 3,
@@ -323,7 +315,8 @@
FUTURE = 8
};
- enum RelationType {
+ enum class RelationType {
+ UNSPECIFIED = 0,
SUNDAY = 1,
MONDAY = 2,
TUESDAY = 3,
@@ -334,7 +327,10 @@
DAY = 8,
WEEK = 9,
MONTH = 10,
- YEAR = 11
+ YEAR = 11,
+ HOUR = 12,
+ MINUTE = 13,
+ SECOND = 14,
};
enum Fields {
@@ -352,9 +348,9 @@
RELATION_DISTANCE_FIELD = 1 << 11
};
- enum AMPM { AM = 0, PM = 1 };
+ enum class AMPM { AM = 0, PM = 1 };
- enum TimeUnit {
+ enum class TimeUnit {
DAYS = 1,
WEEKS = 2,
MONTHS = 3,
@@ -365,38 +361,63 @@
};
// Bit mask of fields which have been set on the struct
- int field_set_mask;
+ int field_set_mask = 0;
// Fields describing absolute date fields.
// Year of the date seen in the text match.
- int year;
+ int year = 0;
// Month of the year starting with January = 1.
- int month;
+ int month = 0;
// Day of the month starting with 1.
- int day_of_month;
+ int day_of_month = 0;
// Hour of the day with a range of 0-23,
// values less than 12 need the AMPM field below or heuristics
// to definitively determine the time.
- int hour;
+ int hour = 0;
// Hour of the day with a range of 0-59.
- int minute;
+ int minute = 0;
// Hour of the day with a range of 0-59.
- int second;
+ int second = 0;
// 0 == AM, 1 == PM
- int ampm;
+ AMPM ampm = AMPM::AM;
// Number of hours offset from UTC this date time is in.
- int zone_offset;
+ int zone_offset = 0;
// Number of hours offest for DST
- int dst_offset;
+ int dst_offset = 0;
// The permutation from now that was made to find the date time.
- Relation relation;
+ Relation relation = Relation::UNSPECIFIED;
// The unit of measure of the change to the date time.
- RelationType relation_type;
+ RelationType relation_type = RelationType::UNSPECIFIED;
// The number of units of change that were made.
- int relation_distance;
+ int relation_distance = 0;
+
+ DateParseData() = default;
+
+ DateParseData(int field_set_mask, int year, int month, int day_of_month,
+ int hour, int minute, int second, AMPM ampm, int zone_offset,
+ int dst_offset, Relation relation, RelationType relation_type,
+ int relation_distance) {
+ this->field_set_mask = field_set_mask;
+ this->year = year;
+ this->month = month;
+ this->day_of_month = day_of_month;
+ this->hour = hour;
+ this->minute = minute;
+ this->second = second;
+ this->ampm = ampm;
+ this->zone_offset = zone_offset;
+ this->dst_offset = dst_offset;
+ this->relation = relation;
+ this->relation_type = relation_type;
+ this->relation_distance = relation_distance;
+ }
};
+// Pretty-printing function for DateParseData.
+logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
+ const DateParseData& data);
+
} // namespace libtextclassifier3
#endif // LIBTEXTCLASSIFIER_ANNOTATOR_TYPES_H_
diff --git a/annotator/zlib-utils.cc b/annotator/zlib-utils.cc
index 6efe025..ec2392b 100644
--- a/annotator/zlib-utils.cc
+++ b/annotator/zlib-utils.cc
@@ -19,6 +19,8 @@
#include <memory>
#include "utils/base/logging.h"
+#include "utils/intents/zlib-utils.h"
+#include "utils/resources.h"
#include "utils/zlib/zlib.h"
namespace libtextclassifier3 {
@@ -63,6 +65,17 @@
extractor->pattern.clear();
}
}
+
+ // Compress resources.
+ if (model->resources != nullptr) {
+ CompressResources(model->resources.get());
+ }
+
+ // Compress intent generator.
+ if (model->intent_options != nullptr) {
+ CompressIntentModel(model->intent_options.get());
+ }
+
return true;
}
diff --git a/java/com/google/android/textclassifier/ActionsSuggestionsModel.java b/java/com/google/android/textclassifier/ActionsSuggestionsModel.java
new file mode 100644
index 0000000..9132b1f
--- /dev/null
+++ b/java/com/google/android/textclassifier/ActionsSuggestionsModel.java
@@ -0,0 +1,265 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.google.android.textclassifier;
+
+import java.util.concurrent.atomic.AtomicBoolean;
+
+/**
+ * Java wrapper for ActionsSuggestions native library interface. This library is used to suggest
+ * actions and replies in a given conversation.
+ *
+ * @hide
+ */
+public final class ActionsSuggestionsModel implements AutoCloseable {
+ private final AtomicBoolean isClosed = new AtomicBoolean(false);
+
+ static {
+ System.loadLibrary("textclassifier");
+ }
+
+ private long actionsModelPtr;
+
+ /**
+ * Creates a new instance of Actions predictor, using the provided model image, given as a file
+ * descriptor.
+ */
+ public ActionsSuggestionsModel(int fileDescriptor, byte[] serializedPreconditions) {
+ actionsModelPtr = nativeNewActionsModel(fileDescriptor, serializedPreconditions);
+ if (actionsModelPtr == 0L) {
+ throw new IllegalArgumentException("Couldn't initialize actions model from file descriptor.");
+ }
+ }
+
+ public ActionsSuggestionsModel(int fileDescriptor) {
+ this(fileDescriptor, /* serializedPreconditions= */ null);
+ }
+
+ /**
+ * Creates a new instance of Actions predictor, using the provided model image, given as a file
+ * path.
+ */
+ public ActionsSuggestionsModel(String path, byte[] serializedPreconditions) {
+ actionsModelPtr = nativeNewActionsModelFromPath(path, serializedPreconditions);
+ if (actionsModelPtr == 0L) {
+ throw new IllegalArgumentException("Couldn't initialize actions model from given file.");
+ }
+ }
+
+ public ActionsSuggestionsModel(String path) {
+ this(path, /* serializedPreconditions= */ null);
+ }
+
+ /** Suggests actions / replies to the given conversation. */
+ public ActionSuggestion[] suggestActions(
+ Conversation conversation, ActionSuggestionOptions options, AnnotatorModel annotator) {
+ return nativeSuggestActions(
+ actionsModelPtr,
+ conversation,
+ options,
+ (annotator != null ? annotator.getNativeAnnotator() : 0),
+ /* appContext= */ null,
+ /* deviceLocales= */ null,
+ /* generateAndroidIntents= */ false);
+ }
+
+ public ActionSuggestion[] suggestActionsWithIntents(
+ Conversation conversation,
+ ActionSuggestionOptions options,
+ Object appContext,
+ String deviceLocales,
+ AnnotatorModel annotator) {
+ return nativeSuggestActions(
+ actionsModelPtr,
+ conversation,
+ options,
+ (annotator != null ? annotator.getNativeAnnotator() : 0),
+ appContext,
+ deviceLocales,
+ /* generateAndroidIntents= */ true);
+ }
+
+ /** Frees up the allocated memory. */
+ @Override
+ public void close() {
+ if (isClosed.compareAndSet(false, true)) {
+ nativeCloseActionsModel(actionsModelPtr);
+ actionsModelPtr = 0L;
+ }
+ }
+
+ @Override
+ protected void finalize() throws Throwable {
+ try {
+ close();
+ } finally {
+ super.finalize();
+ }
+ }
+
+ /** Returns a comma separated list of locales supported by the model as BCP 47 tags. */
+ public static String getLocales(int fd) {
+ return nativeGetLocales(fd);
+ }
+
+ /** Returns the version of the model. */
+ public static int getVersion(int fd) {
+ return nativeGetVersion(fd);
+ }
+
+ /** Returns the name of the model. */
+ public static String getName(int fd) {
+ return nativeGetName(fd);
+ }
+
+ /** Action suggestion that contains a response text and the type of the response. */
+ public static final class ActionSuggestion {
+ private final String responseText;
+ private final String actionType;
+ private final float score;
+ private final NamedVariant[] entityData;
+ private final byte[] serializedEntityData;
+ private final RemoteActionTemplate[] remoteActionTemplates;
+
+ public ActionSuggestion(
+ String responseText,
+ String actionType,
+ float score,
+ NamedVariant[] entityData,
+ byte[] serializedEntityData,
+ RemoteActionTemplate[] remoteActionTemplates) {
+ this.responseText = responseText;
+ this.actionType = actionType;
+ this.score = score;
+ this.entityData = entityData;
+ this.serializedEntityData = serializedEntityData;
+ this.remoteActionTemplates = remoteActionTemplates;
+ }
+
+ public String getResponseText() {
+ return responseText;
+ }
+
+ public String getActionType() {
+ return actionType;
+ }
+
+ /** Confidence score between 0 and 1 */
+ public float getScore() {
+ return score;
+ }
+
+ public NamedVariant[] getEntityData() {
+ return entityData;
+ }
+
+ public byte[] getSerializedEntityData() {
+ return serializedEntityData;
+ }
+
+ public RemoteActionTemplate[] getRemoteActionTemplates() {
+ return remoteActionTemplates;
+ }
+ }
+
+ /** Represents a single message in the conversation. */
+ public static final class ConversationMessage {
+ private final int userId;
+ private final String text;
+ private final long referenceTimeMsUtc;
+ private final String referenceTimezone;
+ private final String detectedTextLanguageTags;
+
+ public ConversationMessage(
+ int userId,
+ String text,
+ long referenceTimeMsUtc,
+ String referenceTimezone,
+ String detectedTextLanguageTags) {
+ this.userId = userId;
+ this.text = text;
+ this.referenceTimeMsUtc = referenceTimeMsUtc;
+ this.referenceTimezone = referenceTimezone;
+ this.detectedTextLanguageTags = detectedTextLanguageTags;
+ }
+
+ /** The identifier of the sender */
+ public int getUserId() {
+ return userId;
+ }
+
+ public String getText() {
+ return text;
+ }
+
+ /**
+ * Return the reference time of the message, for example, it could be compose time or send time.
+ * {@code 0} means unspecified.
+ */
+ public long getReferenceTimeMsUtc() {
+ return referenceTimeMsUtc;
+ }
+
+ public String getReferenceTimezone() {
+ return referenceTimezone;
+ }
+
+ /** Returns a comma separated list of BCP 47 language tags. */
+ public String getDetectedTextLanguageTags() {
+ return detectedTextLanguageTags;
+ }
+ }
+
+ /** Represents conversation between multiple users. */
+ public static final class Conversation {
+ public final ConversationMessage[] conversationMessages;
+
+ public Conversation(ConversationMessage[] conversationMessages) {
+ this.conversationMessages = conversationMessages;
+ }
+
+ public ConversationMessage[] getConversationMessages() {
+ return conversationMessages;
+ }
+ }
+
+ /** Represents options for the SuggestActions call. */
+ public static final class ActionSuggestionOptions {
+ public ActionSuggestionOptions() {}
+ }
+
+ private static native long nativeNewActionsModel(int fd, byte[] serializedPreconditions);
+
+ private static native long nativeNewActionsModelFromPath(
+ String path, byte[] preconditionsOverwrite);
+
+ private static native String nativeGetLocales(int fd);
+
+ private static native int nativeGetVersion(int fd);
+
+ private static native String nativeGetName(int fd);
+
+ private native ActionSuggestion[] nativeSuggestActions(
+ long context,
+ Conversation conversation,
+ ActionSuggestionOptions options,
+ long annotatorPtr,
+ Object appContext,
+ String deviceLocales,
+ boolean generateAndroidIntents);
+
+ private native void nativeCloseActionsModel(long ptr);
+}
diff --git a/java/com/google/android/textclassifier/AnnotatorModel.java b/java/com/google/android/textclassifier/AnnotatorModel.java
index 08a4455..5f99f74 100644
--- a/java/com/google/android/textclassifier/AnnotatorModel.java
+++ b/java/com/google/android/textclassifier/AnnotatorModel.java
@@ -16,6 +16,7 @@
package com.google.android.textclassifier;
+import java.util.Collection;
import java.util.concurrent.atomic.AtomicBoolean;
/**
@@ -44,6 +45,28 @@
private long annotatorPtr;
+ /** Enumeration for specifying the usecase of the annotations. */
+ public static enum AnnotationUsecase {
+ /** Results are optimized for Smart{Select,Share,Linkify}. */
+ SMART(0),
+
+ /**
+ * Results are optimized for using TextClassifier as an infrastructure that annotates as much as
+ * possible.
+ */
+ RAW(1);
+
+ private final int value;
+
+ AnnotationUsecase(int value) {
+ this.value = value;
+ }
+
+ public int getValue() {
+ return value;
+ }
+ };
+
/**
* Creates a new instance of SmartSelect predictor, using the provided model image, given as a
* file descriptor.
@@ -73,6 +96,20 @@
}
}
+ /** Initializes the contact engine, passing the given serialized config to it. */
+ public void initializeContactEngine(byte[] serializedConfig) {
+ if (!nativeInitializeContactEngine(annotatorPtr, serializedConfig)) {
+ throw new IllegalArgumentException("Couldn't initialize the contact engine");
+ }
+ }
+
+ /** Initializes the installed app engine, passing the given serialized config to it. */
+ public void initializeInstalledAppEngine(byte[] serializedConfig) {
+ if (!nativeInitializeInstalledAppEngine(annotatorPtr, serializedConfig)) {
+ throw new IllegalArgumentException("Couldn't initialize the installed app engine");
+ }
+ }
+
/**
* Given a string context and current selection, computes the selection suggestion.
*
@@ -98,7 +135,24 @@
*/
public ClassificationResult[] classifyText(
String context, int selectionBegin, int selectionEnd, ClassificationOptions options) {
- return nativeClassifyText(annotatorPtr, context, selectionBegin, selectionEnd, options);
+ return classifyText(
+ context,
+ selectionBegin,
+ selectionEnd,
+ options,
+ /*appContext=*/ null,
+ /*deviceLocales=*/ null);
+ }
+
+ public ClassificationResult[] classifyText(
+ String context,
+ int selectionBegin,
+ int selectionEnd,
+ ClassificationOptions options,
+ Object appContext,
+ String deviceLocales) {
+ return nativeClassifyText(
+ annotatorPtr, context, selectionBegin, selectionEnd, options, appContext, deviceLocales);
}
/**
@@ -109,6 +163,14 @@
return nativeAnnotate(annotatorPtr, text, options);
}
+ /**
+ * Looks up a knowledge entity by its identifier. Returns null if the entity is not found or on
+ * error.
+ */
+ public byte[] lookUpKnowledgeEntity(String id) {
+ return nativeLookUpKnowledgeEntity(annotatorPtr, id);
+ }
+
/** Frees up the allocated memory. */
@Override
public void close() {
@@ -145,18 +207,18 @@
/** Information about a parsed time/date. */
public static final class DatetimeResult {
- static final int GRANULARITY_YEAR = 0;
- static final int GRANULARITY_MONTH = 1;
- static final int GRANULARITY_WEEK = 2;
- static final int GRANULARITY_DAY = 3;
- static final int GRANULARITY_HOUR = 4;
- static final int GRANULARITY_MINUTE = 5;
- static final int GRANULARITY_SECOND = 6;
+ public static final int GRANULARITY_YEAR = 0;
+ public static final int GRANULARITY_MONTH = 1;
+ public static final int GRANULARITY_WEEK = 2;
+ public static final int GRANULARITY_DAY = 3;
+ public static final int GRANULARITY_HOUR = 4;
+ public static final int GRANULARITY_MINUTE = 5;
+ public static final int GRANULARITY_SECOND = 6;
private final long timeMsUtc;
private final int granularity;
- DatetimeResult(long timeMsUtc, int granularity) {
+ public DatetimeResult(long timeMsUtc, int granularity) {
this.timeMsUtc = timeMsUtc;
this.granularity = granularity;
}
@@ -176,30 +238,59 @@
private final float score;
private final DatetimeResult datetimeResult;
private final byte[] serializedKnowledgeResult;
+ private final String contactName;
+ private final String contactGivenName;
+ private final String contactNickname;
+ private final String contactEmailAddress;
+ private final String contactPhoneNumber;
+ private final String contactId;
+ private final String appName;
+ private final String appPackageName;
+ private final NamedVariant[] entityData;
+ private final byte[] serializedEntityData;
+ private final RemoteActionTemplate[] remoteActionTemplates;
+ private final long durationMs;
+ private final long numericValue;
public ClassificationResult(
String collection,
float score,
DatetimeResult datetimeResult,
- byte[] serializedKnowledgeResult) {
+ byte[] serializedKnowledgeResult,
+ String contactName,
+ String contactGivenName,
+ String contactNickname,
+ String contactEmailAddress,
+ String contactPhoneNumber,
+ String contactId,
+ String appName,
+ String appPackageName,
+ NamedVariant[] entityData,
+ byte[] serializedEntityData,
+ RemoteActionTemplate[] remoteActionTemplates,
+ long durationMs,
+ long numericValue) {
this.collection = collection;
this.score = score;
this.datetimeResult = datetimeResult;
this.serializedKnowledgeResult = serializedKnowledgeResult;
+ this.contactName = contactName;
+ this.contactGivenName = contactGivenName;
+ this.contactNickname = contactNickname;
+ this.contactEmailAddress = contactEmailAddress;
+ this.contactPhoneNumber = contactPhoneNumber;
+ this.contactId = contactId;
+ this.appName = appName;
+ this.appPackageName = appPackageName;
+ this.entityData = entityData;
+ this.serializedEntityData = serializedEntityData;
+ this.remoteActionTemplates = remoteActionTemplates;
+ this.durationMs = durationMs;
+ this.numericValue = numericValue;
}
/** Returns the classified entity type. */
public String getCollection() {
- if (TYPE_DATE.equals(collection) && datetimeResult != null) {
- switch (datetimeResult.getGranularity()) {
- case DatetimeResult.GRANULARITY_HOUR:
- case DatetimeResult.GRANULARITY_MINUTE:
- case DatetimeResult.GRANULARITY_SECOND:
- return TYPE_DATE_TIME;
- default:
- return TYPE_DATE;
- }
- }
return collection;
}
@@ -212,9 +303,61 @@
return datetimeResult;
}
- byte[] getSerializedKnowledgeResult() {
+ public byte[] getSerializedKnowledgeResult() {
return serializedKnowledgeResult;
}
+
+ public String getContactName() {
+ return contactName;
+ }
+
+ public String getContactGivenName() {
+ return contactGivenName;
+ }
+
+ public String getContactNickname() {
+ return contactNickname;
+ }
+
+ public String getContactEmailAddress() {
+ return contactEmailAddress;
+ }
+
+ public String getContactPhoneNumber() {
+ return contactPhoneNumber;
+ }
+
+ public String getContactId() {
+ return contactId;
+ }
+
+ public String getAppName() {
+ return appName;
+ }
+
+ public String getAppPackageName() {
+ return appPackageName;
+ }
+
+ public NamedVariant[] getEntityData() {
+ return entityData;
+ }
+
+ public byte[] getSerializedEntityData() {
+ return serializedEntityData;
+ }
+
+ public RemoteActionTemplate[] getRemoteActionTemplates() {
+ return remoteActionTemplates;
+ }
+
+ public long getDurationMs() {
+ return durationMs;
+ }
+
+ public long getNumericValue() {
+ return numericValue;
+ }
}
/** Represents a result of Annotate call. */
@@ -245,14 +388,32 @@
/** Represents options for the suggestSelection call. */
public static final class SelectionOptions {
private final String locales;
+ private final String detectedTextLanguageTags;
+ private final int annotationUsecase;
- public SelectionOptions(String locales) {
+ public SelectionOptions(
+ String locales, String detectedTextLanguageTags, int annotationUsecase) {
this.locales = locales;
+ this.detectedTextLanguageTags = detectedTextLanguageTags;
+ this.annotationUsecase = annotationUsecase;
+ }
+
+ public SelectionOptions(String locales, String detectedTextLanguageTags) {
+ this(locales, detectedTextLanguageTags, AnnotationUsecase.SMART.getValue());
}
public String getLocales() {
return locales;
}
+
+ /** Returns a comma separated list of BCP 47 language tags. */
+ public String getDetectedTextLanguageTags() {
+ return detectedTextLanguageTags;
+ }
+
+ public int getAnnotationUsecase() {
+ return annotationUsecase;
+ }
}
/** Represents options for the classifyText call. */
@@ -260,11 +421,33 @@
private final long referenceTimeMsUtc;
private final String referenceTimezone;
private final String locales;
+ private final String detectedTextLanguageTags;
+ private final int annotationUsecase;
- public ClassificationOptions(long referenceTimeMsUtc, String referenceTimezone, String locale) {
+ public ClassificationOptions(
+ long referenceTimeMsUtc,
+ String referenceTimezone,
+ String locales,
+ String detectedTextLanguageTags,
+ int annotationUsecase) {
this.referenceTimeMsUtc = referenceTimeMsUtc;
this.referenceTimezone = referenceTimezone;
- this.locales = locale;
+ this.locales = locales;
+ this.detectedTextLanguageTags = detectedTextLanguageTags;
+ this.annotationUsecase = annotationUsecase;
+ }
+
+ public ClassificationOptions(
+ long referenceTimeMsUtc,
+ String referenceTimezone,
+ String locales,
+ String detectedTextLanguageTags) {
+ this(
+ referenceTimeMsUtc,
+ referenceTimezone,
+ locales,
+ detectedTextLanguageTags,
+ AnnotationUsecase.SMART.getValue());
}
public long getReferenceTimeMsUtc() {
@@ -278,6 +461,15 @@
public String getLocale() {
return locales;
}
+
+ /** Returns a comma separated list of BCP 47 language tags. */
+ public String getDetectedTextLanguageTags() {
+ return detectedTextLanguageTags;
+ }
+
+ public int getAnnotationUsecase() {
+ return annotationUsecase;
+ }
}
/** Represents options for the annotate call. */
@@ -285,11 +477,41 @@
private final long referenceTimeMsUtc;
private final String referenceTimezone;
private final String locales;
+ private final String detectedTextLanguageTags;
+ private final String[] entityTypes;
+ private final int annotationUsecase;
+ private final boolean isSerializedEntityDataEnabled;
- public AnnotationOptions(long referenceTimeMsUtc, String referenceTimezone, String locale) {
+ public AnnotationOptions(
+ long referenceTimeMsUtc,
+ String referenceTimezone,
+ String locales,
+ String detectedTextLanguageTags,
+ Collection<String> entityTypes,
+ int annotationUsecase,
+ boolean isSerializedEntityDataEnabled) {
this.referenceTimeMsUtc = referenceTimeMsUtc;
this.referenceTimezone = referenceTimezone;
- this.locales = locale;
+ this.locales = locales;
+ this.detectedTextLanguageTags = detectedTextLanguageTags;
+ this.entityTypes = entityTypes == null ? new String[0] : entityTypes.toArray(new String[0]);
+ this.annotationUsecase = annotationUsecase;
+ this.isSerializedEntityDataEnabled = isSerializedEntityDataEnabled;
+ }
+
+ public AnnotationOptions(
+ long referenceTimeMsUtc,
+ String referenceTimezone,
+ String locales,
+ String detectedTextLanguageTags) {
+ this(
+ referenceTimeMsUtc,
+ referenceTimezone,
+ locales,
+ detectedTextLanguageTags,
+ null,
+ AnnotationUsecase.SMART.getValue(),
+ /* isSerializedEntityDataEnabled */ false);
}
public long getReferenceTimeMsUtc() {
@@ -303,6 +525,23 @@
public String getLocale() {
return locales;
}
+
+ /** Returns a comma separated list of BCP 47 language tags. */
+ public String getDetectedTextLanguageTags() {
+ return detectedTextLanguageTags;
+ }
+
+ public String[] getEntityTypes() {
+ return entityTypes;
+ }
+
+ public int getAnnotationUsecase() {
+ return annotationUsecase;
+ }
+
+ public boolean isSerializedEntityDataEnabled() {
+ return isSerializedEntityDataEnabled;
+ }
}
/**
@@ -310,7 +549,7 @@
* as the pointer is used.
*/
long getNativeAnnotator() {
- return annotatorPtr;
+ return nativeGetNativeModelPtr(annotatorPtr);
}
private static native long nativeNewAnnotator(int fd);
@@ -323,8 +562,14 @@
private static native String nativeGetName(int fd);
+ private native long nativeGetNativeModelPtr(long context);
+
private native boolean nativeInitializeKnowledgeEngine(long context, byte[] serializedConfig);
+ private native boolean nativeInitializeContactEngine(long context, byte[] serializedConfig);
+
+ private native boolean nativeInitializeInstalledAppEngine(long context, byte[] serializedConfig);
+
private native int[] nativeSuggestSelection(
long context, String text, int selectionBegin, int selectionEnd, SelectionOptions options);
@@ -333,10 +578,14 @@
String text,
int selectionBegin,
int selectionEnd,
- ClassificationOptions options);
+ ClassificationOptions options,
+ Object appContext,
+ String deviceLocales);
private native AnnotatedSpan[] nativeAnnotate(
long context, String text, AnnotationOptions options);
+ private native byte[] nativeLookUpKnowledgeEntity(long context, String id);
+
private native void nativeCloseAnnotator(long context);
}
diff --git a/java/com/google/android/textclassifier/LangIdModel.java b/java/com/google/android/textclassifier/LangIdModel.java
new file mode 100644
index 0000000..d3e166f
--- /dev/null
+++ b/java/com/google/android/textclassifier/LangIdModel.java
@@ -0,0 +1,119 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.google.android.textclassifier;
+
+import java.util.concurrent.atomic.AtomicBoolean;
+
+/**
+ * Java wrapper for LangId native library interface. This class is used to detect languages in text.
+ *
+ * @hide
+ */
+public final class LangIdModel implements AutoCloseable {
+ private final AtomicBoolean isClosed = new AtomicBoolean(false);
+
+ static {
+ System.loadLibrary("textclassifier");
+ }
+
+ private long modelPtr;
+
+ /** Creates a new instance of LangId predictor, using the provided model image. */
+ public LangIdModel(int fd) {
+ modelPtr = nativeNew(fd);
+ if (modelPtr == 0L) {
+ throw new IllegalArgumentException("Couldn't initialize LangId from given file descriptor.");
+ }
+ }
+
+ /** Creates a new instance of LangId predictor, using the provided model image. */
+ public LangIdModel(String modelPath) {
+ modelPtr = nativeNewFromPath(modelPath);
+ if (modelPtr == 0L) {
+ throw new IllegalArgumentException("Couldn't initialize LangId from given file.");
+ }
+ }
+
+ /** Detects the languages for given text. */
+ public LanguageResult[] detectLanguages(String text) {
+ return nativeDetectLanguages(modelPtr, text);
+ }
+
+ /** Frees up the allocated memory. */
+ @Override
+ public void close() {
+ if (isClosed.compareAndSet(false, true)) {
+ nativeClose(modelPtr);
+ modelPtr = 0L;
+ }
+ }
+
+ @Override
+ protected void finalize() throws Throwable {
+ try {
+ close();
+ } finally {
+ super.finalize();
+ }
+ }
+
+ /** Result for detectLanguages method. */
+ public static final class LanguageResult {
+ final String mLanguage;
+ final float mScore;
+
+ LanguageResult(String language, float score) {
+ mLanguage = language;
+ mScore = score;
+ }
+
+ public final String getLanguage() {
+ return mLanguage;
+ }
+
+ public final float getScore() {
+ return mScore;
+ }
+ }
+
+ /** Returns the version of the LangId model used. */
+ public int getVersion() {
+ return nativeGetVersion(modelPtr);
+ }
+
+ public float getLangIdThreshold() {
+ return nativeGetLangIdThreshold(modelPtr);
+ }
+
+ public static int getVersion(int fd) {
+ return nativeGetVersionFromFd(fd);
+ }
+
+ private static native long nativeNew(int fd);
+
+ private static native long nativeNewFromPath(String path);
+
+ private native LanguageResult[] nativeDetectLanguages(long nativePtr, String text);
+
+ private native void nativeClose(long nativePtr);
+
+ private native int nativeGetVersion(long nativePtr);
+
+ private static native int nativeGetVersionFromFd(int fd);
+
+ private native float nativeGetLangIdThreshold(long nativePtr);
+}
diff --git a/java/com/google/android/textclassifier/NamedVariant.java b/java/com/google/android/textclassifier/NamedVariant.java
new file mode 100644
index 0000000..d04bb11
--- /dev/null
+++ b/java/com/google/android/textclassifier/NamedVariant.java
@@ -0,0 +1,115 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.google.android.textclassifier;
+
+/**
+ * Represents a union of different basic types.
+ *
+ * @hide
+ */
+public final class NamedVariant {
+ public static final int TYPE_EMPTY = 0;
+ public static final int TYPE_INT = 1;
+ public static final int TYPE_LONG = 2;
+ public static final int TYPE_FLOAT = 3;
+ public static final int TYPE_DOUBLE = 4;
+ public static final int TYPE_BOOL = 5;
+ public static final int TYPE_STRING = 6;
+
+ public NamedVariant(String name, int value) {
+ this.name = name;
+ this.intValue = value;
+ this.type = TYPE_INT;
+ }
+
+ public NamedVariant(String name, long value) {
+ this.name = name;
+ this.longValue = value;
+ this.type = TYPE_LONG;
+ }
+
+ public NamedVariant(String name, float value) {
+ this.name = name;
+ this.floatValue = value;
+ this.type = TYPE_FLOAT;
+ }
+
+ public NamedVariant(String name, double value) {
+ this.name = name;
+ this.doubleValue = value;
+ this.type = TYPE_DOUBLE;
+ }
+
+ public NamedVariant(String name, boolean value) {
+ this.name = name;
+ this.boolValue = value;
+ this.type = TYPE_BOOL;
+ }
+
+ public NamedVariant(String name, String value) {
+ this.name = name;
+ this.stringValue = value;
+ this.type = TYPE_STRING;
+ }
+
+ public String getName() {
+ return name;
+ }
+
+ public int getType() {
+ return type;
+ }
+
+ public int getInt() {
+ assert (type == TYPE_INT);
+ return intValue;
+ }
+
+ public long getLong() {
+ assert (type == TYPE_LONG);
+ return longValue;
+ }
+
+ public float getFloat() {
+ assert (type == TYPE_FLOAT);
+ return floatValue;
+ }
+
+ public double getDouble() {
+ assert (type == TYPE_DOUBLE);
+ return doubleValue;
+ }
+
+ public boolean getBool() {
+ assert (type == TYPE_BOOL);
+ return boolValue;
+ }
+
+ public String getString() {
+ assert (type == TYPE_STRING);
+ return stringValue;
+ }
+
+ private final String name;
+ private final int type;
+ private int intValue;
+ private long longValue;
+ private float floatValue;
+ private double doubleValue;
+ private boolean boolValue;
+ private String stringValue;
+}
diff --git a/java/com/google/android/textclassifier/RemoteActionTemplate.java b/java/com/google/android/textclassifier/RemoteActionTemplate.java
new file mode 100644
index 0000000..308d809
--- /dev/null
+++ b/java/com/google/android/textclassifier/RemoteActionTemplate.java
@@ -0,0 +1,91 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.google.android.textclassifier;
+
+/**
+ * Represents a template for an Android RemoteAction.
+ *
+ * @hide
+ */
+public final class RemoteActionTemplate {
+ /** Title shown for the action (see: RemoteAction.getTitle). */
+ public final String titleWithoutEntity;
+
+ /** Title with entity for the action. */
+ public final String titleWithEntity;
+
+ /** Description shown for the action (see: RemoteAction.getContentDescription). */
+ public final String description;
+
+ /**
+ * Description shown for the action (see: RemoteAction.getContentDescription) when app name is
+ * available. Caller is expected to replace the placeholder by the name of the app that is going
+ * to handle the action.
+ */
+ public final String descriptionWithAppName;
+
+ /** The action to set on the Intent (see: Intent.setAction). */
+ public final String action;
+
+ /** The data to set on the Intent (see: Intent.setData). */
+ public final String data;
+
+ /** The type to set on the Intent (see: Intent.setType). */
+ public final String type;
+
+ /** Flags for launching the Intent (see: Intent.setFlags). */
+ public final Integer flags;
+
+ /** Categories to set on the Intent (see: Intent.addCategory). */
+ public final String[] category;
+
+ /** Explicit application package to set on the Intent (see: Intent.setPackage). */
+ public final String packageName;
+
+ /** The list of all the extras to add to the Intent. */
+ public final NamedVariant[] extras;
+
+ /** Private request code to use for the Intent. */
+ public final Integer requestCode;
+
+ public RemoteActionTemplate(
+ String titleWithoutEntity,
+ String titleWithEntity,
+ String description,
+ String descriptionWithAppName,
+ String action,
+ String data,
+ String type,
+ Integer flags,
+ String[] category,
+ String packageName,
+ NamedVariant[] extras,
+ Integer requestCode) {
+ this.titleWithoutEntity = titleWithoutEntity;
+ this.titleWithEntity = titleWithEntity;
+ this.description = description;
+ this.descriptionWithAppName = descriptionWithAppName;
+ this.action = action;
+ this.data = data;
+ this.type = type;
+ this.flags = flags;
+ this.category = category;
+ this.packageName = packageName;
+ this.extras = extras;
+ this.requestCode = requestCode;
+ }
+}
diff --git a/lang_id/common/embedding-feature-extractor.cc b/lang_id/common/embedding-feature-extractor.cc
new file mode 100644
index 0000000..6235f89
--- /dev/null
+++ b/lang_id/common/embedding-feature-extractor.cc
@@ -0,0 +1,73 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/common/embedding-feature-extractor.h"
+
+#include <stddef.h>
+
+#include <string>
+#include <vector>
+
+#include "lang_id/common/fel/feature-extractor.h"
+#include "lang_id/common/fel/feature-types.h"
+#include "lang_id/common/fel/task-context.h"
+#include "lang_id/common/lite_base/integral-types.h"
+#include "lang_id/common/lite_base/logging.h"
+#include "lang_id/common/lite_strings/numbers.h"
+#include "lang_id/common/lite_strings/str-split.h"
+#include "lang_id/common/lite_strings/stringpiece.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+bool GenericEmbeddingFeatureExtractor::Setup(TaskContext *context) {
+ // Don't use version to determine how to get feature FML.
+ const string features = context->Get(GetParamName("features"), "");
+ const string embedding_names =
+ context->Get(GetParamName("embedding_names"), "");
+ const string embedding_dims =
+ context->Get(GetParamName("embedding_dims"), "");
+
+ // NOTE: unfortunately, LiteStrSplit returns a vector of StringPieces pointing
+ // to the original string, in this case |features|, which is local to this
+ // method. We need to explicitly create new strings.
+ for (StringPiece sp : LiteStrSplit(features, ';')) {
+ embedding_fml_.emplace_back(sp);
+ }
+
+ // Same here.
+ for (StringPiece sp : LiteStrSplit(embedding_names, ';')) {
+ embedding_names_.emplace_back(sp);
+ }
+
+ std::vector<StringPiece> dim_strs = LiteStrSplit(embedding_dims, ';');
+ for (const auto &dim_str : dim_strs) {
+ int dim = 0;
+ if (!LiteAtoi(dim_str, &dim)) {
+ SAFTM_LOG(ERROR) << "Unable to parse " << dim_str;
+ return false;
+ }
+ embedding_dims_.push_back(dim);
+ }
+ return true;
+}
+
+bool GenericEmbeddingFeatureExtractor::Init(TaskContext *context) {
+ return true;
+}
+
+} // namespace mobile
+} // namespace nlp_saft
diff --git a/lang_id/common/embedding-feature-extractor.h b/lang_id/common/embedding-feature-extractor.h
new file mode 100644
index 0000000..f51b6e5
--- /dev/null
+++ b/lang_id/common/embedding-feature-extractor.h
@@ -0,0 +1,174 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_FEATURE_EXTRACTOR_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_FEATURE_EXTRACTOR_H_
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "lang_id/common/fel/feature-extractor.h"
+#include "lang_id/common/fel/task-context.h"
+#include "lang_id/common/fel/workspace.h"
+#include "lang_id/common/lite_base/attributes.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+// An EmbeddingFeatureExtractor manages the extraction of features for
+// embedding-based models. It wraps a sequence of underlying classes of feature
+// extractors, along with associated predicate maps. Each class of feature
+// extractors is associated with a name, e.g., "words", "labels", "tags".
+//
+// The class is split between a generic abstract version,
+// GenericEmbeddingFeatureExtractor (that can be initialized without knowing the
+// signature of the ExtractFeatures method) and a typed version.
+//
+// The predicate maps must be initialized before use: they can be loaded using
+// Read() or updated via UpdateMapsForExample.
+class GenericEmbeddingFeatureExtractor {
+ public:
+ // Constructs this GenericEmbeddingFeatureExtractor.
+ //
+ // |arg_prefix| is a string prefix for the relevant TaskContext parameters, to
+ // avoid name clashes. See GetParamName().
+ explicit GenericEmbeddingFeatureExtractor(const string &arg_prefix)
+ : arg_prefix_(arg_prefix) {}
+
+ virtual ~GenericEmbeddingFeatureExtractor() {}
+
+ // Sets/inits up predicate maps and embedding space names that are common for
+ // all embedding based feature extractors.
+ //
+ // Returns true on success, false otherwise.
+ SAFTM_MUST_USE_RESULT virtual bool Setup(TaskContext *context);
+ SAFTM_MUST_USE_RESULT virtual bool Init(TaskContext *context);
+
+ // Requests workspace for the underlying feature extractors. This is
+ // implemented in the typed class.
+ virtual void RequestWorkspaces(WorkspaceRegistry *registry) = 0;
+
+ // Returns number of embedding spaces.
+ int NumEmbeddings() const { return embedding_dims_.size(); }
+
+ const std::vector<string> &embedding_fml() const { return embedding_fml_; }
+
+ // Get parameter name by concatenating the prefix and the original name.
+ string GetParamName(const string ¶m_name) const {
+ string full_name = arg_prefix_;
+ full_name.push_back('_');
+ full_name.append(param_name);
+ return full_name;
+ }
+
+ private:
+ // Prefix for TaskContext parameters.
+ const string arg_prefix_;
+
+ // Embedding space names for parameter sharing.
+ std::vector<string> embedding_names_;
+
+ // FML strings for each feature extractor.
+ std::vector<string> embedding_fml_;
+
+ // Size of each of the embedding spaces (maximum predicate id).
+ std::vector<int> embedding_sizes_;
+
+ // Embedding dimensions of the embedding spaces (i.e. 32, 64 etc.)
+ std::vector<int> embedding_dims_;
+};
+
+// Templated, object-specific implementation of the
+// EmbeddingFeatureExtractor. EXTRACTOR should be a FeatureExtractor<OBJ,
+// ARGS...> class that has the appropriate FeatureTraits() to ensure that
+// locator type features work.
+//
+// Note: for backwards compatibility purposes, this always reads the FML spec
+// from "<prefix>_features".
+template <class EXTRACTOR, class OBJ, class... ARGS>
+class EmbeddingFeatureExtractor : public GenericEmbeddingFeatureExtractor {
+ public:
+ // Constructs this EmbeddingFeatureExtractor.
+ //
+ // |arg_prefix| is a string prefix for the relevant TaskContext parameters, to
+ // avoid name clashes. See GetParamName().
+ explicit EmbeddingFeatureExtractor(const string &arg_prefix)
+ : GenericEmbeddingFeatureExtractor(arg_prefix) {}
+
+ // Sets up all predicate maps, feature extractors, and flags.
+ SAFTM_MUST_USE_RESULT bool Setup(TaskContext *context) override {
+ if (!GenericEmbeddingFeatureExtractor::Setup(context)) {
+ return false;
+ }
+ feature_extractors_.resize(embedding_fml().size());
+ for (int i = 0; i < embedding_fml().size(); ++i) {
+ feature_extractors_[i].reset(new EXTRACTOR());
+ if (!feature_extractors_[i]->Parse(embedding_fml()[i])) return false;
+ if (!feature_extractors_[i]->Setup(context)) return false;
+ }
+ return true;
+ }
+
+ // Initializes resources needed by the feature extractors.
+ SAFTM_MUST_USE_RESULT bool Init(TaskContext *context) override {
+ if (!GenericEmbeddingFeatureExtractor::Init(context)) return false;
+ for (auto &feature_extractor : feature_extractors_) {
+ if (!feature_extractor->Init(context)) return false;
+ }
+ return true;
+ }
+
+ // Requests workspaces from the registry. Must be called after Init(), and
+ // before Preprocess().
+ void RequestWorkspaces(WorkspaceRegistry *registry) override {
+ for (auto &feature_extractor : feature_extractors_) {
+ feature_extractor->RequestWorkspaces(registry);
+ }
+ }
+
+ // Must be called on the object one state for each sentence, before any
+ // feature extraction (e.g., UpdateMapsForExample, ExtractFeatures).
+ void Preprocess(WorkspaceSet *workspaces, OBJ *obj) const {
+ for (auto &feature_extractor : feature_extractors_) {
+ feature_extractor->Preprocess(workspaces, obj);
+ }
+ }
+
+ // Extracts features using the extractors. Note that features must already
+ // be initialized to the correct number of feature extractors. No predicate
+ // mapping is applied.
+ void ExtractFeatures(const WorkspaceSet &workspaces, const OBJ &obj,
+ ARGS... args,
+ std::vector<FeatureVector> *features) const {
+ // DCHECK(features != nullptr);
+ // DCHECK_EQ(features->size(), feature_extractors_.size());
+ for (int i = 0; i < feature_extractors_.size(); ++i) {
+ (*features)[i].clear();
+ feature_extractors_[i]->ExtractFeatures(workspaces, obj, args...,
+ &(*features)[i]);
+ }
+ }
+
+ private:
+ // Templated feature extractor class.
+ std::vector<std::unique_ptr<EXTRACTOR>> feature_extractors_;
+};
+
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_FEATURE_EXTRACTOR_H_
diff --git a/lang_id/common/embedding-feature-interface.h b/lang_id/common/embedding-feature-interface.h
new file mode 100644
index 0000000..87576c6
--- /dev/null
+++ b/lang_id/common/embedding-feature-interface.h
@@ -0,0 +1,115 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_FEATURE_INTERFACE_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_FEATURE_INTERFACE_H_
+
+#include <string>
+#include <vector>
+
+#include "lang_id/common/embedding-feature-extractor.h"
+#include "lang_id/common/fel/feature-extractor.h"
+#include "lang_id/common/fel/task-context.h"
+#include "lang_id/common/fel/workspace.h"
+#include "lang_id/common/lite_base/attributes.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+template <class EXTRACTOR, class OBJ, class... ARGS>
+class EmbeddingFeatureInterface {
+ public:
+ // Constructs this EmbeddingFeatureInterface.
+ //
+ // |arg_prefix| is a string prefix for the TaskContext parameters, passed to
+ // |the underlying EmbeddingFeatureExtractor.
+ explicit EmbeddingFeatureInterface(const string &arg_prefix)
+ : feature_extractor_(arg_prefix) {}
+
+ // Sets up feature extractors and flags for processing (inference).
+ SAFTM_MUST_USE_RESULT bool SetupForProcessing(TaskContext *context) {
+ return feature_extractor_.Setup(context);
+ }
+
+ // Initializes feature extractor resources for processing (inference)
+ // including requesting a workspace for caching extracted features.
+ SAFTM_MUST_USE_RESULT bool InitForProcessing(TaskContext *context) {
+ if (!feature_extractor_.Init(context)) return false;
+ feature_extractor_.RequestWorkspaces(&workspace_registry_);
+ return true;
+ }
+
+ // Preprocesses *obj using the internal workspace registry.
+ void Preprocess(WorkspaceSet *workspace, OBJ *obj) const {
+ workspace->Reset(workspace_registry_);
+ feature_extractor_.Preprocess(workspace, obj);
+ }
+
+ // Extract features from |obj|. On return, FeatureVector features[i]
+ // contains the features for the embedding space #i.
+ //
+ // This function uses the precomputed info from |workspace|. Usage pattern:
+ //
+ // EmbeddingFeatureInterface<...> feature_interface;
+ // ...
+ // OBJ obj;
+ // WorkspaceSet workspace;
+ // feature_interface.Preprocess(&workspace, &obj);
+ //
+ // // For the same obj, but with different args:
+ // std::vector<FeatureVector> features;
+ // feature_interface.GetFeatures(obj, args, workspace, &features);
+ //
+ // This pattern is useful (more efficient) if you can pre-compute some info
+ // for the entire |obj|, which is reused by the feature extraction performed
+ // for different args. If that is not the case, you can use the simpler
+ // version GetFeaturesNoCaching below.
+ void GetFeatures(const OBJ &obj, ARGS... args, const WorkspaceSet &workspace,
+ std::vector<FeatureVector> *features) const {
+ feature_extractor_.ExtractFeatures(workspace, obj, args..., features);
+ }
+
+ // Simpler version of GetFeatures(), for cases when there is no opportunity to
+ // reuse computation between feature extractions for the same |obj|, but with
+ // different |args|. Returns the extracted features. For more info, see the
+ // doc for GetFeatures().
+ std::vector<FeatureVector> GetFeaturesNoCaching(OBJ *obj,
+ ARGS... args) const {
+ // Technically, we still use a workspace, because
+ // feature_extractor_.ExtractFeatures requires one. But there is no real
+ // caching here, as we start from scratch for each call to ExtractFeatures.
+ WorkspaceSet workspace;
+ Preprocess(&workspace, obj);
+ std::vector<FeatureVector> features(NumEmbeddings());
+ GetFeatures(*obj, args..., workspace, &features);
+ return features;
+ }
+
+ // Returns number of embedding spaces.
+ int NumEmbeddings() const { return feature_extractor_.NumEmbeddings(); }
+
+ private:
+ // Typed feature extractor for embeddings.
+ EmbeddingFeatureExtractor<EXTRACTOR, OBJ, ARGS...> feature_extractor_;
+
+ // The registry of shared workspaces in the feature extractor.
+ WorkspaceRegistry workspace_registry_;
+};
+
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_FEATURE_INTERFACE_H_
diff --git a/lang_id/common/embedding-network-params.cc b/lang_id/common/embedding-network-params.cc
new file mode 100644
index 0000000..be7c80e
--- /dev/null
+++ b/lang_id/common/embedding-network-params.cc
@@ -0,0 +1,44 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/common/embedding-network-params.h"
+
+#include "lang_id/common/lite_base/logging.h"
+
+namespace libtextclassifier3 {
+
+QuantizationType ParseQuantizationType(const string &s) {
+ if (s == "NONE") {
+ return QuantizationType::NONE;
+ }
+ if (s == "UINT8") {
+ return QuantizationType::UINT8;
+ }
+ if (s == "UINT4") {
+ return QuantizationType::UINT4;
+ }
+ if (s == "FLOAT16") {
+ return QuantizationType::FLOAT16;
+ }
+ SAFTM_LOG(FATAL) << "Unsupported quantization type: " << s;
+
+ // Execution should never reach this point; just to keep the compiler happy.
+ // TODO(salcianu): implement SAFTM_LOG(FATAL) in a way that doesn't require
+ // this trick.
+ return QuantizationType::NONE;
+}
+
+} // namespace nlp_saft
diff --git a/lang_id/common/embedding-network-params.h b/lang_id/common/embedding-network-params.h
new file mode 100755
index 0000000..f43c653
--- /dev/null
+++ b/lang_id/common/embedding-network-params.h
@@ -0,0 +1,316 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_NETWORK_PARAMS_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_NETWORK_PARAMS_H_
+
+#include <string>
+
+#include "lang_id/common/fel/task-context.h"
+#include "lang_id/common/lite_base/float16.h"
+#include "lang_id/common/lite_base/logging.h"
+
+namespace libtextclassifier3 {
+
+enum class QuantizationType {
+ NONE = 0,
+
+ // Quantization to 8 bit unsigned ints.
+ UINT8,
+
+ // Quantization to 4 bit unsigned ints.
+ UINT4,
+
+ // Quantization to 16 bit floats, the type defined in
+ // lang_id/common/float16.h
+ FLOAT16,
+
+ // NOTE: for backward compatibility, if you add a new value to this enum, add
+ // it *at the end*, such that you do not change the integer values of the
+ // existing enum values.
+};
+
+// Converts "UINT8" -> QuantizationType::UINT8, and so on.
+QuantizationType ParseQuantizationType(const string &s);
+
+// API for accessing parameters for a feed-forward neural network with
+// embeddings.
+//
+//
+// In fact, we provide two APIs: a high-level (and highly-recommented) API, with
+// methods named using the BigCamel notation (e.g., GetEmbeddingMatrix()) and a
+// low-level API, using C-style names (e.g., softmax_num_cols()).
+//
+// Note: the API below is meant to allow the inference code (the class
+// libtextclassifier3::mobile::EmbeddingNetwork) to use the data directly, with no need
+// for transposing any matrix (which would require extra overhead on mobile
+// devices). Hence, as indicated by the comments for the API methods, some of
+// the matrices below are the transposes of the corresponding matrices from the
+// original proto.
+class EmbeddingNetworkParams {
+ public:
+ virtual ~EmbeddingNetworkParams() {}
+
+ // Returns true if these params are valid. False otherwise (e.g., if the
+ // underlying data is corrupted). If is_valid() returns false, clients should
+ // not call any other method on that instance of EmbeddingNetworkParams. If
+ // is_valid() returns true, then calls to the API methods below should not
+ // crash *if they are called with index parameters in bounds*. E.g., if
+ // is_valid() and 0 <= i < embeddings_size(), then GetEmbeddingMatrix(i)
+ // should not crash.
+ virtual bool is_valid() const = 0;
+
+ // **** High-level API.
+
+ // Simple representation of a matrix. This small struct that doesn't own any
+ // resource intentionally supports copy / assign, to simplify our APIs.
+ struct Matrix {
+ // Number of rows.
+ int rows = 0;
+
+ // Number of columns.
+ int cols = 0;
+
+ QuantizationType quant_type = QuantizationType::NONE;
+
+ // Pointer to matrix elements, in row-major order
+ // (https://en.wikipedia.org/wiki/Row-major_order) Not owned.
+ const void *elements = nullptr;
+
+ // Quantization scales: one scale for each row.
+ const ::libtextclassifier3::mobile::float16 *quant_scales = nullptr;
+ };
+
+ // Returns i-th embedding matrix. Crashes on out of bounds indices.
+ //
+ // This is the transpose of the corresponding matrix from the original proto.
+ Matrix GetEmbeddingMatrix(int i) const {
+ CheckIndex(i, embeddings_size(), "embedding matrix");
+ Matrix matrix;
+ matrix.rows = embeddings_num_rows(i);
+ matrix.cols = embeddings_num_cols(i);
+ matrix.elements = embeddings_weights(i);
+ matrix.quant_type = embeddings_quant_type(i);
+ matrix.quant_scales = embeddings_quant_scales(i);
+ return matrix;
+ }
+
+ // Returns weight matrix for i-th hidden layer. Crashes on out of bounds
+ // indices.
+ //
+ // This is the transpose of the corresponding matrix from the original proto.
+ Matrix GetHiddenLayerMatrix(int i) const {
+ CheckIndex(i, hidden_size(), "hidden layer");
+ Matrix matrix;
+ matrix.rows = hidden_num_rows(i);
+ matrix.cols = hidden_num_cols(i);
+
+ // Quantization not supported here.
+ matrix.quant_type = hidden_weights_quant_type(i);
+ matrix.elements = hidden_weights(i);
+ return matrix;
+ }
+
+ // Returns bias for i-th hidden layer. Technically a Matrix, but we expect it
+ // to be a row/column vector (i.e., num rows or num cols is 1). However, we
+ // don't CHECK for that: we just provide access to underlying data. Crashes
+ // on out of bounds indices.
+ Matrix GetHiddenLayerBias(int i) const {
+ CheckIndex(i, hidden_bias_size(), "hidden layer bias");
+ Matrix matrix;
+ matrix.rows = hidden_bias_num_rows(i);
+ matrix.cols = hidden_bias_num_cols(i);
+
+ // Quantization not supported here.
+ matrix.quant_type = QuantizationType::NONE;
+ matrix.elements = hidden_bias_weights(i);
+ return matrix;
+ }
+
+ // Returns true if a softmax layer exists.
+ bool HasSoftmax() const {
+ return softmax_size() == 1;
+ }
+
+ // Returns weight matrix for the softmax layer. Note: should be called only
+ // if HasSoftmax() is true.
+ //
+ // This is the transpose of the corresponding matrix from the original proto.
+ Matrix GetSoftmaxMatrix() const {
+ SAFTM_CHECK(HasSoftmax()) << "No softmax layer.";
+ Matrix matrix;
+ matrix.rows = softmax_num_rows(0);
+ matrix.cols = softmax_num_cols(0);
+
+ // Quantization not supported here.
+ matrix.quant_type = softmax_weights_quant_type(0);
+ matrix.elements = softmax_weights(0);
+ return matrix;
+ }
+
+ // Returns bias for the softmax layer. Technically a Matrix, but we expect it
+ // to be a row/column vector (i.e., num rows or num cols is 1). However, we
+ // don't CHECK for that: we just provide access to underlying data.
+ Matrix GetSoftmaxBias() const {
+ SAFTM_CHECK(HasSoftmax()) << "No softmax layer.";
+ Matrix matrix;
+ matrix.rows = softmax_bias_num_rows(0);
+ matrix.cols = softmax_bias_num_cols(0);
+
+ // Quantization not supported here.
+ matrix.quant_type = QuantizationType::NONE;
+ matrix.elements = softmax_bias_weights(0);
+ return matrix;
+ }
+
+ // Updates the EmbeddingNetwork-related parameters from task_context. Returns
+ // true on success, false on error.
+ virtual bool UpdateTaskContextParameters(
+ mobile::TaskContext *task_context) = 0;
+
+ // **** Low-level API.
+ //
+ // * Most low-level API methods are documented by giving an equivalent
+ // function call on proto, the original proto (of type
+ // EmbeddingNetworkProto) which was used to generate the C++ code.
+ //
+ // * To simplify our generation code, optional proto fields of message type
+ // are treated as repeated fields with 0 or 1 instances. As such, we have
+ // *_size() methods for such optional fields: they return 0 or 1.
+ //
+ // * "transpose(M)" denotes the transpose of a matrix M.
+
+ // ** Access methods for repeated MatrixParams embeddings.
+ //
+ // Returns proto.embeddings_size().
+ virtual int embeddings_size() const = 0;
+
+ // Returns number of rows of transpose(proto.embeddings(i)).
+ virtual int embeddings_num_rows(int i) const = 0;
+
+ // Returns number of columns of transpose(proto.embeddings(i)).
+ virtual int embeddings_num_cols(int i) const = 0;
+
+ // Returns pointer to elements of transpose(proto.embeddings(i)), in row-major
+ // order. NOTE: for unquantized embeddings, this returns a pointer to float;
+ // for quantized embeddings, this returns a pointer to uint8.
+ virtual const void *embeddings_weights(int i) const = 0;
+
+ virtual QuantizationType embeddings_quant_type(int i) const {
+ return QuantizationType::NONE;
+ }
+
+ virtual const ::libtextclassifier3::mobile::float16 *embeddings_quant_scales(
+ int i) const {
+ return nullptr;
+ }
+
+ // ** Access methods for repeated MatrixParams hidden.
+ //
+ // Returns embedding_network_proto.hidden_size().
+ virtual int hidden_size() const = 0;
+
+ // Returns embedding_network_proto.hidden(i).rows().
+ virtual int hidden_num_rows(int i) const = 0;
+
+ // Returns embedding_network_proto.hidden(i).rows().
+ virtual int hidden_num_cols(int i) const = 0;
+
+ // Returns quantization mode for the weights of the i-th hidden layer.
+ virtual QuantizationType hidden_weights_quant_type(int i) const {
+ return QuantizationType::NONE;
+ }
+
+ // Returns pointer to beginning of array of floats with all values from
+ // embedding_network_proto.hidden(i).
+ virtual const void *hidden_weights(int i) const = 0;
+
+ // ** Access methods for repeated MatrixParams hidden_bias.
+ //
+ // Returns proto.hidden_bias_size().
+ virtual int hidden_bias_size() const = 0;
+
+ // Returns number of rows of proto.hidden_bias(i).
+ virtual int hidden_bias_num_rows(int i) const = 0;
+
+ // Returns number of columns of proto.hidden_bias(i).
+ virtual int hidden_bias_num_cols(int i) const = 0;
+
+ // Returns pointer to elements of proto.hidden_bias(i), in row-major order.
+ virtual const void *hidden_bias_weights(int i) const = 0;
+
+ // ** Access methods for optional MatrixParams softmax.
+ //
+ // Returns 1 if proto has optional field softmax, 0 otherwise.
+ virtual int softmax_size() const = 0;
+
+ // Returns number of rows of transpose(proto.softmax()).
+ virtual int softmax_num_rows(int i) const = 0;
+
+ // Returns number of columns of transpose(proto.softmax()).
+ virtual int softmax_num_cols(int i) const = 0;
+
+ // Returns quantization mode for the softmax weights.
+ virtual QuantizationType softmax_weights_quant_type(int i) const {
+ return QuantizationType::NONE;
+ }
+
+ // Returns pointer to elements of transpose(proto.softmax()), in row-major
+ // order.
+ virtual const void *softmax_weights(int i) const = 0;
+
+ // ** Access methods for optional MatrixParams softmax_bias.
+ //
+ // Returns 1 if proto has optional field softmax_bias, 0 otherwise.
+ virtual int softmax_bias_size() const = 0;
+
+ // Returns number of rows of proto.softmax_bias().
+ virtual int softmax_bias_num_rows(int i) const = 0;
+
+ // Returns number of columns of proto.softmax_bias().
+ virtual int softmax_bias_num_cols(int i) const = 0;
+
+ // Returns pointer to elements of proto.softmax_bias(), in row-major order.
+ virtual const void *softmax_bias_weights(int i) const = 0;
+
+ // ** Access methods for repeated int32 embedding_num_features.
+ //
+ // Returns proto.embedding_num_features_size().
+ virtual int embedding_num_features_size() const = 0;
+
+ // Returns proto.embedding_num_features(i).
+ virtual int embedding_num_features(int i) const = 0;
+
+ // ** Access methods for is_precomputed
+ //
+ // Returns proto.has_is_precomputed().
+ virtual bool has_is_precomputed() const = 0;
+
+ // Returns proto.is_precomputed().
+ virtual bool is_precomputed() const = 0;
+
+ protected:
+ void CheckIndex(int index, int size, const string &description) const {
+ SAFTM_CHECK_GE(index, 0)
+ << "Out-of-range index for " << description << ": " << index;
+ SAFTM_CHECK_LT(index, size)
+ << "Out-of-range index for " << description << ": " << index;
+ }
+}; // class EmbeddingNetworkParams
+
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_NETWORK_PARAMS_H_
diff --git a/lang_id/common/embedding-network.cc b/lang_id/common/embedding-network.cc
new file mode 100644
index 0000000..469cb1f
--- /dev/null
+++ b/lang_id/common/embedding-network.cc
@@ -0,0 +1,323 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/common/embedding-network.h"
+
+#include "lang_id/common/lite_base/integral-types.h"
+#include "lang_id/common/lite_base/logging.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+namespace {
+
+void CheckNoQuantization(const EmbeddingNetworkParams::Matrix &matrix) {
+ SAFTM_CHECK_EQ(static_cast<int>(QuantizationType::NONE),
+ static_cast<int>(matrix.quant_type))
+ << "Quantization not allowed here";
+}
+
+int GetMatrixRowSizeInBytes(const EmbeddingNetworkParams::Matrix &matrix) {
+ int cols = matrix.cols;
+ QuantizationType quant_type = matrix.quant_type;
+ switch (quant_type) {
+ case QuantizationType::NONE:
+ return cols * sizeof(float);
+ case QuantizationType::UINT8:
+ return cols * sizeof(uint8);
+ case QuantizationType::UINT4:
+ SAFTM_DCHECK_EQ(cols % 2, 0) << "UINT4 with odd #cols = " << cols;
+ return cols / 2;
+ case QuantizationType::FLOAT16:
+ return cols * sizeof(float16);
+ default:
+ SAFTM_LOG(FATAL) << "Unknown quant type: "
+ << static_cast<int>(quant_type);
+ }
+}
+
+// Computes y = weights * Relu(x) + b where Relu is optionally applied.
+//
+// weights and b are the weight matrix, respectively the bias vector of a neural
+// network layer.
+//
+// Note: in the research literature, usually Relu (the activation function) is
+// the last part of a neural layer. From that perspective, this function
+// computes the Relu part of the previous layer (if any) and next the first half
+// (the computation of the state) for the current layer.
+//
+// Note: weights is expected to be the transposed version of the real weight
+// matrix. Hence, instead of computing a linear combination of the columns of
+// weights, we compute a linear combination of its rows; but we are mindful that
+// these rows are the columns of the original matrix, hence the name
+// weights_col_i in the code.
+void SparseReluProductPlusBias(bool apply_relu,
+ const EmbeddingNetworkParams::Matrix &weights,
+ const EmbeddingNetworkParams::Matrix &b,
+ const std::vector<float> &x,
+ std::vector<float> *y) {
+ // Initialize y to b. b is a column matrix (i.e., nb.cols == 1); we already
+ // CHECK-ed that the EmbeddingNetwork constructor.
+ const float *b_start = reinterpret_cast<const float *>(b.elements);
+ SAFTM_DCHECK_EQ(b.cols, 1);
+ y->assign(b_start, b_start + b.rows);
+
+ float *const y_data = y->data();
+ const int y_size = y->size();
+ SAFTM_CHECK_EQ(weights.cols, y_size);
+ const int x_size = x.size();
+ SAFTM_CHECK_EQ(weights.rows, x_size);
+
+ // NOTE: the code below reads x_size * y_size elements from weights; these
+ // reads are safe as long as weights.elements contains weights.rows *
+ // weights.cols elements (where the element size depends on the quantization
+ // type). That requirement is checked by the params provider, e.g., by
+ // EmbeddingNetworkParamsFromFlatbuffer.
+
+ // There is some code duplication between the two main cases of the switch
+ // below: the idea was to "lift" the switch outside the loops, to reduce the
+ // number of tests at runtime.
+ switch (weights.quant_type) {
+ case QuantizationType::NONE: {
+ // We compute a linear combination of the rows from |weights|, using
+ // elements of x (optionally, Relu(x)) as scaling factors (the i-th row
+ // gets multiplied by x[i] before being added with the other rows). Note:
+ // elements of |weights| are stored in row-major order: first the elements
+ // of row #0, next the elements of row #1, etc. In the comments below, we
+ // write "weights[i][j]" to refer to the j-th element from the i-th row of
+ // weights.
+ const float *weight_ptr =
+ reinterpret_cast<const float *>(weights.elements);
+ for (int i = 0; i < x_size; ++i) {
+ // Invariant 1: weight_ptr points to the beginning of the i-th row from
+ // weights (i.e., weights[i][0]).
+ const float scale = x[i];
+ if (!apply_relu || (scale > 0)) {
+ for (int j = 0; j < y_size; ++j, ++weight_ptr) {
+ // Invariant 2: weight_ptr points to weights[i][j].
+ y_data[j] += (*weight_ptr) * scale;
+ }
+ } else {
+ // We don't update y_data, but we still have to move weight_ptr to the
+ // next row (to satisfy Invariant 1). We do this by adding y_size ==
+ // weights.cols() (see earlier CHECK_EQ).
+ weight_ptr += y_size;
+ }
+ }
+ break;
+ }
+ case QuantizationType::FLOAT16: {
+ // See comments for the QuantizationType::NONE case: the code is almost
+ // identical, except for float16 (instead of float) and the Float16To32
+ // conversion. We could unify these two cases using a template, but since
+ // this is a critical loop, don't want to risk that e.g., inlining of the
+ // conversion function doesn't happen.
+ const float16 *weight_ptr =
+ reinterpret_cast<const float16 *>(weights.elements);
+ for (int i = 0; i < x_size; ++i) {
+ const float scale = x[i];
+ if (!apply_relu || (scale > 0)) {
+ for (int j = 0; j < y_size; ++j, ++weight_ptr) {
+ y_data[j] += Float16To32(*weight_ptr) * scale;
+ }
+ } else {
+ weight_ptr += y_size;
+ }
+ }
+ break;
+ }
+ default:
+ SAFTM_LOG(FATAL) << "Unsupported weights quantization type: "
+ << static_cast<int>(weights.quant_type);
+ }
+}
+} // namespace
+
+void EmbeddingNetwork::ConcatEmbeddings(
+ const std::vector<FeatureVector> &feature_vectors,
+ std::vector<float> *concat) const {
+ concat->resize(concat_layer_size_);
+
+ // "es_index" stands for "embedding space index".
+ for (int es_index = 0; es_index < feature_vectors.size(); ++es_index) {
+ const int concat_offset = concat_offset_[es_index];
+
+ const EmbeddingNetworkParams::Matrix &embedding_matrix =
+ embedding_matrices_[es_index];
+ const int embedding_dim = embedding_matrix.cols;
+ const int embedding_row_size_in_bytes =
+ embedding_row_size_in_bytes_[es_index];
+
+ const FeatureVector &feature_vector = feature_vectors[es_index];
+ const int num_features = feature_vector.size();
+ for (int fi = 0; fi < num_features; ++fi) {
+ const FeatureType *feature_type = feature_vector.type(fi);
+ int feature_offset = concat_offset + feature_type->base() * embedding_dim;
+ SAFTM_CHECK_LE(feature_offset + embedding_dim, concat->size());
+
+ // Weighted embeddings will be added starting from this address.
+ float *concat_ptr = concat->data() + feature_offset;
+
+ // Multiplier for each embedding weight. Includes feature weight (for
+ // continuous features) and quantization scale (for quantized embeddings).
+ float multiplier;
+ int feature_id;
+ const FeatureValue feature_value = feature_vector.value(fi);
+ if (feature_type->is_continuous()) {
+ // Continuous features (encoded as FloatFeatureValue).
+ FloatFeatureValue float_feature_value(feature_value);
+ feature_id = float_feature_value.id;
+ multiplier = float_feature_value.weight;
+ } else {
+ // Discrete features: every present feature has implicit value 1.0.
+ feature_id = feature_value;
+ multiplier = 1.0;
+ }
+
+ SAFTM_CHECK_GE(feature_id, 0);
+ SAFTM_CHECK_LT(feature_id, embedding_matrix.rows);
+
+ // Pointer to float / uint8 weights for relevant embedding.
+ const void *embedding_data =
+ (reinterpret_cast<const char *>(embedding_matrix.elements) +
+ feature_id * embedding_row_size_in_bytes);
+
+ switch (embedding_matrix.quant_type) {
+ case QuantizationType::NONE: {
+ const float *weights =
+ reinterpret_cast<const float *>(embedding_data);
+ for (int i = 0; i < embedding_dim; ++i, ++weights, ++concat_ptr) {
+ *concat_ptr += *weights * multiplier;
+ }
+ break;
+ }
+ case QuantizationType::UINT8: {
+ multiplier *= Float16To32(embedding_matrix.quant_scales[feature_id]);
+ const uint8 *quant_weights =
+ reinterpret_cast<const uint8 *>(embedding_data);
+ for (int i = 0; i < embedding_dim;
+ ++i, ++quant_weights, ++concat_ptr) {
+ // 128 is bias for UINT8 quantization.
+ *concat_ptr +=
+ (static_cast<int>(*quant_weights) - 128) * multiplier;
+ }
+ break;
+ }
+ case QuantizationType::UINT4: {
+ multiplier *= Float16To32(embedding_matrix.quant_scales[feature_id]);
+ const uint8 *quant_weights =
+ reinterpret_cast<const uint8 *>(embedding_data);
+ for (int i = 0; i < embedding_dim / 2; ++i, ++quant_weights) {
+ const uint8 qq = *quant_weights;
+ concat_ptr[0] +=
+ (static_cast<int>((qq & 0xF0) | 0x08) - 128) * multiplier;
+ concat_ptr[1] +=
+ (static_cast<int>(((qq & 0x0F) << 4) | 0x08) - 128) *
+ multiplier;
+ concat_ptr += 2;
+ }
+ break;
+ }
+ default:
+ // We already checked (in GetMatrixRowSizeInBytes) that each embedding
+ // matrix has a known quantization type. Hence, DLOG is enough here.
+ SAFTM_DLOG(ERROR) << "Unknown embeddings quantization type "
+ << static_cast<int>(embedding_matrix.quant_type);
+ break;
+ }
+ }
+ }
+}
+
+void EmbeddingNetwork::ComputeFinalScores(
+ const std::vector<FeatureVector> &features,
+ std::vector<float> *scores) const {
+ ComputeFinalScores(features, {}, scores);
+}
+
+void EmbeddingNetwork::ComputeFinalScores(
+ const std::vector<FeatureVector> &features,
+ const std::vector<float> &extra_inputs, std::vector<float> *scores) const {
+ // Construct the input layer for our feed-forward neural network (FFNN).
+ std::vector<float> input;
+ ConcatEmbeddings(features, &input);
+ if (!extra_inputs.empty()) {
+ input.reserve(input.size() + extra_inputs.size());
+ for (int i = 0; i < extra_inputs.size(); i++) {
+ input.push_back(extra_inputs[i]);
+ }
+ }
+
+ // Propagate input through all layers of our FFNN.
+
+ // Alternating storage for activations of the different layers. We can't use
+ // a single vector because all activations of the previous layer are required
+ // when computing the activations of the next one.
+ std::vector<float> storage[2];
+ const std::vector<float> *v_in = &input;
+ const int num_layers = layer_weights_.size();
+ for (int i = 0; i < num_layers; ++i) {
+ std::vector<float> *v_out = nullptr;
+ if (i == num_layers - 1) {
+ // Final layer: write results directly into |scores|.
+ v_out = scores;
+ } else {
+ // Hidden layer: write results into the alternating storage. The i % 2
+ // trick ensures the alternation.
+ v_out = &(storage[i % 2]);
+ }
+ const bool apply_relu = i > 0;
+ SparseReluProductPlusBias(
+ apply_relu, layer_weights_[i], layer_bias_[i], *v_in, v_out);
+ v_in = v_out;
+ }
+}
+
+EmbeddingNetwork::EmbeddingNetwork(const EmbeddingNetworkParams *model)
+ : model_(model) {
+ int offset_sum = 0;
+ for (int i = 0; i < model_->embedding_num_features_size(); ++i) {
+ concat_offset_.push_back(offset_sum);
+ EmbeddingNetworkParams::Matrix matrix = model_->GetEmbeddingMatrix(i);
+ offset_sum += matrix.cols * model_->embedding_num_features(i);
+
+ // NOTE: each Matrix is a small struct that doesn't own the actual matrix
+ // weights. Hence, the push_back below is fast.
+ embedding_matrices_.push_back(matrix);
+ embedding_row_size_in_bytes_.push_back(GetMatrixRowSizeInBytes(matrix));
+ }
+ concat_layer_size_ = offset_sum;
+
+ SAFTM_CHECK_EQ(model_->hidden_size(), model_->hidden_bias_size());
+ for (int i = 0; i < model_->hidden_size(); ++i) {
+ layer_weights_.push_back(model_->GetHiddenLayerMatrix(i));
+
+ EmbeddingNetworkParams::Matrix bias = model_->GetHiddenLayerBias(i);
+ SAFTM_CHECK_EQ(1, bias.cols);
+ CheckNoQuantization(bias);
+ layer_bias_.push_back(bias);
+ }
+
+ SAFTM_CHECK(model_->HasSoftmax());
+ layer_weights_.push_back(model_->GetSoftmaxMatrix());
+
+ EmbeddingNetworkParams::Matrix softmax_bias = model_->GetSoftmaxBias();
+ SAFTM_CHECK_EQ(1, softmax_bias.cols);
+ CheckNoQuantization(softmax_bias);
+ layer_bias_.push_back(softmax_bias);
+}
+
+} // namespace mobile
+} // namespace nlp_saft
diff --git a/lang_id/common/embedding-network.h b/lang_id/common/embedding-network.h
new file mode 100644
index 0000000..54094d7
--- /dev/null
+++ b/lang_id/common/embedding-network.h
@@ -0,0 +1,99 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_NETWORK_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_NETWORK_H_
+
+#include <vector>
+
+#include "lang_id/common/embedding-network-params.h"
+#include "lang_id/common/fel/feature-extractor.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+// Classifier using a hand-coded feed-forward neural network.
+//
+// No gradient computation, just inference.
+//
+// Based on the more general nlp_saft::EmbeddingNetwork (without ::mobile).
+//
+// Classification works as follows:
+//
+// Discrete features -> Embeddings -> Concatenation -> Hidden+ -> Softmax
+//
+// In words: given some discrete features, this class extracts the embeddings
+// for these features, concatenates them, passes them through one or more hidden
+// layers (each layer uses Relu) and next through a softmax layer that computes
+// an unnormalized score for each possible class. Note: there is always a
+// softmax layer at the end.
+class EmbeddingNetwork {
+ public:
+ // Constructs an embedding network using the parameters from model.
+ //
+ // Note: model should stay alive for at least the lifetime of this
+ // EmbeddingNetwork object.
+ explicit EmbeddingNetwork(const EmbeddingNetworkParams *model);
+
+ virtual ~EmbeddingNetwork() {}
+
+ // Runs forward computation to fill scores with unnormalized output unit
+ // scores. This is useful for making predictions.
+ void ComputeFinalScores(const std::vector<FeatureVector> &features,
+ std::vector<float> *scores) const;
+
+ // Same as above, but allows specification of extra extra neural network
+ // inputs that will be appended to the embedding vector build from features.
+ void ComputeFinalScores(const std::vector<FeatureVector> &features,
+ const std::vector<float> &extra_inputs,
+ std::vector<float> *scores) const;
+
+ private:
+ // Constructs the concatenated input embedding vector in place in output
+ // vector concat.
+ void ConcatEmbeddings(const std::vector<FeatureVector> &features,
+ std::vector<float> *concat) const;
+
+ // Pointer to the model object passed to the constructor. Not owned.
+ const EmbeddingNetworkParams *model_;
+
+ // Network parameters.
+
+ // One weight matrix for each embedding.
+ std::vector<EmbeddingNetworkParams::Matrix> embedding_matrices_;
+
+ // embedding_row_size_in_bytes_[i] is the size (in bytes) of a row from
+ // embedding_matrices_[i]. We precompute this in order to quickly find the
+ // beginning of the k-th row from an embedding matrix (which is stored in
+ // row-major order).
+ std::vector<int> embedding_row_size_in_bytes_;
+
+ // concat_offset_[i] is the input layer offset for i-th embedding space.
+ std::vector<int> concat_offset_;
+
+ // Size of the input ("concatenation") layer.
+ int concat_layer_size_ = 0;
+
+ // One weight matrix and one vector of bias weights for each layer of neurons.
+ // Last layer is the softmax layer, the previous ones are the hidden layers.
+ std::vector<EmbeddingNetworkParams::Matrix> layer_weights_;
+ std::vector<EmbeddingNetworkParams::Matrix> layer_bias_;
+};
+
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_NETWORK_H_
diff --git a/lang_id/common/fel/feature-descriptors.cc b/lang_id/common/fel/feature-descriptors.cc
new file mode 100644
index 0000000..bf03dd5
--- /dev/null
+++ b/lang_id/common/fel/feature-descriptors.cc
@@ -0,0 +1,78 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/common/fel/feature-descriptors.h"
+
+#include "lang_id/common/lite_strings/str-cat.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+void ToFELFunction(const FeatureFunctionDescriptor &function, string *output) {
+ LiteStrAppend(output, function.type());
+ if (function.argument() != 0 || function.parameter_size() > 0) {
+ LiteStrAppend(output, "(");
+ bool first = true;
+ if (function.argument() != 0) {
+ LiteStrAppend(output, function.argument());
+ first = false;
+ }
+ for (int i = 0; i < function.parameter_size(); ++i) {
+ if (!first) LiteStrAppend(output, ",");
+ LiteStrAppend(output, function.parameter(i).name(), "=\"",
+ function.parameter(i).value(), "\"");
+ first = false;
+ }
+ LiteStrAppend(output, ")");
+ }
+}
+
+void ToFEL(const FeatureFunctionDescriptor &function, string *output) {
+ ToFELFunction(function, output);
+ if (function.feature_size() == 1) {
+ LiteStrAppend(output, ".");
+ ToFEL(function.feature(0), output);
+ } else if (function.feature_size() > 1) {
+ LiteStrAppend(output, " { ");
+ for (int i = 0; i < function.feature_size(); ++i) {
+ if (i > 0) LiteStrAppend(output, " ");
+ ToFEL(function.feature(i), output);
+ }
+ LiteStrAppend(output, " } ");
+ }
+}
+
+void ToFEL(const FeatureExtractorDescriptor &extractor, string *output) {
+ for (int i = 0; i < extractor.feature_size(); ++i) {
+ ToFEL(extractor.feature(i), output);
+ LiteStrAppend(output, "\n");
+ }
+}
+
+string FeatureFunctionDescriptor::DebugString() const {
+ string str;
+ ToFEL(*this, &str);
+ return str;
+}
+
+string FeatureExtractorDescriptor::DebugString() const {
+ string str;
+ ToFEL(*this, &str);
+ return str;
+}
+
+} // namespace mobile
+} // namespace nlp_saft
diff --git a/lang_id/common/fel/feature-descriptors.h b/lang_id/common/fel/feature-descriptors.h
new file mode 100644
index 0000000..a9408c9
--- /dev/null
+++ b/lang_id/common/fel/feature-descriptors.h
@@ -0,0 +1,159 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_FEATURE_DESCRIPTORS_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_FEATURE_DESCRIPTORS_H_
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "lang_id/common/lite_base/integral-types.h"
+#include "lang_id/common/lite_base/logging.h"
+#include "lang_id/common/lite_base/macros.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+// Named feature parameter.
+class Parameter {
+ public:
+ Parameter() {}
+
+ void set_name(const string &name) { name_ = name; }
+ const string &name() const { return name_; }
+
+ void set_value(const string &value) { value_ = value; }
+ const string &value() const { return value_; }
+
+ private:
+ string name_;
+ string value_;
+};
+
+// Descriptor for a feature function. Used to store the results of parsing one
+// feature function.
+class FeatureFunctionDescriptor {
+ public:
+ FeatureFunctionDescriptor() {}
+
+ // Accessors for the feature function type. The function type is the string
+ // that the feature extractor code is registered under.
+ void set_type(const string &type) { type_ = type; }
+ const string &type() const { return type_; }
+
+ // Accessors for the feature function name. The function name (if available)
+ // is used for some log messages. Otherwise, a more precise, but also more
+ // verbose name based on the feature specification is used.
+ void set_name(const string &name) { name_ = name; }
+ const string &name() const { return name_; }
+
+ // Accessors for the default (name-less) parameter.
+ void set_argument(int32 argument) { argument_ = argument; }
+ bool has_argument() const {
+ // If argument has not been specified, clients should treat it as 0. This
+ // makes the test below correct, without having a separate has_argument_
+ // bool field.
+ return argument_ != 0;
+ }
+ int32 argument() const { return argument_; }
+
+ // Accessors for the named parameters.
+ Parameter *add_parameter() {
+ parameters_.emplace_back();
+ return &(parameters_.back());
+ }
+ int parameter_size() const { return parameters_.size(); }
+ const Parameter ¶meter(int i) const {
+ SAFTM_DCHECK((i >= 0) && (i < parameter_size()));
+ return parameters_[i];
+ }
+
+ // Accessors for the sub (i.e., nested) features. Nested features: as in
+ // offset(1).label.
+ FeatureFunctionDescriptor *add_feature() {
+ sub_features_.emplace_back(new FeatureFunctionDescriptor());
+ return sub_features_.back().get();
+ }
+ int feature_size() const { return sub_features_.size(); }
+ const FeatureFunctionDescriptor &feature(int i) const {
+ SAFTM_DCHECK((i >= 0) && (i < feature_size()));
+ return *(sub_features_[i].get());
+ }
+
+ // Returns human-readable representation of this FeatureFunctionDescriptor.
+ string DebugString() const;
+
+ private:
+ // See comments for set_type().
+ string type_;
+
+ // See comments for set_name().
+ string name_;
+
+ // See comments for set_argument().
+ int32 argument_ = 0;
+
+ // See comemnts for add_parameter().
+ std::vector<Parameter> parameters_;
+
+ // See comments for add_feature().
+ std::vector<std::unique_ptr<FeatureFunctionDescriptor>> sub_features_;
+
+ SAFTM_DISALLOW_COPY_AND_ASSIGN(FeatureFunctionDescriptor);
+};
+
+// List of FeatureFunctionDescriptors. Used to store the result of parsing the
+// spec for several feature functions.
+class FeatureExtractorDescriptor {
+ public:
+ FeatureExtractorDescriptor() {}
+
+ int feature_size() const { return features_.size(); }
+
+ FeatureFunctionDescriptor *add_feature() {
+ features_.emplace_back(new FeatureFunctionDescriptor());
+ return features_.back().get();
+ }
+
+ const FeatureFunctionDescriptor &feature(int i) const {
+ SAFTM_DCHECK((i >= 0) && (i < feature_size()));
+ return *(features_[i].get());
+ }
+
+ // Returns human-readable representation of this FeatureExtractorDescriptor.
+ string DebugString() const;
+
+ private:
+ std::vector<std::unique_ptr<FeatureFunctionDescriptor>> features_;
+
+ SAFTM_DISALLOW_COPY_AND_ASSIGN(FeatureExtractorDescriptor);
+};
+
+// Appends to |*output| the FEL representation of the top-level feature from
+// |function|, without diving into the nested features.
+void ToFELFunction(const FeatureFunctionDescriptor &function, string *output);
+
+// Appends to |*output| the FEL representation of |function|.
+void ToFEL(const FeatureFunctionDescriptor &function, string *output);
+
+// Appends to |*output| the FEL representation of |extractor|.
+void ToFEL(const FeatureExtractorDescriptor &extractor, string *output);
+
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_FEATURE_DESCRIPTORS_H_
diff --git a/lang_id/common/fel/feature-extractor.cc b/lang_id/common/fel/feature-extractor.cc
new file mode 100644
index 0000000..c256257
--- /dev/null
+++ b/lang_id/common/fel/feature-extractor.cc
@@ -0,0 +1,139 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/common/fel/feature-extractor.h"
+
+#include "lang_id/common/fel/feature-types.h"
+#include "lang_id/common/fel/fel-parser.h"
+#include "lang_id/common/lite_base/logging.h"
+#include "lang_id/common/lite_strings/numbers.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+constexpr FeatureValue GenericFeatureFunction::kNone;
+
+GenericFeatureExtractor::GenericFeatureExtractor() {}
+
+GenericFeatureExtractor::~GenericFeatureExtractor() {}
+
+bool GenericFeatureExtractor::Parse(const string &source) {
+ // Parse feature specification into descriptor.
+ FELParser parser;
+
+ if (!parser.Parse(source, mutable_descriptor())) {
+ SAFTM_LOG(ERROR) << "Error parsing the FEL spec " << source;
+ return false;
+ }
+
+ // Initialize feature extractor from descriptor.
+ return InitializeFeatureFunctions();
+}
+
+bool GenericFeatureExtractor::InitializeFeatureTypes() {
+ // Register all feature types.
+ GetFeatureTypes(&feature_types_);
+ for (size_t i = 0; i < feature_types_.size(); ++i) {
+ FeatureType *ft = feature_types_[i];
+ ft->set_base(i);
+
+ // Check for feature space overflow.
+ double domain_size = ft->GetDomainSize();
+ if (domain_size < 0) {
+ SAFTM_LOG(ERROR) << "Illegal domain size for feature " << ft->name()
+ << ": " << domain_size;
+ return false;
+ }
+ }
+ return true;
+}
+
+string GenericFeatureFunction::GetParameter(const string &name,
+ const string &default_value) const {
+ // Find named parameter in feature descriptor.
+ for (int i = 0; i < descriptor_->parameter_size(); ++i) {
+ if (name == descriptor_->parameter(i).name()) {
+ return descriptor_->parameter(i).value();
+ }
+ }
+ return default_value;
+}
+
+GenericFeatureFunction::GenericFeatureFunction() {}
+
+GenericFeatureFunction::~GenericFeatureFunction() { delete feature_type_; }
+
+int GenericFeatureFunction::GetIntParameter(const string &name,
+ int default_value) const {
+ string value_str = GetParameter(name, "");
+ if (value_str.empty()) {
+ // Parameter not specified, use default value for it.
+ return default_value;
+ }
+ int value = 0;
+ if (!LiteAtoi(value_str, &value)) {
+ SAFTM_LOG(DFATAL) << "Unable to parse '" << value_str
+ << "' as int for parameter " << name;
+ return default_value;
+ }
+ return value;
+}
+
+bool GenericFeatureFunction::GetBoolParameter(const string &name,
+ bool default_value) const {
+ string value = GetParameter(name, "");
+ if (value.empty()) return default_value;
+ if (value == "true") return true;
+ if (value == "false") return false;
+ SAFTM_LOG(DFATAL) << "Illegal value '" << value << "' for bool parameter "
+ << name;
+ return default_value;
+}
+
+void GenericFeatureFunction::GetFeatureTypes(
+ std::vector<FeatureType *> *types) const {
+ if (feature_type_ != nullptr) types->push_back(feature_type_);
+}
+
+FeatureType *GenericFeatureFunction::GetFeatureType() const {
+ // If a single feature type has been registered return it.
+ if (feature_type_ != nullptr) return feature_type_;
+
+ // Get feature types for function.
+ std::vector<FeatureType *> types;
+ GetFeatureTypes(&types);
+
+ // If there is exactly one feature type return this, else return null.
+ if (types.size() == 1) return types[0];
+ return nullptr;
+}
+
+string GenericFeatureFunction::name() const {
+ string output;
+ if (descriptor_->name().empty()) {
+ if (!prefix_.empty()) {
+ output.append(prefix_);
+ output.append(".");
+ }
+ ToFEL(*descriptor_, &output);
+ } else {
+ output = descriptor_->name();
+ }
+ return output;
+}
+
+} // namespace mobile
+} // namespace nlp_saft
diff --git a/lang_id/common/fel/feature-extractor.h b/lang_id/common/fel/feature-extractor.h
new file mode 100644
index 0000000..8763852
--- /dev/null
+++ b/lang_id/common/fel/feature-extractor.h
@@ -0,0 +1,651 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Generic feature extractor for extracting features from objects. The feature
+// extractor can be used for extracting features from any object. The feature
+// extractor and feature function classes are template classes that have to
+// be instantiated for extracting feature from a specific object type.
+//
+// A feature extractor consists of a hierarchy of feature functions. Each
+// feature function extracts one or more feature type and value pairs from the
+// object.
+//
+// The feature extractor has a modular design where new feature functions can be
+// registered as components. The feature extractor is initialized from a
+// descriptor represented by a protocol buffer. The feature extractor can also
+// be initialized from a text-based source specification of the feature
+// extractor. Feature specification parsers can be added as components. By
+// default the feature extractor can be read from an ASCII protocol buffer or in
+// a simple feature modeling language (fml).
+
+// A feature function is invoked with a focus. Nested feature function can be
+// invoked with another focus determined by the parent feature function.
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_FEATURE_EXTRACTOR_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_FEATURE_EXTRACTOR_H_
+
+#include <stddef.h>
+
+#include <string>
+#include <vector>
+
+#include "lang_id/common/fel/feature-descriptors.h"
+#include "lang_id/common/fel/feature-types.h"
+#include "lang_id/common/fel/task-context.h"
+#include "lang_id/common/fel/workspace.h"
+#include "lang_id/common/lite_base/attributes.h"
+#include "lang_id/common/lite_base/integral-types.h"
+#include "lang_id/common/lite_base/logging.h"
+#include "lang_id/common/lite_base/macros.h"
+#include "lang_id/common/registry.h"
+#include "lang_id/common/stl-util.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+// TODO(djweiss) Clean this up as well.
+// Use the same type for feature values as is used for predicated.
+typedef int64 Predicate;
+typedef Predicate FeatureValue;
+
+// A union used to represent discrete and continuous feature values.
+union FloatFeatureValue {
+ public:
+ explicit FloatFeatureValue(FeatureValue v) : discrete_value(v) {}
+ FloatFeatureValue(uint32 i, float w) : id(i), weight(w) {}
+ FeatureValue discrete_value;
+ struct {
+ uint32 id;
+ float weight;
+ };
+};
+
+// A feature vector contains feature type and value pairs.
+class FeatureVector {
+ public:
+ FeatureVector() {}
+
+ // Adds feature type and value pair to feature vector.
+ void add(FeatureType *type, FeatureValue value) {
+ features_.emplace_back(type, value);
+ }
+
+ // Removes all elements from the feature vector.
+ void clear() { features_.clear(); }
+
+ // Returns the number of elements in the feature vector.
+ int size() const { return features_.size(); }
+
+ // Reserves space in the underlying feature vector.
+ void reserve(int n) { features_.reserve(n); }
+
+ // Returns feature type for an element in the feature vector.
+ FeatureType *type(int index) const { return features_[index].type; }
+
+ // Returns feature value for an element in the feature vector.
+ FeatureValue value(int index) const { return features_[index].value; }
+
+ private:
+ // Structure for holding feature type and value pairs.
+ struct Element {
+ Element() : type(nullptr), value(-1) {}
+ Element(FeatureType *t, FeatureValue v) : type(t), value(v) {}
+
+ FeatureType *type;
+ FeatureValue value;
+ };
+
+ // Array for storing feature vector elements.
+ std::vector<Element> features_;
+
+ SAFTM_DISALLOW_COPY_AND_ASSIGN(FeatureVector);
+};
+
+// The generic feature extractor is the type-independent part of a feature
+// extractor. This holds the descriptor for the feature extractor and the
+// collection of feature types used in the feature extractor. The feature
+// types are not available until FeatureExtractor<>::Init() has been called.
+class GenericFeatureExtractor {
+ public:
+ GenericFeatureExtractor();
+ virtual ~GenericFeatureExtractor();
+
+ // Initializes the feature extractor from the FEL specification |source|.
+ //
+ // Returns true on success, false otherwise (e.g., FEL syntax error).
+ SAFTM_MUST_USE_RESULT bool Parse(const string &source);
+
+ // Returns the feature extractor descriptor.
+ const FeatureExtractorDescriptor &descriptor() const { return descriptor_; }
+ FeatureExtractorDescriptor *mutable_descriptor() { return &descriptor_; }
+
+ // Returns the number of feature types in the feature extractor. Invalid
+ // before Init() has been called.
+ int feature_types() const { return feature_types_.size(); }
+
+ protected:
+ // Initializes the feature types used by the extractor. Called from
+ // FeatureExtractor<>::Init().
+ //
+ // Returns true on success, false otherwise.
+ SAFTM_MUST_USE_RESULT bool InitializeFeatureTypes();
+
+ private:
+ // Initializes the top-level feature functions.
+ //
+ // Returns true on success, false otherwise.
+ SAFTM_MUST_USE_RESULT virtual bool InitializeFeatureFunctions() = 0;
+
+ // Returns all feature types used by the extractor. The feature types are
+ // added to the result array.
+ virtual void GetFeatureTypes(std::vector<FeatureType *> *types) const = 0;
+
+ // Descriptor for the feature extractor. This is a protocol buffer that
+ // contains all the information about the feature extractor. The feature
+ // functions are initialized from the information in the descriptor.
+ FeatureExtractorDescriptor descriptor_;
+
+ // All feature types used by the feature extractor. The collection of all the
+ // feature types describes the feature space of the feature set produced by
+ // the feature extractor. Not owned.
+ std::vector<FeatureType *> feature_types_;
+};
+
+// The generic feature function is the type-independent part of a feature
+// function. Each feature function is associated with the descriptor that it is
+// instantiated from. The feature types associated with this feature function
+// will be established by the time FeatureExtractor<>::Init() completes.
+class GenericFeatureFunction {
+ public:
+ // A feature value that represents the absence of a value.
+ static constexpr FeatureValue kNone = -1;
+
+ GenericFeatureFunction();
+ virtual ~GenericFeatureFunction();
+
+ // Sets up the feature function. NB: FeatureTypes of nested functions are not
+ // guaranteed to be available until Init().
+ //
+ // Returns true on success, false otherwise.
+ SAFTM_MUST_USE_RESULT virtual bool Setup(TaskContext *context) {
+ return true;
+ }
+
+ // Initializes the feature function. NB: The FeatureType of this function must
+ // be established when this method completes.
+ //
+ // Returns true on success, false otherwise.
+ SAFTM_MUST_USE_RESULT virtual bool Init(TaskContext *context) { return true; }
+
+ // Requests workspaces from a registry to obtain indices into a WorkspaceSet
+ // for any Workspace objects used by this feature function. NB: This will be
+ // called after Init(), so it can depend on resources and arguments.
+ virtual void RequestWorkspaces(WorkspaceRegistry *registry) {}
+
+ // Appends the feature types produced by the feature function to types. The
+ // default implementation appends feature_type(), if non-null. Invalid
+ // before Init() has been called.
+ virtual void GetFeatureTypes(std::vector<FeatureType *> *types) const;
+
+ // Returns the feature type for feature produced by this feature function. If
+ // the feature function produces features of different types this returns
+ // null. Invalid before Init() has been called.
+ virtual FeatureType *GetFeatureType() const;
+
+ // Returns value of parameter |name| from the feature function descriptor.
+ // If the parameter is not present, returns the indicated |default_value|.
+ string GetParameter(const string &name, const string &default_value) const;
+
+ // Returns value of int parameter |name| from feature function descriptor.
+ // If the parameter is not present, or its value can't be parsed as an int,
+ // returns |default_value|.
+ int GetIntParameter(const string &name, int default_value) const;
+
+ // Returns value of bool parameter |name| from feature function descriptor.
+ // If the parameter is not present, or its value is not "true" or "false",
+ // returns |default_value|. NOTE: this method is case sensitive, it doesn't
+ // do any lower-casing.
+ bool GetBoolParameter(const string &name, bool default_value) const;
+
+ // Returns the FEL function description for the feature function, i.e. the
+ // name and parameters without the nested features.
+ string FunctionName() const {
+ string output;
+ ToFELFunction(*descriptor_, &output);
+ return output;
+ }
+
+ // Returns the prefix for nested feature functions. This is the prefix of this
+ // feature function concatenated with the feature function name.
+ string SubPrefix() const {
+ return prefix_.empty() ? FunctionName() : prefix_ + "." + FunctionName();
+ }
+
+ // Returns/sets the feature extractor this function belongs to.
+ const GenericFeatureExtractor *extractor() const { return extractor_; }
+ void set_extractor(const GenericFeatureExtractor *extractor) {
+ extractor_ = extractor;
+ }
+
+ // Returns/sets the feature function descriptor.
+ const FeatureFunctionDescriptor *descriptor() const { return descriptor_; }
+ void set_descriptor(const FeatureFunctionDescriptor *descriptor) {
+ descriptor_ = descriptor;
+ }
+
+ // Returns a descriptive name for the feature function. The name is taken from
+ // the descriptor for the feature function. If the name is empty or the
+ // feature function is a variable the name is the FEL representation of the
+ // feature, including the prefix.
+ string name() const;
+
+ // Returns the argument from the feature function descriptor. It defaults to
+ // 0 if the argument has not been specified.
+ int argument() const {
+ return descriptor_->has_argument() ? descriptor_->argument() : 0;
+ }
+
+ // Returns/sets/clears function name prefix.
+ const string &prefix() const { return prefix_; }
+ void set_prefix(const string &prefix) { prefix_ = prefix; }
+
+ protected:
+ // Returns the feature type for single-type feature functions.
+ FeatureType *feature_type() const { return feature_type_; }
+
+ // Sets the feature type for single-type feature functions. This takes
+ // ownership of feature_type. Can only be called once.
+ void set_feature_type(FeatureType *feature_type) {
+ SAFTM_CHECK_EQ(feature_type_, nullptr);
+ feature_type_ = feature_type;
+ }
+
+ private:
+ // Feature extractor this feature function belongs to. Not owned. Set to a
+ // pointer != nullptr as soon as this object is created by Instantiate().
+ // Normal methods can safely assume this is != nullptr.
+ const GenericFeatureExtractor *extractor_ = nullptr;
+
+ // Descriptor for feature function. Not owned. Set to a pointer != nullptr
+ // as soon as this object is created by Instantiate(). Normal methods can
+ // safely assume this is != nullptr.
+ const FeatureFunctionDescriptor *descriptor_ = nullptr;
+
+ // Feature type for features produced by this feature function. If the
+ // feature function produces features of multiple feature types this is null
+ // and the feature function must return it's feature types in
+ // GetFeatureTypes(). Owned.
+ FeatureType *feature_type_ = nullptr;
+
+ // Prefix used for sub-feature types of this function.
+ string prefix_;
+};
+
+// Feature function that can extract features from an object. Templated on
+// two type arguments:
+//
+// OBJ: The "object" from which features are extracted; e.g., a sentence. This
+// should be a plain type, rather than a reference or pointer.
+//
+// ARGS: A set of 0 or more types that are used to "index" into some part of the
+// object that should be extracted, e.g. an int token index for a sentence
+// object. This should not be a reference type.
+template <class OBJ, class... ARGS>
+class FeatureFunction
+ : public GenericFeatureFunction,
+ public RegisterableClass<FeatureFunction<OBJ, ARGS...> > {
+ public:
+ using Self = FeatureFunction<OBJ, ARGS...>;
+
+ // Preprocesses the object. This will be called prior to calling Evaluate()
+ // or Compute() on that object.
+ virtual void Preprocess(WorkspaceSet *workspaces, const OBJ *object) const {}
+
+ // Appends features computed from the object and focus to the result. The
+ // default implementation delegates to Compute(), adding a single value if
+ // available. Multi-valued feature functions must override this method.
+ virtual void Evaluate(const WorkspaceSet &workspaces, const OBJ &object,
+ ARGS... args, FeatureVector *result) const {
+ FeatureValue value = Compute(workspaces, object, args...);
+ if (value != kNone) result->add(feature_type(), value);
+ }
+
+ // Returns a feature value computed from the object and focus, or kNone if no
+ // value is computed. Single-valued feature functions only need to override
+ // this method.
+ virtual FeatureValue Compute(const WorkspaceSet &workspaces,
+ const OBJ &object, ARGS... args) const {
+ return kNone;
+ }
+
+ // Instantiates a new feature function in a feature extractor from a feature
+ // descriptor.
+ //
+ // Returns a pointer to the newly-created object if everything goes well.
+ // Returns nullptr if the feature function could not be instantiated (e.g., if
+ // the function with that name is not registered; this usually happens because
+ // the relevant cc_library was not linked-in).
+ static Self *Instantiate(const GenericFeatureExtractor *extractor,
+ const FeatureFunctionDescriptor *fd,
+ const string &prefix) {
+ Self *f = Self::Create(fd->type());
+ if (f != nullptr) {
+ f->set_extractor(extractor);
+ f->set_descriptor(fd);
+ f->set_prefix(prefix);
+ }
+ return f;
+ }
+
+ private:
+ // Special feature function class for resolving variable references. The type
+ // of the feature function is used for resolving the variable reference. When
+ // evaluated it will either get the feature value(s) from the variable portion
+ // of the feature vector, if present, or otherwise it will call the referenced
+ // feature extractor function directly to extract the feature(s).
+ class Reference;
+};
+
+// Base class for features with nested feature functions. The nested functions
+// are of type NES, which may be different from the type of the parent function.
+// NB: NestedFeatureFunction will ensure that all initialization of nested
+// functions takes place during Setup() and Init() -- after the nested features
+// are initialized, the parent feature is initialized via SetupNested() and
+// InitNested(). Alternatively, a derived classes that overrides Setup() and
+// Init() directly should call Parent::Setup(), Parent::Init(), etc. first.
+//
+// Note: NestedFeatureFunction cannot know how to call Preprocess, Evaluate, or
+// Compute, since the nested functions may be of a different type.
+template <class NES, class OBJ, class... ARGS>
+class NestedFeatureFunction : public FeatureFunction<OBJ, ARGS...> {
+ public:
+ using Parent = NestedFeatureFunction<NES, OBJ, ARGS...>;
+
+ // Clean up nested functions.
+ ~NestedFeatureFunction() override { utils::STLDeleteElements(&nested_); }
+
+ // By default, just appends the nested feature types.
+ void GetFeatureTypes(std::vector<FeatureType *> *types) const override {
+ SAFTM_CHECK(!this->nested().empty())
+ << "Nested features require nested features to be defined.";
+ for (auto *function : nested_) function->GetFeatureTypes(types);
+ }
+
+ // Sets up the nested features.
+ //
+ // Returns true on success, false otherwise.
+ SAFTM_MUST_USE_RESULT bool Setup(TaskContext *context) override {
+ bool success = CreateNested(this->extractor(), this->descriptor(), &nested_,
+ this->SubPrefix());
+ if (!success) return false;
+ for (auto *function : nested_) {
+ if (!function->Setup(context)) return false;
+ }
+ if (!SetupNested(context)) return false;
+ return true;
+ }
+
+ // Sets up this NestedFeatureFunction specifically.
+ //
+ // Returns true on success, false otherwise.
+ SAFTM_MUST_USE_RESULT virtual bool SetupNested(TaskContext *context) {
+ return true;
+ }
+
+ // Initializes the nested features.
+ //
+ // Returns true on success, false otherwise.
+ SAFTM_MUST_USE_RESULT bool Init(TaskContext *context) override {
+ for (auto *function : nested_) {
+ if (!function->Init(context)) return false;
+ }
+ if (!InitNested(context)) return false;
+ return true;
+ }
+
+ // Initializes this NestedFeatureFunction specifically.
+ //
+ // Returns true on success, false otherwise.
+ SAFTM_MUST_USE_RESULT virtual bool InitNested(TaskContext *context) {
+ return true;
+ }
+
+ // Gets all the workspaces needed for the nested functions.
+ void RequestWorkspaces(WorkspaceRegistry *registry) override {
+ for (auto *function : nested_) function->RequestWorkspaces(registry);
+ }
+
+ // Returns the list of nested feature functions.
+ const std::vector<NES *> &nested() const { return nested_; }
+
+ // Instantiates nested feature functions for a feature function. Creates and
+ // initializes one feature function for each sub-descriptor in the feature
+ // descriptor.
+ //
+ // Returns true on success, false otherwise.
+ SAFTM_MUST_USE_RESULT static bool CreateNested(
+ const GenericFeatureExtractor *extractor,
+ const FeatureFunctionDescriptor *fd, std::vector<NES *> *functions,
+ const string &prefix) {
+ for (int i = 0; i < fd->feature_size(); ++i) {
+ const FeatureFunctionDescriptor &sub = fd->feature(i);
+ NES *f = NES::Instantiate(extractor, &sub, prefix);
+ if (f == nullptr) return false;
+ functions->push_back(f);
+ }
+ return true;
+ }
+
+ protected:
+ // The nested feature functions, if any, in order of declaration in the
+ // feature descriptor. Owned.
+ std::vector<NES *> nested_;
+};
+
+// Base class for a nested feature function that takes nested features with the
+// same signature as these features, i.e. a meta feature. For this class, we can
+// provide preprocessing of the nested features.
+template <class OBJ, class... ARGS>
+class MetaFeatureFunction
+ : public NestedFeatureFunction<FeatureFunction<OBJ, ARGS...>, OBJ,
+ ARGS...> {
+ public:
+ // Preprocesses using the nested features.
+ void Preprocess(WorkspaceSet *workspaces, const OBJ *object) const override {
+ for (auto *function : this->nested_) {
+ function->Preprocess(workspaces, object);
+ }
+ }
+};
+
+// Template for a special type of locator: The locator of type
+// FeatureFunction<OBJ, ARGS...> calls nested functions of type
+// FeatureFunction<OBJ, IDX, ARGS...>, where the derived class DER is
+// responsible for translating by providing the following:
+//
+// // Gets the new additional focus.
+// IDX GetFocus(const WorkspaceSet &workspaces, const OBJ &object);
+//
+// This is useful to e.g. add a token focus to a parser state based on some
+// desired property of that state.
+template <class DER, class OBJ, class IDX, class... ARGS>
+class FeatureAddFocusLocator
+ : public NestedFeatureFunction<FeatureFunction<OBJ, IDX, ARGS...>, OBJ,
+ ARGS...> {
+ public:
+ void Preprocess(WorkspaceSet *workspaces, const OBJ *object) const override {
+ for (auto *function : this->nested_) {
+ function->Preprocess(workspaces, object);
+ }
+ }
+
+ void Evaluate(const WorkspaceSet &workspaces, const OBJ &object, ARGS... args,
+ FeatureVector *result) const override {
+ IDX focus =
+ static_cast<const DER *>(this)->GetFocus(workspaces, object, args...);
+ for (auto *function : this->nested()) {
+ function->Evaluate(workspaces, object, focus, args..., result);
+ }
+ }
+
+ // Returns the first nested feature's computed value.
+ FeatureValue Compute(const WorkspaceSet &workspaces, const OBJ &object,
+ ARGS... args) const override {
+ IDX focus =
+ static_cast<const DER *>(this)->GetFocus(workspaces, object, args...);
+ return this->nested()[0]->Compute(workspaces, object, focus, args...);
+ }
+};
+
+// CRTP feature locator class. This is a meta feature that modifies ARGS and
+// then calls the nested feature functions with the modified ARGS. Note that in
+// order for this template to work correctly, all of ARGS must be types for
+// which the reference operator & can be interpreted as a pointer to the
+// argument. The derived class DER must implement the UpdateFocus method which
+// takes pointers to the ARGS arguments:
+//
+// // Updates the current arguments.
+// void UpdateArgs(const OBJ &object, ARGS *...args) const;
+template <class DER, class OBJ, class... ARGS>
+class FeatureLocator : public MetaFeatureFunction<OBJ, ARGS...> {
+ public:
+ // Feature locators have an additional check that there is no intrinsic type.
+ void GetFeatureTypes(std::vector<FeatureType *> *types) const override {
+ SAFTM_CHECK_EQ(this->feature_type(), nullptr)
+ << "FeatureLocators should not have an intrinsic type.";
+ MetaFeatureFunction<OBJ, ARGS...>::GetFeatureTypes(types);
+ }
+
+ // Evaluates the locator.
+ void Evaluate(const WorkspaceSet &workspaces, const OBJ &object, ARGS... args,
+ FeatureVector *result) const override {
+ static_cast<const DER *>(this)->UpdateArgs(workspaces, object, &args...);
+ for (auto *function : this->nested()) {
+ function->Evaluate(workspaces, object, args..., result);
+ }
+ }
+
+ // Returns the first nested feature's computed value.
+ FeatureValue Compute(const WorkspaceSet &workspaces, const OBJ &object,
+ ARGS... args) const override {
+ static_cast<const DER *>(this)->UpdateArgs(workspaces, object, &args...);
+ return this->nested()[0]->Compute(workspaces, object, args...);
+ }
+};
+
+// Feature extractor for extracting features from objects of a certain class.
+// Template type parameters are as defined for FeatureFunction.
+template <class OBJ, class... ARGS>
+class FeatureExtractor : public GenericFeatureExtractor {
+ public:
+ // Feature function type for top-level functions in the feature extractor.
+ typedef FeatureFunction<OBJ, ARGS...> Function;
+ typedef FeatureExtractor<OBJ, ARGS...> Self;
+
+ // Feature locator type for the feature extractor.
+ template <class DER>
+ using Locator = FeatureLocator<DER, OBJ, ARGS...>;
+
+ // Initializes feature extractor.
+ FeatureExtractor() {}
+
+ ~FeatureExtractor() override { utils::STLDeleteElements(&functions_); }
+
+ // Sets up the feature extractor. Note that only top-level functions exist
+ // until Setup() is called. This does not take ownership over the context,
+ // which must outlive this.
+ //
+ // Returns true on success, false otherwise.
+ SAFTM_MUST_USE_RESULT bool Setup(TaskContext *context) {
+ for (Function *function : functions_) {
+ if (!function->Setup(context)) return false;
+ }
+ return true;
+ }
+
+ // Initializes the feature extractor. Must be called after Setup(). This
+ // does not take ownership over the context, which must outlive this.
+ //
+ // Returns true on success, false otherwise.
+ SAFTM_MUST_USE_RESULT bool Init(TaskContext *context) {
+ for (Function *function : functions_) {
+ if (!function->Init(context)) return false;
+ }
+ if (!this->InitializeFeatureTypes()) return false;
+ return true;
+ }
+
+ // Requests workspaces from the registry. Must be called after Init(), and
+ // before Preprocess(). Does not take ownership over registry. This should be
+ // the same registry used to initialize the WorkspaceSet used in Preprocess()
+ // and ExtractFeatures(). NB: This is a different ordering from that used in
+ // SentenceFeatureRepresentation style feature computation.
+ void RequestWorkspaces(WorkspaceRegistry *registry) {
+ for (auto *function : functions_) function->RequestWorkspaces(registry);
+ }
+
+ // Preprocesses the object using feature functions for the phase. Must be
+ // called before any calls to ExtractFeatures() on that object and phase.
+ void Preprocess(WorkspaceSet *workspaces, const OBJ *object) const {
+ for (Function *function : functions_) {
+ function->Preprocess(workspaces, object);
+ }
+ }
+
+ // Extracts features from an object with a focus. This invokes all the
+ // top-level feature functions in the feature extractor. Only feature
+ // functions belonging to the specified phase are invoked.
+ void ExtractFeatures(const WorkspaceSet &workspaces, const OBJ &object,
+ ARGS... args, FeatureVector *result) const {
+ result->reserve(this->feature_types());
+
+ // Extract features.
+ for (int i = 0; i < functions_.size(); ++i) {
+ functions_[i]->Evaluate(workspaces, object, args..., result);
+ }
+ }
+
+ private:
+ // Creates and initializes all feature functions in the feature extractor.
+ //
+ // Returns true on success, false otherwise.
+ SAFTM_MUST_USE_RESULT bool InitializeFeatureFunctions() override {
+ // Create all top-level feature functions.
+ for (int i = 0; i < descriptor().feature_size(); ++i) {
+ const FeatureFunctionDescriptor &fd = descriptor().feature(i);
+ Function *function = Function::Instantiate(this, &fd, "");
+ if (function == nullptr) return false;
+ functions_.push_back(function);
+ }
+ return true;
+ }
+
+ // Collect all feature types used in the feature extractor.
+ void GetFeatureTypes(std::vector<FeatureType *> *types) const override {
+ for (int i = 0; i < functions_.size(); ++i) {
+ functions_[i]->GetFeatureTypes(types);
+ }
+ }
+
+ // Top-level feature functions (and variables) in the feature extractor.
+ // Owned.
+ std::vector<Function *> functions_;
+};
+
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_FEATURE_EXTRACTOR_H_
diff --git a/lang_id/common/fel/feature-types.h b/lang_id/common/fel/feature-types.h
new file mode 100644
index 0000000..18cf69a
--- /dev/null
+++ b/lang_id/common/fel/feature-types.h
@@ -0,0 +1,189 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Common feature types for parser components.
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_FEATURE_TYPES_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_FEATURE_TYPES_H_
+
+#include <algorithm>
+#include <map>
+#include <string>
+#include <utility>
+
+#include "lang_id/common/lite_base/integral-types.h"
+#include "lang_id/common/lite_base/logging.h"
+#include "lang_id/common/lite_strings/str-cat.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+// TODO(djweiss) Clean this up as well.
+// Use the same type for feature values as is used for predicated.
+typedef int64 Predicate;
+typedef Predicate FeatureValue;
+
+// Each feature value in a feature vector has a feature type. The feature type
+// is used for converting feature type and value pairs to predicate values. The
+// feature type can also return names for feature values and calculate the size
+// of the feature value domain. The FeatureType class is abstract and must be
+// specialized for the concrete feature types.
+class FeatureType {
+ public:
+ // Initializes a feature type.
+ explicit FeatureType(const string &name)
+ : name_(name), base_(0),
+ is_continuous_(name.find("continuous") != string::npos) {
+ }
+
+ virtual ~FeatureType() {}
+
+ // Converts a feature value to a name.
+ virtual string GetFeatureValueName(FeatureValue value) const = 0;
+
+ // Returns the size of the feature values domain.
+ virtual int64 GetDomainSize() const = 0;
+
+ // Returns the feature type name.
+ const string &name() const { return name_; }
+
+ Predicate base() const { return base_; }
+ void set_base(Predicate base) { base_ = base; }
+
+ // Returns true iff this feature is continuous; see FloatFeatureValue.
+ bool is_continuous() const { return is_continuous_; }
+
+ private:
+ // Feature type name.
+ string name_;
+
+ // "Base" feature value: i.e. a "slot" in a global ordering of features.
+ Predicate base_;
+
+ // See doc for is_continuous().
+ bool is_continuous_;
+};
+
+// Feature type that is defined using an explicit map from FeatureValue to
+// string values. This can reduce some of the boilerplate when defining
+// features that generate enum values. Example usage:
+//
+// class BeverageSizeFeature : public FeatureFunction<Beverage>
+// enum FeatureValue { SMALL, MEDIUM, LARGE }; // values for this feature
+// void Init(TaskContext *context) override {
+// set_feature_type(new EnumFeatureType("beverage_size",
+// {{SMALL, "SMALL"}, {MEDIUM, "MEDIUM"}, {LARGE, "LARGE"}});
+// }
+// [...]
+// };
+class EnumFeatureType : public FeatureType {
+ public:
+ EnumFeatureType(const string &name,
+ const std::map<FeatureValue, string> &value_names)
+ : FeatureType(name), value_names_(value_names) {
+ for (const auto &pair : value_names) {
+ SAFTM_CHECK_GE(pair.first, 0)
+ << "Invalid feature value: " << pair.first << ", " << pair.second;
+ domain_size_ = std::max(domain_size_, pair.first + 1);
+ }
+ }
+
+ // Returns the feature name for a given feature value.
+ string GetFeatureValueName(FeatureValue value) const override {
+ auto it = value_names_.find(value);
+ if (it == value_names_.end()) {
+ SAFTM_LOG(ERROR) << "Invalid feature value " << value << " for "
+ << name();
+ return "<INVALID>";
+ }
+ return it->second;
+ }
+
+ // Returns the number of possible values for this feature type. This is one
+ // greater than the largest value in the value_names map.
+ FeatureValue GetDomainSize() const override { return domain_size_; }
+
+ protected:
+ // Maximum possible value this feature could take.
+ FeatureValue domain_size_ = 0;
+
+ // Names of feature values.
+ std::map<FeatureValue, string> value_names_;
+};
+
+// Feature type for binary features.
+class BinaryFeatureType : public FeatureType {
+ public:
+ BinaryFeatureType(const string &name, const string &off, const string &on)
+ : FeatureType(name), off_(off), on_(on) {}
+
+ // Returns the feature name for a given feature value.
+ string GetFeatureValueName(FeatureValue value) const override {
+ if (value == 0) return off_;
+ if (value == 1) return on_;
+ return "";
+ }
+
+ // Binary features always have two feature values.
+ FeatureValue GetDomainSize() const override { return 2; }
+
+ private:
+ // Feature value names for on and off.
+ string off_;
+ string on_;
+};
+
+// Feature type for numeric features.
+class NumericFeatureType : public FeatureType {
+ public:
+ // Initializes numeric feature.
+ NumericFeatureType(const string &name, FeatureValue size)
+ : FeatureType(name), size_(size) {}
+
+ // Returns numeric feature value.
+ string GetFeatureValueName(FeatureValue value) const override {
+ if (value < 0) return "";
+ return LiteStrCat(value);
+ }
+
+ // Returns the number of feature values.
+ FeatureValue GetDomainSize() const override { return size_; }
+
+ private:
+ // The underlying size of the numeric feature.
+ FeatureValue size_;
+};
+
+// Feature type for byte features, including an "outside" value.
+class ByteFeatureType : public NumericFeatureType {
+ public:
+ explicit ByteFeatureType(const string &name)
+ : NumericFeatureType(name, 257) {}
+
+ string GetFeatureValueName(FeatureValue value) const override {
+ if (value == 256) {
+ return "<NULL>";
+ }
+ string result;
+ result += static_cast<char>(value);
+ return result;
+ }
+};
+
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_FEATURE_TYPES_H_
diff --git a/lang_id/common/fel/fel-parser.cc b/lang_id/common/fel/fel-parser.cc
new file mode 100644
index 0000000..4346fb7
--- /dev/null
+++ b/lang_id/common/fel/fel-parser.cc
@@ -0,0 +1,289 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/common/fel/fel-parser.h"
+
+#include <ctype.h>
+#include <string>
+
+#include "lang_id/common/lite_base/logging.h"
+#include "lang_id/common/lite_strings/numbers.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+namespace {
+inline bool IsValidCharAtStartOfIdentifier(char c) {
+ return isalpha(c) || (c == '_') || (c == '/');
+}
+
+// Returns true iff character c can appear inside an identifier.
+inline bool IsValidCharInsideIdentifier(char c) {
+ return isalnum(c) || (c == '_') || (c == '-') || (c == '/');
+}
+
+// Returns true iff character c can appear at the beginning of a number.
+inline bool IsValidCharAtStartOfNumber(char c) {
+ return isdigit(c) || (c == '+') || (c == '-');
+}
+
+// Returns true iff character c can appear inside a number.
+inline bool IsValidCharInsideNumber(char c) {
+ return isdigit(c) || (c == '.');
+}
+} // namespace
+
+bool FELParser::Initialize(const string &source) {
+ // Initialize parser state.
+ source_ = source;
+ current_ = source_.begin();
+ item_start_ = line_start_ = current_;
+ line_number_ = item_line_number_ = 1;
+
+ // Read first input item.
+ return NextItem();
+}
+
+void FELParser::ReportError(const string &error_message) {
+ const int position = item_start_ - line_start_ + 1;
+ const string line(line_start_, current_);
+
+ SAFTM_LOG(ERROR) << "Error in feature model, line " << item_line_number_
+ << ", position " << position << ": " << error_message
+ << "\n " << line << " <--HERE";
+}
+
+void FELParser::Next() {
+ // Move to the next input character. If we are at a line break update line
+ // number and line start position.
+ if (CurrentChar() == '\n') {
+ ++line_number_;
+ ++current_;
+ line_start_ = current_;
+ } else {
+ ++current_;
+ }
+}
+
+bool FELParser::NextItem() {
+ // Skip white space and comments.
+ while (!eos()) {
+ if (CurrentChar() == '#') {
+ // Skip comment.
+ while (!eos() && CurrentChar() != '\n') Next();
+ } else if (isspace(CurrentChar())) {
+ // Skip whitespace.
+ while (!eos() && isspace(CurrentChar())) Next();
+ } else {
+ break;
+ }
+ }
+
+ // Record start position for next item.
+ item_start_ = current_;
+ item_line_number_ = line_number_;
+
+ // Check for end of input.
+ if (eos()) {
+ item_type_ = END;
+ return true;
+ }
+
+ // Parse number.
+ if (IsValidCharAtStartOfNumber(CurrentChar())) {
+ string::iterator start = current_;
+ Next();
+ while (!eos() && IsValidCharInsideNumber(CurrentChar())) Next();
+ item_text_.assign(start, current_);
+ item_type_ = NUMBER;
+ return true;
+ }
+
+ // Parse string.
+ if (CurrentChar() == '"') {
+ Next();
+ string::iterator start = current_;
+ while (CurrentChar() != '"') {
+ if (eos()) {
+ ReportError("Unterminated string");
+ return false;
+ }
+ Next();
+ }
+ item_text_.assign(start, current_);
+ item_type_ = STRING;
+ Next();
+ return true;
+ }
+
+ // Parse identifier name.
+ if (IsValidCharAtStartOfIdentifier(CurrentChar())) {
+ string::iterator start = current_;
+ while (!eos() && IsValidCharInsideIdentifier(CurrentChar())) {
+ Next();
+ }
+ item_text_.assign(start, current_);
+ item_type_ = NAME;
+ return true;
+ }
+
+ // Single character item.
+ item_type_ = CurrentChar();
+ Next();
+ return true;
+}
+
+bool FELParser::Parse(const string &source,
+ FeatureExtractorDescriptor *result) {
+ // Initialize parser.
+ if (!Initialize(source)) {
+ return false;
+ }
+
+ while (item_type_ != END) {
+ // Current item should be a feature name.
+ if (item_type_ != NAME) {
+ ReportError("Feature type name expected");
+ return false;
+ }
+ string name = item_text_;
+ if (!NextItem()) {
+ return false;
+ }
+
+ if (item_type_ == '=') {
+ ReportError("Invalid syntax: feature expected");
+ return false;
+ } else {
+ // Parse feature.
+ FeatureFunctionDescriptor *descriptor = result->add_feature();
+ descriptor->set_type(name);
+ if (!ParseFeature(descriptor)) {
+ return false;
+ }
+ }
+ }
+
+ return true;
+}
+
+bool FELParser::ParseFeature(FeatureFunctionDescriptor *result) {
+ // Parse argument and parameters.
+ if (item_type_ == '(') {
+ if (!NextItem()) return false;
+ if (!ParseParameter(result)) return false;
+ while (item_type_ == ',') {
+ if (!NextItem()) return false;
+ if (!ParseParameter(result)) return false;
+ }
+
+ if (item_type_ != ')') {
+ ReportError(") expected");
+ return false;
+ }
+ if (!NextItem()) return false;
+ }
+
+ // Parse feature name.
+ if (item_type_ == ':') {
+ if (!NextItem()) return false;
+ if (item_type_ != NAME && item_type_ != STRING) {
+ ReportError("Feature name expected");
+ return false;
+ }
+ string name = item_text_;
+ if (!NextItem()) return false;
+
+ // Set feature name.
+ result->set_name(name);
+ }
+
+ // Parse sub-features.
+ if (item_type_ == '.') {
+ // Parse dotted sub-feature.
+ if (!NextItem()) return false;
+ if (item_type_ != NAME) {
+ ReportError("Feature type name expected");
+ return false;
+ }
+ string type = item_text_;
+ if (!NextItem()) return false;
+
+ // Parse sub-feature.
+ FeatureFunctionDescriptor *subfeature = result->add_feature();
+ subfeature->set_type(type);
+ if (!ParseFeature(subfeature)) return false;
+ } else if (item_type_ == '{') {
+ // Parse sub-feature block.
+ if (!NextItem()) return false;
+ while (item_type_ != '}') {
+ if (item_type_ != NAME) {
+ ReportError("Feature type name expected");
+ return false;
+ }
+ string type = item_text_;
+ if (!NextItem()) return false;
+
+ // Parse sub-feature.
+ FeatureFunctionDescriptor *subfeature = result->add_feature();
+ subfeature->set_type(type);
+ if (!ParseFeature(subfeature)) return false;
+ }
+ if (!NextItem()) return false;
+ }
+ return true;
+}
+
+bool FELParser::ParseParameter(FeatureFunctionDescriptor *result) {
+ if (item_type_ == NUMBER) {
+ int argument;
+ if (!LiteAtoi(item_text_, &argument)) {
+ ReportError("Unable to parse number");
+ return false;
+ }
+ if (!NextItem()) return false;
+
+ // Set default argument for feature.
+ result->set_argument(argument);
+ } else if (item_type_ == NAME) {
+ string name = item_text_;
+ if (!NextItem()) return false;
+ if (item_type_ != '=') {
+ ReportError("= expected");
+ return false;
+ }
+ if (!NextItem()) return false;
+ if (item_type_ >= END) {
+ ReportError("Parameter value expected");
+ return false;
+ }
+ string value = item_text_;
+ if (!NextItem()) return false;
+
+ // Add parameter to feature.
+ Parameter *parameter;
+ parameter = result->add_parameter();
+ parameter->set_name(name);
+ parameter->set_value(value);
+ } else {
+ ReportError("Syntax error in parameter list");
+ return false;
+ }
+ return true;
+}
+
+} // namespace mobile
+} // namespace nlp_saft
diff --git a/lang_id/common/fel/fel-parser.h b/lang_id/common/fel/fel-parser.h
new file mode 100644
index 0000000..eacb442
--- /dev/null
+++ b/lang_id/common/fel/fel-parser.h
@@ -0,0 +1,135 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Feature extraction language (FEL) parser.
+//
+// BNF grammar for FEL:
+//
+// <feature model> ::= { <feature extractor> }
+//
+// <feature extractor> ::= <extractor spec> |
+// <extractor spec> '.' <feature extractor> |
+// <extractor spec> '{' { <feature extractor> } '}'
+//
+// <extractor spec> ::= <extractor type>
+// [ '(' <parameter list> ')' ]
+// [ ':' <extractor name> ]
+//
+// <parameter list> = ( <parameter> | <argument> ) { ',' <parameter> }
+//
+// <parameter> ::= <parameter name> '=' <parameter value>
+//
+// <extractor type> ::= NAME
+// <extractor name> ::= NAME | STRING
+// <argument> ::= NUMBER
+// <parameter name> ::= NAME
+// <parameter value> ::= NUMBER | STRING | NAME
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_FEL_PARSER_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_FEL_PARSER_H_
+
+#include <string>
+
+#include "lang_id/common/fel/feature-descriptors.h"
+#include "lang_id/common/lite_base/logging.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+class FELParser {
+ public:
+ // Parses fml specification into feature extractor descriptor.
+ // Returns true on success, false on error (e.g., syntax errors).
+ bool Parse(const string &source, FeatureExtractorDescriptor *result);
+
+ private:
+ // Initializes the parser with the source text.
+ // Returns true on success, false on syntax error.
+ bool Initialize(const string &source);
+
+ // Outputs an error message, with context info.
+ void ReportError(const string &error_message);
+
+ // Moves to the next input character.
+ void Next();
+
+ // Moves to the next input item. Sets item_text_ and item_type_ accordingly.
+ // Returns true on success, false on syntax error.
+ bool NextItem();
+
+ // Parses a feature descriptor.
+ // Returns true on success, false on syntax error.
+ bool ParseFeature(FeatureFunctionDescriptor *result);
+
+ // Parses a parameter specification.
+ // Returns true on success, false on syntax error.
+ bool ParseParameter(FeatureFunctionDescriptor *result);
+
+ // Returns true if end of source input has been reached.
+ bool eos() const { return current_ >= source_.end(); }
+
+ // Returns current character. Other methods should access the current
+ // character through this method (instead of using *current_ directly): this
+ // method performs extra safety checks.
+ //
+ // In case of an unsafe access, returns '\0'.
+ char CurrentChar() const {
+ if ((current_ >= source_.begin()) && (current_ < source_.end())) {
+ return *current_;
+ } else {
+ SAFTM_LOG(ERROR) << "Unsafe char read";
+ return '\0';
+ }
+ }
+
+ // Item types.
+ enum ItemTypes {
+ END = 0,
+ NAME = -1,
+ NUMBER = -2,
+ STRING = -3,
+ };
+
+ // Source text.
+ string source_;
+
+ // Current input position.
+ string::iterator current_;
+
+ // Line number for current input position.
+ int line_number_;
+
+ // Start position for current item.
+ string::iterator item_start_;
+
+ // Start position for current line.
+ string::iterator line_start_;
+
+ // Line number for current item.
+ int item_line_number_;
+
+ // Item type for current item. If this is positive it is interpreted as a
+ // character. If it is negative it is interpreted as an item type.
+ int item_type_;
+
+ // Text for current item.
+ string item_text_;
+};
+
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_FEL_PARSER_H_
diff --git a/lang_id/common/fel/task-context.cc b/lang_id/common/fel/task-context.cc
new file mode 100644
index 0000000..f8b0701
--- /dev/null
+++ b/lang_id/common/fel/task-context.cc
@@ -0,0 +1,74 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/common/fel/task-context.h"
+
+#include <string>
+
+#include "lang_id/common/lite_strings/numbers.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+string TaskContext::GetInputPath(const string &name) const {
+ auto it = inputs_.find(name);
+ if (it != inputs_.end()) {
+ return it->second;
+ }
+ return "";
+}
+
+void TaskContext::SetInputPath(const string &name, const string &path) {
+ inputs_[name] = path;
+}
+
+string TaskContext::Get(const string &name, const char *defval) const {
+ auto it = parameters_.find(name);
+ if (it != parameters_.end()) {
+ return it->second;
+ }
+ return defval;
+}
+
+int TaskContext::Get(const string &name, int defval) const {
+ const string s = Get(name, "");
+ int value = defval;
+ if (LiteAtoi(s, &value)) {
+ return value;
+ }
+ return defval;
+}
+
+float TaskContext::Get(const string &name, float defval) const {
+ const string s = Get(name, "");
+ float value = defval;
+ if (LiteAtof(s, &value)) {
+ return value;
+ }
+ return defval;
+}
+
+bool TaskContext::Get(const string &name, bool defval) const {
+ string value = Get(name, "");
+ return value.empty() ? defval : value == "true";
+}
+
+void TaskContext::SetParameter(const string &name, const string &value) {
+ parameters_[name] = value;
+}
+
+} // namespace mobile
+} // namespace nlp_saft
diff --git a/lang_id/common/fel/task-context.h b/lang_id/common/fel/task-context.h
new file mode 100644
index 0000000..ddc8cfe
--- /dev/null
+++ b/lang_id/common/fel/task-context.h
@@ -0,0 +1,72 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef TC3_STD_STRING_IMPORT
+#define TC3_STD_STRING_IMPORT
+#include <string>
+
+namespace libtextclassifier3 {
+using string = std::string;
+template <class CharT, class Traits = std::char_traits<CharT>,
+ class Allocator = std::allocator<CharT> >
+using basic_string = std::basic_string<CharT, Traits, Allocator>;
+} // namespace libtextclassifier3
+#endif
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_TASK_CONTEXT_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_TASK_CONTEXT_H_
+
+#include <map>
+#include <string>
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+// Class that provides access to model parameter and inputs.
+//
+// Note: This class is related to the servers-side nlp_saft::TaskContext, but it
+// has been simplified to reduce code dependencies.
+class TaskContext {
+ public:
+ // Returns path for the input named |name|. Returns empty string ("") if
+ // there is no input with that name. Note: this can be a standard file path,
+ // or a path in a more special file system.
+ string GetInputPath(const string &name) const;
+
+ // Sets path for input |name|. Previous path, if any, is overwritten.
+ void SetInputPath(const string &name, const string &path);
+
+ // Returns parameter value. If the parameter is not specified in this
+ // context, the default value is returned.
+ string Get(const string &name, const char *defval) const;
+ int Get(const string &name, int defval) const;
+ float Get(const string &name, float defval) const;
+ bool Get(const string &name, bool defval) const;
+
+ // Sets value of parameter |name| to |value|.
+ void SetParameter(const string &name, const string &value);
+
+ private:
+ // Maps input name -> path.
+ std::map<string, string> inputs_;
+
+ // Maps parameter name -> value.
+ std::map<string, string> parameters_;
+};
+
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_TASK_CONTEXT_H_
diff --git a/lang_id/common/fel/workspace.cc b/lang_id/common/fel/workspace.cc
new file mode 100644
index 0000000..8cab281
--- /dev/null
+++ b/lang_id/common/fel/workspace.cc
@@ -0,0 +1,58 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/common/fel/workspace.h"
+
+#include <atomic>
+#include <string>
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+// static
+int GetFreshTypeId() {
+ // Static local below is initialized the first time this method is run.
+ static std::atomic<int> counter(0);
+ return counter++;
+}
+
+string WorkspaceRegistry::DebugString() const {
+ string str;
+ for (auto &it : workspace_names_) {
+ const string &type_name = workspace_types_.at(it.first);
+ for (size_t index = 0; index < it.second.size(); ++index) {
+ const string &workspace_name = it.second[index];
+ str.append("\n ");
+ str.append(type_name);
+ str.append(" :: ");
+ str.append(workspace_name);
+ }
+ }
+ return str;
+}
+
+VectorIntWorkspace::VectorIntWorkspace(int size) : elements_(size) {}
+
+VectorIntWorkspace::VectorIntWorkspace(int size, int value)
+ : elements_(size, value) {}
+
+VectorIntWorkspace::VectorIntWorkspace(const std::vector<int> &elements)
+ : elements_(elements) {}
+
+string VectorIntWorkspace::TypeName() { return "Vector"; }
+
+} // namespace mobile
+} // namespace nlp_saft
diff --git a/lang_id/common/fel/workspace.h b/lang_id/common/fel/workspace.h
new file mode 100644
index 0000000..09095e4
--- /dev/null
+++ b/lang_id/common/fel/workspace.h
@@ -0,0 +1,204 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Notes on thread-safety: All of the classes here are thread-compatible. More
+// specifically, the registry machinery is thread-safe, as long as each thread
+// performs feature extraction on a different Sentence object.
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_WORKSPACE_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_WORKSPACE_H_
+
+#include <stddef.h>
+#include <string>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "lang_id/common/lite_base/logging.h"
+#include "lang_id/common/lite_base/macros.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+// A base class for shared workspaces. Derived classes implement a static member
+// function TypeName() which returns a human readable string name for the class.
+class Workspace {
+ public:
+ // Polymorphic destructor.
+ virtual ~Workspace() {}
+
+ protected:
+ // Create an empty workspace.
+ Workspace() {}
+
+ private:
+ SAFTM_DISALLOW_COPY_AND_ASSIGN(Workspace);
+};
+
+// Returns a new, strictly increasing int every time it is invoked.
+int GetFreshTypeId();
+
+// Struct to simulate typeid, but without RTTI.
+template <typename T>
+struct TypeId {
+ static int type_id;
+};
+
+template <typename T>
+int TypeId<T>::type_id = GetFreshTypeId();
+
+// A registry that keeps track of workspaces.
+class WorkspaceRegistry {
+ public:
+ // Create an empty registry.
+ WorkspaceRegistry() {}
+
+ // Returns the index of a named workspace, adding it to the registry first
+ // if necessary.
+ template <class W>
+ int Request(const string &name) {
+ const int id = TypeId<W>::type_id;
+ max_workspace_id_ = std::max(id, max_workspace_id_);
+ workspace_types_[id] = W::TypeName();
+ std::vector<string> &names = workspace_names_[id];
+ for (int i = 0; i < names.size(); ++i) {
+ if (names[i] == name) return i;
+ }
+ names.push_back(name);
+ return names.size() - 1;
+ }
+
+ // Returns the maximum workspace id that has been registered.
+ int MaxId() const {
+ return max_workspace_id_;
+ }
+
+ const std::unordered_map<int, std::vector<string> > &WorkspaceNames()
+ const {
+ return workspace_names_;
+ }
+
+ // Returns a string describing the registered workspaces.
+ string DebugString() const;
+
+ private:
+ // Workspace type names, indexed as workspace_types_[typeid].
+ std::unordered_map<int, string> workspace_types_;
+
+ // Workspace names, indexed as workspace_names_[typeid][workspace].
+ std::unordered_map<int, std::vector<string> > workspace_names_;
+
+ // The maximum workspace id that has been registered.
+ int max_workspace_id_ = 0;
+
+ SAFTM_DISALLOW_COPY_AND_ASSIGN(WorkspaceRegistry);
+};
+
+// A typed collected of workspaces. The workspaces are indexed according to an
+// external WorkspaceRegistry. If the WorkspaceSet is const, the contents are
+// also immutable.
+class WorkspaceSet {
+ public:
+ ~WorkspaceSet() { Reset(WorkspaceRegistry()); }
+
+ // Returns true if a workspace has been set.
+ template <class W>
+ bool Has(int index) const {
+ const int id = TypeId<W>::type_id;
+ SAFTM_DCHECK_GE(id, 0);
+ SAFTM_DCHECK_LT(id, workspaces_.size());
+ SAFTM_DCHECK_GE(index, 0);
+ SAFTM_DCHECK_LT(index, workspaces_[id].size());
+ if (id >= workspaces_.size()) return false;
+ return workspaces_[id][index] != nullptr;
+ }
+
+ // Returns an indexed workspace; the workspace must have been set.
+ template <class W>
+ const W &Get(int index) const {
+ SAFTM_DCHECK(Has<W>(index));
+ const int id = TypeId<W>::type_id;
+ const Workspace *w = workspaces_[id][index];
+ return reinterpret_cast<const W &>(*w);
+ }
+
+ // Sets an indexed workspace; this takes ownership of the workspace, which
+ // must have been new-allocated. It is an error to set a workspace twice.
+ template <class W>
+ void Set(int index, W *workspace) {
+ const int id = TypeId<W>::type_id;
+ SAFTM_DCHECK_GE(id, 0);
+ SAFTM_DCHECK_LT(id, workspaces_.size());
+ SAFTM_DCHECK_GE(index, 0);
+ SAFTM_DCHECK_LT(index, workspaces_[id].size());
+ SAFTM_DCHECK(workspaces_[id][index] == nullptr);
+ SAFTM_DCHECK(workspace != nullptr);
+ workspaces_[id][index] = workspace;
+ }
+
+ void Reset(const WorkspaceRegistry ®istry) {
+ // Deallocate current workspaces.
+ for (auto &it : workspaces_) {
+ for (size_t index = 0; index < it.size(); ++index) {
+ delete it[index];
+ }
+ }
+ workspaces_.clear();
+ workspaces_.resize(registry.MaxId() + 1, std::vector<Workspace *>());
+ for (auto &it : registry.WorkspaceNames()) {
+ workspaces_[it.first].resize(it.second.size());
+ }
+ }
+
+ private:
+ // The set of workspaces, indexed as workspaces_[typeid][index].
+ std::vector<std::vector<Workspace *> > workspaces_;
+};
+
+// A workspace that wraps around a vector of int.
+class VectorIntWorkspace : public Workspace {
+ public:
+ // Creates a vector of the given size.
+ explicit VectorIntWorkspace(int size);
+
+ // Creates a vector initialized with the given array.
+ explicit VectorIntWorkspace(const std::vector<int> &elements);
+
+ // Creates a vector of the given size, with each element initialized to the
+ // given value.
+ VectorIntWorkspace(int size, int value);
+
+ // Returns the name of this type of workspace.
+ static string TypeName();
+
+ // Returns the i'th element.
+ int element(int i) const { return elements_[i]; }
+
+ // Sets the i'th element.
+ void set_element(int i, int value) { elements_[i] = value; }
+
+ // Returns the size of the underlying vector.
+ int size() const { return elements_.size(); }
+
+ private:
+ // The enclosed vector.
+ std::vector<int> elements_;
+};
+
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_WORKSPACE_H_
diff --git a/lang_id/common/file/file-utils.cc b/lang_id/common/file/file-utils.cc
new file mode 100644
index 0000000..108c7d5
--- /dev/null
+++ b/lang_id/common/file/file-utils.cc
@@ -0,0 +1,65 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/common/file/file-utils.h"
+
+#include <fcntl.h>
+#include <stdio.h>
+#include <sys/stat.h>
+#include <sys/types.h>
+
+#include "lang_id/common/lite_base/logging.h"
+#include "lang_id/common/lite_strings/stringpiece.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+namespace file_utils {
+
+bool GetFileContent(const string &filename, string *content) {
+ ScopedMmap scoped_mmap(filename);
+ const MmapHandle &handle = scoped_mmap.handle();
+ if (!handle.ok()) {
+ SAFTM_LOG(ERROR) << "Error opening " << filename;
+ return false;
+ }
+ StringPiece sp = handle.to_stringpiece();
+ content->assign(sp.data(), sp.size());
+ return true;
+}
+
+bool FileExists(const string &filename) {
+ struct stat s = {0};
+ if (!stat(filename.c_str(), &s)) {
+ return s.st_mode & S_IFREG;
+ } else {
+ return false;
+ }
+}
+
+bool DirectoryExists(const string &dirpath) {
+ struct stat s = {0};
+ if (!stat(dirpath.c_str(), &s)) {
+ return s.st_mode & S_IFDIR;
+ } else {
+ return false;
+ }
+}
+
+} // namespace file_utils
+
+} // namespace mobile
+} // namespace nlp_saft
diff --git a/lang_id/common/file/file-utils.h b/lang_id/common/file/file-utils.h
new file mode 100644
index 0000000..6377d7a
--- /dev/null
+++ b/lang_id/common/file/file-utils.h
@@ -0,0 +1,83 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_FILE_FILE_UTILS_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_FILE_FILE_UTILS_H_
+
+#include <stddef.h>
+#include <string>
+
+#include "lang_id/common/file/mmap.h"
+#include "lang_id/common/lite_strings/stringpiece.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+namespace file_utils {
+
+// Reads the entire content of a file into a string. Returns true on success,
+// false on error.
+bool GetFileContent(const string &filename, string *content);
+
+// Parses a proto from its serialized representation in memory. That
+// representation starts at address |data| and should contain exactly
+// |num_bytes| bytes. Returns true on success, false otherwise.
+template <class Proto>
+bool ParseProtoFromMemory(const char *data, size_t num_bytes, Proto *proto) {
+ if (data == nullptr) {
+ // Avoid passing a nullptr to ParseFromArray below.
+ return false;
+ }
+ return proto->ParseFromArray(data, num_bytes);
+}
+
+// Convenience StringPiece-based version of ParseProtoFromMemory.
+template <class Proto>
+inline bool ParseProtoFromMemory(StringPiece sp, Proto *proto) {
+ return ParseProtoFromMemory(sp.data(), sp.size(), proto);
+}
+
+// Parses a proto from a file. Returns true on success, false otherwise.
+//
+// Note: the entire content of the file should be the binary (not
+// human-readable) serialization of a protocol buffer.
+//
+// Note: when we compile for Android, the proto parsing methods need to know the
+// type of the message they are parsing. We use template polymorphism for that.
+template<class Proto>
+bool ReadProtoFromFile(const string &filename, Proto *proto) {
+ ScopedMmap scoped_mmap(filename);
+ const MmapHandle &handle = scoped_mmap.handle();
+ if (!handle.ok()) {
+ return false;
+ }
+ return ParseProtoFromMemory(handle.to_stringpiece(), proto);
+}
+
+// Returns true if filename is the name of an existing file, and false
+// otherwise.
+bool FileExists(const string &filename);
+
+// Returns true if dirpath is the path to an existing directory, and false
+// otherwise.
+bool DirectoryExists(const string &dirpath);
+
+} // namespace file_utils
+
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_FILE_FILE_UTILS_H_
diff --git a/lang_id/common/file/mmap.cc b/lang_id/common/file/mmap.cc
new file mode 100644
index 0000000..89efa99
--- /dev/null
+++ b/lang_id/common/file/mmap.cc
@@ -0,0 +1,133 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/common/file/mmap.h"
+
+#include <errno.h>
+#include <fcntl.h>
+#include <stdint.h>
+#include <string.h>
+#include <sys/mman.h>
+#include <sys/stat.h>
+#include <unistd.h>
+
+#include "lang_id/common/lite_base/logging.h"
+#include "lang_id/common/lite_base/macros.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+namespace {
+inline string GetLastSystemError() {
+ return string(strerror(errno));
+}
+
+inline MmapHandle GetErrorMmapHandle() {
+ return MmapHandle(nullptr, 0);
+}
+
+class FileCloser {
+ public:
+ explicit FileCloser(int fd) : fd_(fd) {}
+ ~FileCloser() {
+ int result = close(fd_);
+ if (result != 0) {
+ const string last_error = GetLastSystemError();
+ SAFTM_LOG(ERROR) << "Error closing file descriptor: " << last_error;
+ }
+ }
+ private:
+ const int fd_;
+
+ SAFTM_DISALLOW_COPY_AND_ASSIGN(FileCloser);
+};
+} // namespace
+
+MmapHandle MmapFile(const string &filename) {
+ int fd = open(filename.c_str(), O_RDONLY);
+
+ if (fd < 0) {
+ const string last_error = GetLastSystemError();
+ SAFTM_LOG(ERROR) << "Error opening " << filename << ": " << last_error;
+ return GetErrorMmapHandle();
+ }
+
+ // Make sure we close fd no matter how we exit this function. As the man page
+ // for mmap clearly states: "closing the file descriptor does not unmap the
+ // region." Hence, we can close fd as soon as we return from here.
+ FileCloser file_closer(fd);
+
+ return MmapFile(fd);
+}
+
+MmapHandle MmapFile(int fd) {
+ // Get file stats to obtain file size.
+ struct stat sb;
+ if (fstat(fd, &sb) != 0) {
+ const string last_error = GetLastSystemError();
+ SAFTM_LOG(ERROR) << "Unable to stat fd: " << last_error;
+ return GetErrorMmapHandle();
+ }
+ size_t file_size_in_bytes = static_cast<size_t>(sb.st_size);
+
+ // Perform actual mmap.
+ void *mmap_addr = mmap(
+
+ // Let system pick address for mmapp-ed data.
+ nullptr,
+
+ // Mmap all bytes from the file.
+ file_size_in_bytes,
+
+ // One can read / write the mapped data (but see MAP_PRIVATE below).
+ // Normally, we expect only to read it, but in the future, we may want to
+ // write it, to fix e.g., endianness differences.
+ PROT_READ | PROT_WRITE,
+
+ // Updates to mmaped data are *not* propagated to actual file.
+ // AFAIK(salcianu) that's anyway not possible on Android.
+ MAP_PRIVATE,
+
+ // Descriptor of file to mmap.
+ fd,
+
+ // Map bytes right from the beginning of the file. This, and
+ // file_size_in_bytes (2nd argument) means we map all bytes from the file.
+ 0);
+ if (mmap_addr == MAP_FAILED) {
+ const string last_error = GetLastSystemError();
+ SAFTM_LOG(ERROR) << "Error while mmapping: " << last_error;
+ return GetErrorMmapHandle();
+ }
+
+ return MmapHandle(mmap_addr, file_size_in_bytes);
+}
+
+bool Unmap(MmapHandle mmap_handle) {
+ if (!mmap_handle.ok()) {
+ // Unmapping something that hasn't been mapped is trivially successful.
+ return true;
+ }
+ if (munmap(mmap_handle.start(), mmap_handle.num_bytes()) != 0) {
+ const string last_error = GetLastSystemError();
+ SAFTM_LOG(ERROR) << "Error during Unmap / munmap: " << last_error;
+ return false;
+ }
+ return true;
+}
+
+} // namespace mobile
+} // namespace nlp_saft
diff --git a/lang_id/common/file/mmap.h b/lang_id/common/file/mmap.h
new file mode 100644
index 0000000..6131803
--- /dev/null
+++ b/lang_id/common/file/mmap.h
@@ -0,0 +1,120 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_FILE_MMAP_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_FILE_MMAP_H_
+
+#include <stddef.h>
+
+#include <string>
+
+#include "lang_id/common/lite_strings/stringpiece.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+// Handle for a memory area where a file has been mmapped.
+//
+// Similar to a pointer: you "allocate" it using MmapFile(filename) and "delete"
+// it using Unmap(). Just like a pointer, it is passed around by value (see
+// signature of MmapFile and Unmap; fortunately, it's a small class, so there
+// shouldn't be any significant performance penalty) and its usage is not
+// necessarily scoped (that's why the destructor is not performing the unmap).
+//
+// Note: on program termination, each still unmapped file is automatically
+// unmapped. Hence, it is not an error if you don't call Unmap() (provided you
+// are ok keeping that file in memory the whole time).
+class MmapHandle {
+ public:
+ MmapHandle(void *start, size_t num_bytes)
+ : start_(start), num_bytes_(num_bytes) {}
+
+ // Returns start address for the memory area where a file has been mmapped.
+ void *start() const { return start_; }
+
+ // Returns number of bytes of the memory area from start().
+ size_t num_bytes() const { return num_bytes_; }
+
+ // Shortcut to simplify checking success of MmapFile(). See usage example
+ // from the doc of that function.
+ bool ok() const { return start() != nullptr; }
+
+ // Returns a StringPiece pointing to the same underlying bytes.
+ StringPiece to_stringpiece() const {
+ return StringPiece(reinterpret_cast<char *>(start_), num_bytes_);
+ }
+
+ private:
+ // See doc for start(). Not owned.
+ void *const start_;
+
+ // See doc for num_bytes().
+ const size_t num_bytes_;
+};
+
+// Maps the full content of a file in memory (using mmap).
+//
+// When done using the file content, one can unmap using Unmap(). Otherwise,
+// all mapped files are unmapped when the program terminates.
+//
+// Sample usage:
+//
+// MmapHandle mmap_handle = MmapFile(filename);
+// CHECK(mmap_handle.ok()) << "Unable to mmap " << filename;
+//
+// ... use data from addresses
+// ... [mmap_handle.start, mmap_handle.start + mmap_handle.num_bytes)
+//
+// Unmap(mmap_handle); // Unmap logs errors internally.
+//
+// Note: one can read *and* write the num_bytes bytes from start, but those
+// writes are not propagated to the underlying file, nor to other processes that
+// may have mmapped that file (all changes are local to current process).
+MmapHandle MmapFile(const string &filename);
+
+// Like MmapFile(const string &filename), but uses a file descriptor.
+MmapHandle MmapFile(int fd);
+
+// Unmaps a file mapped using MmapFile. Returns true on success, false
+// otherwise.
+bool Unmap(MmapHandle mmap_handle);
+
+// Scoped mmapping of a file. Mmaps a file on construction, unmaps it on
+// destruction.
+class ScopedMmap {
+ public:
+ explicit ScopedMmap(const string &filename)
+ : handle_(MmapFile(filename)) {}
+
+ explicit ScopedMmap(int fd)
+ : handle_(MmapFile(fd)) {}
+
+ ~ScopedMmap() {
+ if (handle_.ok()) {
+ Unmap(handle_);
+ }
+ }
+
+ const MmapHandle &handle() { return handle_; }
+
+ private:
+ MmapHandle handle_;
+};
+
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_FILE_MMAP_H_
diff --git a/lang_id/common/flatbuffers/embedding-network-params-from-flatbuffer.cc b/lang_id/common/flatbuffers/embedding-network-params-from-flatbuffer.cc
new file mode 100644
index 0000000..ee22420
--- /dev/null
+++ b/lang_id/common/flatbuffers/embedding-network-params-from-flatbuffer.cc
@@ -0,0 +1,449 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/common/flatbuffers/embedding-network-params-from-flatbuffer.h"
+
+#include "lang_id/common/lite_base/endian.h"
+#include "lang_id/common/lite_base/logging.h"
+#include "lang_id/common/lite_base/macros.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+namespace {
+// Returns true if and only if ptr points to a location inside allowed_range.
+bool IsPointerInRange(const char *ptr, StringPiece allowed_range) {
+ return (ptr >= allowed_range.data()) &&
+ (ptr < (allowed_range.data() + allowed_range.size()));
+}
+
+// Returns true if and only if the memory range [start, start +
+// range_size_in_bytes) is included inside allowed_range.
+//
+// Special case: if range_size_in_bytes == 0 (empty range) then we require that
+// start is nullptr or in the allowed_range.
+bool IsMemoryRangeValid(const void *start, int range_size_in_bytes,
+ StringPiece allowed_range) {
+ const char *begin = reinterpret_cast<const char *>(start);
+ if (range_size_in_bytes < 0) {
+ return false;
+ }
+ if (range_size_in_bytes == 0) {
+ return (start == nullptr) || IsPointerInRange(begin, allowed_range);
+ }
+ const char *inclusive_end = begin + (range_size_in_bytes - 1);
+ return (begin <= inclusive_end) && IsPointerInRange(begin, allowed_range) &&
+ IsPointerInRange(inclusive_end, allowed_range);
+}
+
+bool VerifyQuantizationScales(EmbeddingNetworkParams::Matrix matrix,
+ StringPiece bytes) {
+ if (matrix.quant_scales == nullptr) {
+ SAFTM_LOG(ERROR) << "Quantization type "
+ << static_cast<int>(matrix.quant_type)
+ << "; but no quantization scales";
+ return false;
+ }
+ bool valid_scales = IsMemoryRangeValid(matrix.quant_scales,
+ matrix.rows * sizeof(float16), bytes);
+ if (!valid_scales) {
+ SAFTM_LOG(ERROR) << "quantization scales not fully inside bytes";
+ return false;
+ }
+ return true;
+}
+
+// Returns false if we detect a problem with |matrix|, true otherwise. E.g., we
+// check that the array that starts at pointer matrix.elements is fully inside
+// |bytes| (the range of bytes passed to the
+// EmbeddingNetworkParamsFromFlatbuffer constructor).
+bool VerifyMatrix(EmbeddingNetworkParams::Matrix matrix, StringPiece bytes) {
+ if ((matrix.rows < 0) || (matrix.cols < 0)) {
+ SAFTM_LOG(ERROR) << "Wrong matrix geometry: " << matrix.rows << " x "
+ << matrix.cols;
+ return false;
+ }
+
+ const int num_elements = matrix.rows * matrix.cols;
+
+ // Number of bytes occupied by the num_elements elements that start at address
+ // matrix.elements.
+ int element_range_size_in_bytes = 0;
+ switch (matrix.quant_type) {
+ case QuantizationType::NONE:
+ element_range_size_in_bytes = num_elements * sizeof(float);
+ break;
+ case QuantizationType::UINT8: {
+ element_range_size_in_bytes = num_elements;
+ if (!VerifyQuantizationScales(matrix, bytes)) {
+ return false;
+ }
+ break;
+ }
+ case QuantizationType::UINT4: {
+ if (matrix.cols % 2 != 0) {
+ SAFTM_LOG(ERROR) << "UINT4 doesn't work with odd #cols" << matrix.cols;
+ return false;
+ }
+ element_range_size_in_bytes = num_elements / 2;
+ if (!VerifyQuantizationScales(matrix, bytes)) {
+ return false;
+ }
+ break;
+ }
+ case QuantizationType::FLOAT16: {
+ element_range_size_in_bytes = num_elements * sizeof(float16);
+
+ // No need to verify the scales: FLOAT16 quantization does not use scales.
+ break;
+ }
+ default:
+ SAFTM_LOG(ERROR) << "Unsupported quantization type "
+ << static_cast<int>(matrix.quant_type);
+ return false;
+ }
+ if (matrix.elements == nullptr) {
+ SAFTM_LOG(ERROR) << "matrix.elements == nullptr";
+ return false;
+ }
+ bool valid =
+ IsMemoryRangeValid(matrix.elements, element_range_size_in_bytes, bytes);
+ if (!valid) {
+ SAFTM_LOG(ERROR) << "elements not fully inside bytes";
+ return false;
+ }
+ return true;
+}
+
+// Checks the geometry of the network layer represented by |weights| and |bias|,
+// assuming the input to this layer has size |input_size|. Returns false if we
+// detect any problem, true otherwise.
+bool GoodLayerGeometry(int input_size,
+ const EmbeddingNetworkParams::Matrix &weights,
+ const EmbeddingNetworkParams::Matrix &bias) {
+ if (weights.rows != input_size) {
+ SAFTM_LOG(ERROR) << "#rows " << weights.rows << " != " << input_size;
+ return false;
+ }
+ if ((bias.rows != 1) && (bias.cols != 1)) {
+ SAFTM_LOG(ERROR) << "bad bias vector geometry: " << bias.rows << " x "
+ << bias.cols;
+ return false;
+ }
+ int bias_dimension = bias.rows * bias.cols;
+ if (weights.cols != bias_dimension) {
+ SAFTM_LOG(ERROR) << "#cols " << weights.cols << " != " << bias_dimension;
+ return false;
+ }
+ return true;
+}
+} // namespace
+
+EmbeddingNetworkParamsFromFlatbuffer::EmbeddingNetworkParamsFromFlatbuffer(
+ StringPiece bytes) {
+ // We expect valid_ to be initialized to false at this point. We set it to
+ // true only if we successfully complete all initialization. On error, we
+ // return early, leaving valid_ set to false.
+ SAFTM_DCHECK(!valid_);
+
+ // NOTE: current EmbeddingNetworkParams API works only on little-endian
+ // machines. Fortunately, all modern devices are little-endian so, instead of
+ // a costly API change, we support only the little-endian case.
+ //
+ // Technical explanation: for each Matrix, our API provides a pointer to the
+ // matrix elements (see Matrix field |elements|). For unquantized matrices,
+ // that's a const float *pointer; the client code (e.g., Neurosis) uses those
+ // floats directly. That is correct if the EmbeddingNetworkParams come from a
+ // proto, where the proto parsing already handled the endianness differences.
+ // But in the flatbuffer case, that's a pointer to floats in little-endian
+ // format (flatbuffers always use little-endian). If our API provided access
+ // to only one element at a time, the accessor method could swap the bytes "on
+ // the fly", using temporary variables. Instead, our API provides a pointer
+ // to all elements: as their number is variable (and underlying data is
+ // immutable), we can't ensure the bytes of all those elements are swapped
+ // without extra memory allocation to store the swapped bytes (which is what
+ // using flatbuffers is supposed to prevent).
+ if (!LittleEndian::IsLittleEndian()) {
+ SAFTM_LOG(INFO) << "Not a little-endian machine";
+ return;
+ }
+
+ const uint8_t *start = reinterpret_cast<const uint8_t *>(bytes.data());
+ if (start == nullptr) {
+ // Note: as |bytes| is expected to be a valid EmbeddingNetwork flatbuffer,
+ // it should contain the 4-char identifier "NS00" (or a later version). It
+ // can't be empty; hence StringPiece(nullptr, 0) is not legal here.
+ SAFTM_LOG(ERROR) << "nullptr bytes";
+ return;
+ }
+ flatbuffers::Verifier verifier(start, bytes.size());
+ if (!saft_fbs::VerifyEmbeddingNetworkBuffer(verifier)) {
+ SAFTM_LOG(ERROR) << "Not a valid EmbeddingNetwork flatbuffer";
+ return;
+ }
+ network_ = saft_fbs::GetEmbeddingNetwork(start);
+ if (network_ == nullptr) {
+ SAFTM_LOG(ERROR) << "Unable to interpret bytes as a flatbuffer";
+ return;
+ }
+
+ // Perform a few extra checks before declaring this object valid.
+ valid_ = ValidityChecking(bytes);
+}
+
+bool EmbeddingNetworkParamsFromFlatbuffer::ValidityChecking(
+ StringPiece bytes) const {
+ int input_size = 0;
+ for (int i = 0; i < embeddings_size(); ++i) {
+ Matrix embeddings = GetEmbeddingMatrix(i);
+ if (!VerifyMatrix(embeddings, bytes)) {
+ SAFTM_LOG(ERROR) << "Bad embedding matrix #" << i;
+ return false;
+ }
+ input_size += embedding_num_features(i) * embeddings.cols;
+ }
+ int current_size = input_size;
+ for (int i = 0; i < hidden_size(); ++i) {
+ Matrix weights = GetHiddenLayerMatrix(i);
+ if (!VerifyMatrix(weights, bytes)) {
+ SAFTM_LOG(ERROR) << "Bad weights matrix for hidden layer #" << i;
+ return false;
+ }
+ Matrix bias = GetHiddenLayerBias(i);
+ if (!VerifyMatrix(bias, bytes)) {
+ SAFTM_LOG(ERROR) << "Bad bias vector for hidden layer #" << i;
+ return false;
+ }
+ if (!GoodLayerGeometry(current_size, weights, bias)) {
+ SAFTM_LOG(ERROR) << "Bad geometry for hidden layer #" << i;
+ return false;
+ }
+ current_size = weights.cols;
+ }
+
+ if (HasSoftmax()) {
+ Matrix weights = GetSoftmaxMatrix();
+ if (!VerifyMatrix(weights, bytes)) {
+ SAFTM_LOG(ERROR) << "Bad weights matrix for softmax";
+ return false;
+ }
+ Matrix bias = GetSoftmaxBias();
+ if (!VerifyMatrix(bias, bytes)) {
+ SAFTM_LOG(ERROR) << "Bad bias vector for softmax";
+ return false;
+ }
+ if (!GoodLayerGeometry(current_size, weights, bias)) {
+ SAFTM_LOG(ERROR) << "Bad geometry for softmax layer";
+ return false;
+ }
+ }
+ return true;
+}
+
+// static
+bool EmbeddingNetworkParamsFromFlatbuffer::InRangeIndex(int index, int limit,
+ const char *info) {
+ if ((index >= 0) && (index < limit)) {
+ return true;
+ } else {
+ SAFTM_LOG(ERROR) << info << " index " << index << " outside range [0, "
+ << limit << ")";
+ return false;
+ }
+}
+
+int EmbeddingNetworkParamsFromFlatbuffer::SafeGetNumInputChunks() const {
+ const auto *input_chunks = network_->input_chunks();
+ if (input_chunks == nullptr) {
+ SAFTM_LOG(ERROR) << "nullptr input_chunks";
+ return 0;
+ }
+ return input_chunks->size();
+}
+
+const saft_fbs::InputChunk *
+EmbeddingNetworkParamsFromFlatbuffer::SafeGetInputChunk(int i) const {
+ if (!InRangeIndex(i, SafeGetNumInputChunks(), "input chunks")) {
+ return nullptr;
+ }
+ const auto *input_chunks = network_->input_chunks();
+ if (input_chunks == nullptr) {
+ // Execution should not reach this point, due to how SafeGetNumInputChunks()
+ // is implemented. Still, just to be sure:
+ SAFTM_LOG(ERROR) << "nullptr input_chunks";
+ return nullptr;
+ }
+ const saft_fbs::InputChunk *input_chunk = input_chunks->Get(i);
+ if (input_chunk == nullptr) {
+ SAFTM_LOG(ERROR) << "nullptr input chunk #" << i;
+ }
+ return input_chunk;
+}
+
+const saft_fbs::Matrix *
+EmbeddingNetworkParamsFromFlatbuffer::SafeGetEmbeddingMatrix(int i) const {
+ const saft_fbs::InputChunk *input_chunk = SafeGetInputChunk(i);
+ if (input_chunk == nullptr) return nullptr;
+ const saft_fbs::Matrix *matrix = input_chunk->embedding();
+ if (matrix == nullptr) {
+ SAFTM_LOG(ERROR) << "nullptr embeding matrix #" << i;
+ }
+ return matrix;
+}
+
+int EmbeddingNetworkParamsFromFlatbuffer::SafeGetNumLayers() const {
+ const auto *layers = network_->layers();
+ if (layers == nullptr) {
+ SAFTM_LOG(ERROR) << "nullptr layers";
+ return 0;
+ }
+ return layers->size();
+}
+
+const saft_fbs::NeuralLayer *EmbeddingNetworkParamsFromFlatbuffer::SafeGetLayer(
+ int i) const {
+ if (!InRangeIndex(i, SafeGetNumLayers(), "layer")) {
+ return nullptr;
+ }
+ const auto *layers = network_->layers();
+ if (layers == nullptr) {
+ // Execution should not reach this point, due to how SafeGetNumLayers()
+ // is implemented. Still, just to be sure:
+ SAFTM_LOG(ERROR) << "nullptr layers";
+ return nullptr;
+ }
+ const saft_fbs::NeuralLayer *layer = layers->Get(i);
+ if (layer == nullptr) {
+ SAFTM_LOG(ERROR) << "nullptr layer #" << i;
+ }
+ return layer;
+}
+
+const saft_fbs::Matrix *
+EmbeddingNetworkParamsFromFlatbuffer::SafeGetLayerWeights(int i) const {
+ const saft_fbs::NeuralLayer *layer = SafeGetLayer(i);
+ if (layer == nullptr) return nullptr;
+ const saft_fbs::Matrix *weights = layer->weights();
+ if (weights == nullptr) {
+ SAFTM_LOG(ERROR) << "nullptr weights for layer #" << i;
+ }
+ return weights;
+}
+
+const saft_fbs::Matrix *EmbeddingNetworkParamsFromFlatbuffer::SafeGetLayerBias(
+ int i) const {
+ const saft_fbs::NeuralLayer *layer = SafeGetLayer(i);
+ if (layer == nullptr) return nullptr;
+ const saft_fbs::Matrix *bias = layer->bias();
+ if (bias == nullptr) {
+ SAFTM_LOG(ERROR) << "nullptr bias for layer #" << i;
+ }
+ return bias;
+}
+
+// static
+const float *EmbeddingNetworkParamsFromFlatbuffer::SafeGetValues(
+ const saft_fbs::Matrix *matrix) {
+ if (matrix == nullptr) return nullptr;
+ const flatbuffers::Vector<float> *values = matrix->values();
+ if (values == nullptr) {
+ SAFTM_LOG(ERROR) << "nullptr values";
+ }
+ return values->data();
+}
+
+// static
+const uint8_t *EmbeddingNetworkParamsFromFlatbuffer::SafeGetQuantizedValues(
+ const saft_fbs::Matrix *matrix) {
+ if (matrix == nullptr) return nullptr;
+ const flatbuffers::Vector<uint8_t> *quantized_values =
+ matrix->quantized_values();
+ if (quantized_values == nullptr) {
+ SAFTM_LOG(ERROR) << "nullptr quantized_values";
+ }
+ return quantized_values->data();
+}
+
+// static
+const float16 *EmbeddingNetworkParamsFromFlatbuffer::SafeGetScales(
+ const saft_fbs::Matrix *matrix) {
+ if (matrix == nullptr) return nullptr;
+ const flatbuffers::Vector<uint16_t> *scales = matrix->scales();
+ if (scales == nullptr) {
+ SAFTM_LOG(ERROR) << "nullptr scales";
+ }
+ return scales->data();
+}
+
+const saft_fbs::NeuralLayer *
+EmbeddingNetworkParamsFromFlatbuffer::SafeGetSoftmaxLayer() const {
+ int num_layers = SafeGetNumLayers();
+ if (num_layers <= 0) {
+ SAFTM_LOG(ERROR) << "No softmax layer";
+ return nullptr;
+ }
+ return SafeGetLayer(num_layers - 1);
+}
+
+QuantizationType EmbeddingNetworkParamsFromFlatbuffer::SafeGetQuantizationType(
+ const saft_fbs::Matrix *matrix) const {
+ if (matrix == nullptr) {
+ return QuantizationType::NONE;
+ }
+ saft_fbs::QuantizationType quantization_type = matrix->quantization_type();
+
+ // Conversion from nlp_saft::saft_fbs::QuantizationType to
+ // nlp_saft::QuantizationType (due to legacy reasons, we have both).
+ switch (quantization_type) {
+ case saft_fbs::QuantizationType_NONE:
+ return QuantizationType::NONE;
+ case saft_fbs::QuantizationType_UINT8:
+ return QuantizationType::UINT8;
+ case saft_fbs::QuantizationType_UINT4:
+ return QuantizationType::UINT4;
+ case saft_fbs::QuantizationType_FLOAT16:
+ return QuantizationType::FLOAT16;
+ default:
+ SAFTM_LOG(ERROR) << "Unsupported quantization type "
+ << static_cast<int>(quantization_type);
+ return QuantizationType::NONE;
+ }
+}
+
+const void *EmbeddingNetworkParamsFromFlatbuffer::SafeGetValuesOfMatrix(
+ const saft_fbs::Matrix *matrix) const {
+ if (matrix == nullptr) {
+ return nullptr;
+ }
+ saft_fbs::QuantizationType quantization_type = matrix->quantization_type();
+ switch (quantization_type) {
+ case saft_fbs::QuantizationType_NONE:
+ return SafeGetValues(matrix);
+ case saft_fbs::QuantizationType_UINT8:
+ SAFTM_FALLTHROUGH_INTENDED;
+ case saft_fbs::QuantizationType_UINT4:
+ SAFTM_FALLTHROUGH_INTENDED;
+ case saft_fbs::QuantizationType_FLOAT16:
+ return SafeGetQuantizedValues(matrix);
+ default:
+ SAFTM_LOG(ERROR) << "Unsupported quantization type "
+ << static_cast<int>(quantization_type);
+ return nullptr;
+ }
+}
+
+} // namespace mobile
+} // namespace nlp_saft
diff --git a/lang_id/common/flatbuffers/embedding-network-params-from-flatbuffer.h b/lang_id/common/flatbuffers/embedding-network-params-from-flatbuffer.h
new file mode 100644
index 0000000..57d59c5
--- /dev/null
+++ b/lang_id/common/flatbuffers/embedding-network-params-from-flatbuffer.h
@@ -0,0 +1,285 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_FLATBUFFERS_EMBEDDING_NETWORK_PARAMS_FROM_FLATBUFFER_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_FLATBUFFERS_EMBEDDING_NETWORK_PARAMS_FROM_FLATBUFFER_H_
+
+#include <algorithm>
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "lang_id/common/embedding-network-params.h"
+#include "lang_id/common/flatbuffers/embedding-network_generated.h"
+#include "lang_id/common/lite_base/float16.h"
+#include "lang_id/common/lite_base/logging.h"
+#include "lang_id/common/lite_strings/stringpiece.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+// EmbeddingNetworkParams implementation backed by a flatbuffer.
+//
+// For info on our flatbuffer schema, see embedding-network.fbs.
+class EmbeddingNetworkParamsFromFlatbuffer : public EmbeddingNetworkParams {
+ public:
+ // Constructs an EmbeddingNetworkParamsFromFlatbuffer instance, using the
+ // flatbuffer from |bytes|.
+ //
+ // IMPORTANT #1: caller should make sure |bytes| are alive during the lifetime
+ // of this EmbeddingNetworkParamsFromFlatbuffer instance. To avoid overhead,
+ // this constructor does not copy |bytes|.
+ //
+ // IMPORTANT #2: immediately after this constructor returns, we suggest you
+ // call is_valid() on the newly-constructed object and do not call any other
+ // method if the answer is negative (false).
+ explicit EmbeddingNetworkParamsFromFlatbuffer(StringPiece bytes);
+
+ bool UpdateTaskContextParameters(mobile::TaskContext *task_context) override {
+ // This class does not provide access to the overall TaskContext. It
+ // provides only parameters for the Neurosis neural network.
+ SAFTM_LOG(DFATAL) << "Not supported";
+ return false;
+ }
+
+ bool is_valid() const override { return valid_; }
+
+ int embeddings_size() const override { return SafeGetNumInputChunks(); }
+
+ int embeddings_num_rows(int i) const override {
+ const saft_fbs::Matrix *matrix = SafeGetEmbeddingMatrix(i);
+ return SafeGetNumRows(matrix);
+ }
+
+ int embeddings_num_cols(int i) const override {
+ const saft_fbs::Matrix *matrix = SafeGetEmbeddingMatrix(i);
+ return SafeGetNumCols(matrix);
+ }
+
+ const void *embeddings_weights(int i) const override {
+ const saft_fbs::Matrix *matrix = SafeGetEmbeddingMatrix(i);
+ return SafeGetValuesOfMatrix(matrix);
+ }
+
+ QuantizationType embeddings_quant_type(int i) const override {
+ const saft_fbs::Matrix *matrix = SafeGetEmbeddingMatrix(i);
+ return SafeGetQuantizationType(matrix);
+ }
+
+ const float16 *embeddings_quant_scales(int i) const override {
+ const saft_fbs::Matrix *matrix = SafeGetEmbeddingMatrix(i);
+ return SafeGetScales(matrix);
+ }
+
+ int hidden_size() const override {
+ // -1 because last layer is always the softmax layer.
+ return std::max(SafeGetNumLayers() - 1, 0);
+ }
+
+ int hidden_num_rows(int i) const override {
+ const saft_fbs::Matrix *weights = SafeGetLayerWeights(i);
+ return SafeGetNumRows(weights);
+ }
+
+ int hidden_num_cols(int i) const override {
+ const saft_fbs::Matrix *weights = SafeGetLayerWeights(i);
+ return SafeGetNumCols(weights);
+ }
+
+ QuantizationType hidden_weights_quant_type(int i) const override {
+ const saft_fbs::Matrix *weights = SafeGetLayerWeights(i);
+ return SafeGetQuantizationType(weights);
+ }
+
+ const void *hidden_weights(int i) const override {
+ const saft_fbs::Matrix *weights = SafeGetLayerWeights(i);
+ return SafeGetValuesOfMatrix(weights);
+ }
+
+ int hidden_bias_size() const override { return hidden_size(); }
+
+ int hidden_bias_num_rows(int i) const override {
+ const saft_fbs::Matrix *bias = SafeGetLayerBias(i);
+ return SafeGetNumRows(bias);
+ }
+
+ int hidden_bias_num_cols(int i) const override {
+ const saft_fbs::Matrix *bias = SafeGetLayerBias(i);
+ return SafeGetNumCols(bias);
+ }
+
+ const void *hidden_bias_weights(int i) const override {
+ const saft_fbs::Matrix *bias = SafeGetLayerBias(i);
+ return SafeGetValues(bias);
+ }
+
+ int softmax_size() const override { return (SafeGetNumLayers() > 0) ? 1 : 0; }
+
+ int softmax_num_rows(int i) const override {
+ const saft_fbs::Matrix *weights = SafeGetSoftmaxWeights();
+ return SafeGetNumRows(weights);
+ }
+
+ int softmax_num_cols(int i) const override {
+ const saft_fbs::Matrix *weights = SafeGetSoftmaxWeights();
+ return SafeGetNumCols(weights);
+ }
+
+ QuantizationType softmax_weights_quant_type(int i) const override {
+ const saft_fbs::Matrix *weights = SafeGetSoftmaxWeights();
+ return SafeGetQuantizationType(weights);
+ }
+
+ const void *softmax_weights(int i) const override {
+ const saft_fbs::Matrix *weights = SafeGetSoftmaxWeights();
+ return SafeGetValuesOfMatrix(weights);
+ }
+
+ int softmax_bias_size() const override { return softmax_size(); }
+
+ int softmax_bias_num_rows(int i) const override {
+ const saft_fbs::Matrix *bias = SafeGetSoftmaxBias();
+ return SafeGetNumRows(bias);
+ }
+
+ int softmax_bias_num_cols(int i) const override {
+ const saft_fbs::Matrix *bias = SafeGetSoftmaxBias();
+ return SafeGetNumCols(bias);
+ }
+
+ const void *softmax_bias_weights(int i) const override {
+ const saft_fbs::Matrix *bias = SafeGetSoftmaxBias();
+ return SafeGetValues(bias);
+ }
+
+ int embedding_num_features_size() const override {
+ return SafeGetNumInputChunks();
+ }
+
+ int embedding_num_features(int i) const override {
+ if (!InRangeIndex(i, embedding_num_features_size(),
+ "embedding num features")) {
+ return 0;
+ }
+ const saft_fbs::InputChunk *input_chunk = SafeGetInputChunk(i);
+ if (input_chunk == nullptr) {
+ return 0;
+ }
+ return input_chunk->num_features();
+ }
+
+ bool has_is_precomputed() const override { return false; }
+ bool is_precomputed() const override { return false; }
+
+ private:
+ // Returns true if and only if index is in [0, limit). info should be a
+ // pointer to a zero-terminated array of chars (ideally a literal string,
+ // e.g. "layer") indicating what the index refers to; info is used to make log
+ // messages more informative.
+ static bool InRangeIndex(int index, int limit, const char *info);
+
+ // Returns network_->input_chunks()->size(), if all dereferences are safe
+ // (i.e., no nullptr); otherwise, returns 0.
+ int SafeGetNumInputChunks() const;
+
+ // Returns network_->input_chunks()->Get(i), if all dereferences are safe
+ // (i.e., no nullptr) otherwise, returns nullptr.
+ const saft_fbs::InputChunk *SafeGetInputChunk(int i) const;
+
+ // Returns network_->input_chunks()->Get(i)->embedding(), if all dereferences
+ // are safe (i.e., no nullptr); otherwise, returns nullptr.
+ const saft_fbs::Matrix *SafeGetEmbeddingMatrix(int i) const;
+
+ // Returns network_->layers()->size(), if all dereferences are safe (i.e., no
+ // nullptr); otherwise, returns 0.
+ int SafeGetNumLayers() const;
+
+ // Returns network_->layers()->Get(i), if all dereferences are safe
+ // (i.e., no nullptr); otherwise, returns nullptr.
+ const saft_fbs::NeuralLayer *SafeGetLayer(int i) const;
+
+ // Returns network_->layers()->Get(i)->weights(), if all dereferences are safe
+ // (i.e., no nullptr); otherwise, returns nullptr.
+ const saft_fbs::Matrix *SafeGetLayerWeights(int i) const;
+
+ // Returns network_->layers()->Get(i)->bias(), if all dereferences are safe
+ // (i.e., no nullptr); otherwise, returns nullptr.
+ const saft_fbs::Matrix *SafeGetLayerBias(int i) const;
+
+ static int SafeGetNumRows(const saft_fbs::Matrix *matrix) {
+ return (matrix == nullptr) ? 0 : matrix->rows();
+ }
+
+ static int SafeGetNumCols(const saft_fbs::Matrix *matrix) {
+ return (matrix == nullptr) ? 0 : matrix->cols();
+ }
+
+ // Returns matrix->values()->data() if all dereferences are safe (i.e., no
+ // nullptr); otherwise, returns nullptr.
+ static const float *SafeGetValues(const saft_fbs::Matrix *matrix);
+
+ // Returns matrix->quantized_values()->data() if all dereferences are safe
+ // (i.e., no nullptr); otherwise, returns nullptr.
+ static const uint8_t *SafeGetQuantizedValues(const saft_fbs::Matrix *matrix);
+
+ // Returns matrix->scales()->data() if all dereferences are safe (i.e., no
+ // nullptr); otherwise, returns nullptr.
+ static const float16 *SafeGetScales(const saft_fbs::Matrix *matrix);
+
+ // Returns network_->layers()->Get(last_index) with last_index =
+ // SafeGetNumLayers() - 1, if all dereferences are safe (i.e., no nullptr) and
+ // there exists at least one layer; otherwise, returns nullptr.
+ const saft_fbs::NeuralLayer *SafeGetSoftmaxLayer() const;
+
+ const saft_fbs::Matrix *SafeGetSoftmaxWeights() const {
+ const saft_fbs::NeuralLayer *layer = SafeGetSoftmaxLayer();
+ return (layer == nullptr) ? nullptr : layer->weights();
+ }
+
+ const saft_fbs::Matrix *SafeGetSoftmaxBias() const {
+ const saft_fbs::NeuralLayer *layer = SafeGetSoftmaxLayer();
+ return (layer == nullptr) ? nullptr : layer->bias();
+ }
+
+ // Returns the quantization type for |matrix|. Returns NONE in case of
+ // problems (e.g., matrix is nullptr or unknown quantization type).
+ QuantizationType SafeGetQuantizationType(
+ const saft_fbs::Matrix *matrix) const;
+
+ // Returns a pointer to the values (float, uint8, or float16, depending on
+ // quantization) from |matrix|, in row-major order. Returns nullptr in case
+ // of a problem.
+ const void *SafeGetValuesOfMatrix(const saft_fbs::Matrix *matrix) const;
+
+ // Performs some validity checks. E.g., check that dimensions of the network
+ // layers match. Also checks that all pointers we return are inside the
+ // |bytes| passed to the constructor, such that client that reads from those
+ // pointers will not run into troubles.
+ bool ValidityChecking(StringPiece bytes) const;
+
+ // True if these params are valid. May be false if the original proto was
+ // corrupted. We prefer to set this to false to CHECK-failing.
+ bool valid_ = false;
+
+ // EmbeddingNetwork flatbuffer from the bytes passed as parameter to the
+ // constructor; see constructor doc.
+ const saft_fbs::EmbeddingNetwork *network_ = nullptr;
+};
+
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_FLATBUFFERS_EMBEDDING_NETWORK_PARAMS_FROM_FLATBUFFER_H_
diff --git a/lang_id/common/flatbuffers/embedding-network.fbs b/lang_id/common/flatbuffers/embedding-network.fbs
new file mode 100644
index 0000000..1fde6a3
--- /dev/null
+++ b/lang_id/common/flatbuffers/embedding-network.fbs
@@ -0,0 +1,117 @@
+//
+// Copyright (C) 2018 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+// Flatbuffer schema for Neurosis (FFNN with embeddings) parameters.
+//
+// Contains the same information as an EmbeddingNetworkProto.
+
+namespace libtextclassifier3.saft_fbs;
+
+// NS stands for NeurosiS. The next two digits are meant to identify
+// incompatible versions. Ideally, we'll never have to go beyond 00.
+file_identifier "NS00";
+
+// Should be kept in sync with the C++ enum nlp_saft::QuantizationType.
+enum QuantizationType : byte {
+ NONE = 0,
+ UINT8 = 1,
+ UINT4 = 2,
+ FLOAT16 = 3,
+}
+
+table Matrix {
+ // Number of rows of this matrix.
+ rows:int;
+
+ // Number of columns of this matrix.
+ cols:int;
+
+ // Type of quantization used for the values from this matrix.
+ //
+ // If this is QuantizationType_NONE, then the unquantized values should be
+ // stored in |values| below. Otherwise, the bytes of the quantized values
+ // should be stored in |quantized_values| and the float16 quantization scales
+ // should be stored in |scales|.
+ quantization_type:QuantizationType = NONE;
+
+ // Non-quantized matrix elements, in row-major order. See comments for
+ // |quantization_type|.
+ values:[float];
+
+ // Quantized matrix elements, in row-major order. See comments for
+ // |quantization_type|.
+ quantized_values:[ubyte];
+
+ // Quantization factors (float16), one per matrix row. There is no float16
+ // primitive type for flatbuffers, we just use another 16 bit type. See
+ // comments for |quantization_type|.
+ scales:[ushort];
+}
+
+// The input layer for a Neurosis network is composed of several parts (named
+// "chunks" below, "embedding spaces" in some other parts, etc). For each
+// chunk, we have |num_features| features that extract feature values in that
+// chunk. All values extracted by a feature get projected via the embedding
+// matrix |embedding| and summed together, producing a vector of
+// |embedding.cols| elements. The resulting vector gets concatenated with the
+// similar vectors for other |num_features| features, producing a "chunk" of
+// |num_features * embedding.cols| elements. This chunk gets concatenated with
+// the other chunks.
+//
+// Note: the specification that indicates what those |num_features| features are
+// is stored elsewhere (usually in a ModelParameter, see model.fbs). But we
+// need to know |num_features| here, in order to specify the geometry of the
+// Neurosis network.
+table InputChunk {
+ embedding:Matrix;
+ num_features:int;
+}
+
+// One layer of neurons from the Neurosis network. This table can represent a
+// hidden layer or the final (output / softmax) layer.
+//
+// Our formalism is a bit different, but equivalent to the usual description
+// from the literature:
+//
+// Technically, in Neurosis, each layer takes an input (a vector of floats); if
+// this is not the first layer, we apply a nonlinear function (ReLU); for the
+// first layer, we skip ReLU. Next, we multiply by |weights| and add |bias|,
+// get the input for the next level and so on. The output from the last layer
+// is generally used for softmax classification. That's why we say that the
+// last layer is the "softmax layer".
+table NeuralLayer {
+ // Weight matrix for this layer. Geometry: num_inputs x num_neurons, where
+ // num_inputs is the number of values produced by previous layer (which can be
+ // the input layer, or another hidden layer) and num_neurons is the number of
+ // neurons from this layer.
+ weights:Matrix;
+
+ // Bias vector for this layer.
+ //
+ // NOTE: right now, we accept both 1 x num_neurons and num_neurons x 1
+ // geometries: the layout of the elements is the same in both cases.
+ bias:Matrix;
+}
+
+table EmbeddingNetwork {
+ // Specification of the chunks that compose the input layer.
+ input_chunks:[InputChunk];
+
+ // Hidden layers, followed by the final (softmax) layer.
+ layers:[NeuralLayer];
+}
+
+root_type EmbeddingNetwork;
diff --git a/lang_id/common/flatbuffers/model-utils.cc b/lang_id/common/flatbuffers/model-utils.cc
new file mode 100644
index 0000000..2c57aa2
--- /dev/null
+++ b/lang_id/common/flatbuffers/model-utils.cc
@@ -0,0 +1,208 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/common/flatbuffers/model-utils.h"
+
+#include <string.h>
+
+#include "lang_id/common/lite_base/logging.h"
+#include "lang_id/common/math/checksum.h"
+
+namespace libtextclassifier3 {
+namespace saft_fbs {
+
+namespace {
+
+// Returns true if we have clear evidence that |model| fails its checksum.
+//
+// E.g., if |model| has the crc32 field, and the value of that field does not
+// match the checksum, then this function returns true. If there is no crc32
+// field, then we don't know what the original (at build time) checksum was, so
+// we don't know anything clear and this function returns false.
+bool ClearlyFailsChecksum(const Model &model) {
+ if (!flatbuffers::IsFieldPresent(&model, Model::VT_CRC32)) {
+ SAFTM_LOG(WARNING)
+ << "No CRC32, most likely an old model; skip CRC32 check";
+ return false;
+ }
+ const mobile::uint32 expected_crc32 = model.crc32();
+ const mobile::uint32 actual_crc32 = ComputeCrc2Checksum(&model);
+ if (actual_crc32 != expected_crc32) {
+ SAFTM_LOG(ERROR) << "Corrupt model: different CRC32: " << actual_crc32
+ << " vs " << expected_crc32;
+ return true;
+ }
+ SAFTM_LOG(INFO) << "Successfully checked CRC32 " << actual_crc32;
+ return false;
+}
+} // namespace
+
+const Model *GetVerifiedModelFromBytes(const char *data, size_t num_bytes) {
+ if ((data == nullptr) || (num_bytes == 0)) {
+ SAFTM_LOG(ERROR) << "GetModel called on an empty sequence of bytes";
+ return nullptr;
+ }
+ const uint8_t *start = reinterpret_cast<const uint8_t *>(data);
+ flatbuffers::Verifier verifier(start, num_bytes);
+ if (!VerifyModelBuffer(verifier)) {
+ SAFTM_LOG(ERROR) << "Not a valid Model flatbuffer";
+ return nullptr;
+ }
+ const Model *model = GetModel(start);
+ if (model == nullptr) {
+ return nullptr;
+ }
+ if (ClearlyFailsChecksum(*model)) {
+ return nullptr;
+ }
+ return model;
+}
+
+const ModelInput *GetInputByName(const Model *model, const string &name) {
+ if (model == nullptr) {
+ SAFTM_LOG(ERROR) << "GetInputByName called with model == nullptr";
+ return nullptr;
+ }
+ const auto *inputs = model->inputs();
+ if (inputs == nullptr) {
+ // We should always have a list of inputs; maybe an empty one, if no inputs,
+ // but the list should be there.
+ SAFTM_LOG(ERROR) << "null inputs";
+ return nullptr;
+ }
+ for (const ModelInput *input : *inputs) {
+ if (input != nullptr) {
+ const flatbuffers::String *input_name = input->name();
+ if (input_name && input_name->str() == name) {
+ return input;
+ }
+ }
+ }
+ return nullptr;
+}
+
+mobile::StringPiece GetInputBytes(const ModelInput *input) {
+ if ((input == nullptr) || (input->data() == nullptr)) {
+ SAFTM_LOG(ERROR) << "ModelInput has no content";
+ return mobile::StringPiece(nullptr, 0);
+ }
+ const flatbuffers::Vector<uint8_t> *input_data = input->data();
+ if (input_data == nullptr) {
+ SAFTM_LOG(ERROR) << "null input data";
+ return mobile::StringPiece(nullptr, 0);
+ }
+ return mobile::StringPiece(reinterpret_cast<const char *>(input_data->data()),
+ input_data->size());
+}
+
+bool FillParameters(const Model &model, mobile::TaskContext *context) {
+ if (context == nullptr) {
+ SAFTM_LOG(ERROR) << "null context";
+ return false;
+ }
+ const auto *parameters = model.parameters();
+ if (parameters == nullptr) {
+ // We should always have a list of parameters; maybe an empty one, if no
+ // parameters, but the list should be there.
+ SAFTM_LOG(ERROR) << "null list of parameters";
+ return false;
+ }
+ for (const ModelParameter *p : *parameters) {
+ if (p == nullptr) {
+ SAFTM_LOG(ERROR) << "null parameter";
+ return false;
+ }
+ if (p->name() == nullptr) {
+ SAFTM_LOG(ERROR) << "null parameter name";
+ return false;
+ }
+ const string name = p->name()->str();
+ if (name.empty()) {
+ SAFTM_LOG(ERROR) << "empty parameter name";
+ return false;
+ }
+ if (p->value() == nullptr) {
+ SAFTM_LOG(ERROR) << "null parameter name";
+ return false;
+ }
+ context->SetParameter(name, p->value()->str());
+ }
+ return true;
+}
+
+namespace {
+// Updates |*crc| with the information from |s|. Auxiliary for
+// ComputeCrc2Checksum.
+//
+// The bytes from |info| are also used to update the CRC32 checksum. |info|
+// should be a brief tag that indicates what |s| represents. The idea is to add
+// some structure to the information that goes into the CRC32 computation.
+template <typename T>
+void UpdateCrc(mobile::Crc32 *crc, const flatbuffers::Vector<T> *s,
+ mobile::StringPiece info) {
+ crc->Update("|");
+ crc->Update(info.data(), info.size());
+ crc->Update(":");
+ if (s == nullptr) {
+ crc->Update("empty");
+ } else {
+ crc->Update(reinterpret_cast<const char *>(s->data()),
+ s->size() * sizeof(T));
+ }
+}
+} // namespace
+
+mobile::uint32 ComputeCrc2Checksum(const Model *model) {
+ // Implementation note: originally, I (salcianu@) thought we can just compute
+ // a CRC32 checksum of the model bytes. Unfortunately, the expected checksum
+ // is there too (and because we don't control the flatbuffer format, we can't
+ // "arrange" for it to be placed at the head / tail of those bytes). Instead,
+ // we traverse |model| and feed into the CRC32 computation those parts we are
+ // interested in (which excludes the crc32 field).
+ //
+ // Note: storing the checksum outside the Model would be too disruptive for
+ // the way we currently ship our models.
+ mobile::Crc32 crc;
+ if (model == nullptr) {
+ return crc.Get();
+ }
+ crc.Update("|Parameters:");
+ const auto *parameters = model->parameters();
+ if (parameters != nullptr) {
+ for (const ModelParameter *p : *parameters) {
+ if (p != nullptr) {
+ UpdateCrc(&crc, p->name(), "name");
+ UpdateCrc(&crc, p->value(), "value");
+ }
+ }
+ }
+ crc.Update("|Inputs:");
+ const auto *inputs = model->inputs();
+ if (inputs != nullptr) {
+ for (const ModelInput *input : *inputs) {
+ if (input != nullptr) {
+ UpdateCrc(&crc, input->name(), "name");
+ UpdateCrc(&crc, input->type(), "type");
+ UpdateCrc(&crc, input->sub_type(), "sub-type");
+ UpdateCrc(&crc, input->data(), "data");
+ }
+ }
+ }
+ return crc.Get();
+}
+
+} // namespace saft_fbs
+} // namespace nlp_saft
diff --git a/lang_id/common/flatbuffers/model-utils.h b/lang_id/common/flatbuffers/model-utils.h
new file mode 100644
index 0000000..5427f70
--- /dev/null
+++ b/lang_id/common/flatbuffers/model-utils.h
@@ -0,0 +1,72 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_FLATBUFFERS_MODEL_UTILS_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_FLATBUFFERS_MODEL_UTILS_H_
+
+#include <stddef.h>
+
+#include <string>
+
+#include "lang_id/common/fel/task-context.h"
+#include "lang_id/common/flatbuffers/model_generated.h"
+#include "lang_id/common/lite_base/integral-types.h"
+#include "lang_id/common/lite_strings/stringpiece.h"
+
+namespace libtextclassifier3 {
+namespace saft_fbs {
+
+// Verifies that the |num_bytes| bytes that start at |data| represent a valid
+// Model flatbuffer. If so, returns that Model. Otherwise, returns nullptr.
+//
+// Note: if the Model has the crc32 field, this method checks that the Model
+// checksum matches that field; if they don't match, the Model is considered
+// invalid, and this function returns nullptr. The checksum test is in addition
+// to the standard flatbuffer validity checking.
+const Model *GetVerifiedModelFromBytes(const char *data, size_t num_bytes);
+
+// Convenience StringPiece version of GetVerifiedModelFromBytes.
+inline const Model *GetVerifiedModelFromBytes(mobile::StringPiece bytes) {
+ return GetVerifiedModelFromBytes(bytes.data(), bytes.size());
+}
+
+// Returns the |model| input with specified |name|. Returns nullptr if no such
+// input exists. If |model| contains multiple inputs with that |name|, returns
+// the first one (model builders should avoid building such models).
+const ModelInput *GetInputByName(const Model *model, const string &name);
+
+// Returns a StringPiece pointing to the bytes for the content of |input|. In
+// case of errors, returns StringPiece(nullptr, 0).
+mobile::StringPiece GetInputBytes(const ModelInput *input);
+
+// Fills parameters from |context|, based on the parameters from |model|.
+// Returns false if any error is encountered, true otherwise. In the case of an
+// error, some parameters may have been added to |context| (e.g., if we find a
+// problem with the 3rd parameter, the first 2 have been added).
+bool FillParameters(const Model &model, mobile::TaskContext *context);
+
+// Returns the CRC32 checksum of |model|. This checksum is computed over the
+// entire information from the model (including the bytes of the inputs),
+// *except* the crc32 field. Hence, when a model is build, one can store the
+// result of this function into that field; on the user side, one can check that
+// the result of this function matches the crc32 field, to guard against model
+// corruption. GetVerifiedModelFromBytes performs this check.
+mobile::uint32 ComputeCrc2Checksum(const Model *model);
+
+} // namespace saft_fbs
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_FLATBUFFERS_MODEL_UTILS_H_
diff --git a/lang_id/common/flatbuffers/model.fbs b/lang_id/common/flatbuffers/model.fbs
new file mode 100644
index 0000000..41251e1
--- /dev/null
+++ b/lang_id/common/flatbuffers/model.fbs
@@ -0,0 +1,79 @@
+//
+// Copyright (C) 2018 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+// Flatbuffer schema for SAFT models.
+//
+// For info on flatbuffers, see http://go/flatbuffers and
+// http://google.github.io/flatbuffers/, including info on writing schemas:
+// http://google.github.io/flatbuffers/flatbuffers_guide_writing_schema.html
+
+namespace libtextclassifier3.saft_fbs;
+
+// SM stands for Saft Model. The next two digits are meant to identify
+// incompatible versions. Ideally, we'll never have to go beyond 00.
+file_identifier "SM00";
+
+// Extension stands for Saft Model in FlatBuffer format.
+file_extension "smfb";
+
+table ModelParameter {
+ // Parameter name.
+ name:string;
+
+ // Parameter value.
+ value:string;
+}
+
+// Input for a SAFT model. Inputs usually provide extra resources: e.g., the
+// parameters for a Neurosis FFNN with embeddings, or a word cluster structure,
+// etc.
+table ModelInput {
+ // Name of this input. Different input of the same model should have
+ // different names, such that we can non-ambiguously look them up.
+ name:string;
+
+ // General description of the type of this input. Required to parse the
+ // content of this input (see |data| below). If |data| is a flatbuffer, use
+ // "flatbuffer". If |data| is a proto, use "proto". Otherwise, use your best
+ // judgment: use something human-readable, and look around to make sure you
+ // don't invent a new name for something that already exists.
+ type:string;
+
+ // More specific information about the type of this input. E.g., if |type| is
+ // "flatbuffer", this should be the name of the root_type we should parse from
+ // the input bytes., e.g., "EmbeddingNetwork". If |type| is proto, this
+ // should be the name of the proto serialized as |data|, e.g.,
+ // "EmbeddingNetworkProto".
+ sub_type:string;
+
+ // The content of this input. With a generous alignment, such that we can
+ // accommodate mmap-friendly data structures. E.g., the word clusters used by
+ // the Translate team require 8-byte alignment.
+ data:[ubyte] (force_align: 16);
+}
+
+// A Saft model. A list of parameters with model settings (e.g., the
+// specification of the features to use) and a list of inputs.
+table Model {
+ parameters:[ModelParameter];
+ inputs:[ModelInput];
+
+ // Crc32 checksum of all parameters and inputs (including the bytes of the
+ // inputs). Used to check that the model has not been corrupted.
+ crc32:uint32;
+}
+
+root_type Model;
diff --git a/lang_id/common/lite_base/attributes.h b/lang_id/common/lite_base/attributes.h
new file mode 100644
index 0000000..f29e48f
--- /dev/null
+++ b/lang_id/common/lite_base/attributes.h
@@ -0,0 +1,82 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Various macros related to function inlining.
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_BASE_ATTRIBUTES_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_BASE_ATTRIBUTES_H_
+
+// SAFTM_HAVE_ATTRIBUTE
+//
+// A function-like feature checking macro that is a wrapper around
+// `__has_attribute`, which is defined by GCC 5+ and Clang and evaluates to a
+// nonzero constant integer if the attribute is supported or 0 if not.
+//
+// It evaluates to zero if `__has_attribute` is not defined by the compiler.
+//
+// GCC: https://gcc.gnu.org/gcc-5/changes.html
+// Clang: https://clang.llvm.org/docs/LanguageExtensions.html
+#ifdef __has_attribute
+#define SAFTM_HAVE_ATTRIBUTE(x) __has_attribute(x)
+#else
+#define SAFTM_HAVE_ATTRIBUTE(x) 0
+#endif
+
+// SAFTM_MUST_USE_RESULT
+//
+// Tells the compiler to warn about unused return values for functions declared
+// with this macro. The macro must appear as the very first part of a function
+// declaration or definition:
+//
+// Example:
+//
+// SAFTM_MUST_USE_RESULT Sprocket* AllocateSprocket();
+//
+// This placement has the broadest compatibility with GCC, Clang, and MSVC, with
+// both defs and decls, and with GCC-style attributes, MSVC declspec, C++11
+// and C++17 attributes.
+//
+// SAFTM_MUST_USE_RESULT allows using cast-to-void to suppress the unused result
+// warning. For that, warn_unused_result is used only for clang but not for gcc.
+// https://gcc.gnu.org/bugzilla/show_bug.cgi?id=66425
+#if SAFTM_HAVE_ATTRIBUTE(nodiscard)
+#define SAFTM_MUST_USE_RESULT [[nodiscard]]
+#elif defined(__clang__) && SAFTM_HAVE_ATTRIBUTE(warn_unused_result)
+#define SAFTM_MUST_USE_RESULT __attribute__((warn_unused_result))
+#else
+#define SAFTM_MUST_USE_RESULT
+#endif
+
+#if defined(__GNUC__) && \
+ (__GNUC__ > 3 || (__GNUC__ == 3 && __GNUC_MINOR__ >= 1))
+
+// For functions we want to force inline.
+// Introduced in gcc 3.1.
+#define SAFTM_ATTRIBUTE_ALWAYS_INLINE __attribute__((always_inline))
+
+// For functions we don't want to inline, e.g., to keep code size small.
+#define SAFTM_ATTRIBUTE_NOINLINE __attribute__((noinline))
+
+#elif defined(_MSC_VER)
+#define SAFTM_ATTRIBUTE_ALWAYS_INLINE __forceinline
+#else
+
+// Other compilers will have to figure it out for themselves.
+#define SAFTM_ATTRIBUTE_ALWAYS_INLINE
+#define SAFTM_ATTRIBUTE_NOINLINE
+#endif // big condition on two lines.
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_BASE_ATTRIBUTES_H_
diff --git a/lang_id/common/lite_base/casts.h b/lang_id/common/lite_base/casts.h
new file mode 100644
index 0000000..11a4ba2
--- /dev/null
+++ b/lang_id/common/lite_base/casts.h
@@ -0,0 +1,95 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_BASE_CASTS_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_BASE_CASTS_H_
+
+#include <string.h> // for memcpy
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+// bit_cast<Dest, Source> is a template function that implements the equivalent
+// of "*reinterpret_cast<Dest*>(&source)". We need this in very low-level
+// functions like fast math support.
+//
+// float f = 3.14159265358979;
+// int i = bit_cast<int32>(f);
+// // i = 0x40490fdb
+//
+// The classical address-casting method is:
+//
+// // WRONG
+// float f = 3.14159265358979; // WRONG
+// int i = * reinterpret_cast<int*>(&f); // WRONG
+//
+// The address-casting method actually produces undefined behavior
+// according to ISO C++ specification section 3.10 -15 -. Roughly, this
+// section says: if an object in memory has one type, and a program
+// accesses it with a different type, then the result is undefined
+// behavior for most values of "different type".
+//
+// This is true for any cast syntax, either *(int*)&f or
+// *reinterpret_cast<int*>(&f). And it is particularly true for
+// conversions between integral lvalues and floating-point lvalues.
+//
+// The purpose of 3.10 -15- is to allow optimizing compilers to assume
+// that expressions with different types refer to different memory. gcc
+// 4.0.1 has an optimizer that takes advantage of this. So a
+// non-conforming program quietly produces wildly incorrect output.
+//
+// The problem is not the use of reinterpret_cast. The problem is type
+// punning: holding an object in memory of one type and reading its bits
+// back using a different type.
+//
+// The C++ standard is more subtle and complex than this, but that
+// is the basic idea.
+//
+// Anyways ...
+//
+// bit_cast<> calls memcpy() which is blessed by the standard, especially by the
+// example in section 3.9 . Also, of course, bit_cast<> wraps up the nasty
+// logic in one place.
+//
+// Fortunately memcpy() is very fast. In optimized mode, with a
+// constant size, gcc 2.95.3, gcc 4.0.1, and msvc 7.1 produce inline
+// code with the minimal amount of data movement. On a 32-bit system,
+// memcpy(d,s,4) compiles to one load and one store, and memcpy(d,s,8)
+// compiles to two loads and two stores.
+//
+// I tested this code with gcc 2.95.3, gcc 4.0.1, icc 8.1, and msvc 7.1.
+//
+// WARNING: if Dest or Source is a non-POD type, the result of the memcpy
+// is likely to surprise you.
+//
+// Props to Bill Gibbons for the compile time assertion technique and
+// Art Komninos and Igor Tandetnik for the msvc experiments.
+//
+// -- mec 2005-10-17
+
+template <class Dest, class Source>
+inline Dest bit_cast(const Source &source) {
+ static_assert(sizeof(Dest) == sizeof(Source), "Sizes do not match");
+
+ Dest dest;
+ memcpy(&dest, &source, sizeof(dest));
+ return dest;
+}
+
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_BASE_CASTS_H_
diff --git a/lang_id/common/lite_base/compact-logging-levels.h b/lang_id/common/lite_base/compact-logging-levels.h
new file mode 100644
index 0000000..977f4da
--- /dev/null
+++ b/lang_id/common/lite_base/compact-logging-levels.h
@@ -0,0 +1,50 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef TC3_STD_STRING_IMPORT
+#define TC3_STD_STRING_IMPORT
+#include <string>
+
+namespace libtextclassifier3 {
+using string = std::string;
+template <class CharT, class Traits = std::char_traits<CharT>,
+ class Allocator = std::allocator<CharT> >
+using basic_string = std::basic_string<CharT, Traits, Allocator>;
+} // namespace libtextclassifier3
+#endif
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_BASE_COMPACT_LOGGING_LEVELS_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_BASE_COMPACT_LOGGING_LEVELS_H_
+
+namespace libtextclassifier3 {
+namespace mobile {
+namespace internal_logging {
+
+enum LogSeverity {
+ FATAL = 0,
+ ERROR,
+ WARNING,
+ INFO,
+
+ // In debug mode, DFATAL has the same semantics as FATAL. Otherwise, it
+ // behaves like ERROR.
+ DFATAL,
+};
+
+} // namespace internal_logging
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_BASE_COMPACT_LOGGING_LEVELS_H_
diff --git a/lang_id/common/lite_base/compact-logging-raw.cc b/lang_id/common/lite_base/compact-logging-raw.cc
new file mode 100644
index 0000000..53dfc8e
--- /dev/null
+++ b/lang_id/common/lite_base/compact-logging-raw.cc
@@ -0,0 +1,103 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/common/lite_base/compact-logging-raw.h"
+
+#include <stdio.h>
+#include <string>
+
+// NOTE: this file contains two implementations: one for Android, one for all
+// other cases. We always build exactly one implementation.
+#if defined(__ANDROID__)
+
+// Compiled as part of Android.
+#include <android/log.h>
+
+namespace libtextclassifier3 {
+namespace mobile {
+namespace internal_logging {
+
+namespace {
+// Converts LogSeverity to level for __android_log_write.
+int GetAndroidLogLevel(LogSeverity severity) {
+ switch (severity) {
+ case FATAL:
+ return ANDROID_LOG_FATAL;
+ case ERROR:
+ return ANDROID_LOG_ERROR;
+ case WARNING:
+ return ANDROID_LOG_WARN;
+ case INFO:
+ return ANDROID_LOG_INFO;
+ default:
+ return ANDROID_LOG_DEBUG;
+ }
+}
+} // namespace
+
+void LowLevelLogging(LogSeverity severity, const string &tag,
+ const string &message) {
+ const int android_log_level = GetAndroidLogLevel(severity);
+#if !defined(SAFTM_DEBUG_LOGGING)
+ if (android_log_level != ANDROID_LOG_ERROR &&
+ android_log_level != ANDROID_LOG_FATAL) {
+ return;
+ }
+#endif
+ __android_log_write(android_log_level, tag.c_str(), message.c_str());
+}
+
+} // namespace internal_logging
+} // namespace mobile
+} // namespace nlp_saft
+
+#else // if defined(__ANDROID__)
+
+// Not on Android: implement LowLevelLogging to print to stderr (see below).
+namespace libtextclassifier3 {
+namespace mobile {
+namespace internal_logging {
+
+namespace {
+// Converts LogSeverity to human-readable text.
+const char *LogSeverityToString(LogSeverity severity) {
+ switch (severity) {
+ case INFO:
+ return "INFO";
+ case WARNING:
+ return "WARNING";
+ case ERROR:
+ return "ERROR";
+ case FATAL:
+ return "FATAL";
+ default:
+ return "UNKNOWN";
+ }
+}
+} // namespace
+
+void LowLevelLogging(LogSeverity severity, const string &tag,
+ const string &message) {
+ fprintf(stderr, "[%s] %s : %s\n", LogSeverityToString(severity), tag.c_str(),
+ message.c_str());
+ fflush(stderr);
+}
+
+} // namespace internal_logging
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // if defined(__ANDROID__)
diff --git a/lang_id/common/lite_base/compact-logging-raw.h b/lang_id/common/lite_base/compact-logging-raw.h
new file mode 100644
index 0000000..f67287c
--- /dev/null
+++ b/lang_id/common/lite_base/compact-logging-raw.h
@@ -0,0 +1,38 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_BASE_COMPACT_LOGGING_RAW_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_BASE_COMPACT_LOGGING_RAW_H_
+
+#include <string>
+
+#include "lang_id/common/lite_base/compact-logging-levels.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+namespace internal_logging {
+
+// Low-level logging primitive. Logs a message, with the indicated log
+// severity. From android/log.h: "the tag normally corresponds to the component
+// that emits the log message, and should be reasonably small".
+void LowLevelLogging(LogSeverity severity, const string &tag,
+ const string &message);
+
+} // namespace internal_logging
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_BASE_COMPACT_LOGGING_RAW_H_
diff --git a/lang_id/common/lite_base/compact-logging.cc b/lang_id/common/lite_base/compact-logging.cc
new file mode 100644
index 0000000..99d60a3
--- /dev/null
+++ b/lang_id/common/lite_base/compact-logging.cc
@@ -0,0 +1,84 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/common/lite_base/compact-logging.h"
+
+#include <stdlib.h>
+
+#include <iostream>
+
+#include "lang_id/common/lite_base/compact-logging-raw.h"
+
+#ifndef SAFTM_LOGGING_TAG
+
+// Tag inserted in the prefix of the generated log messages. The user can
+// override this by defining this macro on the blaze build command-line.
+#define SAFTM_LOGGING_TAG "saftm"
+#endif
+
+namespace libtextclassifier3 {
+namespace mobile {
+namespace internal_logging {
+
+namespace {
+// Returns pointer to beginning of last /-separated token from file_name.
+// file_name should be a pointer to a zero-terminated array of chars.
+// E.g., "foo/bar.cc" -> "bar.cc", "foo/" -> "", "foo" -> "foo".
+const char *JumpToBasename(const char *file_name) {
+ if (file_name == nullptr) {
+ return nullptr;
+ }
+
+ // Points to the beginning of the last encountered token.
+ const char *last_token_start = file_name;
+ while (*file_name != '\0') {
+ if (*file_name == '/') {
+ // Found token separator. A new (potentially empty) token starts after
+ // this position. Notice that if file_name is a valid zero-terminated
+ // string, file_name + 1 is a valid pointer (there is at least one char
+ // after address file_name, the zero terminator).
+ last_token_start = file_name + 1;
+ }
+ file_name++;
+ }
+ return last_token_start;
+}
+} // namespace
+
+LogMessage::LogMessage(LogSeverity severity, const char *file_name,
+ int line_number)
+ : severity_(severity) {
+ stream_ << JumpToBasename(file_name) << ":" << line_number << ": ";
+}
+
+LogMessage::~LogMessage() {
+ LogSeverity level = severity_;
+ if (level == DFATAL) {
+#ifdef NDEBUG
+ level = ERROR;
+#else
+ level = FATAL;
+#endif
+ }
+ LowLevelLogging(level, /* tag = */ SAFTM_LOGGING_TAG, stream_.message);
+ if (level == FATAL) {
+ exit(1);
+ }
+}
+
+} // namespace internal_logging
+} // namespace mobile
+} // namespace nlp_saft
diff --git a/lang_id/common/lite_base/compact-logging.h b/lang_id/common/lite_base/compact-logging.h
new file mode 100644
index 0000000..eccb7d1
--- /dev/null
+++ b/lang_id/common/lite_base/compact-logging.h
@@ -0,0 +1,177 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_BASE_COMPACT_LOGGING_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_BASE_COMPACT_LOGGING_H_
+
+#include <cassert>
+#include <string>
+
+#include "lang_id/common/lite_base/attributes.h"
+#include "lang_id/common/lite_base/compact-logging-levels.h"
+#include "lang_id/common/lite_strings/stringpiece.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+namespace internal_logging {
+
+// A tiny code footprint string stream for assembling log messages.
+struct LoggingStringStream {
+ LoggingStringStream() {}
+ LoggingStringStream &stream() { return *this; }
+
+ // Needed for invocation in SAFTM_CHECK macro.
+ explicit operator bool() const { return true; }
+
+ string message;
+};
+
+template <typename T>
+inline LoggingStringStream &operator<<(LoggingStringStream &stream,
+ const T &entry) {
+ stream.message.append(std::to_string(entry));
+ return stream;
+}
+
+inline LoggingStringStream &operator<<(LoggingStringStream &stream,
+ const char *message) {
+ stream.message.append(message);
+ return stream;
+}
+
+inline LoggingStringStream &operator<<(LoggingStringStream &stream,
+ const string &message) {
+ stream.message.append(message);
+ return stream;
+}
+
+inline LoggingStringStream &operator<<(LoggingStringStream &stream,
+ StringPiece sp) {
+ stream.message.append(sp.data(), sp.size());
+ return stream;
+}
+
+// The class that does all the work behind our SAFTM_LOG(severity) macros. Each
+// SAFTM_LOG(severity) << obj1 << obj2 << ...; logging statement creates a
+// LogMessage temporary object containing a stringstream. Each operator<< adds
+// info to that stringstream and the LogMessage destructor performs the actual
+// logging. The reason this works is that in C++, "all temporary objects are
+// destroyed as the last step in evaluating the full-expression that (lexically)
+// contains the point where they were created." For more info, see
+// http://en.cppreference.com/w/cpp/language/lifetime. Hence, the destructor is
+// invoked after the last << from that logging statement.
+class LogMessage {
+ public:
+ LogMessage(LogSeverity severity, const char *file_name,
+ int line_number) SAFTM_ATTRIBUTE_NOINLINE;
+
+ ~LogMessage() SAFTM_ATTRIBUTE_NOINLINE;
+
+ // Returns the stream associated with the logger object.
+ LoggingStringStream &stream() { return stream_; }
+
+ private:
+ const LogSeverity severity_;
+
+ // Stream that "prints" all info into a string (not to a file). We construct
+ // here the entire logging message and next print it in one operation.
+ LoggingStringStream stream_;
+};
+
+// Pseudo-stream that "eats" the tokens <<-pumped into it, without printing
+// anything.
+class NullStream {
+ public:
+ NullStream() {}
+ NullStream &stream() { return *this; }
+};
+template <typename T>
+inline NullStream &operator<<(NullStream &str, const T &) {
+ return str;
+}
+
+} // namespace internal_logging
+} // namespace mobile
+} // namespace nlp_saft
+
+#define SAFTM_LOG(severity) \
+ ::libtextclassifier3::mobile::internal_logging::LogMessage( \
+ ::libtextclassifier3::mobile::internal_logging::severity, __FILE__, __LINE__) \
+ .stream()
+
+// If condition x is true, does nothing. Otherwise, crashes the program (liek
+// LOG(FATAL)) with an informative message. Can be continued with extra
+// messages, via <<, like any logging macro, e.g.,
+//
+// SAFTM_CHECK(my_cond) << "I think we hit a problem";
+#define SAFTM_CHECK(x) \
+ (x) || SAFTM_LOG(FATAL) << __FILE__ << ":" << __LINE__ \
+ << ": check failed: \"" << #x
+
+#define SAFTM_CHECK_EQ(x, y) SAFTM_CHECK((x) == (y))
+#define SAFTM_CHECK_LT(x, y) SAFTM_CHECK((x) < (y))
+#define SAFTM_CHECK_GT(x, y) SAFTM_CHECK((x) > (y))
+#define SAFTM_CHECK_LE(x, y) SAFTM_CHECK((x) <= (y))
+#define SAFTM_CHECK_GE(x, y) SAFTM_CHECK((x) >= (y))
+#define SAFTM_CHECK_NE(x, y) SAFTM_CHECK((x) != (y))
+
+#define SAFTM_NULLSTREAM \
+ ::libtextclassifier3::mobile::internal_logging::NullStream().stream()
+
+// Debug checks: a SAFTM_DCHECK<suffix> macro should behave like
+// SAFTM_CHECK<suffix> in debug mode an don't check / don't print anything in
+// non-debug mode.
+#ifdef NDEBUG
+
+#define SAFTM_DCHECK(x) SAFTM_NULLSTREAM
+#define SAFTM_DCHECK_EQ(x, y) SAFTM_NULLSTREAM
+#define SAFTM_DCHECK_LT(x, y) SAFTM_NULLSTREAM
+#define SAFTM_DCHECK_GT(x, y) SAFTM_NULLSTREAM
+#define SAFTM_DCHECK_LE(x, y) SAFTM_NULLSTREAM
+#define SAFTM_DCHECK_GE(x, y) SAFTM_NULLSTREAM
+#define SAFTM_DCHECK_NE(x, y) SAFTM_NULLSTREAM
+
+// In non-debug mode, SAFT_DLOG statements do not generate any logging.
+#define SAFTM_DLOG(severity) SAFTM_NULLSTREAM
+
+#else // NDEBUG
+
+// In debug mode, each SAFTM_DCHECK<suffix> is equivalent to
+// SAFTM_CHECK<suffix>, i.e., a real check that crashes when the condition is
+// not true.
+#define SAFTM_DCHECK(x) SAFTM_CHECK(x)
+#define SAFTM_DCHECK_EQ(x, y) SAFTM_CHECK_EQ(x, y)
+#define SAFTM_DCHECK_LT(x, y) SAFTM_CHECK_LT(x, y)
+#define SAFTM_DCHECK_GT(x, y) SAFTM_CHECK_GT(x, y)
+#define SAFTM_DCHECK_LE(x, y) SAFTM_CHECK_LE(x, y)
+#define SAFTM_DCHECK_GE(x, y) SAFTM_CHECK_GE(x, y)
+#define SAFTM_DCHECK_NE(x, y) SAFTM_CHECK_NE(x, y)
+
+// In debug mode, SAFT_DLOG statements are like SAFT_LOG.
+#define SAFTM_DLOG SAFTM_LOG
+
+#endif // NDEBUG
+
+#ifdef LIBTEXTCLASSIFIER_VLOG
+#define SAFTM_VLOG(severity) \
+ ::libtextclassifier3::mobile::internal_logging::LogMessage( \
+ ::libtextclassifier3::mobile::internal_logging::INFO, __FILE__, __LINE__) \
+ .stream()
+#else
+#define SAFTM_VLOG(severity) SAFTM_NULLSTREAM
+#endif
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_BASE_COMPACT_LOGGING_H_
diff --git a/lang_id/common/lite_base/endian.h b/lang_id/common/lite_base/endian.h
new file mode 100644
index 0000000..16c2dca
--- /dev/null
+++ b/lang_id/common/lite_base/endian.h
@@ -0,0 +1,126 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_BASE_ENDIAN_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_BASE_ENDIAN_H_
+
+#include "lang_id/common/lite_base/integral-types.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+#if defined OS_LINUX || defined OS_CYGWIN || defined OS_ANDROID || \
+ defined(__ANDROID__)
+#include <endian.h>
+#endif
+
+// The following guarantees declaration of the byte swap functions, and
+// defines __BYTE_ORDER for MSVC
+#if defined(__GLIBC__) || defined(__CYGWIN__)
+#include <byteswap.h> // IWYU pragma: export
+
+#else
+#ifndef bswap_16
+static inline uint16 bswap_16(uint16 x) {
+ return (uint16)(((x & 0xFF) << 8) | ((x & 0xFF00) >> 8)); // NOLINT
+}
+#define bswap_16(x) bswap_16(x)
+#endif // bswap_16
+
+#ifndef bswap_32
+static inline uint32 bswap_32(uint32 x) {
+ return (((x & 0xFF) << 24) | ((x & 0xFF00) << 8) | ((x & 0xFF0000) >> 8) |
+ ((x & 0xFF000000) >> 24));
+}
+#define bswap_32(x) bswap_32(x)
+#endif // bswap_32
+
+#ifndef bswap_64
+#define SAFTM_GG_ULONGLONG(x) x##ULL
+static inline uint64 bswap_64(uint64 x) {
+ return (((x & SAFTM_GG_ULONGLONG(0xFF)) << 56) |
+ ((x & SAFTM_GG_ULONGLONG(0xFF00)) << 40) |
+ ((x & SAFTM_GG_ULONGLONG(0xFF0000)) << 24) |
+ ((x & SAFTM_GG_ULONGLONG(0xFF000000)) << 8) |
+ ((x & SAFTM_GG_ULONGLONG(0xFF00000000)) >> 8) |
+ ((x & SAFTM_GG_ULONGLONG(0xFF0000000000)) >> 24) |
+ ((x & SAFTM_GG_ULONGLONG(0xFF000000000000)) >> 40) |
+ ((x & SAFTM_GG_ULONGLONG(0xFF00000000000000)) >> 56));
+}
+#define bswap_64(x) bswap_64(x)
+#endif // bswap_64
+
+#endif
+
+// define the macros SAFTM_IS_LITTLE_ENDIAN or SAFTM_IS_BIG_ENDIAN using the
+// above endian definitions from endian.h if endian.h was included
+#ifdef __BYTE_ORDER
+#if __BYTE_ORDER == __LITTLE_ENDIAN
+#define SAFTM_IS_LITTLE_ENDIAN
+#endif
+
+#if __BYTE_ORDER == __BIG_ENDIAN
+#define SAFTM_IS_BIG_ENDIAN
+#endif
+
+#else // __BYTE_ORDER
+
+#if defined(__LITTLE_ENDIAN__)
+#define SAFTM_IS_LITTLE_ENDIAN
+#elif defined(__BIG_ENDIAN__)
+#define SAFTM_IS_BIG_ENDIAN
+#endif
+
+// there is also PDP endian ...
+
+#endif // __BYTE_ORDER
+
+class LittleEndian {
+ public:
+// Conversion functions.
+#ifdef SAFTM_IS_LITTLE_ENDIAN
+
+ static uint16 FromHost16(uint16 x) { return x; }
+ static uint16 ToHost16(uint16 x) { return x; }
+
+ static uint32 FromHost32(uint32 x) { return x; }
+ static uint32 ToHost32(uint32 x) { return x; }
+
+ static uint64 FromHost64(uint64 x) { return x; }
+ static uint64 ToHost64(uint64 x) { return x; }
+
+ static bool IsLittleEndian() { return true; }
+
+#elif defined SAFTM_IS_BIG_ENDIAN
+
+ static uint16 FromHost16(uint16 x) { return gbswap_16(x); }
+ static uint16 ToHost16(uint16 x) { return gbswap_16(x); }
+
+ static uint32 FromHost32(uint32 x) { return gbswap_32(x); }
+ static uint32 ToHost32(uint32 x) { return gbswap_32(x); }
+
+ static uint64 FromHost64(uint64 x) { return gbswap_64(x); }
+ static uint64 ToHost64(uint64 x) { return gbswap_64(x); }
+
+ static bool IsLittleEndian() { return false; }
+
+#endif /* ENDIAN */
+};
+
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_BASE_ENDIAN_H_
diff --git a/lang_id/common/lite_base/float16.h b/lang_id/common/lite_base/float16.h
new file mode 100644
index 0000000..bc3fd21
--- /dev/null
+++ b/lang_id/common/lite_base/float16.h
@@ -0,0 +1,62 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_BASE_FLOAT16_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_BASE_FLOAT16_H_
+
+#include "lang_id/common/lite_base/casts.h"
+#include "lang_id/common/lite_base/integral-types.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+// 16 bit encoding of a float. NOTE: can't be used directly for computation:
+// one first needs to convert it to a normal float, using Float16To32.
+//
+// Compact 16-bit encoding of floating point numbers. This
+// representation uses 1 bit for the sign, 8 bits for the exponent and
+// 7 bits for the mantissa. It is assumed that floats are in IEEE 754
+// format so a float16 is just bits 16-31 of a single precision float.
+//
+// NOTE: The IEEE floating point standard defines a float16 format that
+// is different than this format (it has fewer bits of exponent and more
+// bits of mantissa). We don't use that format here because conversion
+// to/from 32-bit floats is more complex for that format, and the
+// conversion for this format is very simple.
+//
+// <---------float16------------>
+// s e e e e e e e e f f f f f f f f f f f f f f f f f f f f f f f
+// <------------------------------float-------------------------->
+// 3 3 2 2 1 1 0
+// 1 0 3 2 5 4 0
+
+typedef uint16 float16;
+
+static inline float16 Float32To16(float f) {
+ // Note that we just truncate the mantissa bits: we make no effort to
+ // do any smarter rounding.
+ return (bit_cast<uint32>(f) >> 16) & 0xffff;
+}
+
+static inline float Float16To32(float16 f) {
+ // We fill in the new mantissa bits with 0, and don't do anything smarter.
+ return bit_cast<float>(f << 16);
+}
+
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_BASE_FLOAT16_H_
diff --git a/lang_id/common/lite_base/integral-types.h b/lang_id/common/lite_base/integral-types.h
new file mode 100644
index 0000000..4c3038c
--- /dev/null
+++ b/lang_id/common/lite_base/integral-types.h
@@ -0,0 +1,58 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Basic integer type definitions.
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_BASE_INTEGRAL_TYPES_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_BASE_INTEGRAL_TYPES_H_
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+typedef unsigned int uint32;
+typedef unsigned long long uint64;
+
+#ifndef SWIG
+typedef int int32;
+typedef unsigned char uint8; // NOLINT
+typedef unsigned short uint16; // NOLINT
+
+// A type to represent a Unicode code-point value. As of Unicode 4.0,
+// such values require up to 21 bits.
+// (For type-checking on pointers, make this explicitly signed,
+// and it should always be the signed version of whatever int32 is.)
+typedef signed int char32;
+#endif // SWIG
+
+#ifdef COMPILER_MSVC
+typedef __int64 int64;
+#else
+typedef long long int64; // NOLINT
+#endif // COMPILER_MSVC
+
+// Some compile-time assertions that our new types have the intended size.
+static_assert(sizeof(int) == 4, "Our typedefs depend on int being 32 bits");
+static_assert(sizeof(uint32) == 4, "wrong size");
+static_assert(sizeof(int32) == 4, "wrong size");
+static_assert(sizeof(uint8) == 1, "wrong size");
+static_assert(sizeof(uint16) == 2, "wrong size");
+static_assert(sizeof(char32) == 4, "wrong size");
+static_assert(sizeof(int64) == 8, "wrong size");
+
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_BASE_INTEGRAL_TYPES_H_
diff --git a/lang_id/common/lite_base/logging.h b/lang_id/common/lite_base/logging.h
new file mode 100644
index 0000000..88797cb
--- /dev/null
+++ b/lang_id/common/lite_base/logging.h
@@ -0,0 +1,52 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_BASE_LOGGING_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_BASE_LOGGING_H_
+
+#ifdef SAFTM_COMPACT_LOGGING
+
+// One gets the compact logging only one requests it explicitly, by passing
+// --define saftm_compact_logging=true on the blaze command-line.
+#include "lang_id/common/lite_base/compact-logging.h"
+
+#else
+
+// Otherwise, one gets the standard base/logging.h You should do so, unless you
+// have a really good reason to switch to the compact logging.
+#include "base/logging.h"
+
+#define SAFTM_LOG LOG
+#define SAFTM_CHECK CHECK
+#define SAFTM_CHECK_EQ CHECK_EQ
+#define SAFTM_CHECK_LT CHECK_LT
+#define SAFTM_CHECK_LE CHECK_LE
+#define SAFTM_CHECK_GT CHECK_GT
+#define SAFTM_CHECK_GE CHECK_GE
+#define SAFTM_CHECK_NE CHECK_NE
+
+#define SAFTM_DLOG DLOG
+#define SAFTM_DCHECK DCHECK
+#define SAFTM_DCHECK_EQ DCHECK_EQ
+#define SAFTM_DCHECK_LT DCHECK_LT
+#define SAFTM_DCHECK_LE DCHECK_LE
+#define SAFTM_DCHECK_GT DCHECK_GT
+#define SAFTM_DCHECK_GE DCHECK_GE
+#define SAFTM_DCHECK_NE DCHECK_NE
+
+#endif // SAFTM_COMPACT_LOGGING
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_BASE_LOGGING_H_
diff --git a/lang_id/common/lite_base/macros.h b/lang_id/common/lite_base/macros.h
new file mode 100644
index 0000000..8fe5e8a
--- /dev/null
+++ b/lang_id/common/lite_base/macros.h
@@ -0,0 +1,82 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_BASE_MACROS_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_BASE_MACROS_H_
+
+#define SAFTM_DISALLOW_COPY_AND_ASSIGN(TypeName) \
+ TypeName(const TypeName &) = delete; \
+ TypeName &operator=(const TypeName &) = delete
+
+// The SAFTM_FALLTHROUGH_INTENDED macro can be used to annotate implicit
+// fall-through between switch labels:
+//
+// switch (x) {
+// case 40:
+// case 41:
+// if (truth_is_out_there) {
+// ++x;
+// SAFTM_FALLTHROUGH_INTENDED; // Use instead of/along with annotations
+// // in comments.
+// } else {
+// return x;
+// }
+// case 42:
+// ...
+//
+// As shown in the example above, the SAFTM_FALLTHROUGH_INTENDED macro should
+// be followed by a semicolon. It is designed to mimic control-flow statements
+// like 'break;', so it can be placed in most places where 'break;' can, but
+// only if there are no statements on the execution path between it and the
+// next switch label.
+//
+// When compiled with clang, the SAFTM_FALLTHROUGH_INTENDED macro is expanded
+// to [[clang::fallthrough]] attribute, which is analysed when performing
+// switch labels fall-through diagnostic ('-Wimplicit-fallthrough'). See clang
+// documentation on language extensions for details:
+// http://clang.llvm.org/docs/AttributeReference.html#fallthrough-clang-fallthrough
+//
+// When used with unsupported compilers, the SAFTM_FALLTHROUGH_INTENDED macro
+// has no effect on diagnostics.
+//
+// In either case this macro has no effect on runtime behavior and performance
+// of code.
+#if defined(__clang__) && defined(__has_warning)
+#if __has_feature(cxx_attributes) && __has_warning("-Wimplicit-fallthrough")
+#define SAFTM_FALLTHROUGH_INTENDED [[clang::fallthrough]] // NOLINT
+#endif
+#endif
+
+#ifndef SAFTM_FALLTHROUGH_INTENDED
+#define SAFTM_FALLTHROUGH_INTENDED \
+ do { \
+ } while (0)
+#endif
+
+// SAFTM_UNIQUE_ID(prefix) expands to a unique id that starts with prefix.
+//
+// The current implementation expands to prefix_<line_number>; hence, multiple
+// uses of this macro with the same prefix and on the same line will result in
+// the same identifier name. In those cases, if you need different ids, we
+// suggest you use different prefixes.
+//
+// Implementation is tricky; for more info, see
+// https://stackoverflow.com/questions/1597007/creating-c-macro-with-and-line-token-concatenation-with-positioning-macr
+#define SAFTM_UNIQUE_ID_INTERNAL2(x, y) x ## y
+#define SAFTM_UNIQUE_ID_INTERNAL(x, y) SAFTM_UNIQUE_ID_INTERNAL2(x, y)
+#define SAFTM_UNIQUE_ID(prefix) SAFTM_UNIQUE_ID_INTERNAL(prefix ## _, __LINE__)
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_BASE_MACROS_H_
diff --git a/lang_id/common/lite_strings/numbers.cc b/lang_id/common/lite_strings/numbers.cc
new file mode 100644
index 0000000..e0c66f3
--- /dev/null
+++ b/lang_id/common/lite_strings/numbers.cc
@@ -0,0 +1,86 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/common/lite_strings/numbers.h"
+
+#include <ctype.h>
+#include <stdlib.h>
+#include <climits>
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+// Returns true if the characters that start at address ptr (inclusive) and stop
+// at the first '\0' consist of only whitespaces, as determined by isspace().
+// Note: this function returns false if ptr is nullptr.
+static bool OnlyWhitespaces(const char *ptr) {
+ if (ptr == nullptr) {
+ return false;
+ }
+ for (; *ptr != '\0'; ++ptr) {
+ if (!isspace(*ptr)) {
+ return false;
+ }
+ }
+ return true;
+}
+
+bool LiteAtoi(const char *c_str, int *value) {
+ if (c_str == nullptr) {
+ return false;
+ }
+
+ // Short version of man strtol:
+ //
+ // strtol parses some optional whitespaces, an optional +/- sign, and next a
+ // succession of digits. If it finds some digits, it sets temp to point to
+ // the first character after that succession of digits and returns the parsed
+ // integer.
+ //
+ // If there were no digits at all, strtol() sets temp to be c_str (the start
+ // address) and returns 0.
+ char *temp = nullptr;
+ const long int parsed_value = strtol(c_str, &temp, 0); // NOLINT
+
+ // Check for overflow. Note: to simplify the code, we assume that LONG_MIN /
+ // LONG_MAX means that strtol encountered an overflow (normally, in that case,
+ // one should also inspect errno). Hence, we maybe give up the possibility to
+ // parse one extreme value on each side (min/max). That should be ok.
+ if ((parsed_value == LONG_MIN) || (parsed_value == LONG_MAX) ||
+ (parsed_value < INT_MIN) || (parsed_value > INT_MAX)) {
+ return false;
+ }
+ *value = static_cast<int>(parsed_value);
+
+ // First part of the expression below means that the input string contained at
+ // least one digit. The other part checks that what remains after the number
+ // (if anything) consists only of whitespaces.
+ return (temp != c_str) && OnlyWhitespaces(temp);
+}
+
+bool LiteAtof(const char *c_str, float *value) {
+ if (c_str == nullptr) {
+ return false;
+ }
+
+ // strtof is similar to strtol, see more detailed comments inside LiteAtoi.
+ char *temp = nullptr;
+ *value = strtof(c_str, &temp);
+ return (temp != c_str) && OnlyWhitespaces(temp);
+}
+
+} // namespace mobile
+} // namespace nlp_saft
diff --git a/lang_id/common/lite_strings/numbers.h b/lang_id/common/lite_strings/numbers.h
new file mode 100644
index 0000000..4b3c93c
--- /dev/null
+++ b/lang_id/common/lite_strings/numbers.h
@@ -0,0 +1,74 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_STRINGS_NUMBERS_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_STRINGS_NUMBERS_H_
+
+#include <string>
+
+#include "lang_id/common/lite_strings/stringpiece.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+// Parses an int from a C-style string; similar to absl::SimpleAtoi.
+//
+// c_str should point to a zero-terminated array of chars that contains the
+// number representation as (a) "<radix-10-number>" (e.g., "721"), (b)
+// "0x<radix-16-number>" (e.g., "0xa1"), or (c) "0<radix-8-number>" (e.g.,
+// "017201"). Whitespaces (as determined by isspace()) are allowed before and
+// after the number representation (but obviously not in the middle).
+//
+// Stores parsed number into *value. Returns true on success, false on error.
+// Note: presence of extra non-whitespace characters after the number counts as
+// an error: e.g., parsing "123a" will return false due to the extra "a" (which
+// is not a valid radix-10 digit). This function also returns false for strings
+// that do not contain any digit (e.g., ""), as well as for overflows /
+// underflows.
+bool LiteAtoi(const char *c_str, int *value);
+
+inline bool LiteAtoi(const string &s, int *value) {
+ return LiteAtoi(s.c_str(), value);
+}
+
+inline bool LiteAtoi(StringPiece sp, int *value) {
+ // Unfortunately, we can't directly call LiteAtoi(sp.data()): LiteAtoi(const
+ // char *) needs a zero-terminated string.
+ const string temp(sp.data(), sp.size());
+ return LiteAtoi(temp.c_str(), value);
+}
+
+// Like LiteAtoi, but for float; similar to absl::SimpleAtof.
+//
+// NOTE: currently, does not properly handle overflow / underflow.
+// TODO(salcianu): fix that.
+bool LiteAtof(const char *c_str, float *value);
+
+inline bool LiteAtof(const string &s, float *value) {
+ return LiteAtof(s.c_str(), value);
+}
+
+inline bool LiteAtof(StringPiece sp, float *value) {
+ // Unfortunately, we can't directly call LiteAtoi(sp.data()): LiteAtoi(const
+ // char *) needs a zero-terminated string.
+ const string temp(sp.data(), sp.size());
+ return LiteAtof(temp.c_str(), value);
+}
+
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_STRINGS_NUMBERS_H_
diff --git a/lang_id/common/lite_strings/str-cat.h b/lang_id/common/lite_strings/str-cat.h
new file mode 100644
index 0000000..f0c1682
--- /dev/null
+++ b/lang_id/common/lite_strings/str-cat.h
@@ -0,0 +1,98 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_STRINGS_STR_CAT_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_STRINGS_STR_CAT_H_
+
+// Less efficient but more compact versions of several absl string utils.
+//
+// "More compact" means "pulls in fewer code dependencies". That's useful if
+// one tries to minimize the code size.
+//
+// Note: the name and the signature of the functions from this header were
+// chosen to minimize the effort of converting code that uses absl::LiteStrCat &
+// co to our more compact functions.
+
+#include <string>
+
+#ifdef COMPILER_MSVC
+#include <sstream>
+#endif // COMPILER_MSVC
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+// Less efficient but more compact version of absl::LiteStrCat().
+//
+// Given a value v (see supported types below) LiteStrCat(v) returns a new
+// string that contains the representation of v. For examples, see
+// str-cat_test.cc.
+template <typename T>
+inline string LiteStrCat(T v) {
+#ifdef COMPILER_MSVC
+ std::stringstream stream;
+ stream << input;
+ return stream.str();
+#else
+ return std::to_string(v);
+#endif
+}
+
+template <>
+inline string LiteStrCat(const char *v) {
+ return string(v);
+}
+
+// TODO(salcianu): use a reference type (const string &). For some reason, I
+// couldn't get that to work on a first try.
+template <>
+inline string LiteStrCat(string v) {
+ return v;
+}
+
+template <>
+inline string LiteStrCat(char v) {
+ return string(1, v);
+}
+
+// Less efficient but more compact version of absl::LiteStrAppend().
+template <typename T>
+inline void LiteStrAppend(string *dest, T v) {
+ dest->append(LiteStrCat(v)); // NOLINT
+}
+
+template <typename T1, typename T2>
+inline void LiteStrAppend(string *dest, T1 v1, T2 v2) {
+ dest->append(LiteStrCat(v1)); // NOLINT
+ dest->append(LiteStrCat(v2)); // NOLINT
+}
+
+template <typename T1, typename T2, typename T3>
+inline void LiteStrAppend(string *dest, T1 v1, T2 v2, T3 v3) {
+ LiteStrAppend(dest, v1, v2);
+ dest->append(LiteStrCat(v3)); // NOLINT
+}
+
+template <typename T1, typename T2, typename T3, typename T4>
+inline void LiteStrAppend(string *dest, T1 v1, T2 v2, T3 v3, T4 v4) {
+ LiteStrAppend(dest, v1, v2, v3);
+ dest->append(LiteStrCat(v4)); // NOLINT
+}
+
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_STRINGS_STR_CAT_H_
diff --git a/lang_id/common/lite_strings/str-split.cc b/lang_id/common/lite_strings/str-split.cc
new file mode 100644
index 0000000..199bb69
--- /dev/null
+++ b/lang_id/common/lite_strings/str-split.cc
@@ -0,0 +1,37 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/common/lite_strings/str-split.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+std::vector<StringPiece> LiteStrSplit(StringPiece text, char delim) {
+ std::vector<StringPiece> result;
+ int token_start = 0;
+ if (!text.empty()) {
+ for (size_t i = 0; i < text.size() + 1; ++i) {
+ if ((i == text.size()) || (text[i] == delim)) {
+ result.emplace_back(text.data() + token_start, i - token_start);
+ token_start = i + 1;
+ }
+ }
+ }
+ return result;
+}
+
+} // namespace mobile
+} // namespace nlp_saft
diff --git a/lang_id/common/lite_strings/str-split.h b/lang_id/common/lite_strings/str-split.h
new file mode 100644
index 0000000..300bc9f
--- /dev/null
+++ b/lang_id/common/lite_strings/str-split.h
@@ -0,0 +1,48 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_STRINGS_STR_SPLIT_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_STRINGS_STR_SPLIT_H_
+
+#include <vector>
+
+#include "lang_id/common/lite_strings/stringpiece.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+// Splits |text| on |delim|; similar to absl::StrSplit.
+//
+// Returns a list of tokens. Each token is represented by a StringPiece that
+// indicates a range of chars from |text|.
+//
+// Example: StrSplit("apple,orange", ',') returns two tokens: a StringPiece that
+// points to "apple", and another one for "orange".
+//
+// If one concatenates all returned tokens with |delim| in between, one gets the
+// original |text|. E.g., If we split "apple,orange," on ',', we get three
+// tokens: "apple", "orange" and "" (an empty token). We do not filter out
+// empty tokens. If necessary, the caller can do that.
+//
+// Note: if the input text is empty, we return an empty list of tokens. In
+// general, the number of returned tokens is 1 + the number of occurences of
+// |delim| inside |text|.
+std::vector<StringPiece> LiteStrSplit(StringPiece text, char delim);
+
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_STRINGS_STR_SPLIT_H_
diff --git a/lang_id/common/lite_strings/stringpiece.h b/lang_id/common/lite_strings/stringpiece.h
new file mode 100644
index 0000000..59a2176
--- /dev/null
+++ b/lang_id/common/lite_strings/stringpiece.h
@@ -0,0 +1,87 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef TC3_STD_STRING_IMPORT
+#define TC3_STD_STRING_IMPORT
+#include <string>
+
+namespace libtextclassifier3 {
+using string = std::string;
+template <class CharT, class Traits = std::char_traits<CharT>,
+ class Allocator = std::allocator<CharT> >
+using basic_string = std::basic_string<CharT, Traits, Allocator>;
+} // namespace libtextclassifier3
+#endif
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_STRINGS_STRINGPIECE_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_STRINGS_STRINGPIECE_H_
+
+#include <stddef.h>
+#include <string.h>
+
+#include <ostream>
+#include <string>
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+// Read-only "view" of a piece of data. Does not own the underlying data.
+class StringPiece {
+ public:
+ StringPiece() : StringPiece(nullptr, 0) {}
+
+ StringPiece(const char *str) // NOLINT
+ : start_(str), size_(strlen(str)) {}
+
+ StringPiece(const char *start, size_t size) : start_(start), size_(size) {}
+
+ // Intentionally no "explicit" keyword: in function calls, we want strings to
+ // be converted to StringPiece implicitly.
+ StringPiece(const string &s) // NOLINT
+ : StringPiece(s.data(), s.size()) {}
+
+ StringPiece(const string &s, int offset, int len)
+ : StringPiece(s.data() + offset, len) {}
+
+ char operator[](size_t i) const { return start_[i]; }
+
+ // Returns start address of underlying data.
+ const char *data() const { return start_; }
+
+ // Returns number of bytes of underlying data.
+ size_t size() const { return size_; }
+
+ // Returns true if this StringPiece does not refer to any characters.
+ bool empty() const { return size() == 0; }
+
+ template <typename A>
+ explicit operator basic_string<char, std::char_traits<char>, A>() const {
+ if (!data()) return {};
+ return basic_string<char, std::char_traits<char>, A>(data(), size());
+ }
+
+ private:
+ const char *start_; // Not owned.
+ size_t size_;
+};
+
+inline std::ostream &operator<<(std::ostream &out, StringPiece sp) {
+ return out.write(sp.data(), sp.size());
+}
+
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_STRINGS_STRINGPIECE_H_
diff --git a/lang_id/common/math/algorithm.h b/lang_id/common/math/algorithm.h
new file mode 100644
index 0000000..a963807
--- /dev/null
+++ b/lang_id/common/math/algorithm.h
@@ -0,0 +1,49 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Generic utils similar to those from the C++ header <algorithm>.
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_MATH_ALGORITHM_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_MATH_ALGORITHM_H_
+
+#include <algorithm>
+#include <vector>
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+// Returns index of max element from the vector |elements|. Returns 0 if
+// |elements| is empty. T should be a type that can be compared by operator<.
+template<typename T>
+inline int GetArgMax(const std::vector<T> &elements) {
+ return std::distance(
+ elements.begin(),
+ std::max_element(elements.begin(), elements.end()));
+}
+
+// Returns index of min element from the vector |elements|. Returns 0 if
+// |elements| is empty. T should be a type that can be compared by operator<.
+template<typename T>
+inline int GetArgMin(const std::vector<T> &elements) {
+ return std::distance(
+ elements.begin(),
+ std::min_element(elements.begin(), elements.end()));
+}
+
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_MATH_ALGORITHM_H_
diff --git a/lang_id/common/math/checksum.cc b/lang_id/common/math/checksum.cc
new file mode 100644
index 0000000..23d88bc
--- /dev/null
+++ b/lang_id/common/math/checksum.cc
@@ -0,0 +1,43 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/common/math/checksum.h"
+
+// Though we use the same zlib header on all platforms, the implementation used
+// is from NDK on android and from third_party/zlib on iOS/linux. See BUILD
+// rule.
+#include <zlib.h>
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+// static
+uint32 Crc32::GetInitialCrc32() {
+ static const uint32 kCrcInitZero = crc32(0L, nullptr, 0);
+ return kCrcInitZero;
+}
+
+void Crc32::Update(const char *str, int len) {
+ if (str == nullptr || len == 0) {
+ return;
+ }
+ current_ = crc32(current_,
+ reinterpret_cast<const unsigned char *>(str),
+ len);
+}
+
+} // namespace mobile
+} // namespace nlp_saft
diff --git a/lang_id/common/math/checksum.h b/lang_id/common/math/checksum.h
new file mode 100644
index 0000000..d62893f
--- /dev/null
+++ b/lang_id/common/math/checksum.h
@@ -0,0 +1,55 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_MATH_CHECKSUM_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_MATH_CHECKSUM_H_
+
+#include "lang_id/common/lite_base/integral-types.h"
+#include "lang_id/common/lite_strings/stringpiece.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+// Class to compute a 32bit Cyclic Redundancy Check (CRC) in a cummulative way.
+//
+// To use, create an instance of this class, repeatedly call Update() to "feed"
+// it your pieces of data, and, when done, call Get().
+class Crc32 {
+ public:
+ Crc32() : current_(GetInitialCrc32()) {}
+
+ // Updates current CRC32 code to also take into account the |len| bytes that
+ // start at address |str|.
+ void Update(const char *str, int len);
+
+ // Updates current CRC32 code to also take into account the bytes from |s|.
+ void Update(StringPiece s) { Update(s.data(), s.size()); }
+
+ // Returns the CRC32 code for the data so far.
+ uint32 Get() const { return current_; }
+
+ private:
+ // Returns the initial value for current_.
+ static uint32 GetInitialCrc32();
+
+ // CRC32 for the data so far.
+ uint32 current_;
+};
+
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_MATH_CHECKSUM_H_
diff --git a/lang_id/common/math/fastexp.cc b/lang_id/common/math/fastexp.cc
new file mode 100644
index 0000000..44df91f
--- /dev/null
+++ b/lang_id/common/math/fastexp.cc
@@ -0,0 +1,50 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/common/math/fastexp.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+const int FastMathClass::kBits;
+const int FastMathClass::kMask1;
+const int FastMathClass::kMask2;
+constexpr float FastMathClass::kLogBase2OfE;
+
+FastMathClass FastMathInstance;
+
+// Data taken from util/math/fastmath.cc
+const FastMathClass::Table FastMathClass::cache_ = {
+ {0, 45549, 91345, 137391, 183686, 230233, 277032, 324086, 371395, 418961,
+ 466785, 514869, 563214, 611822, 660693, 709830, 759233, 808905, 858847,
+ 909060, 959545, 1010305, 1061340, 1112652, 1164243, 1216114, 1268267,
+ 1320703, 1373423, 1426430, 1479725, 1533309, 1587184, 1641351, 1695813,
+ 1750570, 1805625, 1860979, 1916633, 1972590, 2028850, 2085416, 2142289,
+ 2199470, 2256963, 2314767, 2372885, 2431319, 2490070, 2549140, 2608531,
+ 2668245, 2728282, 2788646, 2849337, 2910358, 2971710, 3033396, 3095416,
+ 3157773, 3220469, 3283505, 3346884, 3410606, 3474675, 3539091, 3603857,
+ 3668975, 3734447, 3800274, 3866458, 3933002, 3999907, 4067176, 4134809,
+ 4202810, 4271180, 4339922, 4409036, 4478526, 4548394, 4618640, 4689268,
+ 4760280, 4831677, 4903462, 4975636, 5048203, 5121164, 5194520, 5268275,
+ 5342431, 5416989, 5491952, 5567322, 5643101, 5719292, 5795897, 5872917,
+ 5950356, 6028215, 6106497, 6185204, 6264338, 6343902, 6423898, 6504329,
+ 6585196, 6666502, 6748250, 6830442, 6913080, 6996166, 7079704, 7163696,
+ 7248143, 7333049, 7418416, 7504247, 7590543, 7677309, 7764545, 7852255,
+ 7940441, 8029106, 8118253, 8207884, 8298001}
+};
+
+} // namespace mobile
+} // namespace nlp_saft
diff --git a/lang_id/common/math/fastexp.h b/lang_id/common/math/fastexp.h
new file mode 100644
index 0000000..05b654a
--- /dev/null
+++ b/lang_id/common/math/fastexp.h
@@ -0,0 +1,70 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Fast approximation for exp.
+//
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_MATH_FASTEXP_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_MATH_FASTEXP_H_
+
+#include <cassert>
+#include <cmath>
+#include <limits>
+
+#include "lang_id/common/lite_base/casts.h"
+#include "lang_id/common/lite_base/integral-types.h"
+#include "lang_id/common/lite_base/logging.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+class FastMathClass {
+ private:
+ static const int kBits = 7;
+ static const int kMask1 = (1 << kBits) - 1;
+ static const int kMask2 = 0xFF << kBits;
+ static constexpr float kLogBase2OfE = 1.44269504088896340736f;
+
+ struct Table {
+ int32 exp1[1 << kBits];
+ };
+
+ public:
+ float VeryFastExp2(float f) const {
+ SAFTM_DCHECK_LE(fabs(f), 126);
+ const float g = f + (127 + (1 << (23 - kBits)));
+ const int32 x = bit_cast<int32>(g);
+ int32 ret = ((x & kMask2) << (23 - kBits))
+ | cache_.exp1[x & kMask1];
+ return bit_cast<float>(ret);
+ }
+
+ float VeryFastExp(float f) const {
+ return VeryFastExp2(f * kLogBase2OfE);
+ }
+
+ private:
+ static const Table cache_;
+};
+
+extern FastMathClass FastMathInstance;
+
+inline float VeryFastExp(float f) { return FastMathInstance.VeryFastExp(f); }
+
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_MATH_FASTEXP_H_
diff --git a/lang_id/common/math/hash.cc b/lang_id/common/math/hash.cc
new file mode 100644
index 0000000..d320428
--- /dev/null
+++ b/lang_id/common/math/hash.cc
@@ -0,0 +1,83 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/common/math/hash.h"
+
+#include "lang_id/common/lite_base/macros.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+namespace utils {
+
+namespace {
+// Lower-level versions of Get... that read directly from a character buffer
+// without any bounds checking.
+inline uint32 DecodeFixed32(const char *ptr) {
+ return ((static_cast<uint32>(static_cast<unsigned char>(ptr[0]))) |
+ (static_cast<uint32>(static_cast<unsigned char>(ptr[1])) << 8) |
+ (static_cast<uint32>(static_cast<unsigned char>(ptr[2])) << 16) |
+ (static_cast<uint32>(static_cast<unsigned char>(ptr[3])) << 24));
+}
+
+// 0xff is in case char is signed.
+static inline uint32 ByteAs32(char c) { return static_cast<uint32>(c) & 0xff; }
+} // namespace
+
+uint32 Hash32(const char *data, size_t n, uint32 seed) {
+ // 'm' and 'r' are mixing constants generated offline.
+ // They're not really 'magic', they just happen to work well.
+ const uint32 m = 0x5bd1e995;
+ const int r = 24;
+
+ // Initialize the hash to a 'random' value
+ uint32 h = seed ^ n;
+
+ // Mix 4 bytes at a time into the hash
+ while (n >= 4) {
+ uint32 k = DecodeFixed32(data);
+ k *= m;
+ k ^= k >> r;
+ k *= m;
+ h *= m;
+ h ^= k;
+ data += 4;
+ n -= 4;
+ }
+
+ // Handle the last few bytes of the input array
+ switch (n) {
+ case 3:
+ h ^= ByteAs32(data[2]) << 16;
+ SAFTM_FALLTHROUGH_INTENDED;
+ case 2:
+ h ^= ByteAs32(data[1]) << 8;
+ SAFTM_FALLTHROUGH_INTENDED;
+ case 1:
+ h ^= ByteAs32(data[0]);
+ h *= m;
+ }
+
+ // Do a few final mixes of the hash to ensure the last few
+ // bytes are well-incorporated.
+ h ^= h >> 13;
+ h *= m;
+ h ^= h >> 15;
+ return h;
+}
+
+} // namespace utils
+} // namespace mobile
+} // namespace nlp_saft
diff --git a/lang_id/common/math/hash.h b/lang_id/common/math/hash.h
new file mode 100644
index 0000000..08c32be
--- /dev/null
+++ b/lang_id/common/math/hash.h
@@ -0,0 +1,62 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef TC3_STD_STRING_IMPORT
+#define TC3_STD_STRING_IMPORT
+#include <string>
+
+namespace libtextclassifier3 {
+using string = std::string;
+template <class CharT, class Traits = std::char_traits<CharT>,
+ class Allocator = std::allocator<CharT> >
+using basic_string = std::basic_string<CharT, Traits, Allocator>;
+} // namespace libtextclassifier3
+#endif
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_MATH_HASH_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_MATH_HASH_H_
+
+#include <string>
+
+#include "lang_id/common/lite_base/integral-types.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+namespace utils {
+
+// Returns a 32 bit hash of the |n| bytes that start at |data|, using |seed| for
+// internal initialization. By changing the seed, one effectively gets
+// different hash functions.
+//
+// NOTE: this function is guaranteed not to change in the future.
+//
+// IMPORTANT: for speed reasons, this method does not check its parameters
+// |data| and |n|. The caller should ensure that n >= 0 and that one can read
+// from the memory area [data, data + n).
+uint32 Hash32(const char *data, size_t n, uint32 seed);
+
+static inline uint32 Hash32WithDefaultSeed(const char *data, size_t n) {
+ return Hash32(data, n, 0xBEEF);
+}
+
+static inline uint32 Hash32WithDefaultSeed(const string &input) {
+ return Hash32WithDefaultSeed(input.data(), input.size());
+}
+
+} // namespace utils
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_MATH_HASH_H_
diff --git a/lang_id/common/math/softmax.cc b/lang_id/common/math/softmax.cc
new file mode 100644
index 0000000..750341d
--- /dev/null
+++ b/lang_id/common/math/softmax.cc
@@ -0,0 +1,103 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/common/math/softmax.h"
+
+#include <algorithm>
+
+#include "lang_id/common/lite_base/logging.h"
+#include "lang_id/common/math/fastexp.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+float ComputeSoftmaxProbability(const std::vector<float> &scores, int label) {
+ if ((label < 0) || (label >= scores.size())) {
+ SAFTM_LOG(ERROR) << "label " << label << " outside range "
+ << "[0, " << scores.size() << ")";
+ return 0.0f;
+ }
+
+ // Standard softmax formula for label's probability is
+ //
+ // exp(scores[label]) / sum_i exp(scores[i])
+ //
+ // We compute the mathematically equivalent
+ //
+ // 1 / (1 + sum_{i != label} exp(scores[i] - scores[label]))
+ //
+ // which saves two calls to exp().
+ const float label_score = scores[label];
+ float denominator = 1.0f; // Contribution of i == label.
+ for (int i = 0; i < scores.size(); ++i) {
+ if (i == label) continue;
+ const float delta_score = scores[i] - label_score;
+
+ // TODO(salcianu): one can optimize the test below, to avoid any float
+ // operation: extract exponent (via bit mask + shift) and check it's >= 4.
+ if (fabs(delta_score) >= 16.0f) {
+ if (delta_score > 0.0f) {
+ // If delta_score >= 16, the denominator (e^delta_score + other positive
+ // terms) is very big and its inverse can be approximated with 0.
+ return 0.0f;
+ } else {
+ // If delta_score <= -16, then e^delta_score < 1.2e-7. Even if we have
+ // 1000 such labels i, their sum is < 1.2e-4 (which gets summed with
+ // 1.0f for i == label). Hence, we can approximate each such label with
+ // 0 and skip the call to VeryFastExp and the update to denominator.
+ continue;
+ }
+ }
+
+ // At this point, delta_score is in (-16.0, 16.0). For such values, vfexp
+ // works fine: no under/overflows (we have tests for that in fastexp_test).
+ // Also, even for 1000 labels, denominator will not overflow.
+ denominator += VeryFastExp(delta_score);
+ }
+ return 1.0f / denominator;
+}
+
+std::vector<float> ComputeSoftmax(const std::vector<float> &scores,
+ float alpha) {
+ std::vector<float> softmax;
+ softmax.reserve(scores.size());
+ if (scores.empty()) {
+ return softmax;
+ }
+
+ std::vector<float> exp_scores;
+ exp_scores.reserve(scores.size());
+
+ // Find max value in "scores" vector and rescale to avoid overflows.
+ const float max_score = *std::max_element(scores.begin(), scores.end());
+ float denominator = 0;
+ for (const float score : scores) {
+ // See comments above in ComputeSoftmaxProbability for the reasoning behind
+ // this approximation.
+ const float delta_score = alpha * (score - max_score);
+ const float exp_score = delta_score < -16.0f ? 0 : VeryFastExp(delta_score);
+ exp_scores.push_back(exp_score);
+ denominator += exp_score;
+ }
+
+ for (int i = 0; i < scores.size(); ++i) {
+ softmax.push_back(exp_scores[i] / denominator);
+ }
+ return softmax;
+}
+
+} // namespace mobile
+} // namespace nlp_saft
diff --git a/lang_id/common/math/softmax.h b/lang_id/common/math/softmax.h
new file mode 100644
index 0000000..0100e59
--- /dev/null
+++ b/lang_id/common/math/softmax.h
@@ -0,0 +1,40 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_MATH_SOFTMAX_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_MATH_SOFTMAX_H_
+
+#include <vector>
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+// Computes probability of a softmax label. Parameter "scores" is the vector of
+// softmax logits. Returns 0.0f if "label" is outside the range [0,
+// scores.size()).
+float ComputeSoftmaxProbability(const std::vector<float> &scores, int label);
+
+// Computes and returns a softmax for a given vector of floats. Parameter
+// "scores" is the vector of softmax logits.
+//
+// The alpha parameter is a scaling factor on the logits.
+std::vector<float> ComputeSoftmax(const std::vector<float> &scores,
+ float alpha = 1.0f);
+
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_MATH_SOFTMAX_H_
diff --git a/lang_id/common/registry.h b/lang_id/common/registry.h
new file mode 100644
index 0000000..d2c5271
--- /dev/null
+++ b/lang_id/common/registry.h
@@ -0,0 +1,321 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Mechanism to instantiate classes by name.
+//
+// This mechanism is useful if the concrete classes to be instantiated are not
+// statically known (e.g., if their names are read from a dynamically-provided
+// config).
+//
+// In that case, the first step is to define the API implemented by the
+// instantiated classes. E.g.,
+//
+// // In a header file function.h:
+//
+// // Abstract function that takes a double and returns a double.
+// class Function : public RegisterableClass<Function> {
+// public:
+// virtual ~Function() {}
+// virtual double Evaluate(double x) = 0;
+// };
+//
+// // Should be inside namespace libtextclassifier3::mobile.
+// SAFTM_DECLARE_CLASS_REGISTRY_NAME(Function);
+//
+// Notice the inheritance from RegisterableClass<Function>. RegisterableClass
+// is defined by this file (registry.h). Under the hood, this inheritanace
+// defines a "registry" that maps names (zero-terminated arrays of chars) to
+// factory methods that create Functions. You should give a human-readable name
+// to this registry. To do that, use the following macro in a .cc file (it has
+// to be a .cc file, as it defines some static data):
+//
+// // Inside function.cc
+// // Should be inside namespace libtextclassifier3::mobile.
+// SAFTM_DEFINE_CLASS_REGISTRY_NAME("function", Function);
+//
+// Now, let's define a few concrete Functions: e.g.,
+//
+// class Cos : public Function {
+// public:
+// double Evaluate(double x) override { return cos(x); }
+// SAFTM_DEFINE_REGISTRATION_METHOD("cos", Cos);
+// };
+//
+// class Exp : public Function {
+// public:
+// double Evaluate(double x) override { return exp(x); }
+// SAFTM_DEFINE_REGISTRATION_METHOD("sin", Sin);
+// };
+//
+// Each concrete Function implementation should have (in the public section) the
+// macro
+//
+// SAFTM_DEFINE_REGISTRATION_METHOD("name", implementation_class);
+//
+// This defines a RegisterClass static method that, when invoked, associates
+// "name" with a factory method that creates instances of implementation_class.
+//
+// Before instantiating Functions by name, we need to tell our system which
+// Functions we may be interested in. This is done by calling the
+// Foo::RegisterClass() for each relevant Foo implementation of Function. It is
+// ok to call Foo::RegisterClass() multiple times (even in parallel): only the
+// first call will perform something, the others will return immediately.
+//
+// Cos::RegisterClass();
+// Exp::RegisterClass();
+//
+// Now, let's instantiate a Function based on its name. This get a lot more
+// interesting if the Function name is not statically known (i.e.,
+// read from an input proto:
+//
+// std::unique_ptr<Function> f(Function::Create("cos"));
+// double result = f->Evaluate(arg);
+//
+// NOTE: the same binary can use this mechanism for different APIs. E.g., one
+// can also have (in the binary with Function, Sin, Cos, etc):
+//
+// class IntFunction : public RegisterableClass<IntFunction> {
+// public:
+// virtual ~IntFunction() {}
+// virtual int Evaluate(int k) = 0;
+// };
+//
+// SAFTM_DECLARE_CLASS_REGISTRY_NAME(IntFunction);
+//
+// SAFTM_DEFINE_CLASS_REGISTRY_NAME("int function", IntFunction);
+//
+// class Inc : public IntFunction {
+// public:
+// int Evaluate(int k) override { return k + 1; }
+// SAFTM_DEFINE_REGISTRATION_METHOD("inc", Inc);
+// };
+//
+// RegisterableClass<Function> and RegisterableClass<IntFunction> define their
+// own registries: each maps string names to implementation of the corresponding
+// API.
+//
+// NOTE: the mechanism described above requires you to explicitly call
+// RegisterClass() for all relevant classes before instantiating them. You can
+// do this in the main() function or in any other function that is guaranteed to
+// run before the code that instantiates those classes. Alternatively, you can
+// use the macro SAFTM_STATIC_REGISTRATION to perform this registration in a
+// decentralized fashion. Just use that macro in a .cc file, outside any
+// function / class, e.g.,
+//
+// SAFTM_STATIC_REGISTRATION(Cos);
+//
+// and make sure you link in all symbols from that .cc file; e.g., in bazel, use
+// alwayslink = 1 for the corresponding cc_library. Still, please be aware that
+// using alwayslink = 1 limits the ability of the linker to perform dead code
+// elimination.
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_REGISTRY_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_REGISTRY_H_
+
+#include <stdlib.h>
+#include <string.h>
+
+#include <string>
+#include <vector>
+
+#include "lang_id/common/lite_base/logging.h"
+#include "lang_id/common/lite_base/macros.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+namespace internal {
+// Registry that associates keys (zero-terminated array of chars) with values.
+// Values are pointers to type T (the template parameter). This is used to
+// store the association between component names and factory methods that
+// produce those components; the error messages are focused on that case.
+//
+// Internally, this registry uses a linked list of (key, value) pairs. We do
+// not use an STL map, list, etc because we aim for small code size.
+template <class T>
+class ComponentRegistry {
+ public:
+ explicit ComponentRegistry(const char *name) : name_(name), head_(nullptr) {}
+
+ // Adds a the (key, value) pair to this registry (if the key does not already
+ // exists in this registry) and returns true. If the registry already has a
+ // mapping for key, returns false and does not modify the registry. NOTE: the
+ // error (false) case happens even if the existing value for key is equal with
+ // the new one.
+ //
+ // This method does not take ownership of key, nor of value.
+ bool Add(const char *key, T *value) {
+ const Cell *old_cell = FindCell(key);
+ if (old_cell != nullptr) {
+ SAFTM_LOG(ERROR) << "Duplicate component: " << key;
+ return false;
+ }
+ Cell *new_cell = new Cell(key, value, head_);
+ head_ = new_cell;
+ return true;
+ }
+
+ // Returns the value attached to a key in this registry. Returns nullptr on
+ // error (e.g., unknown key).
+ T *Lookup(const char *key) const {
+ const Cell *cell = FindCell(key);
+ if (cell == nullptr) {
+ SAFTM_LOG(ERROR) << "Unknown " << name() << " component: " << key;
+ }
+ return (cell == nullptr) ? nullptr : cell->value();
+ }
+
+ T *Lookup(const string &key) const { return Lookup(key.c_str()); }
+
+ // Returns name of this ComponentRegistry.
+ const char *name() const { return name_; }
+
+ // Fills *names with names of all components registered in this
+ // ComponentRegistry. Previous content of *names is cleared out.
+ void GetComponentNames(std::vector<string> *names) {
+ names->clear();
+ for (const Cell *c = head_; c!= nullptr; c = c->next()) {
+ names->emplace_back(c->key());
+ }
+ }
+
+ private:
+ // Cell for the singly-linked list underlying this ComponentRegistry. Each
+ // cell contains a key, the value for that key, as well as a pointer to the
+ // next Cell from the list.
+ class Cell {
+ public:
+ // Constructs a new Cell.
+ Cell(const char *key, T *value, Cell *next)
+ : key_(key), value_(value), next_(next) {}
+
+ const char *key() const { return key_; }
+ T *value() const { return value_; }
+ Cell *next() const { return next_; }
+
+ private:
+ const char *const key_;
+ T *const value_;
+ Cell *const next_;
+ };
+
+ // Finds Cell for indicated key in the singly-linked list pointed to by head_.
+ // Returns pointer to that first Cell with that key, or nullptr if no such
+ // Cell (i.e., unknown key).
+ //
+ // Caller does NOT own the returned pointer.
+ const Cell *FindCell(const char *key) const {
+ const Cell *c = head_;
+ while (c != nullptr && strcmp(key, c->key()) != 0) {
+ c = c->next();
+ }
+ return c;
+ }
+
+ // Human-readable description for this ComponentRegistry. For debug purposes.
+ const char *const name_;
+
+ // Pointer to the first Cell from the underlying list of (key, value) pairs.
+ Cell *head_;
+};
+} // namespace internal
+
+// Base class for registerable classes.
+template <class T>
+class RegisterableClass {
+ public:
+ // Factory function type.
+ typedef T *(Factory)();
+
+ // Registry type.
+ typedef internal::ComponentRegistry<Factory> Registry;
+
+ // Creates a new instance of T. Returns pointer to new instance or nullptr in
+ // case of errors (e.g., unknown component).
+ //
+ // Passes ownership of the returned pointer to the caller.
+ static T *Create(const string &name) { // NOLINT
+ auto *factory = registry()->Lookup(name);
+ if (factory == nullptr) {
+ SAFTM_LOG(ERROR) << "Unknown RegisterableClass " << name;
+ return nullptr;
+ }
+ return factory();
+ }
+
+ // Returns registry for class.
+ static Registry *registry() {
+ static Registry *registry_for_type_t = new Registry(kRegistryName);
+ return registry_for_type_t;
+ }
+
+ protected:
+ // Factory method for subclass ComponentClass. Used internally by the static
+ // method RegisterClass() defined by SAFTM_DEFINE_REGISTRATION_METHOD.
+ template <class ComponentClass>
+ static T *_internal_component_factory() {
+ return new ComponentClass();
+ }
+
+ private:
+ // Human-readable name for the registry for this class.
+ static const char kRegistryName[];
+};
+
+// Defines the static method component_class::RegisterClass() that should be
+// called before trying to instantiate component_class by name. Should be used
+// inside the public section of the declaration of component_class. See
+// comments at the top-level of this file.
+#define SAFTM_DEFINE_REGISTRATION_METHOD(component_name, component_class) \
+ static void RegisterClass() { \
+ static bool once = registry()->Add( \
+ component_name, &_internal_component_factory<component_class>); \
+ if (!once) { \
+ SAFTM_LOG(ERROR) << "Problem registering " << component_name; \
+ } \
+ SAFTM_DCHECK(once); \
+ }
+
+// Defines the human-readable name of the registry associated with base_class.
+#define SAFTM_DECLARE_CLASS_REGISTRY_NAME(base_class) \
+ template <> \
+ const char ::libtextclassifier3::mobile::RegisterableClass<base_class>::kRegistryName[]
+
+// Defines the human-readable name of the registry associated with base_class.
+#define SAFTM_DEFINE_CLASS_REGISTRY_NAME(registry_name, base_class) \
+ template <> \
+ const char \
+ ::libtextclassifier3::mobile::RegisterableClass<base_class>::kRegistryName[] \
+ = registry_name
+
+// Register component_name, by calling component_class::RegisterClass() on
+// program start-up, before main. NOTE: this macro should be used in
+// conjunction with something like alwayslink = 1 from bazel. That is
+// discouraged, as it prevents the linker from doing dead code elimination, so
+// please use this macro only in special cases. Instead, if you care about code
+// size, then you should aim to explicitly call RegisterClass from your code
+// (e.g., from the main method, or from the constructor of the class that may
+// need those registered components).
+#define SAFTM_STATIC_REGISTRATION(component_class) \
+ static bool SAFTM_UNIQUE_ID(_kRegistrationDummy) = [] { \
+ component_class::RegisterClass(); \
+ return true; \
+ }()
+
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_REGISTRY_H_
diff --git a/lang_id/common/stl-util.h b/lang_id/common/stl-util.h
new file mode 100644
index 0000000..95d8d3b
--- /dev/null
+++ b/lang_id/common/stl-util.h
@@ -0,0 +1,45 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_STL_UTIL_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_STL_UTIL_H_
+
+
+namespace libtextclassifier3 {
+namespace mobile {
+namespace utils {
+
+// Deletes all the elements in an STL container and clears the container. This
+// function is suitable for use with a vector, set, hash_set, or any other STL
+// container which defines sensible begin(), end(), and clear() methods. If
+// container is NULL, this function is a no-op.
+template <typename T>
+void STLDeleteElements(T *container) {
+ if (!container) return;
+ auto it = container->begin();
+ while (it != container->end()) {
+ auto temp = it;
+ ++it;
+ delete *temp;
+ }
+ container->clear();
+}
+
+} // namespace utils
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_STL_UTIL_H_
diff --git a/lang_id/common/utf8.cc b/lang_id/common/utf8.cc
new file mode 100644
index 0000000..ef00145
--- /dev/null
+++ b/lang_id/common/utf8.cc
@@ -0,0 +1,39 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/common/utf8.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+namespace utils {
+
+const char *GetSafeEndOfUtf8String(const char *data, size_t size) {
+ const char *const hard_end = data + size;
+ const char *curr = data;
+ while (curr < hard_end && *curr) {
+ int num_bytes = utils::OneCharLen(curr);
+ const char *new_curr = curr + num_bytes;
+ if (new_curr > hard_end) {
+ return curr;
+ }
+ curr = new_curr;
+ }
+ return curr;
+}
+
+} // namespace utils
+} // namespace mobile
+} // namespace nlp_saft
diff --git a/lang_id/common/utf8.h b/lang_id/common/utf8.h
new file mode 100644
index 0000000..2365429
--- /dev/null
+++ b/lang_id/common/utf8.h
@@ -0,0 +1,76 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef TC3_STD_STRING_IMPORT
+#define TC3_STD_STRING_IMPORT
+#include <string>
+
+namespace libtextclassifier3 {
+using string = std::string;
+template <class CharT, class Traits = std::char_traits<CharT>,
+ class Allocator = std::allocator<CharT> >
+using basic_string = std::basic_string<CharT, Traits, Allocator>;
+} // namespace libtextclassifier3
+#endif
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_UTF8_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_UTF8_H_
+
+#include <stddef.h>
+
+#include <string>
+
+namespace libtextclassifier3 {
+namespace mobile {
+namespace utils {
+
+// Returns the length (number of bytes) of the UTF8 code point starting at src,
+// by reading only the byte from address src.
+//
+// The result is a number from the set {1, 2, 3, 4}.
+static inline int OneCharLen(const char *src) {
+ // On most platforms, char is unsigned by default, but iOS is an exception.
+ // The cast below makes sure we always interpret *src as an unsigned char.
+ return "\1\1\1\1\1\1\1\1\1\1\1\1\2\2\3\4"
+ [(*(reinterpret_cast<const unsigned char *>(src)) & 0xFF) >> 4];
+}
+
+// Returns a pointer "end" inside [data, data + size) such that the prefix from
+// [data, end) is the largest one that does not contain '\0' and offers the
+// following guarantee: if one starts with
+//
+// curr = text.data()
+//
+// and keeps executing
+//
+// curr += OneCharLen(curr)
+//
+// one would eventually reach curr == end (the pointer returned by this
+// function) without accessing data outside the string. This guards against
+// scenarios like a broken UTF8 string which has only e.g., the first 2 bytes
+// from a 3-byte UTF8 sequence.
+//
+// Preconditions: data != nullptr.
+const char *GetSafeEndOfUtf8String(const char *data, size_t size);
+
+static inline const char *GetSafeEndOfUtf8String(const string &text) {
+ return GetSafeEndOfUtf8String(text.data(), text.size());
+}
+
+} // namespace utils
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_UTF8_H_
diff --git a/lang_id/custom-tokenizer.cc b/lang_id/custom-tokenizer.cc
new file mode 100644
index 0000000..f77ad53
--- /dev/null
+++ b/lang_id/custom-tokenizer.cc
@@ -0,0 +1,162 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/custom-tokenizer.h"
+
+#include <ctype.h>
+
+#include <string>
+
+#include "lang_id/common/lite_base/attributes.h"
+#include "lang_id/common/lite_base/logging.h"
+#include "lang_id/common/utf8.h"
+#include "utf.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+namespace lang_id {
+
+namespace {
+inline bool IsTokenSeparator(int num_bytes, const char *curr) {
+ if (num_bytes != 1) {
+ return false;
+ }
+ return !isalpha(*curr);
+}
+
+// Appends to *word the UTF8 encoding for the lowercase version of the UTF8
+// character that starts at |curr| and has |num_bytes| bytes.
+//
+// NOTE: if the current UTF8 character does not have a lowercase version, then
+// we append the original UTF8 character.
+inline SAFTM_ATTRIBUTE_ALWAYS_INLINE void AppendLowerCase(const char *curr,
+ int num_bytes,
+ string *word) {
+ if (num_bytes == 1) {
+ // Optimize the ASCII case.
+ word->push_back(tolower(*curr));
+ return;
+ }
+
+ // Harder, general case.
+ //
+ // NOTE: for lowercasing, we use the utils from utf.h:
+ // charntorune + tolowerrune + runetochar. Unfortunately, that library does
+ // not contain any fast util for determining the number of bytes for the UTF8
+ // character that starts at a given address *without* converting to a full
+ // codepoint (like our utils::OneCharLen, which is used intensively by the
+ // rest of our code, including by the performance-critical char ngram
+ // feature). Hence, the rest of our code continues to use utils::OneCharLen,
+ // and here, when we append the bytes to *word, we make sure that's consistent
+ // with utils::OneCharLen.
+
+ // charntorune() below reads the UTF8 character that starts at curr (using at
+ // most num_bytes bytes) and stores the corresponding codepoint into rune.
+ Rune rune;
+ charntorune(&rune, curr, num_bytes);
+ if (rune != Runeerror) {
+ Rune lower = tolowerrune(rune);
+ char lower_buf[UTFmax];
+ runetochar(lower_buf, &lower);
+
+ // When appending the UTF8 bytes to word, we do not use the number of bytes
+ // returned by runetochar(); instead, we use utils::OneCharLen(), the same
+ // method used by the char ngram feature. We expect them to be equal, but
+ // just in case.
+ int lower_num_bytes = utils::OneCharLen(lower_buf);
+
+ // Using lower_num_bytes below is safe, because, by definition of UTFmax,
+ SAFTM_DCHECK_GE(UTFmax, 4);
+
+ // And, by implementation of utils::OneCharLen():
+ SAFTM_DCHECK_GT(lower_num_bytes, 0);
+ SAFTM_DCHECK_LE(lower_num_bytes, 4);
+ word->append(lower_buf, lower_num_bytes);
+ } else {
+ // There are sequences of bytes that charntorune() can't convert into a
+ // valid Rune (a special case is [0xEF, 0xBF, 0xBD], the UTF8 encoding for
+ // the U+FFFD special Unicode character, which is also the value of
+ // Runeerror). We keep those bytes unchanged.
+ word->append(curr, num_bytes);
+ }
+}
+} // namespace
+
+void TokenizerForLangId::Setup(TaskContext *context) {
+ lowercase_input_ = context->Get("lang_id_lowercase_input", false);
+}
+
+void TokenizerForLangId::Tokenize(StringPiece text,
+ LightSentence *sentence) const {
+ const char *const start = text.data();
+ const char *curr = start;
+ const char *end = utils::GetSafeEndOfUtf8String(start, text.size());
+
+ // Corner case: the safe part of the text is empty ("").
+ if (curr >= end) {
+ return;
+ }
+
+ // Number of bytes for UTF8 character starting at *curr. Note: the loop below
+ // is guaranteed to terminate because in each iteration, we move curr by at
+ // least num_bytes, and num_bytes is guaranteed to be > 0.
+ int num_bytes = utils::OneCharLen(curr);
+ while (curr < end) {
+ // Jump over consecutive token separators.
+ while (IsTokenSeparator(num_bytes, curr)) {
+ curr += num_bytes;
+ if (curr >= end) {
+ return;
+ }
+ num_bytes = utils::OneCharLen(curr);
+ }
+
+ // If control reaches this point, we are at beginning of a non-empty token.
+ sentence->emplace_back();
+ string *word = &(sentence->back());
+
+ // Add special token-start character.
+ word->push_back('^');
+
+ // Add UTF8 characters to word, until we hit the end of the safe text or a
+ // token separator.
+ while (true) {
+ if (lowercase_input_) {
+ AppendLowerCase(curr, num_bytes, word);
+ } else {
+ word->append(curr, num_bytes);
+ }
+ curr += num_bytes;
+ if (curr >= end) {
+ break;
+ }
+ num_bytes = utils::OneCharLen(curr);
+ if (IsTokenSeparator(num_bytes, curr)) {
+ curr += num_bytes;
+ if (curr >= end) {
+ break;
+ }
+ num_bytes = utils::OneCharLen(curr);
+ break;
+ }
+ }
+ word->push_back('$');
+ }
+}
+
+} // namespace lang_id
+} // namespace mobile
+} // namespace nlp_saft
diff --git a/lang_id/custom-tokenizer.h b/lang_id/custom-tokenizer.h
new file mode 100644
index 0000000..6fab796
--- /dev/null
+++ b/lang_id/custom-tokenizer.h
@@ -0,0 +1,55 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_CUSTOM_TOKENIZER_H_
+#define NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_CUSTOM_TOKENIZER_H_
+
+#include <string>
+
+#include "lang_id/common/fel/task-context.h"
+#include "lang_id/common/lite_strings/stringpiece.h"
+#include "lang_id/light-sentence.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+namespace lang_id {
+
+// Custom tokenizer for the LangId model.
+class TokenizerForLangId {
+ public:
+ void Setup(TaskContext *context);
+
+ // Tokenizes |text|, placing the tokens into |sentence|. Customized for
+ // LangId. Currently (Sep 15, 2016) we tokenize on space, newline, tab, and
+ // any other 1-byte UTF8 character which is not a letter, ignore all empty
+ // tokens, and (for each of the remaining tokens) prepend "^" (special token
+ // begin marker) and append "$" (special token end marker).
+ //
+ // Tokens are stored into the "repeated Token token;" field of *sentence.
+ void Tokenize(StringPiece text, LightSentence *sentence) const;
+
+ private:
+ // If true, during tokenization, we use the lowercase version of each Unicode
+ // character from the text to tokenize. E.g., if this is true, the text "Foo
+ // bar" is tokenized as ["foo", "bar"]; otherwise, we get ["Foo", "bar"].
+ bool lowercase_input_ = false;
+};
+
+} // namespace lang_id
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_CUSTOM_TOKENIZER_H_
diff --git a/lang_id/fb_model/lang-id-from-fb.cc b/lang_id/fb_model/lang-id-from-fb.cc
new file mode 100644
index 0000000..f8e39d7
--- /dev/null
+++ b/lang_id/fb_model/lang-id-from-fb.cc
@@ -0,0 +1,55 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/fb_model/lang-id-from-fb.h"
+
+#include "lang_id/fb_model/model-provider-from-fb.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+namespace lang_id {
+
+std::unique_ptr<LangId> GetLangIdFromFlatbufferFile(const string &filename) {
+ std::unique_ptr<ModelProvider> model_provider(
+ new ModelProviderFromFlatbuffer(filename));
+
+ // NOTE: we avoid absl (including absl::make_unique), due to b/113350902
+ return std::unique_ptr<LangId>( // NOLINT
+ new LangId(std::move(model_provider)));
+}
+
+std::unique_ptr<LangId> GetLangIdFromFlatbufferFileDescriptor(int fd) {
+ std::unique_ptr<ModelProvider> model_provider(
+ new ModelProviderFromFlatbuffer(fd));
+
+ // NOTE: we avoid absl (including absl::make_unique), due to b/113350902
+ return std::unique_ptr<LangId>( // NOLINT
+ new LangId(std::move(model_provider)));
+}
+
+std::unique_ptr<LangId> GetLangIdFromFlatbufferBytes(const char *data,
+ size_t num_bytes) {
+ std::unique_ptr<ModelProvider> model_provider(
+ new ModelProviderFromFlatbuffer(data, num_bytes));
+
+ // NOTE: we avoid absl (including absl::make_unique), due to b/113350902
+ return std::unique_ptr<LangId>( // NOLINT
+ new LangId(std::move(model_provider)));
+}
+
+} // namespace lang_id
+} // namespace mobile
+} // namespace nlp_saft
diff --git a/lang_id/fb_model/lang-id-from-fb.h b/lang_id/fb_model/lang-id-from-fb.h
new file mode 100644
index 0000000..51bcffe
--- /dev/null
+++ b/lang_id/fb_model/lang-id-from-fb.h
@@ -0,0 +1,61 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_FB_MODEL_LANG_ID_FROM_FB_H_
+#define NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_FB_MODEL_LANG_ID_FROM_FB_H_
+
+#include <stddef.h>
+
+#include <memory>
+#include <string>
+
+#include "lang_id/lang-id.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+namespace lang_id {
+
+// Returns a LangId built using the SAFT model in flatbuffer format from
+// |filename|.
+std::unique_ptr<LangId> GetLangIdFromFlatbufferFile(const string &filename);
+
+// Returns a LangId built using the SAFT model in flatbuffer format from
+// given file descriptor.
+std::unique_ptr<LangId> GetLangIdFromFlatbufferFileDescriptor(int fd);
+
+// Returns a LangId built using the SAFT model in flatbuffer format from
+// the |num_bytes| bytes that start at address |data|.
+//
+// IMPORTANT: the model bytes must be alive during the lifetime of the returned
+// LangId. To avoid overhead (e.g., heap allocation), this method does not make
+// a private copy of the model bytes. Avoiding overhead is the main reason we
+// use flatbuffers.
+std::unique_ptr<LangId> GetLangIdFromFlatbufferBytes(const char *data,
+ size_t num_bytes);
+
+// Convenience string-based version of GetLangIdFromFlatbufferBytes.
+//
+// IMPORTANT: |bytes| must be alive during the lifetime of the returned LangId.
+inline std::unique_ptr<LangId> GetLangIdFromFlatbufferBytes(
+ const string &bytes) {
+ return GetLangIdFromFlatbufferBytes(bytes.data(), bytes.size());
+}
+
+} // namespace lang_id
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_FB_MODEL_LANG_ID_FROM_FB_H_
diff --git a/lang_id/fb_model/model-provider-from-fb.cc b/lang_id/fb_model/model-provider-from-fb.cc
new file mode 100644
index 0000000..3357963
--- /dev/null
+++ b/lang_id/fb_model/model-provider-from-fb.cc
@@ -0,0 +1,102 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/fb_model/model-provider-from-fb.h"
+
+#include "lang_id/common/file/file-utils.h"
+#include "lang_id/common/flatbuffers/embedding-network-params-from-flatbuffer.h"
+#include "lang_id/common/flatbuffers/model-utils.h"
+#include "lang_id/common/lite_strings/str-split.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+namespace lang_id {
+
+ModelProviderFromFlatbuffer::ModelProviderFromFlatbuffer(const string &filename)
+
+ // Using mmap as a fast way to read the model bytes. As the file is
+ // unmapped only when the field scoped_mmap_ is destructed, the model bytes
+ // stay alive for the entire lifetime of this object.
+ : scoped_mmap_(new ScopedMmap(filename)) {
+ Initialize(scoped_mmap_->handle().to_stringpiece());
+}
+
+ModelProviderFromFlatbuffer::ModelProviderFromFlatbuffer(int fd)
+
+ // Using mmap as a fast way to read the model bytes. As the file is
+ // unmapped only when the field scoped_mmap_ is destructed, the model bytes
+ // stay alive for the entire lifetime of this object.
+ : scoped_mmap_(new ScopedMmap(fd)) {
+ Initialize(scoped_mmap_->handle().to_stringpiece());
+}
+
+void ModelProviderFromFlatbuffer::Initialize(StringPiece model_bytes) {
+ // Note: valid_ was initialized to false. In the code below, we set valid_ to
+ // true only if all initialization steps completed successfully. Otherwise,
+ // we return early, leaving valid_ to its default value false.
+ model_ = saft_fbs::GetVerifiedModelFromBytes(model_bytes);
+ if (model_ == nullptr) {
+ SAFTM_LOG(ERROR) << "Unable to initialize ModelProviderFromFlatbuffer";
+ return;
+ }
+
+ // Initialize context_ parameters.
+ if (!saft_fbs::FillParameters(*model_, &context_)) {
+ // FillParameters already performs error logging.
+ return;
+ }
+
+ // Init languages_.
+ const string known_languages_str = context_.Get("supported_languages", "");
+ for (StringPiece sp : LiteStrSplit(known_languages_str, ',')) {
+ languages_.emplace_back(sp);
+ }
+ if (languages_.empty()) {
+ SAFTM_LOG(ERROR) << "Unable to find list of supported_languages";
+ return;
+ }
+
+ // Init nn_params_.
+ if (!InitNetworkParams()) {
+ // InitNetworkParams already performs error logging.
+ return;
+ }
+
+ // Everything looks fine.
+ valid_ = true;
+}
+
+bool ModelProviderFromFlatbuffer::InitNetworkParams() {
+ const string kInputName = "language-identifier-network";
+ StringPiece bytes =
+ saft_fbs::GetInputBytes(saft_fbs::GetInputByName(model_, kInputName));
+ if ((bytes.data() == nullptr) || bytes.empty()) {
+ SAFTM_LOG(ERROR) << "Unable to get bytes for model input " << kInputName;
+ return false;
+ }
+ std::unique_ptr<EmbeddingNetworkParamsFromFlatbuffer> nn_params_from_fb(
+ new EmbeddingNetworkParamsFromFlatbuffer(bytes));
+ if (!nn_params_from_fb->is_valid()) {
+ SAFTM_LOG(ERROR) << "EmbeddingNetworkParamsFromFlatbuffer not valid";
+ return false;
+ }
+ nn_params_ = std::move(nn_params_from_fb);
+ return true;
+}
+
+} // namespace lang_id
+} // namespace mobile
+} // namespace nlp_saft
diff --git a/lang_id/fb_model/model-provider-from-fb.h b/lang_id/fb_model/model-provider-from-fb.h
new file mode 100644
index 0000000..d25c903
--- /dev/null
+++ b/lang_id/fb_model/model-provider-from-fb.h
@@ -0,0 +1,118 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_FB_MODEL_MODEL_PROVIDER_FROM_FB_H_
+#define NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_FB_MODEL_MODEL_PROVIDER_FROM_FB_H_
+
+#include <cstddef>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "lang_id/common/fel/task-context.h"
+#include "lang_id/common/file/mmap.h"
+#include "lang_id/common/flatbuffers/model_generated.h"
+#include "lang_id/common/lite_strings/stringpiece.h"
+#include "lang_id/model-provider.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+namespace lang_id {
+
+// ModelProvider for LangId, based on a SAFT model in flatbuffer format.
+class ModelProviderFromFlatbuffer : public ModelProvider {
+ public:
+ // Constructs a model provider based on a flatbuffer-format SAFT model from
+ // |filename|.
+ explicit ModelProviderFromFlatbuffer(const string &filename);
+
+ // Constructs a model provider based on a flatbuffer-format SAFT model from
+ // file descriptor |fd|.
+ explicit ModelProviderFromFlatbuffer(int fd);
+
+ // Constructs a model provider from a flatbuffer-format SAFT model the bytes
+ // of which are already in RAM (size bytes starting from address data).
+ // Useful if you "transport" these bytes otherwise than via a normal file
+ // (e.g., if you embed them somehow in your binary).
+ //
+ // IMPORTANT: |data| should be alive during the lifetime of the
+ // newly-constructed ModelProviderFromFlatbuffer. This is trivial to ensure
+ // for data that's statically embedded in your binary, but more complex in
+ // other cases. To avoid overhead (e.g., heap allocation), this method does
+ // not make a private copy of the data. In general, the ownership of the
+ // newly-constructed ModelProviderFromFlatbuffer is immediately passed to a
+ // LangId object (which doesn't pass it further); hence, one needs to make
+ // sure |data| is alive during the lifetime of that LangId object.
+ ModelProviderFromFlatbuffer(const char *data, std::size_t size) {
+ StringPiece model_bytes(data, size);
+ Initialize(model_bytes);
+ }
+
+ ~ModelProviderFromFlatbuffer() override = default;
+
+ const TaskContext *GetTaskContext() const override {
+ return &context_;
+ }
+
+ const EmbeddingNetworkParams *GetNnParams() const override {
+ return nn_params_.get();
+ }
+
+ std::vector<string> GetLanguages() const override {
+ return languages_;
+ }
+
+ private:
+ // Initializes the fields of this class based on the flatbuffer from
+ // |model_bytes|. These bytes are supposed to be the representation of a
+ // Model flatbuffer and should be alive during the lifetime of this object.
+ void Initialize(StringPiece model_bytes);
+
+ // Initializes nn_params_ based on model_.
+ bool InitNetworkParams();
+
+ // If filename-based constructor is used, scoped_mmap_ keeps the file mmapped
+ // during the lifetime of this object, such that references inside the Model
+ // flatbuffer from those bytes remain valid.
+ const std::unique_ptr<ScopedMmap> scoped_mmap_;
+
+ // Pointer to the flatbuffer from
+ //
+ // (a) [if filename constructor was used:] the bytes mmapped by scoped_mmap_
+ // (for safety considerations, see comment for that field), or
+ //
+ // (b) [of (data, size) constructor was used:] the bytes from [data,
+ // data+size). Please read carefully the doc for that constructor.
+ const saft_fbs::Model *model_;
+
+ // Context returned by this model provider. We set its parameters based on
+ // model_, at construction time.
+ TaskContext context_;
+
+ // List of supported languages, see GetLanguages(). We expect this list to be
+ // specified by the ModelParameter named "supported_languages" from model_.
+ std::vector<string> languages_;
+
+ // EmbeddingNetworkParams, see GetNnParams(). Set based on the ModelInput
+ // named "language-identifier-network" from model_.
+ std::unique_ptr<EmbeddingNetworkParams> nn_params_;
+};
+
+} // namespace lang_id
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_FB_MODEL_MODEL_PROVIDER_FROM_FB_H_
diff --git a/lang_id/features/char-ngram-feature.cc b/lang_id/features/char-ngram-feature.cc
new file mode 100644
index 0000000..83d7588
--- /dev/null
+++ b/lang_id/features/char-ngram-feature.cc
@@ -0,0 +1,156 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/features/char-ngram-feature.h"
+
+#include <utility>
+#include <vector>
+
+#include "lang_id/common/fel/feature-types.h"
+#include "lang_id/common/fel/task-context.h"
+#include "lang_id/common/lite_base/logging.h"
+#include "lang_id/common/math/hash.h"
+#include "lang_id/common/utf8.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+namespace lang_id {
+
+bool ContinuousBagOfNgramsFunction::Setup(TaskContext *context) {
+ // Parameters in the feature function descriptor.
+ bool include_terminators = GetBoolParameter("include_terminators", false);
+ if (!include_terminators) {
+ SAFTM_LOG(ERROR) << "No support for include_terminators=true";
+ return false;
+ }
+
+ bool include_spaces = GetBoolParameter("include_spaces", false);
+ if (include_spaces) {
+ SAFTM_LOG(ERROR) << "No support for include_spaces=true";
+ return false;
+ }
+
+ bool use_equal_ngram_weight = GetBoolParameter("use_equal_weight", false);
+ if (use_equal_ngram_weight) {
+ SAFTM_LOG(ERROR) << "No support for use_equal_weight=true";
+ return false;
+ }
+
+ ngram_id_dimension_ = GetIntParameter("id_dim", 10000);
+ ngram_size_ = GetIntParameter("size", 3);
+
+ counts_.assign(ngram_id_dimension_, 0);
+ return true;
+}
+
+bool ContinuousBagOfNgramsFunction::Init(TaskContext *context) {
+ set_feature_type(new NumericFeatureType(name(), ngram_id_dimension_));
+ return true;
+}
+
+int ContinuousBagOfNgramsFunction::ComputeNgramCounts(
+ const LightSentence &sentence) const {
+ SAFTM_CHECK_EQ(counts_.size(), ngram_id_dimension_);
+ SAFTM_CHECK_EQ(non_zero_count_indices_.size(), 0);
+
+ int total_count = 0;
+
+ for (const string &word : sentence) {
+ const char *const word_end = word.data() + word.size();
+
+ // Set ngram_start at the start of the current token (word).
+ const char *ngram_start = word.data();
+
+ // Set ngram_end ngram_size UTF8 characters after ngram_start. Note: each
+ // UTF8 character contains between 1 and 4 bytes.
+ const char *ngram_end = ngram_start;
+ int num_utf8_chars = 0;
+ do {
+ ngram_end += utils::OneCharLen(ngram_end);
+ num_utf8_chars++;
+ } while ((num_utf8_chars < ngram_size_) && (ngram_end < word_end));
+
+ if (num_utf8_chars < ngram_size_) {
+ // Current token is so small, it does not contain a single ngram of
+ // ngram_size UTF8 characters. Not much we can do in this case ...
+ continue;
+ }
+
+ // At this point, [ngram_start, ngram_end) is the first ngram of ngram_size
+ // UTF8 characters from current token.
+ while (true) {
+ // Compute ngram id: hash(ngram) % ngram_id_dimension
+ int ngram_id = (
+ utils::Hash32WithDefaultSeed(ngram_start, ngram_end - ngram_start)
+ % ngram_id_dimension_);
+
+ // Use a reference to the actual count, such that we can both test whether
+ // the count was 0 and increment it without perfoming two lookups.
+ int &ref_to_count_for_ngram = counts_[ngram_id];
+ if (ref_to_count_for_ngram == 0) {
+ non_zero_count_indices_.push_back(ngram_id);
+ }
+ ref_to_count_for_ngram++;
+ total_count++;
+ if (ngram_end >= word_end) {
+ break;
+ }
+
+ // Advance both ngram_start and ngram_end by one UTF8 character. This
+ // way, the number of UTF8 characters between them remains constant
+ // (ngram_size).
+ ngram_start += utils::OneCharLen(ngram_start);
+ ngram_end += utils::OneCharLen(ngram_end);
+ }
+ } // end of loop over tokens.
+
+ return total_count;
+}
+
+void ContinuousBagOfNgramsFunction::Evaluate(const WorkspaceSet &workspaces,
+ const LightSentence &sentence,
+ FeatureVector *result) const {
+ // NOTE: we use std::* constructs (instead of absl::Mutex & co) to simplify
+ // porting to Android and to avoid pulling in absl (which increases our code
+ // size).
+ std::lock_guard<std::mutex> mlock(state_mutex_);
+
+ // Find the char ngram counts.
+ int total_count = ComputeNgramCounts(sentence);
+
+ // Populate the feature vector.
+ const float norm = static_cast<float>(total_count);
+
+ // TODO(salcianu): explore treating dense vectors (i.e., many non-zero
+ // elements) separately.
+ for (int ngram_id : non_zero_count_indices_) {
+ const float weight = counts_[ngram_id] / norm;
+ FloatFeatureValue value(ngram_id, weight);
+ result->add(feature_type(), value.discrete_value);
+
+ // Clear up counts_, for the next invocation of Evaluate().
+ counts_[ngram_id] = 0;
+ }
+
+ // Clear up non_zero_count_indices_, for the next invocation of Evaluate().
+ non_zero_count_indices_.clear();
+}
+
+SAFTM_STATIC_REGISTRATION(ContinuousBagOfNgramsFunction);
+
+} // namespace lang_id
+} // namespace mobile
+} // namespace nlp_saft
diff --git a/lang_id/features/char-ngram-feature.h b/lang_id/features/char-ngram-feature.h
new file mode 100644
index 0000000..db0f83e
--- /dev/null
+++ b/lang_id/features/char-ngram-feature.h
@@ -0,0 +1,98 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_FEATURES_CHAR_NGRAM_FEATURE_H_
+#define NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_FEATURES_CHAR_NGRAM_FEATURE_H_
+
+#include <mutex> // NOLINT: see comments for state_mutex_
+#include <string>
+
+#include "lang_id/common/fel/feature-extractor.h"
+#include "lang_id/common/fel/task-context.h"
+#include "lang_id/common/fel/workspace.h"
+#include "lang_id/features/light-sentence-features.h"
+#include "lang_id/light-sentence.h"
+
+// TODO(abakalov): Add a test.
+namespace libtextclassifier3 {
+namespace mobile {
+namespace lang_id {
+
+// Class for computing continuous char ngram features.
+//
+// Feature function descriptor parameters:
+// include_terminators(bool, false):
+// If 'true', then splits the text based on spaces to get tokens, adds "^"
+// to the beginning of each token, and adds "$" to the end of each token.
+// NOTE: currently, we support only include_terminators=true.
+// include_spaces(bool, false):
+// If 'true', then includes char ngrams containing spaces.
+// NOTE: currently, we support only include_spaces=false.
+// use_equal_weight(bool, false):
+// If 'true', then weighs each unique ngram by 1.0 / (number of unique
+// ngrams in the input). Otherwise, weighs each unique ngram by (ngram
+// count) / (total number of ngrams).
+// NOTE: currently, we support only use_equal_weight=false.
+// id_dim(int, 10000):
+// The integer id of each char ngram is computed as follows:
+// Hash32WithDefault(char ngram) % id_dim.
+// size(int, 3):
+// Only ngrams of this size will be extracted.
+//
+// NOTE: this class is not thread-safe. TODO(salcianu): make it thread-safe.
+class ContinuousBagOfNgramsFunction : public LightSentenceFeature {
+ public:
+ bool Setup(TaskContext *context) override;
+ bool Init(TaskContext *context) override;
+
+ // Appends the features computed from the sentence to the feature vector.
+ void Evaluate(const WorkspaceSet &workspaces, const LightSentence &sentence,
+ FeatureVector *result) const override;
+
+ SAFTM_DEFINE_REGISTRATION_METHOD("continuous-bag-of-ngrams",
+ ContinuousBagOfNgramsFunction);
+
+ private:
+ // Auxiliary for Evaluate(). Fills counts_ and non_zero_count_indices_ (see
+ // below), and returns the total ngram count.
+ int ComputeNgramCounts(const LightSentence &sentence) const;
+
+ // Guards counts_ and non_zero_count_indices_. NOTE: we use std::* constructs
+ // (instead of absl::Mutex & co) to simplify porting to Android and to avoid
+ // pulling in absl (which increases our code size).
+ mutable std::mutex state_mutex_;
+
+ // counts_[i] is the count of all ngrams with id i. Work data for Evaluate().
+ // NOTE: we declare this vector as a field, such that its underlying capacity
+ // stays allocated in between calls to Evaluate().
+ mutable std::vector<int> counts_;
+
+ // Indices of non-zero elements of counts_. See comments for counts_.
+ mutable std::vector<int> non_zero_count_indices_;
+
+ // The integer id of each char ngram is computed as follows:
+ // Hash32WithDefaultSeed(char_ngram) % ngram_id_dimension_.
+ int ngram_id_dimension_;
+
+ // Only ngrams of size ngram_size_ will be extracted.
+ int ngram_size_;
+};
+
+} // namespace lang_id
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_FEATURES_CHAR_NGRAM_FEATURE_H_
diff --git a/lang_id/features/light-sentence-features.cc b/lang_id/features/light-sentence-features.cc
new file mode 100644
index 0000000..7f1d878
--- /dev/null
+++ b/lang_id/features/light-sentence-features.cc
@@ -0,0 +1,27 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/features/light-sentence-features.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+// Registry for the features on whole light sentences.
+SAFTM_DEFINE_CLASS_REGISTRY_NAME("light sentence feature function",
+ lang_id::LightSentenceFeature);
+
+} // namespace mobile
+} // namespace nlp_saft
diff --git a/lang_id/features/light-sentence-features.h b/lang_id/features/light-sentence-features.h
new file mode 100644
index 0000000..cc85878
--- /dev/null
+++ b/lang_id/features/light-sentence-features.h
@@ -0,0 +1,41 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_FEATURES_LIGHT_SENTENCE_FEATURES_H_
+#define NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_FEATURES_LIGHT_SENTENCE_FEATURES_H_
+
+#include "lang_id/common/fel/feature-extractor.h"
+#include "lang_id/common/registry.h"
+#include "lang_id/light-sentence.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+namespace lang_id {
+
+// Feature function that extracts features from LightSentences.
+typedef FeatureFunction<LightSentence> LightSentenceFeature;
+
+// Feature extractor for LightSentences.
+typedef FeatureExtractor<LightSentence> LightSentenceExtractor;
+
+} // namespace lang_id
+
+SAFTM_DECLARE_CLASS_REGISTRY_NAME(lang_id::LightSentenceFeature);
+
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_FEATURES_LIGHT_SENTENCE_FEATURES_H_
diff --git a/lang_id/features/relevant-script-feature.cc b/lang_id/features/relevant-script-feature.cc
new file mode 100644
index 0000000..0fde87b
--- /dev/null
+++ b/lang_id/features/relevant-script-feature.cc
@@ -0,0 +1,109 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/features/relevant-script-feature.h"
+
+#include <string>
+
+#include "lang_id/common/fel/feature-types.h"
+#include "lang_id/common/fel/task-context.h"
+#include "lang_id/common/fel/workspace.h"
+#include "lang_id/common/lite_base/logging.h"
+#include "lang_id/common/utf8.h"
+#include "lang_id/script/script-detector.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+namespace lang_id {
+
+bool RelevantScriptFeature::Setup(TaskContext *context) {
+ string script_detector_name = GetParameter(
+ "script_detector_name", /* default_value = */ "tiny-script-detector");
+
+ // We don't use absl::WrapUnique, nor the rest of absl, see http://b/71873194
+ script_detector_.reset(ScriptDetector::Create(script_detector_name));
+ if (script_detector_ == nullptr) {
+ // This means ScriptDetector::Create() could not find the requested
+ // script_detector_name. In that case, Create() already logged an error
+ // message.
+ return false;
+ }
+
+ // We use default value 172 because this is the number of scripts supported by
+ // the first model we trained with this feature. See http://b/70617713.
+ // Newer models may support more scripts.
+ num_supported_scripts_ = GetIntParameter("num_supported_scripts", 172);
+ return true;
+}
+
+bool RelevantScriptFeature::Init(TaskContext *context) {
+ set_feature_type(new NumericFeatureType(name(), num_supported_scripts_));
+ return true;
+}
+
+void RelevantScriptFeature::Evaluate(
+ const WorkspaceSet &workspaces, const LightSentence &sentence,
+ FeatureVector *result) const {
+ // counts[s] is the number of characters with script s.
+ std::vector<int> counts(num_supported_scripts_);
+ int total_count = 0;
+ for (const string &word : sentence) {
+ const char *const word_end = word.data() + word.size();
+ const char *curr = word.data();
+
+ // Skip over token start '^'.
+ SAFTM_DCHECK_EQ(*curr, '^');
+ curr += utils::OneCharLen(curr);
+ while (true) {
+ const int num_bytes = utils::OneCharLen(curr);
+
+ int script = script_detector_->GetScript(curr, num_bytes);
+
+ // We do this update and the if (...) break below *before* incrementing
+ // counts[script] in order to skip the token end '$'.
+ curr += num_bytes;
+ if (curr >= word_end) {
+ SAFTM_DCHECK_EQ(*(curr - num_bytes), '$');
+ break;
+ }
+ SAFTM_DCHECK_GE(script, 0);
+
+ if (script < num_supported_scripts_) {
+ counts[script]++;
+ total_count++;
+ } else {
+ // Unsupported script: this usually indicates a script that is
+ // recognized by newer versions of the code, after the model was
+ // trained. E.g., new code running with old model.
+ }
+ }
+ }
+
+ for (int script_id = 0; script_id < num_supported_scripts_; ++script_id) {
+ int count = counts[script_id];
+ if (count > 0) {
+ const float weight = static_cast<float>(count) / total_count;
+ FloatFeatureValue value(script_id, weight);
+ result->add(feature_type(), value.discrete_value);
+ }
+ }
+}
+
+SAFTM_STATIC_REGISTRATION(RelevantScriptFeature);
+
+} // namespace lang_id
+} // namespace mobile
+} // namespace nlp_saft
diff --git a/lang_id/features/relevant-script-feature.h b/lang_id/features/relevant-script-feature.h
new file mode 100644
index 0000000..57c5a1f
--- /dev/null
+++ b/lang_id/features/relevant-script-feature.h
@@ -0,0 +1,68 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_FEATURES_RELEVANT_SCRIPT_FEATURE_H_
+#define NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_FEATURES_RELEVANT_SCRIPT_FEATURE_H_
+
+#include <memory>
+
+#include "lang_id/common/fel/feature-extractor.h"
+#include "lang_id/common/fel/task-context.h"
+#include "lang_id/common/fel/workspace.h"
+#include "lang_id/features/light-sentence-features.h"
+#include "lang_id/light-sentence.h"
+#include "lang_id/script/script-detector.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+namespace lang_id {
+
+// Given a sentence, generates one FloatFeatureValue for each "relevant" Unicode
+// script (see below): each such feature indicates the script and the ratio of
+// UTF8 characters in that script, in the given sentence.
+//
+// What is a relevant script? Recognizing all 100+ Unicode scripts would
+// require too much code size and runtime. Instead, we focus only on a few
+// scripts that communicate a lot of language information: e.g., the use of
+// Hiragana characters almost always indicates Japanese, so Hiragana is a
+// "relevant" script for us. The Latin script is used by dozens of language, so
+// Latin is not relevant in this context.
+class RelevantScriptFeature : public LightSentenceFeature {
+ public:
+ bool Setup(TaskContext *context) override;
+ bool Init(TaskContext *context) override;
+
+ // Appends the features computed from the sentence to the feature vector.
+ void Evaluate(const WorkspaceSet &workspaces,
+ const LightSentence &sentence,
+ FeatureVector *result) const override;
+
+ SAFTM_DEFINE_REGISTRATION_METHOD("continuous-bag-of-relevant-scripts",
+ RelevantScriptFeature);
+
+ private:
+ // Detects script of individual UTF8 characters.
+ std::unique_ptr<ScriptDetector> script_detector_;
+
+ // Current model supports scripts in [0, num_supported_scripts_).
+ int num_supported_scripts_ = 0;
+};
+
+} // namespace lang_id
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_FEATURES_RELEVANT_SCRIPT_FEATURE_H_
diff --git a/lang_id/lang-id.cc b/lang_id/lang-id.cc
new file mode 100644
index 0000000..c892329
--- /dev/null
+++ b/lang_id/lang-id.cc
@@ -0,0 +1,288 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/lang-id.h"
+
+#include <stdio.h>
+
+#include <algorithm>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "lang_id/common/embedding-feature-interface.h"
+#include "lang_id/common/embedding-network-params.h"
+#include "lang_id/common/embedding-network.h"
+#include "lang_id/common/fel/feature-extractor.h"
+#include "lang_id/common/lite_base/logging.h"
+#include "lang_id/common/lite_strings/numbers.h"
+#include "lang_id/common/lite_strings/str-split.h"
+#include "lang_id/common/lite_strings/stringpiece.h"
+#include "lang_id/common/math/algorithm.h"
+#include "lang_id/common/math/softmax.h"
+#include "lang_id/custom-tokenizer.h"
+#include "lang_id/features/light-sentence-features.h"
+#include "lang_id/light-sentence.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+namespace lang_id {
+
+namespace {
+// Default value for the confidence threshold. If the confidence of the top
+// prediction is below this threshold, then FindLanguage() returns
+// LangId::kUnknownLanguageCode. Note: this is just a default value; if the
+// TaskSpec from the model specifies a "reliability_thresh" parameter, then we
+// use that value instead. Note: for legacy reasons, our code and comments use
+// the terms "confidence", "probability" and "reliability" equivalently.
+static const float kDefaultConfidenceThreshold = 0.50f;
+} // namespace
+
+// Class that performs all work behind LangId.
+class LangIdImpl {
+ public:
+ explicit LangIdImpl(std::unique_ptr<ModelProvider> model_provider)
+ : model_provider_(std::move(model_provider)),
+ lang_id_brain_interface_("language_identifier") {
+ // Note: in the code below, we set valid_ to true only if all initialization
+ // steps completed successfully. Otherwise, we return early, leaving valid_
+ // to its default value false.
+ if (!model_provider_ || !model_provider_->is_valid()) {
+ SAFTM_LOG(ERROR) << "Invalid model provider";
+ return;
+ }
+
+ auto *nn_params = model_provider_->GetNnParams();
+ if (!nn_params) {
+ SAFTM_LOG(ERROR) << "No NN params";
+ return;
+ }
+ network_.reset(new EmbeddingNetwork(nn_params));
+
+ languages_ = model_provider_->GetLanguages();
+ if (languages_.empty()) {
+ SAFTM_LOG(ERROR) << "No known languages";
+ return;
+ }
+
+ TaskContext context = *model_provider_->GetTaskContext();
+ if (!Setup(&context)) {
+ SAFTM_LOG(ERROR) << "Unable to Setup() LangId";
+ return;
+ }
+ if (!Init(&context)) {
+ SAFTM_LOG(ERROR) << "Unable to Init() LangId";
+ return;
+ }
+ valid_ = true;
+ }
+
+ string FindLanguage(StringPiece text) const {
+ // NOTE: it would be wasteful to implement this method in terms of
+ // FindLanguages(). We just need the most likely language and its
+ // probability; no need to compute (and allocate) a vector of pairs for all
+ // languages, nor to compute probabilities for all non-top languages.
+ if (!is_valid()) {
+ return LangId::kUnknownLanguageCode;
+ }
+
+ std::vector<float> scores;
+ ComputeScores(text, &scores);
+
+ int prediction_id = GetArgMax(scores);
+ const string language = GetLanguageForSoftmaxLabel(prediction_id);
+ float probability = ComputeSoftmaxProbability(scores, prediction_id);
+ SAFTM_DLOG(INFO) << "Predicted " << language
+ << " with prob: " << probability << " for \"" << text
+ << "\"";
+
+ // Find confidence threshold for language.
+ float threshold = default_threshold_;
+ auto it = per_lang_thresholds_.find(language);
+ if (it != per_lang_thresholds_.end()) {
+ threshold = it->second;
+ }
+ if (probability < threshold) {
+ SAFTM_DLOG(INFO) << " below threshold => "
+ << LangId::kUnknownLanguageCode;
+ return LangId::kUnknownLanguageCode;
+ }
+ return language;
+ }
+
+ void FindLanguages(StringPiece text, LangIdResult *result) const {
+ if (result == nullptr) return;
+
+ result->predictions.clear();
+ if (!is_valid()) {
+ result->predictions.emplace_back(LangId::kUnknownLanguageCode, 1);
+ return;
+ }
+
+ std::vector<float> scores;
+ ComputeScores(text, &scores);
+
+ // Compute and sort softmax in descending order by probability and convert
+ // IDs to language code strings. When probabilities are equal, we sort by
+ // language code string in ascending order.
+ std::vector<float> softmax = ComputeSoftmax(scores);
+
+ for (int i = 0; i < softmax.size(); ++i) {
+ result->predictions.emplace_back(GetLanguageForSoftmaxLabel(i),
+ softmax[i]);
+ }
+
+ // Sort the resulting language predictions by probability in descending
+ // order.
+ std::sort(result->predictions.begin(), result->predictions.end(),
+ [](const std::pair<string, float> &a,
+ const std::pair<string, float> &b) {
+ if (a.second == b.second) {
+ return a.first.compare(b.first) < 0;
+ } else {
+ return a.second > b.second;
+ }
+ });
+ }
+
+ bool is_valid() const { return valid_; }
+
+ int GetModelVersion() const { return model_version_; }
+
+ // Returns a property stored in the model file.
+ template <typename T, typename R>
+ R GetProperty(const string &property, T default_value) const {
+ return model_provider_->GetTaskContext()->Get(property, default_value);
+ }
+
+ private:
+ bool Setup(TaskContext *context) {
+ tokenizer_.Setup(context);
+ if (!lang_id_brain_interface_.SetupForProcessing(context)) return false;
+ default_threshold_ =
+ context->Get("reliability_thresh", kDefaultConfidenceThreshold);
+
+ // Parse task parameter "per_lang_reliability_thresholds", fill
+ // per_lang_thresholds_.
+ const string thresholds_str =
+ context->Get("per_lang_reliability_thresholds", "");
+ std::vector<StringPiece> tokens = LiteStrSplit(thresholds_str, ',');
+ for (const auto &token : tokens) {
+ if (token.empty()) continue;
+ std::vector<StringPiece> parts = LiteStrSplit(token, '=');
+ float threshold = 0.0f;
+ if ((parts.size() == 2) && LiteAtof(parts[1], &threshold)) {
+ per_lang_thresholds_[string(parts[0])] = threshold;
+ } else {
+ SAFTM_LOG(ERROR) << "Broken token: \"" << token << "\"";
+ }
+ }
+ model_version_ = context->Get("model_version", model_version_);
+ return true;
+ }
+
+ bool Init(TaskContext *context) {
+ return lang_id_brain_interface_.InitForProcessing(context);
+ }
+
+ // Extracts features for |text|, runs them through the feed-forward neural
+ // network, and computes the output scores (activations from the last layer).
+ // These scores can be used to compute the softmax probabilities for our
+ // labels (in this case, the languages).
+ void ComputeScores(StringPiece text, std::vector<float> *scores) const {
+ // Create a Sentence storing the input text.
+ LightSentence sentence;
+ tokenizer_.Tokenize(text, &sentence);
+
+ std::vector<FeatureVector> features =
+ lang_id_brain_interface_.GetFeaturesNoCaching(&sentence);
+
+ // Run feed-forward neural network to compute scores.
+ network_->ComputeFinalScores(features, scores);
+ }
+
+ // Returns language code for a softmax label. See comments for languages_
+ // field. If label is out of range, returns LangId::kUnknownLanguageCode.
+ string GetLanguageForSoftmaxLabel(int label) const {
+ if ((label >= 0) && (label < languages_.size())) {
+ return languages_[label];
+ } else {
+ SAFTM_LOG(ERROR) << "Softmax label " << label << " outside range [0, "
+ << languages_.size() << ")";
+ return LangId::kUnknownLanguageCode;
+ }
+ }
+
+ std::unique_ptr<ModelProvider> model_provider_;
+
+ TokenizerForLangId tokenizer_;
+
+ EmbeddingFeatureInterface<LightSentenceExtractor, LightSentence>
+ lang_id_brain_interface_;
+
+ // Neural network to use for scoring.
+ std::unique_ptr<EmbeddingNetwork> network_;
+
+ // True if this object is ready to perform language predictions.
+ bool valid_ = false;
+
+ // Only predictions with a probability (confidence) above this threshold are
+ // reported. Otherwise, we report LangId::kUnknownLanguageCode.
+ float default_threshold_ = kDefaultConfidenceThreshold;
+
+ std::unordered_map<string, float> per_lang_thresholds_;
+
+ // Recognized languages: softmax label i means languages_[i] (something like
+ // "en", "fr", "ru", etc).
+ std::vector<string> languages_;
+
+ // Version of the model used by this LangIdImpl object. Zero means that the
+ // model version could not be determined.
+ int model_version_ = 0;
+};
+
+const char LangId::kUnknownLanguageCode[] = "und";
+
+LangId::LangId(std::unique_ptr<ModelProvider> model_provider)
+ : pimpl_(new LangIdImpl(std::move(model_provider))) {}
+
+LangId::~LangId() = default;
+
+string LangId::FindLanguage(const char *data, size_t num_bytes) const {
+ StringPiece text(data, num_bytes);
+ return pimpl_->FindLanguage(text);
+}
+
+void LangId::FindLanguages(const char *data, size_t num_bytes,
+ LangIdResult *result) const {
+ SAFTM_DCHECK(result) << "LangIdResult must not be null.";
+ StringPiece text(data, num_bytes);
+ pimpl_->FindLanguages(text, result);
+}
+
+bool LangId::is_valid() const { return pimpl_->is_valid(); }
+
+int LangId::GetModelVersion() const { return pimpl_->GetModelVersion(); }
+
+float LangId::GetFloatProperty(const string &property,
+ float default_value) const {
+ return pimpl_->GetProperty<float, float>(property, default_value);
+}
+
+} // namespace lang_id
+} // namespace mobile
+} // namespace nlp_saft
diff --git a/lang_id/lang-id.h b/lang_id/lang-id.h
new file mode 100644
index 0000000..94af0c3
--- /dev/null
+++ b/lang_id/lang-id.h
@@ -0,0 +1,137 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_LANG_ID_H_
+#define NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_LANG_ID_H_
+
+
+#include <stddef.h>
+
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "lang_id/common/lite_base/macros.h"
+#include "lang_id/model-provider.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+namespace lang_id {
+
+// Forward-declaration of the class that performs all underlying work.
+class LangIdImpl;
+
+struct LangIdResult {
+ // An n-best list of possible language codes for a given input sorted in
+ // descending order according to each code's respective probability.
+ //
+ // This list is guaranteed to be non-empty after calling
+ // LangId::FindLanguages. The most likely language code is always the first
+ // item in this array.
+ //
+ // If the model cannot make a prediction, this array contains a single result:
+ // a language code LangId::kUnknownLanguageCode with probability 1.
+ std::vector<std::pair<string, float>> predictions;
+};
+
+// Class for detecting the language of a document.
+//
+// Note: this class does not handle the details of loading the actual model.
+// Those details have been "outsourced" to the ModelProvider class.
+//
+// This class is thread safe.
+class LangId {
+ public:
+ // Standard BCP-47 language code for Unknown/Undetermined language.
+ static const char kUnknownLanguageCode[];
+
+ // Constructs a LangId object, based on |model_provider|.
+ //
+ // Note: we don't crash if we detect a problem at construction time (e.g., the
+ // model provider can't read an underlying file). Instead, we mark the
+ // newly-constructed object as invalid; clients can invoke FindLanguage() on
+ // an invalid object: nothing crashes, but accuracy will be bad.
+ explicit LangId(std::unique_ptr<ModelProvider> model_provider);
+
+ virtual ~LangId();
+
+ // Computes the an n-best list of language codes and probabilities
+ // corresponding to the most likely languages the given input text is written
+ // in. The list is sorted in descending order by language probability.
+ //
+ // The input text consists of the |num_bytes| bytes that starts at |data|.
+ //
+ // Note: If this LangId object is not valid (see is_valid()) or if this LangId
+ // object can't make a prediction, this method sets the LangIdResult to
+ // contain a single entry with kUnknownLanguageCode with probability 1.
+ void FindLanguages(const char *data, size_t num_bytes,
+ LangIdResult *result) const;
+
+ // Convenience version of FindLanguages(const char *, size_t, LangIdResult *).
+ void FindLanguages(const string &text, LangIdResult *result) const {
+ FindLanguages(text.data(), text.size(), result);
+ }
+
+ // Returns language code for the most likely language for a piece of text.
+ //
+ // The input text consists of the |num_bytes| bytes that start at |data|.
+ //
+ // Note: this method reports the most likely (1-best) language only if its
+ // probability is high enough; otherwise, it returns
+ // LangId::kUnknownLanguageCode. The specific probability threshold is tuned
+ // to the needs of an early client. If you need a different threshold, you
+ // can use FindLanguages (plural) to get the full LangIdResult, and apply your
+ // own threshold.
+ //
+ // Note: if this LangId object is not valid (see is_valid()) or if this LangId
+ // object can't make a prediction, then this method returns
+ // LangId::kUnknownLanguageCode.
+ //
+ string FindLanguage(const char *data, size_t num_bytes) const;
+
+ // Convenience version of FindLanguage(const char *, size_t).
+ string FindLanguage(const string &text) const {
+ return FindLanguage(text.data(), text.size());
+ }
+
+ // Returns true if this object has been correctly initialized and is ready to
+ // perform predictions. For more info, see doc for LangId
+ // constructor above.
+ bool is_valid() const;
+
+ // Returns the version of the model used by this LangId object. On success,
+ // the returned version number is a strictly positive integer. Returns 0 if
+ // the model version can not be determined (e.g., for old models that do not
+ // specify a version number).
+ int GetModelVersion() const;
+
+ // Returns a typed property stored in the model file.
+ float GetFloatProperty(const string &property, float default_value) const;
+
+ private:
+ // Pimpl ("pointer to implementation") pattern, to hide all internals from our
+ // clients.
+ std::unique_ptr<LangIdImpl> pimpl_;
+
+ SAFTM_DISALLOW_COPY_AND_ASSIGN(LangId);
+};
+
+} // namespace lang_id
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_LANG_ID_H_
diff --git a/lang_id/lang-id_jni.cc b/lang_id/lang-id_jni.cc
new file mode 100644
index 0000000..6696298
--- /dev/null
+++ b/lang_id/lang-id_jni.cc
@@ -0,0 +1,134 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/lang-id_jni.h"
+
+#include <jni.h>
+#include <type_traits>
+#include <vector>
+
+#include "utils/base/logging.h"
+#include "utils/java/scoped_local_ref.h"
+#include "lang_id/fb_model/lang-id-from-fb.h"
+#include "lang_id/lang-id.h"
+
+using libtextclassifier3::ScopedLocalRef;
+using libtextclassifier3::ToStlString;
+using libtextclassifier3::mobile::lang_id::GetLangIdFromFlatbufferFile;
+using libtextclassifier3::mobile::lang_id::GetLangIdFromFlatbufferFileDescriptor;
+using libtextclassifier3::mobile::lang_id::LangId;
+using libtextclassifier3::mobile::lang_id::LangIdResult;
+
+namespace {
+jobjectArray LangIdResultToJObjectArray(JNIEnv* env,
+ const LangIdResult& lang_id_result) {
+ const ScopedLocalRef<jclass> result_class(
+ env->FindClass(TC3_PACKAGE_PATH TC3_LANG_ID_CLASS_NAME_STR
+ "$LanguageResult"),
+ env);
+ if (!result_class) {
+ TC3_LOG(ERROR) << "Couldn't find LanguageResult class.";
+ return nullptr;
+ }
+
+ // clang-format off
+ const std::vector<std::pair<std::string, float>>& predictions =
+ lang_id_result.predictions;
+ // clang-format on
+ const jmethodID result_class_constructor =
+ env->GetMethodID(result_class.get(), "<init>", "(Ljava/lang/String;F)V");
+ const jobjectArray results =
+ env->NewObjectArray(predictions.size(), result_class.get(), nullptr);
+ for (int i = 0; i < predictions.size(); i++) {
+ ScopedLocalRef<jobject> result(
+ env->NewObject(result_class.get(), result_class_constructor,
+ env->NewStringUTF(predictions[i].first.c_str()),
+ static_cast<jfloat>(predictions[i].second)));
+ env->SetObjectArrayElement(results, i, result.get());
+ }
+ return results;
+}
+} // namespace
+
+TC3_JNI_METHOD(jlong, TC3_LANG_ID_CLASS_NAME, nativeNew)
+(JNIEnv* env, jobject thiz, jint fd) {
+ std::unique_ptr<LangId> lang_id = GetLangIdFromFlatbufferFileDescriptor(fd);
+ if (!lang_id->is_valid()) {
+ return reinterpret_cast<jlong>(nullptr);
+ }
+ return reinterpret_cast<jlong>(lang_id.release());
+}
+
+TC3_JNI_METHOD(jlong, TC3_LANG_ID_CLASS_NAME, nativeNewFromPath)
+(JNIEnv* env, jobject thiz, jstring path) {
+ const std::string path_str = ToStlString(env, path);
+ std::unique_ptr<LangId> lang_id = GetLangIdFromFlatbufferFile(path_str);
+ if (!lang_id->is_valid()) {
+ return reinterpret_cast<jlong>(nullptr);
+ }
+ return reinterpret_cast<jlong>(lang_id.release());
+}
+
+TC3_JNI_METHOD(jobjectArray, TC3_LANG_ID_CLASS_NAME, nativeDetectLanguages)
+(JNIEnv* env, jobject clazz, jlong ptr, jstring text) {
+ LangId* model = reinterpret_cast<LangId*>(ptr);
+ if (!model) {
+ return nullptr;
+ }
+
+ const std::string text_str = ToStlString(env, text);
+ LangIdResult result;
+ model->FindLanguages(text_str, &result);
+
+ return LangIdResultToJObjectArray(env, result);
+}
+
+TC3_JNI_METHOD(void, TC3_LANG_ID_CLASS_NAME, nativeClose)
+(JNIEnv* env, jobject clazz, jlong ptr) {
+ if (!ptr) {
+ TC3_LOG(ERROR) << "Trying to close null LangId.";
+ return;
+ }
+ LangId* model = reinterpret_cast<LangId*>(ptr);
+ delete model;
+}
+
+TC3_JNI_METHOD(jint, TC3_LANG_ID_CLASS_NAME, nativeGetVersion)
+(JNIEnv* env, jobject clazz, jlong ptr) {
+ if (!ptr) {
+ return -1;
+ }
+ LangId* model = reinterpret_cast<LangId*>(ptr);
+ return model->GetModelVersion();
+}
+
+TC3_JNI_METHOD(jint, TC3_LANG_ID_CLASS_NAME, nativeGetVersionFromFd)
+(JNIEnv* env, jobject clazz, jint fd) {
+ std::unique_ptr<LangId> lang_id = GetLangIdFromFlatbufferFileDescriptor(fd);
+ if (!lang_id->is_valid()) {
+ return -1;
+ }
+ return lang_id->GetModelVersion();
+}
+
+TC3_JNI_METHOD(jfloat, TC3_LANG_ID_CLASS_NAME, nativeGetLangIdThreshold)
+(JNIEnv* env, jobject thizz, jlong ptr) {
+ if (!ptr) {
+ return -1.0;
+ }
+ LangId* model = reinterpret_cast<LangId*>(ptr);
+ return model->GetFloatProperty("text_classifier_langid_threshold", -1.0);
+}
diff --git a/lang_id/lang-id_jni.h b/lang_id/lang-id_jni.h
new file mode 100644
index 0000000..cd67a4c
--- /dev/null
+++ b/lang_id/lang-id_jni.h
@@ -0,0 +1,61 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// JNI wrapper for LangId.
+
+#ifndef LIBTEXTCLASSIFIER_LANG_ID_LANG_ID_JNI_H_
+#define LIBTEXTCLASSIFIER_LANG_ID_LANG_ID_JNI_H_
+
+#include <jni.h>
+#include <string>
+#include "utils/java/jni-base.h"
+
+#ifndef TC3_LANG_ID_CLASS_NAME
+#define TC3_LANG_ID_CLASS_NAME LangIdModel
+#endif
+
+#define TC3_LANG_ID_CLASS_NAME_STR TC3_ADD_QUOTES(TC3_LANG_ID_CLASS_NAME)
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+TC3_JNI_METHOD(jlong, TC3_LANG_ID_CLASS_NAME, nativeNew)
+(JNIEnv* env, jobject clazz, jint fd);
+
+TC3_JNI_METHOD(jlong, TC3_LANG_ID_CLASS_NAME, nativeNewFromPath)
+(JNIEnv* env, jobject clazz, jstring path);
+
+TC3_JNI_METHOD(jobjectArray, TC3_LANG_ID_CLASS_NAME, nativeDetectLanguages)
+(JNIEnv* env, jobject clazz, jlong ptr, jstring text);
+
+TC3_JNI_METHOD(void, TC3_LANG_ID_CLASS_NAME, nativeClose)
+(JNIEnv* env, jobject clazz, jlong ptr);
+
+TC3_JNI_METHOD(jint, TC3_LANG_ID_CLASS_NAME, nativeGetVersion)
+(JNIEnv* env, jobject clazz, jlong ptr);
+
+TC3_JNI_METHOD(jint, TC3_LANG_ID_CLASS_NAME, nativeGetVersionFromFd)
+(JNIEnv* env, jobject clazz, jint fd);
+
+TC3_JNI_METHOD(jfloat, TC3_LANG_ID_CLASS_NAME, nativeGetLangIdThreshold)
+(JNIEnv* env, jobject thizz, jlong ptr);
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif // LIBTEXTCLASSIFIER_LANG_ID_LANG_ID_JNI_H_
diff --git a/lang_id/light-sentence.h b/lang_id/light-sentence.h
new file mode 100644
index 0000000..2937549
--- /dev/null
+++ b/lang_id/light-sentence.h
@@ -0,0 +1,36 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_LIGHT_SENTENCE_H_
+#define NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_LIGHT_SENTENCE_H_
+
+#include <string>
+#include <vector>
+
+namespace libtextclassifier3 {
+namespace mobile {
+namespace lang_id {
+
+// Very simplified alternative to heavy sentence.proto, for the purpose of
+// LangId. It turns out that in this case, all we need is a vector of strings,
+// which uses a lot less code size than a Sentence proto.
+using LightSentence = std::vector<string>;
+
+} // namespace lang_id
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_LIGHT_SENTENCE_H_
diff --git a/lang_id/model-provider.h b/lang_id/model-provider.h
new file mode 100644
index 0000000..a076871
--- /dev/null
+++ b/lang_id/model-provider.h
@@ -0,0 +1,64 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_MODEL_PROVIDER_H_
+#define NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_MODEL_PROVIDER_H_
+
+#include <string>
+#include <vector>
+
+#include "lang_id/common/embedding-network-params.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+namespace lang_id {
+
+// Interface for accessing parameters for the LangId model.
+//
+// Note: some clients prefer to include the model parameters in the binary,
+// others prefer loading them from a separate file. This file provides a common
+// interface for these alternative mechanisms.
+class ModelProvider {
+ public:
+ virtual ~ModelProvider() = default;
+
+ // Returns true if this ModelProvider has been succesfully constructed (e.g.,
+ // can return false if an underlying model file could not be read). Clients
+ // should not use invalid ModelProviders.
+ bool is_valid() { return valid_; }
+
+ // Returns the TaskContext with parameters for the LangId model. E.g., one
+ // important parameter specifies the features to use.
+ virtual const TaskContext *GetTaskContext() const = 0;
+
+ // Returns parameters for the underlying Neurosis feed-forward neural network.
+ virtual const EmbeddingNetworkParams *GetNnParams() const = 0;
+
+ // Returns list of languages recognized by the model. Each element of the
+ // returned vector should be a BCP-47 language code (e.g., "en", "ro", etc).
+ // Language at index i from the returned vector corresponds to softmax label
+ // i.
+ virtual std::vector<string> GetLanguages() const = 0;
+
+ protected:
+ bool valid_ = false;
+};
+
+} // namespace lang_id
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_MODEL_PROVIDER_H_
diff --git a/lang_id/script/approx-script-data.cc b/lang_id/script/approx-script-data.cc
new file mode 100755
index 0000000..e11d7b7
--- /dev/null
+++ b/lang_id/script/approx-script-data.cc
@@ -0,0 +1,1146 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Internal data for approx-script.cc; see approx-script-data.h
+//
+// DO NOT EDIT BY HAND
+//
+// Generated by
+// lang_id/script/update-script-data.sh
+
+#include "lang_id/script/approx-script-data.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+namespace approx_script_internal {
+
+const int kNumRanges = 367;
+
+const uint32 kRangeFirst[] = {
+ 65, // Range #0: [65, 90, Latin]
+ 97, // Range #1: [97, 122, Latin]
+ 170, // Range #2: [170, 170, Latin]
+ 186, // Range #3: [186, 186, Latin]
+ 192, // Range #4: [192, 214, Latin]
+ 216, // Range #5: [216, 246, Latin]
+ 248, // Range #6: [248, 696, Latin]
+ 736, // Range #7: [736, 740, Latin]
+ 746, // Range #8: [746, 747, Bopomofo]
+ 880, // Range #9: [880, 883, Greek]
+ 885, // Range #10: [885, 893, Greek]
+ 895, // Range #11: [895, 900, Greek]
+ 902, // Range #12: [902, 902, Greek]
+ 904, // Range #13: [904, 993, Greek]
+ 994, // Range #14: [994, 1007, Coptic]
+ 1008, // Range #15: [1008, 1023, Greek]
+ 1024, // Range #16: [1024, 1156, Cyrillic]
+ 1159, // Range #17: [1159, 1327, Cyrillic]
+ 1329, // Range #18: [1329, 1416, Armenian]
+ 1418, // Range #19: [1418, 1423, Armenian]
+ 1425, // Range #20: [1425, 1479, Hebrew]
+ 1488, // Range #21: [1488, 1524, Hebrew]
+ 1536, // Range #22: [1536, 1540, Arabic]
+ 1542, // Range #23: [1542, 1547, Arabic]
+ 1549, // Range #24: [1549, 1562, Arabic]
+ 1564, // Range #25: [1564, 1566, Arabic]
+ 1568, // Range #26: [1568, 1599, Arabic]
+ 1601, // Range #27: [1601, 1610, Arabic]
+ 1622, // Range #28: [1622, 1647, Arabic]
+ 1649, // Range #29: [1649, 1756, Arabic]
+ 1758, // Range #30: [1758, 1791, Arabic]
+ 1792, // Range #31: [1792, 1871, Syriac]
+ 1872, // Range #32: [1872, 1919, Arabic]
+ 1920, // Range #33: [1920, 1969, Thaana]
+ 1984, // Range #34: [1984, 2047, Nko]
+ 2048, // Range #35: [2048, 2110, Samaritan]
+ 2112, // Range #36: [2112, 2142, Mandaic]
+ 2144, // Range #37: [2144, 2154, Syriac]
+ 2208, // Range #38: [2208, 2237, Arabic]
+ 2259, // Range #39: [2259, 2273, Arabic]
+ 2275, // Range #40: [2275, 2303, Arabic]
+ 2304, // Range #41: [2304, 2384, Devanagari]
+ 2389, // Range #42: [2389, 2403, Devanagari]
+ 2406, // Range #43: [2406, 2431, Devanagari]
+ 2432, // Range #44: [2432, 2510, Bengali]
+ 2519, // Range #45: [2519, 2558, Bengali]
+ 2561, // Range #46: [2561, 2641, Gurmukhi]
+ 2649, // Range #47: [2649, 2654, Gurmukhi]
+ 2662, // Range #48: [2662, 2678, Gurmukhi]
+ 2689, // Range #49: [2689, 2768, Gujarati]
+ 2784, // Range #50: [2784, 2801, Gujarati]
+ 2809, // Range #51: [2809, 2815, Gujarati]
+ 2817, // Range #52: [2817, 2893, Oriya]
+ 2902, // Range #53: [2902, 2935, Oriya]
+ 2946, // Range #54: [2946, 3024, Tamil]
+ 3031, // Range #55: [3031, 3031, Tamil]
+ 3046, // Range #56: [3046, 3066, Tamil]
+ 3072, // Range #57: [3072, 3149, Telugu]
+ 3157, // Range #58: [3157, 3162, Telugu]
+ 3168, // Range #59: [3168, 3183, Telugu]
+ 3191, // Range #60: [3191, 3199, Telugu]
+ 3200, // Range #61: [3200, 3277, Kannada]
+ 3285, // Range #62: [3285, 3286, Kannada]
+ 3294, // Range #63: [3294, 3314, Kannada]
+ 3328, // Range #64: [3328, 3455, Malayalam]
+ 3458, // Range #65: [3458, 3551, Sinhala]
+ 3558, // Range #66: [3558, 3572, Sinhala]
+ 3585, // Range #67: [3585, 3642, Thai]
+ 3648, // Range #68: [3648, 3675, Thai]
+ 3713, // Range #69: [3713, 3807, Lao]
+ 3840, // Range #70: [3840, 4052, Tibetan]
+ 4057, // Range #71: [4057, 4058, Tibetan]
+ 4096, // Range #72: [4096, 4255, Myanmar]
+ 4256, // Range #73: [4256, 4295, Georgian]
+ 4301, // Range #74: [4301, 4346, Georgian]
+ 4348, // Range #75: [4348, 4351, Georgian]
+ 4352, // Range #76: [4352, 4607, Hangul]
+ 4608, // Range #77: [4608, 5017, Ethiopic]
+ 5024, // Range #78: [5024, 5117, Cherokee]
+ 5120, // Range #79: [5120, 5759, Canadian_Aboriginal]
+ 5760, // Range #80: [5760, 5788, Ogham]
+ 5792, // Range #81: [5792, 5866, Runic]
+ 5870, // Range #82: [5870, 5880, Runic]
+ 5888, // Range #83: [5888, 5908, Tagalog]
+ 5920, // Range #84: [5920, 5940, Hanunoo]
+ 5952, // Range #85: [5952, 5971, Buhid]
+ 5984, // Range #86: [5984, 6003, Tagbanwa]
+ 6016, // Range #87: [6016, 6121, Khmer]
+ 6128, // Range #88: [6128, 6137, Khmer]
+ 6144, // Range #89: [6144, 6145, Mongolian]
+ 6148, // Range #90: [6148, 6148, Mongolian]
+ 6150, // Range #91: [6150, 6169, Mongolian]
+ 6176, // Range #92: [6176, 6264, Mongolian]
+ 6272, // Range #93: [6272, 6314, Mongolian]
+ 6320, // Range #94: [6320, 6389, Canadian_Aboriginal]
+ 6400, // Range #95: [6400, 6479, Limbu]
+ 6480, // Range #96: [6480, 6516, Tai_Le]
+ 6528, // Range #97: [6528, 6601, New_Tai_Lue]
+ 6608, // Range #98: [6608, 6623, New_Tai_Lue]
+ 6624, // Range #99: [6624, 6655, Khmer]
+ 6656, // Range #100: [6656, 6687, Buginese]
+ 6688, // Range #101: [6688, 6793, Tai_Tham]
+ 6800, // Range #102: [6800, 6809, Tai_Tham]
+ 6816, // Range #103: [6816, 6829, Tai_Tham]
+ 6912, // Range #104: [6912, 7036, Balinese]
+ 7040, // Range #105: [7040, 7103, Sundanese]
+ 7104, // Range #106: [7104, 7155, Batak]
+ 7164, // Range #107: [7164, 7167, Batak]
+ 7168, // Range #108: [7168, 7247, Lepcha]
+ 7248, // Range #109: [7248, 7295, Ol_Chiki]
+ 7296, // Range #110: [7296, 7304, Cyrillic]
+ 7312, // Range #111: [7312, 7359, Georgian]
+ 7360, // Range #112: [7360, 7367, Sundanese]
+ 7424, // Range #113: [7424, 7461, Latin]
+ 7462, // Range #114: [7462, 7466, Greek]
+ 7467, // Range #115: [7467, 7467, Cyrillic]
+ 7468, // Range #116: [7468, 7516, Latin]
+ 7517, // Range #117: [7517, 7521, Greek]
+ 7522, // Range #118: [7522, 7525, Latin]
+ 7526, // Range #119: [7526, 7530, Greek]
+ 7531, // Range #120: [7531, 7543, Latin]
+ 7544, // Range #121: [7544, 7544, Cyrillic]
+ 7545, // Range #122: [7545, 7614, Latin]
+ 7615, // Range #123: [7615, 7615, Greek]
+ 7680, // Range #124: [7680, 7935, Latin]
+ 7936, // Range #125: [7936, 8190, Greek]
+ 8305, // Range #126: [8305, 8305, Latin]
+ 8319, // Range #127: [8319, 8319, Latin]
+ 8336, // Range #128: [8336, 8348, Latin]
+ 8486, // Range #129: [8486, 8486, Greek]
+ 8490, // Range #130: [8490, 8491, Latin]
+ 8498, // Range #131: [8498, 8498, Latin]
+ 8526, // Range #132: [8526, 8526, Latin]
+ 8544, // Range #133: [8544, 8584, Latin]
+ 10240, // Range #134: [10240, 10495, Braille]
+ 11264, // Range #135: [11264, 11358, Glagolitic]
+ 11360, // Range #136: [11360, 11391, Latin]
+ 11392, // Range #137: [11392, 11507, Coptic]
+ 11513, // Range #138: [11513, 11519, Coptic]
+ 11520, // Range #139: [11520, 11559, Georgian]
+ 11565, // Range #140: [11565, 11565, Georgian]
+ 11568, // Range #141: [11568, 11623, Tifinagh]
+ 11631, // Range #142: [11631, 11632, Tifinagh]
+ 11647, // Range #143: [11647, 11647, Tifinagh]
+ 11648, // Range #144: [11648, 11670, Ethiopic]
+ 11680, // Range #145: [11680, 11742, Ethiopic]
+ 11744, // Range #146: [11744, 11775, Cyrillic]
+ 11904, // Range #147: [11904, 12019, Han]
+ 12032, // Range #148: [12032, 12245, Han]
+ 12293, // Range #149: [12293, 12293, Han]
+ 12295, // Range #150: [12295, 12295, Han]
+ 12321, // Range #151: [12321, 12329, Han]
+ 12334, // Range #152: [12334, 12335, Hangul]
+ 12344, // Range #153: [12344, 12347, Han]
+ 12353, // Range #154: [12353, 12438, Hiragana]
+ 12445, // Range #155: [12445, 12447, Hiragana]
+ 12449, // Range #156: [12449, 12538, Katakana]
+ 12541, // Range #157: [12541, 12543, Katakana]
+ 12549, // Range #158: [12549, 12591, Bopomofo]
+ 12593, // Range #159: [12593, 12686, Hangul]
+ 12704, // Range #160: [12704, 12730, Bopomofo]
+ 12784, // Range #161: [12784, 12799, Katakana]
+ 12800, // Range #162: [12800, 12830, Hangul]
+ 12896, // Range #163: [12896, 12926, Hangul]
+ 13008, // Range #164: [13008, 13054, Katakana]
+ 13056, // Range #165: [13056, 13143, Katakana]
+ 13312, // Range #166: [13312, 19893, Han]
+ 19968, // Range #167: [19968, 40943, Han]
+ 40960, // Range #168: [40960, 42182, Yi]
+ 42192, // Range #169: [42192, 42239, Lisu]
+ 42240, // Range #170: [42240, 42539, Vai]
+ 42560, // Range #171: [42560, 42655, Cyrillic]
+ 42656, // Range #172: [42656, 42743, Bamum]
+ 42786, // Range #173: [42786, 42887, Latin]
+ 42891, // Range #174: [42891, 42950, Latin]
+ 42999, // Range #175: [42999, 43007, Latin]
+ 43008, // Range #176: [43008, 43051, Syloti_Nagri]
+ 43072, // Range #177: [43072, 43127, Phags_Pa]
+ 43136, // Range #178: [43136, 43205, Saurashtra]
+ 43214, // Range #179: [43214, 43225, Saurashtra]
+ 43232, // Range #180: [43232, 43263, Devanagari]
+ 43264, // Range #181: [43264, 43309, Kayah_Li]
+ 43311, // Range #182: [43311, 43311, Kayah_Li]
+ 43312, // Range #183: [43312, 43347, Rejang]
+ 43359, // Range #184: [43359, 43359, Rejang]
+ 43360, // Range #185: [43360, 43388, Hangul]
+ 43392, // Range #186: [43392, 43469, Javanese]
+ 43472, // Range #187: [43472, 43487, Javanese]
+ 43488, // Range #188: [43488, 43518, Myanmar]
+ 43520, // Range #189: [43520, 43574, Cham]
+ 43584, // Range #190: [43584, 43615, Cham]
+ 43616, // Range #191: [43616, 43647, Myanmar]
+ 43648, // Range #192: [43648, 43714, Tai_Viet]
+ 43739, // Range #193: [43739, 43743, Tai_Viet]
+ 43744, // Range #194: [43744, 43766, Meetei_Mayek]
+ 43777, // Range #195: [43777, 43798, Ethiopic]
+ 43808, // Range #196: [43808, 43822, Ethiopic]
+ 43824, // Range #197: [43824, 43866, Latin]
+ 43868, // Range #198: [43868, 43876, Latin]
+ 43877, // Range #199: [43877, 43877, Greek]
+ 43878, // Range #200: [43878, 43879, Latin]
+ 43888, // Range #201: [43888, 43967, Cherokee]
+ 43968, // Range #202: [43968, 44025, Meetei_Mayek]
+ 44032, // Range #203: [44032, 55203, Hangul]
+ 55216, // Range #204: [55216, 55291, Hangul]
+ 63744, // Range #205: [63744, 64217, Han]
+ 64256, // Range #206: [64256, 64262, Latin]
+ 64275, // Range #207: [64275, 64279, Armenian]
+ 64285, // Range #208: [64285, 64335, Hebrew]
+ 64336, // Range #209: [64336, 64449, Arabic]
+ 64467, // Range #210: [64467, 64829, Arabic]
+ 64848, // Range #211: [64848, 64967, Arabic]
+ 65008, // Range #212: [65008, 65021, Arabic]
+ 65070, // Range #213: [65070, 65071, Cyrillic]
+ 65136, // Range #214: [65136, 65276, Arabic]
+ 65313, // Range #215: [65313, 65338, Latin]
+ 65345, // Range #216: [65345, 65370, Latin]
+ 65382, // Range #217: [65382, 65391, Katakana]
+ 65393, // Range #218: [65393, 65437, Katakana]
+ 65440, // Range #219: [65440, 65500, Hangul]
+ 65536, // Range #220: [65536, 65629, Linear_B]
+ 65664, // Range #221: [65664, 65786, Linear_B]
+ 65856, // Range #222: [65856, 65934, Greek]
+ 65952, // Range #223: [65952, 65952, Greek]
+ 66176, // Range #224: [66176, 66204, Lycian]
+ 66208, // Range #225: [66208, 66256, Carian]
+ 66304, // Range #226: [66304, 66339, Old_Italic]
+ 66349, // Range #227: [66349, 66351, Old_Italic]
+ 66352, // Range #228: [66352, 66378, Gothic]
+ 66384, // Range #229: [66384, 66426, Old_Permic]
+ 66432, // Range #230: [66432, 66463, Ugaritic]
+ 66464, // Range #231: [66464, 66517, Old_Persian]
+ 66560, // Range #232: [66560, 66639, Deseret]
+ 66640, // Range #233: [66640, 66687, Shavian]
+ 66688, // Range #234: [66688, 66729, Osmanya]
+ 66736, // Range #235: [66736, 66811, Osage]
+ 66816, // Range #236: [66816, 66855, Elbasan]
+ 66864, // Range #237: [66864, 66915, Caucasian_Albanian]
+ 66927, // Range #238: [66927, 66927, Caucasian_Albanian]
+ 67072, // Range #239: [67072, 67382, Linear_A]
+ 67392, // Range #240: [67392, 67413, Linear_A]
+ 67424, // Range #241: [67424, 67431, Linear_A]
+ 67584, // Range #242: [67584, 67647, Cypriot]
+ 67648, // Range #243: [67648, 67679, Imperial_Aramaic]
+ 67680, // Range #244: [67680, 67711, Palmyrene]
+ 67712, // Range #245: [67712, 67742, Nabataean]
+ 67751, // Range #246: [67751, 67759, Nabataean]
+ 67808, // Range #247: [67808, 67829, Hatran]
+ 67835, // Range #248: [67835, 67839, Hatran]
+ 67840, // Range #249: [67840, 67871, Phoenician]
+ 67872, // Range #250: [67872, 67897, Lydian]
+ 67903, // Range #251: [67903, 67903, Lydian]
+ 67968, // Range #252: [67968, 67999, Meroitic_Hieroglyphs]
+ 68000, // Range #253: [68000, 68095, Meroitic_Cursive]
+ 68096, // Range #254: [68096, 68102, Kharoshthi]
+ 68108, // Range #255: [68108, 68168, Kharoshthi]
+ 68176, // Range #256: [68176, 68184, Kharoshthi]
+ 68192, // Range #257: [68192, 68223, Old_South_Arabian]
+ 68224, // Range #258: [68224, 68255, Old_North_Arabian]
+ 68288, // Range #259: [68288, 68342, Manichaean]
+ 68352, // Range #260: [68352, 68415, Avestan]
+ 68416, // Range #261: [68416, 68447, Inscriptional_Parthian]
+ 68448, // Range #262: [68448, 68466, Inscriptional_Pahlavi]
+ 68472, // Range #263: [68472, 68479, Inscriptional_Pahlavi]
+ 68480, // Range #264: [68480, 68497, Psalter_Pahlavi]
+ 68505, // Range #265: [68505, 68508, Psalter_Pahlavi]
+ 68521, // Range #266: [68521, 68527, Psalter_Pahlavi]
+ 68608, // Range #267: [68608, 68680, Old_Turkic]
+ 68736, // Range #268: [68736, 68786, Old_Hungarian]
+ 68800, // Range #269: [68800, 68850, Old_Hungarian]
+ 68858, // Range #270: [68858, 68863, Old_Hungarian]
+ 68864, // Range #271: [68864, 68903, Hanifi_Rohingya]
+ 68912, // Range #272: [68912, 68921, Hanifi_Rohingya]
+ 69216, // Range #273: [69216, 69246, Arabic]
+ 69376, // Range #274: [69376, 69415, Old_Sogdian]
+ 69424, // Range #275: [69424, 69465, Sogdian]
+ 69600, // Range #276: [69600, 69622, Elymaic]
+ 69632, // Range #277: [69632, 69743, Brahmi]
+ 69759, // Range #278: [69759, 69759, Brahmi]
+ 69760, // Range #279: [69760, 69825, Kaithi]
+ 69837, // Range #280: [69837, 69837, Kaithi]
+ 69840, // Range #281: [69840, 69864, Sora_Sompeng]
+ 69872, // Range #282: [69872, 69881, Sora_Sompeng]
+ 69888, // Range #283: [69888, 69958, Chakma]
+ 69968, // Range #284: [69968, 70006, Mahajani]
+ 70016, // Range #285: [70016, 70111, Sharada]
+ 70113, // Range #286: [70113, 70132, Sinhala]
+ 70144, // Range #287: [70144, 70206, Khojki]
+ 70272, // Range #288: [70272, 70313, Multani]
+ 70320, // Range #289: [70320, 70378, Khudawadi]
+ 70384, // Range #290: [70384, 70393, Khudawadi]
+ 70400, // Range #291: [70400, 70457, Grantha]
+ 70460, // Range #292: [70460, 70480, Grantha]
+ 70487, // Range #293: [70487, 70487, Grantha]
+ 70493, // Range #294: [70493, 70516, Grantha]
+ 70656, // Range #295: [70656, 70751, Newa]
+ 70784, // Range #296: [70784, 70855, Tirhuta]
+ 70864, // Range #297: [70864, 70873, Tirhuta]
+ 71040, // Range #298: [71040, 71133, Siddham]
+ 71168, // Range #299: [71168, 71236, Modi]
+ 71248, // Range #300: [71248, 71257, Modi]
+ 71264, // Range #301: [71264, 71276, Mongolian]
+ 71296, // Range #302: [71296, 71352, Takri]
+ 71360, // Range #303: [71360, 71369, Takri]
+ 71424, // Range #304: [71424, 71487, Ahom]
+ 71680, // Range #305: [71680, 71739, Dogra]
+ 71840, // Range #306: [71840, 71922, Warang_Citi]
+ 71935, // Range #307: [71935, 71935, Warang_Citi]
+ 72096, // Range #308: [72096, 72164, Nandinagari]
+ 72192, // Range #309: [72192, 72263, Zanabazar_Square]
+ 72272, // Range #310: [72272, 72354, Soyombo]
+ 72384, // Range #311: [72384, 72440, Pau_Cin_Hau]
+ 72704, // Range #312: [72704, 72773, Bhaiksuki]
+ 72784, // Range #313: [72784, 72812, Bhaiksuki]
+ 72816, // Range #314: [72816, 72886, Marchen]
+ 72960, // Range #315: [72960, 73031, Masaram_Gondi]
+ 73040, // Range #316: [73040, 73049, Masaram_Gondi]
+ 73056, // Range #317: [73056, 73112, Gunjala_Gondi]
+ 73120, // Range #318: [73120, 73129, Gunjala_Gondi]
+ 73440, // Range #319: [73440, 73464, Makasar]
+ 73664, // Range #320: [73664, 73713, Tamil]
+ 73727, // Range #321: [73727, 73727, Tamil]
+ 73728, // Range #322: [73728, 74649, Cuneiform]
+ 74752, // Range #323: [74752, 74868, Cuneiform]
+ 74880, // Range #324: [74880, 75075, Cuneiform]
+ 77824, // Range #325: [77824, 78904, Egyptian_Hieroglyphs]
+ 82944, // Range #326: [82944, 83526, Anatolian_Hieroglyphs]
+ 92160, // Range #327: [92160, 92728, Bamum]
+ 92736, // Range #328: [92736, 92783, Mro]
+ 92880, // Range #329: [92880, 92917, Bassa_Vah]
+ 92928, // Range #330: [92928, 92997, Pahawh_Hmong]
+ 93008, // Range #331: [93008, 93047, Pahawh_Hmong]
+ 93053, // Range #332: [93053, 93071, Pahawh_Hmong]
+ 93760, // Range #333: [93760, 93850, Medefaidrin]
+ 93952, // Range #334: [93952, 94087, Miao]
+ 94095, // Range #335: [94095, 94111, Miao]
+ 94176, // Range #336: [94176, 94176, Tangut]
+ 94177, // Range #337: [94177, 94177, Nushu]
+ 94208, // Range #338: [94208, 100343, Tangut]
+ 100352, // Range #339: [100352, 101106, Tangut]
+ 110592, // Range #340: [110592, 110592, Katakana]
+ 110593, // Range #341: [110593, 110878, Hiragana]
+ 110928, // Range #342: [110928, 110930, Hiragana]
+ 110948, // Range #343: [110948, 110951, Katakana]
+ 110960, // Range #344: [110960, 111355, Nushu]
+ 113664, // Range #345: [113664, 113770, Duployan]
+ 113776, // Range #346: [113776, 113800, Duployan]
+ 113808, // Range #347: [113808, 113823, Duployan]
+ 119296, // Range #348: [119296, 119365, Greek]
+ 120832, // Range #349: [120832, 121483, SignWriting]
+ 121499, // Range #350: [121499, 121519, SignWriting]
+ 122880, // Range #351: [122880, 122922, Glagolitic]
+ 123136, // Range #352: [123136, 123215, Nyiakeng_Puachue_Hmong]
+ 123584, // Range #353: [123584, 123641, Wancho]
+ 123647, // Range #354: [123647, 123647, Wancho]
+ 124928, // Range #355: [124928, 125142, Mende_Kikakui]
+ 125184, // Range #356: [125184, 125279, Adlam]
+ 126464, // Range #357: [126464, 126523, Arabic]
+ 126530, // Range #358: [126530, 126619, Arabic]
+ 126625, // Range #359: [126625, 126651, Arabic]
+ 126704, // Range #360: [126704, 126705, Arabic]
+ 127488, // Range #361: [127488, 127488, Hiragana]
+ 131072, // Range #362: [131072, 173782, Han]
+ 173824, // Range #363: [173824, 177972, Han]
+ 177984, // Range #364: [177984, 183969, Han]
+ 183984, // Range #365: [183984, 191456, Han]
+ 194560, // Range #366: [194560, 195101, Han]
+};
+
+const uint16 kRangeSizeMinusOne[] = {
+ 25, // Range #0: [65, 90, Latin]
+ 25, // Range #1: [97, 122, Latin]
+ 0, // Range #2: [170, 170, Latin]
+ 0, // Range #3: [186, 186, Latin]
+ 22, // Range #4: [192, 214, Latin]
+ 30, // Range #5: [216, 246, Latin]
+ 448, // Range #6: [248, 696, Latin]
+ 4, // Range #7: [736, 740, Latin]
+ 1, // Range #8: [746, 747, Bopomofo]
+ 3, // Range #9: [880, 883, Greek]
+ 8, // Range #10: [885, 893, Greek]
+ 5, // Range #11: [895, 900, Greek]
+ 0, // Range #12: [902, 902, Greek]
+ 89, // Range #13: [904, 993, Greek]
+ 13, // Range #14: [994, 1007, Coptic]
+ 15, // Range #15: [1008, 1023, Greek]
+ 132, // Range #16: [1024, 1156, Cyrillic]
+ 168, // Range #17: [1159, 1327, Cyrillic]
+ 87, // Range #18: [1329, 1416, Armenian]
+ 5, // Range #19: [1418, 1423, Armenian]
+ 54, // Range #20: [1425, 1479, Hebrew]
+ 36, // Range #21: [1488, 1524, Hebrew]
+ 4, // Range #22: [1536, 1540, Arabic]
+ 5, // Range #23: [1542, 1547, Arabic]
+ 13, // Range #24: [1549, 1562, Arabic]
+ 2, // Range #25: [1564, 1566, Arabic]
+ 31, // Range #26: [1568, 1599, Arabic]
+ 9, // Range #27: [1601, 1610, Arabic]
+ 25, // Range #28: [1622, 1647, Arabic]
+ 107, // Range #29: [1649, 1756, Arabic]
+ 33, // Range #30: [1758, 1791, Arabic]
+ 79, // Range #31: [1792, 1871, Syriac]
+ 47, // Range #32: [1872, 1919, Arabic]
+ 49, // Range #33: [1920, 1969, Thaana]
+ 63, // Range #34: [1984, 2047, Nko]
+ 62, // Range #35: [2048, 2110, Samaritan]
+ 30, // Range #36: [2112, 2142, Mandaic]
+ 10, // Range #37: [2144, 2154, Syriac]
+ 29, // Range #38: [2208, 2237, Arabic]
+ 14, // Range #39: [2259, 2273, Arabic]
+ 28, // Range #40: [2275, 2303, Arabic]
+ 80, // Range #41: [2304, 2384, Devanagari]
+ 14, // Range #42: [2389, 2403, Devanagari]
+ 25, // Range #43: [2406, 2431, Devanagari]
+ 78, // Range #44: [2432, 2510, Bengali]
+ 39, // Range #45: [2519, 2558, Bengali]
+ 80, // Range #46: [2561, 2641, Gurmukhi]
+ 5, // Range #47: [2649, 2654, Gurmukhi]
+ 16, // Range #48: [2662, 2678, Gurmukhi]
+ 79, // Range #49: [2689, 2768, Gujarati]
+ 17, // Range #50: [2784, 2801, Gujarati]
+ 6, // Range #51: [2809, 2815, Gujarati]
+ 76, // Range #52: [2817, 2893, Oriya]
+ 33, // Range #53: [2902, 2935, Oriya]
+ 78, // Range #54: [2946, 3024, Tamil]
+ 0, // Range #55: [3031, 3031, Tamil]
+ 20, // Range #56: [3046, 3066, Tamil]
+ 77, // Range #57: [3072, 3149, Telugu]
+ 5, // Range #58: [3157, 3162, Telugu]
+ 15, // Range #59: [3168, 3183, Telugu]
+ 8, // Range #60: [3191, 3199, Telugu]
+ 77, // Range #61: [3200, 3277, Kannada]
+ 1, // Range #62: [3285, 3286, Kannada]
+ 20, // Range #63: [3294, 3314, Kannada]
+ 127, // Range #64: [3328, 3455, Malayalam]
+ 93, // Range #65: [3458, 3551, Sinhala]
+ 14, // Range #66: [3558, 3572, Sinhala]
+ 57, // Range #67: [3585, 3642, Thai]
+ 27, // Range #68: [3648, 3675, Thai]
+ 94, // Range #69: [3713, 3807, Lao]
+ 212, // Range #70: [3840, 4052, Tibetan]
+ 1, // Range #71: [4057, 4058, Tibetan]
+ 159, // Range #72: [4096, 4255, Myanmar]
+ 39, // Range #73: [4256, 4295, Georgian]
+ 45, // Range #74: [4301, 4346, Georgian]
+ 3, // Range #75: [4348, 4351, Georgian]
+ 255, // Range #76: [4352, 4607, Hangul]
+ 409, // Range #77: [4608, 5017, Ethiopic]
+ 93, // Range #78: [5024, 5117, Cherokee]
+ 639, // Range #79: [5120, 5759, Canadian_Aboriginal]
+ 28, // Range #80: [5760, 5788, Ogham]
+ 74, // Range #81: [5792, 5866, Runic]
+ 10, // Range #82: [5870, 5880, Runic]
+ 20, // Range #83: [5888, 5908, Tagalog]
+ 20, // Range #84: [5920, 5940, Hanunoo]
+ 19, // Range #85: [5952, 5971, Buhid]
+ 19, // Range #86: [5984, 6003, Tagbanwa]
+ 105, // Range #87: [6016, 6121, Khmer]
+ 9, // Range #88: [6128, 6137, Khmer]
+ 1, // Range #89: [6144, 6145, Mongolian]
+ 0, // Range #90: [6148, 6148, Mongolian]
+ 19, // Range #91: [6150, 6169, Mongolian]
+ 88, // Range #92: [6176, 6264, Mongolian]
+ 42, // Range #93: [6272, 6314, Mongolian]
+ 69, // Range #94: [6320, 6389, Canadian_Aboriginal]
+ 79, // Range #95: [6400, 6479, Limbu]
+ 36, // Range #96: [6480, 6516, Tai_Le]
+ 73, // Range #97: [6528, 6601, New_Tai_Lue]
+ 15, // Range #98: [6608, 6623, New_Tai_Lue]
+ 31, // Range #99: [6624, 6655, Khmer]
+ 31, // Range #100: [6656, 6687, Buginese]
+ 105, // Range #101: [6688, 6793, Tai_Tham]
+ 9, // Range #102: [6800, 6809, Tai_Tham]
+ 13, // Range #103: [6816, 6829, Tai_Tham]
+ 124, // Range #104: [6912, 7036, Balinese]
+ 63, // Range #105: [7040, 7103, Sundanese]
+ 51, // Range #106: [7104, 7155, Batak]
+ 3, // Range #107: [7164, 7167, Batak]
+ 79, // Range #108: [7168, 7247, Lepcha]
+ 47, // Range #109: [7248, 7295, Ol_Chiki]
+ 8, // Range #110: [7296, 7304, Cyrillic]
+ 47, // Range #111: [7312, 7359, Georgian]
+ 7, // Range #112: [7360, 7367, Sundanese]
+ 37, // Range #113: [7424, 7461, Latin]
+ 4, // Range #114: [7462, 7466, Greek]
+ 0, // Range #115: [7467, 7467, Cyrillic]
+ 48, // Range #116: [7468, 7516, Latin]
+ 4, // Range #117: [7517, 7521, Greek]
+ 3, // Range #118: [7522, 7525, Latin]
+ 4, // Range #119: [7526, 7530, Greek]
+ 12, // Range #120: [7531, 7543, Latin]
+ 0, // Range #121: [7544, 7544, Cyrillic]
+ 69, // Range #122: [7545, 7614, Latin]
+ 0, // Range #123: [7615, 7615, Greek]
+ 255, // Range #124: [7680, 7935, Latin]
+ 254, // Range #125: [7936, 8190, Greek]
+ 0, // Range #126: [8305, 8305, Latin]
+ 0, // Range #127: [8319, 8319, Latin]
+ 12, // Range #128: [8336, 8348, Latin]
+ 0, // Range #129: [8486, 8486, Greek]
+ 1, // Range #130: [8490, 8491, Latin]
+ 0, // Range #131: [8498, 8498, Latin]
+ 0, // Range #132: [8526, 8526, Latin]
+ 40, // Range #133: [8544, 8584, Latin]
+ 255, // Range #134: [10240, 10495, Braille]
+ 94, // Range #135: [11264, 11358, Glagolitic]
+ 31, // Range #136: [11360, 11391, Latin]
+ 115, // Range #137: [11392, 11507, Coptic]
+ 6, // Range #138: [11513, 11519, Coptic]
+ 39, // Range #139: [11520, 11559, Georgian]
+ 0, // Range #140: [11565, 11565, Georgian]
+ 55, // Range #141: [11568, 11623, Tifinagh]
+ 1, // Range #142: [11631, 11632, Tifinagh]
+ 0, // Range #143: [11647, 11647, Tifinagh]
+ 22, // Range #144: [11648, 11670, Ethiopic]
+ 62, // Range #145: [11680, 11742, Ethiopic]
+ 31, // Range #146: [11744, 11775, Cyrillic]
+ 115, // Range #147: [11904, 12019, Han]
+ 213, // Range #148: [12032, 12245, Han]
+ 0, // Range #149: [12293, 12293, Han]
+ 0, // Range #150: [12295, 12295, Han]
+ 8, // Range #151: [12321, 12329, Han]
+ 1, // Range #152: [12334, 12335, Hangul]
+ 3, // Range #153: [12344, 12347, Han]
+ 85, // Range #154: [12353, 12438, Hiragana]
+ 2, // Range #155: [12445, 12447, Hiragana]
+ 89, // Range #156: [12449, 12538, Katakana]
+ 2, // Range #157: [12541, 12543, Katakana]
+ 42, // Range #158: [12549, 12591, Bopomofo]
+ 93, // Range #159: [12593, 12686, Hangul]
+ 26, // Range #160: [12704, 12730, Bopomofo]
+ 15, // Range #161: [12784, 12799, Katakana]
+ 30, // Range #162: [12800, 12830, Hangul]
+ 30, // Range #163: [12896, 12926, Hangul]
+ 46, // Range #164: [13008, 13054, Katakana]
+ 87, // Range #165: [13056, 13143, Katakana]
+ 6581, // Range #166: [13312, 19893, Han]
+ 20975, // Range #167: [19968, 40943, Han]
+ 1222, // Range #168: [40960, 42182, Yi]
+ 47, // Range #169: [42192, 42239, Lisu]
+ 299, // Range #170: [42240, 42539, Vai]
+ 95, // Range #171: [42560, 42655, Cyrillic]
+ 87, // Range #172: [42656, 42743, Bamum]
+ 101, // Range #173: [42786, 42887, Latin]
+ 59, // Range #174: [42891, 42950, Latin]
+ 8, // Range #175: [42999, 43007, Latin]
+ 43, // Range #176: [43008, 43051, Syloti_Nagri]
+ 55, // Range #177: [43072, 43127, Phags_Pa]
+ 69, // Range #178: [43136, 43205, Saurashtra]
+ 11, // Range #179: [43214, 43225, Saurashtra]
+ 31, // Range #180: [43232, 43263, Devanagari]
+ 45, // Range #181: [43264, 43309, Kayah_Li]
+ 0, // Range #182: [43311, 43311, Kayah_Li]
+ 35, // Range #183: [43312, 43347, Rejang]
+ 0, // Range #184: [43359, 43359, Rejang]
+ 28, // Range #185: [43360, 43388, Hangul]
+ 77, // Range #186: [43392, 43469, Javanese]
+ 15, // Range #187: [43472, 43487, Javanese]
+ 30, // Range #188: [43488, 43518, Myanmar]
+ 54, // Range #189: [43520, 43574, Cham]
+ 31, // Range #190: [43584, 43615, Cham]
+ 31, // Range #191: [43616, 43647, Myanmar]
+ 66, // Range #192: [43648, 43714, Tai_Viet]
+ 4, // Range #193: [43739, 43743, Tai_Viet]
+ 22, // Range #194: [43744, 43766, Meetei_Mayek]
+ 21, // Range #195: [43777, 43798, Ethiopic]
+ 14, // Range #196: [43808, 43822, Ethiopic]
+ 42, // Range #197: [43824, 43866, Latin]
+ 8, // Range #198: [43868, 43876, Latin]
+ 0, // Range #199: [43877, 43877, Greek]
+ 1, // Range #200: [43878, 43879, Latin]
+ 79, // Range #201: [43888, 43967, Cherokee]
+ 57, // Range #202: [43968, 44025, Meetei_Mayek]
+ 11171, // Range #203: [44032, 55203, Hangul]
+ 75, // Range #204: [55216, 55291, Hangul]
+ 473, // Range #205: [63744, 64217, Han]
+ 6, // Range #206: [64256, 64262, Latin]
+ 4, // Range #207: [64275, 64279, Armenian]
+ 50, // Range #208: [64285, 64335, Hebrew]
+ 113, // Range #209: [64336, 64449, Arabic]
+ 362, // Range #210: [64467, 64829, Arabic]
+ 119, // Range #211: [64848, 64967, Arabic]
+ 13, // Range #212: [65008, 65021, Arabic]
+ 1, // Range #213: [65070, 65071, Cyrillic]
+ 140, // Range #214: [65136, 65276, Arabic]
+ 25, // Range #215: [65313, 65338, Latin]
+ 25, // Range #216: [65345, 65370, Latin]
+ 9, // Range #217: [65382, 65391, Katakana]
+ 44, // Range #218: [65393, 65437, Katakana]
+ 60, // Range #219: [65440, 65500, Hangul]
+ 93, // Range #220: [65536, 65629, Linear_B]
+ 122, // Range #221: [65664, 65786, Linear_B]
+ 78, // Range #222: [65856, 65934, Greek]
+ 0, // Range #223: [65952, 65952, Greek]
+ 28, // Range #224: [66176, 66204, Lycian]
+ 48, // Range #225: [66208, 66256, Carian]
+ 35, // Range #226: [66304, 66339, Old_Italic]
+ 2, // Range #227: [66349, 66351, Old_Italic]
+ 26, // Range #228: [66352, 66378, Gothic]
+ 42, // Range #229: [66384, 66426, Old_Permic]
+ 31, // Range #230: [66432, 66463, Ugaritic]
+ 53, // Range #231: [66464, 66517, Old_Persian]
+ 79, // Range #232: [66560, 66639, Deseret]
+ 47, // Range #233: [66640, 66687, Shavian]
+ 41, // Range #234: [66688, 66729, Osmanya]
+ 75, // Range #235: [66736, 66811, Osage]
+ 39, // Range #236: [66816, 66855, Elbasan]
+ 51, // Range #237: [66864, 66915, Caucasian_Albanian]
+ 0, // Range #238: [66927, 66927, Caucasian_Albanian]
+ 310, // Range #239: [67072, 67382, Linear_A]
+ 21, // Range #240: [67392, 67413, Linear_A]
+ 7, // Range #241: [67424, 67431, Linear_A]
+ 63, // Range #242: [67584, 67647, Cypriot]
+ 31, // Range #243: [67648, 67679, Imperial_Aramaic]
+ 31, // Range #244: [67680, 67711, Palmyrene]
+ 30, // Range #245: [67712, 67742, Nabataean]
+ 8, // Range #246: [67751, 67759, Nabataean]
+ 21, // Range #247: [67808, 67829, Hatran]
+ 4, // Range #248: [67835, 67839, Hatran]
+ 31, // Range #249: [67840, 67871, Phoenician]
+ 25, // Range #250: [67872, 67897, Lydian]
+ 0, // Range #251: [67903, 67903, Lydian]
+ 31, // Range #252: [67968, 67999, Meroitic_Hieroglyphs]
+ 95, // Range #253: [68000, 68095, Meroitic_Cursive]
+ 6, // Range #254: [68096, 68102, Kharoshthi]
+ 60, // Range #255: [68108, 68168, Kharoshthi]
+ 8, // Range #256: [68176, 68184, Kharoshthi]
+ 31, // Range #257: [68192, 68223, Old_South_Arabian]
+ 31, // Range #258: [68224, 68255, Old_North_Arabian]
+ 54, // Range #259: [68288, 68342, Manichaean]
+ 63, // Range #260: [68352, 68415, Avestan]
+ 31, // Range #261: [68416, 68447, Inscriptional_Parthian]
+ 18, // Range #262: [68448, 68466, Inscriptional_Pahlavi]
+ 7, // Range #263: [68472, 68479, Inscriptional_Pahlavi]
+ 17, // Range #264: [68480, 68497, Psalter_Pahlavi]
+ 3, // Range #265: [68505, 68508, Psalter_Pahlavi]
+ 6, // Range #266: [68521, 68527, Psalter_Pahlavi]
+ 72, // Range #267: [68608, 68680, Old_Turkic]
+ 50, // Range #268: [68736, 68786, Old_Hungarian]
+ 50, // Range #269: [68800, 68850, Old_Hungarian]
+ 5, // Range #270: [68858, 68863, Old_Hungarian]
+ 39, // Range #271: [68864, 68903, Hanifi_Rohingya]
+ 9, // Range #272: [68912, 68921, Hanifi_Rohingya]
+ 30, // Range #273: [69216, 69246, Arabic]
+ 39, // Range #274: [69376, 69415, Old_Sogdian]
+ 41, // Range #275: [69424, 69465, Sogdian]
+ 22, // Range #276: [69600, 69622, Elymaic]
+ 111, // Range #277: [69632, 69743, Brahmi]
+ 0, // Range #278: [69759, 69759, Brahmi]
+ 65, // Range #279: [69760, 69825, Kaithi]
+ 0, // Range #280: [69837, 69837, Kaithi]
+ 24, // Range #281: [69840, 69864, Sora_Sompeng]
+ 9, // Range #282: [69872, 69881, Sora_Sompeng]
+ 70, // Range #283: [69888, 69958, Chakma]
+ 38, // Range #284: [69968, 70006, Mahajani]
+ 95, // Range #285: [70016, 70111, Sharada]
+ 19, // Range #286: [70113, 70132, Sinhala]
+ 62, // Range #287: [70144, 70206, Khojki]
+ 41, // Range #288: [70272, 70313, Multani]
+ 58, // Range #289: [70320, 70378, Khudawadi]
+ 9, // Range #290: [70384, 70393, Khudawadi]
+ 57, // Range #291: [70400, 70457, Grantha]
+ 20, // Range #292: [70460, 70480, Grantha]
+ 0, // Range #293: [70487, 70487, Grantha]
+ 23, // Range #294: [70493, 70516, Grantha]
+ 95, // Range #295: [70656, 70751, Newa]
+ 71, // Range #296: [70784, 70855, Tirhuta]
+ 9, // Range #297: [70864, 70873, Tirhuta]
+ 93, // Range #298: [71040, 71133, Siddham]
+ 68, // Range #299: [71168, 71236, Modi]
+ 9, // Range #300: [71248, 71257, Modi]
+ 12, // Range #301: [71264, 71276, Mongolian]
+ 56, // Range #302: [71296, 71352, Takri]
+ 9, // Range #303: [71360, 71369, Takri]
+ 63, // Range #304: [71424, 71487, Ahom]
+ 59, // Range #305: [71680, 71739, Dogra]
+ 82, // Range #306: [71840, 71922, Warang_Citi]
+ 0, // Range #307: [71935, 71935, Warang_Citi]
+ 68, // Range #308: [72096, 72164, Nandinagari]
+ 71, // Range #309: [72192, 72263, Zanabazar_Square]
+ 82, // Range #310: [72272, 72354, Soyombo]
+ 56, // Range #311: [72384, 72440, Pau_Cin_Hau]
+ 69, // Range #312: [72704, 72773, Bhaiksuki]
+ 28, // Range #313: [72784, 72812, Bhaiksuki]
+ 70, // Range #314: [72816, 72886, Marchen]
+ 71, // Range #315: [72960, 73031, Masaram_Gondi]
+ 9, // Range #316: [73040, 73049, Masaram_Gondi]
+ 56, // Range #317: [73056, 73112, Gunjala_Gondi]
+ 9, // Range #318: [73120, 73129, Gunjala_Gondi]
+ 24, // Range #319: [73440, 73464, Makasar]
+ 49, // Range #320: [73664, 73713, Tamil]
+ 0, // Range #321: [73727, 73727, Tamil]
+ 921, // Range #322: [73728, 74649, Cuneiform]
+ 116, // Range #323: [74752, 74868, Cuneiform]
+ 195, // Range #324: [74880, 75075, Cuneiform]
+ 1080, // Range #325: [77824, 78904, Egyptian_Hieroglyphs]
+ 582, // Range #326: [82944, 83526, Anatolian_Hieroglyphs]
+ 568, // Range #327: [92160, 92728, Bamum]
+ 47, // Range #328: [92736, 92783, Mro]
+ 37, // Range #329: [92880, 92917, Bassa_Vah]
+ 69, // Range #330: [92928, 92997, Pahawh_Hmong]
+ 39, // Range #331: [93008, 93047, Pahawh_Hmong]
+ 18, // Range #332: [93053, 93071, Pahawh_Hmong]
+ 90, // Range #333: [93760, 93850, Medefaidrin]
+ 135, // Range #334: [93952, 94087, Miao]
+ 16, // Range #335: [94095, 94111, Miao]
+ 0, // Range #336: [94176, 94176, Tangut]
+ 0, // Range #337: [94177, 94177, Nushu]
+ 6135, // Range #338: [94208, 100343, Tangut]
+ 754, // Range #339: [100352, 101106, Tangut]
+ 0, // Range #340: [110592, 110592, Katakana]
+ 285, // Range #341: [110593, 110878, Hiragana]
+ 2, // Range #342: [110928, 110930, Hiragana]
+ 3, // Range #343: [110948, 110951, Katakana]
+ 395, // Range #344: [110960, 111355, Nushu]
+ 106, // Range #345: [113664, 113770, Duployan]
+ 24, // Range #346: [113776, 113800, Duployan]
+ 15, // Range #347: [113808, 113823, Duployan]
+ 69, // Range #348: [119296, 119365, Greek]
+ 651, // Range #349: [120832, 121483, SignWriting]
+ 20, // Range #350: [121499, 121519, SignWriting]
+ 42, // Range #351: [122880, 122922, Glagolitic]
+ 79, // Range #352: [123136, 123215, Nyiakeng_Puachue_Hmong]
+ 57, // Range #353: [123584, 123641, Wancho]
+ 0, // Range #354: [123647, 123647, Wancho]
+ 214, // Range #355: [124928, 125142, Mende_Kikakui]
+ 95, // Range #356: [125184, 125279, Adlam]
+ 59, // Range #357: [126464, 126523, Arabic]
+ 89, // Range #358: [126530, 126619, Arabic]
+ 26, // Range #359: [126625, 126651, Arabic]
+ 1, // Range #360: [126704, 126705, Arabic]
+ 0, // Range #361: [127488, 127488, Hiragana]
+ 42710, // Range #362: [131072, 173782, Han]
+ 4148, // Range #363: [173824, 177972, Han]
+ 5985, // Range #364: [177984, 183969, Han]
+ 7472, // Range #365: [183984, 191456, Han]
+ 541, // Range #366: [194560, 195101, Han]
+};
+
+const uint8 kRangeScript[] = {
+ 25, // Range #0: [65, 90, Latin]
+ 25, // Range #1: [97, 122, Latin]
+ 25, // Range #2: [170, 170, Latin]
+ 25, // Range #3: [186, 186, Latin]
+ 25, // Range #4: [192, 214, Latin]
+ 25, // Range #5: [216, 246, Latin]
+ 25, // Range #6: [248, 696, Latin]
+ 25, // Range #7: [736, 740, Latin]
+ 5, // Range #8: [746, 747, Bopomofo]
+ 14, // Range #9: [880, 883, Greek]
+ 14, // Range #10: [885, 893, Greek]
+ 14, // Range #11: [895, 900, Greek]
+ 14, // Range #12: [902, 902, Greek]
+ 14, // Range #13: [904, 993, Greek]
+ 7, // Range #14: [994, 1007, Coptic]
+ 14, // Range #15: [1008, 1023, Greek]
+ 8, // Range #16: [1024, 1156, Cyrillic]
+ 8, // Range #17: [1159, 1327, Cyrillic]
+ 3, // Range #18: [1329, 1416, Armenian]
+ 3, // Range #19: [1418, 1423, Armenian]
+ 19, // Range #20: [1425, 1479, Hebrew]
+ 19, // Range #21: [1488, 1524, Hebrew]
+ 2, // Range #22: [1536, 1540, Arabic]
+ 2, // Range #23: [1542, 1547, Arabic]
+ 2, // Range #24: [1549, 1562, Arabic]
+ 2, // Range #25: [1564, 1566, Arabic]
+ 2, // Range #26: [1568, 1599, Arabic]
+ 2, // Range #27: [1601, 1610, Arabic]
+ 2, // Range #28: [1622, 1647, Arabic]
+ 2, // Range #29: [1649, 1756, Arabic]
+ 2, // Range #30: [1758, 1791, Arabic]
+ 34, // Range #31: [1792, 1871, Syriac]
+ 2, // Range #32: [1872, 1919, Arabic]
+ 37, // Range #33: [1920, 1969, Thaana]
+ 87, // Range #34: [1984, 2047, Nko]
+ 126, // Range #35: [2048, 2110, Samaritan]
+ 84, // Range #36: [2112, 2142, Mandaic]
+ 34, // Range #37: [2144, 2154, Syriac]
+ 2, // Range #38: [2208, 2237, Arabic]
+ 2, // Range #39: [2259, 2273, Arabic]
+ 2, // Range #40: [2275, 2303, Arabic]
+ 10, // Range #41: [2304, 2384, Devanagari]
+ 10, // Range #42: [2389, 2403, Devanagari]
+ 10, // Range #43: [2406, 2431, Devanagari]
+ 4, // Range #44: [2432, 2510, Bengali]
+ 4, // Range #45: [2519, 2558, Bengali]
+ 16, // Range #46: [2561, 2641, Gurmukhi]
+ 16, // Range #47: [2649, 2654, Gurmukhi]
+ 16, // Range #48: [2662, 2678, Gurmukhi]
+ 15, // Range #49: [2689, 2768, Gujarati]
+ 15, // Range #50: [2784, 2801, Gujarati]
+ 15, // Range #51: [2809, 2815, Gujarati]
+ 31, // Range #52: [2817, 2893, Oriya]
+ 31, // Range #53: [2902, 2935, Oriya]
+ 35, // Range #54: [2946, 3024, Tamil]
+ 35, // Range #55: [3031, 3031, Tamil]
+ 35, // Range #56: [3046, 3066, Tamil]
+ 36, // Range #57: [3072, 3149, Telugu]
+ 36, // Range #58: [3157, 3162, Telugu]
+ 36, // Range #59: [3168, 3183, Telugu]
+ 36, // Range #60: [3191, 3199, Telugu]
+ 21, // Range #61: [3200, 3277, Kannada]
+ 21, // Range #62: [3285, 3286, Kannada]
+ 21, // Range #63: [3294, 3314, Kannada]
+ 26, // Range #64: [3328, 3455, Malayalam]
+ 33, // Range #65: [3458, 3551, Sinhala]
+ 33, // Range #66: [3558, 3572, Sinhala]
+ 38, // Range #67: [3585, 3642, Thai]
+ 38, // Range #68: [3648, 3675, Thai]
+ 24, // Range #69: [3713, 3807, Lao]
+ 39, // Range #70: [3840, 4052, Tibetan]
+ 39, // Range #71: [4057, 4058, Tibetan]
+ 28, // Range #72: [4096, 4255, Myanmar]
+ 12, // Range #73: [4256, 4295, Georgian]
+ 12, // Range #74: [4301, 4346, Georgian]
+ 12, // Range #75: [4348, 4351, Georgian]
+ 18, // Range #76: [4352, 4607, Hangul]
+ 11, // Range #77: [4608, 5017, Ethiopic]
+ 6, // Range #78: [5024, 5117, Cherokee]
+ 40, // Range #79: [5120, 5759, Canadian_Aboriginal]
+ 29, // Range #80: [5760, 5788, Ogham]
+ 32, // Range #81: [5792, 5866, Runic]
+ 32, // Range #82: [5870, 5880, Runic]
+ 42, // Range #83: [5888, 5908, Tagalog]
+ 43, // Range #84: [5920, 5940, Hanunoo]
+ 44, // Range #85: [5952, 5971, Buhid]
+ 45, // Range #86: [5984, 6003, Tagbanwa]
+ 23, // Range #87: [6016, 6121, Khmer]
+ 23, // Range #88: [6128, 6137, Khmer]
+ 27, // Range #89: [6144, 6145, Mongolian]
+ 27, // Range #90: [6148, 6148, Mongolian]
+ 27, // Range #91: [6150, 6169, Mongolian]
+ 27, // Range #92: [6176, 6264, Mongolian]
+ 27, // Range #93: [6272, 6314, Mongolian]
+ 40, // Range #94: [6320, 6389, Canadian_Aboriginal]
+ 48, // Range #95: [6400, 6479, Limbu]
+ 52, // Range #96: [6480, 6516, Tai_Le]
+ 59, // Range #97: [6528, 6601, New_Tai_Lue]
+ 59, // Range #98: [6608, 6623, New_Tai_Lue]
+ 23, // Range #99: [6624, 6655, Khmer]
+ 55, // Range #100: [6656, 6687, Buginese]
+ 106, // Range #101: [6688, 6793, Tai_Tham]
+ 106, // Range #102: [6800, 6809, Tai_Tham]
+ 106, // Range #103: [6816, 6829, Tai_Tham]
+ 62, // Range #104: [6912, 7036, Balinese]
+ 113, // Range #105: [7040, 7103, Sundanese]
+ 63, // Range #106: [7104, 7155, Batak]
+ 63, // Range #107: [7164, 7167, Batak]
+ 82, // Range #108: [7168, 7247, Lepcha]
+ 109, // Range #109: [7248, 7295, Ol_Chiki]
+ 8, // Range #110: [7296, 7304, Cyrillic]
+ 12, // Range #111: [7312, 7359, Georgian]
+ 113, // Range #112: [7360, 7367, Sundanese]
+ 25, // Range #113: [7424, 7461, Latin]
+ 14, // Range #114: [7462, 7466, Greek]
+ 8, // Range #115: [7467, 7467, Cyrillic]
+ 25, // Range #116: [7468, 7516, Latin]
+ 14, // Range #117: [7517, 7521, Greek]
+ 25, // Range #118: [7522, 7525, Latin]
+ 14, // Range #119: [7526, 7530, Greek]
+ 25, // Range #120: [7531, 7543, Latin]
+ 8, // Range #121: [7544, 7544, Cyrillic]
+ 25, // Range #122: [7545, 7614, Latin]
+ 14, // Range #123: [7615, 7615, Greek]
+ 25, // Range #124: [7680, 7935, Latin]
+ 14, // Range #125: [7936, 8190, Greek]
+ 25, // Range #126: [8305, 8305, Latin]
+ 25, // Range #127: [8319, 8319, Latin]
+ 25, // Range #128: [8336, 8348, Latin]
+ 14, // Range #129: [8486, 8486, Greek]
+ 25, // Range #130: [8490, 8491, Latin]
+ 25, // Range #131: [8498, 8498, Latin]
+ 25, // Range #132: [8526, 8526, Latin]
+ 25, // Range #133: [8544, 8584, Latin]
+ 46, // Range #134: [10240, 10495, Braille]
+ 56, // Range #135: [11264, 11358, Glagolitic]
+ 25, // Range #136: [11360, 11391, Latin]
+ 7, // Range #137: [11392, 11507, Coptic]
+ 7, // Range #138: [11513, 11519, Coptic]
+ 12, // Range #139: [11520, 11559, Georgian]
+ 12, // Range #140: [11565, 11565, Georgian]
+ 60, // Range #141: [11568, 11623, Tifinagh]
+ 60, // Range #142: [11631, 11632, Tifinagh]
+ 60, // Range #143: [11647, 11647, Tifinagh]
+ 11, // Range #144: [11648, 11670, Ethiopic]
+ 11, // Range #145: [11680, 11742, Ethiopic]
+ 8, // Range #146: [11744, 11775, Cyrillic]
+ 17, // Range #147: [11904, 12019, Han]
+ 17, // Range #148: [12032, 12245, Han]
+ 17, // Range #149: [12293, 12293, Han]
+ 17, // Range #150: [12295, 12295, Han]
+ 17, // Range #151: [12321, 12329, Han]
+ 18, // Range #152: [12334, 12335, Hangul]
+ 17, // Range #153: [12344, 12347, Han]
+ 20, // Range #154: [12353, 12438, Hiragana]
+ 20, // Range #155: [12445, 12447, Hiragana]
+ 22, // Range #156: [12449, 12538, Katakana]
+ 22, // Range #157: [12541, 12543, Katakana]
+ 5, // Range #158: [12549, 12591, Bopomofo]
+ 18, // Range #159: [12593, 12686, Hangul]
+ 5, // Range #160: [12704, 12730, Bopomofo]
+ 22, // Range #161: [12784, 12799, Katakana]
+ 18, // Range #162: [12800, 12830, Hangul]
+ 18, // Range #163: [12896, 12926, Hangul]
+ 22, // Range #164: [13008, 13054, Katakana]
+ 22, // Range #165: [13056, 13143, Katakana]
+ 17, // Range #166: [13312, 19893, Han]
+ 17, // Range #167: [19968, 40943, Han]
+ 41, // Range #168: [40960, 42182, Yi]
+ 131, // Range #169: [42192, 42239, Lisu]
+ 99, // Range #170: [42240, 42539, Vai]
+ 8, // Range #171: [42560, 42655, Cyrillic]
+ 130, // Range #172: [42656, 42743, Bamum]
+ 25, // Range #173: [42786, 42887, Latin]
+ 25, // Range #174: [42891, 42950, Latin]
+ 25, // Range #175: [42999, 43007, Latin]
+ 58, // Range #176: [43008, 43051, Syloti_Nagri]
+ 90, // Range #177: [43072, 43127, Phags_Pa]
+ 111, // Range #178: [43136, 43205, Saurashtra]
+ 111, // Range #179: [43214, 43225, Saurashtra]
+ 10, // Range #180: [43232, 43263, Devanagari]
+ 79, // Range #181: [43264, 43309, Kayah_Li]
+ 79, // Range #182: [43311, 43311, Kayah_Li]
+ 110, // Range #183: [43312, 43347, Rejang]
+ 110, // Range #184: [43359, 43359, Rejang]
+ 18, // Range #185: [43360, 43388, Hangul]
+ 78, // Range #186: [43392, 43469, Javanese]
+ 78, // Range #187: [43472, 43487, Javanese]
+ 28, // Range #188: [43488, 43518, Myanmar]
+ 66, // Range #189: [43520, 43574, Cham]
+ 66, // Range #190: [43584, 43615, Cham]
+ 28, // Range #191: [43616, 43647, Myanmar]
+ 127, // Range #192: [43648, 43714, Tai_Viet]
+ 127, // Range #193: [43739, 43743, Tai_Viet]
+ 115, // Range #194: [43744, 43766, Meetei_Mayek]
+ 11, // Range #195: [43777, 43798, Ethiopic]
+ 11, // Range #196: [43808, 43822, Ethiopic]
+ 25, // Range #197: [43824, 43866, Latin]
+ 25, // Range #198: [43868, 43876, Latin]
+ 14, // Range #199: [43877, 43877, Greek]
+ 25, // Range #200: [43878, 43879, Latin]
+ 6, // Range #201: [43888, 43967, Cherokee]
+ 115, // Range #202: [43968, 44025, Meetei_Mayek]
+ 18, // Range #203: [44032, 55203, Hangul]
+ 18, // Range #204: [55216, 55291, Hangul]
+ 17, // Range #205: [63744, 64217, Han]
+ 25, // Range #206: [64256, 64262, Latin]
+ 3, // Range #207: [64275, 64279, Armenian]
+ 19, // Range #208: [64285, 64335, Hebrew]
+ 2, // Range #209: [64336, 64449, Arabic]
+ 2, // Range #210: [64467, 64829, Arabic]
+ 2, // Range #211: [64848, 64967, Arabic]
+ 2, // Range #212: [65008, 65021, Arabic]
+ 8, // Range #213: [65070, 65071, Cyrillic]
+ 2, // Range #214: [65136, 65276, Arabic]
+ 25, // Range #215: [65313, 65338, Latin]
+ 25, // Range #216: [65345, 65370, Latin]
+ 22, // Range #217: [65382, 65391, Katakana]
+ 22, // Range #218: [65393, 65437, Katakana]
+ 18, // Range #219: [65440, 65500, Hangul]
+ 49, // Range #220: [65536, 65629, Linear_B]
+ 49, // Range #221: [65664, 65786, Linear_B]
+ 14, // Range #222: [65856, 65934, Greek]
+ 14, // Range #223: [65952, 65952, Greek]
+ 107, // Range #224: [66176, 66204, Lycian]
+ 104, // Range #225: [66208, 66256, Carian]
+ 30, // Range #226: [66304, 66339, Old_Italic]
+ 30, // Range #227: [66349, 66351, Old_Italic]
+ 13, // Range #228: [66352, 66378, Gothic]
+ 89, // Range #229: [66384, 66426, Old_Permic]
+ 53, // Range #230: [66432, 66463, Ugaritic]
+ 61, // Range #231: [66464, 66517, Old_Persian]
+ 9, // Range #232: [66560, 66639, Deseret]
+ 51, // Range #233: [66640, 66687, Shavian]
+ 50, // Range #234: [66688, 66729, Osmanya]
+ 171, // Range #235: [66736, 66811, Osage]
+ 136, // Range #236: [66816, 66855, Elbasan]
+ 159, // Range #237: [66864, 66915, Caucasian_Albanian]
+ 159, // Range #238: [66927, 66927, Caucasian_Albanian]
+ 83, // Range #239: [67072, 67382, Linear_A]
+ 83, // Range #240: [67392, 67413, Linear_A]
+ 83, // Range #241: [67424, 67431, Linear_A]
+ 47, // Range #242: [67584, 67647, Cypriot]
+ 116, // Range #243: [67648, 67679, Imperial_Aramaic]
+ 144, // Range #244: [67680, 67711, Palmyrene]
+ 143, // Range #245: [67712, 67742, Nabataean]
+ 143, // Range #246: [67751, 67759, Nabataean]
+ 162, // Range #247: [67808, 67829, Hatran]
+ 162, // Range #248: [67835, 67839, Hatran]
+ 91, // Range #249: [67840, 67871, Phoenician]
+ 108, // Range #250: [67872, 67897, Lydian]
+ 108, // Range #251: [67903, 67903, Lydian]
+ 86, // Range #252: [67968, 67999, Meroitic_Hieroglyphs]
+ 141, // Range #253: [68000, 68095, Meroitic_Cursive]
+ 57, // Range #254: [68096, 68102, Kharoshthi]
+ 57, // Range #255: [68108, 68168, Kharoshthi]
+ 57, // Range #256: [68176, 68184, Kharoshthi]
+ 133, // Range #257: [68192, 68223, Old_South_Arabian]
+ 142, // Range #258: [68224, 68255, Old_North_Arabian]
+ 121, // Range #259: [68288, 68342, Manichaean]
+ 117, // Range #260: [68352, 68415, Avestan]
+ 125, // Range #261: [68416, 68447, Inscriptional_Parthian]
+ 122, // Range #262: [68448, 68466, Inscriptional_Pahlavi]
+ 122, // Range #263: [68472, 68479, Inscriptional_Pahlavi]
+ 123, // Range #264: [68480, 68497, Psalter_Pahlavi]
+ 123, // Range #265: [68505, 68508, Psalter_Pahlavi]
+ 123, // Range #266: [68521, 68527, Psalter_Pahlavi]
+ 88, // Range #267: [68608, 68680, Old_Turkic]
+ 76, // Range #268: [68736, 68786, Old_Hungarian]
+ 76, // Range #269: [68800, 68850, Old_Hungarian]
+ 76, // Range #270: [68858, 68863, Old_Hungarian]
+ 182, // Range #271: [68864, 68903, Hanifi_Rohingya]
+ 182, // Range #272: [68912, 68921, Hanifi_Rohingya]
+ 2, // Range #273: [69216, 69246, Arabic]
+ 184, // Range #274: [69376, 69415, Old_Sogdian]
+ 183, // Range #275: [69424, 69465, Sogdian]
+ 185, // Range #276: [69600, 69622, Elymaic]
+ 65, // Range #277: [69632, 69743, Brahmi]
+ 65, // Range #278: [69759, 69759, Brahmi]
+ 120, // Range #279: [69760, 69825, Kaithi]
+ 120, // Range #280: [69837, 69837, Kaithi]
+ 152, // Range #281: [69840, 69864, Sora_Sompeng]
+ 152, // Range #282: [69872, 69881, Sora_Sompeng]
+ 118, // Range #283: [69888, 69958, Chakma]
+ 160, // Range #284: [69968, 70006, Mahajani]
+ 151, // Range #285: [70016, 70111, Sharada]
+ 33, // Range #286: [70113, 70132, Sinhala]
+ 157, // Range #287: [70144, 70206, Khojki]
+ 164, // Range #288: [70272, 70313, Multani]
+ 145, // Range #289: [70320, 70378, Khudawadi]
+ 145, // Range #290: [70384, 70393, Khudawadi]
+ 137, // Range #291: [70400, 70457, Grantha]
+ 137, // Range #292: [70460, 70480, Grantha]
+ 137, // Range #293: [70487, 70487, Grantha]
+ 137, // Range #294: [70493, 70516, Grantha]
+ 170, // Range #295: [70656, 70751, Newa]
+ 158, // Range #296: [70784, 70855, Tirhuta]
+ 158, // Range #297: [70864, 70873, Tirhuta]
+ 166, // Range #298: [71040, 71133, Siddham]
+ 163, // Range #299: [71168, 71236, Modi]
+ 163, // Range #300: [71248, 71257, Modi]
+ 27, // Range #301: [71264, 71276, Mongolian]
+ 153, // Range #302: [71296, 71352, Takri]
+ 153, // Range #303: [71360, 71369, Takri]
+ 161, // Range #304: [71424, 71487, Ahom]
+ 178, // Range #305: [71680, 71739, Dogra]
+ 146, // Range #306: [71840, 71922, Warang_Citi]
+ 146, // Range #307: [71935, 71935, Warang_Citi]
+ 187, // Range #308: [72096, 72164, Nandinagari]
+ 177, // Range #309: [72192, 72263, Zanabazar_Square]
+ 176, // Range #310: [72272, 72354, Soyombo]
+ 165, // Range #311: [72384, 72440, Pau_Cin_Hau]
+ 168, // Range #312: [72704, 72773, Bhaiksuki]
+ 168, // Range #313: [72784, 72812, Bhaiksuki]
+ 169, // Range #314: [72816, 72886, Marchen]
+ 175, // Range #315: [72960, 73031, Masaram_Gondi]
+ 175, // Range #316: [73040, 73049, Masaram_Gondi]
+ 179, // Range #317: [73056, 73112, Gunjala_Gondi]
+ 179, // Range #318: [73120, 73129, Gunjala_Gondi]
+ 180, // Range #319: [73440, 73464, Makasar]
+ 35, // Range #320: [73664, 73713, Tamil]
+ 35, // Range #321: [73727, 73727, Tamil]
+ 101, // Range #322: [73728, 74649, Cuneiform]
+ 101, // Range #323: [74752, 74868, Cuneiform]
+ 101, // Range #324: [74880, 75075, Cuneiform]
+ 71, // Range #325: [77824, 78904, Egyptian_Hieroglyphs]
+ 156, // Range #326: [82944, 83526, Anatolian_Hieroglyphs]
+ 130, // Range #327: [92160, 92728, Bamum]
+ 149, // Range #328: [92736, 92783, Mro]
+ 134, // Range #329: [92880, 92917, Bassa_Vah]
+ 75, // Range #330: [92928, 92997, Pahawh_Hmong]
+ 75, // Range #331: [93008, 93047, Pahawh_Hmong]
+ 75, // Range #332: [93053, 93071, Pahawh_Hmong]
+ 181, // Range #333: [93760, 93850, Medefaidrin]
+ 92, // Range #334: [93952, 94087, Miao]
+ 92, // Range #335: [94095, 94111, Miao]
+ 154, // Range #336: [94176, 94176, Tangut]
+ 150, // Range #337: [94177, 94177, Nushu]
+ 154, // Range #338: [94208, 100343, Tangut]
+ 154, // Range #339: [100352, 101106, Tangut]
+ 22, // Range #340: [110592, 110592, Katakana]
+ 20, // Range #341: [110593, 110878, Hiragana]
+ 20, // Range #342: [110928, 110930, Hiragana]
+ 22, // Range #343: [110948, 110951, Katakana]
+ 150, // Range #344: [110960, 111355, Nushu]
+ 135, // Range #345: [113664, 113770, Duployan]
+ 135, // Range #346: [113776, 113800, Duployan]
+ 135, // Range #347: [113808, 113823, Duployan]
+ 14, // Range #348: [119296, 119365, Greek]
+ 112, // Range #349: [120832, 121483, SignWriting]
+ 112, // Range #350: [121499, 121519, SignWriting]
+ 56, // Range #351: [122880, 122922, Glagolitic]
+ 186, // Range #352: [123136, 123215, Nyiakeng_Puachue_Hmong]
+ 188, // Range #353: [123584, 123641, Wancho]
+ 188, // Range #354: [123647, 123647, Wancho]
+ 140, // Range #355: [124928, 125142, Mende_Kikakui]
+ 167, // Range #356: [125184, 125279, Adlam]
+ 2, // Range #357: [126464, 126523, Arabic]
+ 2, // Range #358: [126530, 126619, Arabic]
+ 2, // Range #359: [126625, 126651, Arabic]
+ 2, // Range #360: [126704, 126705, Arabic]
+ 20, // Range #361: [127488, 127488, Hiragana]
+ 17, // Range #362: [131072, 173782, Han]
+ 17, // Range #363: [173824, 177972, Han]
+ 17, // Range #364: [177984, 183969, Han]
+ 17, // Range #365: [183984, 191456, Han]
+ 17, // Range #366: [194560, 195101, Han]
+};
+
+const uint8 kMaxScript = 188;
+
+} // namespace approx_script_internal
+} // namespace mobile
+} // namespace nlp_saft
diff --git a/lang_id/script/approx-script-data.h b/lang_id/script/approx-script-data.h
new file mode 100644
index 0000000..3eceed8
--- /dev/null
+++ b/lang_id/script/approx-script-data.h
@@ -0,0 +1,63 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_SCRIPT_APPROX_SCRIPT_DATA_H_
+#define NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_SCRIPT_APPROX_SCRIPT_DATA_H_
+
+#include "lang_id/common/lite_base/integral-types.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+namespace approx_script_internal {
+
+// Number of contiguous ranges of same-script codepoints (see below).
+extern const int kNumRanges;
+
+// Non-overlapping ranges of unicode characters. Characters from each range has
+// the same script (see kRangeScripts below). Multiple ranges may have the same
+// script. Note: we represent the kNumRanges ranges as an array with their
+// first codepoints, and a separate array with their sizes (see kRangeSize
+// below). This leads to better memory locality during the binary search (which
+// uses only the first codepoints, up until the very end).
+//
+// kRangeFirst[i] = first codepoint from range #i, \forall 0 <= i < kNumRanges.
+extern const uint32 kRangeFirst[];
+
+// kRangeSize[i] > 0 is the number of consecutive codepoints in range #i *minus*
+// 1, \forall 0 <= i < kNumRanges. I.e., 0 means that the range contains 1
+// codepoints. Since we don't have empty ranges, this "minus one" convention
+// allows us to use all 2^16 values here.
+extern const uint16 kRangeSizeMinusOne[];
+
+// Scripts for the ranges from kRanges. For each i such that 0 <= i <
+// kNumRanges, the range #i has the script kRangeScript[i]. Each uint8 element
+// can be casted to an UScriptCode enum value (see
+// unicode/uscript.h).
+//
+// NOTE: we don't use directly UScriptCode here, as that requires a full int
+// (due to USCRIPT_INVALID_CODE = -1). uint8 is enough for us (and shorter!)
+extern const uint8 kRangeScript[];
+
+// Max value from kRangeScript[]. Scripts are guaranteed to be in the interval
+// [0, kMaxScript] (inclusive on both sides). Can be used to e.g., set the
+// number of rows in an embedding table for a script-based feature.
+extern const uint8 kMaxScript;
+
+} // namespace approx_script_internal
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_SCRIPT_APPROX_SCRIPT_DATA_H_
diff --git a/lang_id/script/approx-script.cc b/lang_id/script/approx-script.cc
new file mode 100644
index 0000000..6aa3594
--- /dev/null
+++ b/lang_id/script/approx-script.cc
@@ -0,0 +1,94 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/script/approx-script.h"
+
+#include "lang_id/common/lite_base/integral-types.h"
+#include "lang_id/common/lite_base/logging.h"
+#include "lang_id/common/utf8.h"
+#include "lang_id/script/approx-script-data.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+// int value of USCRIPT_UNKNOWN from enum UScriptCode (from
+// unicode/uscript.h). Note: we do have a test that
+// USCRIPT_UNKNOWN evaluates to 103.
+const int kUnknownUscript = 103;
+
+namespace {
+using approx_script_internal::kNumRanges;
+using approx_script_internal::kRangeFirst;
+using approx_script_internal::kRangeScript;
+using approx_script_internal::kRangeSizeMinusOne;
+
+uint32 Utf8ToCodepoint(const unsigned char *s, int num_bytes) {
+ switch (num_bytes) {
+ case 1:
+ return s[0];
+ case 2:
+ return ((s[0] & 0x1F) << 6) | (s[1] & 0x3F);
+ case 3:
+ return (((s[0] & 0x0F) << 12) | ((s[1] & 0x3F) << 6) | (s[2] & 0x3F));
+ case 4:
+ return (((s[0] & 0x07) << 18) | ((s[1] & 0x3F) << 12) |
+ ((s[2] & 0x3F) << 6) | (s[3] & 0x3F));
+ default:
+ SAFTM_DLOG(FATAL) << "Illegal num_bytes: " << num_bytes;
+ return 0;
+ }
+}
+
+inline int BinarySearch(uint32 codepoint, int start, int end) {
+ while (end > start + 1) {
+ // Due to the while loop condition, middle > start and middle < end. Hence,
+ // on both branches of the if below, we strictly reduce the end - start
+ // value, so we eventually get that difference below 1 and complete the
+ // while loop.
+ int middle = (start + end) / 2;
+ if (codepoint < kRangeFirst[middle]) {
+ end = middle;
+ } else {
+ start = middle;
+ }
+ }
+
+ if (end == start + 1) {
+ const uint32 range_start = kRangeFirst[start];
+ if ((codepoint >= range_start) &&
+ (codepoint <= range_start + kRangeSizeMinusOne[start])) {
+ return kRangeScript[start];
+ }
+ }
+
+ return kUnknownUscript;
+}
+} // namespace
+
+int GetApproxScript(const unsigned char *s, int num_bytes) {
+ SAFTM_DCHECK_NE(s, nullptr);
+ SAFTM_DCHECK_EQ(num_bytes,
+ utils::OneCharLen(reinterpret_cast<const char *>(s)));
+ uint32 codepoint = Utf8ToCodepoint(s, num_bytes);
+ return BinarySearch(codepoint, 0, kNumRanges);
+}
+
+int GetMaxApproxScriptResult() { return approx_script_internal::kMaxScript; }
+
+SAFTM_STATIC_REGISTRATION(ApproxScriptDetector);
+
+} // namespace mobile
+} // namespace nlp_saft
diff --git a/lang_id/script/approx-script.h b/lang_id/script/approx-script.h
new file mode 100644
index 0000000..2472e86
--- /dev/null
+++ b/lang_id/script/approx-script.h
@@ -0,0 +1,82 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_SCRIPT_APPROX_SCRIPT_H_
+#define NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_SCRIPT_APPROX_SCRIPT_H_
+
+#include "lang_id/common/utf8.h"
+#include "lang_id/script/script-detector.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+// Returns script for the UTF-8 character that starts at address |s| and has
+// |num_bytes| bytes. Note: behavior is unspecified if s points to a UTF-8
+// character that has a different number of bytes. If you don't know
+// |num_bytes|, call GetApproxScript(const char *s).
+//
+// NOTE: to keep BUILD deps small, this function returns an int, but you can
+// assume it's an enum UScriptCode (unicode/uscript.h)
+//
+// If unable to determine the script, this function returns kUnknownUscript, the
+// int value of USCRIPT_UNKNOWN from enum UScriptCode.
+int GetApproxScript(const unsigned char *s, int num_bytes);
+
+// See comments for GetApproxScript() above.
+extern const int kUnknownUscript;
+
+// Same as before, but s is a const char *pointer (no unsigned). Internally, we
+// prefer "unsigned char" (the signed status of char is ambiguous), so we cast
+// and call the previous version (with const unsigned char *).
+inline int GetApproxScript(const char *s, int num_bytes) {
+ return GetApproxScript(reinterpret_cast<const unsigned char *>(s), num_bytes);
+}
+
+// Returns script for the UTF-8 character that starts at address |s|. NOTE:
+// UTF-8 is a var-length encoding, taking between 1 and 4 bytes per Unicode
+// character. We infer the number of bytes based on s[0]. If that number is k,
+// we expect to be able to read k bytes starting from address |s|. I.e., do not
+// call this function on broken UTF-8.
+inline int GetApproxScript(const char *s) {
+ return GetApproxScript(s, utils::OneCharLen(s));
+}
+
+// Returns max value returned by the GetApproxScript() functions.
+int GetMaxApproxScriptResult();
+
+class ApproxScriptDetector : public ScriptDetector {
+ public:
+ ~ApproxScriptDetector() override = default;
+
+ // Note: the int result of this method is actually a UScriptCode enum value.
+ // We return int to match the general case from the base class ScriptDetector
+ // (some script detectors do not use UScriptCode).
+ int GetScript(const char *s, int num_bytes) const override {
+ return GetApproxScript(s, num_bytes);
+ }
+
+ int GetMaxScript() const override {
+ return GetMaxApproxScriptResult();
+ }
+
+ SAFTM_DEFINE_REGISTRATION_METHOD("approx-unicode-script-detector",
+ ApproxScriptDetector);
+};
+
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_SCRIPT_APPROX_SCRIPT_H_
diff --git a/lang_id/script/script-detector.cc b/lang_id/script/script-detector.cc
new file mode 100644
index 0000000..6c19883
--- /dev/null
+++ b/lang_id/script/script-detector.cc
@@ -0,0 +1,25 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/script/script-detector.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+SAFTM_DEFINE_CLASS_REGISTRY_NAME("script detector", ScriptDetector);
+
+} // namespace mobile
+} // namespace nlp_saft
diff --git a/lang_id/script/script-detector.h b/lang_id/script/script-detector.h
new file mode 100644
index 0000000..12a7888
--- /dev/null
+++ b/lang_id/script/script-detector.h
@@ -0,0 +1,46 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_SCRIPT_SCRIPT_DETECTOR_H_
+#define NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_SCRIPT_SCRIPT_DETECTOR_H_
+
+#include "lang_id/common/registry.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+// Base class for Unicode script detectors. Individual detectors may differ in
+// code size, speed, precision, etc. You can use the registration mechanism to
+// get the ScriptDetector that's most appropriate to your application.
+class ScriptDetector : public RegisterableClass<ScriptDetector> {
+ public:
+ virtual ~ScriptDetector() = default;
+
+ // Returns a number between 0 and GetMaxScript() (inclusive on both ends) that
+ // indicates the script of the UTF8 character that starts at address |s| and
+ // has |num_bytes|.
+ virtual int GetScript(const char *s, int num_bytes) const = 0;
+
+ // Returns max result that can be returned by GetScript().
+ virtual int GetMaxScript() const = 0;
+};
+
+SAFTM_DECLARE_CLASS_REGISTRY_NAME(ScriptDetector);
+
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_SCRIPT_SCRIPT_DETECTOR_H_
diff --git a/lang_id/script/tiny-script-detector.cc b/lang_id/script/tiny-script-detector.cc
new file mode 100644
index 0000000..2f0dd98
--- /dev/null
+++ b/lang_id/script/tiny-script-detector.cc
@@ -0,0 +1,27 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/script/tiny-script-detector.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+namespace lang_id {
+
+SAFTM_STATIC_REGISTRATION(TinyScriptDetector);
+
+} // namespace lang_id
+} // namespace mobile
+} // namespace nlp_saft
diff --git a/lang_id/script/tiny-script-detector.h b/lang_id/script/tiny-script-detector.h
new file mode 100644
index 0000000..a55da04
--- /dev/null
+++ b/lang_id/script/tiny-script-detector.h
@@ -0,0 +1,181 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_SCRIPT_TINY_SCRIPT_DETECTOR_H_
+#define NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_SCRIPT_TINY_SCRIPT_DETECTOR_H_
+
+#include "lang_id/script/script-detector.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+namespace lang_id {
+
+// Unicode scripts we care about. To get compact and fast code, we detect only
+// a few Unicode scripts that offer a strong indication about the language of
+// the text (e.g., Hiragana -> Japanese).
+enum Script {
+ // Special value to indicate internal errors in the script detection code.
+ kScriptError,
+
+ // Special values for all Unicode scripts that we do not detect. One special
+ // value for Unicode characters of 1, 2, 3, respectively 4 bytes (as we
+ // already have that information, we use it). kScriptOtherUtf8OneByte means
+ // ~Latin and kScriptOtherUtf8FourBytes means ~Han.
+ kScriptOtherUtf8OneByte,
+ kScriptOtherUtf8TwoBytes,
+ kScriptOtherUtf8ThreeBytes,
+ kScriptOtherUtf8FourBytes,
+
+ kScriptGreek,
+ kScriptCyrillic,
+ kScriptHebrew,
+ kScriptArabic,
+ kScriptHangulJamo, // Used primarily for Korean.
+ kScriptHiragana, // Used primarily for Japanese.
+ kScriptKatakana, // Used primarily for Japanese.
+
+ // Add new scripts here.
+
+ // Do not add any script after kNumRelevantScripts. This value indicates the
+ // number of elements in this enum Script (except this value) such that we can
+ // easily iterate over the scripts.
+ kNumRelevantScripts,
+};
+
+template<typename IntType>
+inline bool InRange(IntType value, IntType low, IntType hi) {
+ return (value >= low) && (value <= hi);
+}
+
+// Returns Script for the UTF8 character that starts at address p.
+// Precondition: p points to a valid UTF8 character of num_bytes bytes.
+inline Script GetScript(const unsigned char *p, int num_bytes) {
+ switch (num_bytes) {
+ case 1:
+ return kScriptOtherUtf8OneByte;
+
+ case 2: {
+ // 2-byte UTF8 characters have 11 bits of information. unsigned int has
+ // at least 16 bits (http://en.cppreference.com/w/cpp/language/types) so
+ // it's enough. It's also usually the fastest int type on the current
+ // CPU, so it's better to use than int32.
+ static const unsigned int kGreekStart = 0x370;
+
+ // Commented out (unsued in the code): kGreekEnd = 0x3FF;
+ static const unsigned int kCyrillicStart = 0x400;
+ static const unsigned int kCyrillicEnd = 0x4FF;
+ static const unsigned int kHebrewStart = 0x590;
+
+ // Commented out (unsued in the code): kHebrewEnd = 0x5FF;
+ static const unsigned int kArabicStart = 0x600;
+ static const unsigned int kArabicEnd = 0x6FF;
+ const unsigned int codepoint = ((p[0] & 0x1F) << 6) | (p[1] & 0x3F);
+ if (codepoint > kCyrillicEnd) {
+ if (codepoint >= kArabicStart) {
+ if (codepoint <= kArabicEnd) {
+ return kScriptArabic;
+ }
+ } else {
+ // At this point, codepoint < kArabicStart = kHebrewEnd + 1, so
+ // codepoint <= kHebrewEnd.
+ if (codepoint >= kHebrewStart) {
+ return kScriptHebrew;
+ }
+ }
+ } else {
+ if (codepoint >= kCyrillicStart) {
+ return kScriptCyrillic;
+ } else {
+ // At this point, codepoint < kCyrillicStart = kGreekEnd + 1, so
+ // codepoint <= kGreekEnd.
+ if (codepoint >= kGreekStart) {
+ return kScriptGreek;
+ }
+ }
+ }
+ return kScriptOtherUtf8TwoBytes;
+ }
+
+ case 3: {
+ // 3-byte UTF8 characters have 16 bits of information. unsigned int has
+ // at least 16 bits.
+ static const unsigned int kHangulJamoStart = 0x1100;
+ static const unsigned int kHangulJamoEnd = 0x11FF;
+ static const unsigned int kHiraganaStart = 0x3041;
+ static const unsigned int kHiraganaEnd = 0x309F;
+
+ // Commented out (unsued in the code): kKatakanaStart = 0x30A0;
+ static const unsigned int kKatakanaEnd = 0x30FF;
+ const unsigned int codepoint =
+ ((p[0] & 0x0F) << 12) | ((p[1] & 0x3F) << 6) | (p[2] & 0x3F);
+ if (codepoint > kHiraganaEnd) {
+ // On this branch, codepoint > kHiraganaEnd = kKatakanaStart - 1, so
+ // codepoint >= kKatakanaStart.
+ if (codepoint <= kKatakanaEnd) {
+ return kScriptKatakana;
+ }
+ } else {
+ if (codepoint >= kHiraganaStart) {
+ return kScriptHiragana;
+ } else {
+ if (InRange(codepoint, kHangulJamoStart, kHangulJamoEnd)) {
+ return kScriptHangulJamo;
+ }
+ }
+ }
+ return kScriptOtherUtf8ThreeBytes;
+ }
+
+ case 4:
+ return kScriptOtherUtf8FourBytes;
+
+ default:
+ return kScriptError;
+ }
+}
+
+// Returns Script for the UTF8 character that starts at address p. Similar to
+// the previous version of GetScript, except for "char" vs "unsigned char".
+// Most code works with "char *" pointers, ignoring the fact that char is
+// unsigned (by default) on most platforms, but signed on iOS. This code takes
+// care of making sure we always treat chars as unsigned.
+inline Script GetScript(const char *p, int num_bytes) {
+ return GetScript(reinterpret_cast<const unsigned char *>(p),
+ num_bytes);
+}
+
+class TinyScriptDetector : public ScriptDetector {
+ public:
+ ~TinyScriptDetector() override = default;
+
+ int GetScript(const char *s, int num_bytes) const override {
+ // Add the namespace in indicate that we want to call the method outside
+ // this class, instead of performing an infinite recursive call.
+ return libtextclassifier3::mobile::lang_id::GetScript(s, num_bytes);
+ }
+
+ int GetMaxScript() const override {
+ return kNumRelevantScripts - 1;
+ }
+
+ SAFTM_DEFINE_REGISTRATION_METHOD("tiny-script-detector", TinyScriptDetector);
+};
+
+} // namespace lang_id
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_SCRIPT_TINY_SCRIPT_DETECTOR_H_
diff --git a/models/actions_suggestions.en.model b/models/actions_suggestions.en.model
new file mode 100644
index 0000000..6cec2b7
--- /dev/null
+++ b/models/actions_suggestions.en.model
Binary files differ
diff --git a/models/actions_suggestions.universal.model b/models/actions_suggestions.universal.model
new file mode 100644
index 0000000..60f10e6
--- /dev/null
+++ b/models/actions_suggestions.universal.model
Binary files differ
diff --git a/models/lang_id.model b/models/lang_id.model
new file mode 100644
index 0000000..49b4b07
--- /dev/null
+++ b/models/lang_id.model
Binary files differ
diff --git a/models/textclassifier.ar.model b/models/textclassifier.ar.model
index 4153026..9d8e2eb 100644
--- a/models/textclassifier.ar.model
+++ b/models/textclassifier.ar.model
Binary files differ
diff --git a/models/textclassifier.en.model b/models/textclassifier.en.model
index 887d1df..917db91 100644
--- a/models/textclassifier.en.model
+++ b/models/textclassifier.en.model
Binary files differ
diff --git a/models/textclassifier.es.model b/models/textclassifier.es.model
index 2093b41..94b7835 100644
--- a/models/textclassifier.es.model
+++ b/models/textclassifier.es.model
Binary files differ
diff --git a/models/textclassifier.fr.model b/models/textclassifier.fr.model
index b54345b..19081e5 100644
--- a/models/textclassifier.fr.model
+++ b/models/textclassifier.fr.model
Binary files differ
diff --git a/models/textclassifier.it.model b/models/textclassifier.it.model
index e05d2db..2f72c36 100644
--- a/models/textclassifier.it.model
+++ b/models/textclassifier.it.model
Binary files differ
diff --git a/models/textclassifier.ja.model b/models/textclassifier.ja.model
index de10271..92d7cef 100644
--- a/models/textclassifier.ja.model
+++ b/models/textclassifier.ja.model
Binary files differ
diff --git a/models/textclassifier.ko.model b/models/textclassifier.ko.model
index 00d1bf3..7e88f54 100644
--- a/models/textclassifier.ko.model
+++ b/models/textclassifier.ko.model
Binary files differ
diff --git a/models/textclassifier.nl.model b/models/textclassifier.nl.model
index a733938..b2e3923 100644
--- a/models/textclassifier.nl.model
+++ b/models/textclassifier.nl.model
Binary files differ
diff --git a/models/textclassifier.pl.model b/models/textclassifier.pl.model
index 3947dc2..7231c49 100644
--- a/models/textclassifier.pl.model
+++ b/models/textclassifier.pl.model
Binary files differ
diff --git a/models/textclassifier.pt.model b/models/textclassifier.pt.model
index b7bb298..cae8692 100644
--- a/models/textclassifier.pt.model
+++ b/models/textclassifier.pt.model
Binary files differ
diff --git a/models/textclassifier.ru.model b/models/textclassifier.ru.model
index 377f73f..5be2ecc 100644
--- a/models/textclassifier.ru.model
+++ b/models/textclassifier.ru.model
Binary files differ
diff --git a/models/textclassifier.th.model b/models/textclassifier.th.model
index 41a3a3b..321edd7 100644
--- a/models/textclassifier.th.model
+++ b/models/textclassifier.th.model
Binary files differ
diff --git a/models/textclassifier.tr.model b/models/textclassifier.tr.model
index e284388..6d11cef 100644
--- a/models/textclassifier.tr.model
+++ b/models/textclassifier.tr.model
Binary files differ
diff --git a/models/textclassifier.universal.model b/models/textclassifier.universal.model
index 7856747..af19e67 100644
--- a/models/textclassifier.universal.model
+++ b/models/textclassifier.universal.model
Binary files differ
diff --git a/models/textclassifier.zh-Hant.model b/models/textclassifier.zh-Hant.model
index dd04f09..366c923 100644
--- a/models/textclassifier.zh-Hant.model
+++ b/models/textclassifier.zh-Hant.model
Binary files differ
diff --git a/models/textclassifier.zh.model b/models/textclassifier.zh.model
index 4e5f525..22f2777 100644
--- a/models/textclassifier.zh.model
+++ b/models/textclassifier.zh.model
Binary files differ
diff --git a/models/update.sh b/models/update.sh
index 8b60d2f..e756859 100755
--- a/models/update.sh
+++ b/models/update.sh
@@ -3,10 +3,22 @@
set -e
-BASE_URL=https://www.gstatic.com/android/text_classifier/p/live
+ANNOTATOR_BASE_URL=https://www.gstatic.com/android/text_classifier/q/live
+ACTIONS_BASE_URL=https://www.gstatic.com/android/text_classifier/actions/q/live
+LANGID_BASE_URL=https://www.gstatic.com/android/text_classifier/langid/q/live
+
+download() {
+ echo "$1/FILELIST"
+ for f in $(wget -O- "$1/FILELIST"); do
+ destination="$(basename -- $f)"
+ wget "$1/$f" -O "$destination"
+ done
+}
cd "$(dirname "$0")"
-for f in $(wget -O- "$BASE_URL/FILELIST"); do
- wget "$BASE_URL/$f" -O "$f"
-done
+download $ANNOTATOR_BASE_URL
+download $ACTIONS_BASE_URL
+download $LANGID_BASE_URL
+
+echo "You may want to edit the file name of downloaded files, see external/libtextclassifier/Android.bp"
diff --git a/utils/calendar/CalendarJavaIcuLocalTest.java b/utils/calendar/CalendarJavaIcuLocalTest.java
new file mode 100644
index 0000000..9beb36e
--- /dev/null
+++ b/utils/calendar/CalendarJavaIcuLocalTest.java
@@ -0,0 +1,41 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.google.android.textclassifier.utils.calendar;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import com.google.thirdparty.robolectric.GoogleRobolectricTestRunner;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+
+/** This is just a launcher of the tests because we need a valid JNIEnv in the C++ test. */
+@RunWith(GoogleRobolectricTestRunner.class)
+public class CalendarJavaIcuLocalTest {
+
+ @Before
+ public void setUp() throws Exception {
+ System.loadLibrary("calendar-javaicu_test-lib");
+ }
+
+ private native boolean testsMain();
+
+ @Test
+ public void testNative() {
+ assertThat(testsMain()).isTrue();
+ }
+}
diff --git a/utils/calendar/CalendarJavaIcuTest.java b/utils/calendar/CalendarJavaIcuTest.java
new file mode 100644
index 0000000..ab1f00a
--- /dev/null
+++ b/utils/calendar/CalendarJavaIcuTest.java
@@ -0,0 +1,41 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.google.android.textclassifier.utils.calendar;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/** This is just a launcher of the tests because we need a valid JNIEnv in the C++ test. */
+@RunWith(JUnit4.class)
+public class CalendarJavaIcuTest {
+
+ @Before
+ public void setUp() throws Exception {
+ System.loadLibrary("calendar-javaicu_test-lib");
+ }
+
+ private native boolean testsMain();
+
+ @Test
+ public void testNative() {
+ assertThat(testsMain()).isTrue();
+ }
+}
diff --git a/utils/calendar/calendar-common.h b/utils/calendar/calendar-common.h
index 7e606de..5c91e22 100644
--- a/utils/calendar/calendar-common.h
+++ b/utils/calendar/calendar-common.h
@@ -41,8 +41,10 @@
int64 reference_time_ms_utc,
const std::string& reference_timezone,
const std::string& reference_locale,
- DatetimeGranularity granularity,
- TCalendar* calendar) const;
+ TCalendar* calendar,
+ DatetimeGranularity* granularity) const;
+
+ DatetimeGranularity GetGranularity(const DateParseData& data) const;
private:
// Adjusts the calendar's time instant according to a relative date reference
@@ -70,16 +72,12 @@
bool CalendarLibTempl<TCalendar>::InterpretParseData(
const DateParseData& parse_data, int64 reference_time_ms_utc,
const std::string& reference_timezone, const std::string& reference_locale,
- DatetimeGranularity granularity, TCalendar* calendar) const {
+ TCalendar* calendar, DatetimeGranularity* granularity) const {
TC3_CALENDAR_CHECK(calendar->Initialize(reference_timezone, reference_locale,
reference_time_ms_utc))
- // By default, the parsed time is interpreted to be on the reference day.
- // But a parsed date should have time 0:00:00 unless specified.
- TC3_CALENDAR_CHECK(calendar->SetHourOfDay(0))
- TC3_CALENDAR_CHECK(calendar->SetMinute(0))
- TC3_CALENDAR_CHECK(calendar->SetSecond(0))
- TC3_CALENDAR_CHECK(calendar->SetMillisecond(0))
+ bool should_round_to_granularity = true;
+ *granularity = GetGranularity(parse_data);
// Apply each of the parsed fields in order of increasing granularity.
static const int64 kMillisInHour = 1000 * 60 * 60;
@@ -93,6 +91,20 @@
}
if (parse_data.field_set_mask & DateParseData::Fields::RELATION_FIELD) {
TC3_CALENDAR_CHECK(ApplyRelationField(parse_data, calendar));
+ // Don't round to the granularity for relative expressions that specify the
+ // distance. So that, e.g. "in 2 hours" when it's 8:35:03 will result in
+ // 10:35:03.
+ if (parse_data.field_set_mask &
+ DateParseData::Fields::RELATION_DISTANCE_FIELD) {
+ should_round_to_granularity = false;
+ }
+ } else {
+ // By default, the parsed time is interpreted to be on the reference day.
+ // But a parsed date should have time 0:00:00 unless specified.
+ TC3_CALENDAR_CHECK(calendar->SetHourOfDay(0))
+ TC3_CALENDAR_CHECK(calendar->SetMinute(0))
+ TC3_CALENDAR_CHECK(calendar->SetSecond(0))
+ TC3_CALENDAR_CHECK(calendar->SetMillisecond(0))
}
if (parse_data.field_set_mask & DateParseData::Fields::YEAR_FIELD) {
TC3_CALENDAR_CHECK(calendar->SetYear(parse_data.year))
@@ -109,6 +121,9 @@
if (parse_data.field_set_mask & DateParseData::Fields::AMPM_FIELD &&
parse_data.ampm == DateParseData::AMPM::PM && parse_data.hour < 12) {
TC3_CALENDAR_CHECK(calendar->SetHourOfDay(parse_data.hour + 12))
+ } else if (parse_data.ampm == DateParseData::AMPM::AM &&
+ parse_data.hour == 12) {
+ // Do nothing. 12am == 0.
} else {
TC3_CALENDAR_CHECK(calendar->SetHourOfDay(parse_data.hour))
}
@@ -120,7 +135,9 @@
TC3_CALENDAR_CHECK(calendar->SetSecond(parse_data.second))
}
- TC3_CALENDAR_CHECK(RoundToGranularity(granularity, calendar))
+ if (should_round_to_granularity) {
+ TC3_CALENDAR_CHECK(RoundToGranularity(*granularity, calendar))
+ }
return true;
}
@@ -131,6 +148,9 @@
constexpr int relation_distance_mask =
DateParseData::Fields::RELATION_DISTANCE_FIELD;
switch (parse_data.relation) {
+ case DateParseData::Relation::UNSPECIFIED:
+ TC3_LOG(ERROR) << "UNSPECIFIED RelationType.";
+ return false;
case DateParseData::Relation::NEXT:
if (parse_data.field_set_mask & relation_type_mask) {
TC3_CALENDAR_CHECK(AdjustByRelation(parse_data.relation_type,
@@ -226,13 +246,13 @@
TCalendar* calendar) const {
const int distance_sign = distance < 0 ? -1 : 1;
switch (relation_type) {
- case DateParseData::MONDAY:
- case DateParseData::TUESDAY:
- case DateParseData::WEDNESDAY:
- case DateParseData::THURSDAY:
- case DateParseData::FRIDAY:
- case DateParseData::SATURDAY:
- case DateParseData::SUNDAY:
+ case DateParseData::RelationType::MONDAY:
+ case DateParseData::RelationType::TUESDAY:
+ case DateParseData::RelationType::WEDNESDAY:
+ case DateParseData::RelationType::THURSDAY:
+ case DateParseData::RelationType::FRIDAY:
+ case DateParseData::RelationType::SATURDAY:
+ case DateParseData::RelationType::SUNDAY:
if (!allow_today) {
// If we're not including the same day as the reference, skip it.
TC3_CALENDAR_CHECK(calendar->AddDayOfMonth(distance_sign))
@@ -241,34 +261,98 @@
while (distance != 0) {
int day_of_week;
TC3_CALENDAR_CHECK(calendar->GetDayOfWeek(&day_of_week))
- if (day_of_week == relation_type) {
+ if (day_of_week == static_cast<int>(relation_type)) {
distance += -distance_sign;
if (distance == 0) break;
}
TC3_CALENDAR_CHECK(calendar->AddDayOfMonth(distance_sign))
}
return true;
- case DateParseData::DAY:
+ case DateParseData::RelationType::SECOND:
+ TC3_CALENDAR_CHECK(calendar->AddSecond(distance));
+ return true;
+ case DateParseData::RelationType::MINUTE:
+ TC3_CALENDAR_CHECK(calendar->AddMinute(distance));
+ return true;
+ case DateParseData::RelationType::HOUR:
+ TC3_CALENDAR_CHECK(calendar->AddHourOfDay(distance));
+ return true;
+ case DateParseData::RelationType::DAY:
TC3_CALENDAR_CHECK(calendar->AddDayOfMonth(distance));
return true;
- case DateParseData::WEEK:
+ case DateParseData::RelationType::WEEK:
TC3_CALENDAR_CHECK(calendar->AddDayOfMonth(7 * distance))
TC3_CALENDAR_CHECK(calendar->SetDayOfWeek(1))
return true;
- case DateParseData::MONTH:
+ case DateParseData::RelationType::MONTH:
TC3_CALENDAR_CHECK(calendar->AddMonth(distance))
TC3_CALENDAR_CHECK(calendar->SetDayOfMonth(1))
return true;
- case DateParseData::YEAR:
+ case DateParseData::RelationType::YEAR:
TC3_CALENDAR_CHECK(calendar->AddYear(distance))
TC3_CALENDAR_CHECK(calendar->SetDayOfYear(1))
return true;
default:
+ TC3_LOG(ERROR) << "Unknown relation type: "
+ << static_cast<int>(relation_type);
return false;
}
return false;
}
+template <class TCalendar>
+DatetimeGranularity CalendarLibTempl<TCalendar>::GetGranularity(
+ const DateParseData& data) const {
+ DatetimeGranularity granularity = DatetimeGranularity::GRANULARITY_YEAR;
+ if ((data.field_set_mask & DateParseData::YEAR_FIELD) ||
+ (data.field_set_mask & DateParseData::RELATION_TYPE_FIELD &&
+ (data.relation_type == DateParseData::RelationType::YEAR))) {
+ granularity = DatetimeGranularity::GRANULARITY_YEAR;
+ }
+ if ((data.field_set_mask & DateParseData::MONTH_FIELD) ||
+ (data.field_set_mask & DateParseData::RELATION_TYPE_FIELD &&
+ (data.relation_type == DateParseData::RelationType::MONTH))) {
+ granularity = DatetimeGranularity::GRANULARITY_MONTH;
+ }
+ if (data.field_set_mask & DateParseData::RELATION_TYPE_FIELD &&
+ (data.relation_type == DateParseData::RelationType::WEEK)) {
+ granularity = DatetimeGranularity::GRANULARITY_WEEK;
+ }
+ if (data.field_set_mask & DateParseData::DAY_FIELD ||
+ (data.field_set_mask & DateParseData::RELATION_FIELD &&
+ (data.relation == DateParseData::Relation::NOW ||
+ data.relation == DateParseData::Relation::TOMORROW ||
+ data.relation == DateParseData::Relation::YESTERDAY)) ||
+ (data.field_set_mask & DateParseData::RELATION_TYPE_FIELD &&
+ (data.relation_type == DateParseData::RelationType::MONDAY ||
+ data.relation_type == DateParseData::RelationType::TUESDAY ||
+ data.relation_type == DateParseData::RelationType::WEDNESDAY ||
+ data.relation_type == DateParseData::RelationType::THURSDAY ||
+ data.relation_type == DateParseData::RelationType::FRIDAY ||
+ data.relation_type == DateParseData::RelationType::SATURDAY ||
+ data.relation_type == DateParseData::RelationType::SUNDAY ||
+ data.relation_type == DateParseData::RelationType::DAY))) {
+ granularity = DatetimeGranularity::GRANULARITY_DAY;
+ }
+ if (data.field_set_mask & DateParseData::HOUR_FIELD ||
+ (data.field_set_mask & DateParseData::RELATION_TYPE_FIELD &&
+ (data.relation_type == DateParseData::RelationType::HOUR))) {
+ granularity = DatetimeGranularity::GRANULARITY_HOUR;
+ }
+ if (data.field_set_mask & DateParseData::MINUTE_FIELD ||
+ (data.field_set_mask & DateParseData::RELATION_TYPE_FIELD &&
+ data.relation_type == DateParseData::RelationType::MINUTE)) {
+ granularity = DatetimeGranularity::GRANULARITY_MINUTE;
+ }
+ if (data.field_set_mask & DateParseData::SECOND_FIELD ||
+ (data.field_set_mask & DateParseData::RELATION_TYPE_FIELD &&
+ (data.relation_type == DateParseData::RelationType::SECOND))) {
+ granularity = DatetimeGranularity::GRANULARITY_SECOND;
+ }
+
+ return granularity;
+}
+
}; // namespace calendar
#undef TC3_CALENDAR_CHECK
diff --git a/utils/calendar/calendar-javaicu.cc b/utils/calendar/calendar-javaicu.cc
index 7b7f2fa..ac09979 100644
--- a/utils/calendar/calendar-javaicu.cc
+++ b/utils/calendar/calendar-javaicu.cc
@@ -67,13 +67,20 @@
}
// We'll assume the day indices match later on, so verify it here.
- if (jni_cache_->calendar_sunday != DateParseData::SUNDAY ||
- jni_cache_->calendar_monday != DateParseData::MONDAY ||
- jni_cache_->calendar_tuesday != DateParseData::TUESDAY ||
- jni_cache_->calendar_wednesday != DateParseData::WEDNESDAY ||
- jni_cache_->calendar_thursday != DateParseData::THURSDAY ||
- jni_cache_->calendar_friday != DateParseData::FRIDAY ||
- jni_cache_->calendar_saturday != DateParseData::SATURDAY) {
+ if (jni_cache_->calendar_sunday !=
+ static_cast<int>(DateParseData::RelationType::SUNDAY) ||
+ jni_cache_->calendar_monday !=
+ static_cast<int>(DateParseData::RelationType::MONDAY) ||
+ jni_cache_->calendar_tuesday !=
+ static_cast<int>(DateParseData::RelationType::TUESDAY) ||
+ jni_cache_->calendar_wednesday !=
+ static_cast<int>(DateParseData::RelationType::WEDNESDAY) ||
+ jni_cache_->calendar_thursday !=
+ static_cast<int>(DateParseData::RelationType::THURSDAY) ||
+ jni_cache_->calendar_friday !=
+ static_cast<int>(DateParseData::RelationType::FRIDAY) ||
+ jni_cache_->calendar_saturday !=
+ static_cast<int>(DateParseData::RelationType::SATURDAY)) {
TC3_LOG(ERROR) << "day of the week indices mismatch";
return false;
}
@@ -166,6 +173,9 @@
#define TC3_DEFINE_GET(NAME, CONST) \
TC3_DEFINE_FIELD_ACCESSOR(NAME, CONST, Get, int*)
+TC3_DEFINE_ADD(Second, second)
+TC3_DEFINE_ADD(Minute, minute)
+TC3_DEFINE_ADD(HourOfDay, hour_of_day)
TC3_DEFINE_ADD(DayOfMonth, day_of_month)
TC3_DEFINE_ADD(Year, year)
TC3_DEFINE_ADD(Month, month)
diff --git a/utils/calendar/calendar-javaicu.h b/utils/calendar/calendar-javaicu.h
index 88e696a..02673cc 100644
--- a/utils/calendar/calendar-javaicu.h
+++ b/utils/calendar/calendar-javaicu.h
@@ -34,6 +34,9 @@
explicit Calendar(JniCache* jni_cache);
bool Initialize(const std::string& time_zone, const std::string& locale,
int64 time_ms_utc);
+ bool AddSecond(int value) const;
+ bool AddMinute(int value) const;
+ bool AddHourOfDay(int value) const;
bool AddDayOfMonth(int value) const;
bool AddYear(int value) const;
bool AddMonth(int value) const;
@@ -68,20 +71,24 @@
int64 reference_time_ms_utc,
const std::string& reference_timezone,
const std::string& reference_locale,
- DatetimeGranularity granularity,
- int64* interpreted_time_ms_utc) const {
+ int64* interpreted_time_ms_utc,
+ DatetimeGranularity* granularity) const {
Calendar calendar(jni_cache_.get());
- calendar::CalendarLibTempl<Calendar> impl;
- if (!impl.InterpretParseData(parse_data, reference_time_ms_utc,
- reference_timezone, reference_locale,
- granularity, &calendar)) {
+ if (!impl_.InterpretParseData(parse_data, reference_time_ms_utc,
+ reference_timezone, reference_locale,
+ &calendar, granularity)) {
return false;
}
return calendar.GetTimeInMillis(interpreted_time_ms_utc);
}
+ DatetimeGranularity GetGranularity(const DateParseData& data) const {
+ return impl_.GetGranularity(data);
+ }
+
private:
std::shared_ptr<JniCache> jni_cache_;
+ calendar::CalendarLibTempl<Calendar> impl_;
};
} // namespace libtextclassifier3
diff --git a/utils/calendar/calendar_test-include.cc b/utils/calendar/calendar_test-include.cc
new file mode 100644
index 0000000..70520a2
--- /dev/null
+++ b/utils/calendar/calendar_test-include.cc
@@ -0,0 +1,309 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/calendar/calendar_test-include.h"
+
+namespace libtextclassifier3 {
+namespace test_internal {
+
+TEST_F(CalendarTest, Interface) {
+ int64 time;
+ DatetimeGranularity granularity;
+ std::string timezone;
+ bool result = calendarlib_.InterpretParseData(
+ DateParseData{/*field_set_mask=*/0, /*year=*/0, /*month=*/0,
+ /*day_of_month=*/0, /*hour=*/0, /*minute=*/0, /*second=*/0,
+ /*ampm=*/static_cast<DateParseData::AMPM>(0),
+ /*zone_offset=*/0, /*dst_offset=*/0,
+ static_cast<DateParseData::Relation>(0),
+ static_cast<DateParseData::RelationType>(0),
+ /*relation_distance=*/0},
+ 0L, "Zurich", "en-CH", &time, &granularity);
+ TC3_LOG(INFO) << result;
+}
+
+TEST_F(CalendarTest, SetsZeroTimeWhenNotRelative) {
+ int64 time;
+ DatetimeGranularity granularity;
+ DateParseData data;
+
+ data.year = 2018;
+ data.field_set_mask = DateParseData::YEAR_FIELD;
+ ASSERT_TRUE(calendarlib_.InterpretParseData(
+ data,
+ /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
+ /*reference_locale=*/"en-CH", &time, &granularity));
+ EXPECT_EQ(time, 1514761200000L /* Jan 01 2018 00:00:00 */);
+
+ ASSERT_TRUE(calendarlib_.InterpretParseData(
+ data,
+ /*reference_time_ms_utc=*/1L, /*reference_timezone=*/"Europe/Zurich",
+ /*reference_locale=*/"en-CH", &time, &granularity));
+ EXPECT_EQ(time, 1514761200000L /* Jan 01 2018 00:00:00 */);
+}
+
+TEST_F(CalendarTest, RoundingToGranularityBasic) {
+ int64 time;
+ DatetimeGranularity granularity;
+ DateParseData data;
+
+ data.year = 2018;
+ data.field_set_mask = DateParseData::YEAR_FIELD;
+ ASSERT_TRUE(calendarlib_.InterpretParseData(
+ data,
+ /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
+ /*reference_locale=*/"en-CH", &time, &granularity));
+ EXPECT_EQ(time, 1514761200000L /* Jan 01 2018 00:00:00 */);
+
+ data.month = 4;
+ data.field_set_mask |= DateParseData::MONTH_FIELD;
+ ASSERT_TRUE(calendarlib_.InterpretParseData(
+ data,
+ /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
+ /*reference_locale=*/"en-CH", &time, &granularity));
+ EXPECT_EQ(time, 1522533600000L /* Apr 01 2018 00:00:00 */);
+
+ data.day_of_month = 25;
+ data.field_set_mask |= DateParseData::DAY_FIELD;
+ ASSERT_TRUE(calendarlib_.InterpretParseData(
+ data,
+ /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
+ /*reference_locale=*/"en-CH", &time, &granularity));
+ EXPECT_EQ(time, 1524607200000L /* Apr 25 2018 00:00:00 */);
+
+ data.hour = 9;
+ data.field_set_mask |= DateParseData::HOUR_FIELD;
+ ASSERT_TRUE(calendarlib_.InterpretParseData(
+ data,
+ /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
+ /*reference_locale=*/"en-CH", &time, &granularity));
+ EXPECT_EQ(time, 1524639600000L /* Apr 25 2018 09:00:00 */);
+
+ data.minute = 33;
+ data.field_set_mask |= DateParseData::MINUTE_FIELD;
+ ASSERT_TRUE(calendarlib_.InterpretParseData(
+ data,
+ /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
+ /*reference_locale=*/"en-CH", &time, &granularity));
+ EXPECT_EQ(time, 1524641580000 /* Apr 25 2018 09:33:00 */);
+
+ data.second = 59;
+ data.field_set_mask |= DateParseData::SECOND_FIELD;
+ ASSERT_TRUE(calendarlib_.InterpretParseData(
+ data,
+ /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
+ /*reference_locale=*/"en-CH", &time, &granularity));
+ EXPECT_EQ(time, 1524641639000 /* Apr 25 2018 09:33:59 */);
+}
+
+TEST_F(CalendarTest, RoundingToGranularityWeek) {
+ int64 time;
+ DatetimeGranularity granularity;
+ // Prepare data structure that means: "next week"
+ DateParseData data;
+ data.field_set_mask =
+ DateParseData::RELATION_FIELD | DateParseData::RELATION_TYPE_FIELD;
+ data.relation = DateParseData::Relation::NEXT;
+ data.relation_type = DateParseData::RelationType::WEEK;
+
+ ASSERT_TRUE(calendarlib_.InterpretParseData(
+ data,
+ /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
+ /*reference_locale=*/"de-CH", &time, &granularity));
+ EXPECT_EQ(time, 342000000L /* Mon Jan 05 1970 00:00:00 */);
+
+ ASSERT_TRUE(calendarlib_.InterpretParseData(
+ data,
+ /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
+ /*reference_locale=*/"en-US", &time, &granularity));
+ EXPECT_EQ(time, 255600000L /* Sun Jan 04 1970 00:00:00 */);
+}
+
+TEST_F(CalendarTest, RelativeTime) {
+ const int field_mask = DateParseData::RELATION_FIELD |
+ DateParseData::RELATION_TYPE_FIELD |
+ DateParseData::RELATION_DISTANCE_FIELD;
+ const int64 ref_time = 1524648839000L; /* 25 April 2018 09:33:59 */
+ int64 time;
+ DatetimeGranularity granularity;
+
+ // Two Weds from now.
+ const DateParseData future_wed_parse = {
+ field_mask,
+ /*year=*/0,
+ /*month=*/0,
+ /*day_of_month=*/0,
+ /*hour=*/0,
+ /*minute=*/0,
+ /*second=*/0,
+ static_cast<DateParseData::AMPM>(0),
+ /*zone_offset=*/0,
+ /*dst_offset=*/0,
+ DateParseData::Relation::FUTURE,
+ DateParseData::RelationType::WEDNESDAY,
+ /*relation_distance=*/2};
+ ASSERT_TRUE(calendarlib_.InterpretParseData(
+ future_wed_parse, ref_time, /*reference_timezone=*/"Europe/Zurich",
+ /*reference_locale=*/"en-US", &time, &granularity));
+ EXPECT_EQ(time, 1525858439000L /* Wed May 09 2018 11:33:59 */);
+ EXPECT_EQ(granularity, GRANULARITY_DAY);
+
+ // Next Wed.
+ const DateParseData next_wed_parse = {field_mask,
+ /*year=*/0,
+ /*month=*/0,
+ /*day_of_month=*/0,
+ /*hour=*/0,
+ /*minute=*/0,
+ /*second=*/0,
+ static_cast<DateParseData::AMPM>(0),
+ /*zone_offset=*/0,
+ /*dst_offset=*/0,
+ DateParseData::Relation::NEXT,
+ DateParseData::RelationType::WEDNESDAY,
+ /*relation_distance=*/0};
+ ASSERT_TRUE(calendarlib_.InterpretParseData(
+ next_wed_parse, ref_time, /*reference_timezone=*/"Europe/Zurich",
+ /*reference_locale=*/"en-US", &time, &granularity));
+ EXPECT_EQ(time, 1525253639000L /* Wed May 02 2018 11:33:59 */);
+ EXPECT_EQ(granularity, GRANULARITY_DAY);
+
+ // Same Wed.
+ const DateParseData same_wed_parse = {field_mask,
+ /*year=*/0,
+ /*month=*/0,
+ /*day_of_month=*/0,
+ /*hour=*/0,
+ /*minute=*/0,
+ /*second=*/0,
+ static_cast<DateParseData::AMPM>(0),
+ /*zone_offset=*/0,
+ /*dst_offset=*/0,
+ DateParseData::Relation::NEXT_OR_SAME,
+ DateParseData::RelationType::WEDNESDAY,
+ /*relation_distance=*/0};
+ ASSERT_TRUE(calendarlib_.InterpretParseData(
+ same_wed_parse, ref_time, /*reference_timezone=*/"Europe/Zurich",
+ /*reference_locale=*/"en-US", &time, &granularity));
+ EXPECT_EQ(time, 1524648839000L /* Wed Apr 25 2018 11:33:59 */);
+ EXPECT_EQ(granularity, GRANULARITY_DAY);
+
+ // Previous Wed.
+ const DateParseData last_wed_parse = {field_mask,
+ /*year=*/0,
+ /*month=*/0,
+ /*day_of_month=*/0,
+ /*hour=*/0,
+ /*minute=*/0,
+ /*second=*/0,
+ static_cast<DateParseData::AMPM>(0),
+ /*zone_offset=*/0,
+ /*dst_offset=*/0,
+ DateParseData::Relation::LAST,
+ DateParseData::RelationType::WEDNESDAY,
+ /*relation_distance=*/0};
+ ASSERT_TRUE(calendarlib_.InterpretParseData(
+ last_wed_parse, ref_time, /*reference_timezone=*/"Europe/Zurich",
+ /*reference_locale=*/"en-US", &time, &granularity));
+ EXPECT_EQ(time, 1524044039000L /* Wed Apr 18 2018 11:33:59 */);
+ EXPECT_EQ(granularity, GRANULARITY_DAY);
+
+ // Two Weds ago.
+ const DateParseData past_wed_parse = {field_mask,
+ /*year=*/0,
+ /*month=*/0,
+ /*day_of_month=*/0,
+ /*hour=*/0,
+ /*minute=*/0,
+ /*second=*/0,
+ static_cast<DateParseData::AMPM>(0),
+ /*zone_offset=*/0,
+ /*dst_offset=*/0,
+ DateParseData::Relation::PAST,
+ DateParseData::RelationType::WEDNESDAY,
+ /*relation_distance=*/2};
+ ASSERT_TRUE(calendarlib_.InterpretParseData(
+ past_wed_parse, ref_time, /*reference_timezone=*/"Europe/Zurich",
+ /*reference_locale=*/"en-US", &time, &granularity));
+ EXPECT_EQ(time, 1523439239000L /* Wed Apr 11 2018 11:33:59 */);
+ EXPECT_EQ(granularity, GRANULARITY_DAY);
+
+ // In 3 hours.
+ const DateParseData in_3_hours_parse = {
+ field_mask,
+ /*year=*/0,
+ /*month=*/0,
+ /*day_of_month=*/0,
+ /*hour=*/0,
+ /*minute=*/0,
+ /*second=*/0,
+ /*ampm=*/static_cast<DateParseData::AMPM>(0),
+ /*zone_offset=*/0,
+ /*dst_offset=*/0,
+ DateParseData::Relation::FUTURE,
+ DateParseData::RelationType::HOUR,
+ /*relation_distance=*/3};
+ ASSERT_TRUE(calendarlib_.InterpretParseData(
+ in_3_hours_parse, ref_time, /*reference_timezone=*/"Europe/Zurich",
+ /*reference_locale=*/"en-US", &time, &granularity));
+ EXPECT_EQ(time, 1524659639000L /* Wed Apr 25 2018 14:33:59 */);
+ EXPECT_EQ(granularity, GRANULARITY_HOUR);
+
+ // In 5 minutes.
+ const DateParseData in_5_minutes_parse = {
+ field_mask,
+ /*year=*/0,
+ /*month=*/0,
+ /*day_of_month=*/0,
+ /*hour=*/0,
+ /*minute=*/0,
+ /*second=*/0,
+ /*ampm=*/static_cast<DateParseData::AMPM>(0),
+ /*zone_offset=*/0,
+ /*dst_offset=*/0,
+ DateParseData::Relation::FUTURE,
+ DateParseData::RelationType::MINUTE,
+ /*relation_distance=*/5};
+ ASSERT_TRUE(calendarlib_.InterpretParseData(
+ in_5_minutes_parse, ref_time, /*reference_timezone=*/"Europe/Zurich",
+ /*reference_locale=*/"en-US", &time, &granularity));
+ EXPECT_EQ(time, 1524649139000L /* Wed Apr 25 2018 14:33:59 */);
+ EXPECT_EQ(granularity, GRANULARITY_MINUTE);
+
+ // In 10 seconds.
+ const DateParseData in_10_seconds_parse = {
+ field_mask,
+ /*year=*/0,
+ /*month=*/0,
+ /*day_of_month=*/0,
+ /*hour=*/0,
+ /*minute=*/0,
+ /*second=*/0,
+ /*ampm=*/static_cast<DateParseData::AMPM>(0),
+ /*zone_offset=*/0,
+ /*dst_offset=*/0,
+ DateParseData::Relation::FUTURE,
+ DateParseData::RelationType::SECOND,
+ /*relation_distance=*/10};
+ ASSERT_TRUE(calendarlib_.InterpretParseData(
+ in_10_seconds_parse, ref_time, /*reference_timezone=*/"Europe/Zurich",
+ /*reference_locale=*/"en-US", &time, &granularity));
+ EXPECT_EQ(time, 1524648849000L /* Wed Apr 25 2018 14:33:59 */);
+ EXPECT_EQ(granularity, GRANULARITY_SECOND);
+}
+
+} // namespace test_internal
+} // namespace libtextclassifier3
diff --git a/utils/calendar/calendar_test-include.h b/utils/calendar/calendar_test-include.h
new file mode 100644
index 0000000..169a4ed
--- /dev/null
+++ b/utils/calendar/calendar_test-include.h
@@ -0,0 +1,57 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// This is a shared test between icu and javaicu calendar implementations.
+// It is meant to be #include'd.
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_CALENDAR_CALENDAR_TEST_INCLUDE_H_
+#define LIBTEXTCLASSIFIER_UTILS_CALENDAR_CALENDAR_TEST_INCLUDE_H_
+
+#if defined TC3_CALENDAR_ICU
+#include "utils/calendar/calendar-icu.h"
+#define TC3_TESTING_CREATE_CALENDARLIB_INSTANCE(VAR) VAR()
+#elif defined TC3_CALENDAR_JAVAICU
+#include <jni.h>
+extern JNIEnv* g_jenv;
+#define TC3_TESTING_CREATE_CALENDARLIB_INSTANCE(VAR) \
+ VAR(JniCache::Create(g_jenv))
+#include "utils/calendar/calendar-javaicu.h"
+#else
+#error Unsupported calendar implementation.
+#endif
+#include "utils/base/logging.h"
+
+#include "gtest/gtest.h"
+
+// This can get overridden in the javaicu version which needs to pass an JNIEnv*
+// argument to the constructor.
+#ifndef TC3_TESTING_CREATE_CALENDARLIB_INSTANCE
+
+#endif
+
+namespace libtextclassifier3 {
+namespace test_internal {
+
+class CalendarTest : public ::testing::Test {
+ protected:
+ CalendarTest() : TC3_TESTING_CREATE_CALENDARLIB_INSTANCE(calendarlib_) {}
+ CalendarLib calendarlib_;
+};
+
+} // namespace test_internal
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_CALENDAR_CALENDAR_TEST_INCLUDE_H_
diff --git a/utils/calendar/calendar_test.cc b/utils/calendar/calendar_test.cc
deleted file mode 100644
index a8c3af8..0000000
--- a/utils/calendar/calendar_test.cc
+++ /dev/null
@@ -1,244 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-// This test serves the purpose of making sure all the different implementations
-// of the unspoken CalendarLib interface support the same methods.
-
-#include "utils/calendar/calendar.h"
-#include "utils/base/logging.h"
-
-#include "gtest/gtest.h"
-
-namespace libtextclassifier3 {
-namespace {
-
-class CalendarTest : public ::testing::Test {
- protected:
- CalendarTest() : INIT_CALENDARLIB_FOR_TESTING(calendarlib_) {}
- CalendarLib calendarlib_;
-};
-
-TEST_F(CalendarTest, Interface) {
- int64 time;
- std::string timezone;
- bool result = calendarlib_.InterpretParseData(
- DateParseData{/*field_set_mask=*/0, /*year=*/0, /*month=*/0,
- /*day_of_month=*/0, /*hour=*/0, /*minute=*/0, /*second=*/0,
- /*ampm=*/0, /*zone_offset=*/0, /*dst_offset=*/0,
- static_cast<DateParseData::Relation>(0),
- static_cast<DateParseData::RelationType>(0),
- /*relation_distance=*/0},
- 0L, "Zurich", "en-CH", GRANULARITY_UNKNOWN, &time);
- TC3_LOG(INFO) << result;
-}
-
-#ifdef TC3_CALENDAR_ICU
-TEST_F(CalendarTest, RoundingToGranularity) {
- int64 time;
- DateParseData data;
- data.year = 2018;
- data.month = 4;
- data.day_of_month = 25;
- data.hour = 9;
- data.minute = 33;
- data.second = 59;
- data.field_set_mask = DateParseData::YEAR_FIELD | DateParseData::MONTH_FIELD |
- DateParseData::DAY_FIELD | DateParseData::HOUR_FIELD |
- DateParseData::MINUTE_FIELD |
- DateParseData::SECOND_FIELD;
- ASSERT_TRUE(calendarlib_.InterpretParseData(
- data,
- /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
- /*reference_locale=*/"en-CH",
- /*granularity=*/GRANULARITY_YEAR, &time));
- EXPECT_EQ(time, 1514761200000L /* Jan 01 2018 00:00:00 */);
-
- ASSERT_TRUE(calendarlib_.InterpretParseData(
- data,
- /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
- /*reference_locale=*/"en-CH",
- /*granularity=*/GRANULARITY_MONTH, &time));
- EXPECT_EQ(time, 1522533600000L /* Apr 01 2018 00:00:00 */);
-
- ASSERT_TRUE(calendarlib_.InterpretParseData(
- data,
- /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
- /*reference_locale=*/"en-CH",
- /*granularity=*/GRANULARITY_WEEK, &time));
- EXPECT_EQ(time, 1524434400000L /* Mon Apr 23 2018 00:00:00 */);
-
- ASSERT_TRUE(calendarlib_.InterpretParseData(
- data,
- /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
- /*reference_locale=*/"*-CH",
- /*granularity=*/GRANULARITY_WEEK, &time));
- EXPECT_EQ(time, 1524434400000L /* Mon Apr 23 2018 00:00:00 */);
-
- ASSERT_TRUE(calendarlib_.InterpretParseData(
- data,
- /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
- /*reference_locale=*/"en-US",
- /*granularity=*/GRANULARITY_WEEK, &time));
- EXPECT_EQ(time, 1524348000000L /* Sun Apr 22 2018 00:00:00 */);
-
- ASSERT_TRUE(calendarlib_.InterpretParseData(
- data,
- /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
- /*reference_locale=*/"*-US",
- /*granularity=*/GRANULARITY_WEEK, &time));
- EXPECT_EQ(time, 1524348000000L /* Sun Apr 22 2018 00:00:00 */);
-
- ASSERT_TRUE(calendarlib_.InterpretParseData(
- data,
- /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
- /*reference_locale=*/"en-CH",
- /*granularity=*/GRANULARITY_DAY, &time));
- EXPECT_EQ(time, 1524607200000L /* Apr 25 2018 00:00:00 */);
-
- ASSERT_TRUE(calendarlib_.InterpretParseData(
- data,
- /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
- /*reference_locale=*/"en-CH",
- /*granularity=*/GRANULARITY_HOUR, &time));
- EXPECT_EQ(time, 1524639600000L /* Apr 25 2018 09:00:00 */);
-
- ASSERT_TRUE(calendarlib_.InterpretParseData(
- data,
- /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
- /*reference_locale=*/"en-CH",
- /*granularity=*/GRANULARITY_MINUTE, &time));
- EXPECT_EQ(time, 1524641580000 /* Apr 25 2018 09:33:00 */);
-
- ASSERT_TRUE(calendarlib_.InterpretParseData(
- data,
- /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
- /*reference_locale=*/"en-CH",
- /*granularity=*/GRANULARITY_SECOND, &time));
- EXPECT_EQ(time, 1524641639000 /* Apr 25 2018 09:33:59 */);
-}
-
-TEST_F(CalendarTest, RelativeTimeWeekday) {
- const int field_mask = DateParseData::RELATION_FIELD |
- DateParseData::RELATION_TYPE_FIELD |
- DateParseData::RELATION_DISTANCE_FIELD;
- const int64 ref_time = 1524648839000L; /* 25 April 2018 09:33:59 */
- int64 time;
-
- // Two Weds from now.
- const DateParseData future_wed_parse = {
- field_mask,
- /*year=*/0,
- /*month=*/0,
- /*day_of_month=*/0,
- /*hour=*/0,
- /*minute=*/0,
- /*second=*/0,
- /*ampm=*/0,
- /*zone_offset=*/0,
- /*dst_offset=*/0,
- DateParseData::Relation::FUTURE,
- DateParseData::RelationType::WEDNESDAY,
- /*relation_distance=*/2};
- ASSERT_TRUE(calendarlib_.InterpretParseData(
- future_wed_parse, ref_time, /*reference_timezone=*/"Europe/Zurich",
- /*reference_locale=*/"en-US",
- /*granularity=*/GRANULARITY_DAY, &time));
- EXPECT_EQ(time, 1525816800000L /* 9 May 2018 00:00:00 */);
-
- // Next Wed.
- const DateParseData next_wed_parse = {field_mask,
- /*year=*/0,
- /*month=*/0,
- /*day_of_month=*/0,
- /*hour=*/0,
- /*minute=*/0,
- /*second=*/0,
- /*ampm=*/0,
- /*zone_offset=*/0,
- /*dst_offset=*/0,
- DateParseData::Relation::NEXT,
- DateParseData::RelationType::WEDNESDAY,
- /*relation_distance=*/0};
- ASSERT_TRUE(calendarlib_.InterpretParseData(
- next_wed_parse, ref_time, /*reference_timezone=*/"Europe/Zurich",
- /*reference_locale=*/"en-US",
- /*granularity=*/GRANULARITY_DAY, &time));
- EXPECT_EQ(time, 1525212000000L /* 1 May 2018 00:00:00 */);
-
- // Same Wed.
- const DateParseData same_wed_parse = {field_mask,
- /*year=*/0,
- /*month=*/0,
- /*day_of_month=*/0,
- /*hour=*/0,
- /*minute=*/0,
- /*second=*/0,
- /*ampm=*/0,
- /*zone_offset=*/0,
- /*dst_offset=*/0,
- DateParseData::Relation::NEXT_OR_SAME,
- DateParseData::RelationType::WEDNESDAY,
- /*relation_distance=*/0};
- ASSERT_TRUE(calendarlib_.InterpretParseData(
- same_wed_parse, ref_time, /*reference_timezone=*/"Europe/Zurich",
- /*reference_locale=*/"en-US",
- /*granularity=*/GRANULARITY_DAY, &time));
- EXPECT_EQ(time, 1524607200000L /* 25 April 2018 00:00:00 */);
-
- // Previous Wed.
- const DateParseData last_wed_parse = {field_mask,
- /*year=*/0,
- /*month=*/0,
- /*day_of_month=*/0,
- /*hour=*/0,
- /*minute=*/0,
- /*second=*/0,
- /*ampm=*/0,
- /*zone_offset=*/0,
- /*dst_offset=*/0,
- DateParseData::Relation::LAST,
- DateParseData::RelationType::WEDNESDAY,
- /*relation_distance=*/0};
- ASSERT_TRUE(calendarlib_.InterpretParseData(
- last_wed_parse, ref_time, /*reference_timezone=*/"Europe/Zurich",
- /*reference_locale=*/"en-US",
- /*granularity=*/GRANULARITY_DAY, &time));
- EXPECT_EQ(time, 1524002400000L /* 18 April 2018 00:00:00 */);
-
- // Two Weds ago.
- const DateParseData past_wed_parse = {field_mask,
- /*year=*/0,
- /*month=*/0,
- /*day_of_month=*/0,
- /*hour=*/0,
- /*minute=*/0,
- /*second=*/0,
- /*ampm=*/0,
- /*zone_offset=*/0,
- /*dst_offset=*/0,
- DateParseData::Relation::PAST,
- DateParseData::RelationType::WEDNESDAY,
- /*relation_distance=*/2};
- ASSERT_TRUE(calendarlib_.InterpretParseData(
- past_wed_parse, ref_time, /*reference_timezone=*/"Europe/Zurich",
- /*reference_locale=*/"en-US",
- /*granularity=*/GRANULARITY_DAY, &time));
- EXPECT_EQ(time, 1523397600000L /* 11 April 2018 00:00:00 */);
-}
-#endif // TC3_UNILIB_DUMMY
-
-} // namespace
-} // namespace libtextclassifier3
diff --git a/utils/codepoint-range.cc b/utils/codepoint-range.cc
new file mode 100644
index 0000000..e26b160
--- /dev/null
+++ b/utils/codepoint-range.cc
@@ -0,0 +1,70 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/codepoint-range.h"
+
+#include <algorithm>
+
+namespace libtextclassifier3 {
+
+// Returns a sorted list of the codepoint ranges.
+void SortCodepointRanges(
+ const std::vector<const CodepointRange*>& codepoint_ranges,
+ std::vector<CodepointRangeStruct>* sorted_codepoint_ranges) {
+ sorted_codepoint_ranges->clear();
+ sorted_codepoint_ranges->reserve(codepoint_ranges.size());
+ for (const CodepointRange* range : codepoint_ranges) {
+ sorted_codepoint_ranges->push_back(
+ CodepointRangeStruct(range->start(), range->end()));
+ }
+
+ std::sort(sorted_codepoint_ranges->begin(), sorted_codepoint_ranges->end(),
+ [](const CodepointRangeStruct& a, const CodepointRangeStruct& b) {
+ return a.start < b.start;
+ });
+}
+
+// Returns true if given codepoint is covered by the given sorted vector of
+// codepoint ranges.
+bool IsCodepointInRanges(
+ int codepoint, const std::vector<CodepointRangeStruct>& codepoint_ranges) {
+ auto it = std::lower_bound(
+ codepoint_ranges.begin(), codepoint_ranges.end(), codepoint,
+ [](const CodepointRangeStruct& range, int codepoint) {
+ // This function compares range with the
+ // codepoint for the purpose of finding the first
+ // greater or equal range. Because of the use of
+ // std::lower_bound it needs to return true when
+ // range < codepoint; the first time it will
+ // return false the lower bound is found and
+ // returned.
+ //
+ // It might seem weird that the condition is
+ // range.end <= codepoint here but when codepoint
+ // == range.end it means it's actually just
+ // outside of the range, thus the range is less
+ // than the codepoint.
+ return range.end <= codepoint;
+ });
+ if (it != codepoint_ranges.end() && it->start <= codepoint &&
+ it->end > codepoint) {
+ return true;
+ } else {
+ return false;
+ }
+}
+
+} // namespace libtextclassifier3
diff --git a/utils/codepoint-range.fbs b/utils/codepoint-range.fbs
new file mode 100755
index 0000000..135ce30
--- /dev/null
+++ b/utils/codepoint-range.fbs
@@ -0,0 +1,23 @@
+//
+// Copyright (C) 2018 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+// Range of codepoints start - end, where end is exclusive.
+namespace libtextclassifier3;
+table CodepointRange {
+ start:int;
+ end:int;
+}
+
diff --git a/utils/codepoint-range.h b/utils/codepoint-range.h
new file mode 100644
index 0000000..af94d35
--- /dev/null
+++ b/utils/codepoint-range.h
@@ -0,0 +1,49 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_CODEPOINT_RANGE_H_
+#define LIBTEXTCLASSIFIER_UTILS_CODEPOINT_RANGE_H_
+
+#include <vector>
+
+#include "utils/base/integral_types.h"
+#include "utils/codepoint-range_generated.h"
+
+namespace libtextclassifier3 {
+
+// Represents a codepoint range [start, end).
+struct CodepointRangeStruct {
+ int32 start;
+ int32 end;
+
+ CodepointRangeStruct(int32 arg_start, int32 arg_end)
+ : start(arg_start), end(arg_end) {}
+};
+
+// Returns a sorted list of the codepoint (also converts the flatbuffer to
+// struct).
+void SortCodepointRanges(
+ const std::vector<const CodepointRange*>& codepoint_ranges,
+ std::vector<CodepointRangeStruct>* sorted_codepoint_ranges);
+
+// Returns true if given codepoint is covered by the given sorted vector of
+// codepoint ranges.
+bool IsCodepointInRanges(
+ int codepoint, const std::vector<CodepointRangeStruct>& codepoint_ranges);
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_CODEPOINT_RANGE_H_
diff --git a/utils/flatbuffers.cc b/utils/flatbuffers.cc
index c1c2625..a4dbabd 100644
--- a/utils/flatbuffers.cc
+++ b/utils/flatbuffers.cc
@@ -16,11 +16,406 @@
#include "utils/flatbuffers.h"
+#include <vector>
+#include "utils/strings/numbers.h"
+#include "utils/variant.h"
+
namespace libtextclassifier3 {
+namespace {
+bool CreateRepeatedField(
+ const reflection::Schema* schema, const reflection::Type* type,
+ std::unique_ptr<ReflectiveFlatbuffer::RepeatedField>* repeated_field) {
+ switch (type->element()) {
+ case reflection::Bool:
+ repeated_field->reset(new ReflectiveFlatbuffer::TypedRepeatedField<bool>);
+ return true;
+ case reflection::Int:
+ repeated_field->reset(new ReflectiveFlatbuffer::TypedRepeatedField<int>);
+ return true;
+ case reflection::Long:
+ repeated_field->reset(
+ new ReflectiveFlatbuffer::TypedRepeatedField<int64>);
+ return true;
+ case reflection::Float:
+ repeated_field->reset(
+ new ReflectiveFlatbuffer::TypedRepeatedField<float>);
+ return true;
+ case reflection::Double:
+ repeated_field->reset(
+ new ReflectiveFlatbuffer::TypedRepeatedField<double>);
+ return true;
+ case reflection::String:
+ repeated_field->reset(
+ new ReflectiveFlatbuffer::TypedRepeatedField<std::string>);
+ return true;
+ case reflection::Obj:
+ repeated_field->reset(
+ new ReflectiveFlatbuffer::TypedRepeatedField<ReflectiveFlatbuffer>(
+ schema, type));
+ return true;
+ default:
+ TC3_LOG(ERROR) << "Unsupported type: " << type->element();
+ return false;
+ }
+}
+} // namespace
template <>
const char* FlatbufferFileIdentifier<Model>() {
return ModelIdentifier();
}
+std::unique_ptr<ReflectiveFlatbuffer> ReflectiveFlatbufferBuilder::NewRoot()
+ const {
+ if (!schema_->root_table()) {
+ TC3_LOG(ERROR) << "No root table specified.";
+ return nullptr;
+ }
+ return std::unique_ptr<ReflectiveFlatbuffer>(
+ new ReflectiveFlatbuffer(schema_, schema_->root_table()));
+}
+
+std::unique_ptr<ReflectiveFlatbuffer> ReflectiveFlatbufferBuilder::NewTable(
+ StringPiece table_name) const {
+ for (const reflection::Object* object : *schema_->objects()) {
+ if (table_name.Equals(object->name()->str())) {
+ return std::unique_ptr<ReflectiveFlatbuffer>(
+ new ReflectiveFlatbuffer(schema_, object));
+ }
+ }
+ return nullptr;
+}
+
+const reflection::Field* ReflectiveFlatbuffer::GetFieldOrNull(
+ const StringPiece field_name) const {
+ return type_->fields()->LookupByKey(field_name.data());
+}
+
+const reflection::Field* ReflectiveFlatbuffer::GetFieldOrNull(
+ const FlatbufferField* field) const {
+ // Lookup by name might be faster as the fields are sorted by name in the
+ // schema data, so try that first.
+ if (field->field_name() != nullptr) {
+ return GetFieldOrNull(field->field_name()->str());
+ }
+ return GetFieldByOffsetOrNull(field->field_offset());
+}
+
+bool ReflectiveFlatbuffer::GetFieldWithParent(
+ const FlatbufferFieldPath* field_path, ReflectiveFlatbuffer** parent,
+ reflection::Field const** field) {
+ const auto* path = field_path->field();
+ if (path == nullptr || path->size() == 0) {
+ return false;
+ }
+
+ for (int i = 0; i < path->size(); i++) {
+ *parent = (i == 0 ? this : (*parent)->Mutable(*field));
+ if (*parent == nullptr) {
+ return false;
+ }
+ *field = (*parent)->GetFieldOrNull(path->Get(i));
+ if (*field == nullptr) {
+ return false;
+ }
+ }
+
+ return true;
+}
+
+const reflection::Field* ReflectiveFlatbuffer::GetFieldByOffsetOrNull(
+ const int field_offset) const {
+ if (type_->fields() == nullptr) {
+ return nullptr;
+ }
+ for (const reflection::Field* field : *type_->fields()) {
+ if (field->offset() == field_offset) {
+ return field;
+ }
+ }
+ return nullptr;
+}
+
+bool ReflectiveFlatbuffer::IsMatchingType(const reflection::Field* field,
+ const Variant& value) const {
+ switch (field->type()->base_type()) {
+ case reflection::Bool:
+ return value.HasBool();
+ case reflection::Int:
+ return value.HasInt();
+ case reflection::Long:
+ return value.HasInt64();
+ case reflection::Float:
+ return value.HasFloat();
+ case reflection::Double:
+ return value.HasDouble();
+ case reflection::String:
+ return value.HasString();
+ default:
+ return false;
+ }
+}
+
+bool ReflectiveFlatbuffer::ParseAndSet(const reflection::Field* field,
+ const std::string& value) {
+ switch (field->type()->base_type()) {
+ case reflection::String:
+ return Set(field, value);
+ case reflection::Int: {
+ int32 int_value;
+ if (!ParseInt32(value.data(), &int_value)) {
+ TC3_LOG(ERROR) << "Could not parse '" << value << "' as int32.";
+ return false;
+ }
+ return Set(field, int_value);
+ }
+ case reflection::Long: {
+ int64 int_value;
+ if (!ParseInt64(value.data(), &int_value)) {
+ TC3_LOG(ERROR) << "Could not parse '" << value << "' as int64.";
+ return false;
+ }
+ return Set(field, int_value);
+ }
+ case reflection::Float: {
+ double double_value;
+ if (!ParseDouble(value.data(), &double_value)) {
+ TC3_LOG(ERROR) << "Could not parse '" << value << "' as float.";
+ return false;
+ }
+ return Set(field, static_cast<float>(double_value));
+ }
+ case reflection::Double: {
+ double double_value;
+ if (!ParseDouble(value.data(), &double_value)) {
+ TC3_LOG(ERROR) << "Could not parse '" << value << "' as double.";
+ return false;
+ }
+ return Set(field, double_value);
+ }
+ default:
+ TC3_LOG(ERROR) << "Unhandled field type: " << field->type()->base_type();
+ return false;
+ }
+}
+
+bool ReflectiveFlatbuffer::ParseAndSet(const FlatbufferFieldPath* path,
+ const std::string& value) {
+ ReflectiveFlatbuffer* parent;
+ const reflection::Field* field;
+ if (!GetFieldWithParent(path, &parent, &field)) {
+ return false;
+ }
+ return parent->ParseAndSet(field, value);
+}
+
+ReflectiveFlatbuffer* ReflectiveFlatbuffer::Mutable(
+ const StringPiece field_name) {
+ if (const reflection::Field* field = GetFieldOrNull(field_name)) {
+ return Mutable(field);
+ }
+ TC3_LOG(ERROR) << "Unknown field: " << field_name.ToString();
+ return nullptr;
+}
+
+ReflectiveFlatbuffer* ReflectiveFlatbuffer::Mutable(
+ const reflection::Field* field) {
+ if (field->type()->base_type() != reflection::Obj) {
+ TC3_LOG(ERROR) << "Field is not of type Object.";
+ return nullptr;
+ }
+ const auto entry = children_.find(field);
+ if (entry != children_.end()) {
+ return entry->second.get();
+ }
+ const auto it = children_.insert(
+ /*hint=*/entry,
+ std::make_pair(
+ field,
+ std::unique_ptr<ReflectiveFlatbuffer>(new ReflectiveFlatbuffer(
+ schema_, schema_->objects()->Get(field->type()->index())))));
+ return it->second.get();
+}
+
+ReflectiveFlatbuffer::RepeatedField* ReflectiveFlatbuffer::Repeated(
+ StringPiece field_name) {
+ if (const reflection::Field* field = GetFieldOrNull(field_name)) {
+ return Repeated(field);
+ }
+ TC3_LOG(ERROR) << "Unknown field: " << field_name.ToString();
+ return nullptr;
+}
+
+ReflectiveFlatbuffer::RepeatedField* ReflectiveFlatbuffer::Repeated(
+ const reflection::Field* field) {
+ if (field->type()->base_type() != reflection::Vector) {
+ TC3_LOG(ERROR) << "Field is not of type Vector.";
+ return nullptr;
+ }
+
+ // If the repeated field was already set, return its instance.
+ const auto entry = repeated_fields_.find(field);
+ if (entry != repeated_fields_.end()) {
+ return entry->second.get();
+ }
+
+ // Otherwise, create a new instance and store it.
+ std::unique_ptr<RepeatedField> repeated_field;
+ if (!CreateRepeatedField(schema_, field->type(), &repeated_field)) {
+ TC3_LOG(ERROR) << "Could not create repeated field.";
+ return nullptr;
+ }
+ const auto it = repeated_fields_.insert(
+ /*hint=*/entry, std::make_pair(field, std::move(repeated_field)));
+ return it->second.get();
+}
+
+flatbuffers::uoffset_t ReflectiveFlatbuffer::Serialize(
+ flatbuffers::FlatBufferBuilder* builder) const {
+ // Build all children before we can start with this table.
+ std::vector<
+ std::pair</* field vtable offset */ int,
+ /* field data offset in buffer */ flatbuffers::uoffset_t>>
+ offsets;
+ offsets.reserve(children_.size() + repeated_fields_.size());
+ for (const auto& it : children_) {
+ offsets.push_back({it.first->offset(), it.second->Serialize(builder)});
+ }
+
+ // Create strings.
+ for (const auto& it : fields_) {
+ if (it.second.HasString()) {
+ offsets.push_back({it.first->offset(),
+ builder->CreateString(it.second.StringValue()).o});
+ }
+ }
+
+ // Build the repeated fields.
+ for (const auto& it : repeated_fields_) {
+ offsets.push_back({it.first->offset(), it.second->Serialize(builder)});
+ }
+
+ // Build the table now.
+ const flatbuffers::uoffset_t table_start = builder->StartTable();
+
+ // Add scalar fields.
+ for (const auto& it : fields_) {
+ switch (it.second.GetType()) {
+ case Variant::TYPE_BOOL_VALUE:
+ builder->AddElement<uint8_t>(
+ it.first->offset(), static_cast<uint8_t>(it.second.BoolValue()),
+ static_cast<uint8_t>(it.first->default_integer()));
+ continue;
+ case Variant::TYPE_INT_VALUE:
+ builder->AddElement<int32>(
+ it.first->offset(), it.second.IntValue(),
+ static_cast<int32>(it.first->default_integer()));
+ continue;
+ case Variant::TYPE_INT64_VALUE:
+ builder->AddElement<int64>(it.first->offset(), it.second.Int64Value(),
+ it.first->default_integer());
+ continue;
+ case Variant::TYPE_FLOAT_VALUE:
+ builder->AddElement<float>(
+ it.first->offset(), it.second.FloatValue(),
+ static_cast<float>(it.first->default_real()));
+ continue;
+ case Variant::TYPE_DOUBLE_VALUE:
+ builder->AddElement<double>(it.first->offset(), it.second.DoubleValue(),
+ it.first->default_real());
+ continue;
+ default:
+ continue;
+ }
+ }
+
+ // Add strings, subtables and repeated fields.
+ for (const auto& it : offsets) {
+ builder->AddOffset(it.first, flatbuffers::Offset<void>(it.second));
+ }
+
+ return builder->EndTable(table_start);
+}
+
+std::string ReflectiveFlatbuffer::Serialize() const {
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(flatbuffers::Offset<void>(Serialize(&builder)));
+ return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize());
+}
+
+bool ReflectiveFlatbuffer::MergeFrom(const flatbuffers::Table* from) {
+ // No fields to set.
+ if (type_->fields() == nullptr) {
+ return true;
+ }
+
+ for (const reflection::Field* field : *type_->fields()) {
+ // Skip fields that are not explicitly set.
+ if (!from->CheckField(field->offset())) {
+ continue;
+ }
+ const reflection::BaseType type = field->type()->base_type();
+ switch (type) {
+ case reflection::Bool:
+ Set<bool>(field, from->GetField<uint8_t>(field->offset(),
+ field->default_integer()));
+ break;
+ case reflection::Int:
+ Set<int32>(field, from->GetField<int32>(field->offset(),
+ field->default_integer()));
+ break;
+ case reflection::Long:
+ Set<int64>(field, from->GetField<int64>(field->offset(),
+ field->default_integer()));
+ break;
+ case reflection::Float:
+ Set<float>(field, from->GetField<float>(field->offset(),
+ field->default_real()));
+ break;
+ case reflection::Double:
+ Set<double>(field, from->GetField<double>(field->offset(),
+ field->default_real()));
+ break;
+ case reflection::String:
+ Set<std::string>(
+ field, from->GetPointer<const flatbuffers::String*>(field->offset())
+ ->str());
+ break;
+ case reflection::Obj:
+ if (!Mutable(field)->MergeFrom(
+ from->GetPointer<const flatbuffers::Table* const>(
+ field->offset()))) {
+ return false;
+ }
+ break;
+ default:
+ TC3_LOG(ERROR) << "Unsupported type: " << type;
+ return false;
+ }
+ }
+ return true;
+}
+
+bool ReflectiveFlatbuffer::MergeFromSerializedFlatbuffer(StringPiece from) {
+ return MergeFrom(flatbuffers::GetAnyRoot(
+ reinterpret_cast<const unsigned char*>(from.data())));
+}
+
+void ReflectiveFlatbuffer::AsFlatMap(
+ const std::string& key_separator, const std::string& key_prefix,
+ std::map<std::string, Variant>* result) const {
+ // Add direct fields.
+ for (auto it : fields_) {
+ (*result)[key_prefix + it.first->name()->str()] = it.second;
+ }
+
+ // Add nested messages.
+ for (auto& it : children_) {
+ it.second->AsFlatMap(key_separator,
+ key_prefix + it.first->name()->str() + key_separator,
+ result);
+ }
+}
+
} // namespace libtextclassifier3
diff --git a/utils/flatbuffers.fbs b/utils/flatbuffers.fbs
new file mode 100755
index 0000000..584b885
--- /dev/null
+++ b/utils/flatbuffers.fbs
@@ -0,0 +1,32 @@
+//
+// Copyright (C) 2018 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+// Specifies a field in a flatbuffer message.
+namespace libtextclassifier3;
+table FlatbufferField {
+ // Name of the field.
+ field_name:string;
+
+ // Offset of the field
+ field_offset:int;
+}
+
+// Specifies a (nested) field in a flatbuffer message.
+namespace libtextclassifier3;
+table FlatbufferFieldPath {
+ field:[FlatbufferField];
+}
+
diff --git a/utils/flatbuffers.h b/utils/flatbuffers.h
index 4031f89..76b095f 100644
--- a/utils/flatbuffers.h
+++ b/utils/flatbuffers.h
@@ -19,11 +19,15 @@
#ifndef LIBTEXTCLASSIFIER_UTILS_FLATBUFFERS_H_
#define LIBTEXTCLASSIFIER_UTILS_FLATBUFFERS_H_
+#include <map>
#include <memory>
#include <string>
#include "annotator/model_generated.h"
+#include "utils/strings/stringpiece.h"
+#include "utils/variant.h"
#include "flatbuffers/flatbuffers.h"
+#include "flatbuffers/reflection.h"
namespace libtextclassifier3 {
@@ -93,6 +97,241 @@
builder.GetSize());
}
+// A flatbuffer that can be built using flatbuffer reflection data of the
+// schema.
+// Normally, field information is hard-coded in code generated from a flatbuffer
+// schema. Here we lookup the necessary information for building a flatbuffer
+// from the provided reflection meta data.
+// When serializing a flatbuffer, the library requires that the sub messages
+// are already serialized, therefore we explicitly keep the field values and
+// serialize the message in (reverse) topological dependency order.
+class ReflectiveFlatbuffer {
+ public:
+ ReflectiveFlatbuffer(const reflection::Schema* schema,
+ const reflection::Object* type)
+ : schema_(schema), type_(type) {}
+
+ // Encapsulates a repeated field.
+ // Serves as a common base class for repeated fields.
+ class RepeatedField {
+ public:
+ virtual ~RepeatedField() {}
+
+ virtual flatbuffers::uoffset_t Serialize(
+ flatbuffers::FlatBufferBuilder* builder) const = 0;
+ };
+
+ // Represents a repeated field of particular type.
+ template <typename T>
+ class TypedRepeatedField : public RepeatedField {
+ public:
+ void Add(const T value) { items_.push_back(value); }
+
+ flatbuffers::uoffset_t Serialize(
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ return builder->CreateVector(items_).o;
+ }
+
+ private:
+ std::vector<T> items_;
+ };
+
+ // Specialization for strings.
+ template <>
+ class TypedRepeatedField<std::string> : public RepeatedField {
+ public:
+ void Add(const std::string& value) { items_.push_back(value); }
+
+ flatbuffers::uoffset_t Serialize(
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ std::vector<flatbuffers::Offset<flatbuffers::String>> offsets(
+ items_.size());
+ for (int i = 0; i < items_.size(); i++) {
+ offsets[i] = builder->CreateString(items_[i]);
+ }
+ return builder->CreateVector(offsets).o;
+ }
+
+ private:
+ std::vector<std::string> items_;
+ };
+
+ // Specialization for repeated sub-messages.
+ template <>
+ class TypedRepeatedField<ReflectiveFlatbuffer> : public RepeatedField {
+ public:
+ TypedRepeatedField<ReflectiveFlatbuffer>(
+ const reflection::Schema* const schema,
+ const reflection::Type* const type)
+ : schema_(schema), type_(type) {}
+
+ ReflectiveFlatbuffer* Add() {
+ items_.emplace_back(new ReflectiveFlatbuffer(
+ schema_, schema_->objects()->Get(type_->index())));
+ return items_.back().get();
+ }
+
+ flatbuffers::uoffset_t Serialize(
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ std::vector<flatbuffers::Offset<void>> offsets(items_.size());
+ for (int i = 0; i < items_.size(); i++) {
+ offsets[i] = items_[i]->Serialize(builder);
+ }
+ return builder->CreateVector(offsets).o;
+ }
+
+ private:
+ const reflection::Schema* const schema_;
+ const reflection::Type* const type_;
+ std::vector<std::unique_ptr<ReflectiveFlatbuffer>> items_;
+ };
+
+ // Gets the field information for a field name, returns nullptr if the
+ // field was not defined.
+ const reflection::Field* GetFieldOrNull(const StringPiece field_name) const;
+ const reflection::Field* GetFieldOrNull(const FlatbufferField* field) const;
+ const reflection::Field* GetFieldByOffsetOrNull(const int field_offset) const;
+
+ // Gets a nested field and the message it is defined on.
+ bool GetFieldWithParent(const FlatbufferFieldPath* field_path,
+ ReflectiveFlatbuffer** parent,
+ reflection::Field const** field);
+
+ // Checks whether a variant value type agrees with a field type.
+ bool IsMatchingType(const reflection::Field* field,
+ const Variant& value) const;
+
+ // Sets a (primitive) field to a specific value.
+ // Returns true if successful, and false if the field was not found or the
+ // expected type doesn't match.
+ template <typename T>
+ bool Set(StringPiece field_name, T value) {
+ if (const reflection::Field* field = GetFieldOrNull(field_name)) {
+ return Set<T>(field, value);
+ }
+ return false;
+ }
+
+ // Sets a (primitive) field to a specific value.
+ // Returns true if successful, and false if the expected type doesn't match.
+ // Expects `field` to be non-null.
+ template <typename T>
+ bool Set(const reflection::Field* field, T value) {
+ if (field == nullptr) {
+ TC3_LOG(ERROR) << "Expected non-null field.";
+ return false;
+ }
+ Variant variant_value(value);
+ if (!IsMatchingType(field, variant_value)) {
+ TC3_LOG(ERROR) << "Type mismatch for field `" << field->name()->str()
+ << "`, expected: " << field->type()->base_type()
+ << ", got: " << variant_value.GetType();
+ return false;
+ }
+ fields_[field] = variant_value;
+ return true;
+ }
+
+ template <typename T>
+ bool Set(const FlatbufferFieldPath* path, T value) {
+ ReflectiveFlatbuffer* parent;
+ const reflection::Field* field;
+ if (!GetFieldWithParent(path, &parent, &field)) {
+ return false;
+ }
+ return parent->Set<T>(field, value);
+ }
+
+ // Sets a (primitive) field to a specific value.
+ // Parses the string value according to the field type.
+ bool ParseAndSet(const reflection::Field* field, const std::string& value);
+ bool ParseAndSet(const FlatbufferFieldPath* path, const std::string& value);
+
+ // Gets the reflective flatbuffer for a table field.
+ // Returns nullptr if the field was not found, or the field type was not a
+ // table.
+ ReflectiveFlatbuffer* Mutable(StringPiece field_name);
+ ReflectiveFlatbuffer* Mutable(const reflection::Field* field);
+
+ // Gets the reflective flatbuffer for a repeated field.
+ // Returns nullptr if the field was not found, or the field type was not a
+ // vector.
+ RepeatedField* Repeated(StringPiece field_name);
+ RepeatedField* Repeated(const reflection::Field* field);
+
+ template <typename T>
+ TypedRepeatedField<T>* Repeated(const reflection::Field* field) {
+ return static_cast<TypedRepeatedField<T>*>(Repeated(field));
+ }
+
+ template <typename T>
+ TypedRepeatedField<T>* Repeated(StringPiece field_name) {
+ return static_cast<TypedRepeatedField<T>*>(Repeated(field_name));
+ }
+
+ // Serializes the flatbuffer.
+ flatbuffers::uoffset_t Serialize(
+ flatbuffers::FlatBufferBuilder* builder) const;
+ std::string Serialize() const;
+
+ // Merges the fields from the given flatbuffer table into this flatbuffer.
+ // Scalar fields will be overwritten, if present in `from`.
+ // Embedded messages will be merged.
+ bool MergeFrom(const flatbuffers::Table* from);
+ bool MergeFromSerializedFlatbuffer(StringPiece from);
+
+ // Flattens the flatbuffer as a flat map.
+ // (Nested) fields names are joined by `key_separator`.
+ std::map<std::string, Variant> AsFlatMap(
+ const std::string& key_separator = ".") const {
+ std::map<std::string, Variant> result;
+ AsFlatMap(key_separator, /*key_prefix=*/"", &result);
+ return result;
+ }
+
+ private:
+ const reflection::Schema* const schema_;
+ const reflection::Object* const type_;
+
+ // Cached primitive fields (scalars and strings).
+ std::map<const reflection::Field*, Variant> fields_;
+
+ // Cached sub-messages.
+ std::map<const reflection::Field*, std::unique_ptr<ReflectiveFlatbuffer>>
+ children_;
+
+ // Cached repeated fields.
+ std::map<const reflection::Field*, std::unique_ptr<RepeatedField>>
+ repeated_fields_;
+
+ // Flattens the flatbuffer as a flat map.
+ // (Nested) fields names are joined by `key_separator` and prefixed by
+ // `key_prefix`.
+ void AsFlatMap(const std::string& key_separator,
+ const std::string& key_prefix,
+ std::map<std::string, Variant>* result) const;
+};
+
+// A helper class to build flatbuffers based on schema reflection data.
+// Can be used to a `ReflectiveFlatbuffer` for the root message of the
+// schema, or any defined table via name.
+class ReflectiveFlatbufferBuilder {
+ public:
+ explicit ReflectiveFlatbufferBuilder(const reflection::Schema* schema)
+ : schema_(schema) {}
+
+ // Starts a new root table message.
+ std::unique_ptr<ReflectiveFlatbuffer> NewRoot() const;
+
+ // Starts a new table message. Returns nullptr if no table with given name is
+ // found in the schema.
+ std::unique_ptr<ReflectiveFlatbuffer> NewTable(
+ const StringPiece table_name) const;
+
+ private:
+ const reflection::Schema* const schema_;
+};
+
} // namespace libtextclassifier3
#endif // LIBTEXTCLASSIFIER_UTILS_FLATBUFFERS_H_
diff --git a/utils/flatbuffers_test.cc b/utils/flatbuffers_test.cc
new file mode 100644
index 0000000..348ca73
--- /dev/null
+++ b/utils/flatbuffers_test.cc
@@ -0,0 +1,311 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <fstream>
+#include <map>
+#include <memory>
+#include <string>
+
+#include "utils/flatbuffers.h"
+#include "utils/flatbuffers_generated.h"
+#include "utils/flatbuffers_test_generated.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "flatbuffers/flatbuffers.h"
+#include "flatbuffers/reflection.h"
+#include "flatbuffers/reflection_generated.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+std::string GetTestMetadataPath() {
+ return "flatbuffers_test.bfbs";
+}
+
+std::string LoadTestMetadata() {
+ std::ifstream test_config_stream(GetTestMetadataPath());
+ return std::string((std::istreambuf_iterator<char>(test_config_stream)),
+ (std::istreambuf_iterator<char>()));
+}
+
+TEST(FlatbuffersTest, PrimitiveFieldsAreCorrectlySet) {
+ std::string metadata_buffer = LoadTestMetadata();
+ ReflectiveFlatbufferBuilder reflective_builder(
+ flatbuffers::GetRoot<reflection::Schema>(metadata_buffer.data()));
+
+ std::unique_ptr<ReflectiveFlatbuffer> buffer = reflective_builder.NewRoot();
+ EXPECT_TRUE(buffer != nullptr);
+ EXPECT_TRUE(buffer->Set("an_int_field", 42));
+ EXPECT_TRUE(buffer->Set("a_long_field", 84ll));
+ EXPECT_TRUE(buffer->Set("a_bool_field", true));
+ EXPECT_TRUE(buffer->Set("a_float_field", 1.f));
+ EXPECT_TRUE(buffer->Set("a_double_field", 1.0));
+
+ // Try to parse with the generated code.
+ std::string serialized_entity_data = buffer->Serialize();
+ std::unique_ptr<test::EntityDataT> entity_data =
+ LoadAndVerifyMutableFlatbuffer<test::EntityData>(
+ serialized_entity_data.data(), serialized_entity_data.size());
+ EXPECT_TRUE(entity_data != nullptr);
+ EXPECT_EQ(entity_data->an_int_field, 42);
+ EXPECT_EQ(entity_data->a_long_field, 84);
+ EXPECT_EQ(entity_data->a_bool_field, true);
+ EXPECT_NEAR(entity_data->a_float_field, 1.f, 1e-4);
+ EXPECT_NEAR(entity_data->a_double_field, 1.f, 1e-4);
+}
+
+TEST(FlatbuffersTest, HandlesUnknownFields) {
+ std::string metadata_buffer = LoadTestMetadata();
+ const reflection::Schema* schema =
+ flatbuffers::GetRoot<reflection::Schema>(metadata_buffer.data());
+ ReflectiveFlatbufferBuilder reflective_builder(schema);
+
+ std::unique_ptr<ReflectiveFlatbuffer> buffer = reflective_builder.NewRoot();
+ EXPECT_TRUE(buffer != nullptr);
+
+ // Add a field that is not known to the (statically generated) code.
+ EXPECT_TRUE(buffer->Set("mystic", "this is an unknown field."));
+
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(flatbuffers::Offset<void>(buffer->Serialize(&builder)));
+
+ // Try to read the field again.
+ const flatbuffers::Table* extra =
+ flatbuffers::GetAnyRoot(builder.GetBufferPointer());
+ EXPECT_EQ(extra
+ ->GetPointer<const flatbuffers::String*>(
+ buffer->GetFieldOrNull("mystic")->offset())
+ ->str(),
+ "this is an unknown field.");
+}
+
+TEST(FlatbuffersTest, HandlesNestedFields) {
+ std::string metadata_buffer = LoadTestMetadata();
+ const reflection::Schema* schema =
+ flatbuffers::GetRoot<reflection::Schema>(metadata_buffer.data());
+ ReflectiveFlatbufferBuilder reflective_builder(schema);
+
+ FlatbufferFieldPathT path;
+ path.field.emplace_back(new FlatbufferFieldT);
+ path.field.back()->field_name = "flight_number";
+ path.field.emplace_back(new FlatbufferFieldT);
+ path.field.back()->field_name = "carrier_code";
+ flatbuffers::FlatBufferBuilder path_builder;
+ path_builder.Finish(FlatbufferFieldPath::Pack(path_builder, &path));
+
+ std::unique_ptr<ReflectiveFlatbuffer> buffer = reflective_builder.NewRoot();
+
+ ReflectiveFlatbuffer* parent = nullptr;
+ reflection::Field const* field = nullptr;
+ EXPECT_TRUE(
+ buffer->GetFieldWithParent(flatbuffers::GetRoot<FlatbufferFieldPath>(
+ path_builder.GetBufferPointer()),
+ &parent, &field));
+ EXPECT_EQ(parent, buffer->Mutable("flight_number"));
+ EXPECT_EQ(field,
+ buffer->Mutable("flight_number")->GetFieldOrNull("carrier_code"));
+}
+
+TEST(FlatbuffersTest, HandlesMultipleNestedFields) {
+ std::string metadata_buffer = LoadTestMetadata();
+ ReflectiveFlatbufferBuilder reflective_builder(
+ flatbuffers::GetRoot<reflection::Schema>(metadata_buffer.data()));
+
+ std::unique_ptr<ReflectiveFlatbuffer> buffer = reflective_builder.NewRoot();
+ ReflectiveFlatbuffer* flight_info = buffer->Mutable("flight_number");
+ flight_info->Set("carrier_code", "LX");
+ flight_info->Set("flight_code", 38);
+
+ ReflectiveFlatbuffer* contact_info = buffer->Mutable("contact_info");
+ EXPECT_TRUE(contact_info->Set("first_name", "Barack"));
+ EXPECT_TRUE(contact_info->Set("last_name", "Obama"));
+ EXPECT_TRUE(contact_info->Set("phone_number", "1-800-TEST"));
+ EXPECT_TRUE(contact_info->Set("score", 1.f));
+
+ // Try to parse with the generated code.
+ std::string serialized_entity_data = buffer->Serialize();
+ std::unique_ptr<test::EntityDataT> entity_data =
+ LoadAndVerifyMutableFlatbuffer<test::EntityData>(
+ serialized_entity_data.data(), serialized_entity_data.size());
+ EXPECT_TRUE(entity_data != nullptr);
+ EXPECT_EQ(entity_data->flight_number->carrier_code, "LX");
+ EXPECT_EQ(entity_data->flight_number->flight_code, 38);
+ EXPECT_EQ(entity_data->contact_info->first_name, "Barack");
+ EXPECT_EQ(entity_data->contact_info->last_name, "Obama");
+ EXPECT_EQ(entity_data->contact_info->phone_number, "1-800-TEST");
+ EXPECT_NEAR(entity_data->contact_info->score, 1.f, 1e-4);
+}
+
+TEST(FlatbuffersTest, HandlesFieldsSetWithNamePath) {
+ std::string metadata_buffer = LoadTestMetadata();
+ ReflectiveFlatbufferBuilder reflective_builder(
+ flatbuffers::GetRoot<reflection::Schema>(metadata_buffer.data()));
+
+ FlatbufferFieldPathT path;
+ path.field.emplace_back(new FlatbufferFieldT);
+ path.field.back()->field_name = "flight_number";
+ path.field.emplace_back(new FlatbufferFieldT);
+ path.field.back()->field_name = "carrier_code";
+ flatbuffers::FlatBufferBuilder path_builder;
+ path_builder.Finish(FlatbufferFieldPath::Pack(path_builder, &path));
+
+ std::unique_ptr<ReflectiveFlatbuffer> buffer = reflective_builder.NewRoot();
+ // Test setting value using Set function.
+ buffer->Mutable("flight_number")->Set("flight_code", 38);
+ // Test setting value using FlatbufferFieldPath.
+ buffer->Set(flatbuffers::GetRoot<FlatbufferFieldPath>(
+ path_builder.GetBufferPointer()),
+ "LX");
+
+ // Try to parse with the generated code.
+ std::string serialized_entity_data = buffer->Serialize();
+ std::unique_ptr<test::EntityDataT> entity_data =
+ LoadAndVerifyMutableFlatbuffer<test::EntityData>(
+ serialized_entity_data.data(), serialized_entity_data.size());
+ EXPECT_TRUE(entity_data != nullptr);
+ EXPECT_EQ(entity_data->flight_number->carrier_code, "LX");
+ EXPECT_EQ(entity_data->flight_number->flight_code, 38);
+}
+
+TEST(FlatbuffersTest, HandlesFieldsSetWithOffsetPath) {
+ std::string metadata_buffer = LoadTestMetadata();
+ ReflectiveFlatbufferBuilder reflective_builder(
+ flatbuffers::GetRoot<reflection::Schema>(metadata_buffer.data()));
+
+ FlatbufferFieldPathT path;
+ path.field.emplace_back(new FlatbufferFieldT);
+ path.field.back()->field_offset = 14;
+ path.field.emplace_back(new FlatbufferFieldT);
+ path.field.back()->field_offset = 4;
+ flatbuffers::FlatBufferBuilder path_builder;
+ path_builder.Finish(FlatbufferFieldPath::Pack(path_builder, &path));
+
+ std::unique_ptr<ReflectiveFlatbuffer> buffer = reflective_builder.NewRoot();
+ // Test setting value using Set function.
+ buffer->Mutable("flight_number")->Set("flight_code", 38);
+ // Test setting value using FlatbufferFieldPath.
+ buffer->Set(flatbuffers::GetRoot<FlatbufferFieldPath>(
+ path_builder.GetBufferPointer()),
+ "LX");
+
+ // Try to parse with the generated code.
+ std::string serialized_entity_data = buffer->Serialize();
+ std::unique_ptr<test::EntityDataT> entity_data =
+ LoadAndVerifyMutableFlatbuffer<test::EntityData>(
+ serialized_entity_data.data(), serialized_entity_data.size());
+ EXPECT_TRUE(entity_data != nullptr);
+ EXPECT_EQ(entity_data->flight_number->carrier_code, "LX");
+ EXPECT_EQ(entity_data->flight_number->flight_code, 38);
+}
+
+TEST(FlatbuffersTest, PartialBuffersAreCorrectlyMerged) {
+ std::string metadata_buffer = LoadTestMetadata();
+ ReflectiveFlatbufferBuilder reflective_builder(
+ flatbuffers::GetRoot<reflection::Schema>(metadata_buffer.data()));
+ std::unique_ptr<ReflectiveFlatbuffer> buffer = reflective_builder.NewRoot();
+ buffer->Set("an_int_field", 42);
+ buffer->Set("a_long_field", 84ll);
+ ReflectiveFlatbuffer* flight_info = buffer->Mutable("flight_number");
+ flight_info->Set("carrier_code", "LX");
+ flight_info->Set("flight_code", 38);
+
+ // Create message to merge.
+ test::EntityDataT additional_entity_data;
+ additional_entity_data.an_int_field = 43;
+ additional_entity_data.flight_number.reset(new test::FlightNumberInfoT);
+ additional_entity_data.flight_number->flight_code = 39;
+ additional_entity_data.contact_info.reset(new test::ContactInfoT);
+ additional_entity_data.contact_info->first_name = "Barack";
+ flatbuffers::FlatBufferBuilder to_merge_builder;
+ to_merge_builder.Finish(
+ test::EntityData::Pack(to_merge_builder, &additional_entity_data));
+
+ // Merge it.
+ EXPECT_TRUE(buffer->MergeFrom(
+ flatbuffers::GetAnyRoot(to_merge_builder.GetBufferPointer())));
+
+ // Try to parse it with the generated code.
+ std::string serialized_entity_data = buffer->Serialize();
+ std::unique_ptr<test::EntityDataT> entity_data =
+ LoadAndVerifyMutableFlatbuffer<test::EntityData>(
+ serialized_entity_data.data(), serialized_entity_data.size());
+ EXPECT_TRUE(entity_data != nullptr);
+ EXPECT_EQ(entity_data->an_int_field, 43);
+ EXPECT_EQ(entity_data->a_long_field, 84);
+ EXPECT_EQ(entity_data->flight_number->carrier_code, "LX");
+ EXPECT_EQ(entity_data->flight_number->flight_code, 39);
+ EXPECT_EQ(entity_data->contact_info->first_name, "Barack");
+}
+
+TEST(FlatbuffersTest, PrimitiveAndNestedFieldsAreCorrectlyFlattened) {
+ std::string metadata_buffer = LoadTestMetadata();
+ ReflectiveFlatbufferBuilder reflective_builder(
+ flatbuffers::GetRoot<reflection::Schema>(metadata_buffer.data()));
+ std::unique_ptr<ReflectiveFlatbuffer> buffer = reflective_builder.NewRoot();
+ buffer->Set("an_int_field", 42);
+ buffer->Set("a_long_field", 84ll);
+ ReflectiveFlatbuffer* flight_info = buffer->Mutable("flight_number");
+ flight_info->Set("carrier_code", "LX");
+ flight_info->Set("flight_code", 38);
+
+ std::map<std::string, Variant> entity_data_map = buffer->AsFlatMap();
+ EXPECT_EQ(4, entity_data_map.size());
+ EXPECT_EQ(42, entity_data_map["an_int_field"].IntValue());
+ EXPECT_EQ(84, entity_data_map["a_long_field"].Int64Value());
+ EXPECT_EQ("LX", entity_data_map["flight_number.carrier_code"].StringValue());
+ EXPECT_EQ(38, entity_data_map["flight_number.flight_code"].IntValue());
+}
+
+TEST(FlatbuffersTest, RepeatedFieldSetThroughReflectionCanBeRead) {
+ std::string metadata_buffer = LoadTestMetadata();
+ const reflection::Schema* schema =
+ flatbuffers::GetRoot<reflection::Schema>(metadata_buffer.data());
+ ReflectiveFlatbufferBuilder reflective_builder(schema);
+ std::unique_ptr<ReflectiveFlatbuffer> buffer = reflective_builder.NewRoot();
+
+ auto reminders = buffer->Repeated<ReflectiveFlatbuffer>("reminders");
+ {
+ auto reminder = reminders->Add();
+ reminder->Set("title", "test reminder");
+ auto notes = reminder->Repeated<std::string>("notes");
+ notes->Add("note A");
+ notes->Add("note B");
+ }
+ {
+ auto reminder = reminders->Add();
+ reminder->Set("title", "test reminder 2");
+ auto notes = reminder->Repeated<std::string>("notes");
+ notes->Add("note i");
+ notes->Add("note ii");
+ notes->Add("note iii");
+ }
+ const std::string serialized_entity_data = buffer->Serialize();
+
+ std::unique_ptr<test::EntityDataT> entity_data =
+ LoadAndVerifyMutableFlatbuffer<test::EntityData>(
+ serialized_entity_data.data(), serialized_entity_data.size());
+ EXPECT_TRUE(entity_data != nullptr);
+ EXPECT_EQ(2, entity_data->reminders.size());
+ EXPECT_EQ("test reminder", entity_data->reminders[0]->title);
+ EXPECT_THAT(entity_data->reminders[0]->notes,
+ testing::ElementsAreArray({"note A", "note B"}));
+ EXPECT_EQ("test reminder 2", entity_data->reminders[1]->title);
+ EXPECT_THAT(entity_data->reminders[1]->notes,
+ testing::ElementsAreArray({"note i", "note ii", "note iii"}));
+}
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/utils/flatbuffers_test.fbs b/utils/flatbuffers_test.fbs
new file mode 100644
index 0000000..0d5b09b
--- /dev/null
+++ b/utils/flatbuffers_test.fbs
@@ -0,0 +1,47 @@
+//
+// Copyright (C) 2018 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+namespace libtextclassifier3.test;
+
+table FlightNumberInfo {
+ carrier_code: string;
+ flight_code: int;
+}
+
+table ContactInfo {
+ first_name: string;
+ last_name: string;
+ phone_number: string;
+ score: float;
+}
+
+table Reminder {
+ title: string;
+ notes: [string];
+}
+
+table EntityData {
+ an_int_field: int;
+ a_long_field: int64;
+ a_bool_field: bool;
+ a_float_field: float;
+ a_double_field: double;
+ flight_number: FlightNumberInfo;
+ contact_info: ContactInfo;
+ reminders: [Reminder];
+}
+
+root_type libtextclassifier3.test.EntityData;
diff --git a/utils/i18n/locale.cc b/utils/i18n/locale.cc
index acd0379..6349d63 100644
--- a/utils/i18n/locale.cc
+++ b/utils/i18n/locale.cc
@@ -21,8 +21,16 @@
namespace libtextclassifier3 {
namespace {
+constexpr const char* kAnyMatch = "*";
+
+// BCP 47 code for "Undetermined Language".
+constexpr const char* kUnknownLanguageCode = "und";
bool CheckLanguage(StringPiece language) {
+ if (language.size() == 1 && language.data()[0] == '*') {
+ return true;
+ }
+
if (language.size() != 2 && language.size() != 3) {
return false;
}
@@ -107,4 +115,78 @@
return Locale(language.ToString(), script.ToString(), region.ToString());
}
+bool Locale::IsUnknown() const {
+ return is_valid_ && language_ == kUnknownLanguageCode;
+}
+
+bool Locale::IsLocaleSupported(const Locale& locale,
+ const std::vector<Locale>& supported_locales,
+ bool default_value) {
+ if (!locale.IsValid()) {
+ return false;
+ }
+ if (locale.IsUnknown()) {
+ return default_value;
+ }
+ for (const Locale& supported_locale : supported_locales) {
+ if (!supported_locale.IsValid()) {
+ continue;
+ }
+ const bool language_matches =
+ supported_locale.Language().empty() ||
+ supported_locale.Language() == kAnyMatch ||
+ supported_locale.Language() == locale.Language();
+ const bool script_matches = supported_locale.Script().empty() ||
+ supported_locale.Script() == kAnyMatch ||
+ locale.Script().empty() ||
+ supported_locale.Script() == locale.Script();
+ const bool region_matches = supported_locale.Region().empty() ||
+ supported_locale.Region() == kAnyMatch ||
+ locale.Region().empty() ||
+ supported_locale.Region() == locale.Region();
+ if (language_matches && script_matches && region_matches) {
+ return true;
+ }
+ }
+ return false;
+}
+
+bool Locale::IsAnyLocaleSupported(const std::vector<Locale>& locales,
+ const std::vector<Locale>& supported_locales,
+ bool default_value) {
+ if (locales.empty()) {
+ return default_value;
+ }
+ if (supported_locales.empty()) {
+ return default_value;
+ }
+ for (const Locale& locale : locales) {
+ if (IsLocaleSupported(locale, supported_locales, default_value)) {
+ return true;
+ }
+ }
+ return false;
+}
+
+logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
+ const Locale& locale) {
+ return stream << "Locale(language=" << locale.Language()
+ << ", script=" << locale.Script()
+ << ", region=" << locale.Region()
+ << ", is_valid=" << locale.IsValid()
+ << ", is_unknown=" << locale.IsUnknown() << ")";
+}
+
+bool ParseLocales(StringPiece locales_list, std::vector<Locale>* locales) {
+ for (const auto& locale_str : strings::Split(locales_list, ',')) {
+ const Locale locale = Locale::FromBCP47(locale_str.ToString());
+ if (!locale.IsValid()) {
+ TC3_LOG(ERROR) << "Invalid locale " << locale_str.ToString();
+ return false;
+ }
+ locales->push_back(locale);
+ }
+ return true;
+}
+
} // namespace libtextclassifier3
diff --git a/utils/i18n/locale.h b/utils/i18n/locale.h
index 4cfcc22..4420b56 100644
--- a/utils/i18n/locale.h
+++ b/utils/i18n/locale.h
@@ -18,8 +18,11 @@
#define LIBTEXTCLASSIFIER_UTILS_I18N_LOCALE_H_
#include <string>
+#include <vector>
#include "utils/base/integral_types.h"
+#include "utils/base/logging.h"
+#include "utils/strings/stringpiece.h"
namespace libtextclassifier3 {
@@ -43,6 +46,15 @@
std::string Region() const { return region_; }
bool IsValid() const { return is_valid_; }
+ bool IsUnknown() const;
+
+ // Returns whether any of the given locales is supported by any of the
+ // supported locales. Returns default value if the given 'locales' list, or
+ // 'supported_locales' list is empty or an unknown locale is found.
+ // Locale::FromBCP47("*") means any locale.
+ static bool IsAnyLocaleSupported(const std::vector<Locale>& locales,
+ const std::vector<Locale>& supported_locales,
+ bool default_value);
private:
Locale(const std::string& language, const std::string& script,
@@ -52,12 +64,23 @@
region_(region),
is_valid_(true) {}
+ static bool IsLocaleSupported(const Locale& locale,
+ const std::vector<Locale>& supported_locales,
+ bool default_value);
+
std::string language_;
std::string script_;
std::string region_;
bool is_valid_;
};
+// Pretty-printing function for Locale.
+logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
+ const Locale& locale);
+
+// Parses a comma-separated list of BCP47 tags.
+bool ParseLocales(StringPiece locales_list, std::vector<Locale>* locales);
+
} // namespace libtextclassifier3
#endif // LIBTEXTCLASSIFIER_UTILS_I18N_LOCALE_H_
diff --git a/utils/i18n/locale_test.cc b/utils/i18n/locale_test.cc
index 3722727..faea4f6 100644
--- a/utils/i18n/locale_test.cc
+++ b/utils/i18n/locale_test.cc
@@ -66,5 +66,38 @@
EXPECT_EQ(locale.Region(), "");
}
+TEST(LocaleTest, IsAnyLocaleSupportedMatch) {
+ std::vector<Locale> locales = {Locale::FromBCP47("zh-HK"),
+ Locale::FromBCP47("en-UK")};
+ std::vector<Locale> supported_locales = {Locale::FromBCP47("en")};
+
+ EXPECT_TRUE(Locale::IsAnyLocaleSupported(locales, supported_locales,
+ /*default_value=*/false));
+}
+
+TEST(LocaleTest, IsAnyLocaleSupportedNotMatch) {
+ std::vector<Locale> locales = {Locale::FromBCP47("zh-tw")};
+ std::vector<Locale> supported_locales = {Locale::FromBCP47("en"),
+ Locale::FromBCP47("fr")};
+
+ EXPECT_FALSE(Locale::IsAnyLocaleSupported(locales, supported_locales,
+ /*default_value=*/false));
+}
+
+TEST(LocaleTest, IsAnyLocaleSupportedAnyLocale) {
+ std::vector<Locale> locales = {Locale::FromBCP47("zh-tw")};
+ std::vector<Locale> supported_locales = {Locale::FromBCP47("*")};
+
+ EXPECT_TRUE(Locale::IsAnyLocaleSupported(locales, supported_locales,
+ /*default_value=*/false));
+}
+
+TEST(LocaleTest, IsAnyLocaleSupportedEmptyLocales) {
+ std::vector<Locale> supported_locales = {Locale::FromBCP47("en")};
+
+ EXPECT_TRUE(Locale::IsAnyLocaleSupported({}, supported_locales,
+ /*default_value=*/true));
+}
+
} // namespace
} // namespace libtextclassifier3
diff --git a/utils/intents/IntentGeneratorTest.java b/utils/intents/IntentGeneratorTest.java
new file mode 100644
index 0000000..f43ecc0
--- /dev/null
+++ b/utils/intents/IntentGeneratorTest.java
@@ -0,0 +1,43 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.google.android.textclassifier.utils.intents;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import android.content.Context;
+import androidx.test.InstrumentationRegistry;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/** This is just a launcher of the tests because we need a valid JNIEnv in the C++ test. */
+@RunWith(JUnit4.class)
+public final class IntentGeneratorTest {
+
+ @Before
+ public void setUp() throws Exception {
+ System.loadLibrary("intent-generator-test-lib");
+ }
+
+ private native boolean testsMain(Context context);
+
+ @Test
+ public void testNative() {
+ assertThat(testsMain(InstrumentationRegistry.getContext())).isTrue();
+ }
+}
diff --git a/utils/intents/intent-config.fbs b/utils/intents/intent-config.fbs
index 93a6fc9..09ebbb4 100755
--- a/utils/intents/intent-config.fbs
+++ b/utils/intents/intent-config.fbs
@@ -14,6 +14,8 @@
// limitations under the License.
//
+include "utils/zlib/buffer.fbs";
+
// The type of variable to fetch.
namespace libtextclassifier3;
enum AndroidSimpleIntentGeneratorVariableType : int {
@@ -73,7 +75,7 @@
// implements the Intent generation logic.
namespace libtextclassifier3;
table AndroidIntentFactoryOptions {
- entity:[libtextclassifier3.AndroidIntentFactoryEntityOptions];
+ entity:[AndroidIntentFactoryEntityOptions];
}
// Describes how intents should be generated for a particular entity type.
@@ -85,17 +87,17 @@
// List of generators for all the different types of intents that should
// be made available for the entity type.
- generator:[libtextclassifier3.AndroidIntentGeneratorOptions];
+ generator:[AndroidIntentGeneratorOptions];
}
// Configures a single Android Intent generator.
namespace libtextclassifier3;
table AndroidIntentGeneratorOptions {
// Strings for UI elements.
- strings:[libtextclassifier3.AndroidIntentGeneratorStrings];
+ strings:[AndroidIntentGeneratorStrings];
// Generator specific configuration.
- simple:libtextclassifier3.AndroidSimpleIntentGeneratorOptions;
+ simple:AndroidSimpleIntentGeneratorOptions;
}
// Language dependent configuration for an Android Intent generator.
@@ -122,7 +124,7 @@
name:string;
// The type of the extra to set.
- type:libtextclassifier3.AndroidSimpleIntentGeneratorExtraType;
+ type:AndroidSimpleIntentGeneratorExtraType;
string_:string;
@@ -133,7 +135,7 @@
// A condition that needs to be fulfilled for an Intent to get generated.
namespace libtextclassifier3;
table AndroidSimpleIntentGeneratorCondition {
- type:libtextclassifier3.AndroidSimpleIntentGeneratorConditionType;
+ type:AndroidSimpleIntentGeneratorConditionType;
string_:string;
@@ -161,32 +163,37 @@
type:string;
// The list of all the extras to add to the Intent.
- extra:[libtextclassifier3.AndroidSimpleIntentGeneratorExtra];
+ extra:[AndroidSimpleIntentGeneratorExtra];
// The list of all the variables that become available for substitution in
// the action, data, type and extra strings. To e.g. set a field to the value
// of the first variable, use "%0$s".
- variable:[libtextclassifier3.AndroidSimpleIntentGeneratorVariableType];
+ variable:[AndroidSimpleIntentGeneratorVariableType];
// The list of all conditions that need to be fulfilled for Intent generation.
- condition:[libtextclassifier3.AndroidSimpleIntentGeneratorCondition];
+ condition:[AndroidSimpleIntentGeneratorCondition];
}
// Describes how intents should be generated for a particular entity type.
namespace libtextclassifier3.IntentFactoryModel_;
table IntentGenerator {
- // The entity type as defined by on the TextClassifier ENTITY_TYPE constants
- // e.g. "address", "phone", etc.
- entity_type:string;
+ // The type of the intent generator, e.g. the entity type as defined by
+ // on the TextClassifier ENTITY_TYPE constants e.g. "address", "phone", etc.
+ type:string;
// The template generator lua code, either as text source or precompiled
// bytecode.
lua_template_generator:[ubyte];
+
+ compressed_lua_template_generator:CompressedBuffer;
}
// Describes how intents for the various entity types should be generated.
namespace libtextclassifier3;
table IntentFactoryModel {
- entities:[libtextclassifier3.IntentFactoryModel_.IntentGenerator];
+ generator:[IntentFactoryModel_.IntentGenerator];
+
+ // Whether to precompile the generators when loading.
+ precompile_generators:bool = false;
}
diff --git a/utils/intents/intent-generator.cc b/utils/intents/intent-generator.cc
new file mode 100644
index 0000000..f882515
--- /dev/null
+++ b/utils/intents/intent-generator.cc
@@ -0,0 +1,899 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/intents/intent-generator.h"
+
+#include <vector>
+
+#include "actions/lua-utils.h"
+#include "actions/types.h"
+#include "annotator/types.h"
+#include "utils/base/logging.h"
+#include "utils/hash/farmhash.h"
+#include "utils/java/jni-base.h"
+#include "utils/java/string_utils.h"
+#include "utils/lua-utils.h"
+#include "utils/strings/stringpiece.h"
+#include "utils/strings/substitute.h"
+#include "utils/utf8/unicodetext.h"
+#include "utils/variant.h"
+#include "utils/zlib/zlib.h"
+#include "flatbuffers/reflection_generated.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+#include "lauxlib.h"
+#include "lua.h"
+#ifdef __cplusplus
+}
+#endif
+
+namespace libtextclassifier3 {
+namespace {
+
+static constexpr const char* kReferenceTimeUsecKey = "reference_time_ms_utc";
+static constexpr const char* kHashKey = "hash";
+static constexpr const char* kUrlSchemaKey = "url_schema";
+static constexpr const char* kUrlHostKey = "url_host";
+static constexpr const char* kUrlEncodeKey = "urlencode";
+static constexpr const char* kPackageNameKey = "package_name";
+static constexpr const char* kDeviceLocaleKey = "device_locales";
+static constexpr const char* kFormatKey = "format";
+
+// An Android specific Lua environment with JNI backed callbacks.
+class JniLuaEnvironment : public LuaEnvironment {
+ public:
+ JniLuaEnvironment(const Resources& resources, const JniCache* jni_cache,
+ const jobject context,
+ const std::vector<Locale>& device_locales);
+ // Environment setup.
+ bool Initialize();
+
+ // Runs an intent generator snippet.
+ bool RunIntentGenerator(const std::string& generator_snippet,
+ std::vector<RemoteActionTemplate>* remote_actions);
+
+ protected:
+ virtual void SetupExternalHook();
+
+ int HandleExternalCallback();
+ int HandleAndroidCallback();
+ int HandleUserRestrictionsCallback();
+ int HandleUrlEncode();
+ int HandleUrlSchema();
+ int HandleHash();
+ int HandleFormat();
+ int HandleAndroidStringResources();
+ int HandleUrlHost();
+
+ // Checks and retrieves string resources from the model.
+ bool LookupModelStringResource();
+
+ // Reads and create a RemoteAction result from Lua.
+ RemoteActionTemplate ReadRemoteActionTemplateResult();
+
+ // Reads the extras from the Lua result.
+ void ReadExtras(std::map<std::string, Variant>* extra);
+
+ // Reads the intent categories array from a Lua result.
+ void ReadCategories(std::vector<std::string>* category);
+
+ // Retrieves user manager if not previously done.
+ bool RetrieveUserManager();
+
+ // Retrieves system resources if not previously done.
+ bool RetrieveSystemResources();
+
+ // Parse the url string by using Uri.parse from Java.
+ ScopedLocalRef<jobject> ParseUri(StringPiece url) const;
+
+ // Read remote action templates from lua generator.
+ int ReadRemoteActionTemplates(std::vector<RemoteActionTemplate>* result);
+
+ const Resources& resources_;
+ JNIEnv* jenv_;
+ const JniCache* jni_cache_;
+ const jobject context_;
+ std::vector<Locale> device_locales_;
+
+ ScopedGlobalRef<jobject> usermanager_;
+ // Whether we previously attempted to retrieve the UserManager before.
+ bool usermanager_retrieved_;
+
+ ScopedGlobalRef<jobject> system_resources_;
+ // Whether we previously attempted to retrieve the system resources.
+ bool system_resources_resources_retrieved_;
+
+ // Cached JNI references for Java strings `string` and `android`.
+ ScopedGlobalRef<jstring> string_;
+ ScopedGlobalRef<jstring> android_;
+};
+
+JniLuaEnvironment::JniLuaEnvironment(const Resources& resources,
+ const JniCache* jni_cache,
+ const jobject context,
+ const std::vector<Locale>& device_locales)
+ : resources_(resources),
+ jenv_(jni_cache ? jni_cache->GetEnv() : nullptr),
+ jni_cache_(jni_cache),
+ context_(context),
+ device_locales_(device_locales),
+ usermanager_(/*object=*/nullptr,
+ /*jvm=*/(jni_cache ? jni_cache->jvm : nullptr)),
+ usermanager_retrieved_(false),
+ system_resources_(/*object=*/nullptr,
+ /*jvm=*/(jni_cache ? jni_cache->jvm : nullptr)),
+ system_resources_resources_retrieved_(false),
+ string_(/*object=*/nullptr,
+ /*jvm=*/(jni_cache ? jni_cache->jvm : nullptr)),
+ android_(/*object=*/nullptr,
+ /*jvm=*/(jni_cache ? jni_cache->jvm : nullptr)) {}
+
+bool JniLuaEnvironment::Initialize() {
+ string_ =
+ MakeGlobalRef(jenv_->NewStringUTF("string"), jenv_, jni_cache_->jvm);
+ android_ =
+ MakeGlobalRef(jenv_->NewStringUTF("android"), jenv_, jni_cache_->jvm);
+ if (string_ == nullptr || android_ == nullptr) {
+ TC3_LOG(ERROR) << "Could not allocate constant strings references.";
+ return false;
+ }
+ return (RunProtected([this] {
+ LoadDefaultLibraries();
+ SetupExternalHook();
+ lua_setglobal(state_, "external");
+ return LUA_OK;
+ }) == LUA_OK);
+}
+
+void JniLuaEnvironment::SetupExternalHook() {
+ // This exposes an `external` object with the following fields:
+ // * entity: the bundle with all information about a classification.
+ // * android: callbacks into specific android provided methods.
+ // * android.user_restrictions: callbacks to check user permissions.
+ // * android.R: callbacks to retrieve string resources.
+ BindTable<JniLuaEnvironment, &JniLuaEnvironment::HandleExternalCallback>(
+ "external");
+
+ // android
+ BindTable<JniLuaEnvironment, &JniLuaEnvironment::HandleAndroidCallback>(
+ "android");
+ {
+ // android.user_restrictions
+ BindTable<JniLuaEnvironment,
+ &JniLuaEnvironment::HandleUserRestrictionsCallback>(
+ "user_restrictions");
+ lua_setfield(state_, /*idx=*/-2, "user_restrictions");
+
+ // android.R
+ // Callback to access android string resources.
+ BindTable<JniLuaEnvironment,
+ &JniLuaEnvironment::HandleAndroidStringResources>("R");
+ lua_setfield(state_, /*idx=*/-2, "R");
+ }
+ lua_setfield(state_, /*idx=*/-2, "android");
+}
+
+int JniLuaEnvironment::HandleExternalCallback() {
+ const StringPiece key = ReadString(/*index=*/-1);
+ if (key.Equals(kHashKey)) {
+ Bind<JniLuaEnvironment, &JniLuaEnvironment::HandleHash>();
+ return 1;
+ } else if (key.Equals(kFormatKey)) {
+ Bind<JniLuaEnvironment, &JniLuaEnvironment::HandleFormat>();
+ return 1;
+ } else {
+ TC3_LOG(ERROR) << "Undefined external access " << key.ToString();
+ lua_error(state_);
+ return 0;
+ }
+}
+
+int JniLuaEnvironment::HandleAndroidCallback() {
+ const StringPiece key = ReadString(/*index=*/-1);
+ if (key.Equals(kDeviceLocaleKey)) {
+ // Provide the locale as table with the individual fields set.
+ lua_newtable(state_);
+ for (int i = 0; i < device_locales_.size(); i++) {
+ // Adjust index to 1-based indexing for Lua.
+ lua_pushinteger(state_, i + 1);
+ lua_newtable(state_);
+ PushString(device_locales_[i].Language());
+ lua_setfield(state_, -2, "language");
+ PushString(device_locales_[i].Region());
+ lua_setfield(state_, -2, "region");
+ PushString(device_locales_[i].Script());
+ lua_setfield(state_, -2, "script");
+ lua_settable(state_, /*idx=*/-3);
+ }
+ return 1;
+ } else if (key.Equals(kPackageNameKey)) {
+ if (context_ == nullptr) {
+ TC3_LOG(ERROR) << "Context invalid.";
+ lua_error(state_);
+ return 0;
+ }
+ ScopedLocalRef<jstring> package_name_str(
+ static_cast<jstring>(jenv_->CallObjectMethod(
+ context_, jni_cache_->context_get_package_name)));
+ if (jni_cache_->ExceptionCheckAndClear()) {
+ TC3_LOG(ERROR) << "Error calling Context.getPackageName";
+ lua_error(state_);
+ return 0;
+ }
+ PushString(ToStlString(jenv_, package_name_str.get()));
+ return 1;
+ } else if (key.Equals(kUrlEncodeKey)) {
+ Bind<JniLuaEnvironment, &JniLuaEnvironment::HandleUrlEncode>();
+ return 1;
+ } else if (key.Equals(kUrlHostKey)) {
+ Bind<JniLuaEnvironment, &JniLuaEnvironment::HandleUrlHost>();
+ return 1;
+ } else if (key.Equals(kUrlSchemaKey)) {
+ Bind<JniLuaEnvironment, &JniLuaEnvironment::HandleUrlSchema>();
+ return 1;
+ } else {
+ TC3_LOG(ERROR) << "Undefined android reference " << key.ToString();
+ lua_error(state_);
+ return 0;
+ }
+}
+
+int JniLuaEnvironment::HandleUserRestrictionsCallback() {
+ if (jni_cache_->usermanager_class == nullptr ||
+ jni_cache_->usermanager_get_user_restrictions == nullptr) {
+ // UserManager is only available for API level >= 17 and
+ // getUserRestrictions only for API level >= 18, so we just return false
+ // normally here.
+ lua_pushboolean(state_, false);
+ return 1;
+ }
+
+ // Get user manager if not previously retrieved.
+ if (!RetrieveUserManager()) {
+ TC3_LOG(ERROR) << "Error retrieving user manager.";
+ lua_error(state_);
+ return 0;
+ }
+
+ ScopedLocalRef<jobject> bundle(jenv_->CallObjectMethod(
+ usermanager_.get(), jni_cache_->usermanager_get_user_restrictions));
+ if (jni_cache_->ExceptionCheckAndClear() || bundle == nullptr) {
+ TC3_LOG(ERROR) << "Error calling getUserRestrictions";
+ lua_error(state_);
+ return 0;
+ }
+
+ const StringPiece key_str = ReadString(/*index=*/-1);
+ if (key_str.empty()) {
+ TC3_LOG(ERROR) << "Expected string, got null.";
+ lua_error(state_);
+ return 0;
+ }
+
+ ScopedLocalRef<jstring> key = jni_cache_->ConvertToJavaString(key_str);
+ if (jni_cache_->ExceptionCheckAndClear() || key == nullptr) {
+ TC3_LOG(ERROR) << "Expected string, got null.";
+ lua_error(state_);
+ return 0;
+ }
+ const bool permission = jenv_->CallBooleanMethod(
+ bundle.get(), jni_cache_->bundle_get_boolean, key.get());
+ if (jni_cache_->ExceptionCheckAndClear()) {
+ TC3_LOG(ERROR) << "Error getting bundle value";
+ lua_pushboolean(state_, false);
+ } else {
+ lua_pushboolean(state_, permission);
+ }
+ return 1;
+}
+
+int JniLuaEnvironment::HandleUrlEncode() {
+ const StringPiece input = ReadString(/*index=*/1);
+ if (input.empty()) {
+ TC3_LOG(ERROR) << "Expected string, got null.";
+ lua_error(state_);
+ return 0;
+ }
+
+ // Call Java URL encoder.
+ ScopedLocalRef<jstring> input_str = jni_cache_->ConvertToJavaString(input);
+ if (jni_cache_->ExceptionCheckAndClear() || input_str == nullptr) {
+ TC3_LOG(ERROR) << "Expected string, got null.";
+ lua_error(state_);
+ return 0;
+ }
+ ScopedLocalRef<jstring> encoded_str(
+ static_cast<jstring>(jenv_->CallStaticObjectMethod(
+ jni_cache_->urlencoder_class.get(), jni_cache_->urlencoder_encode,
+ input_str.get(), jni_cache_->string_utf8.get())));
+ if (jni_cache_->ExceptionCheckAndClear()) {
+ TC3_LOG(ERROR) << "Error calling UrlEncoder.encode";
+ lua_error(state_);
+ return 0;
+ }
+ PushString(ToStlString(jenv_, encoded_str.get()));
+ return 1;
+}
+
+ScopedLocalRef<jobject> JniLuaEnvironment::ParseUri(StringPiece url) const {
+ if (url.empty()) {
+ return nullptr;
+ }
+
+ // Call to Java URI parser.
+ ScopedLocalRef<jstring> url_str = jni_cache_->ConvertToJavaString(url);
+ if (jni_cache_->ExceptionCheckAndClear() || url_str == nullptr) {
+ TC3_LOG(ERROR) << "Expected string, got null";
+ return nullptr;
+ }
+
+ // Try to parse uri and get scheme.
+ ScopedLocalRef<jobject> uri(jenv_->CallStaticObjectMethod(
+ jni_cache_->uri_class.get(), jni_cache_->uri_parse, url_str.get()));
+ if (jni_cache_->ExceptionCheckAndClear() || uri == nullptr) {
+ TC3_LOG(ERROR) << "Error calling Uri.parse";
+ }
+ return uri;
+}
+
+int JniLuaEnvironment::HandleUrlSchema() {
+ StringPiece url = ReadString(/*index=*/1);
+
+ ScopedLocalRef<jobject> parsed_uri = ParseUri(url);
+ if (parsed_uri == nullptr) {
+ lua_error(state_);
+ return 0;
+ }
+
+ ScopedLocalRef<jstring> scheme_str(static_cast<jstring>(
+ jenv_->CallObjectMethod(parsed_uri.get(), jni_cache_->uri_get_scheme)));
+ if (jni_cache_->ExceptionCheckAndClear()) {
+ TC3_LOG(ERROR) << "Error calling Uri.getScheme";
+ lua_error(state_);
+ return 0;
+ }
+ if (scheme_str == nullptr) {
+ lua_pushnil(state_);
+ } else {
+ PushString(ToStlString(jenv_, scheme_str.get()));
+ }
+ return 1;
+}
+
+int JniLuaEnvironment::HandleUrlHost() {
+ StringPiece url = ReadString(/*index=*/-1);
+
+ ScopedLocalRef<jobject> parsed_uri = ParseUri(url);
+ if (parsed_uri == nullptr) {
+ lua_error(state_);
+ return 0;
+ }
+
+ ScopedLocalRef<jstring> host_str(static_cast<jstring>(
+ jenv_->CallObjectMethod(parsed_uri.get(), jni_cache_->uri_get_host)));
+ if (jni_cache_->ExceptionCheckAndClear()) {
+ TC3_LOG(ERROR) << "Error calling Uri.getHost";
+ lua_error(state_);
+ return 0;
+ }
+ if (host_str == nullptr) {
+ lua_pushnil(state_);
+ } else {
+ PushString(ToStlString(jenv_, host_str.get()));
+ }
+ return 1;
+}
+
+int JniLuaEnvironment::HandleHash() {
+ const StringPiece input = ReadString(/*index=*/-1);
+ lua_pushinteger(state_, tc3farmhash::Hash32(input.data(), input.length()));
+ return 1;
+}
+
+int JniLuaEnvironment::HandleFormat() {
+ const int num_args = lua_gettop(state_);
+ std::vector<StringPiece> args(num_args - 1);
+ for (int i = 0; i < num_args - 1; i++) {
+ args[i] = ReadString(/*index=*/i + 2);
+ }
+ PushString(strings::Substitute(ReadString(/*index=*/1), args));
+ return 1;
+}
+
+bool JniLuaEnvironment::LookupModelStringResource() {
+ // Handle only lookup by name.
+ if (lua_type(state_, 2) != LUA_TSTRING) {
+ return false;
+ }
+
+ const StringPiece resource_name = ReadString(/*index=*/-1);
+ std::string resource_content;
+ if (!resources_.GetResourceContent(device_locales_, resource_name,
+ &resource_content)) {
+ // Resource cannot be provided by the model.
+ return false;
+ }
+
+ PushString(resource_content);
+ return true;
+}
+
+int JniLuaEnvironment::HandleAndroidStringResources() {
+ // Check whether the requested resource can be served from the model data.
+ if (LookupModelStringResource()) {
+ return 1;
+ }
+
+ // Get system resources if not previously retrieved.
+ if (!RetrieveSystemResources()) {
+ TC3_LOG(ERROR) << "Error retrieving system resources.";
+ lua_error(state_);
+ return 0;
+ }
+
+ int resource_id;
+ switch (lua_type(state_, -1)) {
+ case LUA_TNUMBER:
+ resource_id = static_cast<int>(lua_tonumber(state_, /*idx=*/-1));
+ break;
+ case LUA_TSTRING: {
+ const StringPiece resource_name_str = ReadString(/*index=*/-1);
+ if (resource_name_str.empty()) {
+ TC3_LOG(ERROR) << "No resource name provided.";
+ lua_error(state_);
+ return 0;
+ }
+ ScopedLocalRef<jstring> resource_name =
+ jni_cache_->ConvertToJavaString(resource_name_str);
+ if (resource_name == nullptr) {
+ TC3_LOG(ERROR) << "Invalid resource name.";
+ lua_error(state_);
+ return 0;
+ }
+ resource_id = jenv_->CallIntMethod(
+ system_resources_.get(), jni_cache_->resources_get_identifier,
+ resource_name.get(), string_.get(), android_.get());
+ if (jni_cache_->ExceptionCheckAndClear()) {
+ TC3_LOG(ERROR) << "Error calling getIdentifier.";
+ lua_error(state_);
+ return 0;
+ }
+ break;
+ }
+ default:
+ TC3_LOG(ERROR) << "Unexpected type for resource lookup.";
+ lua_error(state_);
+ return 0;
+ }
+ if (resource_id == 0) {
+ TC3_LOG(ERROR) << "Resource not found.";
+ lua_pushnil(state_);
+ return 1;
+ }
+ ScopedLocalRef<jstring> resource_str(static_cast<jstring>(
+ jenv_->CallObjectMethod(system_resources_.get(),
+ jni_cache_->resources_get_string, resource_id)));
+ if (jni_cache_->ExceptionCheckAndClear()) {
+ TC3_LOG(ERROR) << "Error calling getString.";
+ lua_error(state_);
+ return 0;
+ }
+ if (resource_str == nullptr) {
+ lua_pushnil(state_);
+ } else {
+ PushString(ToStlString(jenv_, resource_str.get()));
+ }
+ return 1;
+}
+
+bool JniLuaEnvironment::RetrieveSystemResources() {
+ if (system_resources_resources_retrieved_) {
+ return (system_resources_ != nullptr);
+ }
+ system_resources_resources_retrieved_ = true;
+ jobject system_resources_ref = jenv_->CallStaticObjectMethod(
+ jni_cache_->resources_class.get(), jni_cache_->resources_get_system);
+ if (jni_cache_->ExceptionCheckAndClear()) {
+ TC3_LOG(ERROR) << "Error calling getSystem.";
+ return false;
+ }
+ system_resources_ =
+ MakeGlobalRef(system_resources_ref, jenv_, jni_cache_->jvm);
+ return (system_resources_ != nullptr);
+}
+
+bool JniLuaEnvironment::RetrieveUserManager() {
+ if (context_ == nullptr) {
+ return false;
+ }
+ if (usermanager_retrieved_) {
+ return (usermanager_ != nullptr);
+ }
+ usermanager_retrieved_ = true;
+ ScopedLocalRef<jstring> service(jenv_->NewStringUTF("user"));
+ jobject usermanager_ref = jenv_->CallObjectMethod(
+ context_, jni_cache_->context_get_system_service, service.get());
+ if (jni_cache_->ExceptionCheckAndClear()) {
+ TC3_LOG(ERROR) << "Error calling getSystemService.";
+ return false;
+ }
+ usermanager_ = MakeGlobalRef(usermanager_ref, jenv_, jni_cache_->jvm);
+ return (usermanager_ != nullptr);
+}
+
+RemoteActionTemplate JniLuaEnvironment::ReadRemoteActionTemplateResult() {
+ RemoteActionTemplate result;
+ // Read intent template.
+ lua_pushnil(state_);
+ while (lua_next(state_, /*idx=*/-2)) {
+ const StringPiece key = ReadString(/*index=*/-2);
+ if (key.Equals("title_without_entity")) {
+ result.title_without_entity = ReadString(/*index=*/-1).ToString();
+ } else if (key.Equals("title_with_entity")) {
+ result.title_with_entity = ReadString(/*index=*/-1).ToString();
+ } else if (key.Equals("description")) {
+ result.description = ReadString(/*index=*/-1).ToString();
+ } else if (key.Equals("description_with_app_name")) {
+ result.description_with_app_name = ReadString(/*index=*/-1).ToString();
+ } else if (key.Equals("action")) {
+ result.action = ReadString(/*index=*/-1).ToString();
+ } else if (key.Equals("data")) {
+ result.data = ReadString(/*index=*/-1).ToString();
+ } else if (key.Equals("type")) {
+ result.type = ReadString(/*index=*/-1).ToString();
+ } else if (key.Equals("flags")) {
+ result.flags = static_cast<int>(lua_tointeger(state_, /*idx=*/-1));
+ } else if (key.Equals("package_name")) {
+ result.package_name = ReadString(/*index=*/-1).ToString();
+ } else if (key.Equals("request_code")) {
+ result.request_code = static_cast<int>(lua_tointeger(state_, /*idx=*/-1));
+ } else if (key.Equals("category")) {
+ ReadCategories(&result.category);
+ } else if (key.Equals("extra")) {
+ ReadExtras(&result.extra);
+ } else {
+ TC3_LOG(INFO) << "Unknown entry: " << key.ToString();
+ }
+ lua_pop(state_, 1);
+ }
+ lua_pop(state_, 1);
+ return result;
+}
+
+void JniLuaEnvironment::ReadCategories(std::vector<std::string>* category) {
+ // Read category array.
+ if (lua_type(state_, /*idx=*/-1) != LUA_TTABLE) {
+ TC3_LOG(ERROR) << "Expected categories table, got: "
+ << lua_type(state_, /*idx=*/-1);
+ lua_pop(state_, 1);
+ return;
+ }
+ lua_pushnil(state_);
+ while (lua_next(state_, /*idx=*/-2)) {
+ category->push_back(ReadString(/*index=*/-1).ToString());
+ lua_pop(state_, 1);
+ }
+}
+
+void JniLuaEnvironment::ReadExtras(std::map<std::string, Variant>* extra) {
+ if (lua_type(state_, /*idx=*/-1) != LUA_TTABLE) {
+ TC3_LOG(ERROR) << "Expected extras table, got: "
+ << lua_type(state_, /*idx=*/-1);
+ lua_pop(state_, 1);
+ return;
+ }
+ lua_pushnil(state_);
+ while (lua_next(state_, /*idx=*/-2)) {
+ // Each entry is a table specifying name and value.
+ // The value is specified via a type specific field as Lua doesn't allow
+ // to easily distinguish between different number types.
+ if (lua_type(state_, /*idx=*/-1) != LUA_TTABLE) {
+ TC3_LOG(ERROR) << "Expected a table for an extra, got: "
+ << lua_type(state_, /*idx=*/-1);
+ lua_pop(state_, 1);
+ return;
+ }
+ std::string name;
+ Variant value;
+
+ lua_pushnil(state_);
+ while (lua_next(state_, /*idx=*/-2)) {
+ const StringPiece key = ReadString(/*index=*/-2);
+ if (key.Equals("name")) {
+ name = ReadString(/*index=*/-1).ToString();
+ } else if (key.Equals("int_value")) {
+ value = Variant(static_cast<int>(lua_tonumber(state_, /*idx=*/-1)));
+ } else if (key.Equals("long_value")) {
+ value = Variant(static_cast<int64>(lua_tonumber(state_, /*idx=*/-1)));
+ } else if (key.Equals("float_value")) {
+ value = Variant(static_cast<float>(lua_tonumber(state_, /*idx=*/-1)));
+ } else if (key.Equals("bool_value")) {
+ value = Variant(static_cast<bool>(lua_toboolean(state_, /*idx=*/-1)));
+ } else if (key.Equals("string_value")) {
+ value = Variant(ReadString(/*index=*/-1).ToString());
+ } else {
+ TC3_LOG(INFO) << "Unknown extra field: " << key.ToString();
+ }
+ lua_pop(state_, 1);
+ }
+ if (!name.empty()) {
+ (*extra)[name] = value;
+ } else {
+ TC3_LOG(ERROR) << "Unnamed extra entry. Skipping.";
+ }
+ lua_pop(state_, 1);
+ }
+}
+
+int JniLuaEnvironment::ReadRemoteActionTemplates(
+ std::vector<RemoteActionTemplate>* result) {
+ // Read result.
+ if (lua_type(state_, /*idx=*/-1) != LUA_TTABLE) {
+ TC3_LOG(ERROR) << "Unexpected result for snippet: " << lua_type(state_, -1);
+ lua_error(state_);
+ return LUA_ERRRUN;
+ }
+
+ // Read remote action templates array.
+ lua_pushnil(state_);
+ while (lua_next(state_, /*idx=*/-2)) {
+ if (lua_type(state_, /*idx=*/-1) != LUA_TTABLE) {
+ TC3_LOG(ERROR) << "Expected intent table, got: "
+ << lua_type(state_, /*idx=*/-1);
+ lua_pop(state_, 1);
+ continue;
+ }
+ result->push_back(ReadRemoteActionTemplateResult());
+ }
+ lua_pop(state_, /*n=*/1);
+ return LUA_OK;
+}
+
+bool JniLuaEnvironment::RunIntentGenerator(
+ const std::string& generator_snippet,
+ std::vector<RemoteActionTemplate>* remote_actions) {
+ int status;
+ status = luaL_loadbuffer(state_, generator_snippet.data(),
+ generator_snippet.size(),
+ /*name=*/nullptr);
+ if (status != LUA_OK) {
+ TC3_LOG(ERROR) << "Couldn't load generator snippet: " << status;
+ return false;
+ }
+ status = lua_pcall(state_, /*nargs=*/0, /*nresults=*/1, /*errfunc=*/0);
+ if (status != LUA_OK) {
+ TC3_LOG(ERROR) << "Couldn't run generator snippet: " << status;
+ return false;
+ }
+ if (RunProtected(
+ [this, remote_actions] {
+ return ReadRemoteActionTemplates(remote_actions);
+ },
+ /*num_args=*/1) != LUA_OK) {
+ TC3_LOG(ERROR) << "Could not read results.";
+ return false;
+ }
+ // Check that we correctly cleaned-up the state.
+ const int stack_size = lua_gettop(state_);
+ if (stack_size > 0) {
+ TC3_LOG(ERROR) << "Unexpected stack size.";
+ lua_settop(state_, 0);
+ return false;
+ }
+ return true;
+}
+
+// Lua environment for classfication result intent generation.
+class AnnotatorJniEnvironment : public JniLuaEnvironment {
+ public:
+ AnnotatorJniEnvironment(const Resources& resources, const JniCache* jni_cache,
+ const jobject context,
+ const std::vector<Locale>& device_locales,
+ const std::string& entity_text,
+ const ClassificationResult& classification,
+ const int64 reference_time_ms_utc,
+ const reflection::Schema* entity_data_schema)
+ : JniLuaEnvironment(resources, jni_cache, context, device_locales),
+ entity_text_(entity_text),
+ classification_(classification),
+ reference_time_ms_utc_(reference_time_ms_utc),
+ entity_data_schema_(entity_data_schema) {}
+
+ protected:
+ void SetupExternalHook() override {
+ JniLuaEnvironment::SetupExternalHook();
+ lua_pushinteger(state_, reference_time_ms_utc_);
+ lua_setfield(state_, /*idx=*/-2, kReferenceTimeUsecKey);
+
+ PushAnnotation(classification_, entity_text_, entity_data_schema_, this);
+ lua_setfield(state_, /*idx=*/-2, "entity");
+ }
+
+ const std::string& entity_text_;
+ const ClassificationResult& classification_;
+ const int64 reference_time_ms_utc_;
+
+ // Reflection schema data.
+ const reflection::Schema* const entity_data_schema_;
+};
+
+// Lua environment for actions intent generation.
+class ActionsJniLuaEnvironment : public JniLuaEnvironment {
+ public:
+ ActionsJniLuaEnvironment(
+ const Resources& resources, const JniCache* jni_cache,
+ const jobject context, const std::vector<Locale>& device_locales,
+ const Conversation& conversation, const ActionSuggestion& action,
+ const reflection::Schema* actions_entity_data_schema,
+ const reflection::Schema* annotations_entity_data_schema)
+ : JniLuaEnvironment(resources, jni_cache, context, device_locales),
+ conversation_(conversation),
+ action_(action),
+ annotation_iterator_(annotations_entity_data_schema, this),
+ conversation_iterator_(annotations_entity_data_schema, this),
+ entity_data_schema_(actions_entity_data_schema) {}
+
+ protected:
+ void SetupExternalHook() override {
+ JniLuaEnvironment::SetupExternalHook();
+ conversation_iterator_.NewIterator("conversation", &conversation_.messages,
+ state_);
+ lua_setfield(state_, /*idx=*/-2, "conversation");
+
+ PushAction(action_, entity_data_schema_, annotation_iterator_, this);
+ lua_setfield(state_, /*idx=*/-2, "entity");
+ }
+
+ const Conversation& conversation_;
+ const ActionSuggestion& action_;
+ const AnnotationIterator<ActionSuggestionAnnotation> annotation_iterator_;
+ const ConversationIterator conversation_iterator_;
+ const reflection::Schema* entity_data_schema_;
+};
+
+} // namespace
+
+std::unique_ptr<IntentGenerator> IntentGenerator::Create(
+ const IntentFactoryModel* options, const ResourcePool* resources,
+ const std::shared_ptr<JniCache>& jni_cache) {
+ std::unique_ptr<IntentGenerator> intent_generator(
+ new IntentGenerator(options, resources, jni_cache));
+
+ if (options == nullptr || options->generator() == nullptr) {
+ TC3_LOG(ERROR) << "No intent generator options.";
+ return nullptr;
+ }
+
+ std::unique_ptr<ZlibDecompressor> zlib_decompressor =
+ ZlibDecompressor::Instance();
+ if (!zlib_decompressor) {
+ TC3_LOG(ERROR) << "Cannot initialize decompressor.";
+ return nullptr;
+ }
+
+ for (const IntentFactoryModel_::IntentGenerator* generator :
+ *options->generator()) {
+ std::string lua_template_generator;
+ if (!zlib_decompressor->MaybeDecompressOptionallyCompressedBuffer(
+ generator->lua_template_generator(),
+ generator->compressed_lua_template_generator(),
+ &lua_template_generator)) {
+ TC3_LOG(ERROR) << "Could not decompress generator template.";
+ return nullptr;
+ }
+
+ std::string lua_code = lua_template_generator;
+ if (options->precompile_generators()) {
+ if (!Compile(lua_template_generator, &lua_code)) {
+ TC3_LOG(ERROR) << "Could not precompile generator template.";
+ return nullptr;
+ }
+ }
+
+ intent_generator->generators_[generator->type()->str()] = lua_code;
+ }
+
+ return intent_generator;
+}
+
+std::vector<Locale> IntentGenerator::ParseDeviceLocales(
+ const jstring device_locales) const {
+ if (device_locales == nullptr) {
+ TC3_LOG(ERROR) << "No locales provided.";
+ return {};
+ }
+ ScopedStringChars locales_str =
+ GetScopedStringChars(jni_cache_->GetEnv(), device_locales);
+ if (locales_str == nullptr) {
+ TC3_LOG(ERROR) << "Cannot retrieve provided locales.";
+ return {};
+ }
+ std::vector<Locale> locales;
+ if (!ParseLocales(reinterpret_cast<const char*>(locales_str.get()),
+ &locales)) {
+ TC3_LOG(ERROR) << "Cannot parse locales.";
+ return {};
+ }
+ return locales;
+}
+
+bool IntentGenerator::GenerateIntents(
+ const jstring device_locales, const ClassificationResult& classification,
+ const int64 reference_time_ms_utc, const std::string& text,
+ const CodepointSpan selection_indices, const jobject context,
+ const reflection::Schema* annotations_entity_data_schema,
+ std::vector<RemoteActionTemplate>* remote_actions) const {
+ if (options_ == nullptr) {
+ return false;
+ }
+
+ // Retrieve generator for specified entity.
+ auto it = generators_.find(classification.collection);
+ if (it == generators_.end()) {
+ return true;
+ }
+
+ const std::string entity_text =
+ UTF8ToUnicodeText(text, /*do_copy=*/false)
+ .UTF8Substring(selection_indices.first, selection_indices.second);
+
+ std::unique_ptr<AnnotatorJniEnvironment> interpreter(
+ new AnnotatorJniEnvironment(
+ resources_, jni_cache_.get(), context,
+ ParseDeviceLocales(device_locales), entity_text, classification,
+ reference_time_ms_utc, annotations_entity_data_schema));
+
+ if (!interpreter->Initialize()) {
+ TC3_LOG(ERROR) << "Could not create Lua interpreter.";
+ return false;
+ }
+
+ return interpreter->RunIntentGenerator(it->second, remote_actions);
+}
+
+bool IntentGenerator::GenerateIntents(
+ const jstring device_locales, const ActionSuggestion& action,
+ const Conversation& conversation, const jobject context,
+ const reflection::Schema* annotations_entity_data_schema,
+ const reflection::Schema* actions_entity_data_schema,
+ std::vector<RemoteActionTemplate>* remote_actions) const {
+ if (options_ == nullptr) {
+ return false;
+ }
+
+ // Retrieve generator for specified action.
+ auto it = generators_.find(action.type);
+ if (it == generators_.end()) {
+ return true;
+ }
+
+ std::unique_ptr<ActionsJniLuaEnvironment> interpreter(
+ new ActionsJniLuaEnvironment(
+ resources_, jni_cache_.get(), context,
+ ParseDeviceLocales(device_locales), conversation, action,
+ actions_entity_data_schema, annotations_entity_data_schema));
+
+ if (!interpreter->Initialize()) {
+ TC3_LOG(ERROR) << "Could not create Lua interpreter.";
+ return false;
+ }
+
+ return interpreter->RunIntentGenerator(it->second, remote_actions);
+}
+
+} // namespace libtextclassifier3
diff --git a/utils/intents/intent-generator.h b/utils/intents/intent-generator.h
new file mode 100644
index 0000000..9177adb
--- /dev/null
+++ b/utils/intents/intent-generator.h
@@ -0,0 +1,113 @@
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_INTENTS_INTENT_GENERATOR_H_
+#define LIBTEXTCLASSIFIER_UTILS_INTENTS_INTENT_GENERATOR_H_
+
+#include <jni.h>
+#include <map>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "actions/types.h"
+#include "annotator/types.h"
+#include "utils/i18n/locale.h"
+#include "utils/intents/intent-config_generated.h"
+#include "utils/java/jni-cache.h"
+#include "utils/java/scoped_local_ref.h"
+#include "utils/optional.h"
+#include "utils/resources.h"
+#include "utils/resources_generated.h"
+#include "utils/strings/stringpiece.h"
+
+namespace libtextclassifier3 {
+
+// A template with parameters for an Android remote action.
+struct RemoteActionTemplate {
+ // Title shown for the action (see: RemoteAction.getTitle).
+ Optional<std::string> title_without_entity;
+
+ // Title with entity for the action. It is not guaranteed that the client
+ // will use this, so title should be always given and general enough.
+ Optional<std::string> title_with_entity;
+
+ // Description shown for the action (see: RemoteAction.getContentDescription).
+ Optional<std::string> description;
+
+ // Description shown for the action (see: RemoteAction.getContentDescription)
+ // when app name is available. Caller is expected to replace the placeholder
+ // by the name of the app that is going to handle the action.
+ Optional<std::string> description_with_app_name;
+
+ // The action to set on the Intent (see: Intent.setAction).
+ Optional<std::string> action;
+
+ // The data to set on the Intent (see: Intent.setData).
+ Optional<std::string> data;
+
+ // The type to set on the Intent (see: Intent.setType).
+ Optional<std::string> type;
+
+ // Flags for launching the Intent (see: Intent.setFlags).
+ Optional<int> flags;
+
+ // Categories to set on the Intent (see: Intent.addCategory).
+ std::vector<std::string> category;
+
+ // Explicit application package to set on the Intent (see: Intent.setPackage).
+ Optional<std::string> package_name;
+
+ // The list of all the extras to add to the Intent.
+ std::map<std::string, Variant> extra;
+
+ // Private request code ot use for the Intent.
+ Optional<int> request_code;
+};
+
+// Helper class to generate Android intents for text classifier results.
+class IntentGenerator {
+ public:
+ static std::unique_ptr<IntentGenerator> Create(
+ const IntentFactoryModel* options, const ResourcePool* resources,
+ const std::shared_ptr<JniCache>& jni_cache);
+
+ // Generates intents for a classification result.
+ // Returns true, if the intent generator snippets could be successfully run,
+ // returns false otherwise.
+ bool GenerateIntents(const jstring device_locales,
+ const ClassificationResult& classification,
+ const int64 reference_time_ms_utc,
+ const std::string& text,
+ const CodepointSpan selection_indices,
+ const jobject context,
+ const reflection::Schema* annotations_entity_data_schema,
+ std::vector<RemoteActionTemplate>* remote_actions) const;
+
+ // Generates intents for an action suggestion.
+ // Returns true, if the intent generator snippets could be successfully run,
+ // returns false otherwise.
+ bool GenerateIntents(const jstring device_locales,
+ const ActionSuggestion& action,
+ const Conversation& conversation, const jobject context,
+ const reflection::Schema* annotations_entity_data_schema,
+ const reflection::Schema* actions_entity_data_schema,
+ std::vector<RemoteActionTemplate>* remote_actions) const;
+
+ private:
+ IntentGenerator(const IntentFactoryModel* options,
+ const ResourcePool* resources,
+ const std::shared_ptr<JniCache>& jni_cache)
+ : options_(options),
+ resources_(Resources(resources)),
+ jni_cache_(jni_cache) {}
+
+ std::vector<Locale> ParseDeviceLocales(const jstring device_locales) const;
+
+ const IntentFactoryModel* options_;
+ const Resources resources_;
+ std::shared_ptr<JniCache> jni_cache_;
+ std::map<std::string, std::string> generators_;
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_INTENTS_INTENT_GENERATOR_H_
diff --git a/utils/intents/jni.cc b/utils/intents/jni.cc
new file mode 100644
index 0000000..d6274b1
--- /dev/null
+++ b/utils/intents/jni.cc
@@ -0,0 +1,227 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/intents/jni.h"
+#include <memory>
+#include "utils/intents/intent-generator.h"
+#include "utils/java/scoped_local_ref.h"
+
+namespace libtextclassifier3 {
+
+// The macros below are intended to reduce the boilerplate and avoid
+// easily introduced copy/paste errors.
+#define TC3_CHECK_JNI_PTR(PTR) TC3_CHECK((PTR) != nullptr)
+#define TC3_GET_CLASS(FIELD, NAME) \
+ handler->FIELD = MakeGlobalRef(env->FindClass(NAME), env, jni_cache->jvm); \
+ TC3_CHECK_JNI_PTR(handler->FIELD) << "Error finding class: " << NAME;
+#define TC3_GET_METHOD(CLASS, FIELD, NAME, SIGNATURE) \
+ handler->FIELD = env->GetMethodID(handler->CLASS.get(), NAME, SIGNATURE); \
+ TC3_CHECK(handler->FIELD) << "Error finding method: " << NAME;
+
+std::unique_ptr<RemoteActionTemplatesHandler>
+RemoteActionTemplatesHandler::Create(
+ const std::shared_ptr<JniCache>& jni_cache) {
+ JNIEnv* env = jni_cache->GetEnv();
+ if (env == nullptr) {
+ return nullptr;
+ }
+
+ std::unique_ptr<RemoteActionTemplatesHandler> handler(
+ new RemoteActionTemplatesHandler(jni_cache));
+
+ TC3_GET_CLASS(integer_class_, "java/lang/Integer");
+ TC3_GET_METHOD(integer_class_, integer_init_, "<init>", "(I)V");
+
+ TC3_GET_CLASS(remote_action_template_class_,
+ TC3_PACKAGE_PATH TC3_REMOTE_ACTION_TEMPLATE_CLASS_NAME_STR);
+ TC3_GET_METHOD(
+ remote_action_template_class_, remote_action_template_init_, "<init>",
+ "(Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/"
+ "String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/"
+ "Integer;[Ljava/lang/String;Ljava/lang/String;[L" TC3_PACKAGE_PATH
+ TC3_NAMED_VARIANT_CLASS_NAME_STR ";Ljava/lang/Integer;)V");
+
+ TC3_GET_CLASS(named_variant_class_,
+ TC3_PACKAGE_PATH TC3_NAMED_VARIANT_CLASS_NAME_STR);
+
+ TC3_GET_METHOD(named_variant_class_, named_variant_from_int_, "<init>",
+ "(Ljava/lang/String;I)V");
+ TC3_GET_METHOD(named_variant_class_, named_variant_from_long_, "<init>",
+ "(Ljava/lang/String;J)V");
+ TC3_GET_METHOD(named_variant_class_, named_variant_from_float_, "<init>",
+ "(Ljava/lang/String;F)V");
+ TC3_GET_METHOD(named_variant_class_, named_variant_from_double_, "<init>",
+ "(Ljava/lang/String;D)V");
+ TC3_GET_METHOD(named_variant_class_, named_variant_from_bool_, "<init>",
+ "(Ljava/lang/String;Z)V");
+ TC3_GET_METHOD(named_variant_class_, named_variant_from_string_, "<init>",
+ "(Ljava/lang/String;Ljava/lang/String;)V");
+
+ return handler;
+}
+
+jstring RemoteActionTemplatesHandler::AsUTF8String(
+ const Optional<std::string>& optional) const {
+ if (!optional.has_value()) {
+ return nullptr;
+ }
+ return jni_cache_->ConvertToJavaString(optional.value()).release();
+}
+
+jobject RemoteActionTemplatesHandler::AsInteger(
+ const Optional<int>& optional) const {
+ return (optional.has_value()
+ ? jni_cache_->GetEnv()->NewObject(integer_class_.get(),
+ integer_init_, optional.value())
+ : nullptr);
+}
+
+jobjectArray RemoteActionTemplatesHandler::AsStringArray(
+ const std::vector<std::string>& values) const {
+ if (values.empty()) {
+ return nullptr;
+ }
+ jobjectArray result = jni_cache_->GetEnv()->NewObjectArray(
+ values.size(), jni_cache_->string_class.get(), nullptr);
+ if (result == nullptr) {
+ return nullptr;
+ }
+ for (int k = 0; k < values.size(); k++) {
+ ScopedLocalRef<jstring> value_str =
+ jni_cache_->ConvertToJavaString(values[k]);
+ jni_cache_->GetEnv()->SetObjectArrayElement(result, k, value_str.get());
+ }
+ return result;
+}
+
+jobject RemoteActionTemplatesHandler::AsNamedVariant(
+ const std::string& name_str, const Variant& value) const {
+ ScopedLocalRef<jstring> name = jni_cache_->ConvertToJavaString(name_str);
+ if (name == nullptr) {
+ return nullptr;
+ }
+ switch (value.GetType()) {
+ case Variant::TYPE_INT_VALUE:
+ return jni_cache_->GetEnv()->NewObject(named_variant_class_.get(),
+ named_variant_from_int_,
+ name.get(), value.IntValue());
+ case Variant::TYPE_INT64_VALUE:
+ return jni_cache_->GetEnv()->NewObject(named_variant_class_.get(),
+ named_variant_from_long_,
+ name.get(), value.Int64Value());
+ case Variant::TYPE_FLOAT_VALUE:
+ return jni_cache_->GetEnv()->NewObject(named_variant_class_.get(),
+ named_variant_from_float_,
+ name.get(), value.FloatValue());
+ case Variant::TYPE_DOUBLE_VALUE:
+ return jni_cache_->GetEnv()->NewObject(named_variant_class_.get(),
+ named_variant_from_double_,
+ name.get(), value.DoubleValue());
+ case Variant::TYPE_BOOL_VALUE:
+ return jni_cache_->GetEnv()->NewObject(named_variant_class_.get(),
+ named_variant_from_bool_,
+ name.get(), value.BoolValue());
+ case Variant::TYPE_STRING_VALUE: {
+ ScopedLocalRef<jstring> value_jstring =
+ jni_cache_->ConvertToJavaString(value.StringValue());
+ if (value_jstring == nullptr) {
+ return nullptr;
+ }
+ return jni_cache_->GetEnv()->NewObject(named_variant_class_.get(),
+ named_variant_from_string_,
+ name.get(), value_jstring.get());
+ }
+ default:
+ return nullptr;
+ }
+}
+
+jobjectArray RemoteActionTemplatesHandler::AsNamedVariantArray(
+ const std::map<std::string, Variant>& values) const {
+ if (values.empty()) {
+ return nullptr;
+ }
+ jobjectArray result = jni_cache_->GetEnv()->NewObjectArray(
+ values.size(), named_variant_class_.get(), nullptr);
+ int element_index = 0;
+ for (auto key_value_pair : values) {
+ if (!key_value_pair.second.HasValue()) {
+ element_index++;
+ continue;
+ }
+ ScopedLocalRef<jobject> named_extra(
+ AsNamedVariant(key_value_pair.first, key_value_pair.second),
+ jni_cache_->GetEnv());
+ if (named_extra == nullptr) {
+ return nullptr;
+ }
+ jni_cache_->GetEnv()->SetObjectArrayElement(result, element_index,
+ named_extra.get());
+ element_index++;
+ }
+ return result;
+}
+
+jobjectArray RemoteActionTemplatesHandler::RemoteActionTemplatesToJObjectArray(
+ const std::vector<RemoteActionTemplate>& remote_actions) const {
+ const jobjectArray results = jni_cache_->GetEnv()->NewObjectArray(
+ remote_actions.size(), remote_action_template_class_.get(), nullptr);
+ if (results == nullptr) {
+ return nullptr;
+ }
+ for (int i = 0; i < remote_actions.size(); i++) {
+ const RemoteActionTemplate& remote_action = remote_actions[i];
+ const jstring title_without_entity =
+ AsUTF8String(remote_action.title_without_entity);
+ const jstring title_with_entity =
+ AsUTF8String(remote_action.title_with_entity);
+ const jstring description = AsUTF8String(remote_action.description);
+ const jstring description_with_app_name =
+ AsUTF8String(remote_action.description_with_app_name);
+ const jstring action = AsUTF8String(remote_action.action);
+ const jstring data = AsUTF8String(remote_action.data);
+ const jstring type = AsUTF8String(remote_action.type);
+ const jobject flags = AsInteger(remote_action.flags);
+ const jobjectArray category = AsStringArray(remote_action.category);
+ const jstring package = AsUTF8String(remote_action.package_name);
+ const jobjectArray extra = AsNamedVariantArray(remote_action.extra);
+ const jobject request_code = AsInteger(remote_action.request_code);
+ ScopedLocalRef<jobject> result(
+ jni_cache_->GetEnv()->NewObject(
+ remote_action_template_class_.get(), remote_action_template_init_,
+ title_without_entity, title_with_entity, description,
+ description_with_app_name, action, data, type, flags, category,
+ package, extra, request_code),
+ jni_cache_->GetEnv());
+ if (result == nullptr) {
+ return nullptr;
+ }
+ jni_cache_->GetEnv()->SetObjectArrayElement(results, i, result.get());
+ }
+ return results;
+}
+
+jobject RemoteActionTemplatesHandler::EntityDataAsNamedVariantArray(
+ const reflection::Schema* entity_data_schema,
+ const std::string& serialized_entity_data) const {
+ ReflectiveFlatbufferBuilder entity_data_builder(entity_data_schema);
+ std::unique_ptr<ReflectiveFlatbuffer> buffer = entity_data_builder.NewRoot();
+ buffer->MergeFromSerializedFlatbuffer(serialized_entity_data);
+ std::map<std::string, Variant> entity_data_map = buffer->AsFlatMap();
+ return AsNamedVariantArray(entity_data_map);
+}
+
+} // namespace libtextclassifier3
diff --git a/utils/intents/jni.h b/utils/intents/jni.h
new file mode 100644
index 0000000..37952a2
--- /dev/null
+++ b/utils/intents/jni.h
@@ -0,0 +1,99 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_INTENTS_JNI_H_
+#define LIBTEXTCLASSIFIER_UTILS_INTENTS_JNI_H_
+
+#include <jni.h>
+#include <map>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "utils/flatbuffers.h"
+#include "utils/intents/intent-generator.h"
+#include "utils/java/jni-base.h"
+#include "utils/java/jni-cache.h"
+#include "utils/optional.h"
+#include "utils/variant.h"
+
+#ifndef TC3_REMOTE_ACTION_TEMPLATE_CLASS_NAME
+#define TC3_REMOTE_ACTION_TEMPLATE_CLASS_NAME RemoteActionTemplate
+#endif
+
+#define TC3_REMOTE_ACTION_TEMPLATE_CLASS_NAME_STR \
+ TC3_ADD_QUOTES(TC3_REMOTE_ACTION_TEMPLATE_CLASS_NAME)
+
+#ifndef TC3_NAMED_VARIANT_CLASS_NAME
+#define TC3_NAMED_VARIANT_CLASS_NAME NamedVariant
+#endif
+
+#define TC3_NAMED_VARIANT_CLASS_NAME_STR \
+ TC3_ADD_QUOTES(TC3_NAMED_VARIANT_CLASS_NAME)
+
+namespace libtextclassifier3 {
+
+// A helper class to create RemoteActionTemplate object from model results.
+class RemoteActionTemplatesHandler {
+ public:
+ static std::unique_ptr<RemoteActionTemplatesHandler> Create(
+ const std::shared_ptr<JniCache>& jni_cache);
+
+ jstring AsUTF8String(const Optional<std::string>& optional) const;
+ jobject AsInteger(const Optional<int>& optional) const;
+ jobjectArray AsStringArray(const std::vector<std::string>& values) const;
+ jobject AsNamedVariant(const std::string& name, const Variant& value) const;
+ jobjectArray AsNamedVariantArray(
+ const std::map<std::string, Variant>& values) const;
+
+ jobjectArray RemoteActionTemplatesToJObjectArray(
+ const std::vector<RemoteActionTemplate>& remote_actions) const;
+
+ jobject EntityDataAsNamedVariantArray(
+ const reflection::Schema* entity_data_schema,
+ const std::string& serialized_entity_data) const;
+
+ private:
+ explicit RemoteActionTemplatesHandler(
+ const std::shared_ptr<JniCache>& jni_cache)
+ : jni_cache_(jni_cache),
+ integer_class_(nullptr, jni_cache->jvm),
+ remote_action_template_class_(nullptr, jni_cache->jvm),
+ named_variant_class_(nullptr, jni_cache->jvm) {}
+
+ std::shared_ptr<JniCache> jni_cache_;
+
+ // java.lang.Integer
+ ScopedGlobalRef<jclass> integer_class_;
+ jmethodID integer_init_ = nullptr;
+
+ // RemoteActionTemplate
+ ScopedGlobalRef<jclass> remote_action_template_class_;
+ jmethodID remote_action_template_init_ = nullptr;
+
+ // NamedVariant
+ ScopedGlobalRef<jclass> named_variant_class_;
+ jmethodID named_variant_from_int_ = nullptr;
+ jmethodID named_variant_from_long_ = nullptr;
+ jmethodID named_variant_from_float_ = nullptr;
+ jmethodID named_variant_from_double_ = nullptr;
+ jmethodID named_variant_from_bool_ = nullptr;
+ jmethodID named_variant_from_string_ = nullptr;
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_INTENTS_JNI_H_
diff --git a/utils/intents/zlib-utils.cc b/utils/intents/zlib-utils.cc
new file mode 100644
index 0000000..9f29b46
--- /dev/null
+++ b/utils/intents/zlib-utils.cc
@@ -0,0 +1,41 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/intents/zlib-utils.h"
+
+#include <memory>
+
+#include "utils/zlib/buffer_generated.h"
+#include "utils/zlib/zlib.h"
+
+namespace libtextclassifier3 {
+
+bool CompressIntentModel(IntentFactoryModelT* intent_model) {
+ std::unique_ptr<ZlibCompressor> intent_zlib_compressor =
+ ZlibCompressor::Instance();
+ for (auto& generator : intent_model->generator) {
+ generator->compressed_lua_template_generator.reset(new CompressedBufferT);
+ intent_zlib_compressor->Compress(
+ std::string(reinterpret_cast<const char*>(
+ generator->lua_template_generator.data()),
+ generator->lua_template_generator.size()),
+ generator->compressed_lua_template_generator.get());
+ generator->lua_template_generator.clear();
+ }
+ return true;
+}
+
+} // namespace libtextclassifier3
diff --git a/utils/intents/zlib-utils.h b/utils/intents/zlib-utils.h
new file mode 100644
index 0000000..afefa3d
--- /dev/null
+++ b/utils/intents/zlib-utils.h
@@ -0,0 +1,28 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_INTENTS_ZLIB_UTILS_H_
+#define LIBTEXTCLASSIFIER_UTILS_INTENTS_ZLIB_UTILS_H_
+
+#include "utils/intents/intent-config_generated.h"
+
+namespace libtextclassifier3 {
+
+bool CompressIntentModel(IntentFactoryModelT* intent_model);
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_INTENTS_ZLIB_UTILS_H_
diff --git a/utils/java/jni-base.cc b/utils/java/jni-base.cc
index 330732c..4483b79 100644
--- a/utils/java/jni-base.cc
+++ b/utils/java/jni-base.cc
@@ -46,6 +46,10 @@
jfieldID fd_class_descriptor =
env->GetFieldID(fd_class.get(), "descriptor", "I");
if (fd_class_descriptor == nullptr) {
+ env->ExceptionClear();
+ fd_class_descriptor = env->GetFieldID(fd_class.get(), "fd", "I");
+ }
+ if (fd_class_descriptor == nullptr) {
TC3_LOG(ERROR) << "Couldn't find descriptor.";
return reinterpret_cast<jlong>(nullptr);
}
diff --git a/utils/java/jni-cache.cc b/utils/java/jni-cache.cc
index e2a6676..8c2f00a 100644
--- a/utils/java/jni-cache.cc
+++ b/utils/java/jni-cache.cc
@@ -38,7 +38,8 @@
context_class(nullptr, jvm),
uri_class(nullptr, jvm),
usermanager_class(nullptr, jvm),
- bundle_class(nullptr, jvm)
+ bundle_class(nullptr, jvm),
+ resources_class(nullptr, jvm)
#endif
{
}
@@ -218,6 +219,7 @@
TC3_GET_STATIC_METHOD(uri, parse, "parse",
"(Ljava/lang/String;)Landroid/net/Uri;");
TC3_GET_METHOD(uri, get_scheme, "getScheme", "()Ljava/lang/String;");
+ TC3_GET_METHOD(uri, get_host, "getHost", "()Ljava/lang/String;");
// UserManager.
TC3_GET_OPTIONAL_CLASS(usermanager, "android/os/UserManager");
@@ -227,6 +229,14 @@
// Bundle.
TC3_GET_CLASS(bundle, "android/os/Bundle");
TC3_GET_METHOD(bundle, get_boolean, "getBoolean", "(Ljava/lang/String;)Z");
+
+ // String resources.
+ TC3_GET_CLASS(resources, "android/content/res/Resources");
+ TC3_GET_STATIC_METHOD(resources, get_system, "getSystem",
+ "()Landroid/content/res/Resources;");
+ TC3_GET_METHOD(resources, get_identifier, "getIdentifier",
+ "(Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;)I");
+ TC3_GET_METHOD(resources, get_string, "getString", "(I)Ljava/lang/String;");
#endif
return result;
@@ -261,17 +271,17 @@
}
ScopedLocalRef<jstring> JniCache::ConvertToJavaString(
- const UnicodeText& text) const {
+ const char* utf8_text, const int utf8_text_size_bytes) const {
// Create java byte array.
JNIEnv* jenv = GetEnv();
const ScopedLocalRef<jbyteArray> text_java_utf8(
- jenv->NewByteArray(text.size_bytes()), jenv);
+ jenv->NewByteArray(utf8_text_size_bytes), jenv);
if (!text_java_utf8) {
return nullptr;
}
- jenv->SetByteArrayRegion(text_java_utf8.get(), 0, text.size_bytes(),
- reinterpret_cast<const jbyte*>(text.data()));
+ jenv->SetByteArrayRegion(text_java_utf8.get(), 0, utf8_text_size_bytes,
+ reinterpret_cast<const jbyte*>(utf8_text));
// Create the string with a UTF-8 charset.
return ScopedLocalRef<jstring>(
@@ -281,4 +291,14 @@
jenv);
}
+ScopedLocalRef<jstring> JniCache::ConvertToJavaString(
+ StringPiece utf8_text) const {
+ return ConvertToJavaString(utf8_text.data(), utf8_text.size());
+}
+
+ScopedLocalRef<jstring> JniCache::ConvertToJavaString(
+ const UnicodeText& text) const {
+ return ConvertToJavaString(text.data(), text.size_bytes());
+}
+
} // namespace libtextclassifier3
diff --git a/utils/java/jni-cache.h b/utils/java/jni-cache.h
index 18675fc..609ddb1 100644
--- a/utils/java/jni-cache.h
+++ b/utils/java/jni-cache.h
@@ -20,6 +20,7 @@
#include <jni.h>
#include "utils/java/scoped_global_ref.h"
#include "utils/java/scoped_local_ref.h"
+#include "utils/strings/stringpiece.h"
#include "utils/utf8/unicodetext.h"
namespace libtextclassifier3 {
@@ -109,7 +110,6 @@
ScopedGlobalRef<jclass> urlencoder_class;
jmethodID urlencoder_encode = nullptr;
-#ifdef __ANDROID__
// android.content.Context
ScopedGlobalRef<jclass> context_class;
jmethodID context_get_package_name = nullptr;
@@ -119,6 +119,7 @@
ScopedGlobalRef<jclass> uri_class;
jmethodID uri_parse = nullptr;
jmethodID uri_get_scheme = nullptr;
+ jmethodID uri_get_host = nullptr;
// android.os.UserManager
ScopedGlobalRef<jclass> usermanager_class;
@@ -127,9 +128,17 @@
// android.os.Bundle
ScopedGlobalRef<jclass> bundle_class;
jmethodID bundle_get_boolean = nullptr;
-#endif
+
+ // android.content.res.Resources
+ ScopedGlobalRef<jclass> resources_class;
+ jmethodID resources_get_system = nullptr;
+ jmethodID resources_get_identifier = nullptr;
+ jmethodID resources_get_string = nullptr;
// Helper to convert lib3 UnicodeText to Java strings.
+ ScopedLocalRef<jstring> ConvertToJavaString(
+ const char* utf8_text, const int utf8_text_size_bytes) const;
+ ScopedLocalRef<jstring> ConvertToJavaString(StringPiece utf8_text) const;
ScopedLocalRef<jstring> ConvertToJavaString(const UnicodeText& text) const;
private:
diff --git a/utils/lua-utils.cc b/utils/lua-utils.cc
new file mode 100644
index 0000000..64071ca
--- /dev/null
+++ b/utils/lua-utils.cc
@@ -0,0 +1,303 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/lua-utils.h"
+
+// lua_dump takes an extra argument "strip" in 5.3, but not in 5.2.
+#ifndef TC3_AOSP
+#define lua_dump(L, w, d, s) lua_dump((L), (w), (d))
+#endif
+
+namespace libtextclassifier3 {
+namespace {
+// Upvalue indices for the flatbuffer callback.
+static constexpr int kSchemaArgId = 1;
+static constexpr int kTypeArgId = 2;
+static constexpr int kTableArgId = 3;
+
+static constexpr luaL_Reg defaultlibs[] = {{"_G", luaopen_base},
+ {LUA_TABLIBNAME, luaopen_table},
+ {LUA_STRLIBNAME, luaopen_string},
+ {LUA_BITLIBNAME, luaopen_bit32},
+ {LUA_MATHLIBNAME, luaopen_math},
+ {nullptr, nullptr}};
+
+// Implementation of a lua_Writer that appends the data to a string.
+int LuaStringWriter(lua_State *state, const void *data, size_t size,
+ void *result) {
+ std::string *const result_string = static_cast<std::string *>(result);
+ result_string->insert(result_string->size(), static_cast<const char *>(data),
+ size);
+ return LUA_OK;
+}
+
+} // namespace
+
+LuaEnvironment::LuaEnvironment() { state_ = luaL_newstate(); }
+
+LuaEnvironment::~LuaEnvironment() {
+ if (state_ != nullptr) {
+ lua_close(state_);
+ }
+}
+
+int LuaEnvironment::Iterator::NextCallback(lua_State *state) {
+ return FromUpValue<Iterator *>(kIteratorArgId, state)->Next(state);
+}
+
+int LuaEnvironment::Iterator::LengthCallback(lua_State *state) {
+ return FromUpValue<Iterator *>(kIteratorArgId, state)->Length(state);
+}
+
+int LuaEnvironment::Iterator::ItemCallback(lua_State *state) {
+ return FromUpValue<Iterator *>(kIteratorArgId, state)->Item(state);
+}
+
+int LuaEnvironment::Iterator::IteritemsCallback(lua_State *state) {
+ return FromUpValue<Iterator *>(kIteratorArgId, state)->Iteritems(state);
+}
+
+void LuaEnvironment::PushFlatbuffer(const char *name,
+ const reflection::Schema *schema,
+ const reflection::Object *type,
+ const flatbuffers::Table *table,
+ lua_State *state) {
+ lua_newtable(state);
+ luaL_newmetatable(state, name);
+ lua_pushlightuserdata(state, AsUserData(schema));
+ lua_pushlightuserdata(state, AsUserData(type));
+ lua_pushlightuserdata(state, AsUserData(table));
+ lua_pushcclosure(state, &GetFieldCallback, 3);
+ lua_setfield(state, -2, kIndexKey);
+ lua_setmetatable(state, -2);
+}
+
+int LuaEnvironment::GetFieldCallback(lua_State *state) {
+ // Fetch the arguments.
+ const reflection::Schema *schema =
+ FromUpValue<reflection::Schema *>(kSchemaArgId, state);
+ const reflection::Object *type =
+ FromUpValue<reflection::Object *>(kTypeArgId, state);
+ const flatbuffers::Table *table =
+ FromUpValue<flatbuffers::Table *>(kTableArgId, state);
+ return GetField(schema, type, table, state);
+}
+
+int LuaEnvironment::GetField(const reflection::Schema *schema,
+ const reflection::Object *type,
+ const flatbuffers::Table *table,
+ lua_State *state) {
+ const char *field_name = lua_tostring(state, -1);
+ const reflection::Field *field = type->fields()->LookupByKey(field_name);
+ if (field == nullptr) {
+ lua_error(state);
+ return 0;
+ }
+ // Provide primitive fields directly.
+ const reflection::BaseType field_type = field->type()->base_type();
+ switch (field_type) {
+ case reflection::Bool:
+ lua_pushboolean(state, table->GetField<uint8_t>(
+ field->offset(), field->default_integer()));
+ break;
+ case reflection::Int:
+ lua_pushinteger(state, table->GetField<int32>(field->offset(),
+ field->default_integer()));
+ break;
+ case reflection::Long:
+ lua_pushinteger(state, table->GetField<int64>(field->offset(),
+ field->default_integer()));
+ break;
+ case reflection::Float:
+ lua_pushnumber(state, table->GetField<float>(field->offset(),
+ field->default_real()));
+ break;
+ case reflection::Double:
+ lua_pushnumber(state, table->GetField<double>(field->offset(),
+ field->default_real()));
+ break;
+ case reflection::String: {
+ const flatbuffers::String *string_value =
+ table->GetPointer<const flatbuffers::String *>(field->offset());
+ if (string_value != nullptr) {
+ lua_pushlstring(state, string_value->data(), string_value->Length());
+ } else {
+ lua_pushlstring(state, "", 0);
+ }
+ break;
+ }
+ case reflection::Obj: {
+ const flatbuffers::Table *field_table =
+ table->GetPointer<const flatbuffers::Table *>(field->offset());
+ if (field_table == nullptr) {
+ TC3_LOG(ERROR) << "Field was not set in entity data.";
+ lua_error(state);
+ return 0;
+ }
+ const reflection::Object *field_type =
+ schema->objects()->Get(field->type()->index());
+ PushFlatbuffer(field->name()->c_str(), schema, field_type, field_table,
+ state);
+ break;
+ }
+ default:
+ TC3_LOG(ERROR) << "Unsupported type: " << field_type;
+ lua_error(state);
+ return 0;
+ }
+ return 1;
+}
+
+int LuaEnvironment::ReadFlatbuffer(ReflectiveFlatbuffer *buffer) {
+ if (lua_type(state_, /*idx=*/-1) != LUA_TTABLE) {
+ TC3_LOG(ERROR) << "Expected actions table, got: "
+ << lua_type(state_, /*idx=*/-1);
+ lua_error(state_);
+ return LUA_ERRRUN;
+ }
+
+ lua_pushnil(state_);
+ while (lua_next(state_, /*idx=*/-2)) {
+ const StringPiece key = ReadString(/*index=*/-2);
+ const reflection::Field *field = buffer->GetFieldOrNull(key);
+ if (field == nullptr) {
+ TC3_LOG(ERROR) << "Unknown field: " << key.ToString();
+ lua_error(state_);
+ return LUA_ERRRUN;
+ }
+ switch (field->type()->base_type()) {
+ case reflection::Obj:
+ return ReadFlatbuffer(buffer->Mutable(field));
+ case reflection::Bool:
+ buffer->Set(field,
+ static_cast<bool>(lua_toboolean(state_, /*idx=*/-1)));
+ break;
+ case reflection::Int:
+ buffer->Set(field, static_cast<int>(lua_tonumber(state_, /*idx=*/-1)));
+ break;
+ case reflection::Long:
+ buffer->Set(field,
+ static_cast<int64>(lua_tonumber(state_, /*idx=*/-1)));
+ break;
+ case reflection::Float:
+ buffer->Set(field,
+ static_cast<float>(lua_tonumber(state_, /*idx=*/-1)));
+ break;
+ case reflection::Double:
+ buffer->Set(field,
+ static_cast<double>(lua_tonumber(state_, /*idx=*/-1)));
+ break;
+ case reflection::String: {
+ buffer->Set(field, ReadString(/*index=*/-1));
+ break;
+ }
+ default:
+ TC3_LOG(ERROR) << "Unsupported type: " << field->type()->base_type();
+ lua_error(state_);
+ return LUA_ERRRUN;
+ }
+ lua_pop(state_, 1);
+ }
+ // lua_pop(state_, /*n=*/1);
+ return LUA_OK;
+}
+
+void LuaEnvironment::LoadDefaultLibraries() {
+ for (const luaL_Reg *lib = defaultlibs; lib->func; lib++) {
+ luaL_requiref(state_, lib->name, lib->func, 1);
+ lua_pop(state_, 1); /* remove lib */
+ }
+}
+
+void LuaEnvironment::PushValue(const Variant &value) {
+ if (value.HasInt()) {
+ lua_pushnumber(state_, value.IntValue());
+ } else if (value.HasInt64()) {
+ lua_pushnumber(state_, value.Int64Value());
+ } else if (value.HasBool()) {
+ lua_pushboolean(state_, value.BoolValue());
+ } else if (value.HasFloat()) {
+ lua_pushnumber(state_, value.FloatValue());
+ } else if (value.HasDouble()) {
+ lua_pushnumber(state_, value.DoubleValue());
+ } else if (value.HasString()) {
+ lua_pushlstring(state_, value.StringValue().data(),
+ value.StringValue().size());
+ } else {
+ TC3_LOG(FATAL) << "Unknown value type.";
+ }
+}
+
+StringPiece LuaEnvironment::ReadString(const int index) const {
+ size_t length = 0;
+ const char *data = lua_tolstring(state_, index, &length);
+ return StringPiece(data, length);
+}
+
+void LuaEnvironment::PushString(const StringPiece str) {
+ lua_pushlstring(state_, str.data(), str.size());
+}
+
+void LuaEnvironment::PushFlatbuffer(const reflection::Schema *schema,
+ const flatbuffers::Table *table) {
+ PushFlatbuffer(schema->root_table()->name()->c_str(), schema,
+ schema->root_table(), table, state_);
+}
+
+int LuaEnvironment::RunProtected(const std::function<int()> &func,
+ const int num_args, const int num_results) {
+ struct ProtectedCall {
+ std::function<int()> func;
+
+ static int run(lua_State *state) {
+ // Read the pointer to the ProtectedCall struct.
+ ProtectedCall *p = static_cast<ProtectedCall *>(
+ lua_touserdata(state, lua_upvalueindex(1)));
+ return p->func();
+ }
+ };
+ ProtectedCall protected_call = {func};
+ lua_pushlightuserdata(state_, &protected_call);
+ lua_pushcclosure(state_, &ProtectedCall::run, /*n=*/1);
+ // Put the closure before the arguments on the stack.
+ if (num_args > 0) {
+ lua_insert(state_, -(1 + num_args));
+ }
+ return lua_pcall(state_, num_args, num_results, /*errorfunc=*/0);
+}
+
+bool LuaEnvironment::Compile(StringPiece snippet, std::string *bytecode) {
+ if (luaL_loadbuffer(state_, snippet.data(), snippet.size(),
+ /*name=*/nullptr) != LUA_OK) {
+ TC3_LOG(ERROR) << "Could not compile lua snippet: "
+ << ReadString(/*index=*/-1).ToString();
+ lua_pop(state_, 1);
+ return false;
+ }
+ if (lua_dump(state_, LuaStringWriter, bytecode, /*strip*/ 1) != LUA_OK) {
+ TC3_LOG(ERROR) << "Could not dump compiled lua snippet.";
+ lua_pop(state_, 1);
+ return false;
+ }
+ lua_pop(state_, 1);
+ return true;
+}
+
+bool Compile(StringPiece snippet, std::string *bytecode) {
+ return LuaEnvironment().Compile(snippet, bytecode);
+}
+
+} // namespace libtextclassifier3
diff --git a/utils/lua-utils.h b/utils/lua-utils.h
new file mode 100644
index 0000000..d825cb9
--- /dev/null
+++ b/utils/lua-utils.h
@@ -0,0 +1,264 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_LUA_UTILS_H_
+#define LIBTEXTCLASSIFIER_UTILS_LUA_UTILS_H_
+
+#include <functional>
+#include <vector>
+
+#include "utils/flatbuffers.h"
+#include "utils/strings/stringpiece.h"
+#include "utils/variant.h"
+#include "flatbuffers/reflection_generated.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+#include "lauxlib.h"
+#include "lua.h"
+#include "lualib.h"
+#ifdef __cplusplus
+}
+#endif
+
+namespace libtextclassifier3 {
+
+static constexpr const char *kLengthKey = "__len";
+static constexpr const char *kPairsKey = "__pairs";
+static constexpr const char *kIndexKey = "__index";
+
+// Casts to the lua user data type.
+template <typename T>
+void *AsUserData(const T *value) {
+ return static_cast<void *>(const_cast<T *>(value));
+}
+template <typename T>
+void *AsUserData(const T value) {
+ return reinterpret_cast<void *>(value);
+}
+
+// Retrieves up-values.
+template <typename T>
+T FromUpValue(const int index, lua_State *state) {
+ return static_cast<T>(lua_touserdata(state, lua_upvalueindex(index)));
+}
+
+class LuaEnvironment {
+ public:
+ // Wrapper for handling an iterator.
+ class Iterator {
+ public:
+ virtual ~Iterator() {}
+ static int NextCallback(lua_State *state);
+ static int LengthCallback(lua_State *state);
+ static int ItemCallback(lua_State *state);
+ static int IteritemsCallback(lua_State *state);
+
+ // Called when the next element of an iterator is fetched.
+ virtual int Next(lua_State *state) const = 0;
+
+ // Called when the length of the iterator is queried.
+ virtual int Length(lua_State *state) const = 0;
+
+ // Called when an item is queried.
+ virtual int Item(lua_State *state) const = 0;
+
+ // Called when a new iterator is started.
+ virtual int Iteritems(lua_State *state) const = 0;
+
+ protected:
+ static constexpr int kIteratorArgId = 1;
+ };
+
+ template <typename T>
+ class ItemIterator : public Iterator {
+ public:
+ void NewIterator(StringPiece name, const T *items, lua_State *state) const {
+ lua_newtable(state);
+ luaL_newmetatable(state, name.data());
+ lua_pushlightuserdata(state, AsUserData(this));
+ lua_pushlightuserdata(state, AsUserData(items));
+ lua_pushcclosure(state, &Iterator::ItemCallback, 2);
+ lua_setfield(state, -2, kIndexKey);
+ lua_pushlightuserdata(state, AsUserData(this));
+ lua_pushlightuserdata(state, AsUserData(items));
+ lua_pushcclosure(state, &Iterator::LengthCallback, 2);
+ lua_setfield(state, -2, kLengthKey);
+ lua_pushlightuserdata(state, AsUserData(this));
+ lua_pushlightuserdata(state, AsUserData(items));
+ lua_pushcclosure(state, &Iterator::IteritemsCallback, 2);
+ lua_setfield(state, -2, kPairsKey);
+ lua_setmetatable(state, -2);
+ }
+
+ int Iteritems(lua_State *state) const override {
+ lua_pushlightuserdata(state, AsUserData(this));
+ lua_pushlightuserdata(
+ state, lua_touserdata(state, lua_upvalueindex(kItemsArgId)));
+ lua_pushnumber(state, 0);
+ lua_pushcclosure(state, &Iterator::NextCallback, 3);
+ return /*num results=*/1;
+ }
+
+ int Length(lua_State *state) const override {
+ lua_pushinteger(state, FromUpValue<T *>(kItemsArgId, state)->size());
+ return /*num results=*/1;
+ }
+
+ int Next(lua_State *state) const override {
+ return Next(FromUpValue<T *>(kItemsArgId, state),
+ lua_tointeger(state, lua_upvalueindex(kIterValueArgId)),
+ state);
+ }
+
+ int Next(const T *items, const int64 pos, lua_State *state) const {
+ if (pos >= items->size()) {
+ return 0;
+ }
+
+ // Update iterator value.
+ lua_pushnumber(state, pos + 1);
+ lua_replace(state, lua_upvalueindex(3));
+
+ // Push key.
+ lua_pushinteger(state, pos + 1);
+
+ // Push item.
+ return 1 + Item(items, pos, state);
+ }
+
+ int Item(lua_State *state) const override {
+ const T *items = FromUpValue<T *>(kItemsArgId, state);
+ switch (lua_type(state, -1)) {
+ case LUA_TNUMBER: {
+ // Lua is one based, so adjust the index here.
+ const int64 index =
+ static_cast<int64>(lua_tonumber(state, /*idx=*/-1)) - 1;
+ if (index < 0 || index >= items->size()) {
+ TC3_LOG(ERROR) << "Invalid index: " << index;
+ lua_error(state);
+ return 0;
+ }
+ return Item(items, index, state);
+ }
+ case LUA_TSTRING: {
+ size_t key_length = 0;
+ const char *key = lua_tolstring(state, /*idx=*/-1, &key_length);
+ return Item(items, StringPiece(key, key_length), state);
+ }
+ default:
+ TC3_LOG(ERROR) << "Unexpected access type: " << lua_type(state, -1);
+ lua_error(state);
+ return 0;
+ }
+ }
+
+ virtual int Item(const T *items, const int64 pos,
+ lua_State *state) const = 0;
+
+ virtual int Item(const T *items, StringPiece key, lua_State *state) const {
+ TC3_LOG(ERROR) << "Unexpected key access: " << key.ToString();
+ lua_error(state);
+ return 0;
+ }
+
+ protected:
+ static constexpr int kItemsArgId = 2;
+ static constexpr int kIterValueArgId = 3;
+ };
+
+ virtual ~LuaEnvironment();
+ LuaEnvironment();
+
+ // Compile a lua snippet into binary bytecode.
+ // NOTE: The compiled bytecode might not be compatible across Lua versions
+ // and platforms.
+ bool Compile(StringPiece snippet, std::string *bytecode);
+
+ typedef int (*CallbackHandler)(lua_State *);
+
+ // Loads default libraries.
+ void LoadDefaultLibraries();
+
+ // Provides a callback to Lua.
+ template <typename T, int (T::*handler)()>
+ void Bind() {
+ lua_pushlightuserdata(state_, static_cast<void *>(this));
+ lua_pushcclosure(state_, &Dispatch<T, handler>, 1);
+ }
+
+ // Setup a named table that callsback whenever a member is accessed.
+ // This allows to lazily provide required information to the script.
+ template <typename T, int (T::*handler)()>
+ void BindTable(const char *name) {
+ lua_newtable(state_);
+ luaL_newmetatable(state_, name);
+ lua_pushlightuserdata(state_, static_cast<void *>(this));
+ lua_pushcclosure(state_, &Dispatch<T, handler>, 1);
+ lua_setfield(state_, -2, kIndexKey);
+ lua_setmetatable(state_, -2);
+ }
+
+ void PushValue(const Variant &value);
+
+ // Reads a string from the stack.
+ StringPiece ReadString(const int index) const;
+
+ // Pushes a string to the stack.
+ void PushString(const StringPiece str);
+
+ // Pushes a flatbuffer to the stack.
+ void PushFlatbuffer(const reflection::Schema *schema,
+ const flatbuffers::Table *table);
+
+ // Reads a flatbuffer from the stack.
+ int ReadFlatbuffer(ReflectiveFlatbuffer *buffer);
+
+ // Runs a closure in protected mode.
+ // `func`: closure to run in protected mode.
+ // `num_lua_args`: number of arguments from the lua stack to process.
+ // `num_results`: number of result values pushed on the stack.
+ int RunProtected(const std::function<int()> &func, const int num_args = 0,
+ const int num_results = 0);
+
+ lua_State *state() const { return state_; }
+
+ protected:
+ lua_State *state_;
+
+ private:
+ // Auxiliary methods to expose (reflective) flatbuffer based data to Lua.
+ static void PushFlatbuffer(const char *name, const reflection::Schema *schema,
+ const reflection::Object *type,
+ const flatbuffers::Table *table, lua_State *state);
+ static int GetFieldCallback(lua_State *state);
+ static int GetField(const reflection::Schema *schema,
+ const reflection::Object *type,
+ const flatbuffers::Table *table, lua_State *state);
+
+ template <typename T, int (T::*handler)()>
+ static int Dispatch(lua_State *state) {
+ T *env = FromUpValue<T *>(1, state);
+ return ((*env).*handler)();
+ }
+};
+
+bool Compile(StringPiece snippet, std::string *bytecode);
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_LUA_UTILS_H_
diff --git a/utils/math/fastexp.h b/utils/math/fastexp.h
index 63e5d5d..f690c73 100644
--- a/utils/math/fastexp.h
+++ b/utils/math/fastexp.h
@@ -60,7 +60,6 @@
extern FastMathClass FastMathInstance;
-inline float VeryFastExp2(float f) { return FastMathInstance.VeryFastExp2(f); }
inline float VeryFastExp(float f) { return FastMathInstance.VeryFastExp(f); }
} // namespace libtextclassifier3
diff --git a/utils/regex-match.cc b/utils/regex-match.cc
new file mode 100644
index 0000000..8c55e6b
--- /dev/null
+++ b/utils/regex-match.cc
@@ -0,0 +1,180 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/regex-match.h"
+
+#include <memory>
+
+#include "annotator/types.h"
+#include "utils/lua-utils.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+#include "lauxlib.h"
+#include "lualib.h"
+#ifdef __cplusplus
+}
+#endif
+
+namespace libtextclassifier3 {
+namespace {
+
+// Provide a lua environment for running regex match post verification.
+// It sets up and exposes the match data as well as the context.
+class LuaVerifier : private LuaEnvironment {
+ public:
+ static std::unique_ptr<LuaVerifier> Create(
+ const std::string& context, const std::string& verifier_code,
+ const UniLib::RegexMatcher* matcher);
+
+ bool Verify(bool* result);
+
+ private:
+ explicit LuaVerifier(const std::string& context,
+ const std::string& verifier_code,
+ const UniLib::RegexMatcher* matcher)
+ : context_(context), verifier_code_(verifier_code), matcher_(matcher) {}
+ bool Initialize();
+
+ // Provides details of a capturing group to lua.
+ int GetCapturingGroup();
+
+ const std::string& context_;
+ const std::string& verifier_code_;
+ const UniLib::RegexMatcher* matcher_;
+};
+
+bool LuaVerifier::Initialize() {
+ // Run protected to not lua panic in case of setup failure.
+ return RunProtected([this] {
+ LoadDefaultLibraries();
+
+ // Expose context of the match as `context` global variable.
+ PushString(context_);
+ lua_setglobal(state_, "context");
+
+ // Expose match array as `match` global variable.
+ // Each entry `match[i]` exposes the ith capturing group as:
+ // * `begin`: span start
+ // * `end`: span end
+ // * `text`: the text
+ BindTable<LuaVerifier, &LuaVerifier::GetCapturingGroup>("match");
+ lua_setglobal(state_, "match");
+ return LUA_OK;
+ }) == LUA_OK;
+}
+
+std::unique_ptr<LuaVerifier> LuaVerifier::Create(
+ const std::string& context, const std::string& verifier_code,
+ const UniLib::RegexMatcher* matcher) {
+ auto verifier = std::unique_ptr<LuaVerifier>(
+ new LuaVerifier(context, verifier_code, matcher));
+ if (!verifier->Initialize()) {
+ TC3_LOG(ERROR) << "Could not initialize lua environment.";
+ return nullptr;
+ }
+ return verifier;
+}
+
+int LuaVerifier::GetCapturingGroup() {
+ if (lua_type(state_, /*idx=*/-1) != LUA_TNUMBER) {
+ TC3_LOG(ERROR) << "Unexpected type for match group lookup: "
+ << lua_type(state_, /*idx=*/-1);
+ lua_error(state_);
+ return 0;
+ }
+ const int group_id = static_cast<int>(lua_tonumber(state_, /*idx=*/-1));
+ int status = UniLib::RegexMatcher::kNoError;
+ const CodepointSpan span = {matcher_->Start(group_id, &status),
+ matcher_->End(group_id, &status)};
+ std::string text = matcher_->Group(group_id, &status).ToUTF8String();
+ if (status != UniLib::RegexMatcher::kNoError) {
+ TC3_LOG(ERROR) << "Could not extract span from capturing group.";
+ lua_error(state_);
+ return 0;
+ }
+ lua_newtable(state_);
+ lua_pushinteger(state_, span.first);
+ lua_setfield(state_, /*idx=*/-2, "begin");
+ lua_pushinteger(state_, span.second);
+ lua_setfield(state_, /*idx=*/-2, "end");
+ PushString(text);
+ lua_setfield(state_, /*idx=*/-2, "text");
+ return 1;
+}
+
+bool LuaVerifier::Verify(bool* result) {
+ if (luaL_loadbuffer(state_, verifier_code_.data(), verifier_code_.size(),
+ /*name=*/nullptr) != LUA_OK) {
+ TC3_LOG(ERROR) << "Could not load verifier snippet.";
+ return false;
+ }
+
+ if (lua_pcall(state_, /*nargs=*/0, /*nresults=*/1, /*errfunc=*/0) != LUA_OK) {
+ TC3_LOG(ERROR) << "Could not run verifier snippet.";
+ return false;
+ }
+
+ if (RunProtected(
+ [this, result] {
+ if (lua_type(state_, /*idx=*/-1) != LUA_TBOOLEAN) {
+ TC3_LOG(ERROR) << "Unexpected verification result type: "
+ << lua_type(state_, /*idx=*/-1);
+ lua_error(state_);
+ return LUA_ERRRUN;
+ }
+ *result = lua_toboolean(state_, /*idx=*/-1);
+ return LUA_OK;
+ },
+ /*num_args=*/1) != LUA_OK) {
+ TC3_LOG(ERROR) << "Could not read lua result.";
+ return false;
+ }
+ return true;
+}
+
+} // namespace
+
+bool SetFieldFromCapturingGroup(const int group_id,
+ const FlatbufferFieldPath* field_path,
+ const UniLib::RegexMatcher* matcher,
+ ReflectiveFlatbuffer* flatbuffer) {
+ int status = UniLib::RegexMatcher::kNoError;
+ std::string group_text = matcher->Group(group_id, &status).ToUTF8String();
+ if (status != UniLib::RegexMatcher::kNoError || group_text.empty()) {
+ return false;
+ }
+ return flatbuffer->ParseAndSet(field_path, group_text);
+}
+
+bool VerifyMatch(const std::string& context,
+ const UniLib::RegexMatcher* matcher,
+ const std::string& lua_verifier_code) {
+ bool status = false;
+ auto verifier = LuaVerifier::Create(context, lua_verifier_code, matcher);
+ if (verifier == nullptr) {
+ TC3_LOG(ERROR) << "Could not create verifier.";
+ return false;
+ }
+ if (!verifier->Verify(&status)) {
+ TC3_LOG(ERROR) << "Could not create verifier.";
+ return false;
+ }
+ return status;
+}
+
+} // namespace libtextclassifier3
diff --git a/utils/regex-match.h b/utils/regex-match.h
new file mode 100644
index 0000000..f77f6b1
--- /dev/null
+++ b/utils/regex-match.h
@@ -0,0 +1,48 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_REGEX_MATCH_H_
+#define LIBTEXTCLASSIFIER_UTILS_REGEX_MATCH_H_
+
+#include "utils/flatbuffers.h"
+#include "utils/flatbuffers_generated.h"
+#include "utils/utf8/unilib.h"
+
+namespace libtextclassifier3 {
+// Sets a field in the flatbuffer from a regex match group.
+// Returns true if successful, and false if the field couldn't be set.
+bool SetFieldFromCapturingGroup(const int group_id,
+ const FlatbufferFieldPath* field_path,
+ const UniLib::RegexMatcher* matcher,
+ ReflectiveFlatbuffer* flatbuffer);
+
+// Post-checks a regular expression match with a lua verifier script.
+// The verifier can access:
+// * `context`: The context as a string.
+// * `match`: The groups of the regex match as an array, each group gives
+// * `begin`: span start
+// * `end`: span end
+// * `text`: the text
+// The verifier is expected to return a boolean, indicating whether the
+// verification succeeded or not.
+// Returns true if the verification was successful, false if not.
+bool VerifyMatch(const std::string& context,
+ const UniLib::RegexMatcher* matcher,
+ const std::string& lua_verifier_code);
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_REGEX_MATCH_H_
diff --git a/utils/regex-match_test.cc b/utils/regex-match_test.cc
new file mode 100644
index 0000000..ef86d65
--- /dev/null
+++ b/utils/regex-match_test.cc
@@ -0,0 +1,76 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/regex-match.h"
+
+#include <memory>
+
+#include "utils/utf8/unilib.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+class LuaVerifierTest : public testing::Test {
+ protected:
+ LuaVerifierTest() : INIT_UNILIB_FOR_TESTING(unilib_) {}
+ UniLib unilib_;
+};
+
+#ifdef TC3_UNILIB_ICU
+TEST_F(LuaVerifierTest, HandlesSimpleVerification) {
+ EXPECT_TRUE(VerifyMatch(/*context=*/"", /*matcher=*/nullptr, "return true;"));
+}
+
+TEST_F(LuaVerifierTest, HandlesCustomVerification) {
+ UnicodeText pattern = UTF8ToUnicodeText("(\\d{16})",
+ /*do_copy=*/true);
+ UnicodeText message = UTF8ToUnicodeText("cc: 4012888888881881",
+ /*do_copy=*/true);
+ const std::string verifier = R"(
+function luhn(candidate)
+ local sum = 0
+ local num_digits = string.len(candidate)
+ local parity = num_digits % 2
+ for pos = 1,num_digits do
+ d = tonumber(string.sub(candidate, pos, pos))
+ if pos % 2 ~= parity then
+ d = d * 2
+ end
+ if d > 9 then
+ d = d - 9
+ end
+ sum = sum + d
+ end
+ return (sum % 10) == 0
+end
+return luhn(match[1].text);
+ )";
+ auto regex_pattern = unilib_.CreateRegexPattern(pattern);
+ ASSERT_TRUE(regex_pattern != nullptr);
+ auto matcher = regex_pattern->Matcher(message);
+ ASSERT_TRUE(matcher != nullptr);
+ int status = UniLib::RegexMatcher::kNoError;
+ ASSERT_TRUE(matcher->Find(&status) &&
+ status == UniLib::RegexMatcher::kNoError);
+
+ EXPECT_TRUE(VerifyMatch(message.ToUTF8String(), matcher.get(), verifier));
+}
+#endif
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/utils/resources.cc b/utils/resources.cc
new file mode 100644
index 0000000..ddfa499
--- /dev/null
+++ b/utils/resources.cc
@@ -0,0 +1,217 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/resources.h"
+#include "utils/base/logging.h"
+#include "utils/zlib/buffer_generated.h"
+#include "utils/zlib/zlib.h"
+
+namespace libtextclassifier3 {
+namespace {
+bool isWildcardMatch(const flatbuffers::String* left,
+ const std::string& right) {
+ return (left == nullptr || right.empty());
+}
+
+bool isExactMatch(const flatbuffers::String* left, const std::string& right) {
+ if (left == nullptr) {
+ return right.empty();
+ }
+ return left->str() == right;
+}
+
+} // namespace
+
+int Resources::LocaleMatch(const Locale& locale,
+ const LanguageTag* entry_locale) const {
+ int match = LOCALE_NO_MATCH;
+ if (isExactMatch(entry_locale->language(), locale.Language())) {
+ match |= LOCALE_LANGUAGE_MATCH;
+ } else if (isWildcardMatch(entry_locale->language(), locale.Language())) {
+ match |= LOCALE_LANGUAGE_WILDCARD_MATCH;
+ }
+
+ if (isExactMatch(entry_locale->script(), locale.Script())) {
+ match |= LOCALE_SCRIPT_MATCH;
+ } else if (isWildcardMatch(entry_locale->script(), locale.Script())) {
+ match |= LOCALE_SCRIPT_WILDCARD_MATCH;
+ }
+
+ if (isExactMatch(entry_locale->region(), locale.Region())) {
+ match |= LOCALE_REGION_MATCH;
+ } else if (isWildcardMatch(entry_locale->region(), locale.Region())) {
+ match |= LOCALE_REGION_WILDCARD_MATCH;
+ }
+
+ return match;
+}
+
+const ResourceEntry* Resources::FindResource(
+ const StringPiece resource_name) const {
+ if (resources_ == nullptr || resources_->resource_entry() == nullptr) {
+ TC3_LOG(ERROR) << "No resources defined.";
+ return nullptr;
+ }
+ const ResourceEntry* entry =
+ resources_->resource_entry()->LookupByKey(resource_name.data());
+ if (entry == nullptr) {
+ TC3_LOG(ERROR) << "Resource " << resource_name.ToString() << " not found";
+ return nullptr;
+ }
+ return entry;
+}
+
+int Resources::BestResourceForLocales(
+ const ResourceEntry* resource, const std::vector<Locale>& locales) const {
+ // Find best match based on locale.
+ int resource_id = -1;
+ int locale_match = LOCALE_NO_MATCH;
+ const auto* resources = resource->resource();
+ for (int user_locale = 0; user_locale < locales.size(); user_locale++) {
+ if (!locales[user_locale].IsValid()) {
+ continue;
+ }
+ for (int i = 0; i < resources->size(); i++) {
+ for (const int locale_id : *resources->Get(i)->locale()) {
+ const int candidate_match = LocaleMatch(
+ locales[user_locale], resources_->locale()->Get(locale_id));
+
+ // Only consider if at least the language matches.
+ if ((candidate_match & LOCALE_LANGUAGE_MATCH) == 0 &&
+ (candidate_match & LOCALE_LANGUAGE_WILDCARD_MATCH) == 0) {
+ continue;
+ }
+
+ if (candidate_match > locale_match) {
+ locale_match = candidate_match;
+ resource_id = i;
+ }
+ }
+ }
+
+ // If the language matches exactly, we are already finished.
+ // We found an exact language match.
+ if (locale_match & LOCALE_LANGUAGE_MATCH) {
+ return resource_id;
+ }
+ }
+ return resource_id;
+}
+
+bool Resources::GetResourceContent(const std::vector<Locale>& locales,
+ const StringPiece resource_name,
+ std::string* result) const {
+ const ResourceEntry* entry = FindResource(resource_name);
+ if (entry == nullptr || entry->resource() == nullptr) {
+ return false;
+ }
+
+ int resource_id = BestResourceForLocales(entry, locales);
+ if (resource_id < 0) {
+ return false;
+ }
+ const auto* resource = entry->resource()->Get(resource_id);
+ if (resource->content() != nullptr) {
+ *result = resource->content()->str();
+ return true;
+ } else if (resource->compressed_content() != nullptr) {
+ std::unique_ptr<ZlibDecompressor> decompressor = ZlibDecompressor::Instance(
+ resources_->compression_dictionary()->data(),
+ resources_->compression_dictionary()->size());
+ if (decompressor != nullptr &&
+ decompressor->MaybeDecompress(resource->compressed_content(), result)) {
+ return true;
+ }
+ }
+ return false;
+}
+
+bool CompressResources(ResourcePoolT* resources,
+ const bool build_compression_dictionary,
+ const int dictionary_sample_every) {
+ std::vector<unsigned char> dictionary;
+ if (build_compression_dictionary) {
+ {
+ // Build up a compression dictionary.
+ std::unique_ptr<ZlibCompressor> compressor = ZlibCompressor::Instance();
+ int i = 0;
+ for (auto& entry : resources->resource_entry) {
+ for (auto& resource : entry->resource) {
+ if (resource->content.empty()) {
+ continue;
+ }
+ i++;
+
+ // Use a sample of the entries to build up a custom compression
+ // dictionary. Using all entries will generally not give a benefit
+ // for small data sizes, so we subsample here.
+ if (i % dictionary_sample_every != 0) {
+ continue;
+ }
+ CompressedBufferT compressed_content;
+ compressor->Compress(resource->content, &compressed_content);
+ }
+ }
+ compressor->GetDictionary(&dictionary);
+ resources->compression_dictionary.assign(
+ dictionary.data(), dictionary.data() + dictionary.size());
+ }
+ }
+
+ for (auto& entry : resources->resource_entry) {
+ for (auto& resource : entry->resource) {
+ if (resource->content.empty()) {
+ continue;
+ }
+ // Try compressing the data.
+ std::unique_ptr<ZlibCompressor> compressor =
+ build_compression_dictionary
+ ? ZlibCompressor::Instance(dictionary.data(), dictionary.size())
+ : ZlibCompressor::Instance();
+ if (!compressor) {
+ TC3_LOG(ERROR) << "Cannot create zlib compressor.";
+ return false;
+ }
+
+ CompressedBufferT compressed_content;
+ compressor->Compress(resource->content, &compressed_content);
+
+ // Only keep compressed version if smaller.
+ if (compressed_content.uncompressed_size >
+ compressed_content.buffer.size()) {
+ resource->content.clear();
+ resource->compressed_content.reset(new CompressedBufferT);
+ *resource->compressed_content = compressed_content;
+ }
+ }
+ }
+ return true;
+}
+
+std::string CompressSerializedResources(const std::string& resources,
+ const int dictionary_sample_every) {
+ std::unique_ptr<ResourcePoolT> unpacked_resources(
+ flatbuffers::GetRoot<ResourcePool>(resources.data())->UnPack());
+ TC3_CHECK(unpacked_resources != nullptr);
+ TC3_CHECK(
+ CompressResources(unpacked_resources.get(), dictionary_sample_every));
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(ResourcePool::Pack(builder, unpacked_resources.get()));
+ return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize());
+}
+
+} // namespace libtextclassifier3
diff --git a/utils/resources.fbs b/utils/resources.fbs
new file mode 100755
index 0000000..a88c56d
--- /dev/null
+++ b/utils/resources.fbs
@@ -0,0 +1,46 @@
+//
+// Copyright (C) 2018 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+include "utils/zlib/buffer.fbs";
+
+namespace libtextclassifier3;
+table Resource {
+ locale:[int];
+ content:string;
+ compressed_content:CompressedBuffer;
+}
+
+namespace libtextclassifier3;
+table ResourceEntry {
+ name:string (key);
+ resource:[Resource];
+}
+
+// BCP 47 tag for the supported locale.
+namespace libtextclassifier3;
+table LanguageTag {
+ language:string;
+ script:string;
+ region:string;
+}
+
+namespace libtextclassifier3;
+table ResourcePool {
+ locale:[LanguageTag];
+ resource_entry:[ResourceEntry];
+ compression_dictionary:[ubyte];
+}
+
diff --git a/utils/resources.h b/utils/resources.h
new file mode 100644
index 0000000..28db0cc
--- /dev/null
+++ b/utils/resources.h
@@ -0,0 +1,76 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_RESOURCES_H_
+#define LIBTEXTCLASSIFIER_UTILS_RESOURCES_H_
+
+#include <vector>
+
+#include "utils/i18n/locale.h"
+#include "utils/resources_generated.h"
+#include "utils/strings/stringpiece.h"
+
+namespace libtextclassifier3 {
+
+// Class for accessing localized model resources.
+class Resources {
+ public:
+ explicit Resources(const ResourcePool* resources) : resources_(resources) {}
+
+ // Returns the string value associated with the particular resource.
+ // `locales` are locales in preference order.
+ bool GetResourceContent(const std::vector<Locale>& locales,
+ const StringPiece resource_name,
+ std::string* result) const;
+
+ private:
+ // Match priorities: language > script > region with wildcard matches being
+ // weaker than an exact match.
+ // For a resource lookup, at least language needs to (weakly) match.
+ // c.f. developer.android.com/guide/topics/resources/multilingual-support
+ enum LocaleMatch {
+ LOCALE_NO_MATCH = 0,
+ LOCALE_REGION_WILDCARD_MATCH = 1 << 0,
+ LOCALE_REGION_MATCH = 1 << 1,
+ LOCALE_SCRIPT_WILDCARD_MATCH = 1 << 2,
+ LOCALE_SCRIPT_MATCH = 1 << 3,
+ LOCALE_LANGUAGE_WILDCARD_MATCH = 1 << 4,
+ LOCALE_LANGUAGE_MATCH = 1 << 5
+ };
+ int LocaleMatch(const Locale& locale, const LanguageTag* entry_locale) const;
+
+ // Finds a resource entry by name.
+ const ResourceEntry* FindResource(const StringPiece resource_name) const;
+
+ // Finds the best locale matching resource from a resource entry.
+ int BestResourceForLocales(const ResourceEntry* resource,
+ const std::vector<Locale>& locales) const;
+
+ const ResourcePool* resources_;
+};
+
+// Compresses resources in place.
+bool CompressResources(ResourcePoolT* resources,
+ const bool build_compression_dictionary = false,
+ const int dictionary_sample_every = 1);
+std::string CompressSerializedResources(
+ const std::string& resources,
+ const bool build_compression_dictionary = false,
+ const int dictionary_sample_every = 1);
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_RESOURCES_H_
diff --git a/utils/resources_test.cc b/utils/resources_test.cc
new file mode 100644
index 0000000..c385f39
--- /dev/null
+++ b/utils/resources_test.cc
@@ -0,0 +1,287 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/resources.h"
+#include "utils/i18n/locale.h"
+#include "utils/resources_generated.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+class ResourcesTest
+ : public testing::TestWithParam<testing::tuple<bool, bool>> {
+ protected:
+ ResourcesTest() {}
+
+ std::string BuildTestResources(bool add_default_language = true) const {
+ ResourcePoolT test_resources;
+
+ // Test locales.
+ test_resources.locale.emplace_back(new LanguageTagT);
+ test_resources.locale.back()->language = "en";
+ test_resources.locale.back()->region = "US";
+ test_resources.locale.emplace_back(new LanguageTagT);
+ test_resources.locale.back()->language = "en";
+ test_resources.locale.back()->region = "GB";
+ test_resources.locale.emplace_back(new LanguageTagT);
+ test_resources.locale.back()->language = "de";
+ test_resources.locale.back()->region = "DE";
+ test_resources.locale.emplace_back(new LanguageTagT);
+ test_resources.locale.back()->language = "fr";
+ test_resources.locale.back()->region = "FR";
+ test_resources.locale.emplace_back(new LanguageTagT);
+ test_resources.locale.back()->language = "pt";
+ test_resources.locale.back()->region = "PT";
+ test_resources.locale.emplace_back(new LanguageTagT);
+ test_resources.locale.back()->language = "pt";
+ test_resources.locale.emplace_back(new LanguageTagT);
+ test_resources.locale.back()->language = "zh";
+ test_resources.locale.back()->script = "Hans";
+ test_resources.locale.back()->region = "CN";
+ test_resources.locale.emplace_back(new LanguageTagT);
+ test_resources.locale.back()->language = "zh";
+ test_resources.locale.emplace_back(new LanguageTagT);
+ test_resources.locale.back()->language = "fr";
+ test_resources.locale.back()->language = "fr-CA";
+ if (add_default_language) {
+ test_resources.locale.emplace_back(new LanguageTagT); // default
+ }
+
+ // Test entries.
+ test_resources.resource_entry.emplace_back(new ResourceEntryT);
+ test_resources.resource_entry.back()->name = /*resource_name=*/"A";
+
+ // en-US, default
+ test_resources.resource_entry.back()->resource.emplace_back(new ResourceT);
+ test_resources.resource_entry.back()->resource.back()->content = "localize";
+ test_resources.resource_entry.back()->resource.back()->locale.push_back(0);
+ if (add_default_language) {
+ test_resources.resource_entry.back()->resource.back()->locale.push_back(
+ 9);
+ }
+
+ // en-GB
+ test_resources.resource_entry.back()->resource.emplace_back(new ResourceT);
+ test_resources.resource_entry.back()->resource.back()->content = "localise";
+ test_resources.resource_entry.back()->resource.back()->locale.push_back(1);
+
+ // de-DE
+ test_resources.resource_entry.back()->resource.emplace_back(new ResourceT);
+ test_resources.resource_entry.back()->resource.back()->content =
+ "lokalisieren";
+ test_resources.resource_entry.back()->resource.back()->locale.push_back(2);
+
+ // fr-FR, fr-CA
+ test_resources.resource_entry.back()->resource.emplace_back(new ResourceT);
+ test_resources.resource_entry.back()->resource.back()->content =
+ "localiser";
+ test_resources.resource_entry.back()->resource.back()->locale.push_back(3);
+ test_resources.resource_entry.back()->resource.back()->locale.push_back(8);
+
+ // pt-PT
+ test_resources.resource_entry.back()->resource.emplace_back(new ResourceT);
+ test_resources.resource_entry.back()->resource.back()->content =
+ "localizar";
+ test_resources.resource_entry.back()->resource.back()->locale.push_back(4);
+
+ // pt
+ test_resources.resource_entry.back()->resource.emplace_back(new ResourceT);
+ test_resources.resource_entry.back()->resource.back()->content =
+ "concentrar";
+ test_resources.resource_entry.back()->resource.back()->locale.push_back(5);
+
+ // zh-Hans-CN
+ test_resources.resource_entry.back()->resource.emplace_back(new ResourceT);
+ test_resources.resource_entry.back()->resource.back()->content = "龙";
+ test_resources.resource_entry.back()->resource.back()->locale.push_back(6);
+
+ // zh
+ test_resources.resource_entry.back()->resource.emplace_back(new ResourceT);
+ test_resources.resource_entry.back()->resource.back()->content = "龍";
+ test_resources.resource_entry.back()->resource.back()->locale.push_back(7);
+
+ if (compress()) {
+ EXPECT_TRUE(CompressResources(
+ &test_resources,
+ /*build_compression_dictionary=*/build_dictionary()));
+ }
+
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(ResourcePool::Pack(builder, &test_resources));
+
+ return std::string(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize());
+ }
+
+ bool compress() const { return testing::get<0>(GetParam()); }
+
+ bool build_dictionary() const { return testing::get<1>(GetParam()); }
+};
+
+INSTANTIATE_TEST_SUITE_P(Compression, ResourcesTest,
+ testing::Combine(testing::Bool(), testing::Bool()));
+
+TEST_P(ResourcesTest, CorrectlyHandlesExactMatch) {
+ std::string test_resources = BuildTestResources();
+ Resources resources(
+ flatbuffers::GetRoot<ResourcePool>(test_resources.data()));
+ std::string content;
+ EXPECT_TRUE(resources.GetResourceContent({Locale::FromBCP47("en-US")},
+ /*resource_name=*/"A", &content));
+ EXPECT_EQ("localize", content);
+ EXPECT_TRUE(resources.GetResourceContent({Locale::FromBCP47("en-GB")},
+ /*resource_name=*/"A", &content));
+ EXPECT_EQ("localise", content);
+ EXPECT_TRUE(resources.GetResourceContent({Locale::FromBCP47("pt-PT")},
+ /*resource_name=*/"A", &content));
+ EXPECT_EQ("localizar", content);
+ EXPECT_TRUE(resources.GetResourceContent({Locale::FromBCP47("zh-Hans-CN")},
+ /*resource_name=*/"A", &content));
+ EXPECT_EQ("龙", content);
+ EXPECT_TRUE(resources.GetResourceContent({Locale::FromBCP47("zh")},
+ /*resource_name=*/"A", &content));
+ EXPECT_EQ("龍", content);
+ EXPECT_TRUE(resources.GetResourceContent({Locale::FromBCP47("fr-CA")},
+ /*resource_name=*/"A", &content));
+ EXPECT_EQ("localiser", content);
+}
+
+TEST_P(ResourcesTest, CorrectlyHandlesTie) {
+ std::string test_resources = BuildTestResources();
+ Resources resources(
+ flatbuffers::GetRoot<ResourcePool>(test_resources.data()));
+ // Uses first best match in case of a tie.
+ std::string content;
+ EXPECT_TRUE(resources.GetResourceContent({Locale::FromBCP47("en-CA")},
+ /*resource_name=*/"A", &content));
+ EXPECT_EQ("localize", content);
+}
+
+TEST_P(ResourcesTest, RequiresLanguageMatch) {
+ {
+ std::string test_resources =
+ BuildTestResources(/*add_default_language=*/false);
+ Resources resources(
+ flatbuffers::GetRoot<ResourcePool>(test_resources.data()));
+ EXPECT_FALSE(resources.GetResourceContent({Locale::FromBCP47("es-US")},
+ /*resource_name=*/"A",
+ /*result=*/nullptr));
+ }
+ {
+ std::string test_resources =
+ BuildTestResources(/*add_default_language=*/true);
+ Resources resources(
+ flatbuffers::GetRoot<ResourcePool>(test_resources.data()));
+ std::string content;
+ EXPECT_TRUE(resources.GetResourceContent({Locale::FromBCP47("es-US")},
+ /*resource_name=*/"A",
+ /*result=*/&content));
+ EXPECT_EQ("localize", content);
+ }
+}
+
+TEST_P(ResourcesTest, HandlesFallback) {
+ std::string test_resources = BuildTestResources();
+ Resources resources(
+ flatbuffers::GetRoot<ResourcePool>(test_resources.data()));
+ std::string content;
+ EXPECT_TRUE(resources.GetResourceContent({Locale::FromBCP47("fr-CH")},
+ /*resource_name=*/"A", &content));
+ EXPECT_EQ("localiser", content);
+ EXPECT_TRUE(resources.GetResourceContent({Locale::FromBCP47("zh-Hans")},
+ /*resource_name=*/"A", &content));
+ EXPECT_EQ("龙", content);
+ EXPECT_TRUE(resources.GetResourceContent({Locale::FromBCP47("zh-Hans-ZZ")},
+ /*resource_name=*/"A", &content));
+ EXPECT_EQ("龙", content);
+
+ // Fallback to default, en-US.
+ EXPECT_TRUE(resources.GetResourceContent({Locale::FromBCP47("ru")},
+ /*resource_name=*/"A", &content));
+ EXPECT_EQ("localize", content);
+}
+
+TEST_P(ResourcesTest, HandlesFallbackMultipleLocales) {
+ std::string test_resources = BuildTestResources();
+ Resources resources(
+ flatbuffers::GetRoot<ResourcePool>(test_resources.data()));
+ std::string content;
+
+ // Still use inexact match with primary locale if language matches,
+ // even though secondary locale would match exactly.
+ EXPECT_TRUE(resources.GetResourceContent(
+ {Locale::FromBCP47("fr-CH"), Locale::FromBCP47("en-US")},
+ /*resource_name=*/"A", &content));
+ EXPECT_EQ("localiser", content);
+
+ // Use secondary language instead of default fallback if that is an exact
+ // language match.
+ EXPECT_TRUE(resources.GetResourceContent(
+ {Locale::FromBCP47("ru"), Locale::FromBCP47("de")},
+ /*resource_name=*/"A", &content));
+ EXPECT_EQ("lokalisieren", content);
+
+ // Use tertiary language.
+ EXPECT_TRUE(resources.GetResourceContent(
+ {Locale::FromBCP47("ru"), Locale::FromBCP47("it-IT"),
+ Locale::FromBCP47("de")},
+ /*resource_name=*/"A", &content));
+ EXPECT_EQ("lokalisieren", content);
+
+ // Default fallback if no locale matches.
+ EXPECT_TRUE(resources.GetResourceContent(
+ {Locale::FromBCP47("ru"), Locale::FromBCP47("it-IT"),
+ Locale::FromBCP47("es")},
+ /*resource_name=*/"A", &content));
+ EXPECT_EQ("localize", content);
+}
+
+TEST_P(ResourcesTest, PreferGenericCallback) {
+ std::string test_resources = BuildTestResources();
+ Resources resources(
+ flatbuffers::GetRoot<ResourcePool>(test_resources.data()));
+ std::string content;
+ EXPECT_TRUE(resources.GetResourceContent({Locale::FromBCP47("pt-BR")},
+ /*resource_name=*/"A", &content));
+ EXPECT_EQ("concentrar", content); // Falls back to pt, not pt-PT.
+ EXPECT_TRUE(resources.GetResourceContent({Locale::FromBCP47("zh-Hant")},
+ /*resource_name=*/"A", &content));
+ EXPECT_EQ("龍", content); // Falls back to zh, not zh-Hans-CN.
+ EXPECT_TRUE(resources.GetResourceContent({Locale::FromBCP47("zh-Hant-CN")},
+ /*resource_name=*/"A", &content));
+ EXPECT_EQ("龍", content); // Falls back to zh, not zh-Hans-CN.
+ EXPECT_TRUE(resources.GetResourceContent({Locale::FromBCP47("zh-CN")},
+ /*resource_name=*/"A", &content));
+ EXPECT_EQ("龍", content); // Falls back to zh, not zh-Hans-CN.
+}
+
+TEST_P(ResourcesTest, PreferGenericWhenGeneric) {
+ std::string test_resources = BuildTestResources();
+ Resources resources(
+ flatbuffers::GetRoot<ResourcePool>(test_resources.data()));
+ std::string content;
+ EXPECT_TRUE(resources.GetResourceContent({Locale::FromBCP47("pt")},
+ /*resource_name=*/"A", &content));
+
+ // Uses pt, not pt-PT.
+ EXPECT_EQ("concentrar", content);
+}
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/utils/sentencepiece/double_array_trie.cc b/utils/sentencepiece/double_array_trie.cc
new file mode 100644
index 0000000..a2b66ea
--- /dev/null
+++ b/utils/sentencepiece/double_array_trie.cc
@@ -0,0 +1,69 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/sentencepiece/double_array_trie.h"
+#include "utils/base/logging.h"
+
+namespace libtextclassifier3 {
+
+bool DoubleArrayTrie::GatherPrefixMatches(
+ StringPiece input, const std::function<void(TrieMatch)>& update_fn) const {
+ uint32 pos = 0;
+ if (nodes_length_ == 0) {
+ TC3_LOG(WARNING) << "Trie is empty. Skipping.";
+ return true;
+ }
+ pos = offset(0);
+ for (int i = 0; i < input.size(); i++) {
+ if (input[i] == 0) {
+ break;
+ }
+ pos ^= static_cast<unsigned char>(input[i]);
+ // We exhausted the trie, no more matches possible.
+ if (pos < 0 || pos >= nodes_length_) {
+ break;
+ }
+ if (label(pos) != input[i]) {
+ break;
+ }
+ const bool node_has_leaf = has_leaf(pos);
+ pos ^= offset(pos);
+ if (pos < 0 || pos > nodes_length_) {
+ TC3_LOG(ERROR) << "Out-of-bounds trie search position.";
+ return false;
+ }
+ if (node_has_leaf) {
+ update_fn(TrieMatch(/*id=*/value(pos), /*match_length=*/i + 1));
+ }
+ }
+ return true;
+}
+
+bool DoubleArrayTrie::FindAllPrefixMatches(
+ StringPiece input, std::vector<TrieMatch>* matches) const {
+ return GatherPrefixMatches(
+ input, [matches](const TrieMatch match) { matches->push_back(match); });
+}
+
+bool DoubleArrayTrie::LongestPrefixMatch(StringPiece input,
+ TrieMatch* longest_match) const {
+ *longest_match = TrieMatch();
+ return GatherPrefixMatches(input, [longest_match](const TrieMatch match) {
+ *longest_match = match;
+ });
+}
+
+} // namespace libtextclassifier3
diff --git a/utils/sentencepiece/double_array_trie.h b/utils/sentencepiece/double_array_trie.h
new file mode 100644
index 0000000..0614fb4
--- /dev/null
+++ b/utils/sentencepiece/double_array_trie.h
@@ -0,0 +1,85 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_H_
+#define LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_H_
+
+#include <functional>
+#include <vector>
+
+#include "utils/base/endian.h"
+#include "utils/base/integral_types.h"
+#include "utils/sentencepiece/matcher.h"
+#include "utils/strings/stringpiece.h"
+
+namespace libtextclassifier3 {
+
+// A trie node specifies a node in the tree, either an intermediate node or
+// a leaf node.
+// A leaf node contains the id as an int of the string match. This id is encoded
+// in the lower 30 bits, thus the number of distinct ids is 2^30.
+// An intermediate node has an associated label and an offset to it's children.
+// The label is encoded in the least significant byte and must match the input
+// character during matching.
+// We account for endianness when using the node values, as they are serialized
+// (in little endian) as bytes in the flatbuffer model.
+typedef uint32 TrieNode;
+
+// A memory mappable trie, compatible with Darts::DoubleArray.
+class DoubleArrayTrie : public SentencePieceMatcher {
+ public:
+ // nodes and nodes_length specify the array of the nodes of the trie.
+ DoubleArrayTrie(const TrieNode* nodes, const int nodes_length)
+ : nodes_(nodes), nodes_length_(nodes_length) {}
+
+ // Find matches that are prefixes of a string.
+ bool FindAllPrefixMatches(StringPiece input,
+ std::vector<TrieMatch>* matches) const override;
+ // Find the longest prefix match of a string.
+ bool LongestPrefixMatch(StringPiece input,
+ TrieMatch* longest_match) const override;
+
+ private:
+ // Returns whether a node as a leaf as a child.
+ bool has_leaf(uint32 i) const { return nodes_[i] & 0x100; }
+
+ // Available when a node is a leaf.
+ int value(uint32 i) const {
+ return static_cast<int>(LittleEndian::ToHost32(nodes_[i]) & 0x7fffffff);
+ }
+
+ // Label associated with a node.
+ // A leaf node will have the MSB set and thus return an invalid label.
+ uint32 label(uint32 i) const {
+ return LittleEndian::ToHost32(nodes_[i]) & 0x800000ff;
+ }
+
+ // Returns offset to children.
+ uint32 offset(uint32 i) const {
+ const uint32 node = LittleEndian::ToHost32(nodes_[i]);
+ return (node >> 10) << ((node & 0x200) >> 6);
+ }
+
+ bool GatherPrefixMatches(
+ StringPiece input, const std::function<void(TrieMatch)>& update_fn) const;
+
+ const TrieNode* nodes_;
+ const int nodes_length_;
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_H_
diff --git a/utils/sentencepiece/double_array_trie_test.cc b/utils/sentencepiece/double_array_trie_test.cc
new file mode 100644
index 0000000..d7fc44b
--- /dev/null
+++ b/utils/sentencepiece/double_array_trie_test.cc
@@ -0,0 +1,114 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <fstream>
+#include <string>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+#include "utils/sentencepiece/double_array_trie.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+std::string GetTestConfigPath() {
+ return "";
+}
+
+TEST(DoubleArrayTest, Lookup) {
+ // Test trie that contains pieces "hell", "hello", "o", "there".
+ std::ifstream test_config_stream(GetTestConfigPath());
+ std::string config((std::istreambuf_iterator<char>(test_config_stream)),
+ (std::istreambuf_iterator<char>()));
+ DoubleArrayTrie trie(reinterpret_cast<const TrieNode*>(config.data()),
+ config.size() / sizeof(TrieNode));
+
+ {
+ std::vector<TrieMatch> matches;
+ EXPECT_TRUE(trie.FindAllPrefixMatches("hello there", &matches));
+ EXPECT_EQ(matches.size(), 2);
+ EXPECT_EQ(matches[0].id, 0 /*hell*/);
+ EXPECT_EQ(matches[0].match_length, 4 /*hell*/);
+ EXPECT_EQ(matches[1].id, 1 /*hello*/);
+ EXPECT_EQ(matches[1].match_length, 5 /*hello*/);
+ }
+
+ {
+ std::vector<TrieMatch> matches;
+ EXPECT_TRUE(trie.FindAllPrefixMatches("he", &matches));
+ EXPECT_THAT(matches, testing::IsEmpty());
+ }
+
+ {
+ std::vector<TrieMatch> matches;
+ EXPECT_TRUE(trie.FindAllPrefixMatches("abcd", &matches));
+ EXPECT_THAT(matches, testing::IsEmpty());
+ }
+
+ {
+ std::vector<TrieMatch> matches;
+ EXPECT_TRUE(trie.FindAllPrefixMatches("", &matches));
+ EXPECT_THAT(matches, testing::IsEmpty());
+ }
+
+ {
+ std::vector<TrieMatch> matches;
+ EXPECT_TRUE(trie.FindAllPrefixMatches("hi there", &matches));
+ EXPECT_THAT(matches, testing::IsEmpty());
+ }
+
+ {
+ std::vector<TrieMatch> matches;
+ EXPECT_TRUE(trie.FindAllPrefixMatches(StringPiece("\0", 1), &matches));
+ EXPECT_THAT(matches, testing::IsEmpty());
+ }
+
+ {
+ std::vector<TrieMatch> matches;
+ EXPECT_TRUE(
+ trie.FindAllPrefixMatches(StringPiece("\xff, \xfe", 2), &matches));
+ EXPECT_THAT(matches, testing::IsEmpty());
+ }
+
+ {
+ TrieMatch match;
+ EXPECT_TRUE(trie.LongestPrefixMatch("hella there", &match));
+ EXPECT_EQ(match.id, 0 /*hell*/);
+ }
+
+ {
+ TrieMatch match;
+ EXPECT_TRUE(trie.LongestPrefixMatch("hello there", &match));
+ EXPECT_EQ(match.id, 1 /*hello*/);
+ }
+
+ {
+ TrieMatch match;
+ EXPECT_TRUE(trie.LongestPrefixMatch("abcd", &match));
+ EXPECT_EQ(match.id, -1);
+ }
+
+ {
+ TrieMatch match;
+ EXPECT_TRUE(trie.LongestPrefixMatch("", &match));
+ EXPECT_EQ(match.id, -1);
+ }
+}
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/utils/sentencepiece/encoder.cc b/utils/sentencepiece/encoder.cc
new file mode 100644
index 0000000..51cda30
--- /dev/null
+++ b/utils/sentencepiece/encoder.cc
@@ -0,0 +1,96 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/sentencepiece/encoder.h"
+
+namespace libtextclassifier3 {
+
+bool Encoder::Encode(StringPiece normalized_text,
+ std::vector<int>* encoded_text) const {
+ const int len = normalized_text.size();
+ if (len <= 0) {
+ *encoded_text = {start_code_, end_code_};
+ return true;
+ }
+ // We use `previous_pos` to indicate whether a dynamic programming state was
+ // reachable.
+ std::vector<SegmentationEntry> segmentation(
+ len + 1, {/*score=*/0, /*previous_pos=*/-1, /*piece_id=*/-1,
+ /*num_pieces=*/0});
+ for (int i = 0; i < len; i++) {
+ // State couldn't be reached.
+ if (i > 0 && segmentation[i].previous_pos < 0) {
+ // Advance position.
+ normalized_text.RemovePrefix(1);
+ continue;
+ }
+ // Check whether we can use the unknown token.
+ if (unknown_code_ >= 0) {
+ const int pos = i + 1;
+ const float unknown_penalty = segmentation[i].score + unknown_score_;
+ if (segmentation[pos].previous_pos < 0 ||
+ segmentation[pos].score < unknown_penalty) {
+ // Merge multiple unknown tokens into one.
+ if (segmentation[i].piece_id == unknown_code_) {
+ segmentation[pos] = {/*score=*/unknown_penalty,
+ /*previous_pos=*/segmentation[i].previous_pos,
+ /*piece_id=*/unknown_code_,
+ /*num_pieces=*/segmentation[i].num_pieces};
+ } else {
+ segmentation[pos] = {/*score=*/unknown_penalty,
+ /*previous_pos=*/i,
+ /*piece_id=*/unknown_code_,
+ /*num_pieces=*/segmentation[i].num_pieces + 1};
+ }
+ }
+ }
+ std::vector<TrieMatch> matches;
+ if (!matcher_->FindAllPrefixMatches(normalized_text, &matches)) {
+ TC3_LOG(ERROR)
+ << "Couldn't successfully gather prefix sentence piece matches.";
+ return false;
+ }
+ for (const auto& match : matches) {
+ TC3_CHECK(match.id >= 0 && match.id < num_pieces_);
+ const int pos = i + match.match_length;
+ const float candidate_score = segmentation[i].score + scores_[match.id];
+ if (segmentation[pos].previous_pos < 0 ||
+ segmentation[pos].score < candidate_score) {
+ segmentation[pos] = {/*score=*/candidate_score, /*previous_pos=*/i,
+ /*piece_id=*/match.id + encoding_offset_,
+ /*num_pieces=*/segmentation[i].num_pieces + 1};
+ }
+ }
+ // Advance position.
+ normalized_text.RemovePrefix(1);
+ }
+ if (segmentation[len].num_pieces <= 0) {
+ *encoded_text = {start_code_, end_code_};
+ return true;
+ }
+ const int num_pieces = segmentation[len].num_pieces;
+ encoded_text->resize(num_pieces + 2);
+ (*encoded_text)[num_pieces + 1] = end_code_;
+ int pos = len;
+ for (int i = num_pieces; i > 0; i--) {
+ (*encoded_text)[i] = segmentation[pos].piece_id;
+ pos = segmentation[pos].previous_pos;
+ }
+ (*encoded_text)[0] = start_code_;
+ return true;
+}
+
+} // namespace libtextclassifier3
diff --git a/utils/sentencepiece/encoder.h b/utils/sentencepiece/encoder.h
new file mode 100644
index 0000000..6c69077
--- /dev/null
+++ b/utils/sentencepiece/encoder.h
@@ -0,0 +1,89 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_ENCODER_H_
+#define LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_ENCODER_H_
+
+#include <vector>
+
+#include "utils/base/logging.h"
+#include "utils/sentencepiece/matcher.h"
+#include "utils/strings/stringpiece.h"
+
+namespace libtextclassifier3 {
+
+// Encoder to segment/tokenize strings into pieces such that the sum of the
+// scores of the pieces used is maximized.
+class Encoder {
+ public:
+ // matcher: the list of valid sentence pieces represented as a matcher, e.g.
+ // a trie.
+ // num_pieces: the number of pieces in the trie.
+ // pieces_scores: the scores of the individual pieces.
+ // start_code: code that is used as encoding of the start of input.
+ // end_code: code that is used as encoding of the end of input.
+ // encoding_offset: value added to the sentence piece ids to make them
+ // not interesecting with start_code and end_code.
+ // unknown_code: code that is used for out-of-dictionary characters.
+ // unknown_score: the penality score associated with the unknown code.
+ Encoder(const SentencePieceMatcher* matcher, const int num_pieces,
+ const float* pieces_scores, int start_code = 0, int end_code = 1,
+ int encoding_offset = 2, int unknown_code = -1,
+ float unknown_score = 0.f)
+ : num_pieces_(num_pieces),
+ scores_(pieces_scores),
+ matcher_(matcher),
+ start_code_(start_code),
+ end_code_(end_code),
+ encoding_offset_(encoding_offset),
+ unknown_code_(unknown_code),
+ unknown_score_(unknown_score) {}
+
+ // Segment the input so that the total score of the pieces used is maximized.
+ // This is a simplified implementation of the general Viterbi algorithm,
+ // assuming independence between individual pieces.
+ bool Encode(StringPiece normalized_text,
+ std::vector<int>* encoded_text) const;
+
+ private:
+ // State in the dynamic programming algorithm.
+ struct SegmentationEntry {
+ // Accumulated score.
+ float score;
+
+ // Position before last piece.
+ int previous_pos;
+
+ // Last piece used.
+ int piece_id;
+
+ // Total number of pieces used.
+ int num_pieces;
+ };
+
+ const int num_pieces_;
+ const float* scores_;
+ const SentencePieceMatcher* matcher_;
+ const int start_code_;
+ const int end_code_;
+ const int encoding_offset_;
+ const int unknown_code_;
+ const int unknown_score_;
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_ENCODER_H_
diff --git a/utils/sentencepiece/encoder_test.cc b/utils/sentencepiece/encoder_test.cc
new file mode 100644
index 0000000..9082cca
--- /dev/null
+++ b/utils/sentencepiece/encoder_test.cc
@@ -0,0 +1,122 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <memory>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+#include "utils/base/integral_types.h"
+#include "utils/sentencepiece/encoder.h"
+#include "utils/sentencepiece/sorted_strings_table.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+using testing::ElementsAre;
+
+TEST(EncoderTest, SimpleTokenization) {
+ const char pieces[] = "hell\0hello\0o\0there\0";
+ const uint32 offsets[] = {0, 5, 11, 13};
+ float scores[] = {-0.5, -1.0, -10.0, -1.0};
+ std::unique_ptr<SentencePieceMatcher> matcher(new SortedStringsTable(
+ /*num_pieces=*/4, offsets, StringPiece(pieces, 18)));
+ const Encoder encoder(matcher.get(),
+ /*num_pieces=*/4, scores);
+
+ {
+ std::vector<int> encoded_text;
+ EXPECT_TRUE(encoder.Encode("hellothere", &encoded_text));
+ EXPECT_THAT(encoded_text, ElementsAre(0, 3, 5, 1));
+ }
+
+ // Make probability of hello very low:
+ // hello gets now tokenized as hell + o.
+ scores[1] = -100.0;
+ {
+ std::vector<int> encoded_text;
+ EXPECT_TRUE(encoder.Encode("hellothere", &encoded_text));
+ EXPECT_THAT(encoded_text, ElementsAre(0, 2, 4, 5, 1));
+ }
+}
+
+TEST(EncoderTest, HandlesEdgeCases) {
+ const char pieces[] = "hell\0hello\0o\0there\0";
+ const uint32 offsets[] = {0, 5, 11, 13};
+ float scores[] = {-0.5, -1.0, -10.0, -1.0};
+ std::unique_ptr<SentencePieceMatcher> matcher(new SortedStringsTable(
+ /*num_pieces=*/4, offsets, StringPiece(pieces, 18)));
+ const Encoder encoder(matcher.get(),
+ /*num_pieces=*/4, scores);
+ {
+ std::vector<int> encoded_text;
+ EXPECT_TRUE(encoder.Encode("hellhello", &encoded_text));
+ EXPECT_THAT(encoded_text, ElementsAre(0, 2, 3, 1));
+ }
+ {
+ std::vector<int> encoded_text;
+ EXPECT_TRUE(encoder.Encode("hellohell", &encoded_text));
+ EXPECT_THAT(encoded_text, ElementsAre(0, 3, 2, 1));
+ }
+ {
+ std::vector<int> encoded_text;
+ EXPECT_TRUE(encoder.Encode("", &encoded_text));
+ EXPECT_THAT(encoded_text, ElementsAre(0, 1));
+ }
+ {
+ std::vector<int> encoded_text;
+ EXPECT_TRUE(encoder.Encode("hellathere", &encoded_text));
+ EXPECT_THAT(encoded_text, ElementsAre(0, 1));
+ }
+}
+
+TEST(EncoderTest, HandlesOutOfDictionary) {
+ const char pieces[] = "hell\0hello\0o\0there\0";
+ const uint32 offsets[] = {0, 5, 11, 13};
+ float scores[] = {-0.5, -1.0, -10.0, -1.0};
+ std::unique_ptr<SentencePieceMatcher> matcher(new SortedStringsTable(
+ /*num_pieces=*/4, offsets, StringPiece(pieces, 18)));
+ const Encoder encoder(matcher.get(),
+ /*num_pieces=*/4, scores,
+ /*start_code=*/0, /*end_code=*/1,
+ /*encoding_offset=*/3, /*unknown_code=*/2,
+ /*unknown_score=*/-100.0);
+ {
+ std::vector<int> encoded_text;
+ EXPECT_TRUE(encoder.Encode("hellhello", &encoded_text));
+ EXPECT_THAT(encoded_text, ElementsAre(0, 3, 4, 1));
+ }
+ {
+ std::vector<int> encoded_text;
+ EXPECT_TRUE(encoder.Encode("hellohell", &encoded_text));
+ EXPECT_THAT(encoded_text, ElementsAre(0, 4, 3, 1));
+ }
+ {
+ std::vector<int> encoded_text;
+ EXPECT_TRUE(encoder.Encode("", &encoded_text));
+ EXPECT_THAT(encoded_text, ElementsAre(0, 1));
+ }
+ {
+ std::vector<int> encoded_text;
+ EXPECT_TRUE(encoder.Encode("hellathere", &encoded_text));
+ EXPECT_THAT(encoded_text,
+ ElementsAre(0, /*hell*/ 3, /*unknown*/ 2, /*there*/ 6, 1));
+ }
+}
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/utils/sentencepiece/matcher.h b/utils/sentencepiece/matcher.h
new file mode 100644
index 0000000..47e6560
--- /dev/null
+++ b/utils/sentencepiece/matcher.h
@@ -0,0 +1,47 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_MATCHER_H_
+#define LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_MATCHER_H_
+
+#include <vector>
+#include "utils/strings/stringpiece.h"
+
+namespace libtextclassifier3 {
+
+struct TrieMatch {
+ TrieMatch() {}
+ TrieMatch(int id, int match_length) : id(id), match_length(match_length) {}
+ int id = -1;
+ int match_length = -1;
+};
+
+class SentencePieceMatcher {
+ public:
+ virtual ~SentencePieceMatcher() {}
+
+ // Find matches that are prefixes of a string.
+ virtual bool FindAllPrefixMatches(StringPiece input,
+ std::vector<TrieMatch>* matches) const = 0;
+
+ // Find the longest prefix match of a string.
+ virtual bool LongestPrefixMatch(StringPiece input,
+ TrieMatch* longest_match) const = 0;
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_MATCHER_H_
diff --git a/utils/sentencepiece/normalizer.cc b/utils/sentencepiece/normalizer.cc
new file mode 100644
index 0000000..9d893fd
--- /dev/null
+++ b/utils/sentencepiece/normalizer.cc
@@ -0,0 +1,151 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/sentencepiece/normalizer.h"
+
+#include "utils/base/logging.h"
+#include "utils/strings/utf8.h"
+
+namespace libtextclassifier3 {
+
+bool SentencePieceNormalizer::Normalize(StringPiece input,
+ std::string* normalized_input) const {
+ // Ignores heading space.
+ if (remove_extra_whitespaces_) {
+ while (!input.empty()) {
+ std::pair<StringPiece, int> suffix_and_length;
+ if (!NormalizePrefix(input, &suffix_and_length)) {
+ TC3_LOG(ERROR) << "Couldn't find match in normalization table.";
+ return false;
+ }
+ if (suffix_and_length.second <= 0) {
+ TC3_LOG(ERROR) << "Consumed string is empty.";
+ return false;
+ }
+ if (suffix_and_length.first.size() != 1 ||
+ suffix_and_length.first[0] != ' ') {
+ break;
+ }
+ input.RemovePrefix(suffix_and_length.second);
+ }
+ }
+
+ if (input.empty()) {
+ *normalized_input = "";
+ return true;
+ }
+
+ // Reserves the output buffer to avoid re-allocations.
+ const int kReservedSize = input.size() * 3;
+ normalized_input->reserve(kReservedSize);
+
+ // Replaces white space with U+2581 (LOWER ONE EIGHT BLOCK)
+ // if escape_whitespaces() is set (default = true).
+ const StringPiece kSpaceSymbol = "\xe2\x96\x81";
+
+ // Adds a space symbol as a prefix (default is true)
+ // With this prefix, "world" and "hello world" are converted into
+ // "_world" and "_hello_world", which help the trainer to extract
+ // "_world" as one symbol.
+ if (add_dummy_prefix_) {
+ if (escape_whitespaces_) {
+ normalized_input->append(kSpaceSymbol.data(), kSpaceSymbol.size());
+ } else {
+ normalized_input->append(" ");
+ }
+ }
+
+ bool is_prev_space = remove_extra_whitespaces_;
+ while (!input.empty()) {
+ std::pair<StringPiece, int> p;
+ if (!NormalizePrefix(input, &p)) {
+ TC3_LOG(ERROR) << "Couldn't normalize string.";
+ return false;
+ }
+ if (p.second <= 0) {
+ TC3_LOG(ERROR) << "Consumed string is empty.";
+ return false;
+ }
+
+ StringPiece sp = p.first;
+
+ // Removes heading spaces in sentence piece,
+ // if the previous sentence piece ends with whitespace.
+ while (is_prev_space && ConsumePrefix(&sp, " ")) {
+ }
+
+ if (!sp.empty()) {
+ const char* data = sp.data();
+ for (int n = 0; n < sp.size(); ++n) {
+ if (escape_whitespaces_ && data[n] == ' ') {
+ normalized_input->append(kSpaceSymbol.data(), kSpaceSymbol.size());
+ } else {
+ *normalized_input += data[n];
+ }
+ }
+ // Checks whether the last character of sp is whitespace.
+ is_prev_space = EndsWith(sp, " ");
+ }
+ input.RemovePrefix(p.second);
+ is_prev_space = is_prev_space && remove_extra_whitespaces_;
+ }
+
+ // Ignores tailing space.
+ if (remove_extra_whitespaces_) {
+ const StringPiece space = escape_whitespaces_ ? kSpaceSymbol : " ";
+ while (EndsWith(*normalized_input, space)) {
+ const int length = normalized_input->size() - space.size();
+ normalized_input->resize(length);
+ }
+ }
+ return true;
+}
+
+bool SentencePieceNormalizer::NormalizePrefix(
+ StringPiece input, std::pair<StringPiece, int>* prefix) const {
+ if (input.empty()) return true;
+ TrieMatch match;
+ if (!charsmap_trie_.LongestPrefixMatch(input, &match)) {
+ TC3_LOG(ERROR) << "Couldn't find match in normalization table.";
+ return false;
+ }
+ const bool no_match = match.match_length <= 0;
+ if (no_match) {
+ const int char_length = ValidUTF8CharLength(input.data(), input.size());
+ if (char_length <= 0) {
+ // Found a malformed utf8.
+ // The rune is set to be 0xFFFD (REPLACEMENT CHARACTER),
+ // which is a valid Unicode of three bytes in utf8,
+ // but here we only consume one byte.
+ static const char kReplacementChar[] = "\xEF\xBF\xBD";
+ prefix->first = StringPiece(kReplacementChar, 3);
+ prefix->second = 1; // Consumes 1 byte, buts emit 0xFFFD.
+ } else {
+ prefix->first = StringPiece(input.data(), char_length);
+ prefix->second = char_length;
+ }
+ } else {
+ if (match.id < 0 || match.id >= charsmap_normalized_.size()) {
+ TC3_LOG(ERROR) << "Invalid entry in normalization table.";
+ return false;
+ }
+ prefix->first = StringPiece(&charsmap_normalized_.data()[match.id]);
+ prefix->second = match.match_length;
+ }
+ return true;
+}
+
+} // namespace libtextclassifier3
diff --git a/utils/sentencepiece/normalizer.h b/utils/sentencepiece/normalizer.h
new file mode 100644
index 0000000..1d3aeb5
--- /dev/null
+++ b/utils/sentencepiece/normalizer.h
@@ -0,0 +1,80 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_NORMALIZER_H_
+#define LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_NORMALIZER_H_
+
+#include <memory>
+#include <string>
+
+#include "utils/sentencepiece/double_array_trie.h"
+#include "utils/strings/stringpiece.h"
+
+namespace libtextclassifier3 {
+
+// Normalizer implements a simple text normalizer with user-defined
+// string-to-string rules and leftmost longest matching.
+class SentencePieceNormalizer {
+ public:
+ // charsmap_trie and charsmap_normalized specify the normalization/replacement
+ // string-to-string rules in the following way:
+ // A match in the trie for a string will return the offset in
+ // charsmap_normalized that contains the replacement string.
+ //
+ // add_dummy_prefix: Whether to add dummy whitespace at the beginning of the
+ // text in order to treat "world" in "world" and "hello world" uniformly.
+ //
+ // remove_extra_whitespaces: Whether to remove leading, trailing and duplicate
+ // internal whitespace.
+ //
+ // escape_whitespaces: Whether to replace whitespace with a meta symbol.
+ SentencePieceNormalizer(const DoubleArrayTrie& charsmap_trie,
+ StringPiece charsmap_normalized,
+ bool add_dummy_prefix = true,
+ bool remove_extra_whitespaces = true,
+ bool escape_whitespaces = true)
+ : charsmap_trie_(charsmap_trie),
+ charsmap_normalized_(charsmap_normalized),
+ add_dummy_prefix_(add_dummy_prefix),
+ remove_extra_whitespaces_(remove_extra_whitespaces),
+ escape_whitespaces_(escape_whitespaces) {}
+
+ // Normalizes a plain utf8 string into an internal representation for
+ // Sentencepiece model.
+ bool Normalize(StringPiece input, std::string* normalized_input) const;
+
+ private:
+ // Normalizes the prefix of `input` and returns the pair of
+ // normalized prefix and the length of the prefix of `input` processed in the
+ // normalization.
+ bool NormalizePrefix(StringPiece input,
+ std::pair<StringPiece, int>* prefix) const;
+
+ // Internal trie for efficient longest prefix string matching.
+ DoubleArrayTrie charsmap_trie_;
+
+ // "\0" delimitered concatenated normalized strings.
+ // the value of `charsmap_trie_` stores offsets into this string.
+ StringPiece charsmap_normalized_;
+
+ const bool add_dummy_prefix_;
+ const bool remove_extra_whitespaces_;
+ const bool escape_whitespaces_;
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_NORMALIZER_H_
diff --git a/utils/sentencepiece/normalizer_test.cc b/utils/sentencepiece/normalizer_test.cc
new file mode 100644
index 0000000..a5d6bf9
--- /dev/null
+++ b/utils/sentencepiece/normalizer_test.cc
@@ -0,0 +1,198 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <fstream>
+#include <string>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+#include "utils/sentencepiece/double_array_trie.h"
+#include "utils/sentencepiece/normalizer.h"
+#include "utils/sentencepiece/test_utils.h"
+#include "utils/strings/stringpiece.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+std::string GetTestConfigPath() {
+ return "";
+}
+
+TEST(NormalizerTest, NormalizesAsReferenceNormalizer) {
+ std::ifstream test_config_stream(GetTestConfigPath());
+ std::string config((std::istreambuf_iterator<char>(test_config_stream)),
+ (std::istreambuf_iterator<char>()));
+ SentencePieceNormalizer normalizer =
+ NormalizerFromSpec(config, /*add_dummy_prefix=*/true,
+ /*remove_extra_whitespaces=*/true,
+ /*escape_whitespaces=*/true);
+ {
+ std::string normalized;
+ EXPECT_TRUE(normalizer.Normalize("hello there", &normalized));
+ EXPECT_EQ(normalized, "▁hello▁there");
+ }
+
+ // Redundant whitespace.
+ {
+ std::string normalized;
+ EXPECT_TRUE(normalizer.Normalize("when is the world cup?", &normalized));
+ EXPECT_EQ(normalized, "▁when▁is▁the▁world▁cup?");
+ }
+
+ // Different whitespace.
+ {
+ std::string normalized;
+ EXPECT_TRUE(normalizer.Normalize("general\tkenobi", &normalized));
+ EXPECT_EQ(normalized, "▁general▁kenobi");
+ }
+
+ // NFKC char to multi-char normalization.
+ {
+ std::string normalized;
+ EXPECT_TRUE(normalizer.Normalize("㍿", &normalized));
+ EXPECT_EQ(normalized, "▁株式会社");
+ }
+
+ // Half width katakana, character composition happens.
+ {
+ std::string normalized;
+ EXPECT_TRUE(normalizer.Normalize(" グーグル ", &normalized));
+ EXPECT_EQ(normalized, "▁グーグル");
+ }
+
+ // NFKC char to char normalization.
+ {
+ std::string normalized;
+ EXPECT_TRUE(normalizer.Normalize("①②③", &normalized));
+ EXPECT_EQ(normalized, "▁123");
+ }
+}
+
+TEST(NormalizerTest, NoDummyPrefix) {
+ std::ifstream test_config_stream(GetTestConfigPath());
+ std::string config((std::istreambuf_iterator<char>(test_config_stream)),
+ (std::istreambuf_iterator<char>()));
+ SentencePieceNormalizer normalizer =
+ NormalizerFromSpec(config, /*add_dummy_prefix=*/false,
+ /*remove_extra_whitespaces=*/true,
+ /*escape_whitespaces=*/true);
+
+ // NFKC char to char normalization.
+ {
+ std::string normalized;
+ EXPECT_TRUE(normalizer.Normalize("hello there", &normalized));
+ EXPECT_EQ(normalized, "hello▁there");
+ }
+
+ // Redundant whitespace.
+ {
+ std::string normalized;
+ EXPECT_TRUE(normalizer.Normalize("when is the world cup?", &normalized));
+ EXPECT_EQ(normalized, "when▁is▁the▁world▁cup?");
+ }
+
+ // Different whitespace.
+ {
+ std::string normalized;
+ EXPECT_TRUE(normalizer.Normalize("general\tkenobi", &normalized));
+ EXPECT_EQ(normalized, "general▁kenobi");
+ }
+
+ // NFKC char to multi-char normalization.
+ {
+ std::string normalized;
+ EXPECT_TRUE(normalizer.Normalize("㍿", &normalized));
+ EXPECT_EQ(normalized, "株式会社");
+ }
+
+ // Half width katakana, character composition happens.
+ {
+ std::string normalized;
+ EXPECT_TRUE(normalizer.Normalize(" グーグル ", &normalized));
+ EXPECT_EQ(normalized, "グーグル");
+ }
+
+ // NFKC char to char normalization.
+ {
+ std::string normalized;
+ EXPECT_TRUE(normalizer.Normalize("①②③", &normalized));
+ EXPECT_EQ(normalized, "123");
+ }
+}
+
+TEST(NormalizerTest, NoRemoveExtraWhitespace) {
+ std::ifstream test_config_stream(GetTestConfigPath());
+ std::string config((std::istreambuf_iterator<char>(test_config_stream)),
+ (std::istreambuf_iterator<char>()));
+ SentencePieceNormalizer normalizer =
+ NormalizerFromSpec(config, /*add_dummy_prefix=*/false,
+ /*remove_extra_whitespaces=*/false,
+ /*escape_whitespaces=*/true);
+
+ {
+ std::string normalized;
+ EXPECT_TRUE(normalizer.Normalize("hello there", &normalized));
+ EXPECT_EQ(normalized, "hello▁there");
+ }
+
+ // Redundant whitespace.
+ {
+ std::string normalized;
+ EXPECT_TRUE(normalizer.Normalize("when is the world cup?", &normalized));
+ EXPECT_EQ(normalized, "when▁is▁▁the▁▁world▁cup?");
+ }
+
+ // Different whitespace.
+ {
+ std::string normalized;
+ EXPECT_TRUE(normalizer.Normalize("general\tkenobi", &normalized));
+ EXPECT_EQ(normalized, "general▁kenobi");
+ }
+}
+
+TEST(NormalizerTest, NoEscapeWhitespaces) {
+ std::ifstream test_config_stream(GetTestConfigPath());
+ std::string config((std::istreambuf_iterator<char>(test_config_stream)),
+ (std::istreambuf_iterator<char>()));
+ SentencePieceNormalizer normalizer =
+ NormalizerFromSpec(config, /*add_dummy_prefix=*/false,
+ /*remove_extra_whitespaces=*/false,
+ /*escape_whitespaces=*/false);
+
+ {
+ std::string normalized;
+ EXPECT_TRUE(normalizer.Normalize("hello there", &normalized));
+ EXPECT_EQ(normalized, "hello there");
+ }
+
+ // Redundant whitespace.
+ {
+ std::string normalized;
+ EXPECT_TRUE(normalizer.Normalize("when is the world cup?", &normalized));
+ EXPECT_EQ(normalized, "when is the world cup?");
+ }
+
+ // Different whitespace.
+ {
+ std::string normalized;
+ EXPECT_TRUE(normalizer.Normalize("general\tkenobi", &normalized));
+ EXPECT_EQ(normalized, "general kenobi");
+ }
+}
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/utils/sentencepiece/sorted_strings_table.cc b/utils/sentencepiece/sorted_strings_table.cc
new file mode 100644
index 0000000..8e7e9ba
--- /dev/null
+++ b/utils/sentencepiece/sorted_strings_table.cc
@@ -0,0 +1,111 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/sentencepiece/sorted_strings_table.h"
+
+#include <algorithm>
+#include "utils/base/logging.h"
+
+namespace libtextclassifier3 {
+
+void SortedStringsTable::GatherPrefixMatches(
+ StringPiece input, const std::function<void(TrieMatch)>& update_fn) const {
+ int left = 0;
+ int right = num_pieces_;
+ int span_size = right - left;
+ int match_length = 0;
+
+ // Loop invariant:
+ // at the ith iteration, all strings from `left` ... `right` match the input
+ // on the first `match_length` characters.
+ while (span_size > use_linear_scan_threshold_) {
+ if (match_length >= input.length()) {
+ return;
+ }
+
+ // We find the possible range of pieces in `left` ... `right` matching the
+ // `match_length` + 1 character with two binary searches:
+ // `lower_bound` to find the start of the range of matching pieces.
+ // `upper_bound` to find the non-inclusive end of the range.
+ left = (std::lower_bound(
+ offsets_ + left, offsets_ + right,
+ static_cast<unsigned char>(input[match_length]),
+ [this, match_length](uint32 piece_offset, uint32 c) -> bool {
+ return static_cast<unsigned char>(
+ pieces_[piece_offset + match_length]) < c;
+ }) -
+ offsets_);
+ right = (std::upper_bound(
+ offsets_ + left, offsets_ + right,
+ static_cast<unsigned char>(input[match_length]),
+ [this, match_length](uint32 c, uint32 piece_offset) -> bool {
+ return c < static_cast<unsigned char>(
+ pieces_[piece_offset + match_length]);
+ }) -
+ offsets_);
+ span_size = right - left;
+ if (span_size <= 0) {
+ return;
+ }
+ ++match_length;
+
+ // Due to the loop invariant and the fact that the strings are sorted, there
+ // can only be one piece matching completely now, namely at left.
+ if (pieces_[offsets_[left] + match_length] == 0) {
+ update_fn(TrieMatch(/*id=*/left,
+ /*match_length=*/match_length));
+ left++;
+ }
+ }
+
+ // Use linear scan for small problem instances.
+ // By the loop invariant characters 0...`match_length` of all pieces in
+ // in `left`...`right` match the input on 0...`match_length`.
+ for (int i = left; i < right; i++) {
+ bool matches = true;
+ int piece_match_length = match_length;
+ for (int k = offsets_[i] + piece_match_length; pieces_[k] != 0; k++) {
+ if (match_length >= input.size() ||
+ input[piece_match_length] != pieces_[k]) {
+ matches = false;
+ break;
+ }
+ piece_match_length++;
+ }
+ if (matches) {
+ update_fn(TrieMatch(/*id=*/i,
+ /*match_length=*/piece_match_length));
+ }
+ }
+}
+
+bool SortedStringsTable::FindAllPrefixMatches(
+ StringPiece input, std::vector<TrieMatch>* matches) const {
+ GatherPrefixMatches(
+ input, [matches](const TrieMatch match) { matches->push_back(match); });
+ return true;
+}
+
+bool SortedStringsTable::LongestPrefixMatch(StringPiece input,
+ TrieMatch* longest_match) const {
+ *longest_match = TrieMatch();
+ GatherPrefixMatches(input, [longest_match](const TrieMatch match) {
+ *longest_match = match;
+ });
+ return true;
+}
+
+} // namespace libtextclassifier3
diff --git a/utils/sentencepiece/sorted_strings_table.h b/utils/sentencepiece/sorted_strings_table.h
new file mode 100644
index 0000000..69f638a
--- /dev/null
+++ b/utils/sentencepiece/sorted_strings_table.h
@@ -0,0 +1,67 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_SORTED_STRINGS_TABLE_H_
+#define LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_SORTED_STRINGS_TABLE_H_
+
+#include <functional>
+#include <vector>
+
+#include "utils/base/integral_types.h"
+#include "utils/sentencepiece/matcher.h"
+#include "utils/strings/stringpiece.h"
+
+namespace libtextclassifier3 {
+
+// A matcher to find string pieces matching prefixes of an input string.
+// The list of reference strings are kept in sorted order in a zero separated
+// string.
+// binary search is used to find all prefix matches.
+// num_pieces: Number of sentence pieces.
+// offsets: Offsets into `pieces` where a string starts.
+// pieces: String pieces, concatenated in sorted order and zero byte separated.
+// use_linear_scan_threshold: Minimum size of binary search range before
+// switching to a linear sweep for prefix match testing.
+class SortedStringsTable : public SentencePieceMatcher {
+ public:
+ SortedStringsTable(const int num_pieces, const uint32* offsets,
+ StringPiece pieces,
+ const int use_linear_scan_threshold = 10)
+ : num_pieces_(num_pieces),
+ offsets_(offsets),
+ pieces_(pieces),
+ use_linear_scan_threshold_(use_linear_scan_threshold) {}
+
+ // Find matches that are prefixes of a string.
+ bool FindAllPrefixMatches(StringPiece input,
+ std::vector<TrieMatch>* matches) const override;
+ // Find the longest prefix match of a string.
+ bool LongestPrefixMatch(StringPiece input,
+ TrieMatch* longest_match) const override;
+
+ private:
+ void GatherPrefixMatches(
+ StringPiece input, const std::function<void(TrieMatch)>& update_fn) const;
+
+ const int num_pieces_;
+ const uint32* offsets_;
+ const StringPiece pieces_;
+ const int use_linear_scan_threshold_;
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_SORTED_STRINGS_TABLE_H_
diff --git a/utils/sentencepiece/sorted_strings_table_test.cc b/utils/sentencepiece/sorted_strings_table_test.cc
new file mode 100644
index 0000000..4dff29d
--- /dev/null
+++ b/utils/sentencepiece/sorted_strings_table_test.cc
@@ -0,0 +1,114 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+#include "utils/base/integral_types.h"
+#include "utils/sentencepiece/sorted_strings_table.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+TEST(SortedStringsTest, Lookup) {
+ const char pieces[] = "hell\0hello\0o\0there\0";
+ const uint32 offsets[] = {0, 5, 11, 13};
+
+ SortedStringsTable table(/*num_pieces=*/4, offsets, StringPiece(pieces, 18),
+ /*use_linear_scan_threshold=*/1);
+
+ {
+ std::vector<TrieMatch> matches;
+ EXPECT_TRUE(table.FindAllPrefixMatches("hello there", &matches));
+ EXPECT_EQ(matches.size(), 2);
+ EXPECT_EQ(matches[0].id, 0 /*hell*/);
+ EXPECT_EQ(matches[0].match_length, 4 /*hell*/);
+ EXPECT_EQ(matches[1].id, 1 /*hello*/);
+ EXPECT_EQ(matches[1].match_length, 5 /*hello*/);
+ }
+
+ {
+ std::vector<TrieMatch> matches;
+ EXPECT_TRUE(table.FindAllPrefixMatches("he", &matches));
+ EXPECT_THAT(matches, testing::IsEmpty());
+ }
+
+ {
+ std::vector<TrieMatch> matches;
+ EXPECT_TRUE(table.FindAllPrefixMatches("he", &matches));
+ EXPECT_THAT(matches, testing::IsEmpty());
+ }
+
+ {
+ std::vector<TrieMatch> matches;
+ EXPECT_TRUE(table.FindAllPrefixMatches("abcd", &matches));
+ EXPECT_THAT(matches, testing::IsEmpty());
+ }
+
+ {
+ std::vector<TrieMatch> matches;
+ EXPECT_TRUE(table.FindAllPrefixMatches("", &matches));
+ EXPECT_THAT(matches, testing::IsEmpty());
+ }
+
+ {
+ std::vector<TrieMatch> matches;
+ EXPECT_TRUE(table.FindAllPrefixMatches("hi there", &matches));
+ EXPECT_THAT(matches, testing::IsEmpty());
+ }
+
+ {
+ std::vector<TrieMatch> matches;
+ EXPECT_TRUE(table.FindAllPrefixMatches(StringPiece("\0", 1), &matches));
+ EXPECT_THAT(matches, testing::IsEmpty());
+ }
+
+ {
+ std::vector<TrieMatch> matches;
+ EXPECT_TRUE(
+ table.FindAllPrefixMatches(StringPiece("\xff, \xfe", 2), &matches));
+ EXPECT_THAT(matches, testing::IsEmpty());
+ }
+
+ {
+ TrieMatch match;
+ EXPECT_TRUE(table.LongestPrefixMatch("hella there", &match));
+ EXPECT_EQ(match.id, 0 /*hell*/);
+ }
+
+ {
+ TrieMatch match;
+ EXPECT_TRUE(table.LongestPrefixMatch("hello there", &match));
+ EXPECT_EQ(match.id, 1 /*hello*/);
+ }
+
+ {
+ TrieMatch match;
+ EXPECT_TRUE(table.LongestPrefixMatch("abcd", &match));
+ EXPECT_EQ(match.id, -1);
+ }
+
+ {
+ TrieMatch match;
+ EXPECT_TRUE(table.LongestPrefixMatch("", &match));
+ EXPECT_EQ(match.id, -1);
+ }
+}
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/utils/sentencepiece/test_utils.cc b/utils/sentencepiece/test_utils.cc
new file mode 100644
index 0000000..1ed2bf3
--- /dev/null
+++ b/utils/sentencepiece/test_utils.cc
@@ -0,0 +1,42 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/sentencepiece/test_utils.h"
+
+#include <memory>
+
+#include "utils/base/integral_types.h"
+#include "utils/sentencepiece/double_array_trie.h"
+#include "utils/strings/stringpiece.h"
+
+namespace libtextclassifier3 {
+
+SentencePieceNormalizer NormalizerFromSpec(StringPiece spec,
+ bool add_dummy_prefix,
+ bool remove_extra_whitespaces,
+ bool escape_whitespaces) {
+ const uint32 trie_blob_size = reinterpret_cast<const uint32*>(spec.data())[0];
+ spec.RemovePrefix(sizeof(trie_blob_size));
+ const TrieNode* trie_blob = reinterpret_cast<const TrieNode*>(spec.data());
+ spec.RemovePrefix(trie_blob_size);
+ const int num_nodes = trie_blob_size / sizeof(TrieNode);
+ return SentencePieceNormalizer(
+ DoubleArrayTrie(trie_blob, num_nodes),
+ /*charsmap_normalized=*/StringPiece(spec.data(), spec.size()),
+ add_dummy_prefix, remove_extra_whitespaces, escape_whitespaces);
+}
+
+} // namespace libtextclassifier3
diff --git a/utils/sentencepiece/test_utils.h b/utils/sentencepiece/test_utils.h
new file mode 100644
index 0000000..0c833da
--- /dev/null
+++ b/utils/sentencepiece/test_utils.h
@@ -0,0 +1,35 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_TEST_UTILS_H_
+#define LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_TEST_UTILS_H_
+
+#include <string>
+#include <vector>
+
+#include "utils/sentencepiece/normalizer.h"
+#include "utils/strings/stringpiece.h"
+
+namespace libtextclassifier3 {
+
+SentencePieceNormalizer NormalizerFromSpec(StringPiece spec,
+ bool add_dummy_prefix,
+ bool remove_extra_whitespaces,
+ bool escape_whitespaces);
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_TEST_UTILS_H_
diff --git a/utils/strings/stringpiece.h b/utils/strings/stringpiece.h
index 3ec414f..0dec1b8 100644
--- a/utils/strings/stringpiece.h
+++ b/utils/strings/stringpiece.h
@@ -30,7 +30,7 @@
StringPiece() : StringPiece(nullptr, 0) {}
StringPiece(const char *str) // NOLINT(runtime/explicit)
- : start_(str), size_(strlen(str)) {}
+ : start_(str), size_(str == nullptr ? 0 : strlen(str)) {}
StringPiece(const char *start, size_t size) : start_(start), size_(size) {}
@@ -70,6 +70,10 @@
memcmp(start_, prefix.data(), prefix.size()) == 0);
}
+ bool Equals(StringPiece other) const {
+ return size() == other.size() && memcmp(start_, other.data(), size_) == 0;
+ }
+
// Removes the first `n` characters from the string piece. Note that the
// underlying string is not changed, only the view.
void RemovePrefix(int n) {
diff --git a/utils/strings/substitute.cc b/utils/strings/substitute.cc
new file mode 100644
index 0000000..bba53f5
--- /dev/null
+++ b/utils/strings/substitute.cc
@@ -0,0 +1,91 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/strings/substitute.h"
+
+#include <algorithm>
+
+#include "utils/base/logging.h"
+
+namespace libtextclassifier3 {
+namespace strings {
+
+bool Substitute(const StringPiece format, const std::vector<StringPiece>& args,
+ std::string* output) {
+ // Determine total size needed.
+ size_t size = 0;
+ for (size_t i = 0; i < format.size(); i++) {
+ if (format[i] == '$') {
+ if (i + 1 >= format.size()) {
+ TC3_LOG(ERROR) << "Invalid format string: " << format.ToString();
+ return false;
+ } else if (isdigit(format[i + 1])) {
+ int index = format[i + 1] - '0';
+ if (static_cast<size_t>(index) >= args.size()) {
+ TC3_LOG(ERROR) << "Asked for " << index << ", but only "
+ << args.size() << " arguments given";
+ return false;
+ }
+ size += args[index].size();
+ ++i; // Skip next char.
+ } else if (format[i + 1] == '$') {
+ ++size;
+ ++i; // Skip next char.
+ } else {
+ TC3_LOG(ERROR) << "Invalid format string: " << format.ToString();
+ return false;
+ }
+ } else {
+ ++size;
+ }
+ }
+
+ if (size == 0) {
+ output->clear();
+ return true;
+ }
+
+ // Build the string.
+ output->resize(size);
+ char* target = &(*output)[0];
+ for (size_t i = 0; i < format.size(); i++) {
+ if (format[i] == '$') {
+ if (isdigit(format[i + 1])) {
+ const StringPiece src = args[format[i + 1] - '0'];
+ target = std::copy(src.data(), src.data() + src.size(), target);
+ ++i; // Skip next char.
+ } else if (format[i + 1] == '$') {
+ *target++ = '$';
+ ++i; // Skip next char.
+ }
+ } else {
+ *target++ = format[i];
+ }
+ }
+ return true;
+}
+
+std::string Substitute(const StringPiece format,
+ const std::vector<StringPiece>& args) {
+ std::string result;
+ if (!Substitute(format, args, &result)) {
+ return "";
+ }
+ return result;
+}
+
+} // namespace strings
+} // namespace libtextclassifier3
diff --git a/utils/strings/substitute.h b/utils/strings/substitute.h
new file mode 100644
index 0000000..f7e6714
--- /dev/null
+++ b/utils/strings/substitute.h
@@ -0,0 +1,43 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_STRINGS_SUBSTITUTE_H_
+#define LIBTEXTCLASSIFIER_UTILS_STRINGS_SUBSTITUTE_H_
+
+#include <string>
+#include <vector>
+
+#include "utils/strings/stringpiece.h"
+
+namespace libtextclassifier3 {
+namespace strings {
+
+// Formats a string with argument-binding.
+// Uses a format string that contains positional identifiers indicated by a
+// dollar sign ($) and a signle digit positional id to indicate which
+// substitution arguments to use at that location within the format string.
+// A '$$' sequence in the format string means output a literal '$' character.
+// Returns whether the substitution was successful.
+bool Substitute(const StringPiece format, const std::vector<StringPiece>& args,
+ std::string* output);
+
+std::string Substitute(const StringPiece format,
+ const std::vector<StringPiece>& args);
+
+} // namespace strings
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_STRINGS_SUBSTITUTE_H_
diff --git a/utils/strings/substitute_test.cc b/utils/strings/substitute_test.cc
new file mode 100644
index 0000000..94b37ab
--- /dev/null
+++ b/utils/strings/substitute_test.cc
@@ -0,0 +1,48 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/strings/substitute.h"
+
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+#include "utils/strings/stringpiece.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+TEST(SubstituteTest, Substitute) {
+ EXPECT_EQ("Hello, world!",
+ strings::Substitute("$0, $1!", {"Hello", "world"}));
+
+ // Out of order.
+ EXPECT_EQ("world, Hello!",
+ strings::Substitute("$1, $0!", {"Hello", "world"}));
+ EXPECT_EQ("b, a, c, b",
+ strings::Substitute("$1, $0, $2, $1", {"a", "b", "c"}));
+
+ // Literal $
+ EXPECT_EQ("$", strings::Substitute("$$", {}));
+ EXPECT_EQ("$1", strings::Substitute("$$1", {}));
+
+ const char* null_cstring = nullptr;
+ EXPECT_EQ("Text: ''", strings::Substitute("Text: '$0'", {null_cstring}));
+}
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/utils/tensor-view.h b/utils/tensor-view.h
index a46ebd1..ed75ab8 100644
--- a/utils/tensor-view.h
+++ b/utils/tensor-view.h
@@ -51,7 +51,12 @@
const T* data() const { return data_; }
- int size() const { return size_; }
+ int size() const {
+ if (!is_valid()) {
+ return 0;
+ }
+ return size_;
+ }
bool copy_to(T* dest, int dest_size) const {
if (dest_size < size_) {
diff --git a/utils/test-utils.cc b/utils/test-utils.cc
new file mode 100644
index 0000000..e37105a
--- /dev/null
+++ b/utils/test-utils.cc
@@ -0,0 +1,41 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/test-utils.h"
+
+#include <iterator>
+
+#include "utils/strings/split.h"
+#include "utils/strings/stringpiece.h"
+
+namespace libtextclassifier3 {
+
+using libtextclassifier3::Token;
+
+// Returns a list of Tokens for given input string. Can't handle non-ASCII
+// input.
+std::vector<Token> TokenizeAsciiOnSpace(const std::string& text) {
+ std::vector<Token> result;
+ for (const StringPiece token : strings::Split(text, ' ')) {
+ const int start_offset = std::distance(text.data(), token.data());
+ const int token_length = token.length();
+ result.push_back(
+ Token{token.ToString(), start_offset, start_offset + token_length});
+ }
+ return result;
+}
+
+} // namespace libtextclassifier3
diff --git a/utils/test-utils.h b/utils/test-utils.h
new file mode 100644
index 0000000..7e227dc
--- /dev/null
+++ b/utils/test-utils.h
@@ -0,0 +1,34 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Utilities for tests.
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_TEST_UTILS_H_
+#define LIBTEXTCLASSIFIER_UTILS_TEST_UTILS_H_
+
+#include <string>
+
+#include "annotator/types.h"
+
+namespace libtextclassifier3 {
+
+// Returns a list of Tokens for given input string. Can't handle non-ASCII
+// input.
+std::vector<Token> TokenizeAsciiOnSpace(const std::string& text);
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_TEST_UTILS_H_
diff --git a/utils/testing/annotator.h b/utils/testing/annotator.h
new file mode 100644
index 0000000..b988d0b
--- /dev/null
+++ b/utils/testing/annotator.h
@@ -0,0 +1,51 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Helper utilities for testing Annotator.
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_TESTING_ANNOTATOR_H_
+#define LIBTEXTCLASSIFIER_UTILS_TESTING_ANNOTATOR_H_
+
+#include <memory>
+#include <string>
+
+#include "annotator/model_generated.h"
+#include "annotator/types.h"
+#include "flatbuffers/flatbuffers.h"
+
+namespace libtextclassifier3 {
+
+// Loads FlatBuffer model, unpacks it and passes it to the visitor_fn so that it
+// can modify it. Afterwards the modified unpacked model is serialized back to a
+// flatbuffer.
+template <typename Fn>
+std::string ModifyAnnotatorModel(const std::string& model_flatbuffer,
+ Fn visitor_fn) {
+ std::unique_ptr<ModelT> unpacked_model =
+ UnPackModel(model_flatbuffer.c_str());
+
+ visitor_fn(unpacked_model.get());
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ return std::string(reinterpret_cast<char*>(builder.GetBufferPointer()),
+ builder.GetSize());
+}
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_TESTING_ANNOTATOR_H_
diff --git a/utils/tflite-model-executor.cc b/utils/tflite-model-executor.cc
index e3ec73c..9ba232e 100644
--- a/utils/tflite-model-executor.cc
+++ b/utils/tflite-model-executor.cc
@@ -16,24 +16,142 @@
#include "utils/tflite-model-executor.h"
-#include "tensorflow/lite/kernels/register.h"
#include "utils/base/logging.h"
+#include "tensorflow/lite/kernels/register.h"
// Forward declaration of custom TensorFlow Lite ops for registration.
namespace tflite {
namespace ops {
namespace builtin {
-TfLiteRegistration* Register_DIV();
+TfLiteRegistration* Register_ADD();
+TfLiteRegistration* Register_CONCATENATION();
+TfLiteRegistration* Register_CONV_2D();
TfLiteRegistration* Register_FULLY_CONNECTED();
-TfLiteRegistration* Register_SOFTMAX(); // TODO(smillius): remove.
+TfLiteRegistration* Register_L2_NORMALIZATION();
+TfLiteRegistration* Register_MUL();
+TfLiteRegistration* Register_RESHAPE();
+TfLiteRegistration* Register_SOFTMAX();
+TfLiteRegistration* Register_GATHER();
+TfLiteRegistration* Register_TRANSPOSE();
+TfLiteRegistration* Register_SUB();
+TfLiteRegistration* Register_DIV();
+TfLiteRegistration* Register_STRIDED_SLICE();
+TfLiteRegistration* Register_EXP();
+TfLiteRegistration* Register_TOPK_V2();
+TfLiteRegistration* Register_SPLIT();
+TfLiteRegistration* Register_CAST();
+TfLiteRegistration* Register_MAXIMUM();
+TfLiteRegistration* Register_MINIMUM();
+TfLiteRegistration* Register_NEG();
+TfLiteRegistration* Register_SLICE();
+TfLiteRegistration* Register_LOG();
+TfLiteRegistration* Register_SUM();
+TfLiteRegistration* Register_PACK();
+TfLiteRegistration* Register_DEQUANTIZE();
+TfLiteRegistration* Register_MEAN();
} // namespace builtin
} // namespace ops
} // namespace tflite
-void RegisterSelectedOps(::tflite::MutableOpResolver* resolver) {
- resolver->AddBuiltin(::tflite::BuiltinOperator_FULLY_CONNECTED,
- ::tflite::ops::builtin::Register_FULLY_CONNECTED());
+#ifdef TC3_WITH_ACTIONS_OPS
+#include "utils/tflite/dist_diversification.h"
+#include "utils/tflite/text_encoder.h"
+#include "utils/tflite/token_encoder.h"
+
+void RegisterSelectedOps(tflite::MutableOpResolver* resolver) {
+ resolver->AddBuiltin(tflite::BuiltinOperator_ADD,
+ tflite::ops::builtin::Register_ADD(),
+ /*min_version=*/1,
+ /*max_version=*/2);
+ resolver->AddBuiltin(tflite::BuiltinOperator_CONCATENATION,
+ tflite::ops::builtin::Register_CONCATENATION(),
+ /*min_version=*/1,
+ /*max_version=*/2);
+ resolver->AddBuiltin(tflite::BuiltinOperator_CONV_2D,
+ tflite::ops::builtin::Register_CONV_2D(),
+ /*min_version=*/1,
+ /*max_version=*/3);
+ resolver->AddBuiltin(tflite::BuiltinOperator_FULLY_CONNECTED,
+ tflite::ops::builtin::Register_FULLY_CONNECTED(),
+ /*min_version=*/1,
+ /*max_version=*/4);
+ resolver->AddBuiltin(tflite::BuiltinOperator_L2_NORMALIZATION,
+ tflite::ops::builtin::Register_L2_NORMALIZATION(),
+ /*min_version=*/1,
+ /*max_version=*/2);
+ resolver->AddBuiltin(tflite::BuiltinOperator_MUL,
+ tflite::ops::builtin::Register_MUL());
+ resolver->AddBuiltin(tflite::BuiltinOperator_RESHAPE,
+ tflite::ops::builtin::Register_RESHAPE());
+ resolver->AddBuiltin(tflite::BuiltinOperator_SOFTMAX,
+ tflite::ops::builtin::Register_SOFTMAX(),
+ /*min_version=*/1,
+ /*max_version=*/2);
+ resolver->AddBuiltin(tflite::BuiltinOperator_GATHER,
+ tflite::ops::builtin::Register_GATHER(),
+ /*min_version=*/1,
+ /*max_version=*/2);
+ resolver->AddBuiltin(tflite::BuiltinOperator_TRANSPOSE,
+ tflite::ops::builtin::Register_TRANSPOSE(),
+ /*min_version=*/1,
+ /*max_version=*/2);
+ resolver->AddBuiltin(tflite::BuiltinOperator_SUB,
+ tflite::ops::builtin::Register_SUB(),
+ /*min_version=*/1,
+ /*max_version=*/2);
+ resolver->AddBuiltin(tflite::BuiltinOperator_DIV,
+ tflite::ops::builtin::Register_DIV());
+ resolver->AddBuiltin(tflite::BuiltinOperator_STRIDED_SLICE,
+ tflite::ops::builtin::Register_STRIDED_SLICE(),
+ /*min_version=*/1,
+ /*max_version=*/2);
+ resolver->AddBuiltin(tflite::BuiltinOperator_EXP,
+ tflite::ops::builtin::Register_EXP());
+ resolver->AddBuiltin(tflite::BuiltinOperator_TOPK_V2,
+ tflite::ops::builtin::Register_TOPK_V2(),
+ /*min_version=*/1,
+ /*max_version=*/2);
+ resolver->AddBuiltin(tflite::BuiltinOperator_SPLIT,
+ tflite::ops::builtin::Register_SPLIT(),
+ /*min_version=*/1,
+ /*max_version=*/3);
+ resolver->AddBuiltin(tflite::BuiltinOperator_CAST,
+ tflite::ops::builtin::Register_CAST());
+ resolver->AddBuiltin(tflite::BuiltinOperator_MAXIMUM,
+ tflite::ops::builtin::Register_MAXIMUM(),
+ /*min_version=*/1,
+ /*max_version=*/2);
+ resolver->AddBuiltin(tflite::BuiltinOperator_MINIMUM,
+ tflite::ops::builtin::Register_MINIMUM(),
+ /*min_version=*/1,
+ /*max_version=*/2);
+ resolver->AddBuiltin(tflite::BuiltinOperator_NEG,
+ tflite::ops::builtin::Register_NEG());
+ resolver->AddBuiltin(tflite::BuiltinOperator_SLICE,
+ tflite::ops::builtin::Register_SLICE(),
+ /*min_version=*/1,
+ /*max_version=*/2);
+ resolver->AddBuiltin(tflite::BuiltinOperator_LOG,
+ tflite::ops::builtin::Register_LOG());
+ resolver->AddBuiltin(tflite::BuiltinOperator_SUM,
+ tflite::ops::builtin::Register_SUM());
+ resolver->AddBuiltin(tflite::BuiltinOperator_PACK,
+ tflite::ops::builtin::Register_PACK(),
+ /*min_version=*/1,
+ /*max_version=*/2);
+ resolver->AddBuiltin(tflite::BuiltinOperator_DEQUANTIZE,
+ tflite::ops::builtin::Register_DEQUANTIZE(),
+ /*min_version=*/1,
+ /*max_version=*/2);
+ resolver->AddBuiltin(tflite::BuiltinOperator_MEAN,
+ tflite::ops::builtin::Register_MEAN());
}
+#else
+void RegisterSelectedOps(tflite::MutableOpResolver* resolver) {
+ resolver->AddBuiltin(tflite::BuiltinOperator_FULLY_CONNECTED,
+ tflite::ops::builtin::Register_FULLY_CONNECTED());
+}
+#endif // TC3_WITH_ACTIONS_OPS
namespace libtextclassifier3 {
@@ -41,17 +159,19 @@
#ifdef TC3_USE_SELECTIVE_REGISTRATION
std::unique_ptr<tflite::MutableOpResolver> resolver(
new tflite::MutableOpResolver);
- resolver->AddBuiltin(tflite::BuiltinOperator_DIV,
- tflite::ops::builtin::Register_DIV());
- resolver->AddBuiltin(tflite::BuiltinOperator_FULLY_CONNECTED,
- tflite::ops::builtin::Register_FULLY_CONNECTED());
- resolver->AddBuiltin(tflite::BuiltinOperator_SOFTMAX,
- tflite::ops::builtin::Register_SOFTMAX());
RegisterSelectedOps(resolver.get());
#else
std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver> resolver(
new tflite::ops::builtin::BuiltinOpResolver);
#endif
+#ifdef TC3_WITH_ACTIONS_OPS
+ resolver->AddCustom("DistanceDiversification",
+ tflite::ops::custom::Register_DISTANCE_DIVERSIFICATION());
+ resolver->AddCustom("TextEncoder",
+ tflite::ops::custom::Register_TEXT_ENCODER());
+ resolver->AddCustom("TokenEncoder",
+ tflite::ops::custom::Register_TOKEN_ENCODER());
+#endif // TC3_WITH_ACTIONS_OPS
return std::unique_ptr<tflite::OpResolver>(std::move(resolver));
}
@@ -98,12 +218,12 @@
buf.AddString(s.data(), s.length());
}
buf.WriteToTensorAsVector(
- interpreter->tensor(interpreter->inputs()[input_index]));
+ interpreter->tensor(interpreter->inputs()[input_index]));
}
template <>
std::vector<tflite::StringRef> TfLiteModelExecutor::Output(
- const int output_index, tflite::Interpreter* interpreter) const {
+ const int output_index, const tflite::Interpreter* interpreter) const {
const TfLiteTensor* output_tensor =
interpreter->tensor(interpreter->outputs()[output_index]);
const int num_strings = tflite::GetStringCount(output_tensor);
@@ -116,7 +236,7 @@
template <>
std::vector<std::string> TfLiteModelExecutor::Output(
- const int output_index, tflite::Interpreter* interpreter) const {
+ const int output_index, const tflite::Interpreter* interpreter) const {
std::vector<std::string> output;
for (const tflite::StringRef& s :
Output<tflite::StringRef>(output_index, interpreter)) {
diff --git a/utils/tflite-model-executor.h b/utils/tflite-model-executor.h
index ab1e76d..10d4233 100644
--- a/utils/tflite-model-executor.h
+++ b/utils/tflite-model-executor.h
@@ -21,13 +21,13 @@
#include <memory>
+#include "utils/base/logging.h"
+#include "utils/tensor-view.h"
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/model.h"
#include "tensorflow/lite/op_resolver.h"
#include "tensorflow/lite/string_util.h"
-#include "utils/base/logging.h"
-#include "utils/tensor-view.h"
namespace libtextclassifier3 {
@@ -79,9 +79,41 @@
}
template <typename T>
+ void SetInput(const int input_index, const T input_value,
+ tflite::Interpreter* interpreter) const {
+ TfLiteTensor* input_tensor =
+ interpreter->tensor(interpreter->inputs()[input_index]);
+ switch (input_tensor->type) {
+ case kTfLiteFloat32:
+ *(input_tensor->data.f) = input_value;
+ break;
+ case kTfLiteInt32:
+ *(input_tensor->data.i32) = input_value;
+ break;
+ case kTfLiteUInt8:
+ *(input_tensor->data.uint8) = input_value;
+ break;
+ case kTfLiteInt64:
+ *(input_tensor->data.i64) = input_value;
+ break;
+ case kTfLiteBool:
+ *(input_tensor->data.b) = input_value;
+ break;
+ case kTfLiteInt16:
+ *(input_tensor->data.i16) = input_value;
+ break;
+ case kTfLiteInt8:
+ *(input_tensor->data.int8) = input_value;
+ break;
+ default:
+ break;
+ }
+ }
+
+ template <typename T>
TensorView<T> OutputView(const int output_index,
- tflite::Interpreter* interpreter) const {
- TfLiteTensor* output_tensor =
+ const tflite::Interpreter* interpreter) const {
+ const TfLiteTensor* output_tensor =
interpreter->tensor(interpreter->outputs()[output_index]);
return TensorView<T>(interpreter->typed_output_tensor<T>(output_index),
std::vector<int>(output_tensor->dims->data,
@@ -91,7 +123,7 @@
template <typename T>
std::vector<T> Output(const int output_index,
- tflite::Interpreter* interpreter) const {
+ const tflite::Interpreter* interpreter) const {
TensorView<T> output_view = OutputView<T>(output_index, interpreter);
return std::vector<T>(output_view.data(),
output_view.data() + output_view.size());
@@ -112,11 +144,11 @@
template <>
std::vector<tflite::StringRef> TfLiteModelExecutor::Output(
- const int output_index, tflite::Interpreter* interpreter) const;
+ const int output_index, const tflite::Interpreter* interpreter) const;
template <>
std::vector<std::string> TfLiteModelExecutor::Output(
- const int output_index, tflite::Interpreter* interpreter) const;
+ const int output_index, const tflite::Interpreter* interpreter) const;
} // namespace libtextclassifier3
diff --git a/utils/tflite/dist_diversification.cc b/utils/tflite/dist_diversification.cc
new file mode 100644
index 0000000..6dfc329
--- /dev/null
+++ b/utils/tflite/dist_diversification.cc
@@ -0,0 +1,155 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/tflite/dist_diversification.h"
+
+#include <algorithm>
+#include "tensorflow/lite/context.h"
+#include "tensorflow/lite/kernels/kernel_util.h"
+#include "tensorflow/lite/model.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+// Returns a vector of row indices in a distance matrix.
+// Indices are increasing and the distance of every selected index to others
+// is larger than `min_distance`.
+template <typename DistanceMatrixType>
+std::vector<int> DiversifyByDistance(const DistanceMatrixType& distance_matrix,
+ const int matrix_size,
+ const float min_distance,
+ const int max_num_results) {
+ std::vector<int> result{0};
+ result.reserve(max_num_results);
+ int index = 1;
+ while (result.size() < max_num_results && index < matrix_size) {
+ for (; index < matrix_size; ++index) {
+ bool too_close = false;
+ for (const int selected_index : result) {
+ if (distance_matrix(index, selected_index) < min_distance) {
+ too_close = true;
+ break;
+ }
+ }
+ if (!too_close) {
+ result.push_back(index);
+ ++index;
+ break;
+ }
+ }
+ }
+ return result;
+}
+
+// Input parameters for the op.
+enum DistDiversificationInputs {
+ DIST_DIVERSIFICATION_INPUT_DISTANCE_MATRIX = 0,
+ DIST_DIVERSIFICATION_INPUT_MIN_DISTANCE = 1,
+ DIST_DIVERSIFICATION_INPUT_NUM_RESULTS = 2
+};
+
+// Output parameters for the op.
+enum DistDiversificationOutputs {
+ DIST_DIVERSIFICATION_OUTPUT_INDICES = 0,
+ DIST_DIVERSIFICATION_OUTPUT_LENGTH = 1,
+};
+
+TfLiteIntArray* CreateSizeArray(const std::initializer_list<int>& sizes) {
+ TfLiteIntArray* array_size = TfLiteIntArrayCreate(sizes.size());
+ int index = 0;
+ for (const int size : sizes) {
+ array_size->data[index++] = size;
+ }
+ return array_size;
+}
+
+TfLiteStatus AllocateOutputIndexes(TfLiteContext* context, TfLiteNode* node) {
+ const TfLiteTensor& num_results =
+ context
+ ->tensors[node->inputs->data[DIST_DIVERSIFICATION_INPUT_NUM_RESULTS]];
+ TfLiteTensor& output_indices =
+ context
+ ->tensors[node->outputs->data[DIST_DIVERSIFICATION_OUTPUT_INDICES]];
+ return context->ResizeTensor(context, &output_indices,
+ CreateSizeArray({num_results.data.i32[0]}));
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ const TfLiteTensor& num_results =
+ context
+ ->tensors[node->inputs->data[DIST_DIVERSIFICATION_INPUT_NUM_RESULTS]];
+ if (tflite::IsConstantTensor(&num_results)) {
+ TF_LITE_ENSURE_OK(context, AllocateOutputIndexes(context, node));
+ } else {
+ TfLiteTensor& output_indices =
+ context
+ ->tensors[node->outputs->data[DIST_DIVERSIFICATION_OUTPUT_INDICES]];
+ tflite::SetTensorToDynamic(&output_indices);
+ }
+ TfLiteTensor& output_length =
+ context->tensors[node->outputs->data[DIST_DIVERSIFICATION_OUTPUT_LENGTH]];
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, &output_length,
+ CreateSizeArray({1})));
+ return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ TfLiteTensor& output_indices =
+ context
+ ->tensors[node->outputs->data[DIST_DIVERSIFICATION_OUTPUT_INDICES]];
+ if (tflite::IsDynamicTensor(&output_indices)) {
+ TF_LITE_ENSURE_OK(context, AllocateOutputIndexes(context, node));
+ }
+ const TfLiteTensor& distance_matrix =
+ context->tensors[node->inputs
+ ->data[DIST_DIVERSIFICATION_INPUT_DISTANCE_MATRIX]];
+ const int distance_matrix_dim = distance_matrix.dims->data[0];
+ const float min_distance =
+ context
+ ->tensors[node->inputs->data[DIST_DIVERSIFICATION_INPUT_MIN_DISTANCE]]
+ .data.f[0];
+ const int num_results =
+ context
+ ->tensors[node->inputs->data[DIST_DIVERSIFICATION_INPUT_NUM_RESULTS]]
+ .data.i32[0];
+ const auto indices = DiversifyByDistance(
+ [&](int row, int col) {
+ return distance_matrix.data.f[row * distance_matrix_dim + col];
+ },
+ distance_matrix_dim, min_distance, num_results);
+ std::copy(indices.begin(), indices.end(), output_indices.data.i32);
+ std::fill_n(output_indices.data.i32 + indices.size(),
+ num_results - indices.size(), -1);
+ TfLiteTensor& output_length =
+ context->tensors[node->outputs->data[DIST_DIVERSIFICATION_OUTPUT_LENGTH]];
+ *output_length.data.i32 = indices.size();
+ return kTfLiteOk;
+}
+
+} // namespace
+} // namespace libtextclassifier3
+
+namespace tflite {
+namespace ops {
+namespace custom {
+TfLiteRegistration* Register_DISTANCE_DIVERSIFICATION() {
+ static TfLiteRegistration r = {nullptr, nullptr, libtextclassifier3::Prepare,
+ libtextclassifier3::Eval};
+ return &r;
+}
+} // namespace custom
+} // namespace ops
+} // namespace tflite
diff --git a/utils/tflite/dist_diversification.h b/utils/tflite/dist_diversification.h
new file mode 100644
index 0000000..2a9f17e
--- /dev/null
+++ b/utils/tflite/dist_diversification.h
@@ -0,0 +1,32 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_TFLITE_DIST_DIVERSIFICATION_H_
+#define LIBTEXTCLASSIFIER_UTILS_TFLITE_DIST_DIVERSIFICATION_H_
+
+#include "tensorflow/lite/context.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+
+TfLiteRegistration* Register_DISTANCE_DIVERSIFICATION();
+
+} // namespace custom
+} // namespace ops
+} // namespace tflite
+
+#endif // LIBTEXTCLASSIFIER_UTILS_TFLITE_DIST_DIVERSIFICATION_H_
diff --git a/utils/tflite/dist_diversification_test.cc b/utils/tflite/dist_diversification_test.cc
new file mode 100644
index 0000000..2380116
--- /dev/null
+++ b/utils/tflite/dist_diversification_test.cc
@@ -0,0 +1,82 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/tflite/dist_diversification.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "tensorflow/lite/interpreter.h"
+#include "tensorflow/lite/kernels/register.h"
+#include "tensorflow/lite/kernels/test_util.h"
+#include "tensorflow/lite/model.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+class DistanceDiversificationOpModel : public tflite::SingleOpModel {
+ public:
+ explicit DistanceDiversificationOpModel(int matrix_rows);
+ void SetDistanceMatrix(const std::initializer_list<float>& values) {
+ PopulateTensor(distance_matrix_, values);
+ }
+ void SetNumOutput(int length) { PopulateTensor(num_results_, {length}); }
+ void SetMinDistance(float min_distance) {
+ PopulateTensor(min_distance_, {min_distance});
+ }
+ int GetOutputLen() { return ExtractVector<int>(output_len_).front(); }
+ std::vector<int> GetOutputIndexes(int output_length) {
+ auto res = ExtractVector<int>(output_indexes_);
+ res.resize(output_length);
+ return res;
+ }
+
+ private:
+ int distance_matrix_;
+ int num_results_;
+ int min_distance_;
+
+ int output_len_;
+ int output_indexes_;
+};
+
+DistanceDiversificationOpModel::DistanceDiversificationOpModel(
+ int matrix_rows) {
+ distance_matrix_ = AddInput(tflite::TensorType_FLOAT32);
+ min_distance_ = AddInput(tflite::TensorType_FLOAT32);
+ num_results_ = AddInput(tflite::TensorType_INT32);
+
+ output_indexes_ = AddOutput(tflite::TensorType_INT32);
+ output_len_ = AddOutput(tflite::TensorType_INT32);
+ SetCustomOp("DistanceDiversification", {},
+ tflite::ops::custom::Register_DISTANCE_DIVERSIFICATION);
+ BuildInterpreter({{matrix_rows, matrix_rows}, {1}, {1}});
+}
+
+// Tests
+TEST(DistanceDiversificationOp, Simple) {
+ DistanceDiversificationOpModel m(5);
+ m.SetDistanceMatrix({0.0, 0.1, 0.2, 0.3, 0.4, 0.1, 0.0, 0.1, 0.2,
+ 0.3, 0.2, 0.1, 0.0, 0.1, 0.2, 0.3, 0.2, 0.1,
+ 0.0, 0.1, 0.4, 0.3, 0.2, 0.1, 0.0});
+ m.SetMinDistance(0.21);
+ m.SetNumOutput(3);
+ m.Invoke();
+ const int output_length = m.GetOutputLen();
+ EXPECT_EQ(output_length, 2);
+ EXPECT_THAT(m.GetOutputIndexes(output_length), testing::ElementsAre(0, 3));
+}
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/utils/tflite/encoder_common.cc b/utils/tflite/encoder_common.cc
new file mode 100644
index 0000000..8f9f2a8
--- /dev/null
+++ b/utils/tflite/encoder_common.cc
@@ -0,0 +1,122 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/tflite/encoder_common.h"
+
+#include "tensorflow/lite/kernels/kernel_util.h"
+#include "tensorflow/lite/string_util.h"
+
+namespace libtextclassifier3 {
+
+TfLiteIntArray* CreateIntArray(const std::initializer_list<int>& values) {
+ TfLiteIntArray* array_size = TfLiteIntArrayCreate(values.size());
+ int index = 0;
+ for (const int size : values) {
+ array_size->data[index++] = size;
+ }
+ return array_size;
+}
+
+TfLiteStatus CopyValuesToTensorAndPadOrTruncate(
+ const TfLiteTensor& in, const std::vector<int>& encoding_end_offsets,
+ int start_offset, TfLiteContext* context, TfLiteTensor* out) {
+ TF_LITE_ENSURE_EQ(context, in.dims->size, kEncoderInputRank);
+ TF_LITE_ENSURE_EQ(context, in.dims->data[0], kEncoderBatchSize);
+ const int output_size = out->dims->data[1];
+ int output_offset = 0;
+ for (int value_index = 0;
+ value_index < encoding_end_offsets.size() && output_offset < output_size;
+ ++value_index) {
+ // Calculate how many elements need to be set with this value.
+ // The low bound depends on the offset from the beginning. If this is 0, it
+ // means that this value it truncated.
+ // The upper bound depends on how many elements are in the output tensor.
+ const int from_this_element =
+ std::min(std::max(0, encoding_end_offsets[value_index] - start_offset -
+ output_offset),
+ output_size - output_offset);
+ if (from_this_element == 0) {
+ continue;
+ }
+
+ switch (in.type) {
+ case kTfLiteInt32: {
+ std::fill(out->data.i32 + output_offset,
+ out->data.i32 + output_offset + from_this_element,
+ in.data.i32[value_index]);
+ } break;
+ case kTfLiteFloat32: {
+ std::fill(out->data.f + output_offset,
+ out->data.f + output_offset + from_this_element,
+ in.data.f[value_index]);
+ } break;
+ default:
+ context->ReportError(
+ (context), __FILE__ " Not supported attribute type %d", in.type);
+ return kTfLiteError;
+ }
+ output_offset += from_this_element;
+ }
+ // Do final padding.
+ switch (in.type) {
+ case kTfLiteInt32: {
+ const int32_t value =
+ (output_offset > 0) ? out->data.i32[output_offset - 1] : 0;
+ std::fill(out->data.i32 + output_offset, out->data.i32 + output_size,
+ value);
+ } break;
+ case kTfLiteFloat32: {
+ const float value =
+ (output_offset > 0) ? out->data.f[output_offset - 1] : 0;
+ std::fill(out->data.f + output_offset, out->data.f + output_size, value);
+ } break;
+ default:
+ break;
+ }
+ return kTfLiteOk;
+}
+
+TfLiteStatus ResizeOutputTensor(const int max_output_length,
+ TfLiteTensor* tensor, TfLiteContext* context) {
+ TF_LITE_ENSURE_OK(
+ context, context->ResizeTensor(
+ context, tensor,
+ CreateIntArray({kEncoderBatchSize, max_output_length})));
+ return kTfLiteOk;
+}
+
+int CopyDataToTensorAndPadOrTruncate(const int32_t max_output_length,
+ const std::vector<int32_t>& data,
+ const int32_t padding_value,
+ TfLiteTensor* output_tensor) {
+ const int num_skip =
+ std::max(0, static_cast<int>(data.size()) - max_output_length);
+ int output_offset = 0;
+ int32_t* output_buffer = output_tensor->data.i32;
+ for (int i = num_skip; i < data.size(); ++i, ++output_offset) {
+ output_buffer[output_offset] = data[i];
+ }
+
+ // Do padding.
+ for (; output_offset < max_output_length; ++output_offset) {
+ output_buffer[output_offset] = padding_value;
+ }
+
+ // Return number of skipped entries from the beginning.
+ return num_skip;
+}
+
+} // namespace libtextclassifier3
diff --git a/utils/tflite/encoder_common.h b/utils/tflite/encoder_common.h
new file mode 100644
index 0000000..10ae6df
--- /dev/null
+++ b/utils/tflite/encoder_common.h
@@ -0,0 +1,67 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Shared methods for the text and token encoders.
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_TFLITE_ENCODER_COMMON_H_
+#define LIBTEXTCLASSIFIER_UTILS_TFLITE_ENCODER_COMMON_H_
+
+#include <memory>
+#include <vector>
+
+#include "tensorflow/lite/model.h"
+
+namespace libtextclassifier3 {
+
+// Input rank for the encoder ops is 2, because the first dimension is
+// always considered to be for batching, and during inference is always set to
+// 1, and the second dimension indexes the input values (texts or token
+// lengths).
+constexpr const int kEncoderInputRank = 2;
+constexpr const int kEncoderBatchSize = 1;
+
+// Creates a TensorFlow Lite array from an initializer list.
+TfLiteIntArray* CreateIntArray(const std::initializer_list<int>& values);
+
+// Copies values associated with the input to the output.
+// Typically we have attribute values associated with each item in the input,
+// e.g. user id per message in the conversation.
+// This aligns and replicates the attribute values with the encoded input, e.g.
+// replicates the same user id per token or sentence piece of the input.
+// As the input for the whole conversation is concatenated and (potentially)
+// trimmed, `encoding_end_offset` indicates where each item ends and
+// `start_offset` indicates how many elements at the beginning were dropped.
+TfLiteStatus CopyValuesToTensorAndPadOrTruncate(
+ const TfLiteTensor& in, const std::vector<int>& encoding_end_offsets,
+ int start_offset, TfLiteContext* context, TfLiteTensor* out);
+
+// Resizes an output tensor to shape {kBatchSize, max_output_length}.
+TfLiteStatus ResizeOutputTensor(const int max_output_length,
+ TfLiteTensor* tensor, TfLiteContext* context);
+
+// Copy a slice of data to output.
+// If the size of the data is smaller than `max_output_length` then the output
+// is padded with `padding_value`.
+// If the size of the data is larger than `max_output_length` then entries at
+// the beginning a dropped to fit into the limit.
+int CopyDataToTensorAndPadOrTruncate(const int32_t max_output_length,
+ const std::vector<int32_t>& data,
+ const int32_t padding_value,
+ TfLiteTensor* output_tensor);
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_TFLITE_ENCODER_COMMON_H_
diff --git a/utils/tflite/encoder_common_test.cc b/utils/tflite/encoder_common_test.cc
new file mode 100644
index 0000000..247689f
--- /dev/null
+++ b/utils/tflite/encoder_common_test.cc
@@ -0,0 +1,34 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/tflite/encoder_common.h"
+
+#include "gtest/gtest.h"
+#include "tensorflow/lite/model.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+TEST(EncoderUtilsTest, CreateIntArray) {
+ TfLiteIntArray* a = CreateIntArray({1, 2, 3});
+ EXPECT_EQ(a->data[0], 1);
+ EXPECT_EQ(a->data[1], 2);
+ EXPECT_EQ(a->data[2], 3);
+ TfLiteIntArrayFree(a);
+}
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/utils/tflite/text_encoder.cc b/utils/tflite/text_encoder.cc
new file mode 100644
index 0000000..c7811ea
--- /dev/null
+++ b/utils/tflite/text_encoder.cc
@@ -0,0 +1,298 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <memory>
+#include <vector>
+
+#include "utils/base/logging.h"
+#include "utils/sentencepiece/double_array_trie.h"
+#include "utils/sentencepiece/encoder.h"
+#include "utils/sentencepiece/normalizer.h"
+#include "utils/sentencepiece/sorted_strings_table.h"
+#include "utils/strings/stringpiece.h"
+#include "utils/tflite/encoder_common.h"
+#include "utils/tflite/text_encoder.h"
+#include "utils/tflite/text_encoder_config_generated.h"
+#include "flatbuffers/flatbuffers.h"
+#include "flatbuffers/flexbuffers.h"
+#include "tensorflow/lite/kernels/kernel_util.h"
+#include "tensorflow/lite/model.h"
+#include "tensorflow/lite/string_util.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+struct TextEncoderOp {
+ std::unique_ptr<SentencePieceNormalizer> normalizer;
+ std::unique_ptr<Encoder> encoder;
+ std::unique_ptr<SentencePieceMatcher> matcher;
+};
+
+// Input parameters for the op.
+// The conversation message as a (1, conversation length) string tensor.
+constexpr const int kInputTexts = 0;
+
+// The number of messages, the conversation length, int scalar.
+constexpr const int kInputNumInputs = 1;
+
+// Maximum output length of the encoding, int scalar.
+constexpr const int kInputMaxLength = 2;
+
+// Additional attributes to align to the sentence pieces, e.g. user ids per
+// message.
+constexpr const int kInputAttr = 3;
+
+// Output parameters for the op.
+// The text sentence piece encodings as ids, (1, max output length) int tensor.
+constexpr const int kOutputEncoded = 0;
+
+// Relative position of each sentence piece in the input text,
+// (1, max output length) int tensor.
+constexpr const int kOutputPosition = 1;
+
+// Output length after trimming to the maximum output length specified.
+// int scalar.
+constexpr const int kOutputLengths = 2;
+
+// Padded and sentence piece aligned provided attributes, e.g. user id per
+// sentence piece.
+constexpr const int kOutputAttr = 3;
+
+const char kTextEncoderConfigAttr[] = "text_encoder_config";
+
+// Initializes text encoder object from serialized options:
+// The options are a flexbuffers attribute map that contain the op config
+// with the key `text_encoder_config` as `TextEncoderConfig`.
+void* Initialize(TfLiteContext* context, const char* buffer, size_t length) {
+ const flexbuffers::Map& attr_map =
+ flexbuffers::GetRoot(reinterpret_cast<const uint8_t*>(buffer), length)
+ .AsMap();
+ const flexbuffers::Blob serialized_config =
+ attr_map[kTextEncoderConfigAttr].AsBlob();
+ const TextEncoderConfig* config =
+ flatbuffers::GetRoot<TextEncoderConfig>(serialized_config.data());
+
+ std::unique_ptr<TextEncoderOp> encoder_op(new TextEncoderOp());
+
+ // Create normalizer from options.
+ const TrieNode* charsmap_trie_nodes = reinterpret_cast<const TrieNode*>(
+ config->normalization_charsmap()->Data());
+ const int charsmap_trie_nodes_length =
+ config->normalization_charsmap()->Length() / sizeof(TrieNode);
+ encoder_op->normalizer.reset(new SentencePieceNormalizer(
+ DoubleArrayTrie(charsmap_trie_nodes, charsmap_trie_nodes_length),
+ StringPiece(config->normalization_charsmap_values()->data(),
+ config->normalization_charsmap_values()->size()),
+ config->add_dummy_prefix(), config->remove_extra_whitespaces(),
+ config->escape_whitespaces()));
+
+ const int num_pieces = config->pieces_scores()->Length();
+
+ switch (config->matcher_type()) {
+ case SentencePieceMatcherType_MAPPED_TRIE: {
+ const TrieNode* pieces_trie_nodes =
+ reinterpret_cast<const TrieNode*>(config->pieces()->Data());
+ const int pieces_trie_nodes_length =
+ config->pieces()->Length() / sizeof(TrieNode);
+ encoder_op->matcher.reset(
+ new DoubleArrayTrie(pieces_trie_nodes, pieces_trie_nodes_length));
+ break;
+ }
+ case SentencePieceMatcherType_SORTED_STRING_TABLE: {
+ encoder_op->matcher.reset(new SortedStringsTable(
+ num_pieces, config->pieces_offsets()->data(),
+ StringPiece(config->pieces()->data(), config->pieces()->Length())));
+ break;
+ }
+ default: {
+ TC3_LOG(ERROR) << "Unknown sentence piece matcher type.";
+ return nullptr;
+ }
+ }
+ encoder_op->encoder.reset(new Encoder(
+ encoder_op->matcher.get(), num_pieces, config->pieces_scores()->data(),
+ config->start_code(), config->end_code(), config->encoding_offset(),
+ config->unknown_code(), config->unknown_score()));
+ return encoder_op.release();
+}
+
+void Free(TfLiteContext* context, void* buffer) {
+ delete reinterpret_cast<TextEncoderOp*>(buffer);
+}
+
+TfLiteStatus ResizeOutputTensors(TfLiteContext* context, TfLiteNode* node,
+ int max_output_length) {
+ TF_LITE_ENSURE_OK(
+ context,
+ ResizeOutputTensor(max_output_length,
+ &context->tensors[node->outputs->data[kOutputEncoded]],
+ context));
+
+ TF_LITE_ENSURE_OK(
+ context,
+ ResizeOutputTensor(
+ max_output_length,
+ &context->tensors[node->outputs->data[kOutputPosition]], context));
+
+ const int num_output_attrs = node->outputs->size - kOutputAttr;
+ for (int i = 0; i < num_output_attrs; ++i) {
+ TF_LITE_ENSURE_OK(
+ context,
+ ResizeOutputTensor(
+ max_output_length,
+ &context->tensors[node->outputs->data[kOutputAttr + i]], context));
+ }
+ return kTfLiteOk;
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ // Check that the batch dimension is kBatchSize.
+ const TfLiteTensor& input_text =
+ context->tensors[node->inputs->data[kInputTexts]];
+ TF_LITE_ENSURE_EQ(context, input_text.dims->size, kEncoderInputRank);
+ TF_LITE_ENSURE_EQ(context, input_text.dims->data[0], kEncoderBatchSize);
+
+ TfLiteTensor& output_lengths =
+ context->tensors[node->outputs->data[kOutputLengths]];
+ TfLiteTensor& output_encoded =
+ context->tensors[node->outputs->data[kOutputEncoded]];
+ TfLiteTensor& output_positions =
+ context->tensors[node->outputs->data[kOutputPosition]];
+
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, &output_lengths,
+ CreateIntArray({kEncoderBatchSize})));
+
+ // Check that there are enough outputs for attributes.
+ const int num_output_attrs = node->outputs->size - kOutputAttr;
+ TF_LITE_ENSURE_EQ(context, node->inputs->size - kInputAttr, num_output_attrs);
+
+ // Copy attribute types from input to output tensors.
+ for (int i = 0; i < num_output_attrs; ++i) {
+ TfLiteTensor& input = context->tensors[node->inputs->data[kInputAttr + i]];
+ TfLiteTensor& output =
+ context->tensors[node->outputs->data[kOutputAttr + i]];
+ output.type = input.type;
+ }
+
+ const TfLiteTensor& output_length =
+ context->tensors[node->inputs->data[kInputMaxLength]];
+
+ if (tflite::IsConstantTensor(&output_length)) {
+ return ResizeOutputTensors(context, node, output_length.data.i64[0]);
+ } else {
+ tflite::SetTensorToDynamic(&output_encoded);
+ tflite::SetTensorToDynamic(&output_positions);
+ for (int i = 0; i < num_output_attrs; ++i) {
+ TfLiteTensor& output_attr =
+ context->tensors[node->outputs->data[kOutputAttr + i]];
+ tflite::SetTensorToDynamic(&output_attr);
+ }
+ }
+
+ return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ if (node->user_data == nullptr) {
+ return kTfLiteError;
+ }
+ const TextEncoderOp* encoder_op =
+ reinterpret_cast<TextEncoderOp*>(node->user_data);
+ const TfLiteTensor& input_text =
+ context->tensors[node->inputs->data[kInputTexts]];
+ const int num_strings = tflite::GetStringCount(&input_text);
+ // Check that the number of strings matches the length parameter.
+ const int num_strings_param =
+ context->tensors[node->inputs->data[kInputNumInputs]].data.i32[0];
+ TF_LITE_ENSURE_EQ(context, num_strings, num_strings_param);
+
+ TfLiteTensor& output_encoded =
+ context->tensors[node->outputs->data[kOutputEncoded]];
+ if (tflite::IsDynamicTensor(&output_encoded)) {
+ const TfLiteTensor& output_length =
+ context->tensors[node->inputs->data[kInputMaxLength]];
+ TF_LITE_ENSURE_OK(
+ context, ResizeOutputTensors(context, node, output_length.data.i64[0]));
+ }
+ TfLiteTensor& output_positions =
+ context->tensors[node->outputs->data[kOutputPosition]];
+
+ std::vector<int> encoded_total;
+ std::vector<int> encoded_offsets;
+ std::vector<int> encoded_positions;
+ encoded_offsets.reserve(num_strings);
+ const int max_output_length = output_encoded.dims->data[1];
+ const int max_encoded_position = max_output_length;
+
+ for (int i = 0; i < num_strings; ++i) {
+ const auto& strref = tflite::GetString(&input_text, i);
+ std::string normalized;
+ TF_LITE_ENSURE(context,
+ encoder_op->normalizer->Normalize(
+ StringPiece(strref.str, strref.len), &normalized));
+ std::vector<int> encoded;
+ TF_LITE_ENSURE(context, encoder_op->encoder->Encode(normalized, &encoded));
+ encoded_total.insert(encoded_total.end(), encoded.begin(), encoded.end());
+ encoded_offsets.push_back(encoded_total.size());
+ for (int i = 0; i < encoded.size(); i++) {
+ encoded_positions.push_back(std::min(i, max_encoded_position - 1));
+ }
+ }
+
+ const int num_skip = CopyDataToTensorAndPadOrTruncate(
+ max_output_length, encoded_total,
+ /*padding_value=*/encoded_total.back(), &output_encoded);
+ TfLiteTensor& output_lengths =
+ context->tensors[node->outputs->data[kOutputLengths]];
+ output_lengths.data.i32[0] = encoded_total.size() - num_skip;
+ CopyDataToTensorAndPadOrTruncate(max_output_length, encoded_positions,
+ /*padding_value=*/max_encoded_position,
+ &output_positions);
+
+ // Process attributes, all checks of sizes and types are done in Prepare.
+ const int num_output_attrs = node->outputs->size - kOutputAttr;
+ TF_LITE_ENSURE_EQ(context, node->inputs->size - kInputAttr, num_output_attrs);
+ for (int i = 0; i < num_output_attrs; ++i) {
+ TfLiteStatus attr_status = CopyValuesToTensorAndPadOrTruncate(
+ context->tensors[node->inputs->data[kInputAttr + i]], encoded_offsets,
+ num_skip, context,
+ &context->tensors[node->outputs->data[kOutputAttr + i]]);
+ if (attr_status != kTfLiteOk) {
+ return attr_status;
+ }
+ }
+
+ return kTfLiteOk;
+}
+
+} // namespace
+} // namespace libtextclassifier3
+
+namespace tflite {
+namespace ops {
+namespace custom {
+
+TfLiteRegistration* Register_TEXT_ENCODER() {
+ static TfLiteRegistration registration = {
+ libtextclassifier3::Initialize, libtextclassifier3::Free,
+ libtextclassifier3::Prepare, libtextclassifier3::Eval};
+ return ®istration;
+}
+
+} // namespace custom
+} // namespace ops
+} // namespace tflite
diff --git a/utils/tflite/text_encoder.h b/utils/tflite/text_encoder.h
new file mode 100644
index 0000000..a2cf56b
--- /dev/null
+++ b/utils/tflite/text_encoder.h
@@ -0,0 +1,35 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// An encoder that produces positional and attributes encodings for a
+// transformer style model based on sentence piece segmentation of text.
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_TFLITE_TEXT_ENCODER_H_
+#define LIBTEXTCLASSIFIER_UTILS_TFLITE_TEXT_ENCODER_H_
+
+#include "tensorflow/lite/context.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+
+TfLiteRegistration* Register_TEXT_ENCODER();
+
+} // namespace custom
+} // namespace ops
+} // namespace tflite
+
+#endif // LIBTEXTCLASSIFIER_UTILS_TFLITE_TEXT_ENCODER_H_
diff --git a/utils/tflite/text_encoder_config.fbs b/utils/tflite/text_encoder_config.fbs
new file mode 100644
index 0000000..4ffade4
--- /dev/null
+++ b/utils/tflite/text_encoder_config.fbs
@@ -0,0 +1,65 @@
+//
+// Copyright (C) 2018 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+// Configuration for the text encoder op.
+
+namespace libtextclassifier3;
+
+enum SentencePieceMatcherType : byte {
+ MAPPED_TRIE = 0,
+ SORTED_STRING_TABLE = 1,
+}
+
+table TextEncoderConfig {
+ // Code that is used as encoding of the start code.
+ start_code:int32 = 0;
+
+ // Code that is used as encoding of the end code.
+ end_code:int32 = 1;
+
+ // This value is added to all codes to make them not intersect with
+ // `start_code` and `end_code`.
+ encoding_offset:int32 = 2;
+
+ // Code that is used for out-of-dictionary characters.
+ unknown_code:int32 = -1;
+
+ // Penalty associated with the unknown code.
+ unknown_score:float;
+
+ // Normalization options.
+ // Serialized normalization charsmap.
+ normalization_charsmap:string;
+ normalization_charsmap_values:string;
+
+ // Whether to add dummy whitespace at the beginning of the text in order to
+ // treat "world" in "world" and "hello world" uniformly.
+ add_dummy_prefix:bool = true;
+
+ // Whether to remove leading, trailing and duplicate internal whitespace.
+ remove_extra_whitespaces:bool = true;
+
+ // Whether to replace whitespace with a meta symbol.
+ escape_whitespaces:bool = true;
+
+ // Sentence pieces scores.
+ pieces_scores:[float];
+
+ // Serialized sentence pieces.
+ pieces:string;
+ pieces_offsets:[uint32];
+ matcher_type: SentencePieceMatcherType = MAPPED_TRIE;
+}
diff --git a/utils/tflite/text_encoder_test.cc b/utils/tflite/text_encoder_test.cc
new file mode 100644
index 0000000..ae752f5
--- /dev/null
+++ b/utils/tflite/text_encoder_test.cc
@@ -0,0 +1,176 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <fstream>
+#include <string>
+#include <vector>
+
+#include "utils/tflite/text_encoder.h"
+#include "gtest/gtest.h"
+#include "third_party/absl/flags/flag.h"
+#include "flatbuffers/flexbuffers.h"
+#include "tensorflow/lite/interpreter.h"
+#include "tensorflow/lite/kernels/register.h"
+#include "tensorflow/lite/kernels/test_util.h"
+#include "tensorflow/lite/model.h"
+#include "tensorflow/lite/string_util.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+std::string GetTestConfigPath() {
+ return "";
+}
+
+class TextEncoderOpModel : public tflite::SingleOpModel {
+ public:
+ TextEncoderOpModel(std::initializer_list<int> input_strings_shape,
+ std::initializer_list<int> attribute_shape);
+ void SetInputText(const std::initializer_list<string>& strings) {
+ PopulateStringTensor(input_string_, strings);
+ PopulateTensor(input_length_, {static_cast<int32_t>(strings.size())});
+ }
+ void SetMaxOutputLength(int length) {
+ PopulateTensor(input_output_maxlength_, {length});
+ }
+ void SetInt32Attribute(const std::initializer_list<int>& attribute) {
+ PopulateTensor(input_attributes_int32_, attribute);
+ }
+ void SetFloatAttribute(const std::initializer_list<float>& attribute) {
+ PopulateTensor(input_attributes_float_, attribute);
+ }
+
+ std::vector<int> GetOutputEncoding() {
+ return ExtractVector<int>(output_encoding_);
+ }
+ std::vector<int> GetOutputPositions() {
+ return ExtractVector<int>(output_positions_);
+ }
+ std::vector<int> GetOutputAttributeInt32() {
+ return ExtractVector<int>(output_attributes_int32_);
+ }
+ std::vector<float> GetOutputAttributeFloat() {
+ return ExtractVector<float>(output_attributes_float_);
+ }
+ int GetEncodedLength() { return ExtractVector<int>(output_length_)[0]; }
+
+ private:
+ int input_string_;
+ int input_length_;
+ int input_output_maxlength_;
+ int input_attributes_int32_;
+ int input_attributes_float_;
+
+ int output_encoding_;
+ int output_positions_;
+ int output_length_;
+ int output_attributes_int32_;
+ int output_attributes_float_;
+};
+
+TextEncoderOpModel::TextEncoderOpModel(
+ std::initializer_list<int> input_strings_shape,
+ std::initializer_list<int> attribute_shape) {
+ input_string_ = AddInput(tflite::TensorType_STRING);
+ input_length_ = AddInput(tflite::TensorType_INT32);
+ input_output_maxlength_ = AddInput(tflite::TensorType_INT32);
+ input_attributes_int32_ = AddInput(tflite::TensorType_INT32);
+ input_attributes_float_ = AddInput(tflite::TensorType_FLOAT32);
+
+ output_encoding_ = AddOutput(tflite::TensorType_INT32);
+ output_positions_ = AddOutput(tflite::TensorType_INT32);
+ output_length_ = AddOutput(tflite::TensorType_INT32);
+ output_attributes_int32_ = AddOutput(tflite::TensorType_INT32);
+ output_attributes_float_ = AddOutput(tflite::TensorType_FLOAT32);
+
+ std::ifstream test_config_stream(GetTestConfigPath());
+ std::string config((std::istreambuf_iterator<char>(test_config_stream)),
+ (std::istreambuf_iterator<char>()));
+ flexbuffers::Builder builder;
+ builder.Map([&]() { builder.String("text_encoder_config", config); });
+ builder.Finish();
+ SetCustomOp("TextEncoder", builder.GetBuffer(),
+ tflite::ops::custom::Register_TEXT_ENCODER);
+ BuildInterpreter(
+ {input_strings_shape, {1}, {1}, attribute_shape, attribute_shape});
+}
+
+// Tests
+TEST(TextEncoderTest, SimpleEncoder) {
+ TextEncoderOpModel m({1, 1}, {1, 1});
+ m.SetInputText({"Hello"});
+ m.SetMaxOutputLength(10);
+ m.SetInt32Attribute({7});
+ m.SetFloatAttribute({3.f});
+
+ m.Invoke();
+
+ EXPECT_EQ(m.GetEncodedLength(), 5);
+ EXPECT_THAT(m.GetOutputEncoding(),
+ testing::ElementsAre(1, 90, 547, 58, 2, 2, 2, 2, 2, 2));
+ EXPECT_THAT(m.GetOutputPositions(),
+ testing::ElementsAre(0, 1, 2, 3, 4, 10, 10, 10, 10, 10));
+ EXPECT_THAT(m.GetOutputAttributeInt32(),
+ testing::ElementsAre(7, 7, 7, 7, 7, 7, 7, 7, 7, 7));
+ EXPECT_THAT(
+ m.GetOutputAttributeFloat(),
+ testing::ElementsAre(3.f, 3.f, 3.f, 3.f, 3.f, 3.f, 3.f, 3.f, 3.f, 3.f));
+}
+
+TEST(TextEncoderTest, ManyStrings) {
+ TextEncoderOpModel m({1, 3}, {1, 3});
+ m.SetInt32Attribute({1, 2, 3});
+ m.SetFloatAttribute({5.f, 4.f, 3.f});
+ m.SetInputText({"Hello", "Hi", "Bye"});
+ m.SetMaxOutputLength(10);
+
+ m.Invoke();
+
+ EXPECT_EQ(m.GetEncodedLength(), 10);
+ EXPECT_THAT(m.GetOutputEncoding(),
+ testing::ElementsAre(547, 58, 2, 1, 862, 2, 1, 1919, 19, 2));
+ EXPECT_THAT(m.GetOutputPositions(),
+ testing::ElementsAre(2, 3, 4, 0, 1, 2, 0, 1, 2, 3));
+ EXPECT_THAT(m.GetOutputAttributeInt32(),
+ testing::ElementsAre(1, 1, 1, 2, 2, 2, 3, 3, 3, 3));
+ EXPECT_THAT(
+ m.GetOutputAttributeFloat(),
+ testing::ElementsAre(5.f, 5.f, 5.f, 4.f, 4.f, 4.f, 3.f, 3.f, 3.f, 3.f));
+}
+
+TEST(TextEncoderTest, LongStrings) {
+ TextEncoderOpModel m({1, 4}, {1, 4});
+ m.SetInt32Attribute({1, 2, 3, 4});
+ m.SetFloatAttribute({5.f, 4.f, 3.f, 2.f});
+ m.SetInputText({"Hello", "Hi", "Bye", "Hi"});
+ m.SetMaxOutputLength(9);
+
+ m.Invoke();
+
+ EXPECT_EQ(m.GetEncodedLength(), 9);
+ EXPECT_THAT(m.GetOutputEncoding(),
+ testing::ElementsAre(862, 2, 1, 1919, 19, 2, 1, 862, 2));
+ EXPECT_THAT(m.GetOutputPositions(),
+ testing::ElementsAre(1, 2, 0, 1, 2, 3, 0, 1, 2));
+ EXPECT_THAT(m.GetOutputAttributeInt32(),
+ testing::ElementsAre(2, 2, 3, 3, 3, 3, 4, 4, 4));
+ EXPECT_THAT(
+ m.GetOutputAttributeFloat(),
+ testing::ElementsAre(4.f, 4.f, 3.f, 3.f, 3.f, 3.f, 2.f, 2.f, 2.f));
+}
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/utils/tflite/token_encoder.cc b/utils/tflite/token_encoder.cc
new file mode 100644
index 0000000..75e5d1e
--- /dev/null
+++ b/utils/tflite/token_encoder.cc
@@ -0,0 +1,190 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/tflite/token_encoder.h"
+
+#include <vector>
+
+#include "utils/tflite/encoder_common.h"
+#include "tensorflow/lite/kernels/kernel_util.h"
+#include "tensorflow/lite/model.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+// Input parameters for the op.
+// The number of tokens per message as (1, conversation length) int tensor.
+constexpr const int kInputNumTokens = 0;
+
+// The number of messages, the conversation length, int scalar.
+constexpr const int kInputNumInputs = 1;
+
+// Maximum output length of the encoding, int scalar.
+constexpr const int kInputMaxLength = 2;
+
+// Additional attributes to align to the sentence pieces, e.g. user ids per
+// message.
+constexpr const int kInputAttr = 3;
+
+// Output parameters for the op.
+// Relative position of each token in the input text,
+// (1, max output length) int tensor.
+constexpr const int kOutputPosition = 0;
+
+// Output length after trimming to the maximum output length specified.
+// int scalar.
+constexpr const int kOutputLengths = 1;
+
+// Padded and sentence piece aligned provided attributes, e.g. user id per
+// sentence piece.
+constexpr const int kOutputAttr = 2;
+
+TfLiteStatus ResizeOutputTensors(TfLiteContext* context, TfLiteNode* node,
+ int max_output_length) {
+ TF_LITE_ENSURE_OK(
+ context,
+ ResizeOutputTensor(
+ max_output_length,
+ &context->tensors[node->outputs->data[kOutputPosition]], context));
+
+ const int num_output_attrs = node->outputs->size - kOutputAttr;
+ for (int i = 0; i < num_output_attrs; ++i) {
+ TF_LITE_ENSURE_OK(
+ context,
+ ResizeOutputTensor(
+ max_output_length,
+ &context->tensors[node->outputs->data[kOutputAttr + i]], context));
+ }
+ return kTfLiteOk;
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ // Check that the batch dimension is kBatchSize.
+ const TfLiteTensor& num_tokens =
+ context->tensors[node->inputs->data[kInputNumTokens]];
+ TF_LITE_ENSURE_EQ(context, num_tokens.dims->size, kEncoderInputRank);
+ TF_LITE_ENSURE_EQ(context, num_tokens.dims->data[0], kEncoderBatchSize);
+
+ TfLiteTensor& output_lengths =
+ context->tensors[node->outputs->data[kOutputLengths]];
+ TfLiteTensor& output_positions =
+ context->tensors[node->outputs->data[kOutputPosition]];
+
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, &output_lengths,
+ CreateIntArray({kEncoderBatchSize})));
+
+ // Check that there are enough outputs for attributes.
+ const int num_output_attrs = node->outputs->size - kOutputAttr;
+ TF_LITE_ENSURE_EQ(context, node->inputs->size - kInputAttr, num_output_attrs);
+
+ // Copy attribute types from input to output tensors.
+ for (int i = 0; i < num_output_attrs; ++i) {
+ TfLiteTensor& input = context->tensors[node->inputs->data[kInputAttr + i]];
+ TfLiteTensor& output =
+ context->tensors[node->outputs->data[kOutputAttr + i]];
+ output.type = input.type;
+ }
+
+ const TfLiteTensor& output_length =
+ context->tensors[node->inputs->data[kInputMaxLength]];
+
+ if (tflite::IsConstantTensor(&output_length)) {
+ return ResizeOutputTensors(context, node, output_length.data.i64[0]);
+ } else {
+ tflite::SetTensorToDynamic(&output_positions);
+ for (int i = 0; i < num_output_attrs; ++i) {
+ TfLiteTensor& output_attr =
+ context->tensors[node->outputs->data[kOutputAttr + i]];
+ tflite::SetTensorToDynamic(&output_attr);
+ }
+ }
+
+ return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ const TfLiteTensor& num_tokens =
+ context->tensors[node->inputs->data[kInputNumTokens]];
+ const int num_inputs =
+ context->tensors[node->inputs->data[kInputNumInputs]].data.i32[0];
+
+ const TfLiteTensor& output_length =
+ context->tensors[node->inputs->data[kInputMaxLength]];
+ TfLiteTensor& output_positions =
+ context->tensors[node->outputs->data[kOutputPosition]];
+ if (!tflite::IsConstantTensor(&output_length)) {
+ TF_LITE_ENSURE_OK(
+ context, ResizeOutputTensors(context, node, output_length.data.i64[0]));
+ }
+
+ std::vector<int> encoded_offsets;
+ std::vector<int> encoded_positions;
+ encoded_offsets.reserve(num_inputs);
+ const int max_output_length = output_positions.dims->data[1];
+ const int max_encoded_position = max_output_length;
+ int total_tokens = 0;
+
+ for (int i = 0; i < num_inputs; ++i) {
+ const int num_message_tokens =
+ num_tokens.data.i32[i] + 2; /* num_tokens + start and end token. */
+ total_tokens += num_message_tokens;
+ encoded_offsets.push_back(total_tokens);
+ for (int k = 0; k < num_message_tokens; k++) {
+ encoded_positions.push_back(std::min(k, max_encoded_position - 1));
+ }
+ }
+
+ const int num_skip = CopyDataToTensorAndPadOrTruncate(
+ max_output_length, encoded_positions,
+ /*padding_value=*/max_encoded_position, &output_positions);
+ TfLiteTensor& output_lengths =
+ context->tensors[node->outputs->data[kOutputLengths]];
+ output_lengths.data.i32[0] = encoded_positions.size() - num_skip;
+
+ // Process attributes, all checks of sizes and types are done in Prepare.
+ const int num_output_attrs = node->outputs->size - kOutputAttr;
+ TF_LITE_ENSURE_EQ(context, node->inputs->size - kInputAttr, num_output_attrs);
+ for (int i = 0; i < num_output_attrs; ++i) {
+ TfLiteStatus attr_status = CopyValuesToTensorAndPadOrTruncate(
+ context->tensors[node->inputs->data[kInputAttr + i]], encoded_offsets,
+ num_skip, context,
+ &context->tensors[node->outputs->data[kOutputAttr + i]]);
+ if (attr_status != kTfLiteOk) {
+ return attr_status;
+ }
+ }
+
+ return kTfLiteOk;
+}
+
+} // namespace
+} // namespace libtextclassifier3
+
+namespace tflite {
+namespace ops {
+namespace custom {
+
+TfLiteRegistration* Register_TOKEN_ENCODER() {
+ static TfLiteRegistration registration = {/*init=*/nullptr, /*free=*/nullptr,
+ libtextclassifier3::Prepare,
+ libtextclassifier3::Eval};
+ return ®istration;
+}
+
+} // namespace custom
+} // namespace ops
+} // namespace tflite
diff --git a/utils/tflite/token_encoder.h b/utils/tflite/token_encoder.h
new file mode 100644
index 0000000..94b3f70
--- /dev/null
+++ b/utils/tflite/token_encoder.h
@@ -0,0 +1,35 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// An encoder that produces positional and attributes encodings for a
+// transformer style model based on tokens (rather than sentence pieces).
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_TFLITE_TOKEN_ENCODER_H_
+#define LIBTEXTCLASSIFIER_UTILS_TFLITE_TOKEN_ENCODER_H_
+
+#include "tensorflow/lite/context.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+
+TfLiteRegistration* Register_TOKEN_ENCODER();
+
+} // namespace custom
+} // namespace ops
+} // namespace tflite
+
+#endif // LIBTEXTCLASSIFIER_UTILS_TFLITE_TOKEN_ENCODER_H_
diff --git a/utils/tflite/token_encoder_test.cc b/utils/tflite/token_encoder_test.cc
new file mode 100644
index 0000000..c7f51a1
--- /dev/null
+++ b/utils/tflite/token_encoder_test.cc
@@ -0,0 +1,148 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <vector>
+
+#include "utils/tflite/token_encoder.h"
+#include "gtest/gtest.h"
+#include "third_party/absl/flags/flag.h"
+#include "tensorflow/lite/interpreter.h"
+#include "tensorflow/lite/kernels/register.h"
+#include "tensorflow/lite/kernels/test_util.h"
+#include "tensorflow/lite/model.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+class TokenEncoderOpModel : public tflite::SingleOpModel {
+ public:
+ TokenEncoderOpModel(std::initializer_list<int> input_shape,
+ std::initializer_list<int> attribute_shape);
+ void SetNumTokens(const std::initializer_list<int>& num_tokens) {
+ PopulateTensor(input_num_tokens_, num_tokens);
+ PopulateTensor(input_length_, {static_cast<int32_t>(num_tokens.size())});
+ }
+ void SetMaxOutputLength(int length) {
+ PopulateTensor(input_output_maxlength_, {length});
+ }
+ void SetInt32Attribute(const std::initializer_list<int>& attribute) {
+ PopulateTensor(input_attributes_int32_, attribute);
+ }
+ void SetFloatAttribute(const std::initializer_list<float>& attribute) {
+ PopulateTensor(input_attributes_float_, attribute);
+ }
+ std::vector<int> GetOutputPositions() {
+ return ExtractVector<int>(output_positions_);
+ }
+ std::vector<int> GetOutputAttributeInt32() {
+ return ExtractVector<int>(output_attributes_int32_);
+ }
+ std::vector<float> GetOutputAttributeFloat() {
+ return ExtractVector<float>(output_attributes_float_);
+ }
+ int GetOutputLength() { return ExtractVector<int>(output_length_)[0]; }
+
+ private:
+ int input_num_tokens_;
+ int input_length_;
+ int input_output_maxlength_;
+ int input_attributes_int32_;
+ int input_attributes_float_;
+
+ int output_positions_;
+ int output_length_;
+ int output_attributes_int32_;
+ int output_attributes_float_;
+};
+
+TokenEncoderOpModel::TokenEncoderOpModel(
+ std::initializer_list<int> input_shape,
+ std::initializer_list<int> attribute_shape) {
+ input_num_tokens_ = AddInput(tflite::TensorType_INT32);
+ input_length_ = AddInput(tflite::TensorType_INT32);
+ input_output_maxlength_ = AddInput(tflite::TensorType_INT32);
+ input_attributes_int32_ = AddInput(tflite::TensorType_INT32);
+ input_attributes_float_ = AddInput(tflite::TensorType_FLOAT32);
+
+ output_positions_ = AddOutput(tflite::TensorType_INT32);
+ output_length_ = AddOutput(tflite::TensorType_INT32);
+ output_attributes_int32_ = AddOutput(tflite::TensorType_INT32);
+ output_attributes_float_ = AddOutput(tflite::TensorType_FLOAT32);
+
+ SetCustomOp("TokenEncoder", {}, tflite::ops::custom::Register_TOKEN_ENCODER);
+ BuildInterpreter({input_shape, {1}, {1}, attribute_shape, attribute_shape});
+}
+
+// Tests
+TEST(TokenEncoderTest, SimpleEncoder) {
+ TokenEncoderOpModel m({1, 1}, {1, 1});
+ m.SetNumTokens({1});
+ m.SetMaxOutputLength(10);
+ m.SetInt32Attribute({7});
+ m.SetFloatAttribute({3.f});
+
+ m.Invoke();
+
+ EXPECT_EQ(m.GetOutputLength(), 3);
+ EXPECT_THAT(m.GetOutputPositions(),
+ testing::ElementsAre(0, 1, 2, 10, 10, 10, 10, 10, 10, 10));
+ EXPECT_THAT(m.GetOutputAttributeInt32(),
+ testing::ElementsAre(7, 7, 7, 7, 7, 7, 7, 7, 7, 7));
+ EXPECT_THAT(
+ m.GetOutputAttributeFloat(),
+ testing::ElementsAre(3.f, 3.f, 3.f, 3.f, 3.f, 3.f, 3.f, 3.f, 3.f, 3.f));
+}
+
+TEST(TokenEncoderTest, ManyMessages) {
+ TokenEncoderOpModel m({1, 3}, {1, 3});
+ m.SetInt32Attribute({1, 2, 3});
+ m.SetFloatAttribute({5.f, 4.f, 3.f});
+ m.SetNumTokens({1, 1, 1});
+ m.SetMaxOutputLength(10);
+
+ m.Invoke();
+
+ EXPECT_EQ(m.GetOutputLength(), 9);
+ EXPECT_THAT(m.GetOutputPositions(),
+ testing::ElementsAre(0, 1, 2, 0, 1, 2, 0, 1, 2, 10));
+ EXPECT_THAT(m.GetOutputAttributeInt32(),
+ testing::ElementsAre(1, 1, 1, 2, 2, 2, 3, 3, 3, 3));
+ EXPECT_THAT(
+ m.GetOutputAttributeFloat(),
+ testing::ElementsAre(5.f, 5.f, 5.f, 4.f, 4.f, 4.f, 3.f, 3.f, 3.f, 3.f));
+}
+
+TEST(TokenEncoderTest, ManyMessagesMultipleTokens) {
+ TokenEncoderOpModel m({1, 4}, {1, 4});
+ m.SetInt32Attribute({1, 2, 3, 4});
+ m.SetFloatAttribute({5.f, 4.f, 3.f, 2.f});
+ m.SetNumTokens({1, 2, 3, 4});
+ m.SetMaxOutputLength(9);
+
+ m.Invoke();
+
+ EXPECT_EQ(m.GetOutputLength(), 9);
+ EXPECT_THAT(m.GetOutputPositions(),
+ testing::ElementsAre(2, 3, 4, 0, 1, 2, 3, 4, 5));
+ EXPECT_THAT(m.GetOutputAttributeInt32(),
+ testing::ElementsAre(3, 3, 3, 4, 4, 4, 4, 4, 4));
+ EXPECT_THAT(
+ m.GetOutputAttributeFloat(),
+ testing::ElementsAre(3.f, 3.f, 3.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f));
+}
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/annotator/token-feature-extractor.cc b/utils/token-feature-extractor.cc
similarity index 99%
rename from annotator/token-feature-extractor.cc
rename to utils/token-feature-extractor.cc
index 77ad7a4..9faebca 100644
--- a/annotator/token-feature-extractor.cc
+++ b/utils/token-feature-extractor.cc
@@ -14,7 +14,7 @@
* limitations under the License.
*/
-#include "annotator/token-feature-extractor.h"
+#include "utils/token-feature-extractor.h"
#include <cctype>
#include <string>
diff --git a/annotator/token-feature-extractor.h b/utils/token-feature-extractor.h
similarity index 94%
rename from annotator/token-feature-extractor.h
rename to utils/token-feature-extractor.h
index 7dc19fe..fed113b 100644
--- a/annotator/token-feature-extractor.h
+++ b/utils/token-feature-extractor.h
@@ -14,8 +14,8 @@
* limitations under the License.
*/
-#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_TOKEN_FEATURE_EXTRACTOR_H_
-#define LIBTEXTCLASSIFIER_ANNOTATOR_TOKEN_FEATURE_EXTRACTOR_H_
+#ifndef LIBTEXTCLASSIFIER_UTILS_TOKEN_FEATURE_EXTRACTOR_H_
+#define LIBTEXTCLASSIFIER_UTILS_TOKEN_FEATURE_EXTRACTOR_H_
#include <memory>
#include <unordered_set>
@@ -112,4 +112,4 @@
} // namespace libtextclassifier3
-#endif // LIBTEXTCLASSIFIER_ANNOTATOR_TOKEN_FEATURE_EXTRACTOR_H_
+#endif // LIBTEXTCLASSIFIER_UTILS_TOKEN_FEATURE_EXTRACTOR_H_
diff --git a/annotator/token-feature-extractor_test.cc b/utils/token-feature-extractor_test.cc
similarity index 99%
rename from annotator/token-feature-extractor_test.cc
rename to utils/token-feature-extractor_test.cc
index 32383a9..9a97e42 100644
--- a/annotator/token-feature-extractor_test.cc
+++ b/utils/token-feature-extractor_test.cc
@@ -14,7 +14,7 @@
* limitations under the License.
*/
-#include "annotator/token-feature-extractor.h"
+#include "utils/token-feature-extractor.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
diff --git a/utils/tokenizer.cc b/utils/tokenizer.cc
new file mode 100644
index 0000000..87a5c8d
--- /dev/null
+++ b/utils/tokenizer.cc
@@ -0,0 +1,261 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/tokenizer.h"
+
+#include <algorithm>
+
+#include "utils/base/logging.h"
+#include "utils/base/macros.h"
+#include "utils/strings/utf8.h"
+
+namespace libtextclassifier3 {
+
+Tokenizer::Tokenizer(
+ const TokenizationType type, const UniLib* unilib,
+ const std::vector<const TokenizationCodepointRange*>& codepoint_ranges,
+ const std::vector<const CodepointRange*>&
+ internal_tokenizer_codepoint_ranges,
+ const bool split_on_script_change,
+ const bool icu_preserve_whitespace_tokens)
+ : type_(type),
+ unilib_(unilib),
+ split_on_script_change_(split_on_script_change),
+ icu_preserve_whitespace_tokens_(icu_preserve_whitespace_tokens) {
+ for (const TokenizationCodepointRange* range : codepoint_ranges) {
+ codepoint_ranges_.emplace_back(range->UnPack());
+ }
+
+ std::sort(codepoint_ranges_.begin(), codepoint_ranges_.end(),
+ [](const std::unique_ptr<const TokenizationCodepointRangeT>& a,
+ const std::unique_ptr<const TokenizationCodepointRangeT>& b) {
+ return a->start < b->start;
+ });
+
+ SortCodepointRanges(internal_tokenizer_codepoint_ranges,
+ &internal_tokenizer_codepoint_ranges_);
+}
+
+const TokenizationCodepointRangeT* Tokenizer::FindTokenizationRange(
+ int codepoint) const {
+ auto it = std::lower_bound(
+ codepoint_ranges_.begin(), codepoint_ranges_.end(), codepoint,
+ [](const std::unique_ptr<const TokenizationCodepointRangeT>& range,
+ int codepoint) {
+ // This function compares range with the codepoint for the purpose of
+ // finding the first greater or equal range. Because of the use of
+ // std::lower_bound it needs to return true when range < codepoint;
+ // the first time it will return false the lower bound is found and
+ // returned.
+ //
+ // It might seem weird that the condition is range.end <= codepoint
+ // here but when codepoint == range.end it means it's actually just
+ // outside of the range, thus the range is less than the codepoint.
+ return range->end <= codepoint;
+ });
+ if (it != codepoint_ranges_.end() && (*it)->start <= codepoint &&
+ (*it)->end > codepoint) {
+ return it->get();
+ } else {
+ return nullptr;
+ }
+}
+
+void Tokenizer::GetScriptAndRole(char32 codepoint,
+ TokenizationCodepointRange_::Role* role,
+ int* script) const {
+ const TokenizationCodepointRangeT* range = FindTokenizationRange(codepoint);
+ if (range) {
+ *role = range->role;
+ *script = range->script_id;
+ } else {
+ *role = TokenizationCodepointRange_::Role_DEFAULT_ROLE;
+ *script = kUnknownScript;
+ }
+}
+
+std::vector<Token> Tokenizer::Tokenize(const std::string& text) const {
+ UnicodeText text_unicode = UTF8ToUnicodeText(text, /*do_copy=*/false);
+ return Tokenize(text_unicode);
+}
+
+std::vector<Token> Tokenizer::Tokenize(const UnicodeText& text_unicode) const {
+ switch (type_) {
+ case TokenizationType_INTERNAL_TOKENIZER:
+ return InternalTokenize(text_unicode);
+ case TokenizationType_ICU:
+ TC3_FALLTHROUGH_INTENDED;
+ case TokenizationType_MIXED: {
+ std::vector<Token> result;
+ if (!ICUTokenize(text_unicode, &result)) {
+ return {};
+ }
+ if (type_ == TokenizationType_MIXED) {
+ InternalRetokenize(text_unicode, &result);
+ }
+ return result;
+ }
+ default:
+ TC3_LOG(ERROR) << "Unknown tokenization type specified. Using internal.";
+ return InternalTokenize(text_unicode);
+ }
+}
+
+std::vector<Token> Tokenizer::InternalTokenize(
+ const UnicodeText& text_unicode) const {
+ std::vector<Token> result;
+ Token new_token("", 0, 0);
+ int codepoint_index = 0;
+
+ int last_script = kInvalidScript;
+ for (auto it = text_unicode.begin(); it != text_unicode.end();
+ ++it, ++codepoint_index) {
+ TokenizationCodepointRange_::Role role;
+ int script;
+ GetScriptAndRole(*it, &role, &script);
+
+ if (role & TokenizationCodepointRange_::Role_SPLIT_BEFORE ||
+ (split_on_script_change_ && last_script != kInvalidScript &&
+ last_script != script)) {
+ if (!new_token.value.empty()) {
+ result.push_back(new_token);
+ }
+ new_token = Token("", codepoint_index, codepoint_index);
+ }
+ if (!(role & TokenizationCodepointRange_::Role_DISCARD_CODEPOINT)) {
+ new_token.value += std::string(
+ it.utf8_data(),
+ it.utf8_data() + GetNumBytesForNonZeroUTF8Char(it.utf8_data()));
+ ++new_token.end;
+ }
+ if (role & TokenizationCodepointRange_::Role_SPLIT_AFTER) {
+ if (!new_token.value.empty()) {
+ result.push_back(new_token);
+ }
+ new_token = Token("", codepoint_index + 1, codepoint_index + 1);
+ }
+
+ last_script = script;
+ }
+ if (!new_token.value.empty()) {
+ result.push_back(new_token);
+ }
+
+ return result;
+}
+
+void Tokenizer::TokenizeSubstring(const UnicodeText& unicode_text,
+ CodepointSpan span,
+ std::vector<Token>* result) const {
+ if (span.first < 0) {
+ // There is no span to tokenize.
+ return;
+ }
+
+ // Extract the substring.
+ UnicodeText text = UnicodeText::Substring(unicode_text, span.first,
+ span.second, /*do_copy=*/false);
+
+ // Run the tokenizer and update the token bounds to reflect the offset of the
+ // substring.
+ std::vector<Token> tokens = InternalTokenize(text);
+
+ // Avoids progressive capacity increases in the for loop.
+ result->reserve(result->size() + tokens.size());
+ for (Token& token : tokens) {
+ token.start += span.first;
+ token.end += span.first;
+ result->emplace_back(std::move(token));
+ }
+}
+
+void Tokenizer::InternalRetokenize(const UnicodeText& unicode_text,
+ std::vector<Token>* tokens) const {
+ std::vector<Token> result;
+ CodepointSpan span(-1, -1);
+ for (Token& token : *tokens) {
+ const UnicodeText unicode_token_value =
+ UTF8ToUnicodeText(token.value, /*do_copy=*/false);
+ bool should_retokenize = true;
+ for (const int codepoint : unicode_token_value) {
+ if (!IsCodepointInRanges(codepoint,
+ internal_tokenizer_codepoint_ranges_)) {
+ should_retokenize = false;
+ break;
+ }
+ }
+
+ if (should_retokenize) {
+ if (span.first < 0) {
+ span.first = token.start;
+ }
+ span.second = token.end;
+ } else {
+ TokenizeSubstring(unicode_text, span, &result);
+ span.first = -1;
+ result.emplace_back(std::move(token));
+ }
+ }
+ TokenizeSubstring(unicode_text, span, &result);
+
+ *tokens = std::move(result);
+}
+
+bool Tokenizer::ICUTokenize(const UnicodeText& context_unicode,
+ std::vector<Token>* result) const {
+ std::unique_ptr<UniLib::BreakIterator> break_iterator =
+ unilib_->CreateBreakIterator(context_unicode);
+ if (!break_iterator) {
+ return false;
+ }
+ int last_break_index = 0;
+ int break_index = 0;
+ int last_unicode_index = 0;
+ int unicode_index = 0;
+ auto token_begin_it = context_unicode.begin();
+ while ((break_index = break_iterator->Next()) !=
+ UniLib::BreakIterator::kDone) {
+ const int token_length = break_index - last_break_index;
+ unicode_index = last_unicode_index + token_length;
+
+ auto token_end_it = token_begin_it;
+ std::advance(token_end_it, token_length);
+
+ // Determine if the whole token is whitespace.
+ bool is_whitespace = true;
+ for (auto char_it = token_begin_it; char_it < token_end_it; ++char_it) {
+ if (!unilib_->IsWhitespace(*char_it)) {
+ is_whitespace = false;
+ break;
+ }
+ }
+
+ const std::string token =
+ context_unicode.UTF8Substring(token_begin_it, token_end_it);
+
+ if (!is_whitespace || icu_preserve_whitespace_tokens_) {
+ result->push_back(Token(token, last_unicode_index, unicode_index));
+ }
+
+ last_break_index = break_index;
+ last_unicode_index = unicode_index;
+ token_begin_it = token_end_it;
+ }
+
+ return true;
+}
+
+} // namespace libtextclassifier3
diff --git a/utils/tokenizer.fbs b/utils/tokenizer.fbs
new file mode 100755
index 0000000..2a19999
--- /dev/null
+++ b/utils/tokenizer.fbs
@@ -0,0 +1,70 @@
+//
+// Copyright (C) 2018 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+// Controls the type of tokenization the model will use for the input text.
+namespace libtextclassifier3;
+enum TokenizationType : int {
+ INVALID_TOKENIZATION_TYPE = 0,
+
+ // Use the internal tokenizer for tokenization.
+ INTERNAL_TOKENIZER = 1,
+
+ // Use ICU for tokenization.
+ ICU = 2,
+
+ // First apply ICU tokenization. Then identify stretches of tokens
+ // consisting only of codepoints in internal_tokenizer_codepoint_ranges
+ // and re-tokenize them using the internal tokenizer.
+ MIXED = 3,
+}
+
+// Role of the codepoints in the range.
+namespace libtextclassifier3.TokenizationCodepointRange_;
+enum Role : int {
+ // Concatenates the codepoint to the current run of codepoints.
+ DEFAULT_ROLE = 0,
+
+ // Splits a run of codepoints before the current codepoint.
+ SPLIT_BEFORE = 1,
+
+ // Splits a run of codepoints after the current codepoint.
+ SPLIT_AFTER = 2,
+
+ // Each codepoint will be a separate token. Good e.g. for Chinese
+ // characters.
+ TOKEN_SEPARATOR = 3,
+
+ // Discards the codepoint.
+ DISCARD_CODEPOINT = 4,
+
+ // Common values:
+ // Splits on the characters and discards them. Good e.g. for the space
+ // character.
+ WHITESPACE_SEPARATOR = 7,
+}
+
+// Represents a codepoint range [start, end) with its role for tokenization.
+namespace libtextclassifier3;
+table TokenizationCodepointRange {
+ start:int;
+ end:int;
+ role:TokenizationCodepointRange_.Role;
+
+ // Integer identifier of the script this range denotes. Negative values are
+ // reserved for Tokenizer's internal use.
+ script_id:int;
+}
+
diff --git a/utils/tokenizer.h b/utils/tokenizer.h
new file mode 100644
index 0000000..3a9ef6c
--- /dev/null
+++ b/utils/tokenizer.h
@@ -0,0 +1,124 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_TOKENIZER_H_
+#define LIBTEXTCLASSIFIER_UTILS_TOKENIZER_H_
+
+#include <string>
+#include <vector>
+
+#include "annotator/types.h"
+#include "utils/base/integral_types.h"
+#include "utils/codepoint-range.h"
+#include "utils/tokenizer_generated.h"
+#include "utils/utf8/unicodetext.h"
+#include "utils/utf8/unilib.h"
+
+namespace libtextclassifier3 {
+
+const int kInvalidScript = -1;
+const int kUnknownScript = -2;
+
+// Tokenizer splits the input string into a sequence of tokens, according to
+// the configuration.
+class Tokenizer {
+ public:
+ // `codepoint_ranges`: Codepoint ranges that determine how different
+ // codepoints are tokenized. The ranges must not overlap.
+ // `internal_tokenizer_codepoint_ranges`: Codepoint ranges that define which
+ // tokens should be re-tokenized with the internal tokenizer in the mixed
+ // tokenization mode.
+ // `split_on_script_change`: Whether to consider a change of codepoint script
+ // in a sequence of characters as a token boundary. If True, will treat
+ // script change as a token boundary.
+ // `icu_preserve_whitespace_tokens`: If true, will include empty tokens in the
+ // output (in the ICU tokenization mode).
+ Tokenizer(
+ const TokenizationType type, const UniLib* unilib,
+ const std::vector<const TokenizationCodepointRange*>& codepoint_ranges,
+ const std::vector<const CodepointRange*>&
+ internal_tokenizer_codepoint_ranges,
+ const bool split_on_script_change,
+ const bool icu_preserve_whitespace_tokens);
+
+ Tokenizer(
+ const std::vector<const TokenizationCodepointRange*>& codepoint_ranges,
+ const bool split_on_script_change)
+ : Tokenizer(TokenizationType_INTERNAL_TOKENIZER, /*unilib=*/nullptr,
+ codepoint_ranges, /*internal_tokenizer_codepoint_ranges=*/{},
+ split_on_script_change,
+ /*icu_preserve_whitespace_tokens=*/false) {}
+
+ // Tokenizes the input string using the selected tokenization method.
+ std::vector<Token> Tokenize(const std::string& text) const;
+
+ // Same as above but takes UnicodeText.
+ std::vector<Token> Tokenize(const UnicodeText& text_unicode) const;
+
+ protected:
+ // Finds the tokenization codepoint range config for given codepoint.
+ // Internally uses binary search so should be O(log(# of codepoint_ranges)).
+ const TokenizationCodepointRangeT* FindTokenizationRange(int codepoint) const;
+
+ // Finds the role and script for given codepoint. If not found, DEFAULT_ROLE
+ // and kUnknownScript are assigned.
+ void GetScriptAndRole(char32 codepoint,
+ TokenizationCodepointRange_::Role* role,
+ int* script) const;
+
+ // Tokenizes a substring of the unicode string, appending the resulting tokens
+ // to the output vector. The resulting tokens have bounds relative to the full
+ // string. Does nothing if the start of the span is negative.
+ void TokenizeSubstring(const UnicodeText& unicode_text, CodepointSpan span,
+ std::vector<Token>* result) const;
+
+ std::vector<Token> InternalTokenize(const UnicodeText& text_unicode) const;
+
+ // Takes the result of ICU tokenization and retokenizes stretches of tokens
+ // made of a specific subset of characters using the internal tokenizer.
+ void InternalRetokenize(const UnicodeText& unicode_text,
+ std::vector<Token>* tokens) const;
+
+ // Tokenizes the input text using ICU tokenizer.
+ bool ICUTokenize(const UnicodeText& context_unicode,
+ std::vector<Token>* result) const;
+
+ private:
+ const TokenizationType type_;
+
+ const UniLib* unilib_;
+
+ // Codepoint ranges that determine how different codepoints are tokenized.
+ // The ranges must not overlap.
+ std::vector<std::unique_ptr<const TokenizationCodepointRangeT>>
+ codepoint_ranges_;
+
+ // Codepoint ranges that define which tokens (consisting of which codepoints)
+ // should be re-tokenized with the internal tokenizer in the mixed
+ // tokenization mode.
+ // NOTE: Must be sorted.
+ std::vector<CodepointRangeStruct> internal_tokenizer_codepoint_ranges_;
+
+ // If true, tokens will be additionally split when the codepoint's script_id
+ // changes.
+ const bool split_on_script_change_;
+
+ const bool icu_preserve_whitespace_tokens_;
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_TOKENIZER_H_
diff --git a/utils/tokenizer_test.cc b/utils/tokenizer_test.cc
new file mode 100644
index 0000000..4f4f763
--- /dev/null
+++ b/utils/tokenizer_test.cc
@@ -0,0 +1,485 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/tokenizer.h"
+
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+using testing::ElementsAreArray;
+
+class TestingTokenizer : public Tokenizer {
+ public:
+ TestingTokenizer(
+ const TokenizationType type, const UniLib* unilib,
+ const std::vector<const TokenizationCodepointRange*>& codepoint_ranges,
+ const std::vector<const CodepointRange*>&
+ internal_tokenizer_codepoint_ranges,
+ const bool split_on_script_change,
+ const bool icu_preserve_whitespace_tokens)
+ : Tokenizer(type, unilib, codepoint_ranges,
+ internal_tokenizer_codepoint_ranges, split_on_script_change,
+ icu_preserve_whitespace_tokens) {}
+
+ using Tokenizer::FindTokenizationRange;
+};
+
+class TestingTokenizerProxy {
+ public:
+ TestingTokenizerProxy(
+ TokenizationType type,
+ const std::vector<TokenizationCodepointRangeT>& codepoint_range_configs,
+ const std::vector<CodepointRangeT>& internal_codepoint_range_configs,
+ const bool split_on_script_change,
+ const bool icu_preserve_whitespace_tokens)
+ : INIT_UNILIB_FOR_TESTING(unilib_) {
+ const int num_configs = codepoint_range_configs.size();
+ std::vector<const TokenizationCodepointRange*> configs_fb;
+ configs_fb.reserve(num_configs);
+ const int num_internal_configs = internal_codepoint_range_configs.size();
+ std::vector<const CodepointRange*> internal_configs_fb;
+ internal_configs_fb.reserve(num_internal_configs);
+ buffers_.reserve(num_configs + num_internal_configs);
+ for (int i = 0; i < num_configs; i++) {
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(CreateTokenizationCodepointRange(
+ builder, &codepoint_range_configs[i]));
+ buffers_.push_back(builder.Release());
+ configs_fb.push_back(flatbuffers::GetRoot<TokenizationCodepointRange>(
+ buffers_.back().data()));
+ }
+ for (int i = 0; i < num_internal_configs; i++) {
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(
+ CreateCodepointRange(builder, &internal_codepoint_range_configs[i]));
+ buffers_.push_back(builder.Release());
+ internal_configs_fb.push_back(
+ flatbuffers::GetRoot<CodepointRange>(buffers_.back().data()));
+ }
+ tokenizer_ = std::unique_ptr<TestingTokenizer>(new TestingTokenizer(
+ type, &unilib_, configs_fb, internal_configs_fb, split_on_script_change,
+ icu_preserve_whitespace_tokens));
+ }
+
+ TokenizationCodepointRange_::Role TestFindTokenizationRole(int c) const {
+ const TokenizationCodepointRangeT* range =
+ tokenizer_->FindTokenizationRange(c);
+ if (range != nullptr) {
+ return range->role;
+ } else {
+ return TokenizationCodepointRange_::Role_DEFAULT_ROLE;
+ }
+ }
+
+ std::vector<Token> Tokenize(const std::string& utf8_text) const {
+ return tokenizer_->Tokenize(utf8_text);
+ }
+
+ private:
+ UniLib unilib_;
+ std::vector<flatbuffers::DetachedBuffer> buffers_;
+ std::unique_ptr<TestingTokenizer> tokenizer_;
+};
+
+TEST(TokenizerTest, FindTokenizationRange) {
+ std::vector<TokenizationCodepointRangeT> configs;
+ TokenizationCodepointRangeT* config;
+
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 0;
+ config->end = 10;
+ config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
+
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 32;
+ config->end = 33;
+ config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR;
+
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 1234;
+ config->end = 12345;
+ config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
+
+ TestingTokenizerProxy tokenizer(TokenizationType_INTERNAL_TOKENIZER, configs,
+ {}, /*split_on_script_change=*/false,
+ /*icu_preserve_whitespace_tokens=*/false);
+
+ // Test hits to the first group.
+ EXPECT_EQ(tokenizer.TestFindTokenizationRole(0),
+ TokenizationCodepointRange_::Role_TOKEN_SEPARATOR);
+ EXPECT_EQ(tokenizer.TestFindTokenizationRole(5),
+ TokenizationCodepointRange_::Role_TOKEN_SEPARATOR);
+ EXPECT_EQ(tokenizer.TestFindTokenizationRole(10),
+ TokenizationCodepointRange_::Role_DEFAULT_ROLE);
+
+ // Test a hit to the second group.
+ EXPECT_EQ(tokenizer.TestFindTokenizationRole(31),
+ TokenizationCodepointRange_::Role_DEFAULT_ROLE);
+ EXPECT_EQ(tokenizer.TestFindTokenizationRole(32),
+ TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR);
+ EXPECT_EQ(tokenizer.TestFindTokenizationRole(33),
+ TokenizationCodepointRange_::Role_DEFAULT_ROLE);
+
+ // Test hits to the third group.
+ EXPECT_EQ(tokenizer.TestFindTokenizationRole(1233),
+ TokenizationCodepointRange_::Role_DEFAULT_ROLE);
+ EXPECT_EQ(tokenizer.TestFindTokenizationRole(1234),
+ TokenizationCodepointRange_::Role_TOKEN_SEPARATOR);
+ EXPECT_EQ(tokenizer.TestFindTokenizationRole(12344),
+ TokenizationCodepointRange_::Role_TOKEN_SEPARATOR);
+ EXPECT_EQ(tokenizer.TestFindTokenizationRole(12345),
+ TokenizationCodepointRange_::Role_DEFAULT_ROLE);
+
+ // Test a hit outside.
+ EXPECT_EQ(tokenizer.TestFindTokenizationRole(99),
+ TokenizationCodepointRange_::Role_DEFAULT_ROLE);
+}
+
+TEST(TokenizerTest, TokenizeOnSpace) {
+ std::vector<TokenizationCodepointRangeT> configs;
+ TokenizationCodepointRangeT* config;
+
+ configs.emplace_back();
+ config = &configs.back();
+ // Space character.
+ config->start = 32;
+ config->end = 33;
+ config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR;
+
+ TestingTokenizerProxy tokenizer(TokenizationType_INTERNAL_TOKENIZER, configs,
+ {},
+ /*split_on_script_change=*/false,
+ /*icu_preserve_whitespace_tokens=*/false);
+ std::vector<Token> tokens = tokenizer.Tokenize("Hello world!");
+
+ EXPECT_THAT(tokens,
+ ElementsAreArray({Token("Hello", 0, 5), Token("world!", 6, 12)}));
+}
+
+TEST(TokenizerTest, TokenizeOnSpaceAndScriptChange) {
+ std::vector<TokenizationCodepointRangeT> configs;
+ TokenizationCodepointRangeT* config;
+
+ // Latin.
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 0;
+ config->end = 32;
+ config->role = TokenizationCodepointRange_::Role_DEFAULT_ROLE;
+ config->script_id = 1;
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 32;
+ config->end = 33;
+ config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR;
+ config->script_id = 1;
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 33;
+ config->end = 0x77F + 1;
+ config->role = TokenizationCodepointRange_::Role_DEFAULT_ROLE;
+ config->script_id = 1;
+
+ TestingTokenizerProxy tokenizer(TokenizationType_INTERNAL_TOKENIZER, configs,
+ {},
+ /*split_on_script_change=*/true,
+ /*icu_preserve_whitespace_tokens=*/false);
+ EXPECT_THAT(tokenizer.Tokenize("앨라배마 주 전화(123) 456-789웹사이트"),
+ std::vector<Token>({Token("앨라배마", 0, 4), Token("주", 5, 6),
+ Token("전화", 7, 10), Token("(123)", 10, 15),
+ Token("456-789", 16, 23),
+ Token("웹사이트", 23, 28)}));
+} // namespace
+
+TEST(TokenizerTest, TokenizeComplex) {
+ std::vector<TokenizationCodepointRangeT> configs;
+ TokenizationCodepointRangeT* config;
+
+ // Source: http://www.unicode.org/Public/10.0.0/ucd/Blocks-10.0.0d1.txt
+ // Latin - cyrilic.
+ // 0000..007F; Basic Latin
+ // 0080..00FF; Latin-1 Supplement
+ // 0100..017F; Latin Extended-A
+ // 0180..024F; Latin Extended-B
+ // 0250..02AF; IPA Extensions
+ // 02B0..02FF; Spacing Modifier Letters
+ // 0300..036F; Combining Diacritical Marks
+ // 0370..03FF; Greek and Coptic
+ // 0400..04FF; Cyrillic
+ // 0500..052F; Cyrillic Supplement
+ // 0530..058F; Armenian
+ // 0590..05FF; Hebrew
+ // 0600..06FF; Arabic
+ // 0700..074F; Syriac
+ // 0750..077F; Arabic Supplement
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 0;
+ config->end = 32;
+ config->role = TokenizationCodepointRange_::Role_DEFAULT_ROLE;
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 32;
+ config->end = 33;
+ config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR;
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 33;
+ config->end = 0x77F + 1;
+ config->role = TokenizationCodepointRange_::Role_DEFAULT_ROLE;
+
+ // CJK
+ // 2E80..2EFF; CJK Radicals Supplement
+ // 3000..303F; CJK Symbols and Punctuation
+ // 3040..309F; Hiragana
+ // 30A0..30FF; Katakana
+ // 3100..312F; Bopomofo
+ // 3130..318F; Hangul Compatibility Jamo
+ // 3190..319F; Kanbun
+ // 31A0..31BF; Bopomofo Extended
+ // 31C0..31EF; CJK Strokes
+ // 31F0..31FF; Katakana Phonetic Extensions
+ // 3200..32FF; Enclosed CJK Letters and Months
+ // 3300..33FF; CJK Compatibility
+ // 3400..4DBF; CJK Unified Ideographs Extension A
+ // 4DC0..4DFF; Yijing Hexagram Symbols
+ // 4E00..9FFF; CJK Unified Ideographs
+ // A000..A48F; Yi Syllables
+ // A490..A4CF; Yi Radicals
+ // A4D0..A4FF; Lisu
+ // A500..A63F; Vai
+ // F900..FAFF; CJK Compatibility Ideographs
+ // FE30..FE4F; CJK Compatibility Forms
+ // 20000..2A6DF; CJK Unified Ideographs Extension B
+ // 2A700..2B73F; CJK Unified Ideographs Extension C
+ // 2B740..2B81F; CJK Unified Ideographs Extension D
+ // 2B820..2CEAF; CJK Unified Ideographs Extension E
+ // 2CEB0..2EBEF; CJK Unified Ideographs Extension F
+ // 2F800..2FA1F; CJK Compatibility Ideographs Supplement
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 0x2E80;
+ config->end = 0x2EFF + 1;
+ config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 0x3000;
+ config->end = 0xA63F + 1;
+ config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 0xF900;
+ config->end = 0xFAFF + 1;
+ config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 0xFE30;
+ config->end = 0xFE4F + 1;
+ config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 0x20000;
+ config->end = 0x2A6DF + 1;
+ config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 0x2A700;
+ config->end = 0x2B73F + 1;
+ config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 0x2B740;
+ config->end = 0x2B81F + 1;
+ config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 0x2B820;
+ config->end = 0x2CEAF + 1;
+ config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 0x2CEB0;
+ config->end = 0x2EBEF + 1;
+ config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 0x2F800;
+ config->end = 0x2FA1F + 1;
+ config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
+
+ // Thai.
+ // 0E00..0E7F; Thai
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 0x0E00;
+ config->end = 0x0E7F + 1;
+ config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
+
+ TestingTokenizerProxy tokenizer(TokenizationType_INTERNAL_TOKENIZER, configs,
+ {},
+ /*split_on_script_change=*/false,
+ /*icu_preserve_whitespace_tokens=*/false);
+ std::vector<Token> tokens;
+
+ tokens = tokenizer.Tokenize(
+ "問少目木輸走猶術権自京門録球変。細開括省用掲情結傍走愛明氷。");
+ EXPECT_EQ(tokens.size(), 30);
+
+ tokens = tokenizer.Tokenize("問少目 hello 木輸ยามきゃ");
+ // clang-format off
+ EXPECT_THAT(
+ tokens,
+ ElementsAreArray({Token("問", 0, 1),
+ Token("少", 1, 2),
+ Token("目", 2, 3),
+ Token("hello", 4, 9),
+ Token("木", 10, 11),
+ Token("輸", 11, 12),
+ Token("ย", 12, 13),
+ Token("า", 13, 14),
+ Token("ม", 14, 15),
+ Token("き", 15, 16),
+ Token("ゃ", 16, 17)}));
+ // clang-format on
+}
+
+#ifdef TC3_TEST_ICU
+TEST(TokenizerTest, ICUTokenize) {
+ TestingTokenizerProxy tokenizer(TokenizationType_ICU, {}, {},
+ /*split_on_script_change=*/false,
+ /*icu_preserve_whitespace_tokens=*/false);
+ std::vector<Token> tokens = tokenizer.Tokenize("พระบาทสมเด็จพระปรมิ");
+ ASSERT_EQ(tokens,
+ // clang-format off
+ std::vector<Token>({Token("พระบาท", 0, 6),
+ Token("สมเด็จ", 6, 12),
+ Token("พระ", 12, 15),
+ Token("ปร", 15, 17),
+ Token("มิ", 17, 19)}));
+ // clang-format on
+}
+
+TEST(TokenizerTest, ICUTokenizeWithWhitespaces) {
+ TestingTokenizerProxy tokenizer(TokenizationType_ICU, {}, {},
+ /*split_on_script_change=*/false,
+ /*icu_preserve_whitespace_tokens=*/true);
+ std::vector<Token> tokens = tokenizer.Tokenize("พระบาท สมเด็จ พระ ปร มิ");
+ ASSERT_EQ(tokens,
+ // clang-format off
+ std::vector<Token>({Token("พระบาท", 0, 6),
+ Token(" ", 6, 7),
+ Token("สมเด็จ", 7, 13),
+ Token(" ", 13, 14),
+ Token("พระ", 14, 17),
+ Token(" ", 17, 18),
+ Token("ปร", 18, 20),
+ Token(" ", 20, 21),
+ Token("มิ", 21, 23)}));
+ // clang-format on
+}
+
+TEST(TokenizerTest, MixedTokenize) {
+ std::vector<TokenizationCodepointRangeT> configs;
+ TokenizationCodepointRangeT* config;
+
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 32;
+ config->end = 33;
+ config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR;
+
+ std::vector<CodepointRangeT> internal_configs;
+ CodepointRangeT* interal_config;
+
+ internal_configs.emplace_back();
+ interal_config = &internal_configs.back();
+ interal_config->start = 0;
+ interal_config->end = 128;
+
+ internal_configs.emplace_back();
+ interal_config = &internal_configs.back();
+ interal_config->start = 128;
+ interal_config->end = 256;
+
+ internal_configs.emplace_back();
+ interal_config = &internal_configs.back();
+ interal_config->start = 256;
+ interal_config->end = 384;
+
+ internal_configs.emplace_back();
+ interal_config = &internal_configs.back();
+ interal_config->start = 384;
+ interal_config->end = 592;
+
+ TestingTokenizerProxy tokenizer(TokenizationType_MIXED, configs,
+ internal_configs,
+ /*split_on_script_change=*/false,
+ /*icu_preserve_whitespace_tokens=*/false);
+
+ std::vector<Token> tokens = tokenizer.Tokenize(
+ "こんにちはJapanese-ląnguagę text 世界 http://www.google.com/");
+ ASSERT_EQ(tokens,
+ // clang-format off
+ std::vector<Token>({Token("こんにちは", 0, 5),
+ Token("Japanese-ląnguagę", 5, 22),
+ Token("text", 23, 27),
+ Token("世界", 28, 30),
+ Token("http://www.google.com/", 31, 53)}));
+ // clang-format on
+}
+
+TEST(TokenizerTest, InternalTokenizeOnScriptChange) {
+ std::vector<TokenizationCodepointRangeT> configs;
+ TokenizationCodepointRangeT* config;
+
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 0;
+ config->end = 256;
+ config->role = TokenizationCodepointRange_::Role_DEFAULT_ROLE;
+
+ {
+ TestingTokenizerProxy tokenizer(TokenizationType_INTERNAL_TOKENIZER,
+ configs, {},
+ /*split_on_script_change=*/false,
+ /*icu_preserve_whitespace_tokens=*/false);
+
+ EXPECT_EQ(tokenizer.Tokenize("앨라배마123웹사이트"),
+ std::vector<Token>({Token("앨라배마123웹사이트", 0, 11)}));
+ }
+
+ {
+ TestingTokenizerProxy tokenizer(TokenizationType_INTERNAL_TOKENIZER,
+ configs, {},
+ /*split_on_script_change=*/true,
+ /*icu_preserve_whitespace_tokens=*/false);
+ EXPECT_EQ(tokenizer.Tokenize("앨라배마123웹사이트"),
+ std::vector<Token>({Token("앨라배마", 0, 4), Token("123", 4, 7),
+ Token("웹사이트", 7, 11)}));
+ }
+}
+#endif
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/utils/utf8/UniLibJavaIcuTest.java b/utils/utf8/UniLibJavaIcuTest.java
new file mode 100644
index 0000000..d6a0a06
--- /dev/null
+++ b/utils/utf8/UniLibJavaIcuTest.java
@@ -0,0 +1,41 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.google.android.textclassifier.utils.utf8;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/** This is just a launcher of the tests because we need a valid JNIEnv in the C++ test. */
+@RunWith(JUnit4.class)
+public class UniLibJavaIcuTest {
+
+ @Before
+ public void setUp() throws Exception {
+ System.loadLibrary("unilib-javaicu_test-jni");
+ }
+
+ private native boolean testsMain();
+
+ @Test
+ public void testNative() {
+ assertThat(testsMain()).isTrue();
+ }
+}
diff --git a/utils/utf8/unicodetext.cc b/utils/utf8/unicodetext.cc
index 81492d8..b3b092e 100644
--- a/utils/utf8/unicodetext.cc
+++ b/utils/utf8/unicodetext.cc
@@ -20,6 +20,7 @@
#include <algorithm>
+#include "utils/base/logging.h"
#include "utils/strings/utf8.h"
namespace libtextclassifier3 {
@@ -206,9 +207,36 @@
return UTF8Substring(begin(), end());
}
-std::string UnicodeText::UTF8Substring(const const_iterator& first,
- const const_iterator& last) {
- return std::string(first.it_, last.it_ - first.it_);
+std::string UnicodeText::UTF8Substring(int begin_codepoint,
+ int end_codepoint) const {
+ auto span_begin = begin();
+ std::advance(span_begin, begin_codepoint);
+ auto span_end = begin();
+ std::advance(span_end, end_codepoint);
+ return UTF8Substring(span_begin, span_end);
+}
+
+std::string UnicodeText::UTF8Substring(const const_iterator& it_begin,
+ const const_iterator& it_end) {
+ return std::string(it_begin.it_, it_end.it_ - it_begin.it_);
+}
+
+UnicodeText UnicodeText::Substring(const UnicodeText& text, int begin_codepoint,
+ int end_codepoint, bool do_copy) {
+ auto it_begin = text.begin();
+ std::advance(it_begin, begin_codepoint);
+ auto it_end = text.begin();
+ std::advance(it_end, end_codepoint);
+
+ if (do_copy) {
+ UnicodeText result;
+ result.repr_.Copy(it_begin.it_, it_end.it_ - it_begin.it_);
+ return result;
+ } else {
+ UnicodeText result;
+ result.repr_.PointTo(it_begin.it_, it_end.it_ - it_begin.it_);
+ return result;
+ }
}
UnicodeText::~UnicodeText() {}
diff --git a/utils/utf8/unicodetext.h b/utils/utf8/unicodetext.h
index eb206b8..310fd38 100644
--- a/utils/utf8/unicodetext.h
+++ b/utils/utf8/unicodetext.h
@@ -75,7 +75,7 @@
typedef const_iterator CI;
public:
- typedef std::input_iterator_tag iterator_category;
+ typedef std::bidirectional_iterator_tag iterator_category;
typedef char32 value_type;
typedef int difference_type;
typedef void pointer; // (Not needed.)
@@ -172,8 +172,11 @@
void clear();
std::string ToUTF8String() const;
- static std::string UTF8Substring(const const_iterator& first,
- const const_iterator& last);
+ std::string UTF8Substring(int begin_codepoint, int end_codepoint) const;
+ static std::string UTF8Substring(const const_iterator& it_begin,
+ const const_iterator& it_end);
+ static UnicodeText Substring(const UnicodeText& text, int begin_codepoint,
+ int end_codepoint, bool do_copy = true);
private:
friend class const_iterator;
@@ -214,9 +217,10 @@
// std::string, or from ::string to std::string, because if this happens it
// often results in invalid memory access to a temporary object created during
// such conversion (if do_copy == false).
-UnicodeText UTF8ToUnicodeText(const char* utf8_buf, int len, bool do_copy);
-UnicodeText UTF8ToUnicodeText(const char* utf8_buf, bool do_copy);
-UnicodeText UTF8ToUnicodeText(const std::string& str, bool do_copy);
+UnicodeText UTF8ToUnicodeText(const char* utf8_buf, int len,
+ bool do_copy = true);
+UnicodeText UTF8ToUnicodeText(const char* utf8_buf, bool do_copy = true);
+UnicodeText UTF8ToUnicodeText(const std::string& str, bool do_copy = true);
UnicodeText UTF8ToUnicodeText(const std::string& str);
} // namespace libtextclassifier3
diff --git a/utils/utf8/unicodetext_test.cc b/utils/utf8/unicodetext_test.cc
index 7ebb415..e6926ce 100644
--- a/utils/utf8/unicodetext_test.cc
+++ b/utils/utf8/unicodetext_test.cc
@@ -49,6 +49,15 @@
EXPECT_EQ(text.UTF8Substring(it_begin, it_end), "😋h");
}
+TEST(UnicodeTextTest, Substring) {
+ UnicodeText text = UTF8ToUnicodeText("1234😋hello", /*do_copy=*/false);
+
+ EXPECT_EQ(UnicodeText::Substring(text, 4, 6, /*do_copy=*/true),
+ UTF8ToUnicodeText("😋h"));
+ EXPECT_EQ(UnicodeText::Substring(text, 4, 6, /*do_copy=*/false),
+ UTF8ToUnicodeText("😋h"));
+}
+
TEST(UnicodeTextTest, Ownership) {
const std::string src = "\u304A\u00B0\u106B";
diff --git a/utils/utf8/unilib-javaicu.cc b/utils/utf8/unilib-javaicu.cc
index dab4f70..8cddddd 100644
--- a/utils/utf8/unilib-javaicu.cc
+++ b/utils/utf8/unilib-javaicu.cc
@@ -340,21 +340,50 @@
std::unique_ptr<UniLib::RegexPattern> UniLib::CreateRegexPattern(
const UnicodeText& regex) const {
return std::unique_ptr<UniLib::RegexPattern>(
- new UniLib::RegexPattern(jni_cache_.get(), regex));
+ new UniLib::RegexPattern(jni_cache_.get(), regex, /*lazy=*/false));
+}
+
+std::unique_ptr<UniLib::RegexPattern> UniLib::CreateLazyRegexPattern(
+ const UnicodeText& regex) const {
+ return std::unique_ptr<UniLib::RegexPattern>(
+ new UniLib::RegexPattern(jni_cache_.get(), regex, /*lazy=*/true));
}
UniLib::RegexPattern::RegexPattern(const JniCache* jni_cache,
- const UnicodeText& regex)
+ const UnicodeText& pattern, bool lazy)
: jni_cache_(jni_cache),
- pattern_(nullptr, jni_cache ? jni_cache->jvm : nullptr) {
+ pattern_(nullptr, jni_cache ? jni_cache->jvm : nullptr),
+ initialized_(false),
+ initialization_failure_(false),
+ pattern_text_(pattern) {
+ if (!lazy) {
+ LockedInitializeIfNotAlready();
+ }
+}
+
+void UniLib::RegexPattern::LockedInitializeIfNotAlready() const {
+ std::lock_guard<std::mutex> guard(mutex_);
+ if (initialized_ || initialization_failure_) {
+ return;
+ }
+
if (jni_cache_) {
JNIEnv* jenv = jni_cache_->GetEnv();
const ScopedLocalRef<jstring> regex_java =
- jni_cache->ConvertToJavaString(regex);
+ jni_cache_->ConvertToJavaString(pattern_text_);
pattern_ = MakeGlobalRef(jenv->CallStaticObjectMethod(
jni_cache_->pattern_class.get(),
jni_cache_->pattern_compile, regex_java.get()),
jenv, jni_cache_->jvm);
+
+ if (jni_cache_->ExceptionCheckAndClear() || pattern_ == nullptr) {
+ initialization_failure_ = true;
+ pattern_.reset();
+ return;
+ }
+
+ initialized_ = true;
+ pattern_text_.clear(); // We don't need this anymore.
}
}
@@ -363,6 +392,11 @@
std::unique_ptr<UniLib::RegexMatcher> UniLib::RegexPattern::Matcher(
const UnicodeText& context) const {
+ LockedInitializeIfNotAlready(); // Possibly lazy initialization.
+ if (initialization_failure_) {
+ return nullptr;
+ }
+
if (jni_cache_) {
JNIEnv* env = jni_cache_->GetEnv();
const jstring context_java =
diff --git a/utils/utf8/unilib-javaicu.h b/utils/utf8/unilib-javaicu.h
index a5ea54f..0a5d339 100644
--- a/utils/utf8/unilib-javaicu.h
+++ b/utils/utf8/unilib-javaicu.h
@@ -23,12 +23,14 @@
#include <jni.h>
#include <memory>
+#include <mutex> // NOLINT
#include <string>
#include "utils/base/integral_types.h"
#include "utils/java/jni-cache.h"
#include "utils/java/scoped_global_ref.h"
#include "utils/java/scoped_local_ref.h"
+#include "utils/java/string_utils.h"
#include "utils/utf8/unicodetext.h"
namespace libtextclassifier3 {
@@ -104,12 +106,17 @@
// was not called previously.
UnicodeText Group(int group_idx, int* status) const;
- protected:
+ // Returns the matched text (the 0th capturing group).
+ std::string Text() const {
+ ScopedStringChars text_str =
+ GetScopedStringChars(jni_cache_->GetEnv(), text_.get());
+ return text_str.get();
+ }
+
+ private:
friend class RegexPattern;
RegexMatcher(const JniCache* jni_cache, ScopedGlobalRef<jobject> matcher,
ScopedGlobalRef<jstring> text);
-
- private:
bool UpdateLastFindOffset() const;
const JniCache* jni_cache_;
@@ -124,13 +131,23 @@
public:
std::unique_ptr<RegexMatcher> Matcher(const UnicodeText& context) const;
- protected:
- friend class UniLib;
- RegexPattern(const JniCache* jni_cache, const UnicodeText& regex);
-
private:
+ friend class UniLib;
+ RegexPattern(const JniCache* jni_cache, const UnicodeText& pattern,
+ bool lazy);
+ void LockedInitializeIfNotAlready() const;
+
const JniCache* jni_cache_;
- ScopedGlobalRef<jobject> pattern_;
+
+ // These members need to be mutable because of the lazy initialization.
+ // NOTE: The Matcher method first ensures (using a lock) that the
+ // initialization was attempted (by using LockedInitializeIfNotAlready) and
+ // then can access them without locking.
+ mutable std::mutex mutex_;
+ mutable ScopedGlobalRef<jobject> pattern_;
+ mutable bool initialized_;
+ mutable bool initialization_failure_;
+ mutable UnicodeText pattern_text_;
};
class BreakIterator {
@@ -139,11 +156,10 @@
static constexpr int kDone = -1;
- protected:
+ private:
friend class UniLib;
BreakIterator(const JniCache* jni_cache, const UnicodeText& text);
- private:
const JniCache* jni_cache_;
ScopedGlobalRef<jstring> text_;
ScopedGlobalRef<jobject> iterator_;
@@ -153,6 +169,8 @@
std::unique_ptr<RegexPattern> CreateRegexPattern(
const UnicodeText& regex) const;
+ std::unique_ptr<RegexPattern> CreateLazyRegexPattern(
+ const UnicodeText& regex) const;
std::unique_ptr<BreakIterator> CreateBreakIterator(
const UnicodeText& text) const;
diff --git a/utils/utf8/unilib_test.cc b/utils/utf8/unilib_test-include.cc
similarity index 87%
rename from utils/utf8/unilib_test.cc
rename to utils/utf8/unilib_test-include.cc
index 96b2c2d..bd53208 100644
--- a/utils/utf8/unilib_test.cc
+++ b/utils/utf8/unilib_test-include.cc
@@ -14,24 +14,15 @@
* limitations under the License.
*/
-#include "utils/utf8/unilib.h"
+#include "utils/utf8/unilib_test-include.h"
-#include "utils/base/logging.h"
-#include "utils/utf8/unicodetext.h"
#include "gmock/gmock.h"
-#include "gtest/gtest.h"
namespace libtextclassifier3 {
-namespace {
+namespace test_internal {
using ::testing::ElementsAre;
-class UniLibTest : public ::testing::Test {
- protected:
- UniLibTest() : INIT_UNILIB_FOR_TESTING(unilib_) {}
- UniLib unilib_;
-};
-
TEST_F(UniLibTest, CharacterClassesAscii) {
EXPECT_TRUE(unilib_.IsOpeningBracket('('));
EXPECT_TRUE(unilib_.IsClosingBracket(')'));
@@ -50,7 +41,6 @@
EXPECT_EQ(unilib_.GetPairedBracket('}'), '{');
}
-#ifndef TC3_UNILIB_DUMMY
TEST_F(UniLibTest, CharacterClassesUnicode) {
EXPECT_TRUE(unilib_.IsOpeningBracket(0x0F3C)); // TIBET ANG KHANG GYON
EXPECT_TRUE(unilib_.IsClosingBracket(0x0F3D)); // TIBET ANG KHANG GYAS
@@ -72,7 +62,6 @@
EXPECT_EQ(unilib_.GetPairedBracket(0x0F3C), 0x0F3D);
EXPECT_EQ(unilib_.GetPairedBracket(0x0F3D), 0x0F3C);
}
-#endif // ndef TC3_UNILIB_DUMMY
TEST_F(UniLibTest, RegexInterface) {
const UnicodeText regex_pattern =
@@ -89,7 +78,6 @@
TC3_LOG(INFO) << matcher->Group(0, &status).size_codepoints();
}
-#ifdef TC3_UNILIB_ICU
TEST_F(UniLibTest, Regex) {
// The smiley face is a 4-byte UTF8 codepoint 0x1F60B, and it's important to
// test the regex functionality with it to verify we are handling the indices
@@ -126,15 +114,34 @@
EXPECT_EQ(matcher->Group(0, &status).ToUTF8String(), "0123😋");
EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
}
-#endif // TC3_UNILIB_ICU
-#ifdef TC3_UNILIB_ICU
+TEST_F(UniLibTest, RegexLazy) {
+ std::unique_ptr<UniLib::RegexPattern> pattern =
+ unilib_.CreateLazyRegexPattern(
+ UTF8ToUnicodeText("[a-z][0-9]", /*do_copy=*/false));
+ int status;
+ std::unique_ptr<UniLib::RegexMatcher> matcher;
+
+ matcher = pattern->Matcher(UTF8ToUnicodeText("a3", /*do_copy=*/false));
+ EXPECT_TRUE(matcher->Matches(&status));
+ EXPECT_TRUE(matcher->ApproximatelyMatches(&status));
+ EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
+ EXPECT_TRUE(matcher->Matches(&status)); // Check that the state is reset.
+ EXPECT_TRUE(matcher->ApproximatelyMatches(&status));
+ EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
+
+ matcher = pattern->Matcher(UTF8ToUnicodeText("3a", /*do_copy=*/false));
+ EXPECT_FALSE(matcher->Matches(&status));
+ EXPECT_FALSE(matcher->ApproximatelyMatches(&status));
+ EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
+}
+
TEST_F(UniLibTest, RegexGroups) {
// The smiley face is a 4-byte UTF8 codepoint 0x1F60B, and it's important to
// test the regex functionality with it to verify we are handling the indices
// correctly.
- const UnicodeText regex_pattern = UTF8ToUnicodeText(
- "(?<group1>[0-9])(?<group2>[0-9]+)😋", /*do_copy=*/false);
+ const UnicodeText regex_pattern =
+ UTF8ToUnicodeText("([0-9])([0-9]+)😋", /*do_copy=*/false);
std::unique_ptr<UniLib::RegexPattern> pattern =
unilib_.CreateRegexPattern(regex_pattern);
int status;
@@ -163,9 +170,6 @@
EXPECT_EQ(matcher->Group(2, &status).ToUTF8String(), "123");
EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
}
-#endif // TC3_UNILIB_ICU
-
-#ifdef TC3_UNILIB_ICU
TEST_F(UniLibTest, BreakIterator) {
const UnicodeText text = UTF8ToUnicodeText("some text", /*do_copy=*/false);
@@ -178,9 +182,7 @@
}
EXPECT_THAT(break_indices, ElementsAre(4, 5, 9));
}
-#endif // TC3_UNILIB_ICU
-#ifdef TC3_UNILIB_ICU
TEST_F(UniLibTest, BreakIterator4ByteUTF8) {
const UnicodeText text = UTF8ToUnicodeText("😀😂😋", /*do_copy=*/false);
std::unique_ptr<UniLib::BreakIterator> iterator =
@@ -192,18 +194,14 @@
}
EXPECT_THAT(break_indices, ElementsAre(1, 2, 3));
}
-#endif // TC3_UNILIB_ICU
-#ifndef TC3_UNILIB_JAVAICU
TEST_F(UniLibTest, IntegerParse) {
int result;
EXPECT_TRUE(
unilib_.ParseInt32(UTF8ToUnicodeText("123", /*do_copy=*/false), &result));
EXPECT_EQ(result, 123);
}
-#endif // ndef TC3_UNILIB_JAVAICU
-#ifdef TC3_UNILIB_ICU
TEST_F(UniLibTest, IntegerParseFullWidth) {
int result;
// The input string here is full width
@@ -211,16 +209,13 @@
&result));
EXPECT_EQ(result, 123);
}
-#endif // TC3_UNILIB_ICU
-#ifdef TC3_UNILIB_ICU
TEST_F(UniLibTest, IntegerParseFullWidthWithAlpha) {
int result;
// The input string here is full width
EXPECT_FALSE(unilib_.ParseInt32(UTF8ToUnicodeText("1a3", /*do_copy=*/false),
&result));
}
-#endif // TC3_UNILIB_ICU
-} // namespace
+} // namespace test_internal
} // namespace libtextclassifier3
diff --git a/utils/utf8/unilib_test-include.h b/utils/utf8/unilib_test-include.h
new file mode 100644
index 0000000..151a6f0
--- /dev/null
+++ b/utils/utf8/unilib_test-include.h
@@ -0,0 +1,49 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_UTF8_UNILIB_TEST_INCLUDE_H_
+#define LIBTEXTCLASSIFIER_UTILS_UTF8_UNILIB_TEST_INCLUDE_H_
+
+// Include the version of UniLib depending on the macro.
+#if defined TC3_UNILIB_ICU
+#include "utils/utf8/unilib-icu.h"
+#define TC3_TESTING_CREATE_UNILIB_INSTANCE(VAR) VAR()
+#elif defined TC3_UNILIB_JAVAICU
+#include <jni.h>
+extern JNIEnv* g_jenv;
+#define TC3_TESTING_CREATE_UNILIB_INSTANCE(VAR) VAR(JniCache::Create(g_jenv))
+#include "utils/utf8/unilib-javaicu.h"
+#elif defined TC3_UNILIB_DUMMY
+#include "utils/utf8/unilib-dummy.h"
+#define TC3_TESTING_CREATE_UNILIB_INSTANCE(VAR) VAR()
+#endif
+
+#include "utils/base/logging.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+namespace test_internal {
+
+class UniLibTest : public ::testing::Test {
+ protected:
+ UniLibTest() : TC3_TESTING_CREATE_UNILIB_INSTANCE(unilib_) {}
+ UniLib unilib_;
+};
+
+} // namespace test_internal
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_UTF8_UNILIB_TEST_INCLUDE_H_
diff --git a/utils/variant.h b/utils/variant.h
index ddb0d60..68bb04b 100644
--- a/utils/variant.h
+++ b/utils/variant.h
@@ -17,46 +17,104 @@
#ifndef LIBTEXTCLASSIFIER_UTILS_VARIANT_H_
#define LIBTEXTCLASSIFIER_UTILS_VARIANT_H_
+#include <map>
#include <string>
#include "utils/base/integral_types.h"
+#include "utils/base/logging.h"
#include "utils/strings/stringpiece.h"
namespace libtextclassifier3 {
// Represents a type-tagged union of different basic types.
-struct Variant {
- Variant() : type(TYPE_INVALID) {}
- explicit Variant(int value) : type(TYPE_INT_VALUE), int_value(value) {}
- explicit Variant(int64 value) : type(TYPE_LONG_VALUE), long_value(value) {}
- explicit Variant(float value) : type(TYPE_FLOAT_VALUE), float_value(value) {}
- explicit Variant(double value)
- : type(TYPE_DOUBLE_VALUE), double_value(value) {}
- explicit Variant(StringPiece value)
- : type(TYPE_STRING_VALUE), string_value(value.ToString()) {}
- explicit Variant(std::string value)
- : type(TYPE_STRING_VALUE), string_value(value) {}
- explicit Variant(const char* value)
- : type(TYPE_STRING_VALUE), string_value(value) {}
- explicit Variant(bool value) : type(TYPE_BOOL_VALUE), bool_value(value) {}
+class Variant {
+ public:
enum Type {
- TYPE_INVALID = 0,
+ TYPE_EMPTY = 0,
TYPE_INT_VALUE = 1,
- TYPE_LONG_VALUE = 2,
+ TYPE_INT64_VALUE = 2,
TYPE_FLOAT_VALUE = 3,
TYPE_DOUBLE_VALUE = 4,
TYPE_BOOL_VALUE = 5,
TYPE_STRING_VALUE = 6,
};
- Type type;
+
+ Variant() : type_(TYPE_EMPTY) {}
+ explicit Variant(const int value)
+ : type_(TYPE_INT_VALUE), int_value_(value) {}
+ explicit Variant(const int64 value)
+ : type_(TYPE_INT64_VALUE), long_value_(value) {}
+ explicit Variant(const float value)
+ : type_(TYPE_FLOAT_VALUE), float_value_(value) {}
+ explicit Variant(const double value)
+ : type_(TYPE_DOUBLE_VALUE), double_value_(value) {}
+ explicit Variant(const StringPiece value)
+ : type_(TYPE_STRING_VALUE), string_value_(value.ToString()) {}
+ explicit Variant(const std::string value)
+ : type_(TYPE_STRING_VALUE), string_value_(value) {}
+ explicit Variant(const char* value)
+ : type_(TYPE_STRING_VALUE), string_value_(value) {}
+ explicit Variant(const bool value)
+ : type_(TYPE_BOOL_VALUE), bool_value_(value) {}
+
+ Variant& operator=(const Variant&) = default;
+
+ int IntValue() const {
+ TC3_CHECK(HasInt());
+ return int_value_;
+ }
+
+ int64 Int64Value() const {
+ TC3_CHECK(HasInt64());
+ return long_value_;
+ }
+
+ float FloatValue() const {
+ TC3_CHECK(HasFloat());
+ return float_value_;
+ }
+
+ double DoubleValue() const {
+ TC3_CHECK(HasDouble());
+ return double_value_;
+ }
+
+ bool BoolValue() const {
+ TC3_CHECK(HasBool());
+ return bool_value_;
+ }
+
+ const std::string& StringValue() const {
+ TC3_CHECK(HasString());
+ return string_value_;
+ }
+
+ bool HasInt() const { return type_ == TYPE_INT_VALUE; }
+
+ bool HasInt64() const { return type_ == TYPE_INT64_VALUE; }
+
+ bool HasFloat() const { return type_ == TYPE_FLOAT_VALUE; }
+
+ bool HasDouble() const { return type_ == TYPE_DOUBLE_VALUE; }
+
+ bool HasBool() const { return type_ == TYPE_BOOL_VALUE; }
+
+ bool HasString() const { return type_ == TYPE_STRING_VALUE; }
+
+ Type GetType() const { return type_; }
+
+ bool HasValue() const { return type_ != TYPE_EMPTY; }
+
+ private:
+ Type type_;
union {
- int int_value;
- int64 long_value;
- float float_value;
- double double_value;
- bool bool_value;
+ int int_value_;
+ int64 long_value_;
+ float float_value_;
+ double double_value_;
+ bool bool_value_;
};
- std::string string_value;
+ std::string string_value_;
};
} // namespace libtextclassifier3
diff --git a/utils/zlib/zlib.cc b/utils/zlib/zlib.cc
index e9991e0..4cb7760 100644
--- a/utils/zlib/zlib.cc
+++ b/utils/zlib/zlib.cc
@@ -16,26 +16,36 @@
#include "utils/zlib/zlib.h"
-#include <memory>
-
-#include "utils/base/logging.h"
#include "utils/flatbuffers.h"
namespace libtextclassifier3 {
-std::unique_ptr<ZlibDecompressor> ZlibDecompressor::Instance() {
- std::unique_ptr<ZlibDecompressor> result(new ZlibDecompressor());
+std::unique_ptr<ZlibDecompressor> ZlibDecompressor::Instance(
+ const unsigned char* dictionary, const unsigned int dictionary_size) {
+ std::unique_ptr<ZlibDecompressor> result(
+ new ZlibDecompressor(dictionary, dictionary_size));
if (!result->initialized_) {
result.reset();
}
return result;
}
-ZlibDecompressor::ZlibDecompressor() {
+ZlibDecompressor::ZlibDecompressor(const unsigned char* dictionary,
+ const unsigned int dictionary_size) {
memset(&stream_, 0, sizeof(stream_));
stream_.zalloc = Z_NULL;
stream_.zfree = Z_NULL;
- initialized_ = (inflateInit(&stream_) == Z_OK);
+ initialized_ = false;
+ if (inflateInit(&stream_) != Z_OK) {
+ TC3_LOG(ERROR) << "Could not initialize decompressor.";
+ return;
+ }
+ if (dictionary != nullptr &&
+ inflateSetDictionary(&stream_, dictionary, dictionary_size) != Z_OK) {
+ TC3_LOG(ERROR) << "Could not set dictionary.";
+ return;
+ }
+ initialized_ = true;
}
ZlibDecompressor::~ZlibDecompressor() {
@@ -78,21 +88,57 @@
compressed_buffer->uncompressed_size, out);
}
-std::unique_ptr<ZlibCompressor> ZlibCompressor::Instance() {
- std::unique_ptr<ZlibCompressor> result(new ZlibCompressor());
+bool ZlibDecompressor::MaybeDecompressOptionallyCompressedBuffer(
+ const flatbuffers::String* uncompressed_buffer,
+ const CompressedBuffer* compressed_buffer, std::string* out) {
+ if (uncompressed_buffer != nullptr) {
+ *out = uncompressed_buffer->str();
+ return true;
+ }
+ return MaybeDecompress(compressed_buffer, out);
+}
+
+bool ZlibDecompressor::MaybeDecompressOptionallyCompressedBuffer(
+ const flatbuffers::Vector<uint8>* uncompressed_buffer,
+ const CompressedBuffer* compressed_buffer, std::string* out) {
+ if (uncompressed_buffer != nullptr) {
+ *out =
+ std::string(reinterpret_cast<const char*>(uncompressed_buffer->data()),
+ uncompressed_buffer->size());
+ return true;
+ }
+ return MaybeDecompress(compressed_buffer, out);
+}
+
+std::unique_ptr<ZlibCompressor> ZlibCompressor::Instance(
+ const unsigned char* dictionary, const unsigned int dictionary_size) {
+ std::unique_ptr<ZlibCompressor> result(
+ new ZlibCompressor(dictionary, dictionary_size));
if (!result->initialized_) {
result.reset();
}
return result;
}
-ZlibCompressor::ZlibCompressor(int level, int tmp_buffer_size) {
+ZlibCompressor::ZlibCompressor(const unsigned char* dictionary,
+ const unsigned int dictionary_size,
+ const int level, const int tmp_buffer_size) {
memset(&stream_, 0, sizeof(stream_));
stream_.zalloc = Z_NULL;
stream_.zfree = Z_NULL;
buffer_size_ = tmp_buffer_size;
buffer_.reset(new Bytef[buffer_size_]);
- initialized_ = (deflateInit(&stream_, level) == Z_OK);
+ initialized_ = false;
+ if (deflateInit(&stream_, level) != Z_OK) {
+ TC3_LOG(ERROR) << "Could not initialize compressor.";
+ return;
+ }
+ if (dictionary != nullptr &&
+ deflateSetDictionary(&stream_, dictionary, dictionary_size) != Z_OK) {
+ TC3_LOG(ERROR) << "Could not set dictionary.";
+ return;
+ }
+ initialized_ = true;
}
ZlibCompressor::~ZlibCompressor() { deflateEnd(&stream_); }
@@ -131,44 +177,14 @@
} while (status == Z_OK);
}
-std::unique_ptr<UniLib::RegexPattern> UncompressMakeRegexPattern(
- const UniLib& unilib, const flatbuffers::String* uncompressed_pattern,
- const CompressedBuffer* compressed_pattern, ZlibDecompressor* decompressor,
- std::string* result_pattern_text) {
- UnicodeText unicode_regex_pattern;
- std::string decompressed_pattern;
- if (compressed_pattern != nullptr &&
- compressed_pattern->buffer() != nullptr) {
- if (decompressor == nullptr ||
- !decompressor->MaybeDecompress(compressed_pattern,
- &decompressed_pattern)) {
- TC3_LOG(ERROR) << "Cannot decompress pattern.";
- return nullptr;
- }
- unicode_regex_pattern =
- UTF8ToUnicodeText(decompressed_pattern.data(),
- decompressed_pattern.size(), /*do_copy=*/false);
- } else {
- if (uncompressed_pattern == nullptr) {
- TC3_LOG(ERROR) << "Cannot load uncompressed pattern.";
- return nullptr;
- }
- unicode_regex_pattern =
- UTF8ToUnicodeText(uncompressed_pattern->c_str(),
- uncompressed_pattern->Length(), /*do_copy=*/false);
+bool ZlibCompressor::GetDictionary(std::vector<unsigned char>* dictionary) {
+ // Retrieve first the size of the dictionary.
+ unsigned int size;
+ if (deflateGetDictionary(&stream_, /*dictionary=*/Z_NULL, &size) != Z_OK) {
+ return false;
}
-
- if (result_pattern_text != nullptr) {
- *result_pattern_text = unicode_regex_pattern.ToUTF8String();
- }
-
- std::unique_ptr<UniLib::RegexPattern> regex_pattern =
- unilib.CreateRegexPattern(unicode_regex_pattern);
- if (!regex_pattern) {
- TC3_LOG(ERROR) << "Could not create pattern: "
- << unicode_regex_pattern.ToUTF8String();
- }
- return regex_pattern;
+ dictionary->resize(size);
+ return deflateGetDictionary(&stream_, dictionary->data(), &size) == Z_OK;
}
} // namespace libtextclassifier3
diff --git a/utils/zlib/zlib.h b/utils/zlib/zlib.h
index d93527e..f773c27 100644
--- a/utils/zlib/zlib.h
+++ b/utils/zlib/zlib.h
@@ -19,9 +19,9 @@
#ifndef LIBTEXTCLASSIFIER_UTILS_ZLIB_ZLIB_H_
#define LIBTEXTCLASSIFIER_UTILS_ZLIB_ZLIB_H_
-#include <memory>
+#include <vector>
-#include "utils/utf8/unilib.h"
+#include "utils/base/integral_types.h"
#include "utils/zlib/buffer_generated.h"
#include <zlib.h>
@@ -29,7 +29,9 @@
class ZlibDecompressor {
public:
- static std::unique_ptr<ZlibDecompressor> Instance();
+ static std::unique_ptr<ZlibDecompressor> Instance(
+ const unsigned char* dictionary = nullptr,
+ unsigned int dictionary_size = 0);
~ZlibDecompressor();
bool Decompress(const uint8* buffer, const int buffer_size,
@@ -38,38 +40,45 @@
std::string* out);
bool MaybeDecompress(const CompressedBufferT* compressed_buffer,
std::string* out);
+ bool MaybeDecompressOptionallyCompressedBuffer(
+ const flatbuffers::String* uncompressed_buffer,
+ const CompressedBuffer* compressed_buffer, std::string* out);
+ bool MaybeDecompressOptionallyCompressedBuffer(
+ const flatbuffers::Vector<uint8>* uncompressed_buffer,
+ const CompressedBuffer* compressed_buffer, std::string* out);
private:
- ZlibDecompressor();
+ ZlibDecompressor(const unsigned char* dictionary,
+ const unsigned int dictionary_size);
z_stream stream_;
bool initialized_;
};
class ZlibCompressor {
public:
- static std::unique_ptr<ZlibCompressor> Instance();
+ static std::unique_ptr<ZlibCompressor> Instance(
+ const unsigned char* dictionary = nullptr,
+ unsigned int dictionary_size = 0);
~ZlibCompressor();
void Compress(const std::string& uncompressed_content,
CompressedBufferT* out);
+ bool GetDictionary(std::vector<unsigned char>* dictionary);
+
private:
- explicit ZlibCompressor(int level = Z_BEST_COMPRESSION,
+ explicit ZlibCompressor(const unsigned char* dictionary = nullptr,
+ const unsigned int dictionary_size = 0,
+ const int level = Z_BEST_COMPRESSION,
// Tmp. buffer size was set based on the current set
// of patterns to be compressed.
- int tmp_buffer_size = 64 * 1024);
+ const int tmp_buffer_size = 64 * 1024);
z_stream stream_;
std::unique_ptr<Bytef[]> buffer_;
unsigned int buffer_size_;
bool initialized_;
};
-// Create and compile a regex pattern from optionally compressed pattern.
-std::unique_ptr<UniLib::RegexPattern> UncompressMakeRegexPattern(
- const UniLib& unilib, const flatbuffers::String* uncompressed_pattern,
- const CompressedBuffer* compressed_pattern, ZlibDecompressor* decompressor,
- std::string* result_pattern_text = nullptr);
-
} // namespace libtextclassifier3
#endif // LIBTEXTCLASSIFIER_UTILS_ZLIB_ZLIB_H_
diff --git a/utils/zlib/zlib_regex.cc b/utils/zlib/zlib_regex.cc
new file mode 100644
index 0000000..bfe3f5b
--- /dev/null
+++ b/utils/zlib/zlib_regex.cc
@@ -0,0 +1,71 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/zlib/zlib_regex.h"
+
+#include <memory>
+
+#include "utils/base/logging.h"
+#include "utils/flatbuffers.h"
+
+namespace libtextclassifier3 {
+
+std::unique_ptr<UniLib::RegexPattern> UncompressMakeRegexPattern(
+ const UniLib& unilib, const flatbuffers::String* uncompressed_pattern,
+ const CompressedBuffer* compressed_pattern, bool lazy_compile_regex,
+ ZlibDecompressor* decompressor, std::string* result_pattern_text) {
+ UnicodeText unicode_regex_pattern;
+ std::string decompressed_pattern;
+ if (compressed_pattern != nullptr &&
+ compressed_pattern->buffer() != nullptr) {
+ if (decompressor == nullptr ||
+ !decompressor->MaybeDecompress(compressed_pattern,
+ &decompressed_pattern)) {
+ TC3_LOG(ERROR) << "Cannot decompress pattern.";
+ return nullptr;
+ }
+ unicode_regex_pattern =
+ UTF8ToUnicodeText(decompressed_pattern.data(),
+ decompressed_pattern.size(), /*do_copy=*/false);
+ } else {
+ if (uncompressed_pattern == nullptr) {
+ TC3_LOG(ERROR) << "Cannot load uncompressed pattern.";
+ return nullptr;
+ }
+ unicode_regex_pattern =
+ UTF8ToUnicodeText(uncompressed_pattern->c_str(),
+ uncompressed_pattern->Length(), /*do_copy=*/false);
+ }
+
+ if (result_pattern_text != nullptr) {
+ *result_pattern_text = unicode_regex_pattern.ToUTF8String();
+ }
+
+ std::unique_ptr<UniLib::RegexPattern> regex_pattern;
+ if (lazy_compile_regex) {
+ regex_pattern = unilib.CreateLazyRegexPattern(unicode_regex_pattern);
+ } else {
+ regex_pattern = unilib.CreateRegexPattern(unicode_regex_pattern);
+ }
+
+ if (!regex_pattern) {
+ TC3_LOG(ERROR) << "Could not create pattern: "
+ << unicode_regex_pattern.ToUTF8String();
+ }
+ return regex_pattern;
+}
+
+} // namespace libtextclassifier3
diff --git a/utils/zlib/zlib_regex.h b/utils/zlib/zlib_regex.h
new file mode 100644
index 0000000..27360ec
--- /dev/null
+++ b/utils/zlib/zlib_regex.h
@@ -0,0 +1,36 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_ZLIB_ZLIB_REGEX_H_
+#define LIBTEXTCLASSIFIER_UTILS_ZLIB_ZLIB_REGEX_H_
+
+#include <memory>
+
+#include "utils/utf8/unilib.h"
+#include "utils/zlib/buffer_generated.h"
+#include "utils/zlib/zlib.h"
+
+namespace libtextclassifier3 {
+
+// Create and compile a regex pattern from optionally compressed pattern.
+std::unique_ptr<UniLib::RegexPattern> UncompressMakeRegexPattern(
+ const UniLib& unilib, const flatbuffers::String* uncompressed_pattern,
+ const CompressedBuffer* compressed_pattern, bool lazy_compile_regex,
+ ZlibDecompressor* decompressor, std::string* result_pattern_text = nullptr);
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_ZLIB_ZLIB_REGEX_H_