Update include dir after TFLite rebase. am: db61aac008 am: f6dd76cb57
am: 537cdd1408
Change-Id: Ie763ab095ff9e135f248a144c2fbaabba2790a49
diff --git a/Android.mk b/Android.mk
index 39f7895..17f0373 100644
--- a/Android.mk
+++ b/Android.mk
@@ -41,6 +41,7 @@
-DLIBTEXTCLASSIFIER_UNILIB_ICU \
-DZLIB_CONST \
-DSAFTM_COMPACT_LOGGING \
+ -DTC3_WITH_ACTIONS_OPS \
-DTC3_UNILIB_JAVAICU \
-DTC3_CALENDAR_JAVAICU
@@ -79,6 +80,8 @@
LOCAL_REQUIRED_MODULES := libtextclassifier_annotator_en_model
LOCAL_REQUIRED_MODULES += libtextclassifier_annotator_universal_model
+LOCAL_REQUIRED_MODULES += libtextclassifier_actions_suggestions_model
+LOCAL_REQUIRED_MODULES += libtextclassifier_lang_id_model
LOCAL_ADDITIONAL_DEPENDENCIES += $(LOCAL_PATH)/jni.lds
LOCAL_LDFLAGS += -Wl,-version-script=$(LOCAL_PATH)/jni.lds
@@ -104,7 +107,7 @@
LOCAL_CFLAGS += $(MY_LIBTEXTCLASSIFIER_CFLAGS)
LOCAL_STRIP_MODULE := $(LIBTEXTCLASSIFIER_STRIP_OPTS)
-LOCAL_TEST_DATA := $(call find-test-data-in-subdirs, $(LOCAL_PATH), *, annotator/test_data)
+LOCAL_TEST_DATA := $(call find-test-data-in-subdirs, $(LOCAL_PATH), *, annotator/test_data, actions/test_data)
LOCAL_CPPFLAGS_32 += -DTC3_TEST_DATA_DIR="\"/data/nativetest/libtextclassifier_tests/test_data/\""
LOCAL_CPPFLAGS_64 += -DTC3_TEST_DATA_DIR="\"/data/nativetest64/libtextclassifier_tests/test_data/\""
@@ -148,3 +151,30 @@
LOCAL_SRC_FILES := ./models/textclassifier.universal.model
LOCAL_MODULE_PATH := $(TARGET_OUT_ETC)/textclassifier
include $(BUILD_PREBUILT)
+
+# ---------------------------
+# Actions Suggestions models
+# ---------------------------
+# STOPSHIP: The model size is now around 7.5mb, we should trim it down before shipping it.
+
+include $(CLEAR_VARS)
+LOCAL_MODULE := libtextclassifier_actions_suggestions_model
+LOCAL_MODULE_STEM := actions_suggestions.model
+LOCAL_MODULE_CLASS := ETC
+LOCAL_MODULE_OWNER := google
+LOCAL_SRC_FILES := ./models/actions_suggestions.model
+LOCAL_MODULE_PATH := $(TARGET_OUT_ETC)/textclassifier
+include $(BUILD_PREBUILT)
+
+# ------------
+# LangId model
+# ------------
+
+include $(CLEAR_VARS)
+LOCAL_MODULE := libtextclassifier_lang_id_model
+LOCAL_MODULE_STEM := lang_id.model
+LOCAL_MODULE_CLASS := ETC
+LOCAL_MODULE_OWNER := google
+LOCAL_SRC_FILES := ./models/lang_id.model
+LOCAL_MODULE_PATH := $(TARGET_OUT_ETC)/textclassifier
+include $(BUILD_PREBUILT)
diff --git a/actions/actions-suggestions.cc b/actions/actions-suggestions.cc
new file mode 100644
index 0000000..7ce13f3
--- /dev/null
+++ b/actions/actions-suggestions.cc
@@ -0,0 +1,547 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "actions/actions-suggestions.h"
+#include "utils/base/logging.h"
+#include "utils/utf8/unicodetext.h"
+#include "tensorflow/contrib/lite/string_util.h"
+
+namespace libtextclassifier3 {
+
+const std::string& ActionsSuggestions::kViewCalendarType =
+ *[]() { return new std::string("view_calendar"); }();
+const std::string& ActionsSuggestions::kViewMapType =
+ *[]() { return new std::string("view_map"); }();
+const std::string& ActionsSuggestions::kTrackFlightType =
+ *[]() { return new std::string("track_flight"); }();
+const std::string& ActionsSuggestions::kOpenUrlType =
+ *[]() { return new std::string("open_url"); }();
+const std::string& ActionsSuggestions::kSendSmsType =
+ *[]() { return new std::string("send_sms"); }();
+const std::string& ActionsSuggestions::kCallPhoneType =
+ *[]() { return new std::string("call_phone"); }();
+const std::string& ActionsSuggestions::kSendEmailType =
+ *[]() { return new std::string("send_email"); }();
+const std::string& ActionsSuggestions::kShareLocation =
+ *[]() { return new std::string("share_location"); }();
+
+namespace {
+const ActionsModel* LoadAndVerifyModel(const uint8_t* addr, int size) {
+ flatbuffers::Verifier verifier(addr, size);
+ if (VerifyActionsModelBuffer(verifier)) {
+ return GetActionsModel(addr);
+ } else {
+ return nullptr;
+ }
+}
+
+// Checks whether two annotations can be considered equivalent.
+bool IsEquivalentActionAnnotation(const ActionSuggestionAnnotation& annotation,
+ const ActionSuggestionAnnotation& other) {
+ return annotation.message_index == other.message_index &&
+ annotation.span == other.span && annotation.name == other.name &&
+ annotation.entity.collection == other.entity.collection;
+}
+
+// Checks whether two action suggestions can be considered equivalent.
+bool IsEquivalentActionSuggestion(const ActionSuggestion& action,
+ const ActionSuggestion& other) {
+ if (action.type != other.type ||
+ action.response_text != other.response_text ||
+ action.annotations.size() != other.annotations.size()) {
+ return false;
+ }
+
+ // Check whether annotations are the same.
+ for (int i = 0; i < action.annotations.size(); i++) {
+ if (!IsEquivalentActionAnnotation(action.annotations[i],
+ other.annotations[i])) {
+ return false;
+ }
+ }
+ return true;
+}
+
+// Checks whether any action is equivalent to the given one.
+bool IsAnyActionEquivalent(const ActionSuggestion& action,
+ const std::vector<ActionSuggestion>& actions) {
+ for (const ActionSuggestion& other : actions) {
+ if (IsEquivalentActionSuggestion(action, other)) {
+ return true;
+ }
+ }
+ return false;
+}
+
+} // namespace
+
+std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromUnownedBuffer(
+ const uint8_t* buffer, const int size, const UniLib* unilib) {
+ auto actions = std::unique_ptr<ActionsSuggestions>(new ActionsSuggestions());
+ const ActionsModel* model = LoadAndVerifyModel(buffer, size);
+ if (model == nullptr) {
+ return nullptr;
+ }
+ actions->model_ = model;
+ actions->SetOrCreateUnilib(unilib);
+ if (!actions->ValidateAndInitialize()) {
+ return nullptr;
+ }
+ return actions;
+}
+
+std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromScopedMmap(
+ std::unique_ptr<libtextclassifier3::ScopedMmap> mmap,
+ const UniLib* unilib) {
+ if (!mmap->handle().ok()) {
+ TC3_VLOG(1) << "Mmap failed.";
+ return nullptr;
+ }
+ const ActionsModel* model = LoadAndVerifyModel(
+ reinterpret_cast<const uint8_t*>(mmap->handle().start()),
+ mmap->handle().num_bytes());
+ if (!model) {
+ TC3_LOG(ERROR) << "Model verification failed.";
+ return nullptr;
+ }
+ auto actions = std::unique_ptr<ActionsSuggestions>(new ActionsSuggestions());
+ actions->model_ = model;
+ actions->mmap_ = std::move(mmap);
+ actions->SetOrCreateUnilib(unilib);
+ if (!actions->ValidateAndInitialize()) {
+ return nullptr;
+ }
+ return actions;
+}
+
+std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromFileDescriptor(
+ const int fd, const int offset, const int size, const UniLib* unilib) {
+ std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
+ new libtextclassifier3::ScopedMmap(fd, offset, size));
+ return FromScopedMmap(std::move(mmap), unilib);
+}
+
+std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromFileDescriptor(
+ const int fd, const UniLib* unilib) {
+ std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
+ new libtextclassifier3::ScopedMmap(fd));
+ return FromScopedMmap(std::move(mmap), unilib);
+}
+
+std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromPath(
+ const std::string& path, const UniLib* unilib) {
+ std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
+ new libtextclassifier3::ScopedMmap(path));
+ return FromScopedMmap(std::move(mmap), unilib);
+}
+
+void ActionsSuggestions::SetOrCreateUnilib(const UniLib* unilib) {
+ if (unilib != nullptr) {
+ unilib_ = unilib;
+ } else {
+ owned_unilib_.reset(new UniLib);
+ unilib_ = owned_unilib_.get();
+ }
+}
+
+bool ActionsSuggestions::ValidateAndInitialize() {
+ if (model_ == nullptr) {
+ TC3_LOG(ERROR) << "No model specified.";
+ return false;
+ }
+
+ if (model_->preconditions() == nullptr) {
+ TC3_LOG(ERROR) << "No triggering conditions specified.";
+ return false;
+ }
+
+ if (model_->tflite_model_spec()) {
+ model_executor_ = TfLiteModelExecutor::FromBuffer(
+ model_->tflite_model_spec()->tflite_model());
+ if (!model_executor_) {
+ TC3_LOG(ERROR) << "Could not initialize model executor.";
+ return false;
+ }
+ }
+
+ std::unique_ptr<ZlibDecompressor> decompressor = ZlibDecompressor::Instance();
+ if (!InitializeRules(decompressor.get())) {
+ TC3_LOG(ERROR) << "Could not initialize rules.";
+ return false;
+ }
+
+ return true;
+}
+
+bool ActionsSuggestions::InitializeRules(ZlibDecompressor* decompressor) {
+ if (model_->rules() == nullptr) {
+ // No rules specified.
+ return true;
+ }
+
+ const int num_rules = model_->rules()->rule()->size();
+ for (int i = 0; i < num_rules; i++) {
+ const auto* rule = model_->rules()->rule()->Get(i);
+ std::unique_ptr<UniLib::RegexPattern> compiled_pattern =
+ UncompressMakeRegexPattern(*unilib_, rule->pattern(),
+ rule->compressed_pattern(), decompressor);
+ if (compiled_pattern == nullptr) {
+ TC3_LOG(ERROR) << "Failed to load rule pattern.";
+ return false;
+ }
+ rules_.push_back({/*rule_id=*/i, std::move(compiled_pattern)});
+ }
+
+ return true;
+}
+
+void ActionsSuggestions::RankActions(
+ ActionsSuggestionsResponse* suggestions) const {
+ // First order suggestions by score.
+ std::sort(suggestions->actions.begin(), suggestions->actions.end(),
+ [](const ActionSuggestion& a, const ActionSuggestion& b) {
+ return a.score > b.score;
+ });
+
+ // Deduplicate, keeping the higher score actions.
+ std::vector<ActionSuggestion> deduplicated_actions;
+ for (const ActionSuggestion& candidate : suggestions->actions) {
+ // Check whether we already have an equivalent action.
+ if (!IsAnyActionEquivalent(candidate, deduplicated_actions)) {
+ deduplicated_actions.push_back(candidate);
+ }
+ }
+ suggestions->actions = deduplicated_actions;
+}
+
+void ActionsSuggestions::SetupModelInput(
+ const std::vector<std::string>& context, const std::vector<int>& user_ids,
+ const std::vector<float>& time_diffs, const int num_suggestions,
+ tflite::Interpreter* interpreter) const {
+ if (model_->tflite_model_spec()->input_context() >= 0) {
+ model_executor_->SetInput<std::string>(
+ model_->tflite_model_spec()->input_context(), context, interpreter);
+ }
+ if (model_->tflite_model_spec()->input_context_length() >= 0) {
+ *interpreter
+ ->tensor(interpreter->inputs()[model_->tflite_model_spec()
+ ->input_context_length()])
+ ->data.i64 = context.size();
+ }
+ if (model_->tflite_model_spec()->input_user_id() >= 0) {
+ model_executor_->SetInput<int>(model_->tflite_model_spec()->input_user_id(),
+ user_ids, interpreter);
+ }
+ if (model_->tflite_model_spec()->input_num_suggestions() >= 0) {
+ *interpreter
+ ->tensor(interpreter->inputs()[model_->tflite_model_spec()
+ ->input_num_suggestions()])
+ ->data.i64 = num_suggestions;
+ }
+ if (model_->tflite_model_spec()->input_time_diffs() >= 0) {
+ model_executor_->SetInput<float>(
+ model_->tflite_model_spec()->input_time_diffs(), time_diffs,
+ interpreter);
+ }
+}
+
+void ActionsSuggestions::ReadModelOutput(
+ tflite::Interpreter* interpreter,
+ ActionsSuggestionsResponse* response) const {
+ // Read sensitivity and triggering score predictions.
+ if (model_->tflite_model_spec()->output_triggering_score() >= 0) {
+ const TensorView<float>& triggering_score =
+ model_executor_->OutputView<float>(
+ model_->tflite_model_spec()->output_triggering_score(),
+ interpreter);
+ if (!triggering_score.is_valid() || triggering_score.size() == 0) {
+ TC3_LOG(ERROR) << "Could not compute triggering score.";
+ return;
+ }
+ response->triggering_score = triggering_score.data()[0];
+ response->output_filtered_min_triggering_score =
+ (response->triggering_score <
+ model_->preconditions()->min_smart_reply_triggering_score());
+ }
+ if (model_->tflite_model_spec()->output_sensitive_topic_score() >= 0) {
+ const TensorView<float>& sensitive_topic_score =
+ model_executor_->OutputView<float>(
+ model_->tflite_model_spec()->output_sensitive_topic_score(),
+ interpreter);
+ if (!sensitive_topic_score.is_valid() ||
+ sensitive_topic_score.dim(0) != 1) {
+ TC3_LOG(ERROR) << "Could not compute sensitive topic score.";
+ return;
+ }
+ response->sensitivity_score = sensitive_topic_score.data()[0];
+ response->output_filtered_sensitivity =
+ (response->sensitivity_score >
+ model_->preconditions()->max_sensitive_topic_score());
+ }
+
+ // Suppress model outputs.
+ if (response->output_filtered_sensitivity) {
+ return;
+ }
+
+ // Read smart reply predictions.
+ if (!response->output_filtered_min_triggering_score &&
+ model_->tflite_model_spec()->output_replies() >= 0) {
+ const std::vector<tflite::StringRef> replies =
+ model_executor_->Output<tflite::StringRef>(
+ model_->tflite_model_spec()->output_replies(), interpreter);
+ TensorView<float> scores = model_executor_->OutputView<float>(
+ model_->tflite_model_spec()->output_replies_scores(), interpreter);
+ std::vector<ActionSuggestion> text_replies;
+ for (int i = 0; i < replies.size(); i++) {
+ if (replies[i].len == 0) continue;
+ response->actions.push_back({std::string(replies[i].str, replies[i].len),
+ model_->smart_reply_action_type()->str(),
+ scores.data()[i]});
+ }
+ }
+
+ // Read actions suggestions.
+ if (model_->tflite_model_spec()->output_actions_scores() >= 0) {
+ const TensorView<float> actions_scores = model_executor_->OutputView<float>(
+ model_->tflite_model_spec()->output_actions_scores(), interpreter);
+ for (int i = 0; i < model_->action_type()->Length(); i++) {
+ // Skip disabled action classes, such as the default other category.
+ if (!(*model_->action_type())[i]->enabled()) {
+ continue;
+ }
+ const float score = actions_scores.data()[i];
+ if (score < (*model_->action_type())[i]->min_triggering_score()) {
+ continue;
+ }
+ const std::string& output_class =
+ (*model_->action_type())[i]->name()->str();
+ response->actions.push_back({/*response_text=*/"", output_class, score});
+ }
+ }
+}
+
+void ActionsSuggestions::SuggestActionsFromModel(
+ const Conversation& conversation, const int num_messages,
+ ActionsSuggestionsResponse* response) const {
+ TC3_CHECK_LE(num_messages, conversation.messages.size());
+
+ if (!model_executor_) {
+ return;
+ }
+ std::unique_ptr<tflite::Interpreter> interpreter =
+ model_executor_->CreateInterpreter();
+
+ if (!interpreter) {
+ TC3_LOG(ERROR) << "Could not build TensorFlow Lite interpreter for the "
+ "actions suggestions model.";
+ return;
+ }
+
+ if (interpreter->AllocateTensors() != kTfLiteOk) {
+ TC3_LOG(ERROR)
+ << "Failed to allocate TensorFlow Lite tensors for the actions "
+ "suggestions model.";
+ return;
+ }
+
+ std::vector<std::string> context;
+ std::vector<int> user_ids;
+ std::vector<float> time_diffs;
+
+ // Gather last `num_messages` messages from the conversation.
+ int64 last_message_reference_time_ms_utc = 0;
+ const float second_in_ms = 1000;
+ for (int i = conversation.messages.size() - num_messages;
+ i < conversation.messages.size(); i++) {
+ const ConversationMessage& message = conversation.messages[i];
+ context.push_back(message.text);
+ user_ids.push_back(message.user_id);
+
+ float time_diff_secs = 0;
+ if (message.reference_time_ms_utc != 0 &&
+ last_message_reference_time_ms_utc != 0) {
+ time_diff_secs = std::max(0.0f, (message.reference_time_ms_utc -
+ last_message_reference_time_ms_utc) /
+ second_in_ms);
+ }
+ if (message.reference_time_ms_utc != 0) {
+ last_message_reference_time_ms_utc = message.reference_time_ms_utc;
+ }
+ time_diffs.push_back(time_diff_secs);
+ }
+
+ SetupModelInput(context, user_ids, time_diffs,
+ /*num_suggestions=*/model_->num_smart_replies(),
+ interpreter.get());
+
+ if (interpreter->Invoke() != kTfLiteOk) {
+ TC3_LOG(ERROR) << "Failed to invoke TensorFlow Lite interpreter.";
+ return;
+ }
+
+ ReadModelOutput(interpreter.get(), response);
+}
+
+void ActionsSuggestions::SuggestActionsFromAnnotations(
+ const Conversation& conversation, const ActionSuggestionOptions& options,
+ const Annotator* annotator, ActionsSuggestionsResponse* response) const {
+ if (model_->annotation_actions_spec() == nullptr ||
+ model_->annotation_actions_spec()->annotation_mapping() == nullptr ||
+ model_->annotation_actions_spec()->annotation_mapping()->size() == 0) {
+ return;
+ }
+
+ // Create actions based on the annotations present in the last message.
+ std::vector<AnnotatedSpan> annotations =
+ conversation.messages.back().annotations;
+ if (annotations.empty() && annotator != nullptr) {
+ annotations = annotator->Annotate(conversation.messages.back().text,
+ options.annotation_options);
+ }
+ const int message_index = conversation.messages.size() - 1;
+ for (const AnnotatedSpan& annotation : annotations) {
+ if (annotation.classification.empty() ||
+ annotation.classification[0].collection.empty()) {
+ continue;
+ }
+ CreateActionsFromAnnotationResult(message_index, annotation, response);
+ }
+}
+
+void ActionsSuggestions::CreateActionsFromAnnotationResult(
+ const int message_index, const AnnotatedSpan& annotation,
+ ActionsSuggestionsResponse* suggestions) const {
+ const ClassificationResult& classification_result =
+ annotation.classification[0];
+ ActionSuggestionAnnotation suggestion_annotation;
+ suggestion_annotation.message_index = message_index;
+ suggestion_annotation.span = annotation.span;
+ suggestion_annotation.entity = classification_result;
+ const std::string collection = classification_result.collection;
+
+ for (const AnnotationActionsSpec_::AnnotationMapping* mapping :
+ *model_->annotation_actions_spec()->annotation_mapping()) {
+ if (collection == mapping->annotation_collection()->str()) {
+ if (classification_result.score < mapping->min_annotation_score()) {
+ continue;
+ }
+ const float score =
+ (mapping->use_annotation_score() ? classification_result.score
+ : mapping->action()->score());
+ suggestions->actions.push_back({/*response_text=*/"",
+ /*type=*/mapping->action()->type()->str(),
+ /*score=*/score,
+ /*annotations=*/{suggestion_annotation}});
+ }
+ }
+}
+
+void ActionsSuggestions::SuggestActionsFromRules(
+ const Conversation& conversation,
+ ActionsSuggestionsResponse* suggestions) const {
+ // Create actions based on rules checking the last message.
+ const std::string& message = conversation.messages.back().text;
+ const UnicodeText message_unicode(
+ UTF8ToUnicodeText(message, /*do_copy=*/false));
+ for (int i = 0; i < rules_.size(); i++) {
+ const std::unique_ptr<UniLib::RegexMatcher> matcher =
+ rules_[i].pattern->Matcher(message_unicode);
+ int status = UniLib::RegexMatcher::kNoError;
+ if (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) {
+ const auto actions =
+ model_->rules()->rule()->Get(rules_[i].rule_id)->actions();
+ for (int k = 0; k < actions->size(); k++) {
+ const ActionSuggestionSpec* action = actions->Get(k);
+ suggestions->actions.push_back(
+ {/*response_text=*/(action->response_text() != nullptr
+ ? action->response_text()->str()
+ : ""),
+ /*type=*/action->type()->str(),
+ /*score=*/action->score()});
+ }
+ }
+ }
+}
+
+ActionsSuggestionsResponse ActionsSuggestions::SuggestActions(
+ const Conversation& conversation, const Annotator* annotator,
+ const ActionSuggestionOptions& options) const {
+ ActionsSuggestionsResponse response;
+ if (conversation.messages.empty()) {
+ return response;
+ }
+
+ const int conversation_history_length = conversation.messages.size();
+ const int max_conversation_history_length =
+ model_->max_conversation_history_length();
+ const int num_messages =
+ ((max_conversation_history_length < 0 ||
+ conversation_history_length < max_conversation_history_length)
+ ? conversation_history_length
+ : max_conversation_history_length);
+
+ if (num_messages <= 0) {
+ TC3_LOG(INFO) << "No messages provided for actions suggestions.";
+ return response;
+ }
+
+ int input_text_length = 0;
+ for (int i = conversation.messages.size() - num_messages;
+ i < conversation.messages.size(); i++) {
+ input_text_length += conversation.messages[i].text.length();
+ }
+
+ // Bail out if we are provided with too few or too much input.
+ if (input_text_length < model_->preconditions()->min_input_length() ||
+ (model_->preconditions()->max_input_length() >= 0 &&
+ input_text_length > model_->preconditions()->max_input_length())) {
+ TC3_LOG(INFO) << "Too much or not enough input for inference.";
+ return response;
+ }
+
+ SuggestActionsFromRules(conversation, &response);
+
+ SuggestActionsFromModel(conversation, num_messages, &response);
+
+ // Suppress all predictions if the conversation was deemed sensitive.
+ if (model_->preconditions()->suppress_on_sensitive_topic() &&
+ response.output_filtered_sensitivity) {
+ return response;
+ }
+
+ SuggestActionsFromAnnotations(conversation, options, annotator, &response);
+
+ RankActions(&response);
+
+ return response;
+}
+
+ActionsSuggestionsResponse ActionsSuggestions::SuggestActions(
+ const Conversation& conversation,
+ const ActionSuggestionOptions& options) const {
+ return SuggestActions(conversation, /*annotator=*/nullptr, options);
+}
+
+const ActionsModel* ViewActionsModel(const void* buffer, int size) {
+ if (buffer == nullptr) {
+ return nullptr;
+ }
+
+ return LoadAndVerifyModel(reinterpret_cast<const uint8_t*>(buffer), size);
+}
+
+} // namespace libtextclassifier3
diff --git a/actions/actions-suggestions.h b/actions/actions-suggestions.h
new file mode 100644
index 0000000..b5f0c2e
--- /dev/null
+++ b/actions/actions-suggestions.h
@@ -0,0 +1,223 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ACTIONS_ACTIONS_SUGGESTIONS_H_
+#define LIBTEXTCLASSIFIER_ACTIONS_ACTIONS_SUGGESTIONS_H_
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "actions/actions_model_generated.h"
+#include "annotator/annotator.h"
+#include "annotator/types.h"
+#include "utils/memory/mmap.h"
+#include "utils/tflite-model-executor.h"
+#include "utils/utf8/unilib.h"
+#include "utils/zlib/zlib.h"
+
+namespace libtextclassifier3 {
+
+// An entity associated with an action.
+struct ActionSuggestionAnnotation {
+ // The referenced message.
+ // -1 if not referencing a particular message in the provided input.
+ int message_index;
+
+ // The span within the reference message.
+ // (-1, -1) if not referencing a particular location.
+ CodepointSpan span;
+ ClassificationResult entity;
+
+ // Optional annotation name.
+ std::string name;
+
+ explicit ActionSuggestionAnnotation()
+ : message_index(kInvalidIndex), span({kInvalidIndex, kInvalidIndex}) {}
+};
+
+// Action suggestion that contains a response text and the type of the response.
+struct ActionSuggestion {
+ // Text of the action suggestion.
+ std::string response_text;
+
+ // Type of the action suggestion.
+ std::string type;
+
+ // Score.
+ float score;
+
+ // The associated annotations.
+ std::vector<ActionSuggestionAnnotation> annotations;
+};
+
+// Actions suggestions result containing meta-information and the suggested
+// actions.
+struct ActionsSuggestionsResponse {
+ ActionsSuggestionsResponse()
+ : sensitivity_score(-1),
+ triggering_score(-1),
+ output_filtered_sensitivity(false),
+ output_filtered_min_triggering_score(false) {}
+
+ // The sensitivity assessment.
+ float sensitivity_score;
+ float triggering_score;
+
+ // Whether the output was suppressed by the sensitivity threshold.
+ bool output_filtered_sensitivity;
+
+ // Whether the output was suppressed by the triggering score threshold.
+ bool output_filtered_min_triggering_score;
+
+ // The suggested actions.
+ std::vector<ActionSuggestion> actions;
+};
+
+// Represents a single message in the conversation.
+struct ConversationMessage {
+ // User ID distinguishing the user from other users in the conversation.
+ int user_id;
+
+ // Text of the message.
+ std::string text;
+
+ // Reference time of this message.
+ int64 reference_time_ms_utc;
+
+ // Annotations on the text.
+ std::vector<AnnotatedSpan> annotations;
+
+ // Comma-separated list of locale specification for the text in the
+ // conversation (BCP 47 tags).
+ std::string locales;
+};
+
+// Conversation between multiple users.
+struct Conversation {
+ // Sequence of messages that were exchanged in the conversation.
+ std::vector<ConversationMessage> messages;
+};
+
+// Options for suggesting actions.
+struct ActionSuggestionOptions {
+ // Options for annotation of the messages.
+ AnnotationOptions annotation_options = AnnotationOptions::Default();
+
+ static ActionSuggestionOptions Default() { return ActionSuggestionOptions(); }
+};
+
+// Class for predicting actions following a conversation.
+class ActionsSuggestions {
+ public:
+ static std::unique_ptr<ActionsSuggestions> FromUnownedBuffer(
+ const uint8_t* buffer, const int size, const UniLib* unilib = nullptr);
+ // Takes ownership of the mmap.
+ static std::unique_ptr<ActionsSuggestions> FromScopedMmap(
+ std::unique_ptr<libtextclassifier3::ScopedMmap> mmap,
+ const UniLib* unilib = nullptr);
+ static std::unique_ptr<ActionsSuggestions> FromFileDescriptor(
+ const int fd, const int offset, const int size,
+ const UniLib* unilib = nullptr);
+ static std::unique_ptr<ActionsSuggestions> FromFileDescriptor(
+ const int fd, const UniLib* unilib = nullptr);
+ static std::unique_ptr<ActionsSuggestions> FromPath(
+ const std::string& path, const UniLib* unilib = nullptr);
+
+ ActionsSuggestionsResponse SuggestActions(
+ const Conversation& conversation,
+ const ActionSuggestionOptions& options =
+ ActionSuggestionOptions::Default()) const;
+
+ ActionsSuggestionsResponse SuggestActions(
+ const Conversation& conversation, const Annotator* annotator,
+ const ActionSuggestionOptions& options =
+ ActionSuggestionOptions::Default()) const;
+
+ // Provide an annotator.
+ void SetAnnotator(const Annotator* annotator);
+
+ // Should be in sync with those defined in Android.
+ // android/frameworks/base/core/java/android/view/textclassifier/ConversationActions.java
+ static const std::string& kViewCalendarType;
+ static const std::string& kViewMapType;
+ static const std::string& kTrackFlightType;
+ static const std::string& kOpenUrlType;
+ static const std::string& kSendSmsType;
+ static const std::string& kCallPhoneType;
+ static const std::string& kSendEmailType;
+ static const std::string& kShareLocation;
+
+ private:
+ // Checks that model contains all required fields, and initializes internal
+ // datastructures.
+ bool ValidateAndInitialize();
+
+ void SetOrCreateUnilib(const UniLib* unilib);
+
+ // Initializes regular expression rules.
+ bool InitializeRules(ZlibDecompressor* decompressor);
+
+ void SetupModelInput(const std::vector<std::string>& context,
+ const std::vector<int>& user_ids,
+ const std::vector<float>& time_diffs,
+ const int num_suggestions,
+ tflite::Interpreter* interpreter) const;
+ void ReadModelOutput(tflite::Interpreter* interpreter,
+ ActionsSuggestionsResponse* response) const;
+
+ void SuggestActionsFromModel(const Conversation& conversation,
+ const int num_messages,
+ ActionsSuggestionsResponse* response) const;
+
+ void SuggestActionsFromAnnotations(
+ const Conversation& conversation, const ActionSuggestionOptions& options,
+ const Annotator* annotator,
+ ActionsSuggestionsResponse* suggestions) const;
+
+ void CreateActionsFromAnnotationResult(
+ const int message_index, const AnnotatedSpan& annotation,
+ ActionsSuggestionsResponse* suggestions) const;
+
+ void SuggestActionsFromRules(const Conversation& conversation,
+ ActionsSuggestionsResponse* suggestions) const;
+
+ // Rank and deduplicate actions suggestions.
+ void RankActions(ActionsSuggestionsResponse* suggestions) const;
+
+ const ActionsModel* model_;
+ std::unique_ptr<libtextclassifier3::ScopedMmap> mmap_;
+
+ // Tensorflow Lite models.
+ std::unique_ptr<const TfLiteModelExecutor> model_executor_;
+
+ // Rules.
+ struct CompiledRule {
+ int rule_id;
+ std::unique_ptr<UniLib::RegexPattern> pattern;
+ };
+ std::vector<CompiledRule> rules_;
+
+ std::unique_ptr<UniLib> owned_unilib_;
+ const UniLib* unilib_;
+};
+
+// Interprets the buffer as a Model flatbuffer and returns it for reading.
+const ActionsModel* ViewActionsModel(const void* buffer, int size);
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ACTIONS_ACTIONS_SUGGESTIONS_H_
diff --git a/actions/actions-suggestions_test.cc b/actions/actions-suggestions_test.cc
new file mode 100644
index 0000000..c82763d
--- /dev/null
+++ b/actions/actions-suggestions_test.cc
@@ -0,0 +1,302 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "actions/actions-suggestions.h"
+
+#include <fstream>
+#include <iterator>
+#include <memory>
+
+#include "actions/actions_model_generated.h"
+#include "annotator/types.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "flatbuffers/flatbuffers.h"
+
+namespace libtextclassifier3 {
+namespace {
+constexpr char kModelFileName[] = "actions_suggestions_test.model";
+
+std::string ReadFile(const std::string& file_name) {
+ std::ifstream file_stream(file_name);
+ return std::string(std::istreambuf_iterator<char>(file_stream), {});
+}
+
+std::string GetModelPath() {
+ return "";
+}
+
+class ActionsSuggestionsTest : public testing::Test {
+ protected:
+ ActionsSuggestionsTest() : INIT_UNILIB_FOR_TESTING(unilib_) {}
+ std::unique_ptr<ActionsSuggestions> LoadTestModel() {
+ return ActionsSuggestions::FromPath(GetModelPath() + kModelFileName,
+ &unilib_);
+ }
+ UniLib unilib_;
+};
+
+TEST_F(ActionsSuggestionsTest, InstantiateActionSuggestions) {
+ EXPECT_THAT(LoadTestModel(), testing::NotNull());
+}
+
+TEST_F(ActionsSuggestionsTest, SuggestActions) {
+ std::unique_ptr<ActionsSuggestions> actions_suggestions = LoadTestModel();
+ const ActionsSuggestionsResponse& response =
+ actions_suggestions->SuggestActions(
+ {{{/*user_id=*/1, "Where are you?"}}});
+ EXPECT_EQ(response.actions.size(), 4);
+}
+
+TEST_F(ActionsSuggestionsTest, SuggestActionsFromAnnotations) {
+ std::unique_ptr<ActionsSuggestions> actions_suggestions = LoadTestModel();
+ AnnotatedSpan annotation;
+ annotation.span = {11, 15};
+ annotation.classification = {ClassificationResult("address", 1.0)};
+ const ActionsSuggestionsResponse& response =
+ actions_suggestions->SuggestActions({{{/*user_id=*/1, "are you at home?",
+ /*time_diff_secs=*/0,
+ /*annotations=*/{annotation}}}});
+ EXPECT_EQ(response.actions.size(), 5);
+ EXPECT_EQ(response.actions.front().type, "view_map");
+ EXPECT_EQ(response.actions.front().score, 1.0);
+}
+
+void TestSuggestActionsWithThreshold(
+ const std::function<void(ActionsModelT*)>& set_value_fn,
+ const UniLib* unilib = nullptr, const int expected_size = 0) {
+ const std::string actions_model_string =
+ ReadFile(GetModelPath() + kModelFileName);
+ std::unique_ptr<ActionsModelT> actions_model =
+ UnPackActionsModel(actions_model_string.c_str());
+ set_value_fn(actions_model.get());
+ flatbuffers::FlatBufferBuilder builder;
+ FinishActionsModelBuffer(builder,
+ ActionsModel::Pack(builder, actions_model.get()));
+ std::unique_ptr<ActionsSuggestions> actions_suggestions =
+ ActionsSuggestions::FromUnownedBuffer(
+ reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib);
+ ASSERT_TRUE(actions_suggestions);
+ const ActionsSuggestionsResponse& response =
+ actions_suggestions->SuggestActions(
+ {{{/*user_id=*/1, "Where are you?"}}});
+ EXPECT_EQ(response.actions.size(), expected_size);
+}
+
+TEST_F(ActionsSuggestionsTest, SuggestActionsWithTriggeringScore) {
+ TestSuggestActionsWithThreshold(
+ [](ActionsModelT* actions_model) {
+ actions_model->preconditions->min_smart_reply_triggering_score = 1.0;
+ },
+ &unilib_,
+ /*expected_size=*/1 /*no smart reply, only actions*/
+ );
+}
+
+TEST_F(ActionsSuggestionsTest, SuggestActionsWithSensitiveTopicScore) {
+ TestSuggestActionsWithThreshold(
+ [](ActionsModelT* actions_model) {
+ actions_model->preconditions->max_sensitive_topic_score = 0.0;
+ },
+ &unilib_,
+ /*expected_size=*/4 /* no sensitive prediction in test model*/);
+}
+
+TEST_F(ActionsSuggestionsTest, SuggestActionsWithMaxInputLength) {
+ TestSuggestActionsWithThreshold(
+ [](ActionsModelT* actions_model) {
+ actions_model->preconditions->max_input_length = 0;
+ },
+ &unilib_);
+}
+
+TEST_F(ActionsSuggestionsTest, SuggestActionsWithMinInputLength) {
+ TestSuggestActionsWithThreshold(
+ [](ActionsModelT* actions_model) {
+ actions_model->preconditions->min_input_length = 100;
+ },
+ &unilib_);
+}
+
+TEST_F(ActionsSuggestionsTest, SuppressActionsFromAnnotationsOnSensitiveTopic) {
+ const std::string actions_model_string =
+ ReadFile(GetModelPath() + kModelFileName);
+ std::unique_ptr<ActionsModelT> actions_model =
+ UnPackActionsModel(actions_model_string.c_str());
+
+ // Don't test if no sensitivity score is produced
+ if (actions_model->tflite_model_spec->output_sensitive_topic_score < 0) {
+ return;
+ }
+
+ actions_model->preconditions->max_sensitive_topic_score = 0.0;
+ actions_model->preconditions->suppress_on_sensitive_topic = true;
+ flatbuffers::FlatBufferBuilder builder;
+ FinishActionsModelBuffer(builder,
+ ActionsModel::Pack(builder, actions_model.get()));
+ std::unique_ptr<ActionsSuggestions> actions_suggestions =
+ ActionsSuggestions::FromUnownedBuffer(
+ reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
+ builder.GetSize(), &unilib_);
+ AnnotatedSpan annotation;
+ annotation.span = {11, 15};
+ annotation.classification = {
+ ClassificationResult(Annotator::kAddressCollection, 1.0)};
+ const ActionsSuggestionsResponse& response =
+ actions_suggestions->SuggestActions({{{/*user_id=*/1, "are you at home?",
+ /*time_diff_secs=*/0,
+ /*annotations=*/{annotation}}}});
+ EXPECT_THAT(response.actions, testing::IsEmpty());
+}
+
+TEST_F(ActionsSuggestionsTest, SuggestActionsWithLongerConversation) {
+ const std::string actions_model_string =
+ ReadFile(GetModelPath() + kModelFileName);
+ std::unique_ptr<ActionsModelT> actions_model =
+ UnPackActionsModel(actions_model_string.c_str());
+
+ // Allow a larger conversation context.
+ actions_model->max_conversation_history_length = 10;
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishActionsModelBuffer(builder,
+ ActionsModel::Pack(builder, actions_model.get()));
+ std::unique_ptr<ActionsSuggestions> actions_suggestions =
+ ActionsSuggestions::FromUnownedBuffer(
+ reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
+ builder.GetSize(), &unilib_);
+ AnnotatedSpan annotation;
+ annotation.span = {11, 15};
+ annotation.classification = {
+ ClassificationResult(Annotator::kAddressCollection, 1.0)};
+ const ActionsSuggestionsResponse& response =
+ actions_suggestions->SuggestActions(
+ {{{/*user_id=*/0, "hi, how are you?", /*reference_time=*/10000},
+ {/*user_id=*/1, "good! are you at home?",
+ /*reference_time=*/15000,
+ /*annotations=*/{annotation}}}});
+ EXPECT_EQ(response.actions.size(), 1);
+ EXPECT_EQ(response.actions.back().type, "view_map");
+ EXPECT_EQ(response.actions.back().score, 1.0);
+}
+
+TEST_F(ActionsSuggestionsTest, CreateActionsFromClassificationResult) {
+ std::unique_ptr<ActionsSuggestions> actions_suggestions = LoadTestModel();
+ AnnotatedSpan annotation;
+ annotation.span = {13, 16};
+ annotation.classification = {
+ ClassificationResult(Annotator::kPhoneCollection, 1.0)};
+
+ const ActionsSuggestionsResponse& response =
+ actions_suggestions->SuggestActions({{{/*user_id=*/1, "can you call 911?",
+ /*time_diff_secs=*/0,
+ /*annotations=*/{annotation}}}});
+
+ EXPECT_EQ(response.actions.size(),
+ 5 /* smart replies + actions from annotations*/);
+ EXPECT_EQ(response.actions[0].type, "call_phone");
+ EXPECT_EQ(response.actions[0].score, 1.0);
+ EXPECT_EQ(response.actions[0].annotations.size(), 1);
+ EXPECT_EQ(response.actions[0].annotations[0].message_index, 0);
+ EXPECT_EQ(response.actions[0].annotations[0].span, annotation.span);
+ EXPECT_EQ(response.actions[1].type, "send_sms");
+ EXPECT_EQ(response.actions[1].score, 1.0);
+ EXPECT_EQ(response.actions[1].annotations.size(), 1);
+ EXPECT_EQ(response.actions[1].annotations[0].message_index, 0);
+ EXPECT_EQ(response.actions[1].annotations[0].span, annotation.span);
+}
+
+#ifdef TC3_UNILIB_ICU
+TEST_F(ActionsSuggestionsTest, CreateActionsFromRules) {
+ const std::string actions_model_string =
+ ReadFile(GetModelPath() + kModelFileName);
+ std::unique_ptr<ActionsModelT> actions_model =
+ UnPackActionsModel(actions_model_string.c_str());
+
+ actions_model->rules.reset(new RulesModelT());
+ actions_model->rules->rule.emplace_back(new RulesModel_::RuleT);
+ actions_model->rules->rule.back()->pattern = "^(?i:hello\\sthere)$";
+ actions_model->rules->rule.back()->actions.emplace_back(
+ new ActionSuggestionSpecT);
+ actions_model->rules->rule.back()->actions.back()->type = "text_reply";
+ actions_model->rules->rule.back()->actions.back()->response_text =
+ "General Kenobi!";
+ actions_model->rules->rule.back()->actions.back()->score = 1.f;
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishActionsModelBuffer(builder,
+ ActionsModel::Pack(builder, actions_model.get()));
+ std::unique_ptr<ActionsSuggestions> actions_suggestions =
+ ActionsSuggestions::FromUnownedBuffer(
+ reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
+ builder.GetSize(), &unilib_);
+
+ const ActionsSuggestionsResponse& response =
+ actions_suggestions->SuggestActions({{{/*user_id=*/1, "hello there"}}});
+ EXPECT_EQ(response.actions.size(), 1);
+ EXPECT_EQ(response.actions[0].response_text, "General Kenobi!");
+}
+
+TEST_F(ActionsSuggestionsTest, DeduplicateActions) {
+ std::unique_ptr<ActionsSuggestions> actions_suggestions = LoadTestModel();
+ ActionsSuggestionsResponse response = actions_suggestions->SuggestActions(
+ {{{/*user_id=*/1, "Where are you?"}}});
+ EXPECT_EQ(response.actions.size(), 4);
+
+ // Check that the location sharing model triggered.
+ bool has_location_sharing_action = false;
+ for (const ActionSuggestion action : response.actions) {
+ if (action.type == ActionsSuggestions::kShareLocation) {
+ has_location_sharing_action = true;
+ break;
+ }
+ }
+ EXPECT_TRUE(has_location_sharing_action);
+
+ // Add custom rule for location sharing.
+ const std::string actions_model_string =
+ ReadFile(GetModelPath() + kModelFileName);
+ std::unique_ptr<ActionsModelT> actions_model =
+ UnPackActionsModel(actions_model_string.c_str());
+
+ actions_model->rules.reset(new RulesModelT());
+ actions_model->rules->rule.emplace_back(new RulesModel_::RuleT);
+ actions_model->rules->rule.back()->pattern = "^(?i:where are you[.?]?)$";
+ actions_model->rules->rule.back()->actions.emplace_back(
+ new ActionSuggestionSpecT);
+ actions_model->rules->rule.back()->actions.back()->type = "text_reply";
+ actions_model->rules->rule.back()->actions.back()->response_text =
+ "I am already here for the test!";
+ actions_model->rules->rule.back()->actions.back()->score = 1.f;
+ actions_model->rules->rule.back()->actions.back()->type =
+ ActionsSuggestions::kShareLocation;
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishActionsModelBuffer(builder,
+ ActionsModel::Pack(builder, actions_model.get()));
+ actions_suggestions = ActionsSuggestions::FromUnownedBuffer(
+ reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
+ builder.GetSize(), &unilib_);
+
+ response = actions_suggestions->SuggestActions(
+ {{{/*user_id=*/1, "Where are you?"}}});
+ EXPECT_EQ(response.actions.size(), 5);
+}
+#endif
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/actions/actions_jni.cc b/actions/actions_jni.cc
new file mode 100644
index 0000000..111967e
--- /dev/null
+++ b/actions/actions_jni.cc
@@ -0,0 +1,310 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// JNI wrapper for actions.
+
+#include "actions/actions_jni.h"
+
+#include <jni.h>
+#include <type_traits>
+#include <vector>
+
+#include "actions/actions-suggestions.h"
+#include "annotator/annotator.h"
+#include "annotator/annotator_jni_common.h"
+#include "utils/base/integral_types.h"
+#include "utils/java/scoped_local_ref.h"
+#include "utils/memory/mmap.h"
+
+using libtextclassifier3::ActionsSuggestions;
+using libtextclassifier3::ActionsSuggestionsResponse;
+using libtextclassifier3::ActionSuggestion;
+using libtextclassifier3::ActionSuggestionOptions;
+using libtextclassifier3::Annotator;
+using libtextclassifier3::Conversation;
+using libtextclassifier3::ScopedLocalRef;
+using libtextclassifier3::ToStlString;
+
+// When using the Java's ICU, UniLib needs to be instantiated with a JavaVM
+// pointer from JNI. When using a standard ICU the pointer is not needed and the
+// objects are instantiated implicitly.
+#ifdef TC3_UNILIB_JAVAICU
+using libtextclassifier3::UniLib;
+#endif
+
+namespace libtextclassifier3 {
+
+namespace {
+ActionSuggestionOptions FromJavaActionSuggestionOptions(JNIEnv* env,
+ jobject joptions) {
+ ActionSuggestionOptions options = ActionSuggestionOptions::Default();
+
+ if (!joptions) {
+ return options;
+ }
+
+ const ScopedLocalRef<jclass> options_class(
+ env->FindClass(TC3_PACKAGE_PATH TC3_ACTIONS_CLASS_NAME_STR
+ "$ActionSuggestionOptions"),
+ env);
+
+ if (!options_class) {
+ return options;
+ }
+
+ const std::pair<bool, jobject> status_or_annotation_options =
+ CallJniMethod0<jobject>(env, joptions, options_class.get(),
+ &JNIEnv::CallObjectMethod, "getAnnotationOptions",
+ "L" TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR
+ "$AnnotationOptions;");
+
+ if (!status_or_annotation_options.first) {
+ return options;
+ }
+
+ // Create annotation options.
+ options.annotation_options =
+ FromJavaAnnotationOptions(env, status_or_annotation_options.second);
+
+ return options;
+}
+
+jobjectArray ActionSuggestionsToJObjectArray(
+ JNIEnv* env, const std::vector<ActionSuggestion>& action_result) {
+ const ScopedLocalRef<jclass> result_class(
+ env->FindClass(TC3_PACKAGE_PATH TC3_ACTIONS_CLASS_NAME_STR
+ "$ActionSuggestion"),
+ env);
+ if (!result_class) {
+ TC3_LOG(ERROR) << "Couldn't find ActionSuggestion class.";
+ return nullptr;
+ }
+
+ const jmethodID result_class_constructor = env->GetMethodID(
+ result_class.get(), "<init>", "(Ljava/lang/String;Ljava/lang/String;F)V");
+ const jobjectArray results =
+ env->NewObjectArray(action_result.size(), result_class.get(), nullptr);
+ for (int i = 0; i < action_result.size(); i++) {
+ ScopedLocalRef<jobject> result(env->NewObject(
+ result_class.get(), result_class_constructor,
+ env->NewStringUTF(action_result[i].response_text.c_str()),
+ env->NewStringUTF(action_result[i].type.c_str()),
+ static_cast<jfloat>(action_result[i].score)));
+ env->SetObjectArrayElement(results, i, result.get());
+ }
+ return results;
+}
+
+ConversationMessage FromJavaConversationMessage(JNIEnv* env, jobject jmessage) {
+ if (!jmessage) {
+ return {};
+ }
+
+ const ScopedLocalRef<jclass> message_class(
+ env->FindClass(TC3_PACKAGE_PATH TC3_ACTIONS_CLASS_NAME_STR
+ "$ConversationMessage"),
+ env);
+ const std::pair<bool, jobject> status_or_text = CallJniMethod0<jobject>(
+ env, jmessage, message_class.get(), &JNIEnv::CallObjectMethod, "getText",
+ "Ljava/lang/String;");
+ const std::pair<bool, int32> status_or_user_id =
+ CallJniMethod0<int32>(env, jmessage, message_class.get(),
+ &JNIEnv::CallIntMethod, "getUserId", "I");
+ const std::pair<bool, int64> status_or_reference_time = CallJniMethod0<int64>(
+ env, jmessage, message_class.get(), &JNIEnv::CallLongMethod,
+ "getReferenceTimeMsUtc", "J");
+ const std::pair<bool, jobject> status_or_locales = CallJniMethod0<jobject>(
+ env, jmessage, message_class.get(), &JNIEnv::CallObjectMethod,
+ "getLocales", "Ljava/lang/String;");
+ if (!status_or_text.first || !status_or_user_id.first ||
+ !status_or_locales.first || !status_or_reference_time.first) {
+ return {};
+ }
+
+ ConversationMessage message;
+ message.text =
+ ToStlString(env, reinterpret_cast<jstring>(status_or_text.second));
+ message.user_id = status_or_user_id.second;
+ message.reference_time_ms_utc = status_or_reference_time.second;
+ message.locales =
+ ToStlString(env, reinterpret_cast<jstring>(status_or_locales.second));
+ return message;
+}
+
+Conversation FromJavaConversation(JNIEnv* env, jobject jconversation) {
+ if (!jconversation) {
+ return {};
+ }
+
+ const ScopedLocalRef<jclass> conversation_class(
+ env->FindClass(TC3_PACKAGE_PATH TC3_ACTIONS_CLASS_NAME_STR
+ "$Conversation"),
+ env);
+
+ const std::pair<bool, jobject> status_or_messages = CallJniMethod0<jobject>(
+ env, jconversation, conversation_class.get(), &JNIEnv::CallObjectMethod,
+ "getConversationMessages",
+ "[L" TC3_PACKAGE_PATH TC3_ACTIONS_CLASS_NAME_STR "$ConversationMessage;");
+
+ if (!status_or_messages.first) {
+ return {};
+ }
+
+ const jobjectArray jmessages =
+ reinterpret_cast<jobjectArray>(status_or_messages.second);
+
+ const int size = env->GetArrayLength(jmessages);
+
+ std::vector<ConversationMessage> messages;
+ for (int i = 0; i < size; i++) {
+ jobject jmessage = env->GetObjectArrayElement(jmessages, i);
+ ConversationMessage message = FromJavaConversationMessage(env, jmessage);
+ messages.push_back(message);
+ }
+ Conversation conversation;
+ conversation.messages = messages;
+ return conversation;
+}
+
+jstring GetLocalesFromMmap(JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) {
+ if (!mmap->handle().ok()) {
+ return env->NewStringUTF("");
+ }
+ const ActionsModel* model = libtextclassifier3::ViewActionsModel(
+ mmap->handle().start(), mmap->handle().num_bytes());
+ if (!model || !model->locales()) {
+ return env->NewStringUTF("");
+ }
+ return env->NewStringUTF(model->locales()->c_str());
+}
+
+jint GetVersionFromMmap(JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) {
+ if (!mmap->handle().ok()) {
+ return 0;
+ }
+ const ActionsModel* model = libtextclassifier3::ViewActionsModel(
+ mmap->handle().start(), mmap->handle().num_bytes());
+ if (!model) {
+ return 0;
+ }
+ return model->version();
+}
+
+jstring GetNameFromMmap(JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) {
+ if (!mmap->handle().ok()) {
+ return env->NewStringUTF("");
+ }
+ const ActionsModel* model = libtextclassifier3::ViewActionsModel(
+ mmap->handle().start(), mmap->handle().num_bytes());
+ if (!model || !model->name()) {
+ return env->NewStringUTF("");
+ }
+ return env->NewStringUTF(model->name()->c_str());
+}
+} // namespace
+} // namespace libtextclassifier3
+
+using libtextclassifier3::ActionSuggestionsToJObjectArray;
+using libtextclassifier3::FromJavaActionSuggestionOptions;
+using libtextclassifier3::FromJavaConversation;
+
+TC3_JNI_METHOD(jlong, TC3_ACTIONS_CLASS_NAME, nativeNewActionsModel)
+(JNIEnv* env, jobject thiz, jint fd) {
+#ifdef TC3_UNILIB_JAVAICU
+ std::shared_ptr<libtextclassifier3::JniCache> jni_cache(
+ libtextclassifier3::JniCache::Create(env));
+ return reinterpret_cast<jlong>(
+ ActionsSuggestions::FromFileDescriptor(fd, new UniLib(jni_cache))
+ .release());
+#else
+ return reinterpret_cast<jlong>(
+ ActionsSuggestions::FromFileDescriptor(fd).release());
+#endif // TC3_UNILIB_JAVAICU
+}
+
+TC3_JNI_METHOD(jlong, TC3_ACTIONS_CLASS_NAME, nativeNewActionsModelFromPath)
+(JNIEnv* env, jobject thiz, jstring path) {
+ const std::string path_str = ToStlString(env, path);
+#ifdef TC3_UNILIB_JAVAICU
+ std::shared_ptr<libtextclassifier3::JniCache> jni_cache(
+ libtextclassifier3::JniCache::Create(env));
+ return reinterpret_cast<jlong>(
+ ActionsSuggestions::FromPath(path_str, new UniLib(jni_cache)).release());
+#else
+ return reinterpret_cast<jlong>(
+ ActionsSuggestions::FromPath(path_str).release());
+#endif // TC3_UNILIB_JAVAICU
+}
+
+TC3_JNI_METHOD(jlong, TC3_ACTIONS_CLASS_NAME,
+ nativeNewActionModelsFromAssetFileDescriptor)
+(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size) {
+ const jint fd = libtextclassifier3::GetFdFromAssetFileDescriptor(env, afd);
+#ifdef TC3_UNILIB_JAVAICU
+ std::shared_ptr<libtextclassifier3::JniCache> jni_cache(
+ libtextclassifier3::JniCache::Create(env));
+ return reinterpret_cast<jlong>(ActionsSuggestions::FromFileDescriptor(
+ fd, offset, size, new UniLib(jni_cache))
+ .release());
+#else
+ return reinterpret_cast<jlong>(
+ ActionsSuggestions::FromFileDescriptor(fd, offset, size).release());
+#endif // TC3_UNILIB_JAVAICU
+}
+
+TC3_JNI_METHOD(jobjectArray, TC3_ACTIONS_CLASS_NAME, nativeSuggestActions)
+(JNIEnv* env, jobject clazz, jlong ptr, jobject jconversation, jobject joptions,
+ jlong annotatorPtr) {
+ if (!ptr) {
+ return nullptr;
+ }
+ const Conversation conversation = FromJavaConversation(env, jconversation);
+ const ActionSuggestionOptions actionSuggestionOptions =
+ FromJavaActionSuggestionOptions(env, joptions);
+ ActionsSuggestions* action_model = reinterpret_cast<ActionsSuggestions*>(ptr);
+ Annotator* annotator = reinterpret_cast<Annotator*>(annotatorPtr);
+
+ const ActionsSuggestionsResponse response = action_model->SuggestActions(
+ conversation, annotator, actionSuggestionOptions);
+ return ActionSuggestionsToJObjectArray(env, response.actions);
+}
+
+TC3_JNI_METHOD(void, TC3_ACTIONS_CLASS_NAME, nativeCloseActionsModel)
+(JNIEnv* env, jobject clazz, jlong ptr) {
+ ActionsSuggestions* model = reinterpret_cast<ActionsSuggestions*>(ptr);
+ delete model;
+}
+
+TC3_JNI_METHOD(jstring, TC3_ACTIONS_CLASS_NAME, nativeGetLocales)
+(JNIEnv* env, jobject clazz, jint fd) {
+ const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
+ new libtextclassifier3::ScopedMmap(fd));
+ return libtextclassifier3::GetLocalesFromMmap(env, mmap.get());
+}
+
+TC3_JNI_METHOD(jstring, TC3_ACTIONS_CLASS_NAME, nativeGetName)
+(JNIEnv* env, jobject clazz, jint fd) {
+ const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
+ new libtextclassifier3::ScopedMmap(fd));
+ return libtextclassifier3::GetNameFromMmap(env, mmap.get());
+}
+
+TC3_JNI_METHOD(jint, TC3_ACTIONS_CLASS_NAME, nativeGetVersion)
+(JNIEnv* env, jobject clazz, jint fd) {
+ const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
+ new libtextclassifier3::ScopedMmap(fd));
+ return libtextclassifier3::GetVersionFromMmap(env, mmap.get());
+}
diff --git a/actions/actions_jni.h b/actions/actions_jni.h
new file mode 100644
index 0000000..c5980bc
--- /dev/null
+++ b/actions/actions_jni.h
@@ -0,0 +1,64 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ACTIONS_ACTIONS_JNI_H_
+#define LIBTEXTCLASSIFIER_ACTIONS_ACTIONS_JNI_H_
+
+#include <jni.h>
+#include <string>
+#include "utils/java/jni-base.h"
+
+#ifndef TC3_ACTIONS_CLASS_NAME
+#define TC3_ACTIONS_CLASS_NAME ActionsSuggestionsModel
+#endif
+
+#define TC3_ACTIONS_CLASS_NAME_STR TC3_ADD_QUOTES(TC3_ACTIONS_CLASS_NAME)
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+TC3_JNI_METHOD(jlong, TC3_ACTIONS_CLASS_NAME, nativeNewActionsModel)
+(JNIEnv* env, jobject thiz, jint fd);
+
+TC3_JNI_METHOD(jlong, TC3_ACTIONS_CLASS_NAME, nativeNewActionsModelFromPath)
+(JNIEnv* env, jobject thiz, jstring path);
+
+TC3_JNI_METHOD(jlong, TC3_ACTIONS_CLASS_NAME,
+ nativeNewActionsModelModelsFromAssetFileDescriptor)
+(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size);
+
+TC3_JNI_METHOD(jobjectArray, TC3_ACTIONS_CLASS_NAME, nativeSuggestActions)
+(JNIEnv* env, jobject thiz, jlong ptr, jobject jconversation, jobject joptions,
+ jlong annotatorPtr);
+
+TC3_JNI_METHOD(void, TC3_ACTIONS_CLASS_NAME, nativeCloseActionsModel)
+(JNIEnv* env, jobject thiz, jlong ptr);
+
+TC3_JNI_METHOD(jstring, TC3_ACTIONS_CLASS_NAME, nativeGetLocales)
+(JNIEnv* env, jobject clazz, jint fd);
+
+TC3_JNI_METHOD(jstring, TC3_ACTIONS_CLASS_NAME, nativeGetName)
+(JNIEnv* env, jobject clazz, jint fd);
+
+TC3_JNI_METHOD(jint, TC3_ACTIONS_CLASS_NAME, nativeGetVersion)
+(JNIEnv* env, jobject clazz, jint fd);
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif // LIBTEXTCLASSIFIER_ACTIONS_ACTIONS_JNI_H_
diff --git a/actions/actions_model.fbs b/actions/actions_model.fbs
new file mode 100755
index 0000000..88dff79
--- /dev/null
+++ b/actions/actions_model.fbs
@@ -0,0 +1,163 @@
+//
+// Copyright (C) 2018 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+include "utils/zlib/buffer.fbs";
+
+file_identifier "TC3A";
+
+// Options to specify triggering behaviour per action class.
+namespace libtextclassifier3;
+table ActionTypeOptions {
+ // The name of the predicted action.
+ name:string;
+
+ // Triggering behaviour.
+ // Whether the action class is considered in the model output or not.
+ enabled:bool = true;
+
+ // Minimal output score threshold.
+ min_triggering_score:float = 0;
+}
+
+// TensorFlow Lite model for suggesting actions.
+namespace libtextclassifier3;
+table TensorflowLiteModelSpec {
+ // TensorFlow Lite model for suggesting actions.
+ tflite_model:[ubyte] (force_align: 16);
+
+ // Input specification.
+ input_user_id:int = 0;
+
+ input_context:int = 1;
+ input_context_length:int = 2;
+ input_time_diffs:int = 3;
+ input_num_suggestions:int = 4;
+
+ // Output specification.
+ output_replies:int = 0;
+
+ output_replies_scores:int = 1;
+ output_sensitive_topic_score:int = 3;
+ output_triggering_score:int = 4;
+ output_actions_scores:int = 5;
+}
+
+namespace libtextclassifier3;
+table TriggeringPreconditions {
+ // Lower bound thresholds for the smart reply model prediction output.
+ min_smart_reply_triggering_score:float;
+
+ // Maximum sensitive score for which actions and smart replies are shown.
+ max_sensitive_topic_score:float = 1;
+
+ // Whether to supress all model output when a conversation is classified as
+ // sensitive.
+ suppress_on_sensitive_topic:bool = true;
+
+ // Thresholds on the model prediction input.
+ // The minimal length of input to consider for prediction.
+ min_input_length:int = 0;
+
+ // The maximal length of input to consider for prediciton, -1 if unbounded.
+ max_input_length:int = -1;
+}
+
+namespace libtextclassifier3;
+table ActionSuggestionSpec {
+ // Type of the action suggestion.
+ type:string;
+
+ // Text of a smart reply action.
+ response_text:string;
+
+ // Score.
+ score:float;
+}
+
+namespace libtextclassifier3.AnnotationActionsSpec_;
+table AnnotationMapping {
+ // The annotation collection.
+ annotation_collection:string;
+
+ // The action name to use.
+ action:libtextclassifier3.ActionSuggestionSpec;
+
+ // Whether to use the score of the annotation as the action score.
+ use_annotation_score:bool = true;
+
+ // Minimum threshold for the annotation score for filtering.
+ min_annotation_score:float;
+}
+
+// Configuration for actions based on annotatations.
+namespace libtextclassifier3;
+table AnnotationActionsSpec {
+ annotation_mapping:[libtextclassifier3.AnnotationActionsSpec_.AnnotationMapping];
+}
+
+// List of regular expression matchers.
+namespace libtextclassifier3.RulesModel_;
+table Rule {
+ // The regular expression pattern.
+ pattern:string;
+
+ compressed_pattern:libtextclassifier3.CompressedBuffer;
+
+ // The actions to produce upon triggering.
+ actions:[libtextclassifier3.ActionSuggestionSpec];
+}
+
+// Rule based actions.
+namespace libtextclassifier3;
+table RulesModel {
+ rule:[libtextclassifier3.RulesModel_.Rule];
+}
+
+namespace libtextclassifier3;
+table ActionsModel {
+ // Comma-separated list of locales supported by the model as BCP 47 tags.
+ locales:string;
+
+ // Version of the actions model.
+ version:int;
+
+ // A name for the model that can be used e.g. for logging.
+ name:string;
+
+ tflite_model_spec:libtextclassifier3.TensorflowLiteModelSpec;
+
+ // Output classes.
+ smart_reply_action_type:string;
+
+ action_type:[libtextclassifier3.ActionTypeOptions];
+
+ // Triggering conditions of the model.
+ preconditions:libtextclassifier3.TriggeringPreconditions;
+
+ // Default number of smart reply predictions.
+ num_smart_replies:int = 3;
+
+ // Length of message history to consider, -1 if unbounded.
+ max_conversation_history_length:int = 1;
+
+ // Configuration for mapping annotations to action suggestions.
+ annotation_actions_spec:libtextclassifier3.AnnotationActionsSpec;
+
+ // Configuration for rules.
+ rules:libtextclassifier3.RulesModel;
+}
+
+root_type libtextclassifier3.ActionsModel;
diff --git a/actions/test_data/actions_suggestions_test.model b/actions/test_data/actions_suggestions_test.model
new file mode 100644
index 0000000..ee60ce2
--- /dev/null
+++ b/actions/test_data/actions_suggestions_test.model
Binary files differ
diff --git a/actions/zlib-utils.cc b/actions/zlib-utils.cc
new file mode 100644
index 0000000..bf8dc83
--- /dev/null
+++ b/actions/zlib-utils.cc
@@ -0,0 +1,83 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "actions/zlib-utils.h"
+
+#include <memory>
+
+#include "utils/base/logging.h"
+#include "utils/zlib/zlib.h"
+
+namespace libtextclassifier3 {
+
+// Compress rule fields in the model.
+bool CompressActionsModel(ActionsModelT* model) {
+ std::unique_ptr<ZlibCompressor> zlib_compressor = ZlibCompressor::Instance();
+ if (!zlib_compressor) {
+ TC3_LOG(ERROR) << "Cannot compress model.";
+ return false;
+ }
+
+ // Compress regex rules.
+ if (model->rules != nullptr) {
+ for (int i = 0; i < model->rules->rule.size(); i++) {
+ RulesModel_::RuleT* rule = model->rules->rule[i].get();
+ rule->compressed_pattern.reset(new CompressedBufferT);
+ zlib_compressor->Compress(rule->pattern, rule->compressed_pattern.get());
+ rule->pattern.clear();
+ }
+ }
+
+ return true;
+}
+
+bool DecompressActionsModel(ActionsModelT* model) {
+ std::unique_ptr<ZlibDecompressor> zlib_decompressor =
+ ZlibDecompressor::Instance();
+ if (!zlib_decompressor) {
+ TC3_LOG(ERROR) << "Cannot initialize decompressor.";
+ return false;
+ }
+
+ // Decompress regex rules.
+ if (model->rules != nullptr) {
+ for (int i = 0; i < model->rules->rule.size(); i++) {
+ RulesModel_::RuleT* rule = model->rules->rule[i].get();
+ if (!zlib_decompressor->MaybeDecompress(rule->compressed_pattern.get(),
+ &rule->pattern)) {
+ TC3_LOG(ERROR) << "Cannot decompress pattern: " << i;
+ return false;
+ }
+ rule->compressed_pattern.reset(nullptr);
+ }
+ }
+
+ return true;
+}
+
+std::string CompressSerializedActionsModel(const std::string& model) {
+ std::unique_ptr<ActionsModelT> unpacked_model =
+ UnPackActionsModel(model.c_str());
+ TC3_CHECK(unpacked_model != nullptr);
+ TC3_CHECK(CompressActionsModel(unpacked_model.get()));
+ flatbuffers::FlatBufferBuilder builder;
+ FinishActionsModelBuffer(builder,
+ ActionsModel::Pack(builder, unpacked_model.get()));
+ return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize());
+}
+
+} // namespace libtextclassifier3
diff --git a/actions/zlib-utils.h b/actions/zlib-utils.h
new file mode 100644
index 0000000..3cf7a23
--- /dev/null
+++ b/actions/zlib-utils.h
@@ -0,0 +1,37 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Functions to compress and decompress low entropy entries in the model.
+
+#ifndef LIBTEXTCLASSIFIER_ACTIONS_ZLIB_UTILS_H_
+#define LIBTEXTCLASSIFIER_ACTIONS_ZLIB_UTILS_H_
+
+#include "actions/actions_model_generated.h"
+
+namespace libtextclassifier3 {
+
+// Compresses regex rules in the model in place.
+bool CompressActionsModel(ActionsModelT* model);
+
+// Decompresses regex rules in the model in place.
+bool DecompressActionsModel(ActionsModelT* model);
+
+// Compresses regex rules in the model.
+std::string CompressSerializedActionsModel(const std::string& model);
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ACTIONS_ZLIB_UTILS_H_
diff --git a/actions/zlib-utils_test.cc b/actions/zlib-utils_test.cc
new file mode 100644
index 0000000..377f344
--- /dev/null
+++ b/actions/zlib-utils_test.cc
@@ -0,0 +1,72 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "actions/zlib-utils.h"
+
+#include <memory>
+
+#include "actions/actions_model_generated.h"
+#include "utils/zlib/zlib.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+
+namespace {
+
+TEST(ZlibUtilsTest, CompressModel) {
+ ActionsModelT model;
+ constexpr char kTestPattern1[] = "this is a test pattern";
+ constexpr char kTestPattern2[] = "this is a second test pattern";
+ model.rules.reset(new RulesModelT);
+ model.rules->rule.emplace_back(new RulesModel_::RuleT);
+ model.rules->rule.back()->pattern = kTestPattern1;
+ model.rules->rule.emplace_back(new RulesModel_::RuleT);
+ model.rules->rule.back()->pattern = kTestPattern2;
+
+ // Compress the model.
+ EXPECT_TRUE(CompressActionsModel(&model));
+
+ // Sanity check that uncompressed field is removed.
+ EXPECT_TRUE(model.rules->rule[0]->pattern.empty());
+ EXPECT_TRUE(model.rules->rule[1]->pattern.empty());
+ // Pack and load the model.
+ flatbuffers::FlatBufferBuilder builder;
+ FinishActionsModelBuffer(builder, ActionsModel::Pack(builder, &model));
+ const ActionsModel* compressed_model = GetActionsModel(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()));
+ ASSERT_TRUE(compressed_model != nullptr);
+
+ // Decompress the fields again and check that they match the original.
+ std::unique_ptr<ZlibDecompressor> decompressor = ZlibDecompressor::Instance();
+ ASSERT_TRUE(decompressor != nullptr);
+ std::string uncompressed_pattern;
+ EXPECT_TRUE(decompressor->MaybeDecompress(
+ compressed_model->rules()->rule()->Get(0)->compressed_pattern(),
+ &uncompressed_pattern));
+ EXPECT_EQ(uncompressed_pattern, kTestPattern1);
+ EXPECT_TRUE(decompressor->MaybeDecompress(
+ compressed_model->rules()->rule()->Get(1)->compressed_pattern(),
+ &uncompressed_pattern));
+ EXPECT_EQ(uncompressed_pattern, kTestPattern2);
+ EXPECT_TRUE(DecompressActionsModel(&model));
+ EXPECT_EQ(model.rules->rule[0]->pattern, kTestPattern1);
+ EXPECT_EQ(model.rules->rule[1]->pattern, kTestPattern2);
+}
+
+} // namespace
+
+} // namespace libtextclassifier3
diff --git a/generate_flatbuffers.mk b/generate_flatbuffers.mk
index 1522463..3d36e41 100644
--- a/generate_flatbuffers.mk
+++ b/generate_flatbuffers.mk
@@ -44,3 +44,39 @@
$(ANNOTATOR_MODEL_H): $(FLATC) $(ANNOTATOR_MODEL_FBS) $(INTENT_CONFIG_H)
$(transform-fbs-to-cpp)
LOCAL_GENERATED_SOURCES += $(ANNOTATOR_MODEL_H)
+
+# Generate actions/actions_model_generated.h using FlatBuffer schema compiler.
+ACTIONS_MODEL_FBS := $(LOCAL_PATH)/actions/actions_model.fbs
+ACTIONS_MODEL_H := $(intermediates)/actions/actions_model_generated.h
+$(ACTIONS_MODEL_H): PRIVATE_INPUT_FBS := $(ACTIONS_MODEL_FBS)
+$(ACTIONS_MODEL_H): INPUT_DIR := $(LOCAL_PATH)
+$(ACTIONS_MODEL_H): $(FLATC) $(ACTIONS_MODEL_FBS)
+ $(transform-fbs-to-cpp)
+LOCAL_GENERATED_SOURCES += $(ACTIONS_MODEL_H)
+
+# Generate utils/tflite/text_encoder_config_generated.h using FlatBuffer schema compiler.
+UTILS_TFLITE_TEXT_ENCODER_CONFIG_FBS := $(LOCAL_PATH)/utils/tflite/text_encoder_config.fbs
+UTILS_TFLITE_TEXT_ENCODER_CONFIG_H := $(intermediates)/utils/tflite/text_encoder_config_generated.h
+$(UTILS_TFLITE_TEXT_ENCODER_CONFIG_H): PRIVATE_INPUT_FBS := $(UTILS_TFLITE_TEXT_ENCODER_CONFIG_FBS)
+$(UTILS_TFLITE_TEXT_ENCODER_CONFIG_H): INPUT_DIR := $(LOCAL_PATH)
+$(UTILS_TFLITE_TEXT_ENCODER_CONFIG_H): $(FLATC) $(UTILS_TFLITE_TEXT_ENCODER_CONFIG_FBS)
+ $(transform-fbs-to-cpp)
+LOCAL_GENERATED_SOURCES += $(UTILS_TFLITE_TEXT_ENCODER_CONFIG_H)
+
+# Generate lang_id/common/flatbuffers/embedding-network_generated.h using FlatBuffer schema compiler.
+LANG_ID_COMMON_FLATBUFFERS_EMBEDDING_NETWORK_FBS := $(LOCAL_PATH)/lang_id/common/flatbuffers/embedding-network.fbs
+LANG_ID_COMMON_FLATBUFFERS_EMBEDDING_NETWORK_H := $(intermediates)/lang_id/common/flatbuffers/embedding-network_generated.h
+$(LANG_ID_COMMON_FLATBUFFERS_EMBEDDING_NETWORK_H): PRIVATE_INPUT_FBS := $(LANG_ID_COMMON_FLATBUFFERS_EMBEDDING_NETWORK_FBS)
+$(LANG_ID_COMMON_FLATBUFFERS_EMBEDDING_NETWORK_H): INPUT_DIR := $(LOCAL_PATH)
+$(LANG_ID_COMMON_FLATBUFFERS_EMBEDDING_NETWORK_H): $(FLATC) $(LANG_ID_COMMON_FLATBUFFERS_EMBEDDING_NETWORK_FBS)
+ $(transform-fbs-to-cpp)
+LOCAL_GENERATED_SOURCES += $(LANG_ID_COMMON_FLATBUFFERS_EMBEDDING_NETWORK_H)
+
+# Generate lang_id/common/flatbuffers/model_generated.h using FlatBuffer schema compiler.
+LANG_ID_COMMON_FLATBUFFERS_MODEL_FBS := $(LOCAL_PATH)/lang_id/common/flatbuffers/model.fbs
+LANG_ID_COMMON_FLATBUFFERS_MODEL_H := $(intermediates)/lang_id/common/flatbuffers/model_generated.h
+$(LANG_ID_COMMON_FLATBUFFERS_MODEL_H): PRIVATE_INPUT_FBS := $(LANG_ID_COMMON_FLATBUFFERS_MODEL_FBS)
+$(LANG_ID_COMMON_FLATBUFFERS_MODEL_H): INPUT_DIR := $(LOCAL_PATH)
+$(LANG_ID_COMMON_FLATBUFFERS_MODEL_H): $(FLATC) $(LANG_ID_COMMON_FLATBUFFERS_MODEL_FBS)
+ $(transform-fbs-to-cpp)
+LOCAL_GENERATED_SOURCES += $(LANG_ID_COMMON_FLATBUFFERS_MODEL_H)
diff --git a/java/com/google/android/textclassifier/ActionsSuggestionsModel.java b/java/com/google/android/textclassifier/ActionsSuggestionsModel.java
new file mode 100644
index 0000000..eaec8ba
--- /dev/null
+++ b/java/com/google/android/textclassifier/ActionsSuggestionsModel.java
@@ -0,0 +1,221 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.google.android.textclassifier;
+
+import java.util.concurrent.atomic.AtomicBoolean;
+
+/**
+ * Java wrapper for ActionsSuggestions native library interface. This library is used to suggest
+ * actions and replies in a given conversation.
+ *
+ * @hide
+ */
+public final class ActionsSuggestionsModel implements AutoCloseable {
+ private final AtomicBoolean isClosed = new AtomicBoolean(false);
+
+ static {
+ System.loadLibrary("textclassifier");
+ }
+
+ private long actionsModelPtr;
+ private AnnotatorModel annotator;
+
+ /**
+ * Creates a new instance of Actions predictor, using the provided model image, given as a file
+ * descriptor.
+ */
+ public ActionsSuggestionsModel(int fileDescriptor) {
+ this(fileDescriptor, null);
+ }
+
+ public ActionsSuggestionsModel(int fileDescriptor, AnnotatorModel annotator) {
+ actionsModelPtr = nativeNewActionsModel(fileDescriptor);
+ if (actionsModelPtr == 0L) {
+ throw new IllegalArgumentException("Couldn't initialize actions model from file descriptor.");
+ }
+ this.annotator = annotator;
+ }
+
+ /**
+ * Creates a new instance of Actions predictor, using the provided model image, given as a file
+ * path.
+ */
+ public ActionsSuggestionsModel(String path) {
+ this(path, null);
+ }
+
+ public ActionsSuggestionsModel(String path, AnnotatorModel annotator) {
+ actionsModelPtr = nativeNewActionsModelFromPath(path);
+ if (actionsModelPtr == 0L) {
+ throw new IllegalArgumentException("Couldn't initialize actions model from given file.");
+ }
+ this.annotator = annotator;
+ }
+
+ /** Suggests actions / replies to the given conversation. */
+ public ActionSuggestion[] suggestActions(
+ Conversation conversation, ActionSuggestionOptions options) {
+ return nativeSuggestActions(
+ actionsModelPtr,
+ conversation,
+ options,
+ (annotator != null ? annotator.getNativeAnnotator() : 0));
+ }
+
+ /** Frees up the allocated memory. */
+ @Override
+ public void close() {
+ if (isClosed.compareAndSet(false, true)) {
+ nativeCloseActionsModel(actionsModelPtr);
+ actionsModelPtr = 0L;
+ }
+ }
+
+ @Override
+ protected void finalize() throws Throwable {
+ try {
+ close();
+ } finally {
+ super.finalize();
+ }
+ }
+
+ /** Returns a comma separated list of locales supported by the model as BCP 47 tags. */
+ public static String getLocales(int fd) {
+ return nativeGetLocales(fd);
+ }
+
+ /** Returns the version of the model. */
+ public static int getVersion(int fd) {
+ return nativeGetVersion(fd);
+ }
+
+ /** Returns the name of the model. */
+ public static String getName(int fd) {
+ return nativeGetName(fd);
+ }
+
+ /** Action suggestion that contains a response text and the type of the response. */
+ public static final class ActionSuggestion {
+ private final String responseText;
+ private final String actionType;
+ private final float score;
+
+ public ActionSuggestion(String responseText, String actionType, float score) {
+ this.responseText = responseText;
+ this.actionType = actionType;
+ this.score = score;
+ }
+
+ public String getResponseText() {
+ return responseText;
+ }
+
+ public String getActionType() {
+ return actionType;
+ }
+
+ /** Confidence score between 0 and 1 */
+ public float getScore() {
+ return score;
+ }
+ }
+
+ /** Represents a single message in the conversation. */
+ public static final class ConversationMessage {
+ private final int userId;
+ private final String text;
+ private final long referenceTimeMsUtc;
+ private final String locales;
+
+ public ConversationMessage(int userId, String text, long referenceTimeMsUtc, String locales) {
+ this.userId = userId;
+ this.text = text;
+ this.referenceTimeMsUtc = referenceTimeMsUtc;
+ this.locales = locales;
+ }
+
+ /** The identifier of the sender */
+ public int getUserId() {
+ return userId;
+ }
+
+ public String getText() {
+ return text;
+ }
+
+ /**
+ * Return the reference time of the message, for example, it could be compose time or send time.
+ * {@code 0} means unspecified.
+ */
+ public long getReferenceTimeMsUtc() {
+ return referenceTimeMsUtc;
+ }
+
+ /** Returns a comma separated list of locales supported by the model as BCP 47 tags. */
+ public String getLocales() {
+ return locales;
+ }
+ }
+
+ /** Represents conversation between multiple users. */
+ public static final class Conversation {
+ public final ConversationMessage[] conversationMessages;
+
+ public Conversation(ConversationMessage[] conversationMessages) {
+ this.conversationMessages = conversationMessages;
+ }
+
+ public ConversationMessage[] getConversationMessages() {
+ return conversationMessages;
+ }
+ }
+
+ /** Represents options for the SuggestActions call. */
+ public static final class ActionSuggestionOptions {
+ private final AnnotatorModel.AnnotationOptions annotationOptions;
+
+ public ActionSuggestionOptions() {
+ this.annotationOptions = null;
+ }
+
+ public ActionSuggestionOptions(AnnotatorModel.AnnotationOptions annotationOptions) {
+ this.annotationOptions = annotationOptions;
+ }
+
+ public AnnotatorModel.AnnotationOptions getAnnotationOptions() {
+ return annotationOptions;
+ }
+ }
+
+ private static native long nativeNewActionsModel(int fd);
+
+ private static native long nativeNewActionsModelFromPath(String path);
+
+ private static native String nativeGetLocales(int fd);
+
+ private static native int nativeGetVersion(int fd);
+
+ private static native String nativeGetName(int fd);
+
+ private native ActionSuggestion[] nativeSuggestActions(
+ long context, Conversation conversation, ActionSuggestionOptions options, long annotatorPtr);
+
+ private native void nativeCloseActionsModel(long context);
+
+ private native void nativeSetAnnotator(long annotatorPtr);
+}
diff --git a/java/com/google/android/textclassifier/LangIdModel.java b/java/com/google/android/textclassifier/LangIdModel.java
new file mode 100644
index 0000000..4b10b9f
--- /dev/null
+++ b/java/com/google/android/textclassifier/LangIdModel.java
@@ -0,0 +1,113 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.google.android.textclassifier;
+
+import java.util.concurrent.atomic.AtomicBoolean;
+
+/**
+ * Java wrapper for LangId native library interface. This class is used to detect languages in text.
+ *
+ * @hide
+ */
+public final class LangIdModel implements AutoCloseable {
+ private final AtomicBoolean isClosed = new AtomicBoolean(false);
+
+ static {
+ System.loadLibrary("textclassifier");
+ }
+
+ private long modelPtr;
+
+ /** Creates a new instance of LangId predictor, using the provided model image. */
+ public LangIdModel(int fd) {
+ modelPtr = nativeNew(fd);
+ if (modelPtr == 0L) {
+ throw new IllegalArgumentException("Couldn't initialize LangId from given file descriptor.");
+ }
+ }
+
+ /** Creates a new instance of LangId predictor, using the provided model image. */
+ public LangIdModel(String modelPath) {
+ modelPtr = nativeNewFromPath(modelPath);
+ if (modelPtr == 0L) {
+ throw new IllegalArgumentException("Couldn't initialize LangId from given file.");
+ }
+ }
+
+ /** Detects the languages for given text. */
+ public LanguageResult[] detectLanguages(String text) {
+ return nativeDetectLanguages(modelPtr, text);
+ }
+
+ /** Frees up the allocated memory. */
+ @Override
+ public void close() {
+ if (isClosed.compareAndSet(false, true)) {
+ nativeClose(modelPtr);
+ modelPtr = 0L;
+ }
+ }
+
+ @Override
+ protected void finalize() throws Throwable {
+ try {
+ close();
+ } finally {
+ super.finalize();
+ }
+ }
+
+ /** Result for detectLanguages method. */
+ public static final class LanguageResult {
+ final String mLanguage;
+ final float mScore;
+
+ LanguageResult(String language, float score) {
+ mLanguage = language;
+ mScore = score;
+ }
+
+ public final String getLanguage() {
+ return mLanguage;
+ }
+
+ public final float getScore() {
+ return mScore;
+ }
+ }
+
+ /** Returns the version of the LangId model used. */
+ public int getVersion() {
+ return nativeGetVersion(modelPtr);
+ }
+
+ public static int getVersion(int fd) {
+ return nativeGetVersionFromFd(fd);
+ }
+
+ private static native long nativeNew(int fd);
+
+ private static native long nativeNewFromPath(String path);
+
+ private native LanguageResult[] nativeDetectLanguages(long nativePtr, String text);
+
+ private native void nativeClose(long nativePtr);
+
+ private native int nativeGetVersion(long nativePtr);
+
+ private static native int nativeGetVersionFromFd(int fd);
+}
diff --git a/lang_id/common/embedding-feature-extractor.cc b/lang_id/common/embedding-feature-extractor.cc
new file mode 100644
index 0000000..6235f89
--- /dev/null
+++ b/lang_id/common/embedding-feature-extractor.cc
@@ -0,0 +1,73 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/common/embedding-feature-extractor.h"
+
+#include <stddef.h>
+
+#include <string>
+#include <vector>
+
+#include "lang_id/common/fel/feature-extractor.h"
+#include "lang_id/common/fel/feature-types.h"
+#include "lang_id/common/fel/task-context.h"
+#include "lang_id/common/lite_base/integral-types.h"
+#include "lang_id/common/lite_base/logging.h"
+#include "lang_id/common/lite_strings/numbers.h"
+#include "lang_id/common/lite_strings/str-split.h"
+#include "lang_id/common/lite_strings/stringpiece.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+bool GenericEmbeddingFeatureExtractor::Setup(TaskContext *context) {
+ // Don't use version to determine how to get feature FML.
+ const string features = context->Get(GetParamName("features"), "");
+ const string embedding_names =
+ context->Get(GetParamName("embedding_names"), "");
+ const string embedding_dims =
+ context->Get(GetParamName("embedding_dims"), "");
+
+ // NOTE: unfortunately, LiteStrSplit returns a vector of StringPieces pointing
+ // to the original string, in this case |features|, which is local to this
+ // method. We need to explicitly create new strings.
+ for (StringPiece sp : LiteStrSplit(features, ';')) {
+ embedding_fml_.emplace_back(sp);
+ }
+
+ // Same here.
+ for (StringPiece sp : LiteStrSplit(embedding_names, ';')) {
+ embedding_names_.emplace_back(sp);
+ }
+
+ std::vector<StringPiece> dim_strs = LiteStrSplit(embedding_dims, ';');
+ for (const auto &dim_str : dim_strs) {
+ int dim = 0;
+ if (!LiteAtoi(dim_str, &dim)) {
+ SAFTM_LOG(ERROR) << "Unable to parse " << dim_str;
+ return false;
+ }
+ embedding_dims_.push_back(dim);
+ }
+ return true;
+}
+
+bool GenericEmbeddingFeatureExtractor::Init(TaskContext *context) {
+ return true;
+}
+
+} // namespace mobile
+} // namespace nlp_saft
diff --git a/lang_id/common/embedding-feature-extractor.h b/lang_id/common/embedding-feature-extractor.h
new file mode 100644
index 0000000..f51b6e5
--- /dev/null
+++ b/lang_id/common/embedding-feature-extractor.h
@@ -0,0 +1,174 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_FEATURE_EXTRACTOR_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_FEATURE_EXTRACTOR_H_
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "lang_id/common/fel/feature-extractor.h"
+#include "lang_id/common/fel/task-context.h"
+#include "lang_id/common/fel/workspace.h"
+#include "lang_id/common/lite_base/attributes.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+// An EmbeddingFeatureExtractor manages the extraction of features for
+// embedding-based models. It wraps a sequence of underlying classes of feature
+// extractors, along with associated predicate maps. Each class of feature
+// extractors is associated with a name, e.g., "words", "labels", "tags".
+//
+// The class is split between a generic abstract version,
+// GenericEmbeddingFeatureExtractor (that can be initialized without knowing the
+// signature of the ExtractFeatures method) and a typed version.
+//
+// The predicate maps must be initialized before use: they can be loaded using
+// Read() or updated via UpdateMapsForExample.
+class GenericEmbeddingFeatureExtractor {
+ public:
+ // Constructs this GenericEmbeddingFeatureExtractor.
+ //
+ // |arg_prefix| is a string prefix for the relevant TaskContext parameters, to
+ // avoid name clashes. See GetParamName().
+ explicit GenericEmbeddingFeatureExtractor(const string &arg_prefix)
+ : arg_prefix_(arg_prefix) {}
+
+ virtual ~GenericEmbeddingFeatureExtractor() {}
+
+ // Sets/inits up predicate maps and embedding space names that are common for
+ // all embedding based feature extractors.
+ //
+ // Returns true on success, false otherwise.
+ SAFTM_MUST_USE_RESULT virtual bool Setup(TaskContext *context);
+ SAFTM_MUST_USE_RESULT virtual bool Init(TaskContext *context);
+
+ // Requests workspace for the underlying feature extractors. This is
+ // implemented in the typed class.
+ virtual void RequestWorkspaces(WorkspaceRegistry *registry) = 0;
+
+ // Returns number of embedding spaces.
+ int NumEmbeddings() const { return embedding_dims_.size(); }
+
+ const std::vector<string> &embedding_fml() const { return embedding_fml_; }
+
+ // Get parameter name by concatenating the prefix and the original name.
+ string GetParamName(const string ¶m_name) const {
+ string full_name = arg_prefix_;
+ full_name.push_back('_');
+ full_name.append(param_name);
+ return full_name;
+ }
+
+ private:
+ // Prefix for TaskContext parameters.
+ const string arg_prefix_;
+
+ // Embedding space names for parameter sharing.
+ std::vector<string> embedding_names_;
+
+ // FML strings for each feature extractor.
+ std::vector<string> embedding_fml_;
+
+ // Size of each of the embedding spaces (maximum predicate id).
+ std::vector<int> embedding_sizes_;
+
+ // Embedding dimensions of the embedding spaces (i.e. 32, 64 etc.)
+ std::vector<int> embedding_dims_;
+};
+
+// Templated, object-specific implementation of the
+// EmbeddingFeatureExtractor. EXTRACTOR should be a FeatureExtractor<OBJ,
+// ARGS...> class that has the appropriate FeatureTraits() to ensure that
+// locator type features work.
+//
+// Note: for backwards compatibility purposes, this always reads the FML spec
+// from "<prefix>_features".
+template <class EXTRACTOR, class OBJ, class... ARGS>
+class EmbeddingFeatureExtractor : public GenericEmbeddingFeatureExtractor {
+ public:
+ // Constructs this EmbeddingFeatureExtractor.
+ //
+ // |arg_prefix| is a string prefix for the relevant TaskContext parameters, to
+ // avoid name clashes. See GetParamName().
+ explicit EmbeddingFeatureExtractor(const string &arg_prefix)
+ : GenericEmbeddingFeatureExtractor(arg_prefix) {}
+
+ // Sets up all predicate maps, feature extractors, and flags.
+ SAFTM_MUST_USE_RESULT bool Setup(TaskContext *context) override {
+ if (!GenericEmbeddingFeatureExtractor::Setup(context)) {
+ return false;
+ }
+ feature_extractors_.resize(embedding_fml().size());
+ for (int i = 0; i < embedding_fml().size(); ++i) {
+ feature_extractors_[i].reset(new EXTRACTOR());
+ if (!feature_extractors_[i]->Parse(embedding_fml()[i])) return false;
+ if (!feature_extractors_[i]->Setup(context)) return false;
+ }
+ return true;
+ }
+
+ // Initializes resources needed by the feature extractors.
+ SAFTM_MUST_USE_RESULT bool Init(TaskContext *context) override {
+ if (!GenericEmbeddingFeatureExtractor::Init(context)) return false;
+ for (auto &feature_extractor : feature_extractors_) {
+ if (!feature_extractor->Init(context)) return false;
+ }
+ return true;
+ }
+
+ // Requests workspaces from the registry. Must be called after Init(), and
+ // before Preprocess().
+ void RequestWorkspaces(WorkspaceRegistry *registry) override {
+ for (auto &feature_extractor : feature_extractors_) {
+ feature_extractor->RequestWorkspaces(registry);
+ }
+ }
+
+ // Must be called on the object one state for each sentence, before any
+ // feature extraction (e.g., UpdateMapsForExample, ExtractFeatures).
+ void Preprocess(WorkspaceSet *workspaces, OBJ *obj) const {
+ for (auto &feature_extractor : feature_extractors_) {
+ feature_extractor->Preprocess(workspaces, obj);
+ }
+ }
+
+ // Extracts features using the extractors. Note that features must already
+ // be initialized to the correct number of feature extractors. No predicate
+ // mapping is applied.
+ void ExtractFeatures(const WorkspaceSet &workspaces, const OBJ &obj,
+ ARGS... args,
+ std::vector<FeatureVector> *features) const {
+ // DCHECK(features != nullptr);
+ // DCHECK_EQ(features->size(), feature_extractors_.size());
+ for (int i = 0; i < feature_extractors_.size(); ++i) {
+ (*features)[i].clear();
+ feature_extractors_[i]->ExtractFeatures(workspaces, obj, args...,
+ &(*features)[i]);
+ }
+ }
+
+ private:
+ // Templated feature extractor class.
+ std::vector<std::unique_ptr<EXTRACTOR>> feature_extractors_;
+};
+
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_FEATURE_EXTRACTOR_H_
diff --git a/lang_id/common/embedding-feature-interface.h b/lang_id/common/embedding-feature-interface.h
new file mode 100644
index 0000000..87576c6
--- /dev/null
+++ b/lang_id/common/embedding-feature-interface.h
@@ -0,0 +1,115 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_FEATURE_INTERFACE_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_FEATURE_INTERFACE_H_
+
+#include <string>
+#include <vector>
+
+#include "lang_id/common/embedding-feature-extractor.h"
+#include "lang_id/common/fel/feature-extractor.h"
+#include "lang_id/common/fel/task-context.h"
+#include "lang_id/common/fel/workspace.h"
+#include "lang_id/common/lite_base/attributes.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+template <class EXTRACTOR, class OBJ, class... ARGS>
+class EmbeddingFeatureInterface {
+ public:
+ // Constructs this EmbeddingFeatureInterface.
+ //
+ // |arg_prefix| is a string prefix for the TaskContext parameters, passed to
+ // |the underlying EmbeddingFeatureExtractor.
+ explicit EmbeddingFeatureInterface(const string &arg_prefix)
+ : feature_extractor_(arg_prefix) {}
+
+ // Sets up feature extractors and flags for processing (inference).
+ SAFTM_MUST_USE_RESULT bool SetupForProcessing(TaskContext *context) {
+ return feature_extractor_.Setup(context);
+ }
+
+ // Initializes feature extractor resources for processing (inference)
+ // including requesting a workspace for caching extracted features.
+ SAFTM_MUST_USE_RESULT bool InitForProcessing(TaskContext *context) {
+ if (!feature_extractor_.Init(context)) return false;
+ feature_extractor_.RequestWorkspaces(&workspace_registry_);
+ return true;
+ }
+
+ // Preprocesses *obj using the internal workspace registry.
+ void Preprocess(WorkspaceSet *workspace, OBJ *obj) const {
+ workspace->Reset(workspace_registry_);
+ feature_extractor_.Preprocess(workspace, obj);
+ }
+
+ // Extract features from |obj|. On return, FeatureVector features[i]
+ // contains the features for the embedding space #i.
+ //
+ // This function uses the precomputed info from |workspace|. Usage pattern:
+ //
+ // EmbeddingFeatureInterface<...> feature_interface;
+ // ...
+ // OBJ obj;
+ // WorkspaceSet workspace;
+ // feature_interface.Preprocess(&workspace, &obj);
+ //
+ // // For the same obj, but with different args:
+ // std::vector<FeatureVector> features;
+ // feature_interface.GetFeatures(obj, args, workspace, &features);
+ //
+ // This pattern is useful (more efficient) if you can pre-compute some info
+ // for the entire |obj|, which is reused by the feature extraction performed
+ // for different args. If that is not the case, you can use the simpler
+ // version GetFeaturesNoCaching below.
+ void GetFeatures(const OBJ &obj, ARGS... args, const WorkspaceSet &workspace,
+ std::vector<FeatureVector> *features) const {
+ feature_extractor_.ExtractFeatures(workspace, obj, args..., features);
+ }
+
+ // Simpler version of GetFeatures(), for cases when there is no opportunity to
+ // reuse computation between feature extractions for the same |obj|, but with
+ // different |args|. Returns the extracted features. For more info, see the
+ // doc for GetFeatures().
+ std::vector<FeatureVector> GetFeaturesNoCaching(OBJ *obj,
+ ARGS... args) const {
+ // Technically, we still use a workspace, because
+ // feature_extractor_.ExtractFeatures requires one. But there is no real
+ // caching here, as we start from scratch for each call to ExtractFeatures.
+ WorkspaceSet workspace;
+ Preprocess(&workspace, obj);
+ std::vector<FeatureVector> features(NumEmbeddings());
+ GetFeatures(*obj, args..., workspace, &features);
+ return features;
+ }
+
+ // Returns number of embedding spaces.
+ int NumEmbeddings() const { return feature_extractor_.NumEmbeddings(); }
+
+ private:
+ // Typed feature extractor for embeddings.
+ EmbeddingFeatureExtractor<EXTRACTOR, OBJ, ARGS...> feature_extractor_;
+
+ // The registry of shared workspaces in the feature extractor.
+ WorkspaceRegistry workspace_registry_;
+};
+
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_FEATURE_INTERFACE_H_
diff --git a/lang_id/common/embedding-network-params.cc b/lang_id/common/embedding-network-params.cc
new file mode 100644
index 0000000..be7c80e
--- /dev/null
+++ b/lang_id/common/embedding-network-params.cc
@@ -0,0 +1,44 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/common/embedding-network-params.h"
+
+#include "lang_id/common/lite_base/logging.h"
+
+namespace libtextclassifier3 {
+
+QuantizationType ParseQuantizationType(const string &s) {
+ if (s == "NONE") {
+ return QuantizationType::NONE;
+ }
+ if (s == "UINT8") {
+ return QuantizationType::UINT8;
+ }
+ if (s == "UINT4") {
+ return QuantizationType::UINT4;
+ }
+ if (s == "FLOAT16") {
+ return QuantizationType::FLOAT16;
+ }
+ SAFTM_LOG(FATAL) << "Unsupported quantization type: " << s;
+
+ // Execution should never reach this point; just to keep the compiler happy.
+ // TODO(salcianu): implement SAFTM_LOG(FATAL) in a way that doesn't require
+ // this trick.
+ return QuantizationType::NONE;
+}
+
+} // namespace nlp_saft
diff --git a/lang_id/common/embedding-network-params.h b/lang_id/common/embedding-network-params.h
new file mode 100755
index 0000000..f43c653
--- /dev/null
+++ b/lang_id/common/embedding-network-params.h
@@ -0,0 +1,316 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_NETWORK_PARAMS_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_NETWORK_PARAMS_H_
+
+#include <string>
+
+#include "lang_id/common/fel/task-context.h"
+#include "lang_id/common/lite_base/float16.h"
+#include "lang_id/common/lite_base/logging.h"
+
+namespace libtextclassifier3 {
+
+enum class QuantizationType {
+ NONE = 0,
+
+ // Quantization to 8 bit unsigned ints.
+ UINT8,
+
+ // Quantization to 4 bit unsigned ints.
+ UINT4,
+
+ // Quantization to 16 bit floats, the type defined in
+ // lang_id/common/float16.h
+ FLOAT16,
+
+ // NOTE: for backward compatibility, if you add a new value to this enum, add
+ // it *at the end*, such that you do not change the integer values of the
+ // existing enum values.
+};
+
+// Converts "UINT8" -> QuantizationType::UINT8, and so on.
+QuantizationType ParseQuantizationType(const string &s);
+
+// API for accessing parameters for a feed-forward neural network with
+// embeddings.
+//
+//
+// In fact, we provide two APIs: a high-level (and highly-recommented) API, with
+// methods named using the BigCamel notation (e.g., GetEmbeddingMatrix()) and a
+// low-level API, using C-style names (e.g., softmax_num_cols()).
+//
+// Note: the API below is meant to allow the inference code (the class
+// libtextclassifier3::mobile::EmbeddingNetwork) to use the data directly, with no need
+// for transposing any matrix (which would require extra overhead on mobile
+// devices). Hence, as indicated by the comments for the API methods, some of
+// the matrices below are the transposes of the corresponding matrices from the
+// original proto.
+class EmbeddingNetworkParams {
+ public:
+ virtual ~EmbeddingNetworkParams() {}
+
+ // Returns true if these params are valid. False otherwise (e.g., if the
+ // underlying data is corrupted). If is_valid() returns false, clients should
+ // not call any other method on that instance of EmbeddingNetworkParams. If
+ // is_valid() returns true, then calls to the API methods below should not
+ // crash *if they are called with index parameters in bounds*. E.g., if
+ // is_valid() and 0 <= i < embeddings_size(), then GetEmbeddingMatrix(i)
+ // should not crash.
+ virtual bool is_valid() const = 0;
+
+ // **** High-level API.
+
+ // Simple representation of a matrix. This small struct that doesn't own any
+ // resource intentionally supports copy / assign, to simplify our APIs.
+ struct Matrix {
+ // Number of rows.
+ int rows = 0;
+
+ // Number of columns.
+ int cols = 0;
+
+ QuantizationType quant_type = QuantizationType::NONE;
+
+ // Pointer to matrix elements, in row-major order
+ // (https://en.wikipedia.org/wiki/Row-major_order) Not owned.
+ const void *elements = nullptr;
+
+ // Quantization scales: one scale for each row.
+ const ::libtextclassifier3::mobile::float16 *quant_scales = nullptr;
+ };
+
+ // Returns i-th embedding matrix. Crashes on out of bounds indices.
+ //
+ // This is the transpose of the corresponding matrix from the original proto.
+ Matrix GetEmbeddingMatrix(int i) const {
+ CheckIndex(i, embeddings_size(), "embedding matrix");
+ Matrix matrix;
+ matrix.rows = embeddings_num_rows(i);
+ matrix.cols = embeddings_num_cols(i);
+ matrix.elements = embeddings_weights(i);
+ matrix.quant_type = embeddings_quant_type(i);
+ matrix.quant_scales = embeddings_quant_scales(i);
+ return matrix;
+ }
+
+ // Returns weight matrix for i-th hidden layer. Crashes on out of bounds
+ // indices.
+ //
+ // This is the transpose of the corresponding matrix from the original proto.
+ Matrix GetHiddenLayerMatrix(int i) const {
+ CheckIndex(i, hidden_size(), "hidden layer");
+ Matrix matrix;
+ matrix.rows = hidden_num_rows(i);
+ matrix.cols = hidden_num_cols(i);
+
+ // Quantization not supported here.
+ matrix.quant_type = hidden_weights_quant_type(i);
+ matrix.elements = hidden_weights(i);
+ return matrix;
+ }
+
+ // Returns bias for i-th hidden layer. Technically a Matrix, but we expect it
+ // to be a row/column vector (i.e., num rows or num cols is 1). However, we
+ // don't CHECK for that: we just provide access to underlying data. Crashes
+ // on out of bounds indices.
+ Matrix GetHiddenLayerBias(int i) const {
+ CheckIndex(i, hidden_bias_size(), "hidden layer bias");
+ Matrix matrix;
+ matrix.rows = hidden_bias_num_rows(i);
+ matrix.cols = hidden_bias_num_cols(i);
+
+ // Quantization not supported here.
+ matrix.quant_type = QuantizationType::NONE;
+ matrix.elements = hidden_bias_weights(i);
+ return matrix;
+ }
+
+ // Returns true if a softmax layer exists.
+ bool HasSoftmax() const {
+ return softmax_size() == 1;
+ }
+
+ // Returns weight matrix for the softmax layer. Note: should be called only
+ // if HasSoftmax() is true.
+ //
+ // This is the transpose of the corresponding matrix from the original proto.
+ Matrix GetSoftmaxMatrix() const {
+ SAFTM_CHECK(HasSoftmax()) << "No softmax layer.";
+ Matrix matrix;
+ matrix.rows = softmax_num_rows(0);
+ matrix.cols = softmax_num_cols(0);
+
+ // Quantization not supported here.
+ matrix.quant_type = softmax_weights_quant_type(0);
+ matrix.elements = softmax_weights(0);
+ return matrix;
+ }
+
+ // Returns bias for the softmax layer. Technically a Matrix, but we expect it
+ // to be a row/column vector (i.e., num rows or num cols is 1). However, we
+ // don't CHECK for that: we just provide access to underlying data.
+ Matrix GetSoftmaxBias() const {
+ SAFTM_CHECK(HasSoftmax()) << "No softmax layer.";
+ Matrix matrix;
+ matrix.rows = softmax_bias_num_rows(0);
+ matrix.cols = softmax_bias_num_cols(0);
+
+ // Quantization not supported here.
+ matrix.quant_type = QuantizationType::NONE;
+ matrix.elements = softmax_bias_weights(0);
+ return matrix;
+ }
+
+ // Updates the EmbeddingNetwork-related parameters from task_context. Returns
+ // true on success, false on error.
+ virtual bool UpdateTaskContextParameters(
+ mobile::TaskContext *task_context) = 0;
+
+ // **** Low-level API.
+ //
+ // * Most low-level API methods are documented by giving an equivalent
+ // function call on proto, the original proto (of type
+ // EmbeddingNetworkProto) which was used to generate the C++ code.
+ //
+ // * To simplify our generation code, optional proto fields of message type
+ // are treated as repeated fields with 0 or 1 instances. As such, we have
+ // *_size() methods for such optional fields: they return 0 or 1.
+ //
+ // * "transpose(M)" denotes the transpose of a matrix M.
+
+ // ** Access methods for repeated MatrixParams embeddings.
+ //
+ // Returns proto.embeddings_size().
+ virtual int embeddings_size() const = 0;
+
+ // Returns number of rows of transpose(proto.embeddings(i)).
+ virtual int embeddings_num_rows(int i) const = 0;
+
+ // Returns number of columns of transpose(proto.embeddings(i)).
+ virtual int embeddings_num_cols(int i) const = 0;
+
+ // Returns pointer to elements of transpose(proto.embeddings(i)), in row-major
+ // order. NOTE: for unquantized embeddings, this returns a pointer to float;
+ // for quantized embeddings, this returns a pointer to uint8.
+ virtual const void *embeddings_weights(int i) const = 0;
+
+ virtual QuantizationType embeddings_quant_type(int i) const {
+ return QuantizationType::NONE;
+ }
+
+ virtual const ::libtextclassifier3::mobile::float16 *embeddings_quant_scales(
+ int i) const {
+ return nullptr;
+ }
+
+ // ** Access methods for repeated MatrixParams hidden.
+ //
+ // Returns embedding_network_proto.hidden_size().
+ virtual int hidden_size() const = 0;
+
+ // Returns embedding_network_proto.hidden(i).rows().
+ virtual int hidden_num_rows(int i) const = 0;
+
+ // Returns embedding_network_proto.hidden(i).rows().
+ virtual int hidden_num_cols(int i) const = 0;
+
+ // Returns quantization mode for the weights of the i-th hidden layer.
+ virtual QuantizationType hidden_weights_quant_type(int i) const {
+ return QuantizationType::NONE;
+ }
+
+ // Returns pointer to beginning of array of floats with all values from
+ // embedding_network_proto.hidden(i).
+ virtual const void *hidden_weights(int i) const = 0;
+
+ // ** Access methods for repeated MatrixParams hidden_bias.
+ //
+ // Returns proto.hidden_bias_size().
+ virtual int hidden_bias_size() const = 0;
+
+ // Returns number of rows of proto.hidden_bias(i).
+ virtual int hidden_bias_num_rows(int i) const = 0;
+
+ // Returns number of columns of proto.hidden_bias(i).
+ virtual int hidden_bias_num_cols(int i) const = 0;
+
+ // Returns pointer to elements of proto.hidden_bias(i), in row-major order.
+ virtual const void *hidden_bias_weights(int i) const = 0;
+
+ // ** Access methods for optional MatrixParams softmax.
+ //
+ // Returns 1 if proto has optional field softmax, 0 otherwise.
+ virtual int softmax_size() const = 0;
+
+ // Returns number of rows of transpose(proto.softmax()).
+ virtual int softmax_num_rows(int i) const = 0;
+
+ // Returns number of columns of transpose(proto.softmax()).
+ virtual int softmax_num_cols(int i) const = 0;
+
+ // Returns quantization mode for the softmax weights.
+ virtual QuantizationType softmax_weights_quant_type(int i) const {
+ return QuantizationType::NONE;
+ }
+
+ // Returns pointer to elements of transpose(proto.softmax()), in row-major
+ // order.
+ virtual const void *softmax_weights(int i) const = 0;
+
+ // ** Access methods for optional MatrixParams softmax_bias.
+ //
+ // Returns 1 if proto has optional field softmax_bias, 0 otherwise.
+ virtual int softmax_bias_size() const = 0;
+
+ // Returns number of rows of proto.softmax_bias().
+ virtual int softmax_bias_num_rows(int i) const = 0;
+
+ // Returns number of columns of proto.softmax_bias().
+ virtual int softmax_bias_num_cols(int i) const = 0;
+
+ // Returns pointer to elements of proto.softmax_bias(), in row-major order.
+ virtual const void *softmax_bias_weights(int i) const = 0;
+
+ // ** Access methods for repeated int32 embedding_num_features.
+ //
+ // Returns proto.embedding_num_features_size().
+ virtual int embedding_num_features_size() const = 0;
+
+ // Returns proto.embedding_num_features(i).
+ virtual int embedding_num_features(int i) const = 0;
+
+ // ** Access methods for is_precomputed
+ //
+ // Returns proto.has_is_precomputed().
+ virtual bool has_is_precomputed() const = 0;
+
+ // Returns proto.is_precomputed().
+ virtual bool is_precomputed() const = 0;
+
+ protected:
+ void CheckIndex(int index, int size, const string &description) const {
+ SAFTM_CHECK_GE(index, 0)
+ << "Out-of-range index for " << description << ": " << index;
+ SAFTM_CHECK_LT(index, size)
+ << "Out-of-range index for " << description << ": " << index;
+ }
+}; // class EmbeddingNetworkParams
+
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_NETWORK_PARAMS_H_
diff --git a/lang_id/common/embedding-network.cc b/lang_id/common/embedding-network.cc
new file mode 100644
index 0000000..469cb1f
--- /dev/null
+++ b/lang_id/common/embedding-network.cc
@@ -0,0 +1,323 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/common/embedding-network.h"
+
+#include "lang_id/common/lite_base/integral-types.h"
+#include "lang_id/common/lite_base/logging.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+namespace {
+
+void CheckNoQuantization(const EmbeddingNetworkParams::Matrix &matrix) {
+ SAFTM_CHECK_EQ(static_cast<int>(QuantizationType::NONE),
+ static_cast<int>(matrix.quant_type))
+ << "Quantization not allowed here";
+}
+
+int GetMatrixRowSizeInBytes(const EmbeddingNetworkParams::Matrix &matrix) {
+ int cols = matrix.cols;
+ QuantizationType quant_type = matrix.quant_type;
+ switch (quant_type) {
+ case QuantizationType::NONE:
+ return cols * sizeof(float);
+ case QuantizationType::UINT8:
+ return cols * sizeof(uint8);
+ case QuantizationType::UINT4:
+ SAFTM_DCHECK_EQ(cols % 2, 0) << "UINT4 with odd #cols = " << cols;
+ return cols / 2;
+ case QuantizationType::FLOAT16:
+ return cols * sizeof(float16);
+ default:
+ SAFTM_LOG(FATAL) << "Unknown quant type: "
+ << static_cast<int>(quant_type);
+ }
+}
+
+// Computes y = weights * Relu(x) + b where Relu is optionally applied.
+//
+// weights and b are the weight matrix, respectively the bias vector of a neural
+// network layer.
+//
+// Note: in the research literature, usually Relu (the activation function) is
+// the last part of a neural layer. From that perspective, this function
+// computes the Relu part of the previous layer (if any) and next the first half
+// (the computation of the state) for the current layer.
+//
+// Note: weights is expected to be the transposed version of the real weight
+// matrix. Hence, instead of computing a linear combination of the columns of
+// weights, we compute a linear combination of its rows; but we are mindful that
+// these rows are the columns of the original matrix, hence the name
+// weights_col_i in the code.
+void SparseReluProductPlusBias(bool apply_relu,
+ const EmbeddingNetworkParams::Matrix &weights,
+ const EmbeddingNetworkParams::Matrix &b,
+ const std::vector<float> &x,
+ std::vector<float> *y) {
+ // Initialize y to b. b is a column matrix (i.e., nb.cols == 1); we already
+ // CHECK-ed that the EmbeddingNetwork constructor.
+ const float *b_start = reinterpret_cast<const float *>(b.elements);
+ SAFTM_DCHECK_EQ(b.cols, 1);
+ y->assign(b_start, b_start + b.rows);
+
+ float *const y_data = y->data();
+ const int y_size = y->size();
+ SAFTM_CHECK_EQ(weights.cols, y_size);
+ const int x_size = x.size();
+ SAFTM_CHECK_EQ(weights.rows, x_size);
+
+ // NOTE: the code below reads x_size * y_size elements from weights; these
+ // reads are safe as long as weights.elements contains weights.rows *
+ // weights.cols elements (where the element size depends on the quantization
+ // type). That requirement is checked by the params provider, e.g., by
+ // EmbeddingNetworkParamsFromFlatbuffer.
+
+ // There is some code duplication between the two main cases of the switch
+ // below: the idea was to "lift" the switch outside the loops, to reduce the
+ // number of tests at runtime.
+ switch (weights.quant_type) {
+ case QuantizationType::NONE: {
+ // We compute a linear combination of the rows from |weights|, using
+ // elements of x (optionally, Relu(x)) as scaling factors (the i-th row
+ // gets multiplied by x[i] before being added with the other rows). Note:
+ // elements of |weights| are stored in row-major order: first the elements
+ // of row #0, next the elements of row #1, etc. In the comments below, we
+ // write "weights[i][j]" to refer to the j-th element from the i-th row of
+ // weights.
+ const float *weight_ptr =
+ reinterpret_cast<const float *>(weights.elements);
+ for (int i = 0; i < x_size; ++i) {
+ // Invariant 1: weight_ptr points to the beginning of the i-th row from
+ // weights (i.e., weights[i][0]).
+ const float scale = x[i];
+ if (!apply_relu || (scale > 0)) {
+ for (int j = 0; j < y_size; ++j, ++weight_ptr) {
+ // Invariant 2: weight_ptr points to weights[i][j].
+ y_data[j] += (*weight_ptr) * scale;
+ }
+ } else {
+ // We don't update y_data, but we still have to move weight_ptr to the
+ // next row (to satisfy Invariant 1). We do this by adding y_size ==
+ // weights.cols() (see earlier CHECK_EQ).
+ weight_ptr += y_size;
+ }
+ }
+ break;
+ }
+ case QuantizationType::FLOAT16: {
+ // See comments for the QuantizationType::NONE case: the code is almost
+ // identical, except for float16 (instead of float) and the Float16To32
+ // conversion. We could unify these two cases using a template, but since
+ // this is a critical loop, don't want to risk that e.g., inlining of the
+ // conversion function doesn't happen.
+ const float16 *weight_ptr =
+ reinterpret_cast<const float16 *>(weights.elements);
+ for (int i = 0; i < x_size; ++i) {
+ const float scale = x[i];
+ if (!apply_relu || (scale > 0)) {
+ for (int j = 0; j < y_size; ++j, ++weight_ptr) {
+ y_data[j] += Float16To32(*weight_ptr) * scale;
+ }
+ } else {
+ weight_ptr += y_size;
+ }
+ }
+ break;
+ }
+ default:
+ SAFTM_LOG(FATAL) << "Unsupported weights quantization type: "
+ << static_cast<int>(weights.quant_type);
+ }
+}
+} // namespace
+
+void EmbeddingNetwork::ConcatEmbeddings(
+ const std::vector<FeatureVector> &feature_vectors,
+ std::vector<float> *concat) const {
+ concat->resize(concat_layer_size_);
+
+ // "es_index" stands for "embedding space index".
+ for (int es_index = 0; es_index < feature_vectors.size(); ++es_index) {
+ const int concat_offset = concat_offset_[es_index];
+
+ const EmbeddingNetworkParams::Matrix &embedding_matrix =
+ embedding_matrices_[es_index];
+ const int embedding_dim = embedding_matrix.cols;
+ const int embedding_row_size_in_bytes =
+ embedding_row_size_in_bytes_[es_index];
+
+ const FeatureVector &feature_vector = feature_vectors[es_index];
+ const int num_features = feature_vector.size();
+ for (int fi = 0; fi < num_features; ++fi) {
+ const FeatureType *feature_type = feature_vector.type(fi);
+ int feature_offset = concat_offset + feature_type->base() * embedding_dim;
+ SAFTM_CHECK_LE(feature_offset + embedding_dim, concat->size());
+
+ // Weighted embeddings will be added starting from this address.
+ float *concat_ptr = concat->data() + feature_offset;
+
+ // Multiplier for each embedding weight. Includes feature weight (for
+ // continuous features) and quantization scale (for quantized embeddings).
+ float multiplier;
+ int feature_id;
+ const FeatureValue feature_value = feature_vector.value(fi);
+ if (feature_type->is_continuous()) {
+ // Continuous features (encoded as FloatFeatureValue).
+ FloatFeatureValue float_feature_value(feature_value);
+ feature_id = float_feature_value.id;
+ multiplier = float_feature_value.weight;
+ } else {
+ // Discrete features: every present feature has implicit value 1.0.
+ feature_id = feature_value;
+ multiplier = 1.0;
+ }
+
+ SAFTM_CHECK_GE(feature_id, 0);
+ SAFTM_CHECK_LT(feature_id, embedding_matrix.rows);
+
+ // Pointer to float / uint8 weights for relevant embedding.
+ const void *embedding_data =
+ (reinterpret_cast<const char *>(embedding_matrix.elements) +
+ feature_id * embedding_row_size_in_bytes);
+
+ switch (embedding_matrix.quant_type) {
+ case QuantizationType::NONE: {
+ const float *weights =
+ reinterpret_cast<const float *>(embedding_data);
+ for (int i = 0; i < embedding_dim; ++i, ++weights, ++concat_ptr) {
+ *concat_ptr += *weights * multiplier;
+ }
+ break;
+ }
+ case QuantizationType::UINT8: {
+ multiplier *= Float16To32(embedding_matrix.quant_scales[feature_id]);
+ const uint8 *quant_weights =
+ reinterpret_cast<const uint8 *>(embedding_data);
+ for (int i = 0; i < embedding_dim;
+ ++i, ++quant_weights, ++concat_ptr) {
+ // 128 is bias for UINT8 quantization.
+ *concat_ptr +=
+ (static_cast<int>(*quant_weights) - 128) * multiplier;
+ }
+ break;
+ }
+ case QuantizationType::UINT4: {
+ multiplier *= Float16To32(embedding_matrix.quant_scales[feature_id]);
+ const uint8 *quant_weights =
+ reinterpret_cast<const uint8 *>(embedding_data);
+ for (int i = 0; i < embedding_dim / 2; ++i, ++quant_weights) {
+ const uint8 qq = *quant_weights;
+ concat_ptr[0] +=
+ (static_cast<int>((qq & 0xF0) | 0x08) - 128) * multiplier;
+ concat_ptr[1] +=
+ (static_cast<int>(((qq & 0x0F) << 4) | 0x08) - 128) *
+ multiplier;
+ concat_ptr += 2;
+ }
+ break;
+ }
+ default:
+ // We already checked (in GetMatrixRowSizeInBytes) that each embedding
+ // matrix has a known quantization type. Hence, DLOG is enough here.
+ SAFTM_DLOG(ERROR) << "Unknown embeddings quantization type "
+ << static_cast<int>(embedding_matrix.quant_type);
+ break;
+ }
+ }
+ }
+}
+
+void EmbeddingNetwork::ComputeFinalScores(
+ const std::vector<FeatureVector> &features,
+ std::vector<float> *scores) const {
+ ComputeFinalScores(features, {}, scores);
+}
+
+void EmbeddingNetwork::ComputeFinalScores(
+ const std::vector<FeatureVector> &features,
+ const std::vector<float> &extra_inputs, std::vector<float> *scores) const {
+ // Construct the input layer for our feed-forward neural network (FFNN).
+ std::vector<float> input;
+ ConcatEmbeddings(features, &input);
+ if (!extra_inputs.empty()) {
+ input.reserve(input.size() + extra_inputs.size());
+ for (int i = 0; i < extra_inputs.size(); i++) {
+ input.push_back(extra_inputs[i]);
+ }
+ }
+
+ // Propagate input through all layers of our FFNN.
+
+ // Alternating storage for activations of the different layers. We can't use
+ // a single vector because all activations of the previous layer are required
+ // when computing the activations of the next one.
+ std::vector<float> storage[2];
+ const std::vector<float> *v_in = &input;
+ const int num_layers = layer_weights_.size();
+ for (int i = 0; i < num_layers; ++i) {
+ std::vector<float> *v_out = nullptr;
+ if (i == num_layers - 1) {
+ // Final layer: write results directly into |scores|.
+ v_out = scores;
+ } else {
+ // Hidden layer: write results into the alternating storage. The i % 2
+ // trick ensures the alternation.
+ v_out = &(storage[i % 2]);
+ }
+ const bool apply_relu = i > 0;
+ SparseReluProductPlusBias(
+ apply_relu, layer_weights_[i], layer_bias_[i], *v_in, v_out);
+ v_in = v_out;
+ }
+}
+
+EmbeddingNetwork::EmbeddingNetwork(const EmbeddingNetworkParams *model)
+ : model_(model) {
+ int offset_sum = 0;
+ for (int i = 0; i < model_->embedding_num_features_size(); ++i) {
+ concat_offset_.push_back(offset_sum);
+ EmbeddingNetworkParams::Matrix matrix = model_->GetEmbeddingMatrix(i);
+ offset_sum += matrix.cols * model_->embedding_num_features(i);
+
+ // NOTE: each Matrix is a small struct that doesn't own the actual matrix
+ // weights. Hence, the push_back below is fast.
+ embedding_matrices_.push_back(matrix);
+ embedding_row_size_in_bytes_.push_back(GetMatrixRowSizeInBytes(matrix));
+ }
+ concat_layer_size_ = offset_sum;
+
+ SAFTM_CHECK_EQ(model_->hidden_size(), model_->hidden_bias_size());
+ for (int i = 0; i < model_->hidden_size(); ++i) {
+ layer_weights_.push_back(model_->GetHiddenLayerMatrix(i));
+
+ EmbeddingNetworkParams::Matrix bias = model_->GetHiddenLayerBias(i);
+ SAFTM_CHECK_EQ(1, bias.cols);
+ CheckNoQuantization(bias);
+ layer_bias_.push_back(bias);
+ }
+
+ SAFTM_CHECK(model_->HasSoftmax());
+ layer_weights_.push_back(model_->GetSoftmaxMatrix());
+
+ EmbeddingNetworkParams::Matrix softmax_bias = model_->GetSoftmaxBias();
+ SAFTM_CHECK_EQ(1, softmax_bias.cols);
+ CheckNoQuantization(softmax_bias);
+ layer_bias_.push_back(softmax_bias);
+}
+
+} // namespace mobile
+} // namespace nlp_saft
diff --git a/lang_id/common/embedding-network.h b/lang_id/common/embedding-network.h
new file mode 100644
index 0000000..54094d7
--- /dev/null
+++ b/lang_id/common/embedding-network.h
@@ -0,0 +1,99 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_NETWORK_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_NETWORK_H_
+
+#include <vector>
+
+#include "lang_id/common/embedding-network-params.h"
+#include "lang_id/common/fel/feature-extractor.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+// Classifier using a hand-coded feed-forward neural network.
+//
+// No gradient computation, just inference.
+//
+// Based on the more general nlp_saft::EmbeddingNetwork (without ::mobile).
+//
+// Classification works as follows:
+//
+// Discrete features -> Embeddings -> Concatenation -> Hidden+ -> Softmax
+//
+// In words: given some discrete features, this class extracts the embeddings
+// for these features, concatenates them, passes them through one or more hidden
+// layers (each layer uses Relu) and next through a softmax layer that computes
+// an unnormalized score for each possible class. Note: there is always a
+// softmax layer at the end.
+class EmbeddingNetwork {
+ public:
+ // Constructs an embedding network using the parameters from model.
+ //
+ // Note: model should stay alive for at least the lifetime of this
+ // EmbeddingNetwork object.
+ explicit EmbeddingNetwork(const EmbeddingNetworkParams *model);
+
+ virtual ~EmbeddingNetwork() {}
+
+ // Runs forward computation to fill scores with unnormalized output unit
+ // scores. This is useful for making predictions.
+ void ComputeFinalScores(const std::vector<FeatureVector> &features,
+ std::vector<float> *scores) const;
+
+ // Same as above, but allows specification of extra extra neural network
+ // inputs that will be appended to the embedding vector build from features.
+ void ComputeFinalScores(const std::vector<FeatureVector> &features,
+ const std::vector<float> &extra_inputs,
+ std::vector<float> *scores) const;
+
+ private:
+ // Constructs the concatenated input embedding vector in place in output
+ // vector concat.
+ void ConcatEmbeddings(const std::vector<FeatureVector> &features,
+ std::vector<float> *concat) const;
+
+ // Pointer to the model object passed to the constructor. Not owned.
+ const EmbeddingNetworkParams *model_;
+
+ // Network parameters.
+
+ // One weight matrix for each embedding.
+ std::vector<EmbeddingNetworkParams::Matrix> embedding_matrices_;
+
+ // embedding_row_size_in_bytes_[i] is the size (in bytes) of a row from
+ // embedding_matrices_[i]. We precompute this in order to quickly find the
+ // beginning of the k-th row from an embedding matrix (which is stored in
+ // row-major order).
+ std::vector<int> embedding_row_size_in_bytes_;
+
+ // concat_offset_[i] is the input layer offset for i-th embedding space.
+ std::vector<int> concat_offset_;
+
+ // Size of the input ("concatenation") layer.
+ int concat_layer_size_ = 0;
+
+ // One weight matrix and one vector of bias weights for each layer of neurons.
+ // Last layer is the softmax layer, the previous ones are the hidden layers.
+ std::vector<EmbeddingNetworkParams::Matrix> layer_weights_;
+ std::vector<EmbeddingNetworkParams::Matrix> layer_bias_;
+};
+
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_NETWORK_H_
diff --git a/lang_id/common/fel/feature-descriptors.cc b/lang_id/common/fel/feature-descriptors.cc
new file mode 100644
index 0000000..bf03dd5
--- /dev/null
+++ b/lang_id/common/fel/feature-descriptors.cc
@@ -0,0 +1,78 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/common/fel/feature-descriptors.h"
+
+#include "lang_id/common/lite_strings/str-cat.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+void ToFELFunction(const FeatureFunctionDescriptor &function, string *output) {
+ LiteStrAppend(output, function.type());
+ if (function.argument() != 0 || function.parameter_size() > 0) {
+ LiteStrAppend(output, "(");
+ bool first = true;
+ if (function.argument() != 0) {
+ LiteStrAppend(output, function.argument());
+ first = false;
+ }
+ for (int i = 0; i < function.parameter_size(); ++i) {
+ if (!first) LiteStrAppend(output, ",");
+ LiteStrAppend(output, function.parameter(i).name(), "=\"",
+ function.parameter(i).value(), "\"");
+ first = false;
+ }
+ LiteStrAppend(output, ")");
+ }
+}
+
+void ToFEL(const FeatureFunctionDescriptor &function, string *output) {
+ ToFELFunction(function, output);
+ if (function.feature_size() == 1) {
+ LiteStrAppend(output, ".");
+ ToFEL(function.feature(0), output);
+ } else if (function.feature_size() > 1) {
+ LiteStrAppend(output, " { ");
+ for (int i = 0; i < function.feature_size(); ++i) {
+ if (i > 0) LiteStrAppend(output, " ");
+ ToFEL(function.feature(i), output);
+ }
+ LiteStrAppend(output, " } ");
+ }
+}
+
+void ToFEL(const FeatureExtractorDescriptor &extractor, string *output) {
+ for (int i = 0; i < extractor.feature_size(); ++i) {
+ ToFEL(extractor.feature(i), output);
+ LiteStrAppend(output, "\n");
+ }
+}
+
+string FeatureFunctionDescriptor::DebugString() const {
+ string str;
+ ToFEL(*this, &str);
+ return str;
+}
+
+string FeatureExtractorDescriptor::DebugString() const {
+ string str;
+ ToFEL(*this, &str);
+ return str;
+}
+
+} // namespace mobile
+} // namespace nlp_saft
diff --git a/lang_id/common/fel/feature-descriptors.h b/lang_id/common/fel/feature-descriptors.h
new file mode 100644
index 0000000..a9408c9
--- /dev/null
+++ b/lang_id/common/fel/feature-descriptors.h
@@ -0,0 +1,159 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_FEATURE_DESCRIPTORS_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_FEATURE_DESCRIPTORS_H_
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "lang_id/common/lite_base/integral-types.h"
+#include "lang_id/common/lite_base/logging.h"
+#include "lang_id/common/lite_base/macros.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+// Named feature parameter.
+class Parameter {
+ public:
+ Parameter() {}
+
+ void set_name(const string &name) { name_ = name; }
+ const string &name() const { return name_; }
+
+ void set_value(const string &value) { value_ = value; }
+ const string &value() const { return value_; }
+
+ private:
+ string name_;
+ string value_;
+};
+
+// Descriptor for a feature function. Used to store the results of parsing one
+// feature function.
+class FeatureFunctionDescriptor {
+ public:
+ FeatureFunctionDescriptor() {}
+
+ // Accessors for the feature function type. The function type is the string
+ // that the feature extractor code is registered under.
+ void set_type(const string &type) { type_ = type; }
+ const string &type() const { return type_; }
+
+ // Accessors for the feature function name. The function name (if available)
+ // is used for some log messages. Otherwise, a more precise, but also more
+ // verbose name based on the feature specification is used.
+ void set_name(const string &name) { name_ = name; }
+ const string &name() const { return name_; }
+
+ // Accessors for the default (name-less) parameter.
+ void set_argument(int32 argument) { argument_ = argument; }
+ bool has_argument() const {
+ // If argument has not been specified, clients should treat it as 0. This
+ // makes the test below correct, without having a separate has_argument_
+ // bool field.
+ return argument_ != 0;
+ }
+ int32 argument() const { return argument_; }
+
+ // Accessors for the named parameters.
+ Parameter *add_parameter() {
+ parameters_.emplace_back();
+ return &(parameters_.back());
+ }
+ int parameter_size() const { return parameters_.size(); }
+ const Parameter ¶meter(int i) const {
+ SAFTM_DCHECK((i >= 0) && (i < parameter_size()));
+ return parameters_[i];
+ }
+
+ // Accessors for the sub (i.e., nested) features. Nested features: as in
+ // offset(1).label.
+ FeatureFunctionDescriptor *add_feature() {
+ sub_features_.emplace_back(new FeatureFunctionDescriptor());
+ return sub_features_.back().get();
+ }
+ int feature_size() const { return sub_features_.size(); }
+ const FeatureFunctionDescriptor &feature(int i) const {
+ SAFTM_DCHECK((i >= 0) && (i < feature_size()));
+ return *(sub_features_[i].get());
+ }
+
+ // Returns human-readable representation of this FeatureFunctionDescriptor.
+ string DebugString() const;
+
+ private:
+ // See comments for set_type().
+ string type_;
+
+ // See comments for set_name().
+ string name_;
+
+ // See comments for set_argument().
+ int32 argument_ = 0;
+
+ // See comemnts for add_parameter().
+ std::vector<Parameter> parameters_;
+
+ // See comments for add_feature().
+ std::vector<std::unique_ptr<FeatureFunctionDescriptor>> sub_features_;
+
+ SAFTM_DISALLOW_COPY_AND_ASSIGN(FeatureFunctionDescriptor);
+};
+
+// List of FeatureFunctionDescriptors. Used to store the result of parsing the
+// spec for several feature functions.
+class FeatureExtractorDescriptor {
+ public:
+ FeatureExtractorDescriptor() {}
+
+ int feature_size() const { return features_.size(); }
+
+ FeatureFunctionDescriptor *add_feature() {
+ features_.emplace_back(new FeatureFunctionDescriptor());
+ return features_.back().get();
+ }
+
+ const FeatureFunctionDescriptor &feature(int i) const {
+ SAFTM_DCHECK((i >= 0) && (i < feature_size()));
+ return *(features_[i].get());
+ }
+
+ // Returns human-readable representation of this FeatureExtractorDescriptor.
+ string DebugString() const;
+
+ private:
+ std::vector<std::unique_ptr<FeatureFunctionDescriptor>> features_;
+
+ SAFTM_DISALLOW_COPY_AND_ASSIGN(FeatureExtractorDescriptor);
+};
+
+// Appends to |*output| the FEL representation of the top-level feature from
+// |function|, without diving into the nested features.
+void ToFELFunction(const FeatureFunctionDescriptor &function, string *output);
+
+// Appends to |*output| the FEL representation of |function|.
+void ToFEL(const FeatureFunctionDescriptor &function, string *output);
+
+// Appends to |*output| the FEL representation of |extractor|.
+void ToFEL(const FeatureExtractorDescriptor &extractor, string *output);
+
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_FEATURE_DESCRIPTORS_H_
diff --git a/lang_id/common/fel/feature-extractor.cc b/lang_id/common/fel/feature-extractor.cc
new file mode 100644
index 0000000..c256257
--- /dev/null
+++ b/lang_id/common/fel/feature-extractor.cc
@@ -0,0 +1,139 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/common/fel/feature-extractor.h"
+
+#include "lang_id/common/fel/feature-types.h"
+#include "lang_id/common/fel/fel-parser.h"
+#include "lang_id/common/lite_base/logging.h"
+#include "lang_id/common/lite_strings/numbers.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+constexpr FeatureValue GenericFeatureFunction::kNone;
+
+GenericFeatureExtractor::GenericFeatureExtractor() {}
+
+GenericFeatureExtractor::~GenericFeatureExtractor() {}
+
+bool GenericFeatureExtractor::Parse(const string &source) {
+ // Parse feature specification into descriptor.
+ FELParser parser;
+
+ if (!parser.Parse(source, mutable_descriptor())) {
+ SAFTM_LOG(ERROR) << "Error parsing the FEL spec " << source;
+ return false;
+ }
+
+ // Initialize feature extractor from descriptor.
+ return InitializeFeatureFunctions();
+}
+
+bool GenericFeatureExtractor::InitializeFeatureTypes() {
+ // Register all feature types.
+ GetFeatureTypes(&feature_types_);
+ for (size_t i = 0; i < feature_types_.size(); ++i) {
+ FeatureType *ft = feature_types_[i];
+ ft->set_base(i);
+
+ // Check for feature space overflow.
+ double domain_size = ft->GetDomainSize();
+ if (domain_size < 0) {
+ SAFTM_LOG(ERROR) << "Illegal domain size for feature " << ft->name()
+ << ": " << domain_size;
+ return false;
+ }
+ }
+ return true;
+}
+
+string GenericFeatureFunction::GetParameter(const string &name,
+ const string &default_value) const {
+ // Find named parameter in feature descriptor.
+ for (int i = 0; i < descriptor_->parameter_size(); ++i) {
+ if (name == descriptor_->parameter(i).name()) {
+ return descriptor_->parameter(i).value();
+ }
+ }
+ return default_value;
+}
+
+GenericFeatureFunction::GenericFeatureFunction() {}
+
+GenericFeatureFunction::~GenericFeatureFunction() { delete feature_type_; }
+
+int GenericFeatureFunction::GetIntParameter(const string &name,
+ int default_value) const {
+ string value_str = GetParameter(name, "");
+ if (value_str.empty()) {
+ // Parameter not specified, use default value for it.
+ return default_value;
+ }
+ int value = 0;
+ if (!LiteAtoi(value_str, &value)) {
+ SAFTM_LOG(DFATAL) << "Unable to parse '" << value_str
+ << "' as int for parameter " << name;
+ return default_value;
+ }
+ return value;
+}
+
+bool GenericFeatureFunction::GetBoolParameter(const string &name,
+ bool default_value) const {
+ string value = GetParameter(name, "");
+ if (value.empty()) return default_value;
+ if (value == "true") return true;
+ if (value == "false") return false;
+ SAFTM_LOG(DFATAL) << "Illegal value '" << value << "' for bool parameter "
+ << name;
+ return default_value;
+}
+
+void GenericFeatureFunction::GetFeatureTypes(
+ std::vector<FeatureType *> *types) const {
+ if (feature_type_ != nullptr) types->push_back(feature_type_);
+}
+
+FeatureType *GenericFeatureFunction::GetFeatureType() const {
+ // If a single feature type has been registered return it.
+ if (feature_type_ != nullptr) return feature_type_;
+
+ // Get feature types for function.
+ std::vector<FeatureType *> types;
+ GetFeatureTypes(&types);
+
+ // If there is exactly one feature type return this, else return null.
+ if (types.size() == 1) return types[0];
+ return nullptr;
+}
+
+string GenericFeatureFunction::name() const {
+ string output;
+ if (descriptor_->name().empty()) {
+ if (!prefix_.empty()) {
+ output.append(prefix_);
+ output.append(".");
+ }
+ ToFEL(*descriptor_, &output);
+ } else {
+ output = descriptor_->name();
+ }
+ return output;
+}
+
+} // namespace mobile
+} // namespace nlp_saft
diff --git a/lang_id/common/fel/feature-extractor.h b/lang_id/common/fel/feature-extractor.h
new file mode 100644
index 0000000..8763852
--- /dev/null
+++ b/lang_id/common/fel/feature-extractor.h
@@ -0,0 +1,651 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Generic feature extractor for extracting features from objects. The feature
+// extractor can be used for extracting features from any object. The feature
+// extractor and feature function classes are template classes that have to
+// be instantiated for extracting feature from a specific object type.
+//
+// A feature extractor consists of a hierarchy of feature functions. Each
+// feature function extracts one or more feature type and value pairs from the
+// object.
+//
+// The feature extractor has a modular design where new feature functions can be
+// registered as components. The feature extractor is initialized from a
+// descriptor represented by a protocol buffer. The feature extractor can also
+// be initialized from a text-based source specification of the feature
+// extractor. Feature specification parsers can be added as components. By
+// default the feature extractor can be read from an ASCII protocol buffer or in
+// a simple feature modeling language (fml).
+
+// A feature function is invoked with a focus. Nested feature function can be
+// invoked with another focus determined by the parent feature function.
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_FEATURE_EXTRACTOR_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_FEATURE_EXTRACTOR_H_
+
+#include <stddef.h>
+
+#include <string>
+#include <vector>
+
+#include "lang_id/common/fel/feature-descriptors.h"
+#include "lang_id/common/fel/feature-types.h"
+#include "lang_id/common/fel/task-context.h"
+#include "lang_id/common/fel/workspace.h"
+#include "lang_id/common/lite_base/attributes.h"
+#include "lang_id/common/lite_base/integral-types.h"
+#include "lang_id/common/lite_base/logging.h"
+#include "lang_id/common/lite_base/macros.h"
+#include "lang_id/common/registry.h"
+#include "lang_id/common/stl-util.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+// TODO(djweiss) Clean this up as well.
+// Use the same type for feature values as is used for predicated.
+typedef int64 Predicate;
+typedef Predicate FeatureValue;
+
+// A union used to represent discrete and continuous feature values.
+union FloatFeatureValue {
+ public:
+ explicit FloatFeatureValue(FeatureValue v) : discrete_value(v) {}
+ FloatFeatureValue(uint32 i, float w) : id(i), weight(w) {}
+ FeatureValue discrete_value;
+ struct {
+ uint32 id;
+ float weight;
+ };
+};
+
+// A feature vector contains feature type and value pairs.
+class FeatureVector {
+ public:
+ FeatureVector() {}
+
+ // Adds feature type and value pair to feature vector.
+ void add(FeatureType *type, FeatureValue value) {
+ features_.emplace_back(type, value);
+ }
+
+ // Removes all elements from the feature vector.
+ void clear() { features_.clear(); }
+
+ // Returns the number of elements in the feature vector.
+ int size() const { return features_.size(); }
+
+ // Reserves space in the underlying feature vector.
+ void reserve(int n) { features_.reserve(n); }
+
+ // Returns feature type for an element in the feature vector.
+ FeatureType *type(int index) const { return features_[index].type; }
+
+ // Returns feature value for an element in the feature vector.
+ FeatureValue value(int index) const { return features_[index].value; }
+
+ private:
+ // Structure for holding feature type and value pairs.
+ struct Element {
+ Element() : type(nullptr), value(-1) {}
+ Element(FeatureType *t, FeatureValue v) : type(t), value(v) {}
+
+ FeatureType *type;
+ FeatureValue value;
+ };
+
+ // Array for storing feature vector elements.
+ std::vector<Element> features_;
+
+ SAFTM_DISALLOW_COPY_AND_ASSIGN(FeatureVector);
+};
+
+// The generic feature extractor is the type-independent part of a feature
+// extractor. This holds the descriptor for the feature extractor and the
+// collection of feature types used in the feature extractor. The feature
+// types are not available until FeatureExtractor<>::Init() has been called.
+class GenericFeatureExtractor {
+ public:
+ GenericFeatureExtractor();
+ virtual ~GenericFeatureExtractor();
+
+ // Initializes the feature extractor from the FEL specification |source|.
+ //
+ // Returns true on success, false otherwise (e.g., FEL syntax error).
+ SAFTM_MUST_USE_RESULT bool Parse(const string &source);
+
+ // Returns the feature extractor descriptor.
+ const FeatureExtractorDescriptor &descriptor() const { return descriptor_; }
+ FeatureExtractorDescriptor *mutable_descriptor() { return &descriptor_; }
+
+ // Returns the number of feature types in the feature extractor. Invalid
+ // before Init() has been called.
+ int feature_types() const { return feature_types_.size(); }
+
+ protected:
+ // Initializes the feature types used by the extractor. Called from
+ // FeatureExtractor<>::Init().
+ //
+ // Returns true on success, false otherwise.
+ SAFTM_MUST_USE_RESULT bool InitializeFeatureTypes();
+
+ private:
+ // Initializes the top-level feature functions.
+ //
+ // Returns true on success, false otherwise.
+ SAFTM_MUST_USE_RESULT virtual bool InitializeFeatureFunctions() = 0;
+
+ // Returns all feature types used by the extractor. The feature types are
+ // added to the result array.
+ virtual void GetFeatureTypes(std::vector<FeatureType *> *types) const = 0;
+
+ // Descriptor for the feature extractor. This is a protocol buffer that
+ // contains all the information about the feature extractor. The feature
+ // functions are initialized from the information in the descriptor.
+ FeatureExtractorDescriptor descriptor_;
+
+ // All feature types used by the feature extractor. The collection of all the
+ // feature types describes the feature space of the feature set produced by
+ // the feature extractor. Not owned.
+ std::vector<FeatureType *> feature_types_;
+};
+
+// The generic feature function is the type-independent part of a feature
+// function. Each feature function is associated with the descriptor that it is
+// instantiated from. The feature types associated with this feature function
+// will be established by the time FeatureExtractor<>::Init() completes.
+class GenericFeatureFunction {
+ public:
+ // A feature value that represents the absence of a value.
+ static constexpr FeatureValue kNone = -1;
+
+ GenericFeatureFunction();
+ virtual ~GenericFeatureFunction();
+
+ // Sets up the feature function. NB: FeatureTypes of nested functions are not
+ // guaranteed to be available until Init().
+ //
+ // Returns true on success, false otherwise.
+ SAFTM_MUST_USE_RESULT virtual bool Setup(TaskContext *context) {
+ return true;
+ }
+
+ // Initializes the feature function. NB: The FeatureType of this function must
+ // be established when this method completes.
+ //
+ // Returns true on success, false otherwise.
+ SAFTM_MUST_USE_RESULT virtual bool Init(TaskContext *context) { return true; }
+
+ // Requests workspaces from a registry to obtain indices into a WorkspaceSet
+ // for any Workspace objects used by this feature function. NB: This will be
+ // called after Init(), so it can depend on resources and arguments.
+ virtual void RequestWorkspaces(WorkspaceRegistry *registry) {}
+
+ // Appends the feature types produced by the feature function to types. The
+ // default implementation appends feature_type(), if non-null. Invalid
+ // before Init() has been called.
+ virtual void GetFeatureTypes(std::vector<FeatureType *> *types) const;
+
+ // Returns the feature type for feature produced by this feature function. If
+ // the feature function produces features of different types this returns
+ // null. Invalid before Init() has been called.
+ virtual FeatureType *GetFeatureType() const;
+
+ // Returns value of parameter |name| from the feature function descriptor.
+ // If the parameter is not present, returns the indicated |default_value|.
+ string GetParameter(const string &name, const string &default_value) const;
+
+ // Returns value of int parameter |name| from feature function descriptor.
+ // If the parameter is not present, or its value can't be parsed as an int,
+ // returns |default_value|.
+ int GetIntParameter(const string &name, int default_value) const;
+
+ // Returns value of bool parameter |name| from feature function descriptor.
+ // If the parameter is not present, or its value is not "true" or "false",
+ // returns |default_value|. NOTE: this method is case sensitive, it doesn't
+ // do any lower-casing.
+ bool GetBoolParameter(const string &name, bool default_value) const;
+
+ // Returns the FEL function description for the feature function, i.e. the
+ // name and parameters without the nested features.
+ string FunctionName() const {
+ string output;
+ ToFELFunction(*descriptor_, &output);
+ return output;
+ }
+
+ // Returns the prefix for nested feature functions. This is the prefix of this
+ // feature function concatenated with the feature function name.
+ string SubPrefix() const {
+ return prefix_.empty() ? FunctionName() : prefix_ + "." + FunctionName();
+ }
+
+ // Returns/sets the feature extractor this function belongs to.
+ const GenericFeatureExtractor *extractor() const { return extractor_; }
+ void set_extractor(const GenericFeatureExtractor *extractor) {
+ extractor_ = extractor;
+ }
+
+ // Returns/sets the feature function descriptor.
+ const FeatureFunctionDescriptor *descriptor() const { return descriptor_; }
+ void set_descriptor(const FeatureFunctionDescriptor *descriptor) {
+ descriptor_ = descriptor;
+ }
+
+ // Returns a descriptive name for the feature function. The name is taken from
+ // the descriptor for the feature function. If the name is empty or the
+ // feature function is a variable the name is the FEL representation of the
+ // feature, including the prefix.
+ string name() const;
+
+ // Returns the argument from the feature function descriptor. It defaults to
+ // 0 if the argument has not been specified.
+ int argument() const {
+ return descriptor_->has_argument() ? descriptor_->argument() : 0;
+ }
+
+ // Returns/sets/clears function name prefix.
+ const string &prefix() const { return prefix_; }
+ void set_prefix(const string &prefix) { prefix_ = prefix; }
+
+ protected:
+ // Returns the feature type for single-type feature functions.
+ FeatureType *feature_type() const { return feature_type_; }
+
+ // Sets the feature type for single-type feature functions. This takes
+ // ownership of feature_type. Can only be called once.
+ void set_feature_type(FeatureType *feature_type) {
+ SAFTM_CHECK_EQ(feature_type_, nullptr);
+ feature_type_ = feature_type;
+ }
+
+ private:
+ // Feature extractor this feature function belongs to. Not owned. Set to a
+ // pointer != nullptr as soon as this object is created by Instantiate().
+ // Normal methods can safely assume this is != nullptr.
+ const GenericFeatureExtractor *extractor_ = nullptr;
+
+ // Descriptor for feature function. Not owned. Set to a pointer != nullptr
+ // as soon as this object is created by Instantiate(). Normal methods can
+ // safely assume this is != nullptr.
+ const FeatureFunctionDescriptor *descriptor_ = nullptr;
+
+ // Feature type for features produced by this feature function. If the
+ // feature function produces features of multiple feature types this is null
+ // and the feature function must return it's feature types in
+ // GetFeatureTypes(). Owned.
+ FeatureType *feature_type_ = nullptr;
+
+ // Prefix used for sub-feature types of this function.
+ string prefix_;
+};
+
+// Feature function that can extract features from an object. Templated on
+// two type arguments:
+//
+// OBJ: The "object" from which features are extracted; e.g., a sentence. This
+// should be a plain type, rather than a reference or pointer.
+//
+// ARGS: A set of 0 or more types that are used to "index" into some part of the
+// object that should be extracted, e.g. an int token index for a sentence
+// object. This should not be a reference type.
+template <class OBJ, class... ARGS>
+class FeatureFunction
+ : public GenericFeatureFunction,
+ public RegisterableClass<FeatureFunction<OBJ, ARGS...> > {
+ public:
+ using Self = FeatureFunction<OBJ, ARGS...>;
+
+ // Preprocesses the object. This will be called prior to calling Evaluate()
+ // or Compute() on that object.
+ virtual void Preprocess(WorkspaceSet *workspaces, const OBJ *object) const {}
+
+ // Appends features computed from the object and focus to the result. The
+ // default implementation delegates to Compute(), adding a single value if
+ // available. Multi-valued feature functions must override this method.
+ virtual void Evaluate(const WorkspaceSet &workspaces, const OBJ &object,
+ ARGS... args, FeatureVector *result) const {
+ FeatureValue value = Compute(workspaces, object, args...);
+ if (value != kNone) result->add(feature_type(), value);
+ }
+
+ // Returns a feature value computed from the object and focus, or kNone if no
+ // value is computed. Single-valued feature functions only need to override
+ // this method.
+ virtual FeatureValue Compute(const WorkspaceSet &workspaces,
+ const OBJ &object, ARGS... args) const {
+ return kNone;
+ }
+
+ // Instantiates a new feature function in a feature extractor from a feature
+ // descriptor.
+ //
+ // Returns a pointer to the newly-created object if everything goes well.
+ // Returns nullptr if the feature function could not be instantiated (e.g., if
+ // the function with that name is not registered; this usually happens because
+ // the relevant cc_library was not linked-in).
+ static Self *Instantiate(const GenericFeatureExtractor *extractor,
+ const FeatureFunctionDescriptor *fd,
+ const string &prefix) {
+ Self *f = Self::Create(fd->type());
+ if (f != nullptr) {
+ f->set_extractor(extractor);
+ f->set_descriptor(fd);
+ f->set_prefix(prefix);
+ }
+ return f;
+ }
+
+ private:
+ // Special feature function class for resolving variable references. The type
+ // of the feature function is used for resolving the variable reference. When
+ // evaluated it will either get the feature value(s) from the variable portion
+ // of the feature vector, if present, or otherwise it will call the referenced
+ // feature extractor function directly to extract the feature(s).
+ class Reference;
+};
+
+// Base class for features with nested feature functions. The nested functions
+// are of type NES, which may be different from the type of the parent function.
+// NB: NestedFeatureFunction will ensure that all initialization of nested
+// functions takes place during Setup() and Init() -- after the nested features
+// are initialized, the parent feature is initialized via SetupNested() and
+// InitNested(). Alternatively, a derived classes that overrides Setup() and
+// Init() directly should call Parent::Setup(), Parent::Init(), etc. first.
+//
+// Note: NestedFeatureFunction cannot know how to call Preprocess, Evaluate, or
+// Compute, since the nested functions may be of a different type.
+template <class NES, class OBJ, class... ARGS>
+class NestedFeatureFunction : public FeatureFunction<OBJ, ARGS...> {
+ public:
+ using Parent = NestedFeatureFunction<NES, OBJ, ARGS...>;
+
+ // Clean up nested functions.
+ ~NestedFeatureFunction() override { utils::STLDeleteElements(&nested_); }
+
+ // By default, just appends the nested feature types.
+ void GetFeatureTypes(std::vector<FeatureType *> *types) const override {
+ SAFTM_CHECK(!this->nested().empty())
+ << "Nested features require nested features to be defined.";
+ for (auto *function : nested_) function->GetFeatureTypes(types);
+ }
+
+ // Sets up the nested features.
+ //
+ // Returns true on success, false otherwise.
+ SAFTM_MUST_USE_RESULT bool Setup(TaskContext *context) override {
+ bool success = CreateNested(this->extractor(), this->descriptor(), &nested_,
+ this->SubPrefix());
+ if (!success) return false;
+ for (auto *function : nested_) {
+ if (!function->Setup(context)) return false;
+ }
+ if (!SetupNested(context)) return false;
+ return true;
+ }
+
+ // Sets up this NestedFeatureFunction specifically.
+ //
+ // Returns true on success, false otherwise.
+ SAFTM_MUST_USE_RESULT virtual bool SetupNested(TaskContext *context) {
+ return true;
+ }
+
+ // Initializes the nested features.
+ //
+ // Returns true on success, false otherwise.
+ SAFTM_MUST_USE_RESULT bool Init(TaskContext *context) override {
+ for (auto *function : nested_) {
+ if (!function->Init(context)) return false;
+ }
+ if (!InitNested(context)) return false;
+ return true;
+ }
+
+ // Initializes this NestedFeatureFunction specifically.
+ //
+ // Returns true on success, false otherwise.
+ SAFTM_MUST_USE_RESULT virtual bool InitNested(TaskContext *context) {
+ return true;
+ }
+
+ // Gets all the workspaces needed for the nested functions.
+ void RequestWorkspaces(WorkspaceRegistry *registry) override {
+ for (auto *function : nested_) function->RequestWorkspaces(registry);
+ }
+
+ // Returns the list of nested feature functions.
+ const std::vector<NES *> &nested() const { return nested_; }
+
+ // Instantiates nested feature functions for a feature function. Creates and
+ // initializes one feature function for each sub-descriptor in the feature
+ // descriptor.
+ //
+ // Returns true on success, false otherwise.
+ SAFTM_MUST_USE_RESULT static bool CreateNested(
+ const GenericFeatureExtractor *extractor,
+ const FeatureFunctionDescriptor *fd, std::vector<NES *> *functions,
+ const string &prefix) {
+ for (int i = 0; i < fd->feature_size(); ++i) {
+ const FeatureFunctionDescriptor &sub = fd->feature(i);
+ NES *f = NES::Instantiate(extractor, &sub, prefix);
+ if (f == nullptr) return false;
+ functions->push_back(f);
+ }
+ return true;
+ }
+
+ protected:
+ // The nested feature functions, if any, in order of declaration in the
+ // feature descriptor. Owned.
+ std::vector<NES *> nested_;
+};
+
+// Base class for a nested feature function that takes nested features with the
+// same signature as these features, i.e. a meta feature. For this class, we can
+// provide preprocessing of the nested features.
+template <class OBJ, class... ARGS>
+class MetaFeatureFunction
+ : public NestedFeatureFunction<FeatureFunction<OBJ, ARGS...>, OBJ,
+ ARGS...> {
+ public:
+ // Preprocesses using the nested features.
+ void Preprocess(WorkspaceSet *workspaces, const OBJ *object) const override {
+ for (auto *function : this->nested_) {
+ function->Preprocess(workspaces, object);
+ }
+ }
+};
+
+// Template for a special type of locator: The locator of type
+// FeatureFunction<OBJ, ARGS...> calls nested functions of type
+// FeatureFunction<OBJ, IDX, ARGS...>, where the derived class DER is
+// responsible for translating by providing the following:
+//
+// // Gets the new additional focus.
+// IDX GetFocus(const WorkspaceSet &workspaces, const OBJ &object);
+//
+// This is useful to e.g. add a token focus to a parser state based on some
+// desired property of that state.
+template <class DER, class OBJ, class IDX, class... ARGS>
+class FeatureAddFocusLocator
+ : public NestedFeatureFunction<FeatureFunction<OBJ, IDX, ARGS...>, OBJ,
+ ARGS...> {
+ public:
+ void Preprocess(WorkspaceSet *workspaces, const OBJ *object) const override {
+ for (auto *function : this->nested_) {
+ function->Preprocess(workspaces, object);
+ }
+ }
+
+ void Evaluate(const WorkspaceSet &workspaces, const OBJ &object, ARGS... args,
+ FeatureVector *result) const override {
+ IDX focus =
+ static_cast<const DER *>(this)->GetFocus(workspaces, object, args...);
+ for (auto *function : this->nested()) {
+ function->Evaluate(workspaces, object, focus, args..., result);
+ }
+ }
+
+ // Returns the first nested feature's computed value.
+ FeatureValue Compute(const WorkspaceSet &workspaces, const OBJ &object,
+ ARGS... args) const override {
+ IDX focus =
+ static_cast<const DER *>(this)->GetFocus(workspaces, object, args...);
+ return this->nested()[0]->Compute(workspaces, object, focus, args...);
+ }
+};
+
+// CRTP feature locator class. This is a meta feature that modifies ARGS and
+// then calls the nested feature functions with the modified ARGS. Note that in
+// order for this template to work correctly, all of ARGS must be types for
+// which the reference operator & can be interpreted as a pointer to the
+// argument. The derived class DER must implement the UpdateFocus method which
+// takes pointers to the ARGS arguments:
+//
+// // Updates the current arguments.
+// void UpdateArgs(const OBJ &object, ARGS *...args) const;
+template <class DER, class OBJ, class... ARGS>
+class FeatureLocator : public MetaFeatureFunction<OBJ, ARGS...> {
+ public:
+ // Feature locators have an additional check that there is no intrinsic type.
+ void GetFeatureTypes(std::vector<FeatureType *> *types) const override {
+ SAFTM_CHECK_EQ(this->feature_type(), nullptr)
+ << "FeatureLocators should not have an intrinsic type.";
+ MetaFeatureFunction<OBJ, ARGS...>::GetFeatureTypes(types);
+ }
+
+ // Evaluates the locator.
+ void Evaluate(const WorkspaceSet &workspaces, const OBJ &object, ARGS... args,
+ FeatureVector *result) const override {
+ static_cast<const DER *>(this)->UpdateArgs(workspaces, object, &args...);
+ for (auto *function : this->nested()) {
+ function->Evaluate(workspaces, object, args..., result);
+ }
+ }
+
+ // Returns the first nested feature's computed value.
+ FeatureValue Compute(const WorkspaceSet &workspaces, const OBJ &object,
+ ARGS... args) const override {
+ static_cast<const DER *>(this)->UpdateArgs(workspaces, object, &args...);
+ return this->nested()[0]->Compute(workspaces, object, args...);
+ }
+};
+
+// Feature extractor for extracting features from objects of a certain class.
+// Template type parameters are as defined for FeatureFunction.
+template <class OBJ, class... ARGS>
+class FeatureExtractor : public GenericFeatureExtractor {
+ public:
+ // Feature function type for top-level functions in the feature extractor.
+ typedef FeatureFunction<OBJ, ARGS...> Function;
+ typedef FeatureExtractor<OBJ, ARGS...> Self;
+
+ // Feature locator type for the feature extractor.
+ template <class DER>
+ using Locator = FeatureLocator<DER, OBJ, ARGS...>;
+
+ // Initializes feature extractor.
+ FeatureExtractor() {}
+
+ ~FeatureExtractor() override { utils::STLDeleteElements(&functions_); }
+
+ // Sets up the feature extractor. Note that only top-level functions exist
+ // until Setup() is called. This does not take ownership over the context,
+ // which must outlive this.
+ //
+ // Returns true on success, false otherwise.
+ SAFTM_MUST_USE_RESULT bool Setup(TaskContext *context) {
+ for (Function *function : functions_) {
+ if (!function->Setup(context)) return false;
+ }
+ return true;
+ }
+
+ // Initializes the feature extractor. Must be called after Setup(). This
+ // does not take ownership over the context, which must outlive this.
+ //
+ // Returns true on success, false otherwise.
+ SAFTM_MUST_USE_RESULT bool Init(TaskContext *context) {
+ for (Function *function : functions_) {
+ if (!function->Init(context)) return false;
+ }
+ if (!this->InitializeFeatureTypes()) return false;
+ return true;
+ }
+
+ // Requests workspaces from the registry. Must be called after Init(), and
+ // before Preprocess(). Does not take ownership over registry. This should be
+ // the same registry used to initialize the WorkspaceSet used in Preprocess()
+ // and ExtractFeatures(). NB: This is a different ordering from that used in
+ // SentenceFeatureRepresentation style feature computation.
+ void RequestWorkspaces(WorkspaceRegistry *registry) {
+ for (auto *function : functions_) function->RequestWorkspaces(registry);
+ }
+
+ // Preprocesses the object using feature functions for the phase. Must be
+ // called before any calls to ExtractFeatures() on that object and phase.
+ void Preprocess(WorkspaceSet *workspaces, const OBJ *object) const {
+ for (Function *function : functions_) {
+ function->Preprocess(workspaces, object);
+ }
+ }
+
+ // Extracts features from an object with a focus. This invokes all the
+ // top-level feature functions in the feature extractor. Only feature
+ // functions belonging to the specified phase are invoked.
+ void ExtractFeatures(const WorkspaceSet &workspaces, const OBJ &object,
+ ARGS... args, FeatureVector *result) const {
+ result->reserve(this->feature_types());
+
+ // Extract features.
+ for (int i = 0; i < functions_.size(); ++i) {
+ functions_[i]->Evaluate(workspaces, object, args..., result);
+ }
+ }
+
+ private:
+ // Creates and initializes all feature functions in the feature extractor.
+ //
+ // Returns true on success, false otherwise.
+ SAFTM_MUST_USE_RESULT bool InitializeFeatureFunctions() override {
+ // Create all top-level feature functions.
+ for (int i = 0; i < descriptor().feature_size(); ++i) {
+ const FeatureFunctionDescriptor &fd = descriptor().feature(i);
+ Function *function = Function::Instantiate(this, &fd, "");
+ if (function == nullptr) return false;
+ functions_.push_back(function);
+ }
+ return true;
+ }
+
+ // Collect all feature types used in the feature extractor.
+ void GetFeatureTypes(std::vector<FeatureType *> *types) const override {
+ for (int i = 0; i < functions_.size(); ++i) {
+ functions_[i]->GetFeatureTypes(types);
+ }
+ }
+
+ // Top-level feature functions (and variables) in the feature extractor.
+ // Owned.
+ std::vector<Function *> functions_;
+};
+
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_FEATURE_EXTRACTOR_H_
diff --git a/lang_id/common/fel/feature-types.h b/lang_id/common/fel/feature-types.h
new file mode 100644
index 0000000..18cf69a
--- /dev/null
+++ b/lang_id/common/fel/feature-types.h
@@ -0,0 +1,189 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Common feature types for parser components.
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_FEATURE_TYPES_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_FEATURE_TYPES_H_
+
+#include <algorithm>
+#include <map>
+#include <string>
+#include <utility>
+
+#include "lang_id/common/lite_base/integral-types.h"
+#include "lang_id/common/lite_base/logging.h"
+#include "lang_id/common/lite_strings/str-cat.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+// TODO(djweiss) Clean this up as well.
+// Use the same type for feature values as is used for predicated.
+typedef int64 Predicate;
+typedef Predicate FeatureValue;
+
+// Each feature value in a feature vector has a feature type. The feature type
+// is used for converting feature type and value pairs to predicate values. The
+// feature type can also return names for feature values and calculate the size
+// of the feature value domain. The FeatureType class is abstract and must be
+// specialized for the concrete feature types.
+class FeatureType {
+ public:
+ // Initializes a feature type.
+ explicit FeatureType(const string &name)
+ : name_(name), base_(0),
+ is_continuous_(name.find("continuous") != string::npos) {
+ }
+
+ virtual ~FeatureType() {}
+
+ // Converts a feature value to a name.
+ virtual string GetFeatureValueName(FeatureValue value) const = 0;
+
+ // Returns the size of the feature values domain.
+ virtual int64 GetDomainSize() const = 0;
+
+ // Returns the feature type name.
+ const string &name() const { return name_; }
+
+ Predicate base() const { return base_; }
+ void set_base(Predicate base) { base_ = base; }
+
+ // Returns true iff this feature is continuous; see FloatFeatureValue.
+ bool is_continuous() const { return is_continuous_; }
+
+ private:
+ // Feature type name.
+ string name_;
+
+ // "Base" feature value: i.e. a "slot" in a global ordering of features.
+ Predicate base_;
+
+ // See doc for is_continuous().
+ bool is_continuous_;
+};
+
+// Feature type that is defined using an explicit map from FeatureValue to
+// string values. This can reduce some of the boilerplate when defining
+// features that generate enum values. Example usage:
+//
+// class BeverageSizeFeature : public FeatureFunction<Beverage>
+// enum FeatureValue { SMALL, MEDIUM, LARGE }; // values for this feature
+// void Init(TaskContext *context) override {
+// set_feature_type(new EnumFeatureType("beverage_size",
+// {{SMALL, "SMALL"}, {MEDIUM, "MEDIUM"}, {LARGE, "LARGE"}});
+// }
+// [...]
+// };
+class EnumFeatureType : public FeatureType {
+ public:
+ EnumFeatureType(const string &name,
+ const std::map<FeatureValue, string> &value_names)
+ : FeatureType(name), value_names_(value_names) {
+ for (const auto &pair : value_names) {
+ SAFTM_CHECK_GE(pair.first, 0)
+ << "Invalid feature value: " << pair.first << ", " << pair.second;
+ domain_size_ = std::max(domain_size_, pair.first + 1);
+ }
+ }
+
+ // Returns the feature name for a given feature value.
+ string GetFeatureValueName(FeatureValue value) const override {
+ auto it = value_names_.find(value);
+ if (it == value_names_.end()) {
+ SAFTM_LOG(ERROR) << "Invalid feature value " << value << " for "
+ << name();
+ return "<INVALID>";
+ }
+ return it->second;
+ }
+
+ // Returns the number of possible values for this feature type. This is one
+ // greater than the largest value in the value_names map.
+ FeatureValue GetDomainSize() const override { return domain_size_; }
+
+ protected:
+ // Maximum possible value this feature could take.
+ FeatureValue domain_size_ = 0;
+
+ // Names of feature values.
+ std::map<FeatureValue, string> value_names_;
+};
+
+// Feature type for binary features.
+class BinaryFeatureType : public FeatureType {
+ public:
+ BinaryFeatureType(const string &name, const string &off, const string &on)
+ : FeatureType(name), off_(off), on_(on) {}
+
+ // Returns the feature name for a given feature value.
+ string GetFeatureValueName(FeatureValue value) const override {
+ if (value == 0) return off_;
+ if (value == 1) return on_;
+ return "";
+ }
+
+ // Binary features always have two feature values.
+ FeatureValue GetDomainSize() const override { return 2; }
+
+ private:
+ // Feature value names for on and off.
+ string off_;
+ string on_;
+};
+
+// Feature type for numeric features.
+class NumericFeatureType : public FeatureType {
+ public:
+ // Initializes numeric feature.
+ NumericFeatureType(const string &name, FeatureValue size)
+ : FeatureType(name), size_(size) {}
+
+ // Returns numeric feature value.
+ string GetFeatureValueName(FeatureValue value) const override {
+ if (value < 0) return "";
+ return LiteStrCat(value);
+ }
+
+ // Returns the number of feature values.
+ FeatureValue GetDomainSize() const override { return size_; }
+
+ private:
+ // The underlying size of the numeric feature.
+ FeatureValue size_;
+};
+
+// Feature type for byte features, including an "outside" value.
+class ByteFeatureType : public NumericFeatureType {
+ public:
+ explicit ByteFeatureType(const string &name)
+ : NumericFeatureType(name, 257) {}
+
+ string GetFeatureValueName(FeatureValue value) const override {
+ if (value == 256) {
+ return "<NULL>";
+ }
+ string result;
+ result += static_cast<char>(value);
+ return result;
+ }
+};
+
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_FEATURE_TYPES_H_
diff --git a/lang_id/common/fel/fel-parser.cc b/lang_id/common/fel/fel-parser.cc
new file mode 100644
index 0000000..4346fb7
--- /dev/null
+++ b/lang_id/common/fel/fel-parser.cc
@@ -0,0 +1,289 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/common/fel/fel-parser.h"
+
+#include <ctype.h>
+#include <string>
+
+#include "lang_id/common/lite_base/logging.h"
+#include "lang_id/common/lite_strings/numbers.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+namespace {
+inline bool IsValidCharAtStartOfIdentifier(char c) {
+ return isalpha(c) || (c == '_') || (c == '/');
+}
+
+// Returns true iff character c can appear inside an identifier.
+inline bool IsValidCharInsideIdentifier(char c) {
+ return isalnum(c) || (c == '_') || (c == '-') || (c == '/');
+}
+
+// Returns true iff character c can appear at the beginning of a number.
+inline bool IsValidCharAtStartOfNumber(char c) {
+ return isdigit(c) || (c == '+') || (c == '-');
+}
+
+// Returns true iff character c can appear inside a number.
+inline bool IsValidCharInsideNumber(char c) {
+ return isdigit(c) || (c == '.');
+}
+} // namespace
+
+bool FELParser::Initialize(const string &source) {
+ // Initialize parser state.
+ source_ = source;
+ current_ = source_.begin();
+ item_start_ = line_start_ = current_;
+ line_number_ = item_line_number_ = 1;
+
+ // Read first input item.
+ return NextItem();
+}
+
+void FELParser::ReportError(const string &error_message) {
+ const int position = item_start_ - line_start_ + 1;
+ const string line(line_start_, current_);
+
+ SAFTM_LOG(ERROR) << "Error in feature model, line " << item_line_number_
+ << ", position " << position << ": " << error_message
+ << "\n " << line << " <--HERE";
+}
+
+void FELParser::Next() {
+ // Move to the next input character. If we are at a line break update line
+ // number and line start position.
+ if (CurrentChar() == '\n') {
+ ++line_number_;
+ ++current_;
+ line_start_ = current_;
+ } else {
+ ++current_;
+ }
+}
+
+bool FELParser::NextItem() {
+ // Skip white space and comments.
+ while (!eos()) {
+ if (CurrentChar() == '#') {
+ // Skip comment.
+ while (!eos() && CurrentChar() != '\n') Next();
+ } else if (isspace(CurrentChar())) {
+ // Skip whitespace.
+ while (!eos() && isspace(CurrentChar())) Next();
+ } else {
+ break;
+ }
+ }
+
+ // Record start position for next item.
+ item_start_ = current_;
+ item_line_number_ = line_number_;
+
+ // Check for end of input.
+ if (eos()) {
+ item_type_ = END;
+ return true;
+ }
+
+ // Parse number.
+ if (IsValidCharAtStartOfNumber(CurrentChar())) {
+ string::iterator start = current_;
+ Next();
+ while (!eos() && IsValidCharInsideNumber(CurrentChar())) Next();
+ item_text_.assign(start, current_);
+ item_type_ = NUMBER;
+ return true;
+ }
+
+ // Parse string.
+ if (CurrentChar() == '"') {
+ Next();
+ string::iterator start = current_;
+ while (CurrentChar() != '"') {
+ if (eos()) {
+ ReportError("Unterminated string");
+ return false;
+ }
+ Next();
+ }
+ item_text_.assign(start, current_);
+ item_type_ = STRING;
+ Next();
+ return true;
+ }
+
+ // Parse identifier name.
+ if (IsValidCharAtStartOfIdentifier(CurrentChar())) {
+ string::iterator start = current_;
+ while (!eos() && IsValidCharInsideIdentifier(CurrentChar())) {
+ Next();
+ }
+ item_text_.assign(start, current_);
+ item_type_ = NAME;
+ return true;
+ }
+
+ // Single character item.
+ item_type_ = CurrentChar();
+ Next();
+ return true;
+}
+
+bool FELParser::Parse(const string &source,
+ FeatureExtractorDescriptor *result) {
+ // Initialize parser.
+ if (!Initialize(source)) {
+ return false;
+ }
+
+ while (item_type_ != END) {
+ // Current item should be a feature name.
+ if (item_type_ != NAME) {
+ ReportError("Feature type name expected");
+ return false;
+ }
+ string name = item_text_;
+ if (!NextItem()) {
+ return false;
+ }
+
+ if (item_type_ == '=') {
+ ReportError("Invalid syntax: feature expected");
+ return false;
+ } else {
+ // Parse feature.
+ FeatureFunctionDescriptor *descriptor = result->add_feature();
+ descriptor->set_type(name);
+ if (!ParseFeature(descriptor)) {
+ return false;
+ }
+ }
+ }
+
+ return true;
+}
+
+bool FELParser::ParseFeature(FeatureFunctionDescriptor *result) {
+ // Parse argument and parameters.
+ if (item_type_ == '(') {
+ if (!NextItem()) return false;
+ if (!ParseParameter(result)) return false;
+ while (item_type_ == ',') {
+ if (!NextItem()) return false;
+ if (!ParseParameter(result)) return false;
+ }
+
+ if (item_type_ != ')') {
+ ReportError(") expected");
+ return false;
+ }
+ if (!NextItem()) return false;
+ }
+
+ // Parse feature name.
+ if (item_type_ == ':') {
+ if (!NextItem()) return false;
+ if (item_type_ != NAME && item_type_ != STRING) {
+ ReportError("Feature name expected");
+ return false;
+ }
+ string name = item_text_;
+ if (!NextItem()) return false;
+
+ // Set feature name.
+ result->set_name(name);
+ }
+
+ // Parse sub-features.
+ if (item_type_ == '.') {
+ // Parse dotted sub-feature.
+ if (!NextItem()) return false;
+ if (item_type_ != NAME) {
+ ReportError("Feature type name expected");
+ return false;
+ }
+ string type = item_text_;
+ if (!NextItem()) return false;
+
+ // Parse sub-feature.
+ FeatureFunctionDescriptor *subfeature = result->add_feature();
+ subfeature->set_type(type);
+ if (!ParseFeature(subfeature)) return false;
+ } else if (item_type_ == '{') {
+ // Parse sub-feature block.
+ if (!NextItem()) return false;
+ while (item_type_ != '}') {
+ if (item_type_ != NAME) {
+ ReportError("Feature type name expected");
+ return false;
+ }
+ string type = item_text_;
+ if (!NextItem()) return false;
+
+ // Parse sub-feature.
+ FeatureFunctionDescriptor *subfeature = result->add_feature();
+ subfeature->set_type(type);
+ if (!ParseFeature(subfeature)) return false;
+ }
+ if (!NextItem()) return false;
+ }
+ return true;
+}
+
+bool FELParser::ParseParameter(FeatureFunctionDescriptor *result) {
+ if (item_type_ == NUMBER) {
+ int argument;
+ if (!LiteAtoi(item_text_, &argument)) {
+ ReportError("Unable to parse number");
+ return false;
+ }
+ if (!NextItem()) return false;
+
+ // Set default argument for feature.
+ result->set_argument(argument);
+ } else if (item_type_ == NAME) {
+ string name = item_text_;
+ if (!NextItem()) return false;
+ if (item_type_ != '=') {
+ ReportError("= expected");
+ return false;
+ }
+ if (!NextItem()) return false;
+ if (item_type_ >= END) {
+ ReportError("Parameter value expected");
+ return false;
+ }
+ string value = item_text_;
+ if (!NextItem()) return false;
+
+ // Add parameter to feature.
+ Parameter *parameter;
+ parameter = result->add_parameter();
+ parameter->set_name(name);
+ parameter->set_value(value);
+ } else {
+ ReportError("Syntax error in parameter list");
+ return false;
+ }
+ return true;
+}
+
+} // namespace mobile
+} // namespace nlp_saft
diff --git a/lang_id/common/fel/fel-parser.h b/lang_id/common/fel/fel-parser.h
new file mode 100644
index 0000000..eacb442
--- /dev/null
+++ b/lang_id/common/fel/fel-parser.h
@@ -0,0 +1,135 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Feature extraction language (FEL) parser.
+//
+// BNF grammar for FEL:
+//
+// <feature model> ::= { <feature extractor> }
+//
+// <feature extractor> ::= <extractor spec> |
+// <extractor spec> '.' <feature extractor> |
+// <extractor spec> '{' { <feature extractor> } '}'
+//
+// <extractor spec> ::= <extractor type>
+// [ '(' <parameter list> ')' ]
+// [ ':' <extractor name> ]
+//
+// <parameter list> = ( <parameter> | <argument> ) { ',' <parameter> }
+//
+// <parameter> ::= <parameter name> '=' <parameter value>
+//
+// <extractor type> ::= NAME
+// <extractor name> ::= NAME | STRING
+// <argument> ::= NUMBER
+// <parameter name> ::= NAME
+// <parameter value> ::= NUMBER | STRING | NAME
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_FEL_PARSER_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_FEL_PARSER_H_
+
+#include <string>
+
+#include "lang_id/common/fel/feature-descriptors.h"
+#include "lang_id/common/lite_base/logging.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+class FELParser {
+ public:
+ // Parses fml specification into feature extractor descriptor.
+ // Returns true on success, false on error (e.g., syntax errors).
+ bool Parse(const string &source, FeatureExtractorDescriptor *result);
+
+ private:
+ // Initializes the parser with the source text.
+ // Returns true on success, false on syntax error.
+ bool Initialize(const string &source);
+
+ // Outputs an error message, with context info.
+ void ReportError(const string &error_message);
+
+ // Moves to the next input character.
+ void Next();
+
+ // Moves to the next input item. Sets item_text_ and item_type_ accordingly.
+ // Returns true on success, false on syntax error.
+ bool NextItem();
+
+ // Parses a feature descriptor.
+ // Returns true on success, false on syntax error.
+ bool ParseFeature(FeatureFunctionDescriptor *result);
+
+ // Parses a parameter specification.
+ // Returns true on success, false on syntax error.
+ bool ParseParameter(FeatureFunctionDescriptor *result);
+
+ // Returns true if end of source input has been reached.
+ bool eos() const { return current_ >= source_.end(); }
+
+ // Returns current character. Other methods should access the current
+ // character through this method (instead of using *current_ directly): this
+ // method performs extra safety checks.
+ //
+ // In case of an unsafe access, returns '\0'.
+ char CurrentChar() const {
+ if ((current_ >= source_.begin()) && (current_ < source_.end())) {
+ return *current_;
+ } else {
+ SAFTM_LOG(ERROR) << "Unsafe char read";
+ return '\0';
+ }
+ }
+
+ // Item types.
+ enum ItemTypes {
+ END = 0,
+ NAME = -1,
+ NUMBER = -2,
+ STRING = -3,
+ };
+
+ // Source text.
+ string source_;
+
+ // Current input position.
+ string::iterator current_;
+
+ // Line number for current input position.
+ int line_number_;
+
+ // Start position for current item.
+ string::iterator item_start_;
+
+ // Start position for current line.
+ string::iterator line_start_;
+
+ // Line number for current item.
+ int item_line_number_;
+
+ // Item type for current item. If this is positive it is interpreted as a
+ // character. If it is negative it is interpreted as an item type.
+ int item_type_;
+
+ // Text for current item.
+ string item_text_;
+};
+
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_FEL_PARSER_H_
diff --git a/lang_id/common/fel/task-context.cc b/lang_id/common/fel/task-context.cc
new file mode 100644
index 0000000..f8b0701
--- /dev/null
+++ b/lang_id/common/fel/task-context.cc
@@ -0,0 +1,74 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/common/fel/task-context.h"
+
+#include <string>
+
+#include "lang_id/common/lite_strings/numbers.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+string TaskContext::GetInputPath(const string &name) const {
+ auto it = inputs_.find(name);
+ if (it != inputs_.end()) {
+ return it->second;
+ }
+ return "";
+}
+
+void TaskContext::SetInputPath(const string &name, const string &path) {
+ inputs_[name] = path;
+}
+
+string TaskContext::Get(const string &name, const char *defval) const {
+ auto it = parameters_.find(name);
+ if (it != parameters_.end()) {
+ return it->second;
+ }
+ return defval;
+}
+
+int TaskContext::Get(const string &name, int defval) const {
+ const string s = Get(name, "");
+ int value = defval;
+ if (LiteAtoi(s, &value)) {
+ return value;
+ }
+ return defval;
+}
+
+float TaskContext::Get(const string &name, float defval) const {
+ const string s = Get(name, "");
+ float value = defval;
+ if (LiteAtof(s, &value)) {
+ return value;
+ }
+ return defval;
+}
+
+bool TaskContext::Get(const string &name, bool defval) const {
+ string value = Get(name, "");
+ return value.empty() ? defval : value == "true";
+}
+
+void TaskContext::SetParameter(const string &name, const string &value) {
+ parameters_[name] = value;
+}
+
+} // namespace mobile
+} // namespace nlp_saft
diff --git a/lang_id/common/fel/task-context.h b/lang_id/common/fel/task-context.h
new file mode 100644
index 0000000..ddc8cfe
--- /dev/null
+++ b/lang_id/common/fel/task-context.h
@@ -0,0 +1,72 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef TC3_STD_STRING_IMPORT
+#define TC3_STD_STRING_IMPORT
+#include <string>
+
+namespace libtextclassifier3 {
+using string = std::string;
+template <class CharT, class Traits = std::char_traits<CharT>,
+ class Allocator = std::allocator<CharT> >
+using basic_string = std::basic_string<CharT, Traits, Allocator>;
+} // namespace libtextclassifier3
+#endif
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_TASK_CONTEXT_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_TASK_CONTEXT_H_
+
+#include <map>
+#include <string>
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+// Class that provides access to model parameter and inputs.
+//
+// Note: This class is related to the servers-side nlp_saft::TaskContext, but it
+// has been simplified to reduce code dependencies.
+class TaskContext {
+ public:
+ // Returns path for the input named |name|. Returns empty string ("") if
+ // there is no input with that name. Note: this can be a standard file path,
+ // or a path in a more special file system.
+ string GetInputPath(const string &name) const;
+
+ // Sets path for input |name|. Previous path, if any, is overwritten.
+ void SetInputPath(const string &name, const string &path);
+
+ // Returns parameter value. If the parameter is not specified in this
+ // context, the default value is returned.
+ string Get(const string &name, const char *defval) const;
+ int Get(const string &name, int defval) const;
+ float Get(const string &name, float defval) const;
+ bool Get(const string &name, bool defval) const;
+
+ // Sets value of parameter |name| to |value|.
+ void SetParameter(const string &name, const string &value);
+
+ private:
+ // Maps input name -> path.
+ std::map<string, string> inputs_;
+
+ // Maps parameter name -> value.
+ std::map<string, string> parameters_;
+};
+
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_TASK_CONTEXT_H_
diff --git a/lang_id/common/fel/workspace.cc b/lang_id/common/fel/workspace.cc
new file mode 100644
index 0000000..8cab281
--- /dev/null
+++ b/lang_id/common/fel/workspace.cc
@@ -0,0 +1,58 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/common/fel/workspace.h"
+
+#include <atomic>
+#include <string>
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+// static
+int GetFreshTypeId() {
+ // Static local below is initialized the first time this method is run.
+ static std::atomic<int> counter(0);
+ return counter++;
+}
+
+string WorkspaceRegistry::DebugString() const {
+ string str;
+ for (auto &it : workspace_names_) {
+ const string &type_name = workspace_types_.at(it.first);
+ for (size_t index = 0; index < it.second.size(); ++index) {
+ const string &workspace_name = it.second[index];
+ str.append("\n ");
+ str.append(type_name);
+ str.append(" :: ");
+ str.append(workspace_name);
+ }
+ }
+ return str;
+}
+
+VectorIntWorkspace::VectorIntWorkspace(int size) : elements_(size) {}
+
+VectorIntWorkspace::VectorIntWorkspace(int size, int value)
+ : elements_(size, value) {}
+
+VectorIntWorkspace::VectorIntWorkspace(const std::vector<int> &elements)
+ : elements_(elements) {}
+
+string VectorIntWorkspace::TypeName() { return "Vector"; }
+
+} // namespace mobile
+} // namespace nlp_saft
diff --git a/lang_id/common/fel/workspace.h b/lang_id/common/fel/workspace.h
new file mode 100644
index 0000000..09095e4
--- /dev/null
+++ b/lang_id/common/fel/workspace.h
@@ -0,0 +1,204 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Notes on thread-safety: All of the classes here are thread-compatible. More
+// specifically, the registry machinery is thread-safe, as long as each thread
+// performs feature extraction on a different Sentence object.
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_WORKSPACE_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_WORKSPACE_H_
+
+#include <stddef.h>
+#include <string>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "lang_id/common/lite_base/logging.h"
+#include "lang_id/common/lite_base/macros.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+// A base class for shared workspaces. Derived classes implement a static member
+// function TypeName() which returns a human readable string name for the class.
+class Workspace {
+ public:
+ // Polymorphic destructor.
+ virtual ~Workspace() {}
+
+ protected:
+ // Create an empty workspace.
+ Workspace() {}
+
+ private:
+ SAFTM_DISALLOW_COPY_AND_ASSIGN(Workspace);
+};
+
+// Returns a new, strictly increasing int every time it is invoked.
+int GetFreshTypeId();
+
+// Struct to simulate typeid, but without RTTI.
+template <typename T>
+struct TypeId {
+ static int type_id;
+};
+
+template <typename T>
+int TypeId<T>::type_id = GetFreshTypeId();
+
+// A registry that keeps track of workspaces.
+class WorkspaceRegistry {
+ public:
+ // Create an empty registry.
+ WorkspaceRegistry() {}
+
+ // Returns the index of a named workspace, adding it to the registry first
+ // if necessary.
+ template <class W>
+ int Request(const string &name) {
+ const int id = TypeId<W>::type_id;
+ max_workspace_id_ = std::max(id, max_workspace_id_);
+ workspace_types_[id] = W::TypeName();
+ std::vector<string> &names = workspace_names_[id];
+ for (int i = 0; i < names.size(); ++i) {
+ if (names[i] == name) return i;
+ }
+ names.push_back(name);
+ return names.size() - 1;
+ }
+
+ // Returns the maximum workspace id that has been registered.
+ int MaxId() const {
+ return max_workspace_id_;
+ }
+
+ const std::unordered_map<int, std::vector<string> > &WorkspaceNames()
+ const {
+ return workspace_names_;
+ }
+
+ // Returns a string describing the registered workspaces.
+ string DebugString() const;
+
+ private:
+ // Workspace type names, indexed as workspace_types_[typeid].
+ std::unordered_map<int, string> workspace_types_;
+
+ // Workspace names, indexed as workspace_names_[typeid][workspace].
+ std::unordered_map<int, std::vector<string> > workspace_names_;
+
+ // The maximum workspace id that has been registered.
+ int max_workspace_id_ = 0;
+
+ SAFTM_DISALLOW_COPY_AND_ASSIGN(WorkspaceRegistry);
+};
+
+// A typed collected of workspaces. The workspaces are indexed according to an
+// external WorkspaceRegistry. If the WorkspaceSet is const, the contents are
+// also immutable.
+class WorkspaceSet {
+ public:
+ ~WorkspaceSet() { Reset(WorkspaceRegistry()); }
+
+ // Returns true if a workspace has been set.
+ template <class W>
+ bool Has(int index) const {
+ const int id = TypeId<W>::type_id;
+ SAFTM_DCHECK_GE(id, 0);
+ SAFTM_DCHECK_LT(id, workspaces_.size());
+ SAFTM_DCHECK_GE(index, 0);
+ SAFTM_DCHECK_LT(index, workspaces_[id].size());
+ if (id >= workspaces_.size()) return false;
+ return workspaces_[id][index] != nullptr;
+ }
+
+ // Returns an indexed workspace; the workspace must have been set.
+ template <class W>
+ const W &Get(int index) const {
+ SAFTM_DCHECK(Has<W>(index));
+ const int id = TypeId<W>::type_id;
+ const Workspace *w = workspaces_[id][index];
+ return reinterpret_cast<const W &>(*w);
+ }
+
+ // Sets an indexed workspace; this takes ownership of the workspace, which
+ // must have been new-allocated. It is an error to set a workspace twice.
+ template <class W>
+ void Set(int index, W *workspace) {
+ const int id = TypeId<W>::type_id;
+ SAFTM_DCHECK_GE(id, 0);
+ SAFTM_DCHECK_LT(id, workspaces_.size());
+ SAFTM_DCHECK_GE(index, 0);
+ SAFTM_DCHECK_LT(index, workspaces_[id].size());
+ SAFTM_DCHECK(workspaces_[id][index] == nullptr);
+ SAFTM_DCHECK(workspace != nullptr);
+ workspaces_[id][index] = workspace;
+ }
+
+ void Reset(const WorkspaceRegistry ®istry) {
+ // Deallocate current workspaces.
+ for (auto &it : workspaces_) {
+ for (size_t index = 0; index < it.size(); ++index) {
+ delete it[index];
+ }
+ }
+ workspaces_.clear();
+ workspaces_.resize(registry.MaxId() + 1, std::vector<Workspace *>());
+ for (auto &it : registry.WorkspaceNames()) {
+ workspaces_[it.first].resize(it.second.size());
+ }
+ }
+
+ private:
+ // The set of workspaces, indexed as workspaces_[typeid][index].
+ std::vector<std::vector<Workspace *> > workspaces_;
+};
+
+// A workspace that wraps around a vector of int.
+class VectorIntWorkspace : public Workspace {
+ public:
+ // Creates a vector of the given size.
+ explicit VectorIntWorkspace(int size);
+
+ // Creates a vector initialized with the given array.
+ explicit VectorIntWorkspace(const std::vector<int> &elements);
+
+ // Creates a vector of the given size, with each element initialized to the
+ // given value.
+ VectorIntWorkspace(int size, int value);
+
+ // Returns the name of this type of workspace.
+ static string TypeName();
+
+ // Returns the i'th element.
+ int element(int i) const { return elements_[i]; }
+
+ // Sets the i'th element.
+ void set_element(int i, int value) { elements_[i] = value; }
+
+ // Returns the size of the underlying vector.
+ int size() const { return elements_.size(); }
+
+ private:
+ // The enclosed vector.
+ std::vector<int> elements_;
+};
+
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_WORKSPACE_H_
diff --git a/lang_id/common/file/file-utils.cc b/lang_id/common/file/file-utils.cc
new file mode 100644
index 0000000..108c7d5
--- /dev/null
+++ b/lang_id/common/file/file-utils.cc
@@ -0,0 +1,65 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/common/file/file-utils.h"
+
+#include <fcntl.h>
+#include <stdio.h>
+#include <sys/stat.h>
+#include <sys/types.h>
+
+#include "lang_id/common/lite_base/logging.h"
+#include "lang_id/common/lite_strings/stringpiece.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+namespace file_utils {
+
+bool GetFileContent(const string &filename, string *content) {
+ ScopedMmap scoped_mmap(filename);
+ const MmapHandle &handle = scoped_mmap.handle();
+ if (!handle.ok()) {
+ SAFTM_LOG(ERROR) << "Error opening " << filename;
+ return false;
+ }
+ StringPiece sp = handle.to_stringpiece();
+ content->assign(sp.data(), sp.size());
+ return true;
+}
+
+bool FileExists(const string &filename) {
+ struct stat s = {0};
+ if (!stat(filename.c_str(), &s)) {
+ return s.st_mode & S_IFREG;
+ } else {
+ return false;
+ }
+}
+
+bool DirectoryExists(const string &dirpath) {
+ struct stat s = {0};
+ if (!stat(dirpath.c_str(), &s)) {
+ return s.st_mode & S_IFDIR;
+ } else {
+ return false;
+ }
+}
+
+} // namespace file_utils
+
+} // namespace mobile
+} // namespace nlp_saft
diff --git a/lang_id/common/file/file-utils.h b/lang_id/common/file/file-utils.h
new file mode 100644
index 0000000..6377d7a
--- /dev/null
+++ b/lang_id/common/file/file-utils.h
@@ -0,0 +1,83 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_FILE_FILE_UTILS_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_FILE_FILE_UTILS_H_
+
+#include <stddef.h>
+#include <string>
+
+#include "lang_id/common/file/mmap.h"
+#include "lang_id/common/lite_strings/stringpiece.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+namespace file_utils {
+
+// Reads the entire content of a file into a string. Returns true on success,
+// false on error.
+bool GetFileContent(const string &filename, string *content);
+
+// Parses a proto from its serialized representation in memory. That
+// representation starts at address |data| and should contain exactly
+// |num_bytes| bytes. Returns true on success, false otherwise.
+template <class Proto>
+bool ParseProtoFromMemory(const char *data, size_t num_bytes, Proto *proto) {
+ if (data == nullptr) {
+ // Avoid passing a nullptr to ParseFromArray below.
+ return false;
+ }
+ return proto->ParseFromArray(data, num_bytes);
+}
+
+// Convenience StringPiece-based version of ParseProtoFromMemory.
+template <class Proto>
+inline bool ParseProtoFromMemory(StringPiece sp, Proto *proto) {
+ return ParseProtoFromMemory(sp.data(), sp.size(), proto);
+}
+
+// Parses a proto from a file. Returns true on success, false otherwise.
+//
+// Note: the entire content of the file should be the binary (not
+// human-readable) serialization of a protocol buffer.
+//
+// Note: when we compile for Android, the proto parsing methods need to know the
+// type of the message they are parsing. We use template polymorphism for that.
+template<class Proto>
+bool ReadProtoFromFile(const string &filename, Proto *proto) {
+ ScopedMmap scoped_mmap(filename);
+ const MmapHandle &handle = scoped_mmap.handle();
+ if (!handle.ok()) {
+ return false;
+ }
+ return ParseProtoFromMemory(handle.to_stringpiece(), proto);
+}
+
+// Returns true if filename is the name of an existing file, and false
+// otherwise.
+bool FileExists(const string &filename);
+
+// Returns true if dirpath is the path to an existing directory, and false
+// otherwise.
+bool DirectoryExists(const string &dirpath);
+
+} // namespace file_utils
+
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_FILE_FILE_UTILS_H_
diff --git a/lang_id/common/file/mmap.cc b/lang_id/common/file/mmap.cc
new file mode 100644
index 0000000..89efa99
--- /dev/null
+++ b/lang_id/common/file/mmap.cc
@@ -0,0 +1,133 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/common/file/mmap.h"
+
+#include <errno.h>
+#include <fcntl.h>
+#include <stdint.h>
+#include <string.h>
+#include <sys/mman.h>
+#include <sys/stat.h>
+#include <unistd.h>
+
+#include "lang_id/common/lite_base/logging.h"
+#include "lang_id/common/lite_base/macros.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+namespace {
+inline string GetLastSystemError() {
+ return string(strerror(errno));
+}
+
+inline MmapHandle GetErrorMmapHandle() {
+ return MmapHandle(nullptr, 0);
+}
+
+class FileCloser {
+ public:
+ explicit FileCloser(int fd) : fd_(fd) {}
+ ~FileCloser() {
+ int result = close(fd_);
+ if (result != 0) {
+ const string last_error = GetLastSystemError();
+ SAFTM_LOG(ERROR) << "Error closing file descriptor: " << last_error;
+ }
+ }
+ private:
+ const int fd_;
+
+ SAFTM_DISALLOW_COPY_AND_ASSIGN(FileCloser);
+};
+} // namespace
+
+MmapHandle MmapFile(const string &filename) {
+ int fd = open(filename.c_str(), O_RDONLY);
+
+ if (fd < 0) {
+ const string last_error = GetLastSystemError();
+ SAFTM_LOG(ERROR) << "Error opening " << filename << ": " << last_error;
+ return GetErrorMmapHandle();
+ }
+
+ // Make sure we close fd no matter how we exit this function. As the man page
+ // for mmap clearly states: "closing the file descriptor does not unmap the
+ // region." Hence, we can close fd as soon as we return from here.
+ FileCloser file_closer(fd);
+
+ return MmapFile(fd);
+}
+
+MmapHandle MmapFile(int fd) {
+ // Get file stats to obtain file size.
+ struct stat sb;
+ if (fstat(fd, &sb) != 0) {
+ const string last_error = GetLastSystemError();
+ SAFTM_LOG(ERROR) << "Unable to stat fd: " << last_error;
+ return GetErrorMmapHandle();
+ }
+ size_t file_size_in_bytes = static_cast<size_t>(sb.st_size);
+
+ // Perform actual mmap.
+ void *mmap_addr = mmap(
+
+ // Let system pick address for mmapp-ed data.
+ nullptr,
+
+ // Mmap all bytes from the file.
+ file_size_in_bytes,
+
+ // One can read / write the mapped data (but see MAP_PRIVATE below).
+ // Normally, we expect only to read it, but in the future, we may want to
+ // write it, to fix e.g., endianness differences.
+ PROT_READ | PROT_WRITE,
+
+ // Updates to mmaped data are *not* propagated to actual file.
+ // AFAIK(salcianu) that's anyway not possible on Android.
+ MAP_PRIVATE,
+
+ // Descriptor of file to mmap.
+ fd,
+
+ // Map bytes right from the beginning of the file. This, and
+ // file_size_in_bytes (2nd argument) means we map all bytes from the file.
+ 0);
+ if (mmap_addr == MAP_FAILED) {
+ const string last_error = GetLastSystemError();
+ SAFTM_LOG(ERROR) << "Error while mmapping: " << last_error;
+ return GetErrorMmapHandle();
+ }
+
+ return MmapHandle(mmap_addr, file_size_in_bytes);
+}
+
+bool Unmap(MmapHandle mmap_handle) {
+ if (!mmap_handle.ok()) {
+ // Unmapping something that hasn't been mapped is trivially successful.
+ return true;
+ }
+ if (munmap(mmap_handle.start(), mmap_handle.num_bytes()) != 0) {
+ const string last_error = GetLastSystemError();
+ SAFTM_LOG(ERROR) << "Error during Unmap / munmap: " << last_error;
+ return false;
+ }
+ return true;
+}
+
+} // namespace mobile
+} // namespace nlp_saft
diff --git a/lang_id/common/file/mmap.h b/lang_id/common/file/mmap.h
new file mode 100644
index 0000000..6131803
--- /dev/null
+++ b/lang_id/common/file/mmap.h
@@ -0,0 +1,120 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_FILE_MMAP_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_FILE_MMAP_H_
+
+#include <stddef.h>
+
+#include <string>
+
+#include "lang_id/common/lite_strings/stringpiece.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+// Handle for a memory area where a file has been mmapped.
+//
+// Similar to a pointer: you "allocate" it using MmapFile(filename) and "delete"
+// it using Unmap(). Just like a pointer, it is passed around by value (see
+// signature of MmapFile and Unmap; fortunately, it's a small class, so there
+// shouldn't be any significant performance penalty) and its usage is not
+// necessarily scoped (that's why the destructor is not performing the unmap).
+//
+// Note: on program termination, each still unmapped file is automatically
+// unmapped. Hence, it is not an error if you don't call Unmap() (provided you
+// are ok keeping that file in memory the whole time).
+class MmapHandle {
+ public:
+ MmapHandle(void *start, size_t num_bytes)
+ : start_(start), num_bytes_(num_bytes) {}
+
+ // Returns start address for the memory area where a file has been mmapped.
+ void *start() const { return start_; }
+
+ // Returns number of bytes of the memory area from start().
+ size_t num_bytes() const { return num_bytes_; }
+
+ // Shortcut to simplify checking success of MmapFile(). See usage example
+ // from the doc of that function.
+ bool ok() const { return start() != nullptr; }
+
+ // Returns a StringPiece pointing to the same underlying bytes.
+ StringPiece to_stringpiece() const {
+ return StringPiece(reinterpret_cast<char *>(start_), num_bytes_);
+ }
+
+ private:
+ // See doc for start(). Not owned.
+ void *const start_;
+
+ // See doc for num_bytes().
+ const size_t num_bytes_;
+};
+
+// Maps the full content of a file in memory (using mmap).
+//
+// When done using the file content, one can unmap using Unmap(). Otherwise,
+// all mapped files are unmapped when the program terminates.
+//
+// Sample usage:
+//
+// MmapHandle mmap_handle = MmapFile(filename);
+// CHECK(mmap_handle.ok()) << "Unable to mmap " << filename;
+//
+// ... use data from addresses
+// ... [mmap_handle.start, mmap_handle.start + mmap_handle.num_bytes)
+//
+// Unmap(mmap_handle); // Unmap logs errors internally.
+//
+// Note: one can read *and* write the num_bytes bytes from start, but those
+// writes are not propagated to the underlying file, nor to other processes that
+// may have mmapped that file (all changes are local to current process).
+MmapHandle MmapFile(const string &filename);
+
+// Like MmapFile(const string &filename), but uses a file descriptor.
+MmapHandle MmapFile(int fd);
+
+// Unmaps a file mapped using MmapFile. Returns true on success, false
+// otherwise.
+bool Unmap(MmapHandle mmap_handle);
+
+// Scoped mmapping of a file. Mmaps a file on construction, unmaps it on
+// destruction.
+class ScopedMmap {
+ public:
+ explicit ScopedMmap(const string &filename)
+ : handle_(MmapFile(filename)) {}
+
+ explicit ScopedMmap(int fd)
+ : handle_(MmapFile(fd)) {}
+
+ ~ScopedMmap() {
+ if (handle_.ok()) {
+ Unmap(handle_);
+ }
+ }
+
+ const MmapHandle &handle() { return handle_; }
+
+ private:
+ MmapHandle handle_;
+};
+
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_FILE_MMAP_H_
diff --git a/lang_id/common/flatbuffers/embedding-network-params-from-flatbuffer.cc b/lang_id/common/flatbuffers/embedding-network-params-from-flatbuffer.cc
new file mode 100644
index 0000000..ee22420
--- /dev/null
+++ b/lang_id/common/flatbuffers/embedding-network-params-from-flatbuffer.cc
@@ -0,0 +1,449 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/common/flatbuffers/embedding-network-params-from-flatbuffer.h"
+
+#include "lang_id/common/lite_base/endian.h"
+#include "lang_id/common/lite_base/logging.h"
+#include "lang_id/common/lite_base/macros.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+namespace {
+// Returns true if and only if ptr points to a location inside allowed_range.
+bool IsPointerInRange(const char *ptr, StringPiece allowed_range) {
+ return (ptr >= allowed_range.data()) &&
+ (ptr < (allowed_range.data() + allowed_range.size()));
+}
+
+// Returns true if and only if the memory range [start, start +
+// range_size_in_bytes) is included inside allowed_range.
+//
+// Special case: if range_size_in_bytes == 0 (empty range) then we require that
+// start is nullptr or in the allowed_range.
+bool IsMemoryRangeValid(const void *start, int range_size_in_bytes,
+ StringPiece allowed_range) {
+ const char *begin = reinterpret_cast<const char *>(start);
+ if (range_size_in_bytes < 0) {
+ return false;
+ }
+ if (range_size_in_bytes == 0) {
+ return (start == nullptr) || IsPointerInRange(begin, allowed_range);
+ }
+ const char *inclusive_end = begin + (range_size_in_bytes - 1);
+ return (begin <= inclusive_end) && IsPointerInRange(begin, allowed_range) &&
+ IsPointerInRange(inclusive_end, allowed_range);
+}
+
+bool VerifyQuantizationScales(EmbeddingNetworkParams::Matrix matrix,
+ StringPiece bytes) {
+ if (matrix.quant_scales == nullptr) {
+ SAFTM_LOG(ERROR) << "Quantization type "
+ << static_cast<int>(matrix.quant_type)
+ << "; but no quantization scales";
+ return false;
+ }
+ bool valid_scales = IsMemoryRangeValid(matrix.quant_scales,
+ matrix.rows * sizeof(float16), bytes);
+ if (!valid_scales) {
+ SAFTM_LOG(ERROR) << "quantization scales not fully inside bytes";
+ return false;
+ }
+ return true;
+}
+
+// Returns false if we detect a problem with |matrix|, true otherwise. E.g., we
+// check that the array that starts at pointer matrix.elements is fully inside
+// |bytes| (the range of bytes passed to the
+// EmbeddingNetworkParamsFromFlatbuffer constructor).
+bool VerifyMatrix(EmbeddingNetworkParams::Matrix matrix, StringPiece bytes) {
+ if ((matrix.rows < 0) || (matrix.cols < 0)) {
+ SAFTM_LOG(ERROR) << "Wrong matrix geometry: " << matrix.rows << " x "
+ << matrix.cols;
+ return false;
+ }
+
+ const int num_elements = matrix.rows * matrix.cols;
+
+ // Number of bytes occupied by the num_elements elements that start at address
+ // matrix.elements.
+ int element_range_size_in_bytes = 0;
+ switch (matrix.quant_type) {
+ case QuantizationType::NONE:
+ element_range_size_in_bytes = num_elements * sizeof(float);
+ break;
+ case QuantizationType::UINT8: {
+ element_range_size_in_bytes = num_elements;
+ if (!VerifyQuantizationScales(matrix, bytes)) {
+ return false;
+ }
+ break;
+ }
+ case QuantizationType::UINT4: {
+ if (matrix.cols % 2 != 0) {
+ SAFTM_LOG(ERROR) << "UINT4 doesn't work with odd #cols" << matrix.cols;
+ return false;
+ }
+ element_range_size_in_bytes = num_elements / 2;
+ if (!VerifyQuantizationScales(matrix, bytes)) {
+ return false;
+ }
+ break;
+ }
+ case QuantizationType::FLOAT16: {
+ element_range_size_in_bytes = num_elements * sizeof(float16);
+
+ // No need to verify the scales: FLOAT16 quantization does not use scales.
+ break;
+ }
+ default:
+ SAFTM_LOG(ERROR) << "Unsupported quantization type "
+ << static_cast<int>(matrix.quant_type);
+ return false;
+ }
+ if (matrix.elements == nullptr) {
+ SAFTM_LOG(ERROR) << "matrix.elements == nullptr";
+ return false;
+ }
+ bool valid =
+ IsMemoryRangeValid(matrix.elements, element_range_size_in_bytes, bytes);
+ if (!valid) {
+ SAFTM_LOG(ERROR) << "elements not fully inside bytes";
+ return false;
+ }
+ return true;
+}
+
+// Checks the geometry of the network layer represented by |weights| and |bias|,
+// assuming the input to this layer has size |input_size|. Returns false if we
+// detect any problem, true otherwise.
+bool GoodLayerGeometry(int input_size,
+ const EmbeddingNetworkParams::Matrix &weights,
+ const EmbeddingNetworkParams::Matrix &bias) {
+ if (weights.rows != input_size) {
+ SAFTM_LOG(ERROR) << "#rows " << weights.rows << " != " << input_size;
+ return false;
+ }
+ if ((bias.rows != 1) && (bias.cols != 1)) {
+ SAFTM_LOG(ERROR) << "bad bias vector geometry: " << bias.rows << " x "
+ << bias.cols;
+ return false;
+ }
+ int bias_dimension = bias.rows * bias.cols;
+ if (weights.cols != bias_dimension) {
+ SAFTM_LOG(ERROR) << "#cols " << weights.cols << " != " << bias_dimension;
+ return false;
+ }
+ return true;
+}
+} // namespace
+
+EmbeddingNetworkParamsFromFlatbuffer::EmbeddingNetworkParamsFromFlatbuffer(
+ StringPiece bytes) {
+ // We expect valid_ to be initialized to false at this point. We set it to
+ // true only if we successfully complete all initialization. On error, we
+ // return early, leaving valid_ set to false.
+ SAFTM_DCHECK(!valid_);
+
+ // NOTE: current EmbeddingNetworkParams API works only on little-endian
+ // machines. Fortunately, all modern devices are little-endian so, instead of
+ // a costly API change, we support only the little-endian case.
+ //
+ // Technical explanation: for each Matrix, our API provides a pointer to the
+ // matrix elements (see Matrix field |elements|). For unquantized matrices,
+ // that's a const float *pointer; the client code (e.g., Neurosis) uses those
+ // floats directly. That is correct if the EmbeddingNetworkParams come from a
+ // proto, where the proto parsing already handled the endianness differences.
+ // But in the flatbuffer case, that's a pointer to floats in little-endian
+ // format (flatbuffers always use little-endian). If our API provided access
+ // to only one element at a time, the accessor method could swap the bytes "on
+ // the fly", using temporary variables. Instead, our API provides a pointer
+ // to all elements: as their number is variable (and underlying data is
+ // immutable), we can't ensure the bytes of all those elements are swapped
+ // without extra memory allocation to store the swapped bytes (which is what
+ // using flatbuffers is supposed to prevent).
+ if (!LittleEndian::IsLittleEndian()) {
+ SAFTM_LOG(INFO) << "Not a little-endian machine";
+ return;
+ }
+
+ const uint8_t *start = reinterpret_cast<const uint8_t *>(bytes.data());
+ if (start == nullptr) {
+ // Note: as |bytes| is expected to be a valid EmbeddingNetwork flatbuffer,
+ // it should contain the 4-char identifier "NS00" (or a later version). It
+ // can't be empty; hence StringPiece(nullptr, 0) is not legal here.
+ SAFTM_LOG(ERROR) << "nullptr bytes";
+ return;
+ }
+ flatbuffers::Verifier verifier(start, bytes.size());
+ if (!saft_fbs::VerifyEmbeddingNetworkBuffer(verifier)) {
+ SAFTM_LOG(ERROR) << "Not a valid EmbeddingNetwork flatbuffer";
+ return;
+ }
+ network_ = saft_fbs::GetEmbeddingNetwork(start);
+ if (network_ == nullptr) {
+ SAFTM_LOG(ERROR) << "Unable to interpret bytes as a flatbuffer";
+ return;
+ }
+
+ // Perform a few extra checks before declaring this object valid.
+ valid_ = ValidityChecking(bytes);
+}
+
+bool EmbeddingNetworkParamsFromFlatbuffer::ValidityChecking(
+ StringPiece bytes) const {
+ int input_size = 0;
+ for (int i = 0; i < embeddings_size(); ++i) {
+ Matrix embeddings = GetEmbeddingMatrix(i);
+ if (!VerifyMatrix(embeddings, bytes)) {
+ SAFTM_LOG(ERROR) << "Bad embedding matrix #" << i;
+ return false;
+ }
+ input_size += embedding_num_features(i) * embeddings.cols;
+ }
+ int current_size = input_size;
+ for (int i = 0; i < hidden_size(); ++i) {
+ Matrix weights = GetHiddenLayerMatrix(i);
+ if (!VerifyMatrix(weights, bytes)) {
+ SAFTM_LOG(ERROR) << "Bad weights matrix for hidden layer #" << i;
+ return false;
+ }
+ Matrix bias = GetHiddenLayerBias(i);
+ if (!VerifyMatrix(bias, bytes)) {
+ SAFTM_LOG(ERROR) << "Bad bias vector for hidden layer #" << i;
+ return false;
+ }
+ if (!GoodLayerGeometry(current_size, weights, bias)) {
+ SAFTM_LOG(ERROR) << "Bad geometry for hidden layer #" << i;
+ return false;
+ }
+ current_size = weights.cols;
+ }
+
+ if (HasSoftmax()) {
+ Matrix weights = GetSoftmaxMatrix();
+ if (!VerifyMatrix(weights, bytes)) {
+ SAFTM_LOG(ERROR) << "Bad weights matrix for softmax";
+ return false;
+ }
+ Matrix bias = GetSoftmaxBias();
+ if (!VerifyMatrix(bias, bytes)) {
+ SAFTM_LOG(ERROR) << "Bad bias vector for softmax";
+ return false;
+ }
+ if (!GoodLayerGeometry(current_size, weights, bias)) {
+ SAFTM_LOG(ERROR) << "Bad geometry for softmax layer";
+ return false;
+ }
+ }
+ return true;
+}
+
+// static
+bool EmbeddingNetworkParamsFromFlatbuffer::InRangeIndex(int index, int limit,
+ const char *info) {
+ if ((index >= 0) && (index < limit)) {
+ return true;
+ } else {
+ SAFTM_LOG(ERROR) << info << " index " << index << " outside range [0, "
+ << limit << ")";
+ return false;
+ }
+}
+
+int EmbeddingNetworkParamsFromFlatbuffer::SafeGetNumInputChunks() const {
+ const auto *input_chunks = network_->input_chunks();
+ if (input_chunks == nullptr) {
+ SAFTM_LOG(ERROR) << "nullptr input_chunks";
+ return 0;
+ }
+ return input_chunks->size();
+}
+
+const saft_fbs::InputChunk *
+EmbeddingNetworkParamsFromFlatbuffer::SafeGetInputChunk(int i) const {
+ if (!InRangeIndex(i, SafeGetNumInputChunks(), "input chunks")) {
+ return nullptr;
+ }
+ const auto *input_chunks = network_->input_chunks();
+ if (input_chunks == nullptr) {
+ // Execution should not reach this point, due to how SafeGetNumInputChunks()
+ // is implemented. Still, just to be sure:
+ SAFTM_LOG(ERROR) << "nullptr input_chunks";
+ return nullptr;
+ }
+ const saft_fbs::InputChunk *input_chunk = input_chunks->Get(i);
+ if (input_chunk == nullptr) {
+ SAFTM_LOG(ERROR) << "nullptr input chunk #" << i;
+ }
+ return input_chunk;
+}
+
+const saft_fbs::Matrix *
+EmbeddingNetworkParamsFromFlatbuffer::SafeGetEmbeddingMatrix(int i) const {
+ const saft_fbs::InputChunk *input_chunk = SafeGetInputChunk(i);
+ if (input_chunk == nullptr) return nullptr;
+ const saft_fbs::Matrix *matrix = input_chunk->embedding();
+ if (matrix == nullptr) {
+ SAFTM_LOG(ERROR) << "nullptr embeding matrix #" << i;
+ }
+ return matrix;
+}
+
+int EmbeddingNetworkParamsFromFlatbuffer::SafeGetNumLayers() const {
+ const auto *layers = network_->layers();
+ if (layers == nullptr) {
+ SAFTM_LOG(ERROR) << "nullptr layers";
+ return 0;
+ }
+ return layers->size();
+}
+
+const saft_fbs::NeuralLayer *EmbeddingNetworkParamsFromFlatbuffer::SafeGetLayer(
+ int i) const {
+ if (!InRangeIndex(i, SafeGetNumLayers(), "layer")) {
+ return nullptr;
+ }
+ const auto *layers = network_->layers();
+ if (layers == nullptr) {
+ // Execution should not reach this point, due to how SafeGetNumLayers()
+ // is implemented. Still, just to be sure:
+ SAFTM_LOG(ERROR) << "nullptr layers";
+ return nullptr;
+ }
+ const saft_fbs::NeuralLayer *layer = layers->Get(i);
+ if (layer == nullptr) {
+ SAFTM_LOG(ERROR) << "nullptr layer #" << i;
+ }
+ return layer;
+}
+
+const saft_fbs::Matrix *
+EmbeddingNetworkParamsFromFlatbuffer::SafeGetLayerWeights(int i) const {
+ const saft_fbs::NeuralLayer *layer = SafeGetLayer(i);
+ if (layer == nullptr) return nullptr;
+ const saft_fbs::Matrix *weights = layer->weights();
+ if (weights == nullptr) {
+ SAFTM_LOG(ERROR) << "nullptr weights for layer #" << i;
+ }
+ return weights;
+}
+
+const saft_fbs::Matrix *EmbeddingNetworkParamsFromFlatbuffer::SafeGetLayerBias(
+ int i) const {
+ const saft_fbs::NeuralLayer *layer = SafeGetLayer(i);
+ if (layer == nullptr) return nullptr;
+ const saft_fbs::Matrix *bias = layer->bias();
+ if (bias == nullptr) {
+ SAFTM_LOG(ERROR) << "nullptr bias for layer #" << i;
+ }
+ return bias;
+}
+
+// static
+const float *EmbeddingNetworkParamsFromFlatbuffer::SafeGetValues(
+ const saft_fbs::Matrix *matrix) {
+ if (matrix == nullptr) return nullptr;
+ const flatbuffers::Vector<float> *values = matrix->values();
+ if (values == nullptr) {
+ SAFTM_LOG(ERROR) << "nullptr values";
+ }
+ return values->data();
+}
+
+// static
+const uint8_t *EmbeddingNetworkParamsFromFlatbuffer::SafeGetQuantizedValues(
+ const saft_fbs::Matrix *matrix) {
+ if (matrix == nullptr) return nullptr;
+ const flatbuffers::Vector<uint8_t> *quantized_values =
+ matrix->quantized_values();
+ if (quantized_values == nullptr) {
+ SAFTM_LOG(ERROR) << "nullptr quantized_values";
+ }
+ return quantized_values->data();
+}
+
+// static
+const float16 *EmbeddingNetworkParamsFromFlatbuffer::SafeGetScales(
+ const saft_fbs::Matrix *matrix) {
+ if (matrix == nullptr) return nullptr;
+ const flatbuffers::Vector<uint16_t> *scales = matrix->scales();
+ if (scales == nullptr) {
+ SAFTM_LOG(ERROR) << "nullptr scales";
+ }
+ return scales->data();
+}
+
+const saft_fbs::NeuralLayer *
+EmbeddingNetworkParamsFromFlatbuffer::SafeGetSoftmaxLayer() const {
+ int num_layers = SafeGetNumLayers();
+ if (num_layers <= 0) {
+ SAFTM_LOG(ERROR) << "No softmax layer";
+ return nullptr;
+ }
+ return SafeGetLayer(num_layers - 1);
+}
+
+QuantizationType EmbeddingNetworkParamsFromFlatbuffer::SafeGetQuantizationType(
+ const saft_fbs::Matrix *matrix) const {
+ if (matrix == nullptr) {
+ return QuantizationType::NONE;
+ }
+ saft_fbs::QuantizationType quantization_type = matrix->quantization_type();
+
+ // Conversion from nlp_saft::saft_fbs::QuantizationType to
+ // nlp_saft::QuantizationType (due to legacy reasons, we have both).
+ switch (quantization_type) {
+ case saft_fbs::QuantizationType_NONE:
+ return QuantizationType::NONE;
+ case saft_fbs::QuantizationType_UINT8:
+ return QuantizationType::UINT8;
+ case saft_fbs::QuantizationType_UINT4:
+ return QuantizationType::UINT4;
+ case saft_fbs::QuantizationType_FLOAT16:
+ return QuantizationType::FLOAT16;
+ default:
+ SAFTM_LOG(ERROR) << "Unsupported quantization type "
+ << static_cast<int>(quantization_type);
+ return QuantizationType::NONE;
+ }
+}
+
+const void *EmbeddingNetworkParamsFromFlatbuffer::SafeGetValuesOfMatrix(
+ const saft_fbs::Matrix *matrix) const {
+ if (matrix == nullptr) {
+ return nullptr;
+ }
+ saft_fbs::QuantizationType quantization_type = matrix->quantization_type();
+ switch (quantization_type) {
+ case saft_fbs::QuantizationType_NONE:
+ return SafeGetValues(matrix);
+ case saft_fbs::QuantizationType_UINT8:
+ SAFTM_FALLTHROUGH_INTENDED;
+ case saft_fbs::QuantizationType_UINT4:
+ SAFTM_FALLTHROUGH_INTENDED;
+ case saft_fbs::QuantizationType_FLOAT16:
+ return SafeGetQuantizedValues(matrix);
+ default:
+ SAFTM_LOG(ERROR) << "Unsupported quantization type "
+ << static_cast<int>(quantization_type);
+ return nullptr;
+ }
+}
+
+} // namespace mobile
+} // namespace nlp_saft
diff --git a/lang_id/common/flatbuffers/embedding-network-params-from-flatbuffer.h b/lang_id/common/flatbuffers/embedding-network-params-from-flatbuffer.h
new file mode 100644
index 0000000..57d59c5
--- /dev/null
+++ b/lang_id/common/flatbuffers/embedding-network-params-from-flatbuffer.h
@@ -0,0 +1,285 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_FLATBUFFERS_EMBEDDING_NETWORK_PARAMS_FROM_FLATBUFFER_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_FLATBUFFERS_EMBEDDING_NETWORK_PARAMS_FROM_FLATBUFFER_H_
+
+#include <algorithm>
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "lang_id/common/embedding-network-params.h"
+#include "lang_id/common/flatbuffers/embedding-network_generated.h"
+#include "lang_id/common/lite_base/float16.h"
+#include "lang_id/common/lite_base/logging.h"
+#include "lang_id/common/lite_strings/stringpiece.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+// EmbeddingNetworkParams implementation backed by a flatbuffer.
+//
+// For info on our flatbuffer schema, see embedding-network.fbs.
+class EmbeddingNetworkParamsFromFlatbuffer : public EmbeddingNetworkParams {
+ public:
+ // Constructs an EmbeddingNetworkParamsFromFlatbuffer instance, using the
+ // flatbuffer from |bytes|.
+ //
+ // IMPORTANT #1: caller should make sure |bytes| are alive during the lifetime
+ // of this EmbeddingNetworkParamsFromFlatbuffer instance. To avoid overhead,
+ // this constructor does not copy |bytes|.
+ //
+ // IMPORTANT #2: immediately after this constructor returns, we suggest you
+ // call is_valid() on the newly-constructed object and do not call any other
+ // method if the answer is negative (false).
+ explicit EmbeddingNetworkParamsFromFlatbuffer(StringPiece bytes);
+
+ bool UpdateTaskContextParameters(mobile::TaskContext *task_context) override {
+ // This class does not provide access to the overall TaskContext. It
+ // provides only parameters for the Neurosis neural network.
+ SAFTM_LOG(DFATAL) << "Not supported";
+ return false;
+ }
+
+ bool is_valid() const override { return valid_; }
+
+ int embeddings_size() const override { return SafeGetNumInputChunks(); }
+
+ int embeddings_num_rows(int i) const override {
+ const saft_fbs::Matrix *matrix = SafeGetEmbeddingMatrix(i);
+ return SafeGetNumRows(matrix);
+ }
+
+ int embeddings_num_cols(int i) const override {
+ const saft_fbs::Matrix *matrix = SafeGetEmbeddingMatrix(i);
+ return SafeGetNumCols(matrix);
+ }
+
+ const void *embeddings_weights(int i) const override {
+ const saft_fbs::Matrix *matrix = SafeGetEmbeddingMatrix(i);
+ return SafeGetValuesOfMatrix(matrix);
+ }
+
+ QuantizationType embeddings_quant_type(int i) const override {
+ const saft_fbs::Matrix *matrix = SafeGetEmbeddingMatrix(i);
+ return SafeGetQuantizationType(matrix);
+ }
+
+ const float16 *embeddings_quant_scales(int i) const override {
+ const saft_fbs::Matrix *matrix = SafeGetEmbeddingMatrix(i);
+ return SafeGetScales(matrix);
+ }
+
+ int hidden_size() const override {
+ // -1 because last layer is always the softmax layer.
+ return std::max(SafeGetNumLayers() - 1, 0);
+ }
+
+ int hidden_num_rows(int i) const override {
+ const saft_fbs::Matrix *weights = SafeGetLayerWeights(i);
+ return SafeGetNumRows(weights);
+ }
+
+ int hidden_num_cols(int i) const override {
+ const saft_fbs::Matrix *weights = SafeGetLayerWeights(i);
+ return SafeGetNumCols(weights);
+ }
+
+ QuantizationType hidden_weights_quant_type(int i) const override {
+ const saft_fbs::Matrix *weights = SafeGetLayerWeights(i);
+ return SafeGetQuantizationType(weights);
+ }
+
+ const void *hidden_weights(int i) const override {
+ const saft_fbs::Matrix *weights = SafeGetLayerWeights(i);
+ return SafeGetValuesOfMatrix(weights);
+ }
+
+ int hidden_bias_size() const override { return hidden_size(); }
+
+ int hidden_bias_num_rows(int i) const override {
+ const saft_fbs::Matrix *bias = SafeGetLayerBias(i);
+ return SafeGetNumRows(bias);
+ }
+
+ int hidden_bias_num_cols(int i) const override {
+ const saft_fbs::Matrix *bias = SafeGetLayerBias(i);
+ return SafeGetNumCols(bias);
+ }
+
+ const void *hidden_bias_weights(int i) const override {
+ const saft_fbs::Matrix *bias = SafeGetLayerBias(i);
+ return SafeGetValues(bias);
+ }
+
+ int softmax_size() const override { return (SafeGetNumLayers() > 0) ? 1 : 0; }
+
+ int softmax_num_rows(int i) const override {
+ const saft_fbs::Matrix *weights = SafeGetSoftmaxWeights();
+ return SafeGetNumRows(weights);
+ }
+
+ int softmax_num_cols(int i) const override {
+ const saft_fbs::Matrix *weights = SafeGetSoftmaxWeights();
+ return SafeGetNumCols(weights);
+ }
+
+ QuantizationType softmax_weights_quant_type(int i) const override {
+ const saft_fbs::Matrix *weights = SafeGetSoftmaxWeights();
+ return SafeGetQuantizationType(weights);
+ }
+
+ const void *softmax_weights(int i) const override {
+ const saft_fbs::Matrix *weights = SafeGetSoftmaxWeights();
+ return SafeGetValuesOfMatrix(weights);
+ }
+
+ int softmax_bias_size() const override { return softmax_size(); }
+
+ int softmax_bias_num_rows(int i) const override {
+ const saft_fbs::Matrix *bias = SafeGetSoftmaxBias();
+ return SafeGetNumRows(bias);
+ }
+
+ int softmax_bias_num_cols(int i) const override {
+ const saft_fbs::Matrix *bias = SafeGetSoftmaxBias();
+ return SafeGetNumCols(bias);
+ }
+
+ const void *softmax_bias_weights(int i) const override {
+ const saft_fbs::Matrix *bias = SafeGetSoftmaxBias();
+ return SafeGetValues(bias);
+ }
+
+ int embedding_num_features_size() const override {
+ return SafeGetNumInputChunks();
+ }
+
+ int embedding_num_features(int i) const override {
+ if (!InRangeIndex(i, embedding_num_features_size(),
+ "embedding num features")) {
+ return 0;
+ }
+ const saft_fbs::InputChunk *input_chunk = SafeGetInputChunk(i);
+ if (input_chunk == nullptr) {
+ return 0;
+ }
+ return input_chunk->num_features();
+ }
+
+ bool has_is_precomputed() const override { return false; }
+ bool is_precomputed() const override { return false; }
+
+ private:
+ // Returns true if and only if index is in [0, limit). info should be a
+ // pointer to a zero-terminated array of chars (ideally a literal string,
+ // e.g. "layer") indicating what the index refers to; info is used to make log
+ // messages more informative.
+ static bool InRangeIndex(int index, int limit, const char *info);
+
+ // Returns network_->input_chunks()->size(), if all dereferences are safe
+ // (i.e., no nullptr); otherwise, returns 0.
+ int SafeGetNumInputChunks() const;
+
+ // Returns network_->input_chunks()->Get(i), if all dereferences are safe
+ // (i.e., no nullptr) otherwise, returns nullptr.
+ const saft_fbs::InputChunk *SafeGetInputChunk(int i) const;
+
+ // Returns network_->input_chunks()->Get(i)->embedding(), if all dereferences
+ // are safe (i.e., no nullptr); otherwise, returns nullptr.
+ const saft_fbs::Matrix *SafeGetEmbeddingMatrix(int i) const;
+
+ // Returns network_->layers()->size(), if all dereferences are safe (i.e., no
+ // nullptr); otherwise, returns 0.
+ int SafeGetNumLayers() const;
+
+ // Returns network_->layers()->Get(i), if all dereferences are safe
+ // (i.e., no nullptr); otherwise, returns nullptr.
+ const saft_fbs::NeuralLayer *SafeGetLayer(int i) const;
+
+ // Returns network_->layers()->Get(i)->weights(), if all dereferences are safe
+ // (i.e., no nullptr); otherwise, returns nullptr.
+ const saft_fbs::Matrix *SafeGetLayerWeights(int i) const;
+
+ // Returns network_->layers()->Get(i)->bias(), if all dereferences are safe
+ // (i.e., no nullptr); otherwise, returns nullptr.
+ const saft_fbs::Matrix *SafeGetLayerBias(int i) const;
+
+ static int SafeGetNumRows(const saft_fbs::Matrix *matrix) {
+ return (matrix == nullptr) ? 0 : matrix->rows();
+ }
+
+ static int SafeGetNumCols(const saft_fbs::Matrix *matrix) {
+ return (matrix == nullptr) ? 0 : matrix->cols();
+ }
+
+ // Returns matrix->values()->data() if all dereferences are safe (i.e., no
+ // nullptr); otherwise, returns nullptr.
+ static const float *SafeGetValues(const saft_fbs::Matrix *matrix);
+
+ // Returns matrix->quantized_values()->data() if all dereferences are safe
+ // (i.e., no nullptr); otherwise, returns nullptr.
+ static const uint8_t *SafeGetQuantizedValues(const saft_fbs::Matrix *matrix);
+
+ // Returns matrix->scales()->data() if all dereferences are safe (i.e., no
+ // nullptr); otherwise, returns nullptr.
+ static const float16 *SafeGetScales(const saft_fbs::Matrix *matrix);
+
+ // Returns network_->layers()->Get(last_index) with last_index =
+ // SafeGetNumLayers() - 1, if all dereferences are safe (i.e., no nullptr) and
+ // there exists at least one layer; otherwise, returns nullptr.
+ const saft_fbs::NeuralLayer *SafeGetSoftmaxLayer() const;
+
+ const saft_fbs::Matrix *SafeGetSoftmaxWeights() const {
+ const saft_fbs::NeuralLayer *layer = SafeGetSoftmaxLayer();
+ return (layer == nullptr) ? nullptr : layer->weights();
+ }
+
+ const saft_fbs::Matrix *SafeGetSoftmaxBias() const {
+ const saft_fbs::NeuralLayer *layer = SafeGetSoftmaxLayer();
+ return (layer == nullptr) ? nullptr : layer->bias();
+ }
+
+ // Returns the quantization type for |matrix|. Returns NONE in case of
+ // problems (e.g., matrix is nullptr or unknown quantization type).
+ QuantizationType SafeGetQuantizationType(
+ const saft_fbs::Matrix *matrix) const;
+
+ // Returns a pointer to the values (float, uint8, or float16, depending on
+ // quantization) from |matrix|, in row-major order. Returns nullptr in case
+ // of a problem.
+ const void *SafeGetValuesOfMatrix(const saft_fbs::Matrix *matrix) const;
+
+ // Performs some validity checks. E.g., check that dimensions of the network
+ // layers match. Also checks that all pointers we return are inside the
+ // |bytes| passed to the constructor, such that client that reads from those
+ // pointers will not run into troubles.
+ bool ValidityChecking(StringPiece bytes) const;
+
+ // True if these params are valid. May be false if the original proto was
+ // corrupted. We prefer to set this to false to CHECK-failing.
+ bool valid_ = false;
+
+ // EmbeddingNetwork flatbuffer from the bytes passed as parameter to the
+ // constructor; see constructor doc.
+ const saft_fbs::EmbeddingNetwork *network_ = nullptr;
+};
+
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_FLATBUFFERS_EMBEDDING_NETWORK_PARAMS_FROM_FLATBUFFER_H_
diff --git a/lang_id/common/flatbuffers/embedding-network.fbs b/lang_id/common/flatbuffers/embedding-network.fbs
new file mode 100644
index 0000000..1fde6a3
--- /dev/null
+++ b/lang_id/common/flatbuffers/embedding-network.fbs
@@ -0,0 +1,117 @@
+//
+// Copyright (C) 2018 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+// Flatbuffer schema for Neurosis (FFNN with embeddings) parameters.
+//
+// Contains the same information as an EmbeddingNetworkProto.
+
+namespace libtextclassifier3.saft_fbs;
+
+// NS stands for NeurosiS. The next two digits are meant to identify
+// incompatible versions. Ideally, we'll never have to go beyond 00.
+file_identifier "NS00";
+
+// Should be kept in sync with the C++ enum nlp_saft::QuantizationType.
+enum QuantizationType : byte {
+ NONE = 0,
+ UINT8 = 1,
+ UINT4 = 2,
+ FLOAT16 = 3,
+}
+
+table Matrix {
+ // Number of rows of this matrix.
+ rows:int;
+
+ // Number of columns of this matrix.
+ cols:int;
+
+ // Type of quantization used for the values from this matrix.
+ //
+ // If this is QuantizationType_NONE, then the unquantized values should be
+ // stored in |values| below. Otherwise, the bytes of the quantized values
+ // should be stored in |quantized_values| and the float16 quantization scales
+ // should be stored in |scales|.
+ quantization_type:QuantizationType = NONE;
+
+ // Non-quantized matrix elements, in row-major order. See comments for
+ // |quantization_type|.
+ values:[float];
+
+ // Quantized matrix elements, in row-major order. See comments for
+ // |quantization_type|.
+ quantized_values:[ubyte];
+
+ // Quantization factors (float16), one per matrix row. There is no float16
+ // primitive type for flatbuffers, we just use another 16 bit type. See
+ // comments for |quantization_type|.
+ scales:[ushort];
+}
+
+// The input layer for a Neurosis network is composed of several parts (named
+// "chunks" below, "embedding spaces" in some other parts, etc). For each
+// chunk, we have |num_features| features that extract feature values in that
+// chunk. All values extracted by a feature get projected via the embedding
+// matrix |embedding| and summed together, producing a vector of
+// |embedding.cols| elements. The resulting vector gets concatenated with the
+// similar vectors for other |num_features| features, producing a "chunk" of
+// |num_features * embedding.cols| elements. This chunk gets concatenated with
+// the other chunks.
+//
+// Note: the specification that indicates what those |num_features| features are
+// is stored elsewhere (usually in a ModelParameter, see model.fbs). But we
+// need to know |num_features| here, in order to specify the geometry of the
+// Neurosis network.
+table InputChunk {
+ embedding:Matrix;
+ num_features:int;
+}
+
+// One layer of neurons from the Neurosis network. This table can represent a
+// hidden layer or the final (output / softmax) layer.
+//
+// Our formalism is a bit different, but equivalent to the usual description
+// from the literature:
+//
+// Technically, in Neurosis, each layer takes an input (a vector of floats); if
+// this is not the first layer, we apply a nonlinear function (ReLU); for the
+// first layer, we skip ReLU. Next, we multiply by |weights| and add |bias|,
+// get the input for the next level and so on. The output from the last layer
+// is generally used for softmax classification. That's why we say that the
+// last layer is the "softmax layer".
+table NeuralLayer {
+ // Weight matrix for this layer. Geometry: num_inputs x num_neurons, where
+ // num_inputs is the number of values produced by previous layer (which can be
+ // the input layer, or another hidden layer) and num_neurons is the number of
+ // neurons from this layer.
+ weights:Matrix;
+
+ // Bias vector for this layer.
+ //
+ // NOTE: right now, we accept both 1 x num_neurons and num_neurons x 1
+ // geometries: the layout of the elements is the same in both cases.
+ bias:Matrix;
+}
+
+table EmbeddingNetwork {
+ // Specification of the chunks that compose the input layer.
+ input_chunks:[InputChunk];
+
+ // Hidden layers, followed by the final (softmax) layer.
+ layers:[NeuralLayer];
+}
+
+root_type EmbeddingNetwork;
diff --git a/lang_id/common/flatbuffers/model-utils.cc b/lang_id/common/flatbuffers/model-utils.cc
new file mode 100644
index 0000000..2c57aa2
--- /dev/null
+++ b/lang_id/common/flatbuffers/model-utils.cc
@@ -0,0 +1,208 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/common/flatbuffers/model-utils.h"
+
+#include <string.h>
+
+#include "lang_id/common/lite_base/logging.h"
+#include "lang_id/common/math/checksum.h"
+
+namespace libtextclassifier3 {
+namespace saft_fbs {
+
+namespace {
+
+// Returns true if we have clear evidence that |model| fails its checksum.
+//
+// E.g., if |model| has the crc32 field, and the value of that field does not
+// match the checksum, then this function returns true. If there is no crc32
+// field, then we don't know what the original (at build time) checksum was, so
+// we don't know anything clear and this function returns false.
+bool ClearlyFailsChecksum(const Model &model) {
+ if (!flatbuffers::IsFieldPresent(&model, Model::VT_CRC32)) {
+ SAFTM_LOG(WARNING)
+ << "No CRC32, most likely an old model; skip CRC32 check";
+ return false;
+ }
+ const mobile::uint32 expected_crc32 = model.crc32();
+ const mobile::uint32 actual_crc32 = ComputeCrc2Checksum(&model);
+ if (actual_crc32 != expected_crc32) {
+ SAFTM_LOG(ERROR) << "Corrupt model: different CRC32: " << actual_crc32
+ << " vs " << expected_crc32;
+ return true;
+ }
+ SAFTM_LOG(INFO) << "Successfully checked CRC32 " << actual_crc32;
+ return false;
+}
+} // namespace
+
+const Model *GetVerifiedModelFromBytes(const char *data, size_t num_bytes) {
+ if ((data == nullptr) || (num_bytes == 0)) {
+ SAFTM_LOG(ERROR) << "GetModel called on an empty sequence of bytes";
+ return nullptr;
+ }
+ const uint8_t *start = reinterpret_cast<const uint8_t *>(data);
+ flatbuffers::Verifier verifier(start, num_bytes);
+ if (!VerifyModelBuffer(verifier)) {
+ SAFTM_LOG(ERROR) << "Not a valid Model flatbuffer";
+ return nullptr;
+ }
+ const Model *model = GetModel(start);
+ if (model == nullptr) {
+ return nullptr;
+ }
+ if (ClearlyFailsChecksum(*model)) {
+ return nullptr;
+ }
+ return model;
+}
+
+const ModelInput *GetInputByName(const Model *model, const string &name) {
+ if (model == nullptr) {
+ SAFTM_LOG(ERROR) << "GetInputByName called with model == nullptr";
+ return nullptr;
+ }
+ const auto *inputs = model->inputs();
+ if (inputs == nullptr) {
+ // We should always have a list of inputs; maybe an empty one, if no inputs,
+ // but the list should be there.
+ SAFTM_LOG(ERROR) << "null inputs";
+ return nullptr;
+ }
+ for (const ModelInput *input : *inputs) {
+ if (input != nullptr) {
+ const flatbuffers::String *input_name = input->name();
+ if (input_name && input_name->str() == name) {
+ return input;
+ }
+ }
+ }
+ return nullptr;
+}
+
+mobile::StringPiece GetInputBytes(const ModelInput *input) {
+ if ((input == nullptr) || (input->data() == nullptr)) {
+ SAFTM_LOG(ERROR) << "ModelInput has no content";
+ return mobile::StringPiece(nullptr, 0);
+ }
+ const flatbuffers::Vector<uint8_t> *input_data = input->data();
+ if (input_data == nullptr) {
+ SAFTM_LOG(ERROR) << "null input data";
+ return mobile::StringPiece(nullptr, 0);
+ }
+ return mobile::StringPiece(reinterpret_cast<const char *>(input_data->data()),
+ input_data->size());
+}
+
+bool FillParameters(const Model &model, mobile::TaskContext *context) {
+ if (context == nullptr) {
+ SAFTM_LOG(ERROR) << "null context";
+ return false;
+ }
+ const auto *parameters = model.parameters();
+ if (parameters == nullptr) {
+ // We should always have a list of parameters; maybe an empty one, if no
+ // parameters, but the list should be there.
+ SAFTM_LOG(ERROR) << "null list of parameters";
+ return false;
+ }
+ for (const ModelParameter *p : *parameters) {
+ if (p == nullptr) {
+ SAFTM_LOG(ERROR) << "null parameter";
+ return false;
+ }
+ if (p->name() == nullptr) {
+ SAFTM_LOG(ERROR) << "null parameter name";
+ return false;
+ }
+ const string name = p->name()->str();
+ if (name.empty()) {
+ SAFTM_LOG(ERROR) << "empty parameter name";
+ return false;
+ }
+ if (p->value() == nullptr) {
+ SAFTM_LOG(ERROR) << "null parameter name";
+ return false;
+ }
+ context->SetParameter(name, p->value()->str());
+ }
+ return true;
+}
+
+namespace {
+// Updates |*crc| with the information from |s|. Auxiliary for
+// ComputeCrc2Checksum.
+//
+// The bytes from |info| are also used to update the CRC32 checksum. |info|
+// should be a brief tag that indicates what |s| represents. The idea is to add
+// some structure to the information that goes into the CRC32 computation.
+template <typename T>
+void UpdateCrc(mobile::Crc32 *crc, const flatbuffers::Vector<T> *s,
+ mobile::StringPiece info) {
+ crc->Update("|");
+ crc->Update(info.data(), info.size());
+ crc->Update(":");
+ if (s == nullptr) {
+ crc->Update("empty");
+ } else {
+ crc->Update(reinterpret_cast<const char *>(s->data()),
+ s->size() * sizeof(T));
+ }
+}
+} // namespace
+
+mobile::uint32 ComputeCrc2Checksum(const Model *model) {
+ // Implementation note: originally, I (salcianu@) thought we can just compute
+ // a CRC32 checksum of the model bytes. Unfortunately, the expected checksum
+ // is there too (and because we don't control the flatbuffer format, we can't
+ // "arrange" for it to be placed at the head / tail of those bytes). Instead,
+ // we traverse |model| and feed into the CRC32 computation those parts we are
+ // interested in (which excludes the crc32 field).
+ //
+ // Note: storing the checksum outside the Model would be too disruptive for
+ // the way we currently ship our models.
+ mobile::Crc32 crc;
+ if (model == nullptr) {
+ return crc.Get();
+ }
+ crc.Update("|Parameters:");
+ const auto *parameters = model->parameters();
+ if (parameters != nullptr) {
+ for (const ModelParameter *p : *parameters) {
+ if (p != nullptr) {
+ UpdateCrc(&crc, p->name(), "name");
+ UpdateCrc(&crc, p->value(), "value");
+ }
+ }
+ }
+ crc.Update("|Inputs:");
+ const auto *inputs = model->inputs();
+ if (inputs != nullptr) {
+ for (const ModelInput *input : *inputs) {
+ if (input != nullptr) {
+ UpdateCrc(&crc, input->name(), "name");
+ UpdateCrc(&crc, input->type(), "type");
+ UpdateCrc(&crc, input->sub_type(), "sub-type");
+ UpdateCrc(&crc, input->data(), "data");
+ }
+ }
+ }
+ return crc.Get();
+}
+
+} // namespace saft_fbs
+} // namespace nlp_saft
diff --git a/lang_id/common/flatbuffers/model-utils.h b/lang_id/common/flatbuffers/model-utils.h
new file mode 100644
index 0000000..5427f70
--- /dev/null
+++ b/lang_id/common/flatbuffers/model-utils.h
@@ -0,0 +1,72 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_FLATBUFFERS_MODEL_UTILS_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_FLATBUFFERS_MODEL_UTILS_H_
+
+#include <stddef.h>
+
+#include <string>
+
+#include "lang_id/common/fel/task-context.h"
+#include "lang_id/common/flatbuffers/model_generated.h"
+#include "lang_id/common/lite_base/integral-types.h"
+#include "lang_id/common/lite_strings/stringpiece.h"
+
+namespace libtextclassifier3 {
+namespace saft_fbs {
+
+// Verifies that the |num_bytes| bytes that start at |data| represent a valid
+// Model flatbuffer. If so, returns that Model. Otherwise, returns nullptr.
+//
+// Note: if the Model has the crc32 field, this method checks that the Model
+// checksum matches that field; if they don't match, the Model is considered
+// invalid, and this function returns nullptr. The checksum test is in addition
+// to the standard flatbuffer validity checking.
+const Model *GetVerifiedModelFromBytes(const char *data, size_t num_bytes);
+
+// Convenience StringPiece version of GetVerifiedModelFromBytes.
+inline const Model *GetVerifiedModelFromBytes(mobile::StringPiece bytes) {
+ return GetVerifiedModelFromBytes(bytes.data(), bytes.size());
+}
+
+// Returns the |model| input with specified |name|. Returns nullptr if no such
+// input exists. If |model| contains multiple inputs with that |name|, returns
+// the first one (model builders should avoid building such models).
+const ModelInput *GetInputByName(const Model *model, const string &name);
+
+// Returns a StringPiece pointing to the bytes for the content of |input|. In
+// case of errors, returns StringPiece(nullptr, 0).
+mobile::StringPiece GetInputBytes(const ModelInput *input);
+
+// Fills parameters from |context|, based on the parameters from |model|.
+// Returns false if any error is encountered, true otherwise. In the case of an
+// error, some parameters may have been added to |context| (e.g., if we find a
+// problem with the 3rd parameter, the first 2 have been added).
+bool FillParameters(const Model &model, mobile::TaskContext *context);
+
+// Returns the CRC32 checksum of |model|. This checksum is computed over the
+// entire information from the model (including the bytes of the inputs),
+// *except* the crc32 field. Hence, when a model is build, one can store the
+// result of this function into that field; on the user side, one can check that
+// the result of this function matches the crc32 field, to guard against model
+// corruption. GetVerifiedModelFromBytes performs this check.
+mobile::uint32 ComputeCrc2Checksum(const Model *model);
+
+} // namespace saft_fbs
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_FLATBUFFERS_MODEL_UTILS_H_
diff --git a/lang_id/common/flatbuffers/model.fbs b/lang_id/common/flatbuffers/model.fbs
new file mode 100644
index 0000000..41251e1
--- /dev/null
+++ b/lang_id/common/flatbuffers/model.fbs
@@ -0,0 +1,79 @@
+//
+// Copyright (C) 2018 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+// Flatbuffer schema for SAFT models.
+//
+// For info on flatbuffers, see http://go/flatbuffers and
+// http://google.github.io/flatbuffers/, including info on writing schemas:
+// http://google.github.io/flatbuffers/flatbuffers_guide_writing_schema.html
+
+namespace libtextclassifier3.saft_fbs;
+
+// SM stands for Saft Model. The next two digits are meant to identify
+// incompatible versions. Ideally, we'll never have to go beyond 00.
+file_identifier "SM00";
+
+// Extension stands for Saft Model in FlatBuffer format.
+file_extension "smfb";
+
+table ModelParameter {
+ // Parameter name.
+ name:string;
+
+ // Parameter value.
+ value:string;
+}
+
+// Input for a SAFT model. Inputs usually provide extra resources: e.g., the
+// parameters for a Neurosis FFNN with embeddings, or a word cluster structure,
+// etc.
+table ModelInput {
+ // Name of this input. Different input of the same model should have
+ // different names, such that we can non-ambiguously look them up.
+ name:string;
+
+ // General description of the type of this input. Required to parse the
+ // content of this input (see |data| below). If |data| is a flatbuffer, use
+ // "flatbuffer". If |data| is a proto, use "proto". Otherwise, use your best
+ // judgment: use something human-readable, and look around to make sure you
+ // don't invent a new name for something that already exists.
+ type:string;
+
+ // More specific information about the type of this input. E.g., if |type| is
+ // "flatbuffer", this should be the name of the root_type we should parse from
+ // the input bytes., e.g., "EmbeddingNetwork". If |type| is proto, this
+ // should be the name of the proto serialized as |data|, e.g.,
+ // "EmbeddingNetworkProto".
+ sub_type:string;
+
+ // The content of this input. With a generous alignment, such that we can
+ // accommodate mmap-friendly data structures. E.g., the word clusters used by
+ // the Translate team require 8-byte alignment.
+ data:[ubyte] (force_align: 16);
+}
+
+// A Saft model. A list of parameters with model settings (e.g., the
+// specification of the features to use) and a list of inputs.
+table Model {
+ parameters:[ModelParameter];
+ inputs:[ModelInput];
+
+ // Crc32 checksum of all parameters and inputs (including the bytes of the
+ // inputs). Used to check that the model has not been corrupted.
+ crc32:uint32;
+}
+
+root_type Model;
diff --git a/lang_id/common/lite_base/attributes.h b/lang_id/common/lite_base/attributes.h
new file mode 100644
index 0000000..f29e48f
--- /dev/null
+++ b/lang_id/common/lite_base/attributes.h
@@ -0,0 +1,82 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Various macros related to function inlining.
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_BASE_ATTRIBUTES_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_BASE_ATTRIBUTES_H_
+
+// SAFTM_HAVE_ATTRIBUTE
+//
+// A function-like feature checking macro that is a wrapper around
+// `__has_attribute`, which is defined by GCC 5+ and Clang and evaluates to a
+// nonzero constant integer if the attribute is supported or 0 if not.
+//
+// It evaluates to zero if `__has_attribute` is not defined by the compiler.
+//
+// GCC: https://gcc.gnu.org/gcc-5/changes.html
+// Clang: https://clang.llvm.org/docs/LanguageExtensions.html
+#ifdef __has_attribute
+#define SAFTM_HAVE_ATTRIBUTE(x) __has_attribute(x)
+#else
+#define SAFTM_HAVE_ATTRIBUTE(x) 0
+#endif
+
+// SAFTM_MUST_USE_RESULT
+//
+// Tells the compiler to warn about unused return values for functions declared
+// with this macro. The macro must appear as the very first part of a function
+// declaration or definition:
+//
+// Example:
+//
+// SAFTM_MUST_USE_RESULT Sprocket* AllocateSprocket();
+//
+// This placement has the broadest compatibility with GCC, Clang, and MSVC, with
+// both defs and decls, and with GCC-style attributes, MSVC declspec, C++11
+// and C++17 attributes.
+//
+// SAFTM_MUST_USE_RESULT allows using cast-to-void to suppress the unused result
+// warning. For that, warn_unused_result is used only for clang but not for gcc.
+// https://gcc.gnu.org/bugzilla/show_bug.cgi?id=66425
+#if SAFTM_HAVE_ATTRIBUTE(nodiscard)
+#define SAFTM_MUST_USE_RESULT [[nodiscard]]
+#elif defined(__clang__) && SAFTM_HAVE_ATTRIBUTE(warn_unused_result)
+#define SAFTM_MUST_USE_RESULT __attribute__((warn_unused_result))
+#else
+#define SAFTM_MUST_USE_RESULT
+#endif
+
+#if defined(__GNUC__) && \
+ (__GNUC__ > 3 || (__GNUC__ == 3 && __GNUC_MINOR__ >= 1))
+
+// For functions we want to force inline.
+// Introduced in gcc 3.1.
+#define SAFTM_ATTRIBUTE_ALWAYS_INLINE __attribute__((always_inline))
+
+// For functions we don't want to inline, e.g., to keep code size small.
+#define SAFTM_ATTRIBUTE_NOINLINE __attribute__((noinline))
+
+#elif defined(_MSC_VER)
+#define SAFTM_ATTRIBUTE_ALWAYS_INLINE __forceinline
+#else
+
+// Other compilers will have to figure it out for themselves.
+#define SAFTM_ATTRIBUTE_ALWAYS_INLINE
+#define SAFTM_ATTRIBUTE_NOINLINE
+#endif // big condition on two lines.
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_BASE_ATTRIBUTES_H_
diff --git a/lang_id/common/lite_base/casts.h b/lang_id/common/lite_base/casts.h
new file mode 100644
index 0000000..11a4ba2
--- /dev/null
+++ b/lang_id/common/lite_base/casts.h
@@ -0,0 +1,95 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_BASE_CASTS_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_BASE_CASTS_H_
+
+#include <string.h> // for memcpy
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+// bit_cast<Dest, Source> is a template function that implements the equivalent
+// of "*reinterpret_cast<Dest*>(&source)". We need this in very low-level
+// functions like fast math support.
+//
+// float f = 3.14159265358979;
+// int i = bit_cast<int32>(f);
+// // i = 0x40490fdb
+//
+// The classical address-casting method is:
+//
+// // WRONG
+// float f = 3.14159265358979; // WRONG
+// int i = * reinterpret_cast<int*>(&f); // WRONG
+//
+// The address-casting method actually produces undefined behavior
+// according to ISO C++ specification section 3.10 -15 -. Roughly, this
+// section says: if an object in memory has one type, and a program
+// accesses it with a different type, then the result is undefined
+// behavior for most values of "different type".
+//
+// This is true for any cast syntax, either *(int*)&f or
+// *reinterpret_cast<int*>(&f). And it is particularly true for
+// conversions between integral lvalues and floating-point lvalues.
+//
+// The purpose of 3.10 -15- is to allow optimizing compilers to assume
+// that expressions with different types refer to different memory. gcc
+// 4.0.1 has an optimizer that takes advantage of this. So a
+// non-conforming program quietly produces wildly incorrect output.
+//
+// The problem is not the use of reinterpret_cast. The problem is type
+// punning: holding an object in memory of one type and reading its bits
+// back using a different type.
+//
+// The C++ standard is more subtle and complex than this, but that
+// is the basic idea.
+//
+// Anyways ...
+//
+// bit_cast<> calls memcpy() which is blessed by the standard, especially by the
+// example in section 3.9 . Also, of course, bit_cast<> wraps up the nasty
+// logic in one place.
+//
+// Fortunately memcpy() is very fast. In optimized mode, with a
+// constant size, gcc 2.95.3, gcc 4.0.1, and msvc 7.1 produce inline
+// code with the minimal amount of data movement. On a 32-bit system,
+// memcpy(d,s,4) compiles to one load and one store, and memcpy(d,s,8)
+// compiles to two loads and two stores.
+//
+// I tested this code with gcc 2.95.3, gcc 4.0.1, icc 8.1, and msvc 7.1.
+//
+// WARNING: if Dest or Source is a non-POD type, the result of the memcpy
+// is likely to surprise you.
+//
+// Props to Bill Gibbons for the compile time assertion technique and
+// Art Komninos and Igor Tandetnik for the msvc experiments.
+//
+// -- mec 2005-10-17
+
+template <class Dest, class Source>
+inline Dest bit_cast(const Source &source) {
+ static_assert(sizeof(Dest) == sizeof(Source), "Sizes do not match");
+
+ Dest dest;
+ memcpy(&dest, &source, sizeof(dest));
+ return dest;
+}
+
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_BASE_CASTS_H_
diff --git a/lang_id/common/lite_base/compact-logging-levels.h b/lang_id/common/lite_base/compact-logging-levels.h
new file mode 100644
index 0000000..977f4da
--- /dev/null
+++ b/lang_id/common/lite_base/compact-logging-levels.h
@@ -0,0 +1,50 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef TC3_STD_STRING_IMPORT
+#define TC3_STD_STRING_IMPORT
+#include <string>
+
+namespace libtextclassifier3 {
+using string = std::string;
+template <class CharT, class Traits = std::char_traits<CharT>,
+ class Allocator = std::allocator<CharT> >
+using basic_string = std::basic_string<CharT, Traits, Allocator>;
+} // namespace libtextclassifier3
+#endif
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_BASE_COMPACT_LOGGING_LEVELS_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_BASE_COMPACT_LOGGING_LEVELS_H_
+
+namespace libtextclassifier3 {
+namespace mobile {
+namespace internal_logging {
+
+enum LogSeverity {
+ FATAL = 0,
+ ERROR,
+ WARNING,
+ INFO,
+
+ // In debug mode, DFATAL has the same semantics as FATAL. Otherwise, it
+ // behaves like ERROR.
+ DFATAL,
+};
+
+} // namespace internal_logging
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_BASE_COMPACT_LOGGING_LEVELS_H_
diff --git a/lang_id/common/lite_base/compact-logging-raw.cc b/lang_id/common/lite_base/compact-logging-raw.cc
new file mode 100644
index 0000000..53dfc8e
--- /dev/null
+++ b/lang_id/common/lite_base/compact-logging-raw.cc
@@ -0,0 +1,103 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/common/lite_base/compact-logging-raw.h"
+
+#include <stdio.h>
+#include <string>
+
+// NOTE: this file contains two implementations: one for Android, one for all
+// other cases. We always build exactly one implementation.
+#if defined(__ANDROID__)
+
+// Compiled as part of Android.
+#include <android/log.h>
+
+namespace libtextclassifier3 {
+namespace mobile {
+namespace internal_logging {
+
+namespace {
+// Converts LogSeverity to level for __android_log_write.
+int GetAndroidLogLevel(LogSeverity severity) {
+ switch (severity) {
+ case FATAL:
+ return ANDROID_LOG_FATAL;
+ case ERROR:
+ return ANDROID_LOG_ERROR;
+ case WARNING:
+ return ANDROID_LOG_WARN;
+ case INFO:
+ return ANDROID_LOG_INFO;
+ default:
+ return ANDROID_LOG_DEBUG;
+ }
+}
+} // namespace
+
+void LowLevelLogging(LogSeverity severity, const string &tag,
+ const string &message) {
+ const int android_log_level = GetAndroidLogLevel(severity);
+#if !defined(SAFTM_DEBUG_LOGGING)
+ if (android_log_level != ANDROID_LOG_ERROR &&
+ android_log_level != ANDROID_LOG_FATAL) {
+ return;
+ }
+#endif
+ __android_log_write(android_log_level, tag.c_str(), message.c_str());
+}
+
+} // namespace internal_logging
+} // namespace mobile
+} // namespace nlp_saft
+
+#else // if defined(__ANDROID__)
+
+// Not on Android: implement LowLevelLogging to print to stderr (see below).
+namespace libtextclassifier3 {
+namespace mobile {
+namespace internal_logging {
+
+namespace {
+// Converts LogSeverity to human-readable text.
+const char *LogSeverityToString(LogSeverity severity) {
+ switch (severity) {
+ case INFO:
+ return "INFO";
+ case WARNING:
+ return "WARNING";
+ case ERROR:
+ return "ERROR";
+ case FATAL:
+ return "FATAL";
+ default:
+ return "UNKNOWN";
+ }
+}
+} // namespace
+
+void LowLevelLogging(LogSeverity severity, const string &tag,
+ const string &message) {
+ fprintf(stderr, "[%s] %s : %s\n", LogSeverityToString(severity), tag.c_str(),
+ message.c_str());
+ fflush(stderr);
+}
+
+} // namespace internal_logging
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // if defined(__ANDROID__)
diff --git a/lang_id/common/lite_base/compact-logging-raw.h b/lang_id/common/lite_base/compact-logging-raw.h
new file mode 100644
index 0000000..f67287c
--- /dev/null
+++ b/lang_id/common/lite_base/compact-logging-raw.h
@@ -0,0 +1,38 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_BASE_COMPACT_LOGGING_RAW_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_BASE_COMPACT_LOGGING_RAW_H_
+
+#include <string>
+
+#include "lang_id/common/lite_base/compact-logging-levels.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+namespace internal_logging {
+
+// Low-level logging primitive. Logs a message, with the indicated log
+// severity. From android/log.h: "the tag normally corresponds to the component
+// that emits the log message, and should be reasonably small".
+void LowLevelLogging(LogSeverity severity, const string &tag,
+ const string &message);
+
+} // namespace internal_logging
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_BASE_COMPACT_LOGGING_RAW_H_
diff --git a/lang_id/common/lite_base/compact-logging.cc b/lang_id/common/lite_base/compact-logging.cc
new file mode 100644
index 0000000..99d60a3
--- /dev/null
+++ b/lang_id/common/lite_base/compact-logging.cc
@@ -0,0 +1,84 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/common/lite_base/compact-logging.h"
+
+#include <stdlib.h>
+
+#include <iostream>
+
+#include "lang_id/common/lite_base/compact-logging-raw.h"
+
+#ifndef SAFTM_LOGGING_TAG
+
+// Tag inserted in the prefix of the generated log messages. The user can
+// override this by defining this macro on the blaze build command-line.
+#define SAFTM_LOGGING_TAG "saftm"
+#endif
+
+namespace libtextclassifier3 {
+namespace mobile {
+namespace internal_logging {
+
+namespace {
+// Returns pointer to beginning of last /-separated token from file_name.
+// file_name should be a pointer to a zero-terminated array of chars.
+// E.g., "foo/bar.cc" -> "bar.cc", "foo/" -> "", "foo" -> "foo".
+const char *JumpToBasename(const char *file_name) {
+ if (file_name == nullptr) {
+ return nullptr;
+ }
+
+ // Points to the beginning of the last encountered token.
+ const char *last_token_start = file_name;
+ while (*file_name != '\0') {
+ if (*file_name == '/') {
+ // Found token separator. A new (potentially empty) token starts after
+ // this position. Notice that if file_name is a valid zero-terminated
+ // string, file_name + 1 is a valid pointer (there is at least one char
+ // after address file_name, the zero terminator).
+ last_token_start = file_name + 1;
+ }
+ file_name++;
+ }
+ return last_token_start;
+}
+} // namespace
+
+LogMessage::LogMessage(LogSeverity severity, const char *file_name,
+ int line_number)
+ : severity_(severity) {
+ stream_ << JumpToBasename(file_name) << ":" << line_number << ": ";
+}
+
+LogMessage::~LogMessage() {
+ LogSeverity level = severity_;
+ if (level == DFATAL) {
+#ifdef NDEBUG
+ level = ERROR;
+#else
+ level = FATAL;
+#endif
+ }
+ LowLevelLogging(level, /* tag = */ SAFTM_LOGGING_TAG, stream_.message);
+ if (level == FATAL) {
+ exit(1);
+ }
+}
+
+} // namespace internal_logging
+} // namespace mobile
+} // namespace nlp_saft
diff --git a/lang_id/common/lite_base/compact-logging.h b/lang_id/common/lite_base/compact-logging.h
new file mode 100644
index 0000000..eccb7d1
--- /dev/null
+++ b/lang_id/common/lite_base/compact-logging.h
@@ -0,0 +1,177 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_BASE_COMPACT_LOGGING_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_BASE_COMPACT_LOGGING_H_
+
+#include <cassert>
+#include <string>
+
+#include "lang_id/common/lite_base/attributes.h"
+#include "lang_id/common/lite_base/compact-logging-levels.h"
+#include "lang_id/common/lite_strings/stringpiece.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+namespace internal_logging {
+
+// A tiny code footprint string stream for assembling log messages.
+struct LoggingStringStream {
+ LoggingStringStream() {}
+ LoggingStringStream &stream() { return *this; }
+
+ // Needed for invocation in SAFTM_CHECK macro.
+ explicit operator bool() const { return true; }
+
+ string message;
+};
+
+template <typename T>
+inline LoggingStringStream &operator<<(LoggingStringStream &stream,
+ const T &entry) {
+ stream.message.append(std::to_string(entry));
+ return stream;
+}
+
+inline LoggingStringStream &operator<<(LoggingStringStream &stream,
+ const char *message) {
+ stream.message.append(message);
+ return stream;
+}
+
+inline LoggingStringStream &operator<<(LoggingStringStream &stream,
+ const string &message) {
+ stream.message.append(message);
+ return stream;
+}
+
+inline LoggingStringStream &operator<<(LoggingStringStream &stream,
+ StringPiece sp) {
+ stream.message.append(sp.data(), sp.size());
+ return stream;
+}
+
+// The class that does all the work behind our SAFTM_LOG(severity) macros. Each
+// SAFTM_LOG(severity) << obj1 << obj2 << ...; logging statement creates a
+// LogMessage temporary object containing a stringstream. Each operator<< adds
+// info to that stringstream and the LogMessage destructor performs the actual
+// logging. The reason this works is that in C++, "all temporary objects are
+// destroyed as the last step in evaluating the full-expression that (lexically)
+// contains the point where they were created." For more info, see
+// http://en.cppreference.com/w/cpp/language/lifetime. Hence, the destructor is
+// invoked after the last << from that logging statement.
+class LogMessage {
+ public:
+ LogMessage(LogSeverity severity, const char *file_name,
+ int line_number) SAFTM_ATTRIBUTE_NOINLINE;
+
+ ~LogMessage() SAFTM_ATTRIBUTE_NOINLINE;
+
+ // Returns the stream associated with the logger object.
+ LoggingStringStream &stream() { return stream_; }
+
+ private:
+ const LogSeverity severity_;
+
+ // Stream that "prints" all info into a string (not to a file). We construct
+ // here the entire logging message and next print it in one operation.
+ LoggingStringStream stream_;
+};
+
+// Pseudo-stream that "eats" the tokens <<-pumped into it, without printing
+// anything.
+class NullStream {
+ public:
+ NullStream() {}
+ NullStream &stream() { return *this; }
+};
+template <typename T>
+inline NullStream &operator<<(NullStream &str, const T &) {
+ return str;
+}
+
+} // namespace internal_logging
+} // namespace mobile
+} // namespace nlp_saft
+
+#define SAFTM_LOG(severity) \
+ ::libtextclassifier3::mobile::internal_logging::LogMessage( \
+ ::libtextclassifier3::mobile::internal_logging::severity, __FILE__, __LINE__) \
+ .stream()
+
+// If condition x is true, does nothing. Otherwise, crashes the program (liek
+// LOG(FATAL)) with an informative message. Can be continued with extra
+// messages, via <<, like any logging macro, e.g.,
+//
+// SAFTM_CHECK(my_cond) << "I think we hit a problem";
+#define SAFTM_CHECK(x) \
+ (x) || SAFTM_LOG(FATAL) << __FILE__ << ":" << __LINE__ \
+ << ": check failed: \"" << #x
+
+#define SAFTM_CHECK_EQ(x, y) SAFTM_CHECK((x) == (y))
+#define SAFTM_CHECK_LT(x, y) SAFTM_CHECK((x) < (y))
+#define SAFTM_CHECK_GT(x, y) SAFTM_CHECK((x) > (y))
+#define SAFTM_CHECK_LE(x, y) SAFTM_CHECK((x) <= (y))
+#define SAFTM_CHECK_GE(x, y) SAFTM_CHECK((x) >= (y))
+#define SAFTM_CHECK_NE(x, y) SAFTM_CHECK((x) != (y))
+
+#define SAFTM_NULLSTREAM \
+ ::libtextclassifier3::mobile::internal_logging::NullStream().stream()
+
+// Debug checks: a SAFTM_DCHECK<suffix> macro should behave like
+// SAFTM_CHECK<suffix> in debug mode an don't check / don't print anything in
+// non-debug mode.
+#ifdef NDEBUG
+
+#define SAFTM_DCHECK(x) SAFTM_NULLSTREAM
+#define SAFTM_DCHECK_EQ(x, y) SAFTM_NULLSTREAM
+#define SAFTM_DCHECK_LT(x, y) SAFTM_NULLSTREAM
+#define SAFTM_DCHECK_GT(x, y) SAFTM_NULLSTREAM
+#define SAFTM_DCHECK_LE(x, y) SAFTM_NULLSTREAM
+#define SAFTM_DCHECK_GE(x, y) SAFTM_NULLSTREAM
+#define SAFTM_DCHECK_NE(x, y) SAFTM_NULLSTREAM
+
+// In non-debug mode, SAFT_DLOG statements do not generate any logging.
+#define SAFTM_DLOG(severity) SAFTM_NULLSTREAM
+
+#else // NDEBUG
+
+// In debug mode, each SAFTM_DCHECK<suffix> is equivalent to
+// SAFTM_CHECK<suffix>, i.e., a real check that crashes when the condition is
+// not true.
+#define SAFTM_DCHECK(x) SAFTM_CHECK(x)
+#define SAFTM_DCHECK_EQ(x, y) SAFTM_CHECK_EQ(x, y)
+#define SAFTM_DCHECK_LT(x, y) SAFTM_CHECK_LT(x, y)
+#define SAFTM_DCHECK_GT(x, y) SAFTM_CHECK_GT(x, y)
+#define SAFTM_DCHECK_LE(x, y) SAFTM_CHECK_LE(x, y)
+#define SAFTM_DCHECK_GE(x, y) SAFTM_CHECK_GE(x, y)
+#define SAFTM_DCHECK_NE(x, y) SAFTM_CHECK_NE(x, y)
+
+// In debug mode, SAFT_DLOG statements are like SAFT_LOG.
+#define SAFTM_DLOG SAFTM_LOG
+
+#endif // NDEBUG
+
+#ifdef LIBTEXTCLASSIFIER_VLOG
+#define SAFTM_VLOG(severity) \
+ ::libtextclassifier3::mobile::internal_logging::LogMessage( \
+ ::libtextclassifier3::mobile::internal_logging::INFO, __FILE__, __LINE__) \
+ .stream()
+#else
+#define SAFTM_VLOG(severity) SAFTM_NULLSTREAM
+#endif
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_BASE_COMPACT_LOGGING_H_
diff --git a/lang_id/common/lite_base/endian.h b/lang_id/common/lite_base/endian.h
new file mode 100644
index 0000000..16c2dca
--- /dev/null
+++ b/lang_id/common/lite_base/endian.h
@@ -0,0 +1,126 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_BASE_ENDIAN_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_BASE_ENDIAN_H_
+
+#include "lang_id/common/lite_base/integral-types.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+#if defined OS_LINUX || defined OS_CYGWIN || defined OS_ANDROID || \
+ defined(__ANDROID__)
+#include <endian.h>
+#endif
+
+// The following guarantees declaration of the byte swap functions, and
+// defines __BYTE_ORDER for MSVC
+#if defined(__GLIBC__) || defined(__CYGWIN__)
+#include <byteswap.h> // IWYU pragma: export
+
+#else
+#ifndef bswap_16
+static inline uint16 bswap_16(uint16 x) {
+ return (uint16)(((x & 0xFF) << 8) | ((x & 0xFF00) >> 8)); // NOLINT
+}
+#define bswap_16(x) bswap_16(x)
+#endif // bswap_16
+
+#ifndef bswap_32
+static inline uint32 bswap_32(uint32 x) {
+ return (((x & 0xFF) << 24) | ((x & 0xFF00) << 8) | ((x & 0xFF0000) >> 8) |
+ ((x & 0xFF000000) >> 24));
+}
+#define bswap_32(x) bswap_32(x)
+#endif // bswap_32
+
+#ifndef bswap_64
+#define SAFTM_GG_ULONGLONG(x) x##ULL
+static inline uint64 bswap_64(uint64 x) {
+ return (((x & SAFTM_GG_ULONGLONG(0xFF)) << 56) |
+ ((x & SAFTM_GG_ULONGLONG(0xFF00)) << 40) |
+ ((x & SAFTM_GG_ULONGLONG(0xFF0000)) << 24) |
+ ((x & SAFTM_GG_ULONGLONG(0xFF000000)) << 8) |
+ ((x & SAFTM_GG_ULONGLONG(0xFF00000000)) >> 8) |
+ ((x & SAFTM_GG_ULONGLONG(0xFF0000000000)) >> 24) |
+ ((x & SAFTM_GG_ULONGLONG(0xFF000000000000)) >> 40) |
+ ((x & SAFTM_GG_ULONGLONG(0xFF00000000000000)) >> 56));
+}
+#define bswap_64(x) bswap_64(x)
+#endif // bswap_64
+
+#endif
+
+// define the macros SAFTM_IS_LITTLE_ENDIAN or SAFTM_IS_BIG_ENDIAN using the
+// above endian definitions from endian.h if endian.h was included
+#ifdef __BYTE_ORDER
+#if __BYTE_ORDER == __LITTLE_ENDIAN
+#define SAFTM_IS_LITTLE_ENDIAN
+#endif
+
+#if __BYTE_ORDER == __BIG_ENDIAN
+#define SAFTM_IS_BIG_ENDIAN
+#endif
+
+#else // __BYTE_ORDER
+
+#if defined(__LITTLE_ENDIAN__)
+#define SAFTM_IS_LITTLE_ENDIAN
+#elif defined(__BIG_ENDIAN__)
+#define SAFTM_IS_BIG_ENDIAN
+#endif
+
+// there is also PDP endian ...
+
+#endif // __BYTE_ORDER
+
+class LittleEndian {
+ public:
+// Conversion functions.
+#ifdef SAFTM_IS_LITTLE_ENDIAN
+
+ static uint16 FromHost16(uint16 x) { return x; }
+ static uint16 ToHost16(uint16 x) { return x; }
+
+ static uint32 FromHost32(uint32 x) { return x; }
+ static uint32 ToHost32(uint32 x) { return x; }
+
+ static uint64 FromHost64(uint64 x) { return x; }
+ static uint64 ToHost64(uint64 x) { return x; }
+
+ static bool IsLittleEndian() { return true; }
+
+#elif defined SAFTM_IS_BIG_ENDIAN
+
+ static uint16 FromHost16(uint16 x) { return gbswap_16(x); }
+ static uint16 ToHost16(uint16 x) { return gbswap_16(x); }
+
+ static uint32 FromHost32(uint32 x) { return gbswap_32(x); }
+ static uint32 ToHost32(uint32 x) { return gbswap_32(x); }
+
+ static uint64 FromHost64(uint64 x) { return gbswap_64(x); }
+ static uint64 ToHost64(uint64 x) { return gbswap_64(x); }
+
+ static bool IsLittleEndian() { return false; }
+
+#endif /* ENDIAN */
+};
+
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_BASE_ENDIAN_H_
diff --git a/lang_id/common/lite_base/float16.h b/lang_id/common/lite_base/float16.h
new file mode 100644
index 0000000..bc3fd21
--- /dev/null
+++ b/lang_id/common/lite_base/float16.h
@@ -0,0 +1,62 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_BASE_FLOAT16_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_BASE_FLOAT16_H_
+
+#include "lang_id/common/lite_base/casts.h"
+#include "lang_id/common/lite_base/integral-types.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+// 16 bit encoding of a float. NOTE: can't be used directly for computation:
+// one first needs to convert it to a normal float, using Float16To32.
+//
+// Compact 16-bit encoding of floating point numbers. This
+// representation uses 1 bit for the sign, 8 bits for the exponent and
+// 7 bits for the mantissa. It is assumed that floats are in IEEE 754
+// format so a float16 is just bits 16-31 of a single precision float.
+//
+// NOTE: The IEEE floating point standard defines a float16 format that
+// is different than this format (it has fewer bits of exponent and more
+// bits of mantissa). We don't use that format here because conversion
+// to/from 32-bit floats is more complex for that format, and the
+// conversion for this format is very simple.
+//
+// <---------float16------------>
+// s e e e e e e e e f f f f f f f f f f f f f f f f f f f f f f f
+// <------------------------------float-------------------------->
+// 3 3 2 2 1 1 0
+// 1 0 3 2 5 4 0
+
+typedef uint16 float16;
+
+static inline float16 Float32To16(float f) {
+ // Note that we just truncate the mantissa bits: we make no effort to
+ // do any smarter rounding.
+ return (bit_cast<uint32>(f) >> 16) & 0xffff;
+}
+
+static inline float Float16To32(float16 f) {
+ // We fill in the new mantissa bits with 0, and don't do anything smarter.
+ return bit_cast<float>(f << 16);
+}
+
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_BASE_FLOAT16_H_
diff --git a/lang_id/common/lite_base/integral-types.h b/lang_id/common/lite_base/integral-types.h
new file mode 100644
index 0000000..4c3038c
--- /dev/null
+++ b/lang_id/common/lite_base/integral-types.h
@@ -0,0 +1,58 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Basic integer type definitions.
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_BASE_INTEGRAL_TYPES_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_BASE_INTEGRAL_TYPES_H_
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+typedef unsigned int uint32;
+typedef unsigned long long uint64;
+
+#ifndef SWIG
+typedef int int32;
+typedef unsigned char uint8; // NOLINT
+typedef unsigned short uint16; // NOLINT
+
+// A type to represent a Unicode code-point value. As of Unicode 4.0,
+// such values require up to 21 bits.
+// (For type-checking on pointers, make this explicitly signed,
+// and it should always be the signed version of whatever int32 is.)
+typedef signed int char32;
+#endif // SWIG
+
+#ifdef COMPILER_MSVC
+typedef __int64 int64;
+#else
+typedef long long int64; // NOLINT
+#endif // COMPILER_MSVC
+
+// Some compile-time assertions that our new types have the intended size.
+static_assert(sizeof(int) == 4, "Our typedefs depend on int being 32 bits");
+static_assert(sizeof(uint32) == 4, "wrong size");
+static_assert(sizeof(int32) == 4, "wrong size");
+static_assert(sizeof(uint8) == 1, "wrong size");
+static_assert(sizeof(uint16) == 2, "wrong size");
+static_assert(sizeof(char32) == 4, "wrong size");
+static_assert(sizeof(int64) == 8, "wrong size");
+
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_BASE_INTEGRAL_TYPES_H_
diff --git a/lang_id/common/lite_base/logging.h b/lang_id/common/lite_base/logging.h
new file mode 100644
index 0000000..88797cb
--- /dev/null
+++ b/lang_id/common/lite_base/logging.h
@@ -0,0 +1,52 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_BASE_LOGGING_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_BASE_LOGGING_H_
+
+#ifdef SAFTM_COMPACT_LOGGING
+
+// One gets the compact logging only one requests it explicitly, by passing
+// --define saftm_compact_logging=true on the blaze command-line.
+#include "lang_id/common/lite_base/compact-logging.h"
+
+#else
+
+// Otherwise, one gets the standard base/logging.h You should do so, unless you
+// have a really good reason to switch to the compact logging.
+#include "base/logging.h"
+
+#define SAFTM_LOG LOG
+#define SAFTM_CHECK CHECK
+#define SAFTM_CHECK_EQ CHECK_EQ
+#define SAFTM_CHECK_LT CHECK_LT
+#define SAFTM_CHECK_LE CHECK_LE
+#define SAFTM_CHECK_GT CHECK_GT
+#define SAFTM_CHECK_GE CHECK_GE
+#define SAFTM_CHECK_NE CHECK_NE
+
+#define SAFTM_DLOG DLOG
+#define SAFTM_DCHECK DCHECK
+#define SAFTM_DCHECK_EQ DCHECK_EQ
+#define SAFTM_DCHECK_LT DCHECK_LT
+#define SAFTM_DCHECK_LE DCHECK_LE
+#define SAFTM_DCHECK_GT DCHECK_GT
+#define SAFTM_DCHECK_GE DCHECK_GE
+#define SAFTM_DCHECK_NE DCHECK_NE
+
+#endif // SAFTM_COMPACT_LOGGING
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_BASE_LOGGING_H_
diff --git a/lang_id/common/lite_base/macros.h b/lang_id/common/lite_base/macros.h
new file mode 100644
index 0000000..8fe5e8a
--- /dev/null
+++ b/lang_id/common/lite_base/macros.h
@@ -0,0 +1,82 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_BASE_MACROS_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_BASE_MACROS_H_
+
+#define SAFTM_DISALLOW_COPY_AND_ASSIGN(TypeName) \
+ TypeName(const TypeName &) = delete; \
+ TypeName &operator=(const TypeName &) = delete
+
+// The SAFTM_FALLTHROUGH_INTENDED macro can be used to annotate implicit
+// fall-through between switch labels:
+//
+// switch (x) {
+// case 40:
+// case 41:
+// if (truth_is_out_there) {
+// ++x;
+// SAFTM_FALLTHROUGH_INTENDED; // Use instead of/along with annotations
+// // in comments.
+// } else {
+// return x;
+// }
+// case 42:
+// ...
+//
+// As shown in the example above, the SAFTM_FALLTHROUGH_INTENDED macro should
+// be followed by a semicolon. It is designed to mimic control-flow statements
+// like 'break;', so it can be placed in most places where 'break;' can, but
+// only if there are no statements on the execution path between it and the
+// next switch label.
+//
+// When compiled with clang, the SAFTM_FALLTHROUGH_INTENDED macro is expanded
+// to [[clang::fallthrough]] attribute, which is analysed when performing
+// switch labels fall-through diagnostic ('-Wimplicit-fallthrough'). See clang
+// documentation on language extensions for details:
+// http://clang.llvm.org/docs/AttributeReference.html#fallthrough-clang-fallthrough
+//
+// When used with unsupported compilers, the SAFTM_FALLTHROUGH_INTENDED macro
+// has no effect on diagnostics.
+//
+// In either case this macro has no effect on runtime behavior and performance
+// of code.
+#if defined(__clang__) && defined(__has_warning)
+#if __has_feature(cxx_attributes) && __has_warning("-Wimplicit-fallthrough")
+#define SAFTM_FALLTHROUGH_INTENDED [[clang::fallthrough]] // NOLINT
+#endif
+#endif
+
+#ifndef SAFTM_FALLTHROUGH_INTENDED
+#define SAFTM_FALLTHROUGH_INTENDED \
+ do { \
+ } while (0)
+#endif
+
+// SAFTM_UNIQUE_ID(prefix) expands to a unique id that starts with prefix.
+//
+// The current implementation expands to prefix_<line_number>; hence, multiple
+// uses of this macro with the same prefix and on the same line will result in
+// the same identifier name. In those cases, if you need different ids, we
+// suggest you use different prefixes.
+//
+// Implementation is tricky; for more info, see
+// https://stackoverflow.com/questions/1597007/creating-c-macro-with-and-line-token-concatenation-with-positioning-macr
+#define SAFTM_UNIQUE_ID_INTERNAL2(x, y) x ## y
+#define SAFTM_UNIQUE_ID_INTERNAL(x, y) SAFTM_UNIQUE_ID_INTERNAL2(x, y)
+#define SAFTM_UNIQUE_ID(prefix) SAFTM_UNIQUE_ID_INTERNAL(prefix ## _, __LINE__)
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_BASE_MACROS_H_
diff --git a/lang_id/common/lite_strings/numbers.cc b/lang_id/common/lite_strings/numbers.cc
new file mode 100644
index 0000000..e0c66f3
--- /dev/null
+++ b/lang_id/common/lite_strings/numbers.cc
@@ -0,0 +1,86 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/common/lite_strings/numbers.h"
+
+#include <ctype.h>
+#include <stdlib.h>
+#include <climits>
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+// Returns true if the characters that start at address ptr (inclusive) and stop
+// at the first '\0' consist of only whitespaces, as determined by isspace().
+// Note: this function returns false if ptr is nullptr.
+static bool OnlyWhitespaces(const char *ptr) {
+ if (ptr == nullptr) {
+ return false;
+ }
+ for (; *ptr != '\0'; ++ptr) {
+ if (!isspace(*ptr)) {
+ return false;
+ }
+ }
+ return true;
+}
+
+bool LiteAtoi(const char *c_str, int *value) {
+ if (c_str == nullptr) {
+ return false;
+ }
+
+ // Short version of man strtol:
+ //
+ // strtol parses some optional whitespaces, an optional +/- sign, and next a
+ // succession of digits. If it finds some digits, it sets temp to point to
+ // the first character after that succession of digits and returns the parsed
+ // integer.
+ //
+ // If there were no digits at all, strtol() sets temp to be c_str (the start
+ // address) and returns 0.
+ char *temp = nullptr;
+ const long int parsed_value = strtol(c_str, &temp, 0); // NOLINT
+
+ // Check for overflow. Note: to simplify the code, we assume that LONG_MIN /
+ // LONG_MAX means that strtol encountered an overflow (normally, in that case,
+ // one should also inspect errno). Hence, we maybe give up the possibility to
+ // parse one extreme value on each side (min/max). That should be ok.
+ if ((parsed_value == LONG_MIN) || (parsed_value == LONG_MAX) ||
+ (parsed_value < INT_MIN) || (parsed_value > INT_MAX)) {
+ return false;
+ }
+ *value = static_cast<int>(parsed_value);
+
+ // First part of the expression below means that the input string contained at
+ // least one digit. The other part checks that what remains after the number
+ // (if anything) consists only of whitespaces.
+ return (temp != c_str) && OnlyWhitespaces(temp);
+}
+
+bool LiteAtof(const char *c_str, float *value) {
+ if (c_str == nullptr) {
+ return false;
+ }
+
+ // strtof is similar to strtol, see more detailed comments inside LiteAtoi.
+ char *temp = nullptr;
+ *value = strtof(c_str, &temp);
+ return (temp != c_str) && OnlyWhitespaces(temp);
+}
+
+} // namespace mobile
+} // namespace nlp_saft
diff --git a/lang_id/common/lite_strings/numbers.h b/lang_id/common/lite_strings/numbers.h
new file mode 100644
index 0000000..4b3c93c
--- /dev/null
+++ b/lang_id/common/lite_strings/numbers.h
@@ -0,0 +1,74 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_STRINGS_NUMBERS_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_STRINGS_NUMBERS_H_
+
+#include <string>
+
+#include "lang_id/common/lite_strings/stringpiece.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+// Parses an int from a C-style string; similar to absl::SimpleAtoi.
+//
+// c_str should point to a zero-terminated array of chars that contains the
+// number representation as (a) "<radix-10-number>" (e.g., "721"), (b)
+// "0x<radix-16-number>" (e.g., "0xa1"), or (c) "0<radix-8-number>" (e.g.,
+// "017201"). Whitespaces (as determined by isspace()) are allowed before and
+// after the number representation (but obviously not in the middle).
+//
+// Stores parsed number into *value. Returns true on success, false on error.
+// Note: presence of extra non-whitespace characters after the number counts as
+// an error: e.g., parsing "123a" will return false due to the extra "a" (which
+// is not a valid radix-10 digit). This function also returns false for strings
+// that do not contain any digit (e.g., ""), as well as for overflows /
+// underflows.
+bool LiteAtoi(const char *c_str, int *value);
+
+inline bool LiteAtoi(const string &s, int *value) {
+ return LiteAtoi(s.c_str(), value);
+}
+
+inline bool LiteAtoi(StringPiece sp, int *value) {
+ // Unfortunately, we can't directly call LiteAtoi(sp.data()): LiteAtoi(const
+ // char *) needs a zero-terminated string.
+ const string temp(sp.data(), sp.size());
+ return LiteAtoi(temp.c_str(), value);
+}
+
+// Like LiteAtoi, but for float; similar to absl::SimpleAtof.
+//
+// NOTE: currently, does not properly handle overflow / underflow.
+// TODO(salcianu): fix that.
+bool LiteAtof(const char *c_str, float *value);
+
+inline bool LiteAtof(const string &s, float *value) {
+ return LiteAtof(s.c_str(), value);
+}
+
+inline bool LiteAtof(StringPiece sp, float *value) {
+ // Unfortunately, we can't directly call LiteAtoi(sp.data()): LiteAtoi(const
+ // char *) needs a zero-terminated string.
+ const string temp(sp.data(), sp.size());
+ return LiteAtof(temp.c_str(), value);
+}
+
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_STRINGS_NUMBERS_H_
diff --git a/lang_id/common/lite_strings/str-cat.h b/lang_id/common/lite_strings/str-cat.h
new file mode 100644
index 0000000..f24e6e6
--- /dev/null
+++ b/lang_id/common/lite_strings/str-cat.h
@@ -0,0 +1,104 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_STRINGS_STR_CAT_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_STRINGS_STR_CAT_H_
+
+// Less efficient but more compact versions of several absl string utils.
+//
+// "More compact" means "pulls in fewer code dependencies". That's useful if
+// one tries to minimize the code size.
+//
+// Note: the name and the signature of the functions from this header were
+// chosen to minimize the effort of converting code that uses absl::LiteStrCat &
+// co to our more compact functions.
+
+#include <string>
+
+#ifdef COMPILER_MSVC
+#include <sstream>
+#endif // COMPILER_MSVC
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+// Less efficient but more compact version of absl::LiteStrCat().
+//
+// Given a value v (see supported types below) LiteStrCat(v) returns a new
+// string that contains the representation of v. For examples, see
+// str-cat_test.cc.
+template <typename T>
+inline string LiteStrCat(T v) {
+#ifdef COMPILER_MSVC
+ std::stringstream stream;
+ stream << input;
+ return stream.str();
+#else
+ return std::to_string(v);
+#endif
+}
+
+template <>
+inline string LiteStrCat(const char *v) {
+ return string(v);
+}
+
+// TODO(salcianu): use a reference type (const string &). For some reason, I
+// couldn't get that to work on a first try.
+template <>
+inline string LiteStrCat(string v) {
+ return v;
+}
+
+template <>
+inline string LiteStrCat(char v) {
+ return string(1, v);
+}
+
+// Less efficient but more compact version of absl::LiteStrAppend().
+template <typename T>
+inline void LiteStrAppend(string *dest, T v) {
+ dest->append(LiteStrCat(v)); // NOLINT
+}
+
+template <typename T1, typename T2>
+inline void LiteStrAppend(string *dest, T1 v1, T2 v2) {
+ dest->append(LiteStrCat(v1)); // NOLINT
+ dest->append(LiteStrCat(v2)); // NOLINT
+}
+
+template <typename T1, typename T2, typename T3>
+inline void LiteStrAppend(string *dest, T1 v1, T2 v2, T3 v3) {
+ LiteStrAppend(dest, v1, v2);
+ dest->append(LiteStrCat(v3)); // NOLINT
+}
+
+template <typename T1, typename T2, typename T3, typename T4>
+inline void LiteStrAppend(string *dest, T1 v1, T2 v2, T3 v3, T4 v4) {
+ LiteStrAppend(dest, v1, v2, v3);
+ dest->append(LiteStrCat(v4)); // NOLINT
+}
+
+template <typename T1, typename T2, typename T3, typename T4, typename T5>
+inline void LiteStrAppend(string *dest, T1 v1, T2 v2, T3 v3, T4 v4, T5 v5) {
+ LiteStrAppend(dest, v1, v2, v3, v4);
+ dest->append(LiteStrCat(v5)); // NOLINT
+}
+
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_STRINGS_STR_CAT_H_
diff --git a/lang_id/common/lite_strings/str-split.cc b/lang_id/common/lite_strings/str-split.cc
new file mode 100644
index 0000000..199bb69
--- /dev/null
+++ b/lang_id/common/lite_strings/str-split.cc
@@ -0,0 +1,37 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/common/lite_strings/str-split.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+std::vector<StringPiece> LiteStrSplit(StringPiece text, char delim) {
+ std::vector<StringPiece> result;
+ int token_start = 0;
+ if (!text.empty()) {
+ for (size_t i = 0; i < text.size() + 1; ++i) {
+ if ((i == text.size()) || (text[i] == delim)) {
+ result.emplace_back(text.data() + token_start, i - token_start);
+ token_start = i + 1;
+ }
+ }
+ }
+ return result;
+}
+
+} // namespace mobile
+} // namespace nlp_saft
diff --git a/lang_id/common/lite_strings/str-split.h b/lang_id/common/lite_strings/str-split.h
new file mode 100644
index 0000000..300bc9f
--- /dev/null
+++ b/lang_id/common/lite_strings/str-split.h
@@ -0,0 +1,48 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_STRINGS_STR_SPLIT_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_STRINGS_STR_SPLIT_H_
+
+#include <vector>
+
+#include "lang_id/common/lite_strings/stringpiece.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+// Splits |text| on |delim|; similar to absl::StrSplit.
+//
+// Returns a list of tokens. Each token is represented by a StringPiece that
+// indicates a range of chars from |text|.
+//
+// Example: StrSplit("apple,orange", ',') returns two tokens: a StringPiece that
+// points to "apple", and another one for "orange".
+//
+// If one concatenates all returned tokens with |delim| in between, one gets the
+// original |text|. E.g., If we split "apple,orange," on ',', we get three
+// tokens: "apple", "orange" and "" (an empty token). We do not filter out
+// empty tokens. If necessary, the caller can do that.
+//
+// Note: if the input text is empty, we return an empty list of tokens. In
+// general, the number of returned tokens is 1 + the number of occurences of
+// |delim| inside |text|.
+std::vector<StringPiece> LiteStrSplit(StringPiece text, char delim);
+
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_STRINGS_STR_SPLIT_H_
diff --git a/lang_id/common/lite_strings/stringpiece.h b/lang_id/common/lite_strings/stringpiece.h
new file mode 100644
index 0000000..d19ea41
--- /dev/null
+++ b/lang_id/common/lite_strings/stringpiece.h
@@ -0,0 +1,88 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef TC3_STD_STRING_IMPORT
+#define TC3_STD_STRING_IMPORT
+#include <string>
+
+namespace libtextclassifier3 {
+using string = std::string;
+template <class CharT, class Traits = std::char_traits<CharT>,
+ class Allocator = std::allocator<CharT> >
+using basic_string = std::basic_string<CharT, Traits, Allocator>;
+} // namespace libtextclassifier3
+#endif
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_STRINGS_STRINGPIECE_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_STRINGS_STRINGPIECE_H_
+
+#include <stddef.h>
+#include <string.h>
+
+#include <ostream>
+#include <string>
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+// Read-only "view" of a piece of data. Does not own the underlying data.
+class StringPiece {
+ public:
+ StringPiece() : StringPiece(nullptr, 0) {}
+
+ StringPiece(const char *str) // NOLINT
+ : start_(str), size_(strlen(str)) {}
+
+ StringPiece(const char *start, size_t size) : start_(start), size_(size) {}
+
+ // Intentionally no "explicit" keyword: in function calls, we want strings to
+ // be converted to StringPiece implicitly.
+ StringPiece(const string &s) // NOLINT
+ : StringPiece(s.data(), s.size()) {}
+
+ StringPiece(const string &s, int offset, int len)
+ : StringPiece(s.data() + offset, len) {}
+
+ char operator[](size_t i) const { return start_[i]; }
+
+ // Returns start address of underlying data.
+ const char *data() const { return start_; }
+
+ // Returns number of bytes of underlying data.
+ size_t size() const { return size_; }
+ size_t length() const { return size_; }
+
+ // Returns true if this StringPiece does not refer to any characters.
+ bool empty() const { return size() == 0; }
+
+ template <typename A>
+ explicit operator basic_string<char, std::char_traits<char>, A>() const {
+ if (!data()) return {};
+ return basic_string<char, std::char_traits<char>, A>(data(), size());
+ }
+
+ private:
+ const char *start_; // Not owned.
+ size_t size_;
+};
+
+inline std::ostream &operator<<(std::ostream &out, StringPiece sp) {
+ return out.write(sp.data(), sp.size());
+}
+
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_STRINGS_STRINGPIECE_H_
diff --git a/lang_id/common/math/algorithm.h b/lang_id/common/math/algorithm.h
new file mode 100644
index 0000000..a963807
--- /dev/null
+++ b/lang_id/common/math/algorithm.h
@@ -0,0 +1,49 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Generic utils similar to those from the C++ header <algorithm>.
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_MATH_ALGORITHM_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_MATH_ALGORITHM_H_
+
+#include <algorithm>
+#include <vector>
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+// Returns index of max element from the vector |elements|. Returns 0 if
+// |elements| is empty. T should be a type that can be compared by operator<.
+template<typename T>
+inline int GetArgMax(const std::vector<T> &elements) {
+ return std::distance(
+ elements.begin(),
+ std::max_element(elements.begin(), elements.end()));
+}
+
+// Returns index of min element from the vector |elements|. Returns 0 if
+// |elements| is empty. T should be a type that can be compared by operator<.
+template<typename T>
+inline int GetArgMin(const std::vector<T> &elements) {
+ return std::distance(
+ elements.begin(),
+ std::min_element(elements.begin(), elements.end()));
+}
+
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_MATH_ALGORITHM_H_
diff --git a/lang_id/common/math/checksum.cc b/lang_id/common/math/checksum.cc
new file mode 100644
index 0000000..23d88bc
--- /dev/null
+++ b/lang_id/common/math/checksum.cc
@@ -0,0 +1,43 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/common/math/checksum.h"
+
+// Though we use the same zlib header on all platforms, the implementation used
+// is from NDK on android and from third_party/zlib on iOS/linux. See BUILD
+// rule.
+#include <zlib.h>
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+// static
+uint32 Crc32::GetInitialCrc32() {
+ static const uint32 kCrcInitZero = crc32(0L, nullptr, 0);
+ return kCrcInitZero;
+}
+
+void Crc32::Update(const char *str, int len) {
+ if (str == nullptr || len == 0) {
+ return;
+ }
+ current_ = crc32(current_,
+ reinterpret_cast<const unsigned char *>(str),
+ len);
+}
+
+} // namespace mobile
+} // namespace nlp_saft
diff --git a/lang_id/common/math/checksum.h b/lang_id/common/math/checksum.h
new file mode 100644
index 0000000..d62893f
--- /dev/null
+++ b/lang_id/common/math/checksum.h
@@ -0,0 +1,55 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_MATH_CHECKSUM_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_MATH_CHECKSUM_H_
+
+#include "lang_id/common/lite_base/integral-types.h"
+#include "lang_id/common/lite_strings/stringpiece.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+// Class to compute a 32bit Cyclic Redundancy Check (CRC) in a cummulative way.
+//
+// To use, create an instance of this class, repeatedly call Update() to "feed"
+// it your pieces of data, and, when done, call Get().
+class Crc32 {
+ public:
+ Crc32() : current_(GetInitialCrc32()) {}
+
+ // Updates current CRC32 code to also take into account the |len| bytes that
+ // start at address |str|.
+ void Update(const char *str, int len);
+
+ // Updates current CRC32 code to also take into account the bytes from |s|.
+ void Update(StringPiece s) { Update(s.data(), s.size()); }
+
+ // Returns the CRC32 code for the data so far.
+ uint32 Get() const { return current_; }
+
+ private:
+ // Returns the initial value for current_.
+ static uint32 GetInitialCrc32();
+
+ // CRC32 for the data so far.
+ uint32 current_;
+};
+
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_MATH_CHECKSUM_H_
diff --git a/lang_id/common/math/fastexp.cc b/lang_id/common/math/fastexp.cc
new file mode 100644
index 0000000..44df91f
--- /dev/null
+++ b/lang_id/common/math/fastexp.cc
@@ -0,0 +1,50 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/common/math/fastexp.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+const int FastMathClass::kBits;
+const int FastMathClass::kMask1;
+const int FastMathClass::kMask2;
+constexpr float FastMathClass::kLogBase2OfE;
+
+FastMathClass FastMathInstance;
+
+// Data taken from util/math/fastmath.cc
+const FastMathClass::Table FastMathClass::cache_ = {
+ {0, 45549, 91345, 137391, 183686, 230233, 277032, 324086, 371395, 418961,
+ 466785, 514869, 563214, 611822, 660693, 709830, 759233, 808905, 858847,
+ 909060, 959545, 1010305, 1061340, 1112652, 1164243, 1216114, 1268267,
+ 1320703, 1373423, 1426430, 1479725, 1533309, 1587184, 1641351, 1695813,
+ 1750570, 1805625, 1860979, 1916633, 1972590, 2028850, 2085416, 2142289,
+ 2199470, 2256963, 2314767, 2372885, 2431319, 2490070, 2549140, 2608531,
+ 2668245, 2728282, 2788646, 2849337, 2910358, 2971710, 3033396, 3095416,
+ 3157773, 3220469, 3283505, 3346884, 3410606, 3474675, 3539091, 3603857,
+ 3668975, 3734447, 3800274, 3866458, 3933002, 3999907, 4067176, 4134809,
+ 4202810, 4271180, 4339922, 4409036, 4478526, 4548394, 4618640, 4689268,
+ 4760280, 4831677, 4903462, 4975636, 5048203, 5121164, 5194520, 5268275,
+ 5342431, 5416989, 5491952, 5567322, 5643101, 5719292, 5795897, 5872917,
+ 5950356, 6028215, 6106497, 6185204, 6264338, 6343902, 6423898, 6504329,
+ 6585196, 6666502, 6748250, 6830442, 6913080, 6996166, 7079704, 7163696,
+ 7248143, 7333049, 7418416, 7504247, 7590543, 7677309, 7764545, 7852255,
+ 7940441, 8029106, 8118253, 8207884, 8298001}
+};
+
+} // namespace mobile
+} // namespace nlp_saft
diff --git a/lang_id/common/math/fastexp.h b/lang_id/common/math/fastexp.h
new file mode 100644
index 0000000..05b654a
--- /dev/null
+++ b/lang_id/common/math/fastexp.h
@@ -0,0 +1,70 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Fast approximation for exp.
+//
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_MATH_FASTEXP_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_MATH_FASTEXP_H_
+
+#include <cassert>
+#include <cmath>
+#include <limits>
+
+#include "lang_id/common/lite_base/casts.h"
+#include "lang_id/common/lite_base/integral-types.h"
+#include "lang_id/common/lite_base/logging.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+class FastMathClass {
+ private:
+ static const int kBits = 7;
+ static const int kMask1 = (1 << kBits) - 1;
+ static const int kMask2 = 0xFF << kBits;
+ static constexpr float kLogBase2OfE = 1.44269504088896340736f;
+
+ struct Table {
+ int32 exp1[1 << kBits];
+ };
+
+ public:
+ float VeryFastExp2(float f) const {
+ SAFTM_DCHECK_LE(fabs(f), 126);
+ const float g = f + (127 + (1 << (23 - kBits)));
+ const int32 x = bit_cast<int32>(g);
+ int32 ret = ((x & kMask2) << (23 - kBits))
+ | cache_.exp1[x & kMask1];
+ return bit_cast<float>(ret);
+ }
+
+ float VeryFastExp(float f) const {
+ return VeryFastExp2(f * kLogBase2OfE);
+ }
+
+ private:
+ static const Table cache_;
+};
+
+extern FastMathClass FastMathInstance;
+
+inline float VeryFastExp(float f) { return FastMathInstance.VeryFastExp(f); }
+
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_MATH_FASTEXP_H_
diff --git a/lang_id/common/math/hash.cc b/lang_id/common/math/hash.cc
new file mode 100644
index 0000000..d320428
--- /dev/null
+++ b/lang_id/common/math/hash.cc
@@ -0,0 +1,83 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/common/math/hash.h"
+
+#include "lang_id/common/lite_base/macros.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+namespace utils {
+
+namespace {
+// Lower-level versions of Get... that read directly from a character buffer
+// without any bounds checking.
+inline uint32 DecodeFixed32(const char *ptr) {
+ return ((static_cast<uint32>(static_cast<unsigned char>(ptr[0]))) |
+ (static_cast<uint32>(static_cast<unsigned char>(ptr[1])) << 8) |
+ (static_cast<uint32>(static_cast<unsigned char>(ptr[2])) << 16) |
+ (static_cast<uint32>(static_cast<unsigned char>(ptr[3])) << 24));
+}
+
+// 0xff is in case char is signed.
+static inline uint32 ByteAs32(char c) { return static_cast<uint32>(c) & 0xff; }
+} // namespace
+
+uint32 Hash32(const char *data, size_t n, uint32 seed) {
+ // 'm' and 'r' are mixing constants generated offline.
+ // They're not really 'magic', they just happen to work well.
+ const uint32 m = 0x5bd1e995;
+ const int r = 24;
+
+ // Initialize the hash to a 'random' value
+ uint32 h = seed ^ n;
+
+ // Mix 4 bytes at a time into the hash
+ while (n >= 4) {
+ uint32 k = DecodeFixed32(data);
+ k *= m;
+ k ^= k >> r;
+ k *= m;
+ h *= m;
+ h ^= k;
+ data += 4;
+ n -= 4;
+ }
+
+ // Handle the last few bytes of the input array
+ switch (n) {
+ case 3:
+ h ^= ByteAs32(data[2]) << 16;
+ SAFTM_FALLTHROUGH_INTENDED;
+ case 2:
+ h ^= ByteAs32(data[1]) << 8;
+ SAFTM_FALLTHROUGH_INTENDED;
+ case 1:
+ h ^= ByteAs32(data[0]);
+ h *= m;
+ }
+
+ // Do a few final mixes of the hash to ensure the last few
+ // bytes are well-incorporated.
+ h ^= h >> 13;
+ h *= m;
+ h ^= h >> 15;
+ return h;
+}
+
+} // namespace utils
+} // namespace mobile
+} // namespace nlp_saft
diff --git a/lang_id/common/math/hash.h b/lang_id/common/math/hash.h
new file mode 100644
index 0000000..08c32be
--- /dev/null
+++ b/lang_id/common/math/hash.h
@@ -0,0 +1,62 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef TC3_STD_STRING_IMPORT
+#define TC3_STD_STRING_IMPORT
+#include <string>
+
+namespace libtextclassifier3 {
+using string = std::string;
+template <class CharT, class Traits = std::char_traits<CharT>,
+ class Allocator = std::allocator<CharT> >
+using basic_string = std::basic_string<CharT, Traits, Allocator>;
+} // namespace libtextclassifier3
+#endif
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_MATH_HASH_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_MATH_HASH_H_
+
+#include <string>
+
+#include "lang_id/common/lite_base/integral-types.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+namespace utils {
+
+// Returns a 32 bit hash of the |n| bytes that start at |data|, using |seed| for
+// internal initialization. By changing the seed, one effectively gets
+// different hash functions.
+//
+// NOTE: this function is guaranteed not to change in the future.
+//
+// IMPORTANT: for speed reasons, this method does not check its parameters
+// |data| and |n|. The caller should ensure that n >= 0 and that one can read
+// from the memory area [data, data + n).
+uint32 Hash32(const char *data, size_t n, uint32 seed);
+
+static inline uint32 Hash32WithDefaultSeed(const char *data, size_t n) {
+ return Hash32(data, n, 0xBEEF);
+}
+
+static inline uint32 Hash32WithDefaultSeed(const string &input) {
+ return Hash32WithDefaultSeed(input.data(), input.size());
+}
+
+} // namespace utils
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_MATH_HASH_H_
diff --git a/lang_id/common/math/softmax.cc b/lang_id/common/math/softmax.cc
new file mode 100644
index 0000000..c21f843
--- /dev/null
+++ b/lang_id/common/math/softmax.cc
@@ -0,0 +1,102 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/common/math/softmax.h"
+
+#include <limits>
+
+#include "lang_id/common/lite_base/logging.h"
+#include "lang_id/common/math/fastexp.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+float ComputeSoftmaxProbability(const std::vector<float> &scores, int label) {
+ if ((label < 0) || (label >= scores.size())) {
+ SAFTM_LOG(ERROR) << "label " << label << " outside range "
+ << "[0, " << scores.size() << ")";
+ return 0.0f;
+ }
+
+ // Standard softmax formula for label's probability is
+ //
+ // exp(scores[label]) / sum_i exp(scores[i])
+ //
+ // We compute the mathematically equivalent
+ //
+ // 1 / (1 + sum_{i != label} exp(scores[i] - scores[label]))
+ //
+ // which saves two calls to exp().
+ const float label_score = scores[label];
+ float denominator = 1.0f; // Contribution of i == label.
+ for (int i = 0; i < scores.size(); ++i) {
+ if (i == label) continue;
+ const float delta_score = scores[i] - label_score;
+
+ // TODO(salcianu): one can optimize the test below, to avoid any float
+ // operation: extract exponent (via bit mask + shift) and check it's >= 4.
+ if (fabs(delta_score) >= 16.0f) {
+ if (delta_score > 0.0f) {
+ // If delta_score >= 16, the denominator (e^delta_score + other positive
+ // terms) is very big and its inverse can be approximated with 0.
+ return 0.0f;
+ } else {
+ // If delta_score <= -16, then e^delta_score < 1.2e-7. Even if we have
+ // 1000 such labels i, their sum is < 1.2e-4 (which gets summed with
+ // 1.0f for i == label). Hence, we can approximate each such label with
+ // 0 and skip the call to VeryFastExp and the update to denominator.
+ continue;
+ }
+ }
+
+ // At this point, delta_score is in (-16.0, 16.0). For such values, vfexp
+ // works fine: no under/overflows (we have tests for that in fastexp_test).
+ // Also, even for 1000 labels, denominator will not overflow.
+ denominator += VeryFastExp(delta_score);
+ }
+ return 1.0f / denominator;
+}
+
+std::vector<float> ComputeSoftmax(const std::vector<float> &scores,
+ float alpha) {
+ std::vector<float> softmax;
+ std::vector<float> exp_scores;
+ exp_scores.reserve(scores.size());
+ softmax.reserve(scores.size());
+
+ // Find max value in "scores" vector and rescale to avoid overflows.
+ float max = std::numeric_limits<float>::lowest();
+ for (const auto &score : scores) {
+ if (score > max) max = score;
+ }
+ float denominator = 0;
+ for (auto &score : scores) {
+ // See comments above in ComputeSoftmaxProbability for the reasoning behind
+ // this approximation.
+ const float delta_score = alpha * (score - max);
+ const float exp_score = delta_score < -16.0f ? 0 : VeryFastExp(delta_score);
+ exp_scores.push_back(exp_score);
+ denominator += exp_score;
+ }
+
+ for (int i = 0; i < scores.size(); ++i) {
+ softmax.push_back(exp_scores[i] / denominator);
+ }
+ return softmax;
+}
+
+} // namespace mobile
+} // namespace nlp_saft
diff --git a/lang_id/common/math/softmax.h b/lang_id/common/math/softmax.h
new file mode 100644
index 0000000..0100e59
--- /dev/null
+++ b/lang_id/common/math/softmax.h
@@ -0,0 +1,40 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_MATH_SOFTMAX_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_MATH_SOFTMAX_H_
+
+#include <vector>
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+// Computes probability of a softmax label. Parameter "scores" is the vector of
+// softmax logits. Returns 0.0f if "label" is outside the range [0,
+// scores.size()).
+float ComputeSoftmaxProbability(const std::vector<float> &scores, int label);
+
+// Computes and returns a softmax for a given vector of floats. Parameter
+// "scores" is the vector of softmax logits.
+//
+// The alpha parameter is a scaling factor on the logits.
+std::vector<float> ComputeSoftmax(const std::vector<float> &scores,
+ float alpha = 1.0f);
+
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_MATH_SOFTMAX_H_
diff --git a/lang_id/common/registry.h b/lang_id/common/registry.h
new file mode 100644
index 0000000..d2c5271
--- /dev/null
+++ b/lang_id/common/registry.h
@@ -0,0 +1,321 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Mechanism to instantiate classes by name.
+//
+// This mechanism is useful if the concrete classes to be instantiated are not
+// statically known (e.g., if their names are read from a dynamically-provided
+// config).
+//
+// In that case, the first step is to define the API implemented by the
+// instantiated classes. E.g.,
+//
+// // In a header file function.h:
+//
+// // Abstract function that takes a double and returns a double.
+// class Function : public RegisterableClass<Function> {
+// public:
+// virtual ~Function() {}
+// virtual double Evaluate(double x) = 0;
+// };
+//
+// // Should be inside namespace libtextclassifier3::mobile.
+// SAFTM_DECLARE_CLASS_REGISTRY_NAME(Function);
+//
+// Notice the inheritance from RegisterableClass<Function>. RegisterableClass
+// is defined by this file (registry.h). Under the hood, this inheritanace
+// defines a "registry" that maps names (zero-terminated arrays of chars) to
+// factory methods that create Functions. You should give a human-readable name
+// to this registry. To do that, use the following macro in a .cc file (it has
+// to be a .cc file, as it defines some static data):
+//
+// // Inside function.cc
+// // Should be inside namespace libtextclassifier3::mobile.
+// SAFTM_DEFINE_CLASS_REGISTRY_NAME("function", Function);
+//
+// Now, let's define a few concrete Functions: e.g.,
+//
+// class Cos : public Function {
+// public:
+// double Evaluate(double x) override { return cos(x); }
+// SAFTM_DEFINE_REGISTRATION_METHOD("cos", Cos);
+// };
+//
+// class Exp : public Function {
+// public:
+// double Evaluate(double x) override { return exp(x); }
+// SAFTM_DEFINE_REGISTRATION_METHOD("sin", Sin);
+// };
+//
+// Each concrete Function implementation should have (in the public section) the
+// macro
+//
+// SAFTM_DEFINE_REGISTRATION_METHOD("name", implementation_class);
+//
+// This defines a RegisterClass static method that, when invoked, associates
+// "name" with a factory method that creates instances of implementation_class.
+//
+// Before instantiating Functions by name, we need to tell our system which
+// Functions we may be interested in. This is done by calling the
+// Foo::RegisterClass() for each relevant Foo implementation of Function. It is
+// ok to call Foo::RegisterClass() multiple times (even in parallel): only the
+// first call will perform something, the others will return immediately.
+//
+// Cos::RegisterClass();
+// Exp::RegisterClass();
+//
+// Now, let's instantiate a Function based on its name. This get a lot more
+// interesting if the Function name is not statically known (i.e.,
+// read from an input proto:
+//
+// std::unique_ptr<Function> f(Function::Create("cos"));
+// double result = f->Evaluate(arg);
+//
+// NOTE: the same binary can use this mechanism for different APIs. E.g., one
+// can also have (in the binary with Function, Sin, Cos, etc):
+//
+// class IntFunction : public RegisterableClass<IntFunction> {
+// public:
+// virtual ~IntFunction() {}
+// virtual int Evaluate(int k) = 0;
+// };
+//
+// SAFTM_DECLARE_CLASS_REGISTRY_NAME(IntFunction);
+//
+// SAFTM_DEFINE_CLASS_REGISTRY_NAME("int function", IntFunction);
+//
+// class Inc : public IntFunction {
+// public:
+// int Evaluate(int k) override { return k + 1; }
+// SAFTM_DEFINE_REGISTRATION_METHOD("inc", Inc);
+// };
+//
+// RegisterableClass<Function> and RegisterableClass<IntFunction> define their
+// own registries: each maps string names to implementation of the corresponding
+// API.
+//
+// NOTE: the mechanism described above requires you to explicitly call
+// RegisterClass() for all relevant classes before instantiating them. You can
+// do this in the main() function or in any other function that is guaranteed to
+// run before the code that instantiates those classes. Alternatively, you can
+// use the macro SAFTM_STATIC_REGISTRATION to perform this registration in a
+// decentralized fashion. Just use that macro in a .cc file, outside any
+// function / class, e.g.,
+//
+// SAFTM_STATIC_REGISTRATION(Cos);
+//
+// and make sure you link in all symbols from that .cc file; e.g., in bazel, use
+// alwayslink = 1 for the corresponding cc_library. Still, please be aware that
+// using alwayslink = 1 limits the ability of the linker to perform dead code
+// elimination.
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_REGISTRY_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_REGISTRY_H_
+
+#include <stdlib.h>
+#include <string.h>
+
+#include <string>
+#include <vector>
+
+#include "lang_id/common/lite_base/logging.h"
+#include "lang_id/common/lite_base/macros.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+namespace internal {
+// Registry that associates keys (zero-terminated array of chars) with values.
+// Values are pointers to type T (the template parameter). This is used to
+// store the association between component names and factory methods that
+// produce those components; the error messages are focused on that case.
+//
+// Internally, this registry uses a linked list of (key, value) pairs. We do
+// not use an STL map, list, etc because we aim for small code size.
+template <class T>
+class ComponentRegistry {
+ public:
+ explicit ComponentRegistry(const char *name) : name_(name), head_(nullptr) {}
+
+ // Adds a the (key, value) pair to this registry (if the key does not already
+ // exists in this registry) and returns true. If the registry already has a
+ // mapping for key, returns false and does not modify the registry. NOTE: the
+ // error (false) case happens even if the existing value for key is equal with
+ // the new one.
+ //
+ // This method does not take ownership of key, nor of value.
+ bool Add(const char *key, T *value) {
+ const Cell *old_cell = FindCell(key);
+ if (old_cell != nullptr) {
+ SAFTM_LOG(ERROR) << "Duplicate component: " << key;
+ return false;
+ }
+ Cell *new_cell = new Cell(key, value, head_);
+ head_ = new_cell;
+ return true;
+ }
+
+ // Returns the value attached to a key in this registry. Returns nullptr on
+ // error (e.g., unknown key).
+ T *Lookup(const char *key) const {
+ const Cell *cell = FindCell(key);
+ if (cell == nullptr) {
+ SAFTM_LOG(ERROR) << "Unknown " << name() << " component: " << key;
+ }
+ return (cell == nullptr) ? nullptr : cell->value();
+ }
+
+ T *Lookup(const string &key) const { return Lookup(key.c_str()); }
+
+ // Returns name of this ComponentRegistry.
+ const char *name() const { return name_; }
+
+ // Fills *names with names of all components registered in this
+ // ComponentRegistry. Previous content of *names is cleared out.
+ void GetComponentNames(std::vector<string> *names) {
+ names->clear();
+ for (const Cell *c = head_; c!= nullptr; c = c->next()) {
+ names->emplace_back(c->key());
+ }
+ }
+
+ private:
+ // Cell for the singly-linked list underlying this ComponentRegistry. Each
+ // cell contains a key, the value for that key, as well as a pointer to the
+ // next Cell from the list.
+ class Cell {
+ public:
+ // Constructs a new Cell.
+ Cell(const char *key, T *value, Cell *next)
+ : key_(key), value_(value), next_(next) {}
+
+ const char *key() const { return key_; }
+ T *value() const { return value_; }
+ Cell *next() const { return next_; }
+
+ private:
+ const char *const key_;
+ T *const value_;
+ Cell *const next_;
+ };
+
+ // Finds Cell for indicated key in the singly-linked list pointed to by head_.
+ // Returns pointer to that first Cell with that key, or nullptr if no such
+ // Cell (i.e., unknown key).
+ //
+ // Caller does NOT own the returned pointer.
+ const Cell *FindCell(const char *key) const {
+ const Cell *c = head_;
+ while (c != nullptr && strcmp(key, c->key()) != 0) {
+ c = c->next();
+ }
+ return c;
+ }
+
+ // Human-readable description for this ComponentRegistry. For debug purposes.
+ const char *const name_;
+
+ // Pointer to the first Cell from the underlying list of (key, value) pairs.
+ Cell *head_;
+};
+} // namespace internal
+
+// Base class for registerable classes.
+template <class T>
+class RegisterableClass {
+ public:
+ // Factory function type.
+ typedef T *(Factory)();
+
+ // Registry type.
+ typedef internal::ComponentRegistry<Factory> Registry;
+
+ // Creates a new instance of T. Returns pointer to new instance or nullptr in
+ // case of errors (e.g., unknown component).
+ //
+ // Passes ownership of the returned pointer to the caller.
+ static T *Create(const string &name) { // NOLINT
+ auto *factory = registry()->Lookup(name);
+ if (factory == nullptr) {
+ SAFTM_LOG(ERROR) << "Unknown RegisterableClass " << name;
+ return nullptr;
+ }
+ return factory();
+ }
+
+ // Returns registry for class.
+ static Registry *registry() {
+ static Registry *registry_for_type_t = new Registry(kRegistryName);
+ return registry_for_type_t;
+ }
+
+ protected:
+ // Factory method for subclass ComponentClass. Used internally by the static
+ // method RegisterClass() defined by SAFTM_DEFINE_REGISTRATION_METHOD.
+ template <class ComponentClass>
+ static T *_internal_component_factory() {
+ return new ComponentClass();
+ }
+
+ private:
+ // Human-readable name for the registry for this class.
+ static const char kRegistryName[];
+};
+
+// Defines the static method component_class::RegisterClass() that should be
+// called before trying to instantiate component_class by name. Should be used
+// inside the public section of the declaration of component_class. See
+// comments at the top-level of this file.
+#define SAFTM_DEFINE_REGISTRATION_METHOD(component_name, component_class) \
+ static void RegisterClass() { \
+ static bool once = registry()->Add( \
+ component_name, &_internal_component_factory<component_class>); \
+ if (!once) { \
+ SAFTM_LOG(ERROR) << "Problem registering " << component_name; \
+ } \
+ SAFTM_DCHECK(once); \
+ }
+
+// Defines the human-readable name of the registry associated with base_class.
+#define SAFTM_DECLARE_CLASS_REGISTRY_NAME(base_class) \
+ template <> \
+ const char ::libtextclassifier3::mobile::RegisterableClass<base_class>::kRegistryName[]
+
+// Defines the human-readable name of the registry associated with base_class.
+#define SAFTM_DEFINE_CLASS_REGISTRY_NAME(registry_name, base_class) \
+ template <> \
+ const char \
+ ::libtextclassifier3::mobile::RegisterableClass<base_class>::kRegistryName[] \
+ = registry_name
+
+// Register component_name, by calling component_class::RegisterClass() on
+// program start-up, before main. NOTE: this macro should be used in
+// conjunction with something like alwayslink = 1 from bazel. That is
+// discouraged, as it prevents the linker from doing dead code elimination, so
+// please use this macro only in special cases. Instead, if you care about code
+// size, then you should aim to explicitly call RegisterClass from your code
+// (e.g., from the main method, or from the constructor of the class that may
+// need those registered components).
+#define SAFTM_STATIC_REGISTRATION(component_class) \
+ static bool SAFTM_UNIQUE_ID(_kRegistrationDummy) = [] { \
+ component_class::RegisterClass(); \
+ return true; \
+ }()
+
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_REGISTRY_H_
diff --git a/lang_id/common/stl-util.h b/lang_id/common/stl-util.h
new file mode 100644
index 0000000..95d8d3b
--- /dev/null
+++ b/lang_id/common/stl-util.h
@@ -0,0 +1,45 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_STL_UTIL_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_STL_UTIL_H_
+
+
+namespace libtextclassifier3 {
+namespace mobile {
+namespace utils {
+
+// Deletes all the elements in an STL container and clears the container. This
+// function is suitable for use with a vector, set, hash_set, or any other STL
+// container which defines sensible begin(), end(), and clear() methods. If
+// container is NULL, this function is a no-op.
+template <typename T>
+void STLDeleteElements(T *container) {
+ if (!container) return;
+ auto it = container->begin();
+ while (it != container->end()) {
+ auto temp = it;
+ ++it;
+ delete *temp;
+ }
+ container->clear();
+}
+
+} // namespace utils
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_STL_UTIL_H_
diff --git a/lang_id/common/utf8.cc b/lang_id/common/utf8.cc
new file mode 100644
index 0000000..ef00145
--- /dev/null
+++ b/lang_id/common/utf8.cc
@@ -0,0 +1,39 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/common/utf8.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+namespace utils {
+
+const char *GetSafeEndOfUtf8String(const char *data, size_t size) {
+ const char *const hard_end = data + size;
+ const char *curr = data;
+ while (curr < hard_end && *curr) {
+ int num_bytes = utils::OneCharLen(curr);
+ const char *new_curr = curr + num_bytes;
+ if (new_curr > hard_end) {
+ return curr;
+ }
+ curr = new_curr;
+ }
+ return curr;
+}
+
+} // namespace utils
+} // namespace mobile
+} // namespace nlp_saft
diff --git a/lang_id/common/utf8.h b/lang_id/common/utf8.h
new file mode 100644
index 0000000..2365429
--- /dev/null
+++ b/lang_id/common/utf8.h
@@ -0,0 +1,76 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef TC3_STD_STRING_IMPORT
+#define TC3_STD_STRING_IMPORT
+#include <string>
+
+namespace libtextclassifier3 {
+using string = std::string;
+template <class CharT, class Traits = std::char_traits<CharT>,
+ class Allocator = std::allocator<CharT> >
+using basic_string = std::basic_string<CharT, Traits, Allocator>;
+} // namespace libtextclassifier3
+#endif
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_UTF8_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_UTF8_H_
+
+#include <stddef.h>
+
+#include <string>
+
+namespace libtextclassifier3 {
+namespace mobile {
+namespace utils {
+
+// Returns the length (number of bytes) of the UTF8 code point starting at src,
+// by reading only the byte from address src.
+//
+// The result is a number from the set {1, 2, 3, 4}.
+static inline int OneCharLen(const char *src) {
+ // On most platforms, char is unsigned by default, but iOS is an exception.
+ // The cast below makes sure we always interpret *src as an unsigned char.
+ return "\1\1\1\1\1\1\1\1\1\1\1\1\2\2\3\4"
+ [(*(reinterpret_cast<const unsigned char *>(src)) & 0xFF) >> 4];
+}
+
+// Returns a pointer "end" inside [data, data + size) such that the prefix from
+// [data, end) is the largest one that does not contain '\0' and offers the
+// following guarantee: if one starts with
+//
+// curr = text.data()
+//
+// and keeps executing
+//
+// curr += OneCharLen(curr)
+//
+// one would eventually reach curr == end (the pointer returned by this
+// function) without accessing data outside the string. This guards against
+// scenarios like a broken UTF8 string which has only e.g., the first 2 bytes
+// from a 3-byte UTF8 sequence.
+//
+// Preconditions: data != nullptr.
+const char *GetSafeEndOfUtf8String(const char *data, size_t size);
+
+static inline const char *GetSafeEndOfUtf8String(const string &text) {
+ return GetSafeEndOfUtf8String(text.data(), text.size());
+}
+
+} // namespace utils
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_UTF8_H_
diff --git a/lang_id/custom-tokenizer.cc b/lang_id/custom-tokenizer.cc
new file mode 100644
index 0000000..5a6b997
--- /dev/null
+++ b/lang_id/custom-tokenizer.cc
@@ -0,0 +1,159 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/custom-tokenizer.h"
+
+#include <ctype.h>
+
+#include <string>
+
+#include "lang_id/common/lite_base/attributes.h"
+#include "lang_id/common/lite_base/logging.h"
+#include "lang_id/common/utf8.h"
+#include "utf.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+namespace lang_id {
+
+namespace {
+inline bool IsTokenSeparator(int num_bytes, const char *curr) {
+ if (num_bytes != 1) {
+ return false;
+ }
+ return !isalpha(*curr);
+}
+
+// Appends to *word the UTF8 encoding for the lowercase version of the UTF8
+// character that starts at |curr| and has |num_bytes| bytes.
+//
+// NOTE: if the current UTF8 character does not have a lowercase version, then
+// we append the original UTF8 character.
+inline SAFTM_ATTRIBUTE_ALWAYS_INLINE void AppendLowerCase(const char *curr,
+ int num_bytes,
+ string *word) {
+ if (num_bytes == 1) {
+ // Optimize the ASCII case.
+ word->push_back(tolower(*curr));
+ return;
+ }
+
+ // Harder, general case.
+ //
+ // NOTE: for lowercasing, we use the utils from utf.h:
+ // charntorune + tolowerrune + runetochar. Unfortunately, that library does
+ // not contain any fast util for determining the number of bytes for the UTF8
+ // character that starts at a given address *without* converting to a full
+ // codepoint (like our utils::OneCharLen, which is used intensively by the
+ // rest of our code, including by the performance-critical char ngram
+ // feature). Hence, the rest of our code continues to use utils::OneCharLen,
+ // and here, when we append the bytes to *word, we make sure that's consistent
+ // with utils::OneCharLen.
+
+ // charntorune() below reads the UTF8 character that starts at curr (using at
+ // most num_bytes bytes) and stores the corresponding codepoint into rune.
+ Rune rune;
+ charntorune(&rune, curr, num_bytes);
+ if (rune != Runeerror) {
+ Rune lower = tolowerrune(rune);
+ char lower_buf[UTFmax];
+ runetochar(lower_buf, &lower);
+
+ // When appending the UTF8 bytes to word, we do not use the number of bytes
+ // returned by runetochar(); instead, we use utils::OneCharLen(), the same
+ // method used by the char ngram feature. We expect them to be equal, but
+ // just in case.
+ int lower_num_bytes = utils::OneCharLen(lower_buf);
+
+ // Using lower_num_bytes below is safe, because, by definition of UTFmax,
+ SAFTM_DCHECK_GE(UTFmax, 4);
+
+ // And, by implementation of utils::OneCharLen():
+ SAFTM_DCHECK_GT(lower_num_bytes, 0);
+ SAFTM_DCHECK_LE(lower_num_bytes, 4);
+ word->append(lower_buf, lower_num_bytes);
+ } else {
+ // There are sequences of bytes that charntorune() can't convert into a
+ // valid Rune (a special case is [0xEF, 0xBF, 0xBD], the UTF8 encoding for
+ // the U+FFFD special Unicode character, which is also the value of
+ // Runeerror). We keep those bytes unchanged.
+ word->append(curr, num_bytes);
+ }
+}
+} // namespace
+
+void TokenizerForLangId::Setup(TaskContext *context) {
+ lowercase_input_ = context->Get("lang_id_lowercase_input", false);
+}
+
+void TokenizerForLangId::Tokenize(StringPiece text,
+ LightSentence *sentence) const {
+ const char *const start = text.data();
+ const char *curr = start;
+ const char *end = utils::GetSafeEndOfUtf8String(start, text.size());
+
+ // Corner case: the safe part of the text is empty ("").
+ if (curr >= end) {
+ return;
+ }
+
+ // Number of bytes for UTF8 character starting at *curr. Note: the loop below
+ // is guaranteed to terminate because in each iteration, we move curr by at
+ // least num_bytes, and num_bytes is guaranteed to be > 0.
+ int num_bytes = utils::OneCharLen(curr);
+ while (curr < end) {
+ // Jump over consecutive token separators.
+ while (IsTokenSeparator(num_bytes, curr)) {
+ curr += num_bytes;
+ if (curr >= end) {
+ return;
+ }
+ num_bytes = utils::OneCharLen(curr);
+ }
+
+ // If control reaches this point, we are at beginning of a non-empty token.
+ sentence->emplace_back();
+ string *word = &(sentence->back());
+
+ // Add special token-start character.
+ word->push_back('^');
+
+ // Add UTF8 characters to word, until we hit the end of the safe text or a
+ // token separator.
+ while (true) {
+ if (lowercase_input_) {
+ AppendLowerCase(curr, num_bytes, word);
+ } else {
+ word->append(curr, num_bytes);
+ }
+ curr += num_bytes;
+ if (curr >= end) {
+ break;
+ }
+ num_bytes = utils::OneCharLen(curr);
+ if (IsTokenSeparator(num_bytes, curr)) {
+ curr += num_bytes;
+ num_bytes = utils::OneCharLen(curr);
+ break;
+ }
+ }
+ word->push_back('$');
+ }
+}
+
+} // namespace lang_id
+} // namespace mobile
+} // namespace nlp_saft
diff --git a/lang_id/custom-tokenizer.h b/lang_id/custom-tokenizer.h
new file mode 100644
index 0000000..6fab796
--- /dev/null
+++ b/lang_id/custom-tokenizer.h
@@ -0,0 +1,55 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_CUSTOM_TOKENIZER_H_
+#define NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_CUSTOM_TOKENIZER_H_
+
+#include <string>
+
+#include "lang_id/common/fel/task-context.h"
+#include "lang_id/common/lite_strings/stringpiece.h"
+#include "lang_id/light-sentence.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+namespace lang_id {
+
+// Custom tokenizer for the LangId model.
+class TokenizerForLangId {
+ public:
+ void Setup(TaskContext *context);
+
+ // Tokenizes |text|, placing the tokens into |sentence|. Customized for
+ // LangId. Currently (Sep 15, 2016) we tokenize on space, newline, tab, and
+ // any other 1-byte UTF8 character which is not a letter, ignore all empty
+ // tokens, and (for each of the remaining tokens) prepend "^" (special token
+ // begin marker) and append "$" (special token end marker).
+ //
+ // Tokens are stored into the "repeated Token token;" field of *sentence.
+ void Tokenize(StringPiece text, LightSentence *sentence) const;
+
+ private:
+ // If true, during tokenization, we use the lowercase version of each Unicode
+ // character from the text to tokenize. E.g., if this is true, the text "Foo
+ // bar" is tokenized as ["foo", "bar"]; otherwise, we get ["Foo", "bar"].
+ bool lowercase_input_ = false;
+};
+
+} // namespace lang_id
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_CUSTOM_TOKENIZER_H_
diff --git a/lang_id/fb_model/lang-id-from-fb.cc b/lang_id/fb_model/lang-id-from-fb.cc
new file mode 100644
index 0000000..f8e39d7
--- /dev/null
+++ b/lang_id/fb_model/lang-id-from-fb.cc
@@ -0,0 +1,55 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/fb_model/lang-id-from-fb.h"
+
+#include "lang_id/fb_model/model-provider-from-fb.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+namespace lang_id {
+
+std::unique_ptr<LangId> GetLangIdFromFlatbufferFile(const string &filename) {
+ std::unique_ptr<ModelProvider> model_provider(
+ new ModelProviderFromFlatbuffer(filename));
+
+ // NOTE: we avoid absl (including absl::make_unique), due to b/113350902
+ return std::unique_ptr<LangId>( // NOLINT
+ new LangId(std::move(model_provider)));
+}
+
+std::unique_ptr<LangId> GetLangIdFromFlatbufferFileDescriptor(int fd) {
+ std::unique_ptr<ModelProvider> model_provider(
+ new ModelProviderFromFlatbuffer(fd));
+
+ // NOTE: we avoid absl (including absl::make_unique), due to b/113350902
+ return std::unique_ptr<LangId>( // NOLINT
+ new LangId(std::move(model_provider)));
+}
+
+std::unique_ptr<LangId> GetLangIdFromFlatbufferBytes(const char *data,
+ size_t num_bytes) {
+ std::unique_ptr<ModelProvider> model_provider(
+ new ModelProviderFromFlatbuffer(data, num_bytes));
+
+ // NOTE: we avoid absl (including absl::make_unique), due to b/113350902
+ return std::unique_ptr<LangId>( // NOLINT
+ new LangId(std::move(model_provider)));
+}
+
+} // namespace lang_id
+} // namespace mobile
+} // namespace nlp_saft
diff --git a/lang_id/fb_model/lang-id-from-fb.h b/lang_id/fb_model/lang-id-from-fb.h
new file mode 100644
index 0000000..51bcffe
--- /dev/null
+++ b/lang_id/fb_model/lang-id-from-fb.h
@@ -0,0 +1,61 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_FB_MODEL_LANG_ID_FROM_FB_H_
+#define NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_FB_MODEL_LANG_ID_FROM_FB_H_
+
+#include <stddef.h>
+
+#include <memory>
+#include <string>
+
+#include "lang_id/lang-id.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+namespace lang_id {
+
+// Returns a LangId built using the SAFT model in flatbuffer format from
+// |filename|.
+std::unique_ptr<LangId> GetLangIdFromFlatbufferFile(const string &filename);
+
+// Returns a LangId built using the SAFT model in flatbuffer format from
+// given file descriptor.
+std::unique_ptr<LangId> GetLangIdFromFlatbufferFileDescriptor(int fd);
+
+// Returns a LangId built using the SAFT model in flatbuffer format from
+// the |num_bytes| bytes that start at address |data|.
+//
+// IMPORTANT: the model bytes must be alive during the lifetime of the returned
+// LangId. To avoid overhead (e.g., heap allocation), this method does not make
+// a private copy of the model bytes. Avoiding overhead is the main reason we
+// use flatbuffers.
+std::unique_ptr<LangId> GetLangIdFromFlatbufferBytes(const char *data,
+ size_t num_bytes);
+
+// Convenience string-based version of GetLangIdFromFlatbufferBytes.
+//
+// IMPORTANT: |bytes| must be alive during the lifetime of the returned LangId.
+inline std::unique_ptr<LangId> GetLangIdFromFlatbufferBytes(
+ const string &bytes) {
+ return GetLangIdFromFlatbufferBytes(bytes.data(), bytes.size());
+}
+
+} // namespace lang_id
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_FB_MODEL_LANG_ID_FROM_FB_H_
diff --git a/lang_id/fb_model/model-provider-from-fb.cc b/lang_id/fb_model/model-provider-from-fb.cc
new file mode 100644
index 0000000..3357963
--- /dev/null
+++ b/lang_id/fb_model/model-provider-from-fb.cc
@@ -0,0 +1,102 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/fb_model/model-provider-from-fb.h"
+
+#include "lang_id/common/file/file-utils.h"
+#include "lang_id/common/flatbuffers/embedding-network-params-from-flatbuffer.h"
+#include "lang_id/common/flatbuffers/model-utils.h"
+#include "lang_id/common/lite_strings/str-split.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+namespace lang_id {
+
+ModelProviderFromFlatbuffer::ModelProviderFromFlatbuffer(const string &filename)
+
+ // Using mmap as a fast way to read the model bytes. As the file is
+ // unmapped only when the field scoped_mmap_ is destructed, the model bytes
+ // stay alive for the entire lifetime of this object.
+ : scoped_mmap_(new ScopedMmap(filename)) {
+ Initialize(scoped_mmap_->handle().to_stringpiece());
+}
+
+ModelProviderFromFlatbuffer::ModelProviderFromFlatbuffer(int fd)
+
+ // Using mmap as a fast way to read the model bytes. As the file is
+ // unmapped only when the field scoped_mmap_ is destructed, the model bytes
+ // stay alive for the entire lifetime of this object.
+ : scoped_mmap_(new ScopedMmap(fd)) {
+ Initialize(scoped_mmap_->handle().to_stringpiece());
+}
+
+void ModelProviderFromFlatbuffer::Initialize(StringPiece model_bytes) {
+ // Note: valid_ was initialized to false. In the code below, we set valid_ to
+ // true only if all initialization steps completed successfully. Otherwise,
+ // we return early, leaving valid_ to its default value false.
+ model_ = saft_fbs::GetVerifiedModelFromBytes(model_bytes);
+ if (model_ == nullptr) {
+ SAFTM_LOG(ERROR) << "Unable to initialize ModelProviderFromFlatbuffer";
+ return;
+ }
+
+ // Initialize context_ parameters.
+ if (!saft_fbs::FillParameters(*model_, &context_)) {
+ // FillParameters already performs error logging.
+ return;
+ }
+
+ // Init languages_.
+ const string known_languages_str = context_.Get("supported_languages", "");
+ for (StringPiece sp : LiteStrSplit(known_languages_str, ',')) {
+ languages_.emplace_back(sp);
+ }
+ if (languages_.empty()) {
+ SAFTM_LOG(ERROR) << "Unable to find list of supported_languages";
+ return;
+ }
+
+ // Init nn_params_.
+ if (!InitNetworkParams()) {
+ // InitNetworkParams already performs error logging.
+ return;
+ }
+
+ // Everything looks fine.
+ valid_ = true;
+}
+
+bool ModelProviderFromFlatbuffer::InitNetworkParams() {
+ const string kInputName = "language-identifier-network";
+ StringPiece bytes =
+ saft_fbs::GetInputBytes(saft_fbs::GetInputByName(model_, kInputName));
+ if ((bytes.data() == nullptr) || bytes.empty()) {
+ SAFTM_LOG(ERROR) << "Unable to get bytes for model input " << kInputName;
+ return false;
+ }
+ std::unique_ptr<EmbeddingNetworkParamsFromFlatbuffer> nn_params_from_fb(
+ new EmbeddingNetworkParamsFromFlatbuffer(bytes));
+ if (!nn_params_from_fb->is_valid()) {
+ SAFTM_LOG(ERROR) << "EmbeddingNetworkParamsFromFlatbuffer not valid";
+ return false;
+ }
+ nn_params_ = std::move(nn_params_from_fb);
+ return true;
+}
+
+} // namespace lang_id
+} // namespace mobile
+} // namespace nlp_saft
diff --git a/lang_id/fb_model/model-provider-from-fb.h b/lang_id/fb_model/model-provider-from-fb.h
new file mode 100644
index 0000000..d25c903
--- /dev/null
+++ b/lang_id/fb_model/model-provider-from-fb.h
@@ -0,0 +1,118 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_FB_MODEL_MODEL_PROVIDER_FROM_FB_H_
+#define NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_FB_MODEL_MODEL_PROVIDER_FROM_FB_H_
+
+#include <cstddef>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "lang_id/common/fel/task-context.h"
+#include "lang_id/common/file/mmap.h"
+#include "lang_id/common/flatbuffers/model_generated.h"
+#include "lang_id/common/lite_strings/stringpiece.h"
+#include "lang_id/model-provider.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+namespace lang_id {
+
+// ModelProvider for LangId, based on a SAFT model in flatbuffer format.
+class ModelProviderFromFlatbuffer : public ModelProvider {
+ public:
+ // Constructs a model provider based on a flatbuffer-format SAFT model from
+ // |filename|.
+ explicit ModelProviderFromFlatbuffer(const string &filename);
+
+ // Constructs a model provider based on a flatbuffer-format SAFT model from
+ // file descriptor |fd|.
+ explicit ModelProviderFromFlatbuffer(int fd);
+
+ // Constructs a model provider from a flatbuffer-format SAFT model the bytes
+ // of which are already in RAM (size bytes starting from address data).
+ // Useful if you "transport" these bytes otherwise than via a normal file
+ // (e.g., if you embed them somehow in your binary).
+ //
+ // IMPORTANT: |data| should be alive during the lifetime of the
+ // newly-constructed ModelProviderFromFlatbuffer. This is trivial to ensure
+ // for data that's statically embedded in your binary, but more complex in
+ // other cases. To avoid overhead (e.g., heap allocation), this method does
+ // not make a private copy of the data. In general, the ownership of the
+ // newly-constructed ModelProviderFromFlatbuffer is immediately passed to a
+ // LangId object (which doesn't pass it further); hence, one needs to make
+ // sure |data| is alive during the lifetime of that LangId object.
+ ModelProviderFromFlatbuffer(const char *data, std::size_t size) {
+ StringPiece model_bytes(data, size);
+ Initialize(model_bytes);
+ }
+
+ ~ModelProviderFromFlatbuffer() override = default;
+
+ const TaskContext *GetTaskContext() const override {
+ return &context_;
+ }
+
+ const EmbeddingNetworkParams *GetNnParams() const override {
+ return nn_params_.get();
+ }
+
+ std::vector<string> GetLanguages() const override {
+ return languages_;
+ }
+
+ private:
+ // Initializes the fields of this class based on the flatbuffer from
+ // |model_bytes|. These bytes are supposed to be the representation of a
+ // Model flatbuffer and should be alive during the lifetime of this object.
+ void Initialize(StringPiece model_bytes);
+
+ // Initializes nn_params_ based on model_.
+ bool InitNetworkParams();
+
+ // If filename-based constructor is used, scoped_mmap_ keeps the file mmapped
+ // during the lifetime of this object, such that references inside the Model
+ // flatbuffer from those bytes remain valid.
+ const std::unique_ptr<ScopedMmap> scoped_mmap_;
+
+ // Pointer to the flatbuffer from
+ //
+ // (a) [if filename constructor was used:] the bytes mmapped by scoped_mmap_
+ // (for safety considerations, see comment for that field), or
+ //
+ // (b) [of (data, size) constructor was used:] the bytes from [data,
+ // data+size). Please read carefully the doc for that constructor.
+ const saft_fbs::Model *model_;
+
+ // Context returned by this model provider. We set its parameters based on
+ // model_, at construction time.
+ TaskContext context_;
+
+ // List of supported languages, see GetLanguages(). We expect this list to be
+ // specified by the ModelParameter named "supported_languages" from model_.
+ std::vector<string> languages_;
+
+ // EmbeddingNetworkParams, see GetNnParams(). Set based on the ModelInput
+ // named "language-identifier-network" from model_.
+ std::unique_ptr<EmbeddingNetworkParams> nn_params_;
+};
+
+} // namespace lang_id
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_FB_MODEL_MODEL_PROVIDER_FROM_FB_H_
diff --git a/lang_id/features/char-ngram-feature.cc b/lang_id/features/char-ngram-feature.cc
new file mode 100644
index 0000000..e52b2f2
--- /dev/null
+++ b/lang_id/features/char-ngram-feature.cc
@@ -0,0 +1,151 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/features/char-ngram-feature.h"
+
+#include <utility>
+#include <vector>
+
+#include "lang_id/common/fel/feature-types.h"
+#include "lang_id/common/fel/task-context.h"
+#include "lang_id/common/lite_base/logging.h"
+#include "lang_id/common/math/hash.h"
+#include "lang_id/common/utf8.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+namespace lang_id {
+
+bool ContinuousBagOfNgramsFunction::Setup(TaskContext *context) {
+ // Parameters in the feature function descriptor.
+ bool include_terminators = GetBoolParameter("include_terminators", false);
+ if (!include_terminators) {
+ SAFTM_LOG(ERROR) << "No support for include_terminators=true";
+ return false;
+ }
+
+ bool include_spaces = GetBoolParameter("include_spaces", false);
+ if (include_spaces) {
+ SAFTM_LOG(ERROR) << "No support for include_spaces=true";
+ return false;
+ }
+
+ bool use_equal_ngram_weight = GetBoolParameter("use_equal_weight", false);
+ if (use_equal_ngram_weight) {
+ SAFTM_LOG(ERROR) << "No support for use_equal_weight=true";
+ return false;
+ }
+
+ ngram_id_dimension_ = GetIntParameter("id_dim", 10000);
+ ngram_size_ = GetIntParameter("size", 3);
+
+ counts_.assign(ngram_id_dimension_, 0);
+ return true;
+}
+
+bool ContinuousBagOfNgramsFunction::Init(TaskContext *context) {
+ set_feature_type(new NumericFeatureType(name(), ngram_id_dimension_));
+ return true;
+}
+
+int ContinuousBagOfNgramsFunction::ComputeNgramCounts(
+ const LightSentence &sentence) const {
+ SAFTM_CHECK_EQ(counts_.size(), ngram_id_dimension_);
+ SAFTM_CHECK_EQ(non_zero_count_indices_.size(), 0);
+
+ int total_count = 0;
+
+ for (const string &word : sentence) {
+ const char *const word_end = word.data() + word.size();
+
+ // Set ngram_start at the start of the current token (word).
+ const char *ngram_start = word.data();
+
+ // Set ngram_end ngram_size UTF8 characters after ngram_start. Note: each
+ // UTF8 character contains between 1 and 4 bytes.
+ const char *ngram_end = ngram_start;
+ int num_utf8_chars = 0;
+ do {
+ ngram_end += utils::OneCharLen(ngram_end);
+ num_utf8_chars++;
+ } while ((num_utf8_chars < ngram_size_) && (ngram_end < word_end));
+
+ if (num_utf8_chars < ngram_size_) {
+ // Current token is so small, it does not contain a single ngram of
+ // ngram_size UTF8 characters. Not much we can do in this case ...
+ continue;
+ }
+
+ // At this point, [ngram_start, ngram_end) is the first ngram of ngram_size
+ // UTF8 characters from current token.
+ while (true) {
+ // Compute ngram id: hash(ngram) % ngram_id_dimension
+ int ngram_id = (
+ utils::Hash32WithDefaultSeed(ngram_start, ngram_end - ngram_start)
+ % ngram_id_dimension_);
+
+ // Use a reference to the actual count, such that we can both test whether
+ // the count was 0 and increment it without perfoming two lookups.
+ int &ref_to_count_for_ngram = counts_[ngram_id];
+ if (ref_to_count_for_ngram == 0) {
+ non_zero_count_indices_.push_back(ngram_id);
+ }
+ ref_to_count_for_ngram++;
+ total_count++;
+ if (ngram_end >= word_end) {
+ break;
+ }
+
+ // Advance both ngram_start and ngram_end by one UTF8 character. This
+ // way, the number of UTF8 characters between them remains constant
+ // (ngram_size).
+ ngram_start += utils::OneCharLen(ngram_start);
+ ngram_end += utils::OneCharLen(ngram_end);
+ }
+ } // end of loop over tokens.
+
+ return total_count;
+}
+
+void ContinuousBagOfNgramsFunction::Evaluate(const WorkspaceSet &workspaces,
+ const LightSentence &sentence,
+ FeatureVector *result) const {
+ // Find the char ngram counts.
+ int total_count = ComputeNgramCounts(sentence);
+
+ // Populate the feature vector.
+ const float norm = static_cast<float>(total_count);
+
+ // TODO(salcianu): explore treating dense vectors (i.e., many non-zero
+ // elements) separately.
+ for (int ngram_id : non_zero_count_indices_) {
+ const float weight = counts_[ngram_id] / norm;
+ FloatFeatureValue value(ngram_id, weight);
+ result->add(feature_type(), value.discrete_value);
+
+ // Clear up counts_, for the next invocation of Evaluate().
+ counts_[ngram_id] = 0;
+ }
+
+ // Clear up non_zero_count_indices_, for the next invocation of Evaluate().
+ non_zero_count_indices_.clear();
+}
+
+SAFTM_STATIC_REGISTRATION(ContinuousBagOfNgramsFunction);
+
+} // namespace lang_id
+} // namespace mobile
+} // namespace nlp_saft
diff --git a/lang_id/features/char-ngram-feature.h b/lang_id/features/char-ngram-feature.h
new file mode 100644
index 0000000..8280bca
--- /dev/null
+++ b/lang_id/features/char-ngram-feature.h
@@ -0,0 +1,92 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_FEATURES_CHAR_NGRAM_FEATURE_H_
+#define NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_FEATURES_CHAR_NGRAM_FEATURE_H_
+
+#include <string>
+
+#include "lang_id/common/fel/feature-extractor.h"
+#include "lang_id/common/fel/task-context.h"
+#include "lang_id/common/fel/workspace.h"
+#include "lang_id/features/light-sentence-features.h"
+#include "lang_id/light-sentence.h"
+
+// TODO(abakalov): Add a test.
+namespace libtextclassifier3 {
+namespace mobile {
+namespace lang_id {
+
+// Class for computing continuous char ngram features.
+//
+// Feature function descriptor parameters:
+// include_terminators(bool, false):
+// If 'true', then splits the text based on spaces to get tokens, adds "^"
+// to the beginning of each token, and adds "$" to the end of each token.
+// NOTE: currently, we support only include_terminators=true.
+// include_spaces(bool, false):
+// If 'true', then includes char ngrams containing spaces.
+// NOTE: currently, we support only include_spaces=false.
+// use_equal_weight(bool, false):
+// If 'true', then weighs each unique ngram by 1.0 / (number of unique
+// ngrams in the input). Otherwise, weighs each unique ngram by (ngram
+// count) / (total number of ngrams).
+// NOTE: currently, we support only use_equal_weight=false.
+// id_dim(int, 10000):
+// The integer id of each char ngram is computed as follows:
+// Hash32WithDefault(char ngram) % id_dim.
+// size(int, 3):
+// Only ngrams of this size will be extracted.
+//
+// NOTE: this class is not thread-safe. TODO(salcianu): make it thread-safe.
+class ContinuousBagOfNgramsFunction : public LightSentenceFeature {
+ public:
+ bool Setup(TaskContext *context) override;
+ bool Init(TaskContext *context) override;
+
+ // Appends the features computed from the sentence to the feature vector.
+ void Evaluate(const WorkspaceSet &workspaces, const LightSentence &sentence,
+ FeatureVector *result) const override;
+
+ SAFTM_DEFINE_REGISTRATION_METHOD("continuous-bag-of-ngrams",
+ ContinuousBagOfNgramsFunction);
+
+ private:
+ // Auxiliary for Evaluate(). Fills counts_ and non_zero_count_indices_ (see
+ // below), and returns the total ngram count.
+ int ComputeNgramCounts(const LightSentence &sentence) const;
+
+ // counts_[i] is the count of all ngrams with id i. Work data for Evaluate().
+ // NOTE: we declare this vector as a field, such that its underlying capacity
+ // stays allocated in between calls to Evaluate().
+ mutable std::vector<int> counts_;
+
+ // Indices of non-zero elements of counts_. See comments for counts_.
+ mutable std::vector<int> non_zero_count_indices_;
+
+ // The integer id of each char ngram is computed as follows:
+ // Hash32WithDefaultSeed(char_ngram) % ngram_id_dimension_.
+ int ngram_id_dimension_;
+
+ // Only ngrams of size ngram_size_ will be extracted.
+ int ngram_size_;
+};
+
+} // namespace lang_id
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_FEATURES_CHAR_NGRAM_FEATURE_H_
diff --git a/lang_id/features/light-sentence-features.cc b/lang_id/features/light-sentence-features.cc
new file mode 100644
index 0000000..7f1d878
--- /dev/null
+++ b/lang_id/features/light-sentence-features.cc
@@ -0,0 +1,27 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/features/light-sentence-features.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+// Registry for the features on whole light sentences.
+SAFTM_DEFINE_CLASS_REGISTRY_NAME("light sentence feature function",
+ lang_id::LightSentenceFeature);
+
+} // namespace mobile
+} // namespace nlp_saft
diff --git a/lang_id/features/light-sentence-features.h b/lang_id/features/light-sentence-features.h
new file mode 100644
index 0000000..cc85878
--- /dev/null
+++ b/lang_id/features/light-sentence-features.h
@@ -0,0 +1,41 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_FEATURES_LIGHT_SENTENCE_FEATURES_H_
+#define NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_FEATURES_LIGHT_SENTENCE_FEATURES_H_
+
+#include "lang_id/common/fel/feature-extractor.h"
+#include "lang_id/common/registry.h"
+#include "lang_id/light-sentence.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+namespace lang_id {
+
+// Feature function that extracts features from LightSentences.
+typedef FeatureFunction<LightSentence> LightSentenceFeature;
+
+// Feature extractor for LightSentences.
+typedef FeatureExtractor<LightSentence> LightSentenceExtractor;
+
+} // namespace lang_id
+
+SAFTM_DECLARE_CLASS_REGISTRY_NAME(lang_id::LightSentenceFeature);
+
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_FEATURES_LIGHT_SENTENCE_FEATURES_H_
diff --git a/lang_id/features/relevant-script-feature.cc b/lang_id/features/relevant-script-feature.cc
new file mode 100644
index 0000000..0fde87b
--- /dev/null
+++ b/lang_id/features/relevant-script-feature.cc
@@ -0,0 +1,109 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/features/relevant-script-feature.h"
+
+#include <string>
+
+#include "lang_id/common/fel/feature-types.h"
+#include "lang_id/common/fel/task-context.h"
+#include "lang_id/common/fel/workspace.h"
+#include "lang_id/common/lite_base/logging.h"
+#include "lang_id/common/utf8.h"
+#include "lang_id/script/script-detector.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+namespace lang_id {
+
+bool RelevantScriptFeature::Setup(TaskContext *context) {
+ string script_detector_name = GetParameter(
+ "script_detector_name", /* default_value = */ "tiny-script-detector");
+
+ // We don't use absl::WrapUnique, nor the rest of absl, see http://b/71873194
+ script_detector_.reset(ScriptDetector::Create(script_detector_name));
+ if (script_detector_ == nullptr) {
+ // This means ScriptDetector::Create() could not find the requested
+ // script_detector_name. In that case, Create() already logged an error
+ // message.
+ return false;
+ }
+
+ // We use default value 172 because this is the number of scripts supported by
+ // the first model we trained with this feature. See http://b/70617713.
+ // Newer models may support more scripts.
+ num_supported_scripts_ = GetIntParameter("num_supported_scripts", 172);
+ return true;
+}
+
+bool RelevantScriptFeature::Init(TaskContext *context) {
+ set_feature_type(new NumericFeatureType(name(), num_supported_scripts_));
+ return true;
+}
+
+void RelevantScriptFeature::Evaluate(
+ const WorkspaceSet &workspaces, const LightSentence &sentence,
+ FeatureVector *result) const {
+ // counts[s] is the number of characters with script s.
+ std::vector<int> counts(num_supported_scripts_);
+ int total_count = 0;
+ for (const string &word : sentence) {
+ const char *const word_end = word.data() + word.size();
+ const char *curr = word.data();
+
+ // Skip over token start '^'.
+ SAFTM_DCHECK_EQ(*curr, '^');
+ curr += utils::OneCharLen(curr);
+ while (true) {
+ const int num_bytes = utils::OneCharLen(curr);
+
+ int script = script_detector_->GetScript(curr, num_bytes);
+
+ // We do this update and the if (...) break below *before* incrementing
+ // counts[script] in order to skip the token end '$'.
+ curr += num_bytes;
+ if (curr >= word_end) {
+ SAFTM_DCHECK_EQ(*(curr - num_bytes), '$');
+ break;
+ }
+ SAFTM_DCHECK_GE(script, 0);
+
+ if (script < num_supported_scripts_) {
+ counts[script]++;
+ total_count++;
+ } else {
+ // Unsupported script: this usually indicates a script that is
+ // recognized by newer versions of the code, after the model was
+ // trained. E.g., new code running with old model.
+ }
+ }
+ }
+
+ for (int script_id = 0; script_id < num_supported_scripts_; ++script_id) {
+ int count = counts[script_id];
+ if (count > 0) {
+ const float weight = static_cast<float>(count) / total_count;
+ FloatFeatureValue value(script_id, weight);
+ result->add(feature_type(), value.discrete_value);
+ }
+ }
+}
+
+SAFTM_STATIC_REGISTRATION(RelevantScriptFeature);
+
+} // namespace lang_id
+} // namespace mobile
+} // namespace nlp_saft
diff --git a/lang_id/features/relevant-script-feature.h b/lang_id/features/relevant-script-feature.h
new file mode 100644
index 0000000..57c5a1f
--- /dev/null
+++ b/lang_id/features/relevant-script-feature.h
@@ -0,0 +1,68 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_FEATURES_RELEVANT_SCRIPT_FEATURE_H_
+#define NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_FEATURES_RELEVANT_SCRIPT_FEATURE_H_
+
+#include <memory>
+
+#include "lang_id/common/fel/feature-extractor.h"
+#include "lang_id/common/fel/task-context.h"
+#include "lang_id/common/fel/workspace.h"
+#include "lang_id/features/light-sentence-features.h"
+#include "lang_id/light-sentence.h"
+#include "lang_id/script/script-detector.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+namespace lang_id {
+
+// Given a sentence, generates one FloatFeatureValue for each "relevant" Unicode
+// script (see below): each such feature indicates the script and the ratio of
+// UTF8 characters in that script, in the given sentence.
+//
+// What is a relevant script? Recognizing all 100+ Unicode scripts would
+// require too much code size and runtime. Instead, we focus only on a few
+// scripts that communicate a lot of language information: e.g., the use of
+// Hiragana characters almost always indicates Japanese, so Hiragana is a
+// "relevant" script for us. The Latin script is used by dozens of language, so
+// Latin is not relevant in this context.
+class RelevantScriptFeature : public LightSentenceFeature {
+ public:
+ bool Setup(TaskContext *context) override;
+ bool Init(TaskContext *context) override;
+
+ // Appends the features computed from the sentence to the feature vector.
+ void Evaluate(const WorkspaceSet &workspaces,
+ const LightSentence &sentence,
+ FeatureVector *result) const override;
+
+ SAFTM_DEFINE_REGISTRATION_METHOD("continuous-bag-of-relevant-scripts",
+ RelevantScriptFeature);
+
+ private:
+ // Detects script of individual UTF8 characters.
+ std::unique_ptr<ScriptDetector> script_detector_;
+
+ // Current model supports scripts in [0, num_supported_scripts_).
+ int num_supported_scripts_ = 0;
+};
+
+} // namespace lang_id
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_FEATURES_RELEVANT_SCRIPT_FEATURE_H_
diff --git a/lang_id/lang-id.cc b/lang_id/lang-id.cc
new file mode 100644
index 0000000..ebc88ec
--- /dev/null
+++ b/lang_id/lang-id.cc
@@ -0,0 +1,280 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/lang-id.h"
+
+#include <stdio.h>
+
+#include <algorithm>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "lang_id/common/embedding-feature-interface.h"
+#include "lang_id/common/embedding-network-params.h"
+#include "lang_id/common/embedding-network.h"
+#include "lang_id/common/fel/feature-extractor.h"
+#include "lang_id/common/lite_base/logging.h"
+#include "lang_id/common/lite_strings/numbers.h"
+#include "lang_id/common/lite_strings/str-split.h"
+#include "lang_id/common/lite_strings/stringpiece.h"
+#include "lang_id/common/math/algorithm.h"
+#include "lang_id/common/math/softmax.h"
+#include "lang_id/custom-tokenizer.h"
+#include "lang_id/features/light-sentence-features.h"
+#include "lang_id/light-sentence.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+namespace lang_id {
+
+namespace {
+// Default value for the confidence threshold. If the confidence of the top
+// prediction is below this threshold, then FindLanguage() returns
+// LangId::kUnknownLanguageCode. Note: this is just a default value; if the
+// TaskSpec from the model specifies a "reliability_thresh" parameter, then we
+// use that value instead. Note: for legacy reasons, our code and comments use
+// the terms "confidence", "probability" and "reliability" equivalently.
+static const float kDefaultConfidenceThreshold = 0.50f;
+} // namespace
+
+// Class that performs all work behind LangId.
+class LangIdImpl {
+ public:
+ explicit LangIdImpl(std::unique_ptr<ModelProvider> model_provider)
+ : model_provider_(std::move(model_provider)),
+ lang_id_brain_interface_("language_identifier") {
+ // Note: in the code below, we set valid_ to true only if all initialization
+ // steps completed successfully. Otherwise, we return early, leaving valid_
+ // to its default value false.
+ if (!model_provider_ || !model_provider_->is_valid()) {
+ SAFTM_LOG(ERROR) << "Invalid model provider";
+ return;
+ }
+
+ auto *nn_params = model_provider_->GetNnParams();
+ if (!nn_params) {
+ SAFTM_LOG(ERROR) << "No NN params";
+ return;
+ }
+ network_.reset(new EmbeddingNetwork(nn_params));
+
+ languages_ = model_provider_->GetLanguages();
+ if (languages_.empty()) {
+ SAFTM_LOG(ERROR) << "No known languages";
+ return;
+ }
+
+ TaskContext context = *model_provider_->GetTaskContext();
+ if (!Setup(&context)) {
+ SAFTM_LOG(ERROR) << "Unable to Setup() LangId";
+ return;
+ }
+ if (!Init(&context)) {
+ SAFTM_LOG(ERROR) << "Unable to Init() LangId";
+ return;
+ }
+ valid_ = true;
+ }
+
+ string FindLanguage(StringPiece text) const {
+ // NOTE: it would be wasteful to implement this method in terms of
+ // FindLanguages(). We just need the most likely language and its
+ // probability; no need to compute (and allocate) a vector of pairs for all
+ // languages, nor to compute probabilities for all non-top languages.
+ if (!is_valid()) {
+ return LangId::kUnknownLanguageCode;
+ }
+
+ std::vector<float> scores;
+ ComputeScores(text, &scores);
+
+ int prediction_id = GetArgMax(scores);
+ const string language = GetLanguageForSoftmaxLabel(prediction_id);
+ float probability = ComputeSoftmaxProbability(scores, prediction_id);
+ SAFTM_DLOG(INFO) << "Predicted " << language
+ << " with prob: " << probability << " for \"" << text
+ << "\"";
+
+ // Find confidence threshold for language.
+ float threshold = default_threshold_;
+ auto it = per_lang_thresholds_.find(language);
+ if (it != per_lang_thresholds_.end()) {
+ threshold = it->second;
+ }
+ if (probability < threshold) {
+ SAFTM_DLOG(INFO) << " below threshold => "
+ << LangId::kUnknownLanguageCode;
+ return LangId::kUnknownLanguageCode;
+ }
+ return language;
+ }
+
+ void FindLanguages(StringPiece text, LangIdResult *result) const {
+ if (result == nullptr) return;
+
+ result->predictions.clear();
+ if (!is_valid()) {
+ result->predictions.emplace_back(LangId::kUnknownLanguageCode, 1);
+ return;
+ }
+
+ std::vector<float> scores;
+ ComputeScores(text, &scores);
+
+ // Compute and sort softmax in descending order by probability and convert
+ // IDs to language code strings. When probabilities are equal, we sort by
+ // language code string in ascending order.
+ std::vector<float> softmax = ComputeSoftmax(scores);
+
+ for (int i = 0; i < softmax.size(); ++i) {
+ result->predictions.emplace_back(GetLanguageForSoftmaxLabel(i),
+ softmax[i]);
+ }
+
+ // Sort the resulting language predictions by probability in descending
+ // order.
+ std::sort(result->predictions.begin(), result->predictions.end(),
+ [](const std::pair<string, float> &a,
+ const std::pair<string, float> &b) {
+ if (a.second == b.second) {
+ return a.first.compare(b.first) < 0;
+ } else {
+ return a.second > b.second;
+ }
+ });
+ }
+
+ bool is_valid() const { return valid_; }
+
+ int GetModelVersion() const { return model_version_; }
+
+ private:
+ bool Setup(TaskContext *context) {
+ tokenizer_.Setup(context);
+ if (!lang_id_brain_interface_.SetupForProcessing(context)) return false;
+ default_threshold_ = context->Get(
+ "reliability_thresh", kDefaultConfidenceThreshold);
+
+ // Parse task parameter "per_lang_reliability_thresholds", fill
+ // per_lang_thresholds_.
+ const string thresholds_str =
+ context->Get("per_lang_reliability_thresholds", "");
+ std::vector<StringPiece> tokens = LiteStrSplit(thresholds_str, ',');
+ for (const auto &token : tokens) {
+ if (token.empty()) continue;
+ std::vector<StringPiece> parts = LiteStrSplit(token, '=');
+ float threshold = 0.0f;
+ if ((parts.size() == 2) && LiteAtof(parts[1], &threshold)) {
+ per_lang_thresholds_[string(parts[0])] = threshold;
+ } else {
+ SAFTM_LOG(ERROR) << "Broken token: \"" << token << "\"";
+ }
+ }
+ model_version_ = context->Get("model_version", model_version_);
+ return true;
+ }
+
+ bool Init(TaskContext *context) {
+ return lang_id_brain_interface_.InitForProcessing(context);
+ }
+
+ // Extracts features for |text|, runs them through the feed-forward neural
+ // network, and computes the output scores (activations from the last layer).
+ // These scores can be used to compute the softmax probabilities for our
+ // labels (in this case, the languages).
+ void ComputeScores(StringPiece text, std::vector<float> *scores) const {
+ // Create a Sentence storing the input text.
+ LightSentence sentence;
+ tokenizer_.Tokenize(text, &sentence);
+
+ std::vector<FeatureVector> features =
+ lang_id_brain_interface_.GetFeaturesNoCaching(&sentence);
+
+ // Run feed-forward neural network to compute scores.
+ network_->ComputeFinalScores(features, scores);
+ }
+
+ // Returns language code for a softmax label. See comments for languages_
+ // field. If label is out of range, returns LangId::kUnknownLanguageCode.
+ string GetLanguageForSoftmaxLabel(int label) const {
+ if ((label >= 0) && (label < languages_.size())) {
+ return languages_[label];
+ } else {
+ SAFTM_LOG(ERROR) << "Softmax label " << label << " outside range [0, "
+ << languages_.size() << ")";
+ return LangId::kUnknownLanguageCode;
+ }
+ }
+
+ std::unique_ptr<ModelProvider> model_provider_;
+
+ TokenizerForLangId tokenizer_;
+
+ EmbeddingFeatureInterface<LightSentenceExtractor, LightSentence>
+ lang_id_brain_interface_;
+
+ // Neural network to use for scoring.
+ std::unique_ptr<EmbeddingNetwork> network_;
+
+ // True if this object is ready to perform language predictions.
+ bool valid_ = false;
+
+ // Only predictions with a probability (confidence) above this threshold are
+ // reported. Otherwise, we report LangId::kUnknownLanguageCode.
+ float default_threshold_ = kDefaultConfidenceThreshold;
+
+ std::unordered_map<string, float> per_lang_thresholds_;
+
+ // Recognized languages: softmax label i means languages_[i] (something like
+ // "en", "fr", "ru", etc).
+ std::vector<string> languages_;
+
+ // Version of the model used by this LangIdImpl object. Zero means that the
+ // model version could not be determined.
+ int model_version_ = 0;
+};
+
+const char LangId::kUnknownLanguageCode[] = "und";
+
+LangId::LangId(std::unique_ptr<ModelProvider> model_provider)
+ : pimpl_(new LangIdImpl(std::move(model_provider))) {
+}
+
+LangId::~LangId() = default;
+
+string LangId::FindLanguage(const char *data, size_t num_bytes) const {
+ StringPiece text(data, num_bytes);
+ return pimpl_->FindLanguage(text);
+}
+
+void LangId::FindLanguages(const char *data, size_t num_bytes,
+ LangIdResult *result) const {
+ SAFTM_DCHECK(result) << "LangIdResult must not be null.";
+ StringPiece text(data, num_bytes);
+ pimpl_->FindLanguages(text, result);
+}
+
+bool LangId::is_valid() const {
+ return pimpl_->is_valid();
+}
+
+int LangId::GetModelVersion() const { return pimpl_->GetModelVersion(); }
+
+} // namespace lang_id
+} // namespace mobile
+} // namespace nlp_saft
diff --git a/lang_id/lang-id.h b/lang_id/lang-id.h
new file mode 100644
index 0000000..3f656f2
--- /dev/null
+++ b/lang_id/lang-id.h
@@ -0,0 +1,134 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_LANG_ID_H_
+#define NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_LANG_ID_H_
+
+
+#include <stddef.h>
+
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "lang_id/common/lite_base/macros.h"
+#include "lang_id/model-provider.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+namespace lang_id {
+
+// Forward-declaration of the class that performs all underlying work.
+class LangIdImpl;
+
+struct LangIdResult {
+ // An n-best list of possible language codes for a given input sorted in
+ // descending order according to each code's respective probability.
+ //
+ // This list is guaranteed to be non-empty after calling
+ // LangId::FindLanguages. The most likely language code is always the first
+ // item in this array.
+ //
+ // If the model cannot make a prediction, this array contains a single result:
+ // a language code LangId::kUnknownLanguageCode with probability 1.
+ std::vector<std::pair<string, float>> predictions;
+};
+
+// Class for detecting the language of a document.
+//
+// Note: this class does not handle the details of loading the actual model.
+// Those details have been "outsourced" to the ModelProvider class.
+//
+// Note: this class is thread-unsafe.
+class LangId {
+ public:
+ // Standard BCP-47 language code for Unknown/Undetermined language.
+ static const char kUnknownLanguageCode[];
+
+ // Constructs a LangId object, based on |model_provider|.
+ //
+ // Note: we don't crash if we detect a problem at construction time (e.g., the
+ // model provider can't read an underlying file). Instead, we mark the
+ // newly-constructed object as invalid; clients can invoke FindLanguage() on
+ // an invalid object: nothing crashes, but accuracy will be bad.
+ explicit LangId(std::unique_ptr<ModelProvider> model_provider);
+
+ virtual ~LangId();
+
+ // Computes the an n-best list of language codes and probabilities
+ // corresponding to the most likely languages the given input text is written
+ // in. The list is sorted in descending order by language probability.
+ //
+ // The input text consists of the |num_bytes| bytes that starts at |data|.
+ //
+ // Note: If this LangId object is not valid (see is_valid()) or if this LangId
+ // object can't make a prediction, this method sets the LangIdResult to
+ // contain a single entry with kUnknownLanguageCode with probability 1.
+ void FindLanguages(const char *data, size_t num_bytes,
+ LangIdResult *result) const;
+
+ // Convenience version of FindLanguages(const char *, size_t, LangIdResult *).
+ void FindLanguages(const string &text, LangIdResult *result) const {
+ FindLanguages(text.data(), text.size(), result);
+ }
+
+ // Returns language code for the most likely language for a piece of text.
+ //
+ // The input text consists of the |num_bytes| bytes that start at |data|.
+ //
+ // Note: this method reports the most likely (1-best) language only if its
+ // probability is high enough; otherwise, it returns
+ // LangId::kUnknownLanguageCode. The specific probability threshold is tuned
+ // to the needs of an early client. If you need a different threshold, you
+ // can use FindLanguages (plural) to get the full LangIdResult, and apply your
+ // own threshold.
+ //
+ // Note: if this LangId object is not valid (see is_valid()) or if this LangId
+ // object can't make a prediction, then this method returns
+ // LangId::kUnknownLanguageCode.
+ //
+ string FindLanguage(const char *data, size_t num_bytes) const;
+
+ // Convenience version of FindLanguage(const char *, size_t).
+ string FindLanguage(const string &text) const {
+ return FindLanguage(text.data(), text.size());
+ }
+
+ // Returns true if this object has been correctly initialized and is ready to
+ // perform predictions. For more info, see doc for LangId
+ // constructor above.
+ bool is_valid() const;
+
+ // Returns the version of the model used by this LangId object. On success,
+ // the returned version number is a strictly positive integer. Returns 0 if
+ // the model version can not be determined (e.g., for old models that do not
+ // specify a version number).
+ int GetModelVersion() const;
+
+ private:
+ // Pimpl ("pointer to implementation") pattern, to hide all internals from our
+ // clients.
+ std::unique_ptr<LangIdImpl> pimpl_;
+
+ SAFTM_DISALLOW_COPY_AND_ASSIGN(LangId);
+};
+
+} // namespace lang_id
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_LANG_ID_H_
diff --git a/lang_id/lang-id_jni.cc b/lang_id/lang-id_jni.cc
new file mode 100644
index 0000000..7026417
--- /dev/null
+++ b/lang_id/lang-id_jni.cc
@@ -0,0 +1,125 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/lang-id_jni.h"
+
+#include <jni.h>
+#include <type_traits>
+#include <vector>
+
+#include "utils/base/logging.h"
+#include "utils/java/scoped_local_ref.h"
+#include "lang_id/fb_model/lang-id-from-fb.h"
+#include "lang_id/lang-id.h"
+
+using libtextclassifier3::ScopedLocalRef;
+using libtextclassifier3::ToStlString;
+using libtextclassifier3::mobile::lang_id::GetLangIdFromFlatbufferFile;
+using libtextclassifier3::mobile::lang_id::GetLangIdFromFlatbufferFileDescriptor;
+using libtextclassifier3::mobile::lang_id::LangId;
+using libtextclassifier3::mobile::lang_id::LangIdResult;
+
+namespace {
+jobjectArray LangIdResultToJObjectArray(JNIEnv* env,
+ const LangIdResult& lang_id_result) {
+ const ScopedLocalRef<jclass> result_class(
+ env->FindClass(TC3_PACKAGE_PATH TC3_LANG_ID_CLASS_NAME_STR
+ "$LanguageResult"),
+ env);
+ if (!result_class) {
+ TC3_LOG(ERROR) << "Couldn't find LanguageResult class.";
+ return nullptr;
+ }
+
+ // clang-format off
+ const std::vector<std::pair<std::string, float>>& predictions =
+ lang_id_result.predictions;
+ // clang-format on
+ const jmethodID result_class_constructor =
+ env->GetMethodID(result_class.get(), "<init>", "(Ljava/lang/String;F)V");
+ const jobjectArray results =
+ env->NewObjectArray(predictions.size(), result_class.get(), nullptr);
+ for (int i = 0; i < predictions.size(); i++) {
+ ScopedLocalRef<jobject> result(
+ env->NewObject(result_class.get(), result_class_constructor,
+ env->NewStringUTF(predictions[i].first.c_str()),
+ static_cast<jfloat>(predictions[i].second)));
+ env->SetObjectArrayElement(results, i, result.get());
+ }
+ return results;
+}
+} // namespace
+
+TC3_JNI_METHOD(jlong, TC3_LANG_ID_CLASS_NAME, nativeNew)
+(JNIEnv* env, jobject thiz, jint fd) {
+ std::unique_ptr<LangId> lang_id = GetLangIdFromFlatbufferFileDescriptor(fd);
+ if (!lang_id->is_valid()) {
+ return reinterpret_cast<jlong>(nullptr);
+ }
+ return reinterpret_cast<jlong>(lang_id.release());
+}
+
+TC3_JNI_METHOD(jlong, TC3_LANG_ID_CLASS_NAME, nativeNewFromPath)
+(JNIEnv* env, jobject thiz, jstring path) {
+ const std::string path_str = ToStlString(env, path);
+ std::unique_ptr<LangId> lang_id = GetLangIdFromFlatbufferFile(path_str);
+ if (!lang_id->is_valid()) {
+ return reinterpret_cast<jlong>(nullptr);
+ }
+ return reinterpret_cast<jlong>(lang_id.release());
+}
+
+TC3_JNI_METHOD(jobjectArray, TC3_LANG_ID_CLASS_NAME, nativeDetectLanguages)
+(JNIEnv* env, jobject clazz, jlong ptr, jstring text) {
+ LangId* model = reinterpret_cast<LangId*>(ptr);
+ if (!model) {
+ return nullptr;
+ }
+
+ const std::string text_str = ToStlString(env, text);
+ LangIdResult result;
+ model->FindLanguages(text_str, &result);
+
+ return LangIdResultToJObjectArray(env, result);
+}
+
+TC3_JNI_METHOD(void, TC3_LANG_ID_CLASS_NAME, nativeClose)
+(JNIEnv* env, jobject clazz, jlong ptr) {
+ if (!ptr) {
+ TC3_LOG(ERROR) << "Trying to close null LangId.";
+ return;
+ }
+ LangId* model = reinterpret_cast<LangId*>(ptr);
+ delete model;
+}
+
+TC3_JNI_METHOD(jint, TC3_LANG_ID_CLASS_NAME, nativeGetVersion)
+(JNIEnv* env, jobject clazz, jlong ptr) {
+ if (!ptr) {
+ return -1;
+ }
+ LangId* model = reinterpret_cast<LangId*>(ptr);
+ return model->GetModelVersion();
+}
+
+TC3_JNI_METHOD(jint, TC3_LANG_ID_CLASS_NAME, nativeGetVersionFromFd)
+(JNIEnv* env, jobject clazz, jint fd) {
+ std::unique_ptr<LangId> lang_id = GetLangIdFromFlatbufferFileDescriptor(fd);
+ if (!lang_id->is_valid()) {
+ return -1;
+ }
+ return lang_id->GetModelVersion();
+}
diff --git a/lang_id/lang-id_jni.h b/lang_id/lang-id_jni.h
new file mode 100644
index 0000000..74a7e2d
--- /dev/null
+++ b/lang_id/lang-id_jni.h
@@ -0,0 +1,58 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// JNI wrapper for LangId.
+
+#ifndef LIBTEXTCLASSIFIER_LANG_ID_LANG_ID_JNI_H_
+#define LIBTEXTCLASSIFIER_LANG_ID_LANG_ID_JNI_H_
+
+#include <jni.h>
+#include <string>
+#include "utils/java/jni-base.h"
+
+#ifndef TC3_LANG_ID_CLASS_NAME
+#define TC3_LANG_ID_CLASS_NAME LangIdModel
+#endif
+
+#define TC3_LANG_ID_CLASS_NAME_STR TC3_ADD_QUOTES(TC3_LANG_ID_CLASS_NAME)
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+TC3_JNI_METHOD(jlong, TC3_LANG_ID_CLASS_NAME, nativeNew)
+(JNIEnv* env, jobject clazz, jint fd);
+
+TC3_JNI_METHOD(jlong, TC3_LANG_ID_CLASS_NAME, nativeNewFromPath)
+(JNIEnv* env, jobject clazz, jstring path);
+
+TC3_JNI_METHOD(jobjectArray, TC3_LANG_ID_CLASS_NAME, nativeDetectLanguages)
+(JNIEnv* env, jobject clazz, jlong ptr, jstring text);
+
+TC3_JNI_METHOD(void, TC3_LANG_ID_CLASS_NAME, nativeClose)
+(JNIEnv* env, jobject clazz, jlong ptr);
+
+TC3_JNI_METHOD(jint, TC3_LANG_ID_CLASS_NAME, nativeGetVersion)
+(JNIEnv* env, jobject clazz, jlong ptr);
+
+TC3_JNI_METHOD(jint, TC3_LANG_ID_CLASS_NAME, nativeGetVersionFromFd)
+(JNIEnv* env, jobject clazz, jint fd);
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif // LIBTEXTCLASSIFIER_LANG_ID_LANG_ID_JNI_H_
diff --git a/lang_id/light-sentence.h b/lang_id/light-sentence.h
new file mode 100644
index 0000000..2937549
--- /dev/null
+++ b/lang_id/light-sentence.h
@@ -0,0 +1,36 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_LIGHT_SENTENCE_H_
+#define NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_LIGHT_SENTENCE_H_
+
+#include <string>
+#include <vector>
+
+namespace libtextclassifier3 {
+namespace mobile {
+namespace lang_id {
+
+// Very simplified alternative to heavy sentence.proto, for the purpose of
+// LangId. It turns out that in this case, all we need is a vector of strings,
+// which uses a lot less code size than a Sentence proto.
+using LightSentence = std::vector<string>;
+
+} // namespace lang_id
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_LIGHT_SENTENCE_H_
diff --git a/lang_id/model-provider.h b/lang_id/model-provider.h
new file mode 100644
index 0000000..a076871
--- /dev/null
+++ b/lang_id/model-provider.h
@@ -0,0 +1,64 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_MODEL_PROVIDER_H_
+#define NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_MODEL_PROVIDER_H_
+
+#include <string>
+#include <vector>
+
+#include "lang_id/common/embedding-network-params.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+namespace lang_id {
+
+// Interface for accessing parameters for the LangId model.
+//
+// Note: some clients prefer to include the model parameters in the binary,
+// others prefer loading them from a separate file. This file provides a common
+// interface for these alternative mechanisms.
+class ModelProvider {
+ public:
+ virtual ~ModelProvider() = default;
+
+ // Returns true if this ModelProvider has been succesfully constructed (e.g.,
+ // can return false if an underlying model file could not be read). Clients
+ // should not use invalid ModelProviders.
+ bool is_valid() { return valid_; }
+
+ // Returns the TaskContext with parameters for the LangId model. E.g., one
+ // important parameter specifies the features to use.
+ virtual const TaskContext *GetTaskContext() const = 0;
+
+ // Returns parameters for the underlying Neurosis feed-forward neural network.
+ virtual const EmbeddingNetworkParams *GetNnParams() const = 0;
+
+ // Returns list of languages recognized by the model. Each element of the
+ // returned vector should be a BCP-47 language code (e.g., "en", "ro", etc).
+ // Language at index i from the returned vector corresponds to softmax label
+ // i.
+ virtual std::vector<string> GetLanguages() const = 0;
+
+ protected:
+ bool valid_ = false;
+};
+
+} // namespace lang_id
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_MODEL_PROVIDER_H_
diff --git a/lang_id/script/approx-script-data.cc b/lang_id/script/approx-script-data.cc
new file mode 100755
index 0000000..1ac5cb6
--- /dev/null
+++ b/lang_id/script/approx-script-data.cc
@@ -0,0 +1,1122 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Internal data for approx-script.cc; see approx-script-data.h
+//
+// DO NOT EDIT BY HAND
+//
+// Generated by
+// lang_id/script/update-script-data.sh
+
+#include "lang_id/script/approx-script-data.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+namespace approx_script_internal {
+
+const int kNumRanges = 359;
+
+const uint32 kRangeFirst[] = {
+ 65, // Range #0: [65, 90, Latin]
+ 97, // Range #1: [97, 122, Latin]
+ 170, // Range #2: [170, 170, Latin]
+ 186, // Range #3: [186, 186, Latin]
+ 192, // Range #4: [192, 214, Latin]
+ 216, // Range #5: [216, 246, Latin]
+ 248, // Range #6: [248, 696, Latin]
+ 736, // Range #7: [736, 740, Latin]
+ 746, // Range #8: [746, 747, Bopomofo]
+ 880, // Range #9: [880, 883, Greek]
+ 885, // Range #10: [885, 893, Greek]
+ 895, // Range #11: [895, 900, Greek]
+ 902, // Range #12: [902, 902, Greek]
+ 904, // Range #13: [904, 993, Greek]
+ 994, // Range #14: [994, 1007, Coptic]
+ 1008, // Range #15: [1008, 1023, Greek]
+ 1024, // Range #16: [1024, 1156, Cyrillic]
+ 1159, // Range #17: [1159, 1327, Cyrillic]
+ 1329, // Range #18: [1329, 1416, Armenian]
+ 1418, // Range #19: [1418, 1423, Armenian]
+ 1425, // Range #20: [1425, 1479, Hebrew]
+ 1488, // Range #21: [1488, 1524, Hebrew]
+ 1536, // Range #22: [1536, 1540, Arabic]
+ 1542, // Range #23: [1542, 1547, Arabic]
+ 1549, // Range #24: [1549, 1562, Arabic]
+ 1564, // Range #25: [1564, 1566, Arabic]
+ 1568, // Range #26: [1568, 1599, Arabic]
+ 1601, // Range #27: [1601, 1610, Arabic]
+ 1622, // Range #28: [1622, 1647, Arabic]
+ 1649, // Range #29: [1649, 1756, Arabic]
+ 1758, // Range #30: [1758, 1791, Arabic]
+ 1792, // Range #31: [1792, 1871, Syriac]
+ 1872, // Range #32: [1872, 1919, Arabic]
+ 1920, // Range #33: [1920, 1969, Thaana]
+ 1984, // Range #34: [1984, 2047, Nko]
+ 2048, // Range #35: [2048, 2110, Samaritan]
+ 2112, // Range #36: [2112, 2142, Mandaic]
+ 2144, // Range #37: [2144, 2154, Syriac]
+ 2208, // Range #38: [2208, 2237, Arabic]
+ 2259, // Range #39: [2259, 2273, Arabic]
+ 2275, // Range #40: [2275, 2303, Arabic]
+ 2304, // Range #41: [2304, 2384, Devanagari]
+ 2387, // Range #42: [2387, 2403, Devanagari]
+ 2406, // Range #43: [2406, 2431, Devanagari]
+ 2432, // Range #44: [2432, 2510, Bengali]
+ 2519, // Range #45: [2519, 2558, Bengali]
+ 2561, // Range #46: [2561, 2641, Gurmukhi]
+ 2649, // Range #47: [2649, 2654, Gurmukhi]
+ 2662, // Range #48: [2662, 2678, Gurmukhi]
+ 2689, // Range #49: [2689, 2768, Gujarati]
+ 2784, // Range #50: [2784, 2801, Gujarati]
+ 2809, // Range #51: [2809, 2815, Gujarati]
+ 2817, // Range #52: [2817, 2893, Oriya]
+ 2902, // Range #53: [2902, 2935, Oriya]
+ 2946, // Range #54: [2946, 3024, Tamil]
+ 3031, // Range #55: [3031, 3031, Tamil]
+ 3046, // Range #56: [3046, 3066, Tamil]
+ 3072, // Range #57: [3072, 3149, Telugu]
+ 3157, // Range #58: [3157, 3162, Telugu]
+ 3168, // Range #59: [3168, 3183, Telugu]
+ 3192, // Range #60: [3192, 3199, Telugu]
+ 3200, // Range #61: [3200, 3277, Kannada]
+ 3285, // Range #62: [3285, 3286, Kannada]
+ 3294, // Range #63: [3294, 3314, Kannada]
+ 3328, // Range #64: [3328, 3455, Malayalam]
+ 3458, // Range #65: [3458, 3551, Sinhala]
+ 3558, // Range #66: [3558, 3572, Sinhala]
+ 3585, // Range #67: [3585, 3642, Thai]
+ 3648, // Range #68: [3648, 3675, Thai]
+ 3713, // Range #69: [3713, 3725, Lao]
+ 3732, // Range #70: [3732, 3807, Lao]
+ 3840, // Range #71: [3840, 4052, Tibetan]
+ 4057, // Range #72: [4057, 4058, Tibetan]
+ 4096, // Range #73: [4096, 4255, Myanmar]
+ 4256, // Range #74: [4256, 4295, Georgian]
+ 4301, // Range #75: [4301, 4346, Georgian]
+ 4348, // Range #76: [4348, 4351, Georgian]
+ 4352, // Range #77: [4352, 4607, Hangul]
+ 4608, // Range #78: [4608, 5017, Ethiopic]
+ 5024, // Range #79: [5024, 5117, Cherokee]
+ 5120, // Range #80: [5120, 5759, Canadian_Aboriginal]
+ 5760, // Range #81: [5760, 5788, Ogham]
+ 5792, // Range #82: [5792, 5866, Runic]
+ 5870, // Range #83: [5870, 5880, Runic]
+ 5888, // Range #84: [5888, 5908, Tagalog]
+ 5920, // Range #85: [5920, 5940, Hanunoo]
+ 5952, // Range #86: [5952, 5971, Buhid]
+ 5984, // Range #87: [5984, 6003, Tagbanwa]
+ 6016, // Range #88: [6016, 6121, Khmer]
+ 6128, // Range #89: [6128, 6137, Khmer]
+ 6144, // Range #90: [6144, 6145, Mongolian]
+ 6148, // Range #91: [6148, 6148, Mongolian]
+ 6150, // Range #92: [6150, 6169, Mongolian]
+ 6176, // Range #93: [6176, 6264, Mongolian]
+ 6272, // Range #94: [6272, 6314, Mongolian]
+ 6320, // Range #95: [6320, 6389, Canadian_Aboriginal]
+ 6400, // Range #96: [6400, 6479, Limbu]
+ 6480, // Range #97: [6480, 6516, Tai_Le]
+ 6528, // Range #98: [6528, 6601, New_Tai_Lue]
+ 6608, // Range #99: [6608, 6623, New_Tai_Lue]
+ 6624, // Range #100: [6624, 6655, Khmer]
+ 6656, // Range #101: [6656, 6687, Buginese]
+ 6688, // Range #102: [6688, 6793, Tai_Tham]
+ 6800, // Range #103: [6800, 6809, Tai_Tham]
+ 6816, // Range #104: [6816, 6829, Tai_Tham]
+ 6912, // Range #105: [6912, 7036, Balinese]
+ 7040, // Range #106: [7040, 7103, Sundanese]
+ 7104, // Range #107: [7104, 7155, Batak]
+ 7164, // Range #108: [7164, 7167, Batak]
+ 7168, // Range #109: [7168, 7247, Lepcha]
+ 7248, // Range #110: [7248, 7295, Ol_Chiki]
+ 7296, // Range #111: [7296, 7304, Cyrillic]
+ 7312, // Range #112: [7312, 7359, Georgian]
+ 7360, // Range #113: [7360, 7367, Sundanese]
+ 7424, // Range #114: [7424, 7461, Latin]
+ 7462, // Range #115: [7462, 7466, Greek]
+ 7467, // Range #116: [7467, 7467, Cyrillic]
+ 7468, // Range #117: [7468, 7516, Latin]
+ 7517, // Range #118: [7517, 7521, Greek]
+ 7522, // Range #119: [7522, 7525, Latin]
+ 7526, // Range #120: [7526, 7530, Greek]
+ 7531, // Range #121: [7531, 7543, Latin]
+ 7544, // Range #122: [7544, 7544, Cyrillic]
+ 7545, // Range #123: [7545, 7614, Latin]
+ 7615, // Range #124: [7615, 7615, Greek]
+ 7680, // Range #125: [7680, 7935, Latin]
+ 7936, // Range #126: [7936, 8190, Greek]
+ 8305, // Range #127: [8305, 8305, Latin]
+ 8319, // Range #128: [8319, 8319, Latin]
+ 8336, // Range #129: [8336, 8348, Latin]
+ 8486, // Range #130: [8486, 8486, Greek]
+ 8490, // Range #131: [8490, 8491, Latin]
+ 8498, // Range #132: [8498, 8498, Latin]
+ 8526, // Range #133: [8526, 8526, Latin]
+ 8544, // Range #134: [8544, 8584, Latin]
+ 10240, // Range #135: [10240, 10495, Braille]
+ 11264, // Range #136: [11264, 11358, Glagolitic]
+ 11360, // Range #137: [11360, 11391, Latin]
+ 11392, // Range #138: [11392, 11507, Coptic]
+ 11513, // Range #139: [11513, 11519, Coptic]
+ 11520, // Range #140: [11520, 11559, Georgian]
+ 11565, // Range #141: [11565, 11565, Georgian]
+ 11568, // Range #142: [11568, 11623, Tifinagh]
+ 11631, // Range #143: [11631, 11632, Tifinagh]
+ 11647, // Range #144: [11647, 11647, Tifinagh]
+ 11648, // Range #145: [11648, 11670, Ethiopic]
+ 11680, // Range #146: [11680, 11742, Ethiopic]
+ 11744, // Range #147: [11744, 11775, Cyrillic]
+ 11904, // Range #148: [11904, 12019, Han]
+ 12032, // Range #149: [12032, 12245, Han]
+ 12293, // Range #150: [12293, 12293, Han]
+ 12295, // Range #151: [12295, 12295, Han]
+ 12321, // Range #152: [12321, 12329, Han]
+ 12334, // Range #153: [12334, 12335, Hangul]
+ 12344, // Range #154: [12344, 12347, Han]
+ 12353, // Range #155: [12353, 12438, Hiragana]
+ 12445, // Range #156: [12445, 12447, Hiragana]
+ 12449, // Range #157: [12449, 12538, Katakana]
+ 12541, // Range #158: [12541, 12543, Katakana]
+ 12549, // Range #159: [12549, 12591, Bopomofo]
+ 12593, // Range #160: [12593, 12686, Hangul]
+ 12704, // Range #161: [12704, 12730, Bopomofo]
+ 12784, // Range #162: [12784, 12799, Katakana]
+ 12800, // Range #163: [12800, 12830, Hangul]
+ 12896, // Range #164: [12896, 12926, Hangul]
+ 13008, // Range #165: [13008, 13143, Katakana]
+ 13312, // Range #166: [13312, 19893, Han]
+ 19968, // Range #167: [19968, 40943, Han]
+ 40960, // Range #168: [40960, 42182, Yi]
+ 42192, // Range #169: [42192, 42239, Lisu]
+ 42240, // Range #170: [42240, 42539, Vai]
+ 42560, // Range #171: [42560, 42655, Cyrillic]
+ 42656, // Range #172: [42656, 42743, Bamum]
+ 42786, // Range #173: [42786, 42887, Latin]
+ 42891, // Range #174: [42891, 42937, Latin]
+ 42999, // Range #175: [42999, 43007, Latin]
+ 43008, // Range #176: [43008, 43051, Syloti_Nagri]
+ 43072, // Range #177: [43072, 43127, Phags_Pa]
+ 43136, // Range #178: [43136, 43205, Saurashtra]
+ 43214, // Range #179: [43214, 43225, Saurashtra]
+ 43232, // Range #180: [43232, 43263, Devanagari]
+ 43264, // Range #181: [43264, 43309, Kayah_Li]
+ 43311, // Range #182: [43311, 43311, Kayah_Li]
+ 43312, // Range #183: [43312, 43347, Rejang]
+ 43359, // Range #184: [43359, 43359, Rejang]
+ 43360, // Range #185: [43360, 43388, Hangul]
+ 43392, // Range #186: [43392, 43469, Javanese]
+ 43472, // Range #187: [43472, 43487, Javanese]
+ 43488, // Range #188: [43488, 43518, Myanmar]
+ 43520, // Range #189: [43520, 43574, Cham]
+ 43584, // Range #190: [43584, 43615, Cham]
+ 43616, // Range #191: [43616, 43647, Myanmar]
+ 43648, // Range #192: [43648, 43714, Tai_Viet]
+ 43739, // Range #193: [43739, 43743, Tai_Viet]
+ 43744, // Range #194: [43744, 43766, Meetei_Mayek]
+ 43777, // Range #195: [43777, 43798, Ethiopic]
+ 43808, // Range #196: [43808, 43822, Ethiopic]
+ 43824, // Range #197: [43824, 43866, Latin]
+ 43868, // Range #198: [43868, 43876, Latin]
+ 43877, // Range #199: [43877, 43877, Greek]
+ 43888, // Range #200: [43888, 43967, Cherokee]
+ 43968, // Range #201: [43968, 44025, Meetei_Mayek]
+ 44032, // Range #202: [44032, 55203, Hangul]
+ 55216, // Range #203: [55216, 55291, Hangul]
+ 63744, // Range #204: [63744, 64217, Han]
+ 64256, // Range #205: [64256, 64262, Latin]
+ 64275, // Range #206: [64275, 64279, Armenian]
+ 64285, // Range #207: [64285, 64335, Hebrew]
+ 64336, // Range #208: [64336, 64449, Arabic]
+ 64467, // Range #209: [64467, 64829, Arabic]
+ 64848, // Range #210: [64848, 64967, Arabic]
+ 65008, // Range #211: [65008, 65021, Arabic]
+ 65070, // Range #212: [65070, 65071, Cyrillic]
+ 65136, // Range #213: [65136, 65276, Arabic]
+ 65313, // Range #214: [65313, 65338, Latin]
+ 65345, // Range #215: [65345, 65370, Latin]
+ 65382, // Range #216: [65382, 65391, Katakana]
+ 65393, // Range #217: [65393, 65437, Katakana]
+ 65440, // Range #218: [65440, 65500, Hangul]
+ 65536, // Range #219: [65536, 65629, Linear_B]
+ 65664, // Range #220: [65664, 65786, Linear_B]
+ 65856, // Range #221: [65856, 65934, Greek]
+ 65952, // Range #222: [65952, 65952, Greek]
+ 66176, // Range #223: [66176, 66204, Lycian]
+ 66208, // Range #224: [66208, 66256, Carian]
+ 66304, // Range #225: [66304, 66339, Old_Italic]
+ 66349, // Range #226: [66349, 66351, Old_Italic]
+ 66352, // Range #227: [66352, 66378, Gothic]
+ 66384, // Range #228: [66384, 66426, Old_Permic]
+ 66432, // Range #229: [66432, 66463, Ugaritic]
+ 66464, // Range #230: [66464, 66517, Old_Persian]
+ 66560, // Range #231: [66560, 66639, Deseret]
+ 66640, // Range #232: [66640, 66687, Shavian]
+ 66688, // Range #233: [66688, 66729, Osmanya]
+ 66736, // Range #234: [66736, 66811, Osage]
+ 66816, // Range #235: [66816, 66855, Elbasan]
+ 66864, // Range #236: [66864, 66915, Caucasian_Albanian]
+ 66927, // Range #237: [66927, 66927, Caucasian_Albanian]
+ 67072, // Range #238: [67072, 67382, Linear_A]
+ 67392, // Range #239: [67392, 67413, Linear_A]
+ 67424, // Range #240: [67424, 67431, Linear_A]
+ 67584, // Range #241: [67584, 67647, Cypriot]
+ 67648, // Range #242: [67648, 67679, Imperial_Aramaic]
+ 67680, // Range #243: [67680, 67711, Palmyrene]
+ 67712, // Range #244: [67712, 67742, Nabataean]
+ 67751, // Range #245: [67751, 67759, Nabataean]
+ 67808, // Range #246: [67808, 67829, Hatran]
+ 67835, // Range #247: [67835, 67839, Hatran]
+ 67840, // Range #248: [67840, 67871, Phoenician]
+ 67872, // Range #249: [67872, 67897, Lydian]
+ 67903, // Range #250: [67903, 67903, Lydian]
+ 67968, // Range #251: [67968, 67999, Meroitic_Hieroglyphs]
+ 68000, // Range #252: [68000, 68095, Meroitic_Cursive]
+ 68096, // Range #253: [68096, 68102, Kharoshthi]
+ 68108, // Range #254: [68108, 68168, Kharoshthi]
+ 68176, // Range #255: [68176, 68184, Kharoshthi]
+ 68192, // Range #256: [68192, 68223, Old_South_Arabian]
+ 68224, // Range #257: [68224, 68255, Old_North_Arabian]
+ 68288, // Range #258: [68288, 68342, Manichaean]
+ 68352, // Range #259: [68352, 68415, Avestan]
+ 68416, // Range #260: [68416, 68447, Inscriptional_Parthian]
+ 68448, // Range #261: [68448, 68466, Inscriptional_Pahlavi]
+ 68472, // Range #262: [68472, 68479, Inscriptional_Pahlavi]
+ 68480, // Range #263: [68480, 68497, Psalter_Pahlavi]
+ 68505, // Range #264: [68505, 68508, Psalter_Pahlavi]
+ 68521, // Range #265: [68521, 68527, Psalter_Pahlavi]
+ 68608, // Range #266: [68608, 68680, Old_Turkic]
+ 68736, // Range #267: [68736, 68786, Old_Hungarian]
+ 68800, // Range #268: [68800, 68850, Old_Hungarian]
+ 68858, // Range #269: [68858, 68863, Old_Hungarian]
+ 68864, // Range #270: [68864, 68903, Hanifi_Rohingya]
+ 68912, // Range #271: [68912, 68921, Hanifi_Rohingya]
+ 69216, // Range #272: [69216, 69246, Arabic]
+ 69376, // Range #273: [69376, 69415, Old_Sogdian]
+ 69424, // Range #274: [69424, 69465, Sogdian]
+ 69632, // Range #275: [69632, 69743, Brahmi]
+ 69759, // Range #276: [69759, 69759, Brahmi]
+ 69760, // Range #277: [69760, 69825, Kaithi]
+ 69837, // Range #278: [69837, 69837, Kaithi]
+ 69840, // Range #279: [69840, 69864, Sora_Sompeng]
+ 69872, // Range #280: [69872, 69881, Sora_Sompeng]
+ 69888, // Range #281: [69888, 69958, Chakma]
+ 69968, // Range #282: [69968, 70006, Mahajani]
+ 70016, // Range #283: [70016, 70111, Sharada]
+ 70113, // Range #284: [70113, 70132, Sinhala]
+ 70144, // Range #285: [70144, 70206, Khojki]
+ 70272, // Range #286: [70272, 70313, Multani]
+ 70320, // Range #287: [70320, 70378, Khudawadi]
+ 70384, // Range #288: [70384, 70393, Khudawadi]
+ 70400, // Range #289: [70400, 70457, Grantha]
+ 70460, // Range #290: [70460, 70480, Grantha]
+ 70487, // Range #291: [70487, 70487, Grantha]
+ 70493, // Range #292: [70493, 70516, Grantha]
+ 70656, // Range #293: [70656, 70750, Newa]
+ 70784, // Range #294: [70784, 70855, Tirhuta]
+ 70864, // Range #295: [70864, 70873, Tirhuta]
+ 71040, // Range #296: [71040, 71133, Siddham]
+ 71168, // Range #297: [71168, 71236, Modi]
+ 71248, // Range #298: [71248, 71257, Modi]
+ 71264, // Range #299: [71264, 71276, Mongolian]
+ 71296, // Range #300: [71296, 71351, Takri]
+ 71360, // Range #301: [71360, 71369, Takri]
+ 71424, // Range #302: [71424, 71487, Ahom]
+ 71680, // Range #303: [71680, 71739, Dogra]
+ 71840, // Range #304: [71840, 71922, Warang_Citi]
+ 71935, // Range #305: [71935, 71935, Warang_Citi]
+ 72192, // Range #306: [72192, 72263, Zanabazar_Square]
+ 72272, // Range #307: [72272, 72354, Soyombo]
+ 72384, // Range #308: [72384, 72440, Pau_Cin_Hau]
+ 72704, // Range #309: [72704, 72773, Bhaiksuki]
+ 72784, // Range #310: [72784, 72812, Bhaiksuki]
+ 72816, // Range #311: [72816, 72886, Marchen]
+ 72960, // Range #312: [72960, 73031, Masaram_Gondi]
+ 73040, // Range #313: [73040, 73049, Masaram_Gondi]
+ 73056, // Range #314: [73056, 73112, Gunjala_Gondi]
+ 73120, // Range #315: [73120, 73129, Gunjala_Gondi]
+ 73440, // Range #316: [73440, 73464, Makasar]
+ 73728, // Range #317: [73728, 74649, Cuneiform]
+ 74752, // Range #318: [74752, 74868, Cuneiform]
+ 74880, // Range #319: [74880, 75075, Cuneiform]
+ 77824, // Range #320: [77824, 78894, Egyptian_Hieroglyphs]
+ 82944, // Range #321: [82944, 83526, Anatolian_Hieroglyphs]
+ 92160, // Range #322: [92160, 92728, Bamum]
+ 92736, // Range #323: [92736, 92783, Mro]
+ 92880, // Range #324: [92880, 92917, Bassa_Vah]
+ 92928, // Range #325: [92928, 92997, Pahawh_Hmong]
+ 93008, // Range #326: [93008, 93047, Pahawh_Hmong]
+ 93053, // Range #327: [93053, 93071, Pahawh_Hmong]
+ 93760, // Range #328: [93760, 93850, Medefaidrin]
+ 93952, // Range #329: [93952, 94020, Miao]
+ 94032, // Range #330: [94032, 94078, Miao]
+ 94095, // Range #331: [94095, 94111, Miao]
+ 94176, // Range #332: [94176, 94176, Tangut]
+ 94177, // Range #333: [94177, 94177, Nushu]
+ 94208, // Range #334: [94208, 100337, Tangut]
+ 100352, // Range #335: [100352, 101106, Tangut]
+ 110592, // Range #336: [110592, 110592, Katakana]
+ 110593, // Range #337: [110593, 110878, Hiragana]
+ 110960, // Range #338: [110960, 111355, Nushu]
+ 113664, // Range #339: [113664, 113770, Duployan]
+ 113776, // Range #340: [113776, 113800, Duployan]
+ 113808, // Range #341: [113808, 113823, Duployan]
+ 119296, // Range #342: [119296, 119365, Greek]
+ 120832, // Range #343: [120832, 121483, SignWriting]
+ 121499, // Range #344: [121499, 121519, SignWriting]
+ 122880, // Range #345: [122880, 122922, Glagolitic]
+ 124928, // Range #346: [124928, 125142, Mende_Kikakui]
+ 125184, // Range #347: [125184, 125258, Adlam]
+ 125264, // Range #348: [125264, 125279, Adlam]
+ 126464, // Range #349: [126464, 126523, Arabic]
+ 126530, // Range #350: [126530, 126619, Arabic]
+ 126625, // Range #351: [126625, 126651, Arabic]
+ 126704, // Range #352: [126704, 126705, Arabic]
+ 127488, // Range #353: [127488, 127488, Hiragana]
+ 131072, // Range #354: [131072, 173782, Han]
+ 173824, // Range #355: [173824, 177972, Han]
+ 177984, // Range #356: [177984, 183969, Han]
+ 183984, // Range #357: [183984, 191456, Han]
+ 194560, // Range #358: [194560, 195101, Han]
+};
+
+const uint16 kRangeSizeMinusOne[] = {
+ 25, // Range #0: [65, 90, Latin]
+ 25, // Range #1: [97, 122, Latin]
+ 0, // Range #2: [170, 170, Latin]
+ 0, // Range #3: [186, 186, Latin]
+ 22, // Range #4: [192, 214, Latin]
+ 30, // Range #5: [216, 246, Latin]
+ 448, // Range #6: [248, 696, Latin]
+ 4, // Range #7: [736, 740, Latin]
+ 1, // Range #8: [746, 747, Bopomofo]
+ 3, // Range #9: [880, 883, Greek]
+ 8, // Range #10: [885, 893, Greek]
+ 5, // Range #11: [895, 900, Greek]
+ 0, // Range #12: [902, 902, Greek]
+ 89, // Range #13: [904, 993, Greek]
+ 13, // Range #14: [994, 1007, Coptic]
+ 15, // Range #15: [1008, 1023, Greek]
+ 132, // Range #16: [1024, 1156, Cyrillic]
+ 168, // Range #17: [1159, 1327, Cyrillic]
+ 87, // Range #18: [1329, 1416, Armenian]
+ 5, // Range #19: [1418, 1423, Armenian]
+ 54, // Range #20: [1425, 1479, Hebrew]
+ 36, // Range #21: [1488, 1524, Hebrew]
+ 4, // Range #22: [1536, 1540, Arabic]
+ 5, // Range #23: [1542, 1547, Arabic]
+ 13, // Range #24: [1549, 1562, Arabic]
+ 2, // Range #25: [1564, 1566, Arabic]
+ 31, // Range #26: [1568, 1599, Arabic]
+ 9, // Range #27: [1601, 1610, Arabic]
+ 25, // Range #28: [1622, 1647, Arabic]
+ 107, // Range #29: [1649, 1756, Arabic]
+ 33, // Range #30: [1758, 1791, Arabic]
+ 79, // Range #31: [1792, 1871, Syriac]
+ 47, // Range #32: [1872, 1919, Arabic]
+ 49, // Range #33: [1920, 1969, Thaana]
+ 63, // Range #34: [1984, 2047, Nko]
+ 62, // Range #35: [2048, 2110, Samaritan]
+ 30, // Range #36: [2112, 2142, Mandaic]
+ 10, // Range #37: [2144, 2154, Syriac]
+ 29, // Range #38: [2208, 2237, Arabic]
+ 14, // Range #39: [2259, 2273, Arabic]
+ 28, // Range #40: [2275, 2303, Arabic]
+ 80, // Range #41: [2304, 2384, Devanagari]
+ 16, // Range #42: [2387, 2403, Devanagari]
+ 25, // Range #43: [2406, 2431, Devanagari]
+ 78, // Range #44: [2432, 2510, Bengali]
+ 39, // Range #45: [2519, 2558, Bengali]
+ 80, // Range #46: [2561, 2641, Gurmukhi]
+ 5, // Range #47: [2649, 2654, Gurmukhi]
+ 16, // Range #48: [2662, 2678, Gurmukhi]
+ 79, // Range #49: [2689, 2768, Gujarati]
+ 17, // Range #50: [2784, 2801, Gujarati]
+ 6, // Range #51: [2809, 2815, Gujarati]
+ 76, // Range #52: [2817, 2893, Oriya]
+ 33, // Range #53: [2902, 2935, Oriya]
+ 78, // Range #54: [2946, 3024, Tamil]
+ 0, // Range #55: [3031, 3031, Tamil]
+ 20, // Range #56: [3046, 3066, Tamil]
+ 77, // Range #57: [3072, 3149, Telugu]
+ 5, // Range #58: [3157, 3162, Telugu]
+ 15, // Range #59: [3168, 3183, Telugu]
+ 7, // Range #60: [3192, 3199, Telugu]
+ 77, // Range #61: [3200, 3277, Kannada]
+ 1, // Range #62: [3285, 3286, Kannada]
+ 20, // Range #63: [3294, 3314, Kannada]
+ 127, // Range #64: [3328, 3455, Malayalam]
+ 93, // Range #65: [3458, 3551, Sinhala]
+ 14, // Range #66: [3558, 3572, Sinhala]
+ 57, // Range #67: [3585, 3642, Thai]
+ 27, // Range #68: [3648, 3675, Thai]
+ 12, // Range #69: [3713, 3725, Lao]
+ 75, // Range #70: [3732, 3807, Lao]
+ 212, // Range #71: [3840, 4052, Tibetan]
+ 1, // Range #72: [4057, 4058, Tibetan]
+ 159, // Range #73: [4096, 4255, Myanmar]
+ 39, // Range #74: [4256, 4295, Georgian]
+ 45, // Range #75: [4301, 4346, Georgian]
+ 3, // Range #76: [4348, 4351, Georgian]
+ 255, // Range #77: [4352, 4607, Hangul]
+ 409, // Range #78: [4608, 5017, Ethiopic]
+ 93, // Range #79: [5024, 5117, Cherokee]
+ 639, // Range #80: [5120, 5759, Canadian_Aboriginal]
+ 28, // Range #81: [5760, 5788, Ogham]
+ 74, // Range #82: [5792, 5866, Runic]
+ 10, // Range #83: [5870, 5880, Runic]
+ 20, // Range #84: [5888, 5908, Tagalog]
+ 20, // Range #85: [5920, 5940, Hanunoo]
+ 19, // Range #86: [5952, 5971, Buhid]
+ 19, // Range #87: [5984, 6003, Tagbanwa]
+ 105, // Range #88: [6016, 6121, Khmer]
+ 9, // Range #89: [6128, 6137, Khmer]
+ 1, // Range #90: [6144, 6145, Mongolian]
+ 0, // Range #91: [6148, 6148, Mongolian]
+ 19, // Range #92: [6150, 6169, Mongolian]
+ 88, // Range #93: [6176, 6264, Mongolian]
+ 42, // Range #94: [6272, 6314, Mongolian]
+ 69, // Range #95: [6320, 6389, Canadian_Aboriginal]
+ 79, // Range #96: [6400, 6479, Limbu]
+ 36, // Range #97: [6480, 6516, Tai_Le]
+ 73, // Range #98: [6528, 6601, New_Tai_Lue]
+ 15, // Range #99: [6608, 6623, New_Tai_Lue]
+ 31, // Range #100: [6624, 6655, Khmer]
+ 31, // Range #101: [6656, 6687, Buginese]
+ 105, // Range #102: [6688, 6793, Tai_Tham]
+ 9, // Range #103: [6800, 6809, Tai_Tham]
+ 13, // Range #104: [6816, 6829, Tai_Tham]
+ 124, // Range #105: [6912, 7036, Balinese]
+ 63, // Range #106: [7040, 7103, Sundanese]
+ 51, // Range #107: [7104, 7155, Batak]
+ 3, // Range #108: [7164, 7167, Batak]
+ 79, // Range #109: [7168, 7247, Lepcha]
+ 47, // Range #110: [7248, 7295, Ol_Chiki]
+ 8, // Range #111: [7296, 7304, Cyrillic]
+ 47, // Range #112: [7312, 7359, Georgian]
+ 7, // Range #113: [7360, 7367, Sundanese]
+ 37, // Range #114: [7424, 7461, Latin]
+ 4, // Range #115: [7462, 7466, Greek]
+ 0, // Range #116: [7467, 7467, Cyrillic]
+ 48, // Range #117: [7468, 7516, Latin]
+ 4, // Range #118: [7517, 7521, Greek]
+ 3, // Range #119: [7522, 7525, Latin]
+ 4, // Range #120: [7526, 7530, Greek]
+ 12, // Range #121: [7531, 7543, Latin]
+ 0, // Range #122: [7544, 7544, Cyrillic]
+ 69, // Range #123: [7545, 7614, Latin]
+ 0, // Range #124: [7615, 7615, Greek]
+ 255, // Range #125: [7680, 7935, Latin]
+ 254, // Range #126: [7936, 8190, Greek]
+ 0, // Range #127: [8305, 8305, Latin]
+ 0, // Range #128: [8319, 8319, Latin]
+ 12, // Range #129: [8336, 8348, Latin]
+ 0, // Range #130: [8486, 8486, Greek]
+ 1, // Range #131: [8490, 8491, Latin]
+ 0, // Range #132: [8498, 8498, Latin]
+ 0, // Range #133: [8526, 8526, Latin]
+ 40, // Range #134: [8544, 8584, Latin]
+ 255, // Range #135: [10240, 10495, Braille]
+ 94, // Range #136: [11264, 11358, Glagolitic]
+ 31, // Range #137: [11360, 11391, Latin]
+ 115, // Range #138: [11392, 11507, Coptic]
+ 6, // Range #139: [11513, 11519, Coptic]
+ 39, // Range #140: [11520, 11559, Georgian]
+ 0, // Range #141: [11565, 11565, Georgian]
+ 55, // Range #142: [11568, 11623, Tifinagh]
+ 1, // Range #143: [11631, 11632, Tifinagh]
+ 0, // Range #144: [11647, 11647, Tifinagh]
+ 22, // Range #145: [11648, 11670, Ethiopic]
+ 62, // Range #146: [11680, 11742, Ethiopic]
+ 31, // Range #147: [11744, 11775, Cyrillic]
+ 115, // Range #148: [11904, 12019, Han]
+ 213, // Range #149: [12032, 12245, Han]
+ 0, // Range #150: [12293, 12293, Han]
+ 0, // Range #151: [12295, 12295, Han]
+ 8, // Range #152: [12321, 12329, Han]
+ 1, // Range #153: [12334, 12335, Hangul]
+ 3, // Range #154: [12344, 12347, Han]
+ 85, // Range #155: [12353, 12438, Hiragana]
+ 2, // Range #156: [12445, 12447, Hiragana]
+ 89, // Range #157: [12449, 12538, Katakana]
+ 2, // Range #158: [12541, 12543, Katakana]
+ 42, // Range #159: [12549, 12591, Bopomofo]
+ 93, // Range #160: [12593, 12686, Hangul]
+ 26, // Range #161: [12704, 12730, Bopomofo]
+ 15, // Range #162: [12784, 12799, Katakana]
+ 30, // Range #163: [12800, 12830, Hangul]
+ 30, // Range #164: [12896, 12926, Hangul]
+ 135, // Range #165: [13008, 13143, Katakana]
+ 6581, // Range #166: [13312, 19893, Han]
+ 20975, // Range #167: [19968, 40943, Han]
+ 1222, // Range #168: [40960, 42182, Yi]
+ 47, // Range #169: [42192, 42239, Lisu]
+ 299, // Range #170: [42240, 42539, Vai]
+ 95, // Range #171: [42560, 42655, Cyrillic]
+ 87, // Range #172: [42656, 42743, Bamum]
+ 101, // Range #173: [42786, 42887, Latin]
+ 46, // Range #174: [42891, 42937, Latin]
+ 8, // Range #175: [42999, 43007, Latin]
+ 43, // Range #176: [43008, 43051, Syloti_Nagri]
+ 55, // Range #177: [43072, 43127, Phags_Pa]
+ 69, // Range #178: [43136, 43205, Saurashtra]
+ 11, // Range #179: [43214, 43225, Saurashtra]
+ 31, // Range #180: [43232, 43263, Devanagari]
+ 45, // Range #181: [43264, 43309, Kayah_Li]
+ 0, // Range #182: [43311, 43311, Kayah_Li]
+ 35, // Range #183: [43312, 43347, Rejang]
+ 0, // Range #184: [43359, 43359, Rejang]
+ 28, // Range #185: [43360, 43388, Hangul]
+ 77, // Range #186: [43392, 43469, Javanese]
+ 15, // Range #187: [43472, 43487, Javanese]
+ 30, // Range #188: [43488, 43518, Myanmar]
+ 54, // Range #189: [43520, 43574, Cham]
+ 31, // Range #190: [43584, 43615, Cham]
+ 31, // Range #191: [43616, 43647, Myanmar]
+ 66, // Range #192: [43648, 43714, Tai_Viet]
+ 4, // Range #193: [43739, 43743, Tai_Viet]
+ 22, // Range #194: [43744, 43766, Meetei_Mayek]
+ 21, // Range #195: [43777, 43798, Ethiopic]
+ 14, // Range #196: [43808, 43822, Ethiopic]
+ 42, // Range #197: [43824, 43866, Latin]
+ 8, // Range #198: [43868, 43876, Latin]
+ 0, // Range #199: [43877, 43877, Greek]
+ 79, // Range #200: [43888, 43967, Cherokee]
+ 57, // Range #201: [43968, 44025, Meetei_Mayek]
+ 11171, // Range #202: [44032, 55203, Hangul]
+ 75, // Range #203: [55216, 55291, Hangul]
+ 473, // Range #204: [63744, 64217, Han]
+ 6, // Range #205: [64256, 64262, Latin]
+ 4, // Range #206: [64275, 64279, Armenian]
+ 50, // Range #207: [64285, 64335, Hebrew]
+ 113, // Range #208: [64336, 64449, Arabic]
+ 362, // Range #209: [64467, 64829, Arabic]
+ 119, // Range #210: [64848, 64967, Arabic]
+ 13, // Range #211: [65008, 65021, Arabic]
+ 1, // Range #212: [65070, 65071, Cyrillic]
+ 140, // Range #213: [65136, 65276, Arabic]
+ 25, // Range #214: [65313, 65338, Latin]
+ 25, // Range #215: [65345, 65370, Latin]
+ 9, // Range #216: [65382, 65391, Katakana]
+ 44, // Range #217: [65393, 65437, Katakana]
+ 60, // Range #218: [65440, 65500, Hangul]
+ 93, // Range #219: [65536, 65629, Linear_B]
+ 122, // Range #220: [65664, 65786, Linear_B]
+ 78, // Range #221: [65856, 65934, Greek]
+ 0, // Range #222: [65952, 65952, Greek]
+ 28, // Range #223: [66176, 66204, Lycian]
+ 48, // Range #224: [66208, 66256, Carian]
+ 35, // Range #225: [66304, 66339, Old_Italic]
+ 2, // Range #226: [66349, 66351, Old_Italic]
+ 26, // Range #227: [66352, 66378, Gothic]
+ 42, // Range #228: [66384, 66426, Old_Permic]
+ 31, // Range #229: [66432, 66463, Ugaritic]
+ 53, // Range #230: [66464, 66517, Old_Persian]
+ 79, // Range #231: [66560, 66639, Deseret]
+ 47, // Range #232: [66640, 66687, Shavian]
+ 41, // Range #233: [66688, 66729, Osmanya]
+ 75, // Range #234: [66736, 66811, Osage]
+ 39, // Range #235: [66816, 66855, Elbasan]
+ 51, // Range #236: [66864, 66915, Caucasian_Albanian]
+ 0, // Range #237: [66927, 66927, Caucasian_Albanian]
+ 310, // Range #238: [67072, 67382, Linear_A]
+ 21, // Range #239: [67392, 67413, Linear_A]
+ 7, // Range #240: [67424, 67431, Linear_A]
+ 63, // Range #241: [67584, 67647, Cypriot]
+ 31, // Range #242: [67648, 67679, Imperial_Aramaic]
+ 31, // Range #243: [67680, 67711, Palmyrene]
+ 30, // Range #244: [67712, 67742, Nabataean]
+ 8, // Range #245: [67751, 67759, Nabataean]
+ 21, // Range #246: [67808, 67829, Hatran]
+ 4, // Range #247: [67835, 67839, Hatran]
+ 31, // Range #248: [67840, 67871, Phoenician]
+ 25, // Range #249: [67872, 67897, Lydian]
+ 0, // Range #250: [67903, 67903, Lydian]
+ 31, // Range #251: [67968, 67999, Meroitic_Hieroglyphs]
+ 95, // Range #252: [68000, 68095, Meroitic_Cursive]
+ 6, // Range #253: [68096, 68102, Kharoshthi]
+ 60, // Range #254: [68108, 68168, Kharoshthi]
+ 8, // Range #255: [68176, 68184, Kharoshthi]
+ 31, // Range #256: [68192, 68223, Old_South_Arabian]
+ 31, // Range #257: [68224, 68255, Old_North_Arabian]
+ 54, // Range #258: [68288, 68342, Manichaean]
+ 63, // Range #259: [68352, 68415, Avestan]
+ 31, // Range #260: [68416, 68447, Inscriptional_Parthian]
+ 18, // Range #261: [68448, 68466, Inscriptional_Pahlavi]
+ 7, // Range #262: [68472, 68479, Inscriptional_Pahlavi]
+ 17, // Range #263: [68480, 68497, Psalter_Pahlavi]
+ 3, // Range #264: [68505, 68508, Psalter_Pahlavi]
+ 6, // Range #265: [68521, 68527, Psalter_Pahlavi]
+ 72, // Range #266: [68608, 68680, Old_Turkic]
+ 50, // Range #267: [68736, 68786, Old_Hungarian]
+ 50, // Range #268: [68800, 68850, Old_Hungarian]
+ 5, // Range #269: [68858, 68863, Old_Hungarian]
+ 39, // Range #270: [68864, 68903, Hanifi_Rohingya]
+ 9, // Range #271: [68912, 68921, Hanifi_Rohingya]
+ 30, // Range #272: [69216, 69246, Arabic]
+ 39, // Range #273: [69376, 69415, Old_Sogdian]
+ 41, // Range #274: [69424, 69465, Sogdian]
+ 111, // Range #275: [69632, 69743, Brahmi]
+ 0, // Range #276: [69759, 69759, Brahmi]
+ 65, // Range #277: [69760, 69825, Kaithi]
+ 0, // Range #278: [69837, 69837, Kaithi]
+ 24, // Range #279: [69840, 69864, Sora_Sompeng]
+ 9, // Range #280: [69872, 69881, Sora_Sompeng]
+ 70, // Range #281: [69888, 69958, Chakma]
+ 38, // Range #282: [69968, 70006, Mahajani]
+ 95, // Range #283: [70016, 70111, Sharada]
+ 19, // Range #284: [70113, 70132, Sinhala]
+ 62, // Range #285: [70144, 70206, Khojki]
+ 41, // Range #286: [70272, 70313, Multani]
+ 58, // Range #287: [70320, 70378, Khudawadi]
+ 9, // Range #288: [70384, 70393, Khudawadi]
+ 57, // Range #289: [70400, 70457, Grantha]
+ 20, // Range #290: [70460, 70480, Grantha]
+ 0, // Range #291: [70487, 70487, Grantha]
+ 23, // Range #292: [70493, 70516, Grantha]
+ 94, // Range #293: [70656, 70750, Newa]
+ 71, // Range #294: [70784, 70855, Tirhuta]
+ 9, // Range #295: [70864, 70873, Tirhuta]
+ 93, // Range #296: [71040, 71133, Siddham]
+ 68, // Range #297: [71168, 71236, Modi]
+ 9, // Range #298: [71248, 71257, Modi]
+ 12, // Range #299: [71264, 71276, Mongolian]
+ 55, // Range #300: [71296, 71351, Takri]
+ 9, // Range #301: [71360, 71369, Takri]
+ 63, // Range #302: [71424, 71487, Ahom]
+ 59, // Range #303: [71680, 71739, Dogra]
+ 82, // Range #304: [71840, 71922, Warang_Citi]
+ 0, // Range #305: [71935, 71935, Warang_Citi]
+ 71, // Range #306: [72192, 72263, Zanabazar_Square]
+ 82, // Range #307: [72272, 72354, Soyombo]
+ 56, // Range #308: [72384, 72440, Pau_Cin_Hau]
+ 69, // Range #309: [72704, 72773, Bhaiksuki]
+ 28, // Range #310: [72784, 72812, Bhaiksuki]
+ 70, // Range #311: [72816, 72886, Marchen]
+ 71, // Range #312: [72960, 73031, Masaram_Gondi]
+ 9, // Range #313: [73040, 73049, Masaram_Gondi]
+ 56, // Range #314: [73056, 73112, Gunjala_Gondi]
+ 9, // Range #315: [73120, 73129, Gunjala_Gondi]
+ 24, // Range #316: [73440, 73464, Makasar]
+ 921, // Range #317: [73728, 74649, Cuneiform]
+ 116, // Range #318: [74752, 74868, Cuneiform]
+ 195, // Range #319: [74880, 75075, Cuneiform]
+ 1070, // Range #320: [77824, 78894, Egyptian_Hieroglyphs]
+ 582, // Range #321: [82944, 83526, Anatolian_Hieroglyphs]
+ 568, // Range #322: [92160, 92728, Bamum]
+ 47, // Range #323: [92736, 92783, Mro]
+ 37, // Range #324: [92880, 92917, Bassa_Vah]
+ 69, // Range #325: [92928, 92997, Pahawh_Hmong]
+ 39, // Range #326: [93008, 93047, Pahawh_Hmong]
+ 18, // Range #327: [93053, 93071, Pahawh_Hmong]
+ 90, // Range #328: [93760, 93850, Medefaidrin]
+ 68, // Range #329: [93952, 94020, Miao]
+ 46, // Range #330: [94032, 94078, Miao]
+ 16, // Range #331: [94095, 94111, Miao]
+ 0, // Range #332: [94176, 94176, Tangut]
+ 0, // Range #333: [94177, 94177, Nushu]
+ 6129, // Range #334: [94208, 100337, Tangut]
+ 754, // Range #335: [100352, 101106, Tangut]
+ 0, // Range #336: [110592, 110592, Katakana]
+ 285, // Range #337: [110593, 110878, Hiragana]
+ 395, // Range #338: [110960, 111355, Nushu]
+ 106, // Range #339: [113664, 113770, Duployan]
+ 24, // Range #340: [113776, 113800, Duployan]
+ 15, // Range #341: [113808, 113823, Duployan]
+ 69, // Range #342: [119296, 119365, Greek]
+ 651, // Range #343: [120832, 121483, SignWriting]
+ 20, // Range #344: [121499, 121519, SignWriting]
+ 42, // Range #345: [122880, 122922, Glagolitic]
+ 214, // Range #346: [124928, 125142, Mende_Kikakui]
+ 74, // Range #347: [125184, 125258, Adlam]
+ 15, // Range #348: [125264, 125279, Adlam]
+ 59, // Range #349: [126464, 126523, Arabic]
+ 89, // Range #350: [126530, 126619, Arabic]
+ 26, // Range #351: [126625, 126651, Arabic]
+ 1, // Range #352: [126704, 126705, Arabic]
+ 0, // Range #353: [127488, 127488, Hiragana]
+ 42710, // Range #354: [131072, 173782, Han]
+ 4148, // Range #355: [173824, 177972, Han]
+ 5985, // Range #356: [177984, 183969, Han]
+ 7472, // Range #357: [183984, 191456, Han]
+ 541, // Range #358: [194560, 195101, Han]
+};
+
+const uint8 kRangeScript[] = {
+ 25, // Range #0: [65, 90, Latin]
+ 25, // Range #1: [97, 122, Latin]
+ 25, // Range #2: [170, 170, Latin]
+ 25, // Range #3: [186, 186, Latin]
+ 25, // Range #4: [192, 214, Latin]
+ 25, // Range #5: [216, 246, Latin]
+ 25, // Range #6: [248, 696, Latin]
+ 25, // Range #7: [736, 740, Latin]
+ 5, // Range #8: [746, 747, Bopomofo]
+ 14, // Range #9: [880, 883, Greek]
+ 14, // Range #10: [885, 893, Greek]
+ 14, // Range #11: [895, 900, Greek]
+ 14, // Range #12: [902, 902, Greek]
+ 14, // Range #13: [904, 993, Greek]
+ 7, // Range #14: [994, 1007, Coptic]
+ 14, // Range #15: [1008, 1023, Greek]
+ 8, // Range #16: [1024, 1156, Cyrillic]
+ 8, // Range #17: [1159, 1327, Cyrillic]
+ 3, // Range #18: [1329, 1416, Armenian]
+ 3, // Range #19: [1418, 1423, Armenian]
+ 19, // Range #20: [1425, 1479, Hebrew]
+ 19, // Range #21: [1488, 1524, Hebrew]
+ 2, // Range #22: [1536, 1540, Arabic]
+ 2, // Range #23: [1542, 1547, Arabic]
+ 2, // Range #24: [1549, 1562, Arabic]
+ 2, // Range #25: [1564, 1566, Arabic]
+ 2, // Range #26: [1568, 1599, Arabic]
+ 2, // Range #27: [1601, 1610, Arabic]
+ 2, // Range #28: [1622, 1647, Arabic]
+ 2, // Range #29: [1649, 1756, Arabic]
+ 2, // Range #30: [1758, 1791, Arabic]
+ 34, // Range #31: [1792, 1871, Syriac]
+ 2, // Range #32: [1872, 1919, Arabic]
+ 37, // Range #33: [1920, 1969, Thaana]
+ 87, // Range #34: [1984, 2047, Nko]
+ 126, // Range #35: [2048, 2110, Samaritan]
+ 84, // Range #36: [2112, 2142, Mandaic]
+ 34, // Range #37: [2144, 2154, Syriac]
+ 2, // Range #38: [2208, 2237, Arabic]
+ 2, // Range #39: [2259, 2273, Arabic]
+ 2, // Range #40: [2275, 2303, Arabic]
+ 10, // Range #41: [2304, 2384, Devanagari]
+ 10, // Range #42: [2387, 2403, Devanagari]
+ 10, // Range #43: [2406, 2431, Devanagari]
+ 4, // Range #44: [2432, 2510, Bengali]
+ 4, // Range #45: [2519, 2558, Bengali]
+ 16, // Range #46: [2561, 2641, Gurmukhi]
+ 16, // Range #47: [2649, 2654, Gurmukhi]
+ 16, // Range #48: [2662, 2678, Gurmukhi]
+ 15, // Range #49: [2689, 2768, Gujarati]
+ 15, // Range #50: [2784, 2801, Gujarati]
+ 15, // Range #51: [2809, 2815, Gujarati]
+ 31, // Range #52: [2817, 2893, Oriya]
+ 31, // Range #53: [2902, 2935, Oriya]
+ 35, // Range #54: [2946, 3024, Tamil]
+ 35, // Range #55: [3031, 3031, Tamil]
+ 35, // Range #56: [3046, 3066, Tamil]
+ 36, // Range #57: [3072, 3149, Telugu]
+ 36, // Range #58: [3157, 3162, Telugu]
+ 36, // Range #59: [3168, 3183, Telugu]
+ 36, // Range #60: [3192, 3199, Telugu]
+ 21, // Range #61: [3200, 3277, Kannada]
+ 21, // Range #62: [3285, 3286, Kannada]
+ 21, // Range #63: [3294, 3314, Kannada]
+ 26, // Range #64: [3328, 3455, Malayalam]
+ 33, // Range #65: [3458, 3551, Sinhala]
+ 33, // Range #66: [3558, 3572, Sinhala]
+ 38, // Range #67: [3585, 3642, Thai]
+ 38, // Range #68: [3648, 3675, Thai]
+ 24, // Range #69: [3713, 3725, Lao]
+ 24, // Range #70: [3732, 3807, Lao]
+ 39, // Range #71: [3840, 4052, Tibetan]
+ 39, // Range #72: [4057, 4058, Tibetan]
+ 28, // Range #73: [4096, 4255, Myanmar]
+ 12, // Range #74: [4256, 4295, Georgian]
+ 12, // Range #75: [4301, 4346, Georgian]
+ 12, // Range #76: [4348, 4351, Georgian]
+ 18, // Range #77: [4352, 4607, Hangul]
+ 11, // Range #78: [4608, 5017, Ethiopic]
+ 6, // Range #79: [5024, 5117, Cherokee]
+ 40, // Range #80: [5120, 5759, Canadian_Aboriginal]
+ 29, // Range #81: [5760, 5788, Ogham]
+ 32, // Range #82: [5792, 5866, Runic]
+ 32, // Range #83: [5870, 5880, Runic]
+ 42, // Range #84: [5888, 5908, Tagalog]
+ 43, // Range #85: [5920, 5940, Hanunoo]
+ 44, // Range #86: [5952, 5971, Buhid]
+ 45, // Range #87: [5984, 6003, Tagbanwa]
+ 23, // Range #88: [6016, 6121, Khmer]
+ 23, // Range #89: [6128, 6137, Khmer]
+ 27, // Range #90: [6144, 6145, Mongolian]
+ 27, // Range #91: [6148, 6148, Mongolian]
+ 27, // Range #92: [6150, 6169, Mongolian]
+ 27, // Range #93: [6176, 6264, Mongolian]
+ 27, // Range #94: [6272, 6314, Mongolian]
+ 40, // Range #95: [6320, 6389, Canadian_Aboriginal]
+ 48, // Range #96: [6400, 6479, Limbu]
+ 52, // Range #97: [6480, 6516, Tai_Le]
+ 59, // Range #98: [6528, 6601, New_Tai_Lue]
+ 59, // Range #99: [6608, 6623, New_Tai_Lue]
+ 23, // Range #100: [6624, 6655, Khmer]
+ 55, // Range #101: [6656, 6687, Buginese]
+ 106, // Range #102: [6688, 6793, Tai_Tham]
+ 106, // Range #103: [6800, 6809, Tai_Tham]
+ 106, // Range #104: [6816, 6829, Tai_Tham]
+ 62, // Range #105: [6912, 7036, Balinese]
+ 113, // Range #106: [7040, 7103, Sundanese]
+ 63, // Range #107: [7104, 7155, Batak]
+ 63, // Range #108: [7164, 7167, Batak]
+ 82, // Range #109: [7168, 7247, Lepcha]
+ 109, // Range #110: [7248, 7295, Ol_Chiki]
+ 8, // Range #111: [7296, 7304, Cyrillic]
+ 12, // Range #112: [7312, 7359, Georgian]
+ 113, // Range #113: [7360, 7367, Sundanese]
+ 25, // Range #114: [7424, 7461, Latin]
+ 14, // Range #115: [7462, 7466, Greek]
+ 8, // Range #116: [7467, 7467, Cyrillic]
+ 25, // Range #117: [7468, 7516, Latin]
+ 14, // Range #118: [7517, 7521, Greek]
+ 25, // Range #119: [7522, 7525, Latin]
+ 14, // Range #120: [7526, 7530, Greek]
+ 25, // Range #121: [7531, 7543, Latin]
+ 8, // Range #122: [7544, 7544, Cyrillic]
+ 25, // Range #123: [7545, 7614, Latin]
+ 14, // Range #124: [7615, 7615, Greek]
+ 25, // Range #125: [7680, 7935, Latin]
+ 14, // Range #126: [7936, 8190, Greek]
+ 25, // Range #127: [8305, 8305, Latin]
+ 25, // Range #128: [8319, 8319, Latin]
+ 25, // Range #129: [8336, 8348, Latin]
+ 14, // Range #130: [8486, 8486, Greek]
+ 25, // Range #131: [8490, 8491, Latin]
+ 25, // Range #132: [8498, 8498, Latin]
+ 25, // Range #133: [8526, 8526, Latin]
+ 25, // Range #134: [8544, 8584, Latin]
+ 46, // Range #135: [10240, 10495, Braille]
+ 56, // Range #136: [11264, 11358, Glagolitic]
+ 25, // Range #137: [11360, 11391, Latin]
+ 7, // Range #138: [11392, 11507, Coptic]
+ 7, // Range #139: [11513, 11519, Coptic]
+ 12, // Range #140: [11520, 11559, Georgian]
+ 12, // Range #141: [11565, 11565, Georgian]
+ 60, // Range #142: [11568, 11623, Tifinagh]
+ 60, // Range #143: [11631, 11632, Tifinagh]
+ 60, // Range #144: [11647, 11647, Tifinagh]
+ 11, // Range #145: [11648, 11670, Ethiopic]
+ 11, // Range #146: [11680, 11742, Ethiopic]
+ 8, // Range #147: [11744, 11775, Cyrillic]
+ 17, // Range #148: [11904, 12019, Han]
+ 17, // Range #149: [12032, 12245, Han]
+ 17, // Range #150: [12293, 12293, Han]
+ 17, // Range #151: [12295, 12295, Han]
+ 17, // Range #152: [12321, 12329, Han]
+ 18, // Range #153: [12334, 12335, Hangul]
+ 17, // Range #154: [12344, 12347, Han]
+ 20, // Range #155: [12353, 12438, Hiragana]
+ 20, // Range #156: [12445, 12447, Hiragana]
+ 22, // Range #157: [12449, 12538, Katakana]
+ 22, // Range #158: [12541, 12543, Katakana]
+ 5, // Range #159: [12549, 12591, Bopomofo]
+ 18, // Range #160: [12593, 12686, Hangul]
+ 5, // Range #161: [12704, 12730, Bopomofo]
+ 22, // Range #162: [12784, 12799, Katakana]
+ 18, // Range #163: [12800, 12830, Hangul]
+ 18, // Range #164: [12896, 12926, Hangul]
+ 22, // Range #165: [13008, 13143, Katakana]
+ 17, // Range #166: [13312, 19893, Han]
+ 17, // Range #167: [19968, 40943, Han]
+ 41, // Range #168: [40960, 42182, Yi]
+ 131, // Range #169: [42192, 42239, Lisu]
+ 99, // Range #170: [42240, 42539, Vai]
+ 8, // Range #171: [42560, 42655, Cyrillic]
+ 130, // Range #172: [42656, 42743, Bamum]
+ 25, // Range #173: [42786, 42887, Latin]
+ 25, // Range #174: [42891, 42937, Latin]
+ 25, // Range #175: [42999, 43007, Latin]
+ 58, // Range #176: [43008, 43051, Syloti_Nagri]
+ 90, // Range #177: [43072, 43127, Phags_Pa]
+ 111, // Range #178: [43136, 43205, Saurashtra]
+ 111, // Range #179: [43214, 43225, Saurashtra]
+ 10, // Range #180: [43232, 43263, Devanagari]
+ 79, // Range #181: [43264, 43309, Kayah_Li]
+ 79, // Range #182: [43311, 43311, Kayah_Li]
+ 110, // Range #183: [43312, 43347, Rejang]
+ 110, // Range #184: [43359, 43359, Rejang]
+ 18, // Range #185: [43360, 43388, Hangul]
+ 78, // Range #186: [43392, 43469, Javanese]
+ 78, // Range #187: [43472, 43487, Javanese]
+ 28, // Range #188: [43488, 43518, Myanmar]
+ 66, // Range #189: [43520, 43574, Cham]
+ 66, // Range #190: [43584, 43615, Cham]
+ 28, // Range #191: [43616, 43647, Myanmar]
+ 127, // Range #192: [43648, 43714, Tai_Viet]
+ 127, // Range #193: [43739, 43743, Tai_Viet]
+ 115, // Range #194: [43744, 43766, Meetei_Mayek]
+ 11, // Range #195: [43777, 43798, Ethiopic]
+ 11, // Range #196: [43808, 43822, Ethiopic]
+ 25, // Range #197: [43824, 43866, Latin]
+ 25, // Range #198: [43868, 43876, Latin]
+ 14, // Range #199: [43877, 43877, Greek]
+ 6, // Range #200: [43888, 43967, Cherokee]
+ 115, // Range #201: [43968, 44025, Meetei_Mayek]
+ 18, // Range #202: [44032, 55203, Hangul]
+ 18, // Range #203: [55216, 55291, Hangul]
+ 17, // Range #204: [63744, 64217, Han]
+ 25, // Range #205: [64256, 64262, Latin]
+ 3, // Range #206: [64275, 64279, Armenian]
+ 19, // Range #207: [64285, 64335, Hebrew]
+ 2, // Range #208: [64336, 64449, Arabic]
+ 2, // Range #209: [64467, 64829, Arabic]
+ 2, // Range #210: [64848, 64967, Arabic]
+ 2, // Range #211: [65008, 65021, Arabic]
+ 8, // Range #212: [65070, 65071, Cyrillic]
+ 2, // Range #213: [65136, 65276, Arabic]
+ 25, // Range #214: [65313, 65338, Latin]
+ 25, // Range #215: [65345, 65370, Latin]
+ 22, // Range #216: [65382, 65391, Katakana]
+ 22, // Range #217: [65393, 65437, Katakana]
+ 18, // Range #218: [65440, 65500, Hangul]
+ 49, // Range #219: [65536, 65629, Linear_B]
+ 49, // Range #220: [65664, 65786, Linear_B]
+ 14, // Range #221: [65856, 65934, Greek]
+ 14, // Range #222: [65952, 65952, Greek]
+ 107, // Range #223: [66176, 66204, Lycian]
+ 104, // Range #224: [66208, 66256, Carian]
+ 30, // Range #225: [66304, 66339, Old_Italic]
+ 30, // Range #226: [66349, 66351, Old_Italic]
+ 13, // Range #227: [66352, 66378, Gothic]
+ 89, // Range #228: [66384, 66426, Old_Permic]
+ 53, // Range #229: [66432, 66463, Ugaritic]
+ 61, // Range #230: [66464, 66517, Old_Persian]
+ 9, // Range #231: [66560, 66639, Deseret]
+ 51, // Range #232: [66640, 66687, Shavian]
+ 50, // Range #233: [66688, 66729, Osmanya]
+ 171, // Range #234: [66736, 66811, Osage]
+ 136, // Range #235: [66816, 66855, Elbasan]
+ 159, // Range #236: [66864, 66915, Caucasian_Albanian]
+ 159, // Range #237: [66927, 66927, Caucasian_Albanian]
+ 83, // Range #238: [67072, 67382, Linear_A]
+ 83, // Range #239: [67392, 67413, Linear_A]
+ 83, // Range #240: [67424, 67431, Linear_A]
+ 47, // Range #241: [67584, 67647, Cypriot]
+ 116, // Range #242: [67648, 67679, Imperial_Aramaic]
+ 144, // Range #243: [67680, 67711, Palmyrene]
+ 143, // Range #244: [67712, 67742, Nabataean]
+ 143, // Range #245: [67751, 67759, Nabataean]
+ 162, // Range #246: [67808, 67829, Hatran]
+ 162, // Range #247: [67835, 67839, Hatran]
+ 91, // Range #248: [67840, 67871, Phoenician]
+ 108, // Range #249: [67872, 67897, Lydian]
+ 108, // Range #250: [67903, 67903, Lydian]
+ 86, // Range #251: [67968, 67999, Meroitic_Hieroglyphs]
+ 141, // Range #252: [68000, 68095, Meroitic_Cursive]
+ 57, // Range #253: [68096, 68102, Kharoshthi]
+ 57, // Range #254: [68108, 68168, Kharoshthi]
+ 57, // Range #255: [68176, 68184, Kharoshthi]
+ 133, // Range #256: [68192, 68223, Old_South_Arabian]
+ 142, // Range #257: [68224, 68255, Old_North_Arabian]
+ 121, // Range #258: [68288, 68342, Manichaean]
+ 117, // Range #259: [68352, 68415, Avestan]
+ 125, // Range #260: [68416, 68447, Inscriptional_Parthian]
+ 122, // Range #261: [68448, 68466, Inscriptional_Pahlavi]
+ 122, // Range #262: [68472, 68479, Inscriptional_Pahlavi]
+ 123, // Range #263: [68480, 68497, Psalter_Pahlavi]
+ 123, // Range #264: [68505, 68508, Psalter_Pahlavi]
+ 123, // Range #265: [68521, 68527, Psalter_Pahlavi]
+ 88, // Range #266: [68608, 68680, Old_Turkic]
+ 76, // Range #267: [68736, 68786, Old_Hungarian]
+ 76, // Range #268: [68800, 68850, Old_Hungarian]
+ 76, // Range #269: [68858, 68863, Old_Hungarian]
+ 182, // Range #270: [68864, 68903, Hanifi_Rohingya]
+ 182, // Range #271: [68912, 68921, Hanifi_Rohingya]
+ 2, // Range #272: [69216, 69246, Arabic]
+ 184, // Range #273: [69376, 69415, Old_Sogdian]
+ 183, // Range #274: [69424, 69465, Sogdian]
+ 65, // Range #275: [69632, 69743, Brahmi]
+ 65, // Range #276: [69759, 69759, Brahmi]
+ 120, // Range #277: [69760, 69825, Kaithi]
+ 120, // Range #278: [69837, 69837, Kaithi]
+ 152, // Range #279: [69840, 69864, Sora_Sompeng]
+ 152, // Range #280: [69872, 69881, Sora_Sompeng]
+ 118, // Range #281: [69888, 69958, Chakma]
+ 160, // Range #282: [69968, 70006, Mahajani]
+ 151, // Range #283: [70016, 70111, Sharada]
+ 33, // Range #284: [70113, 70132, Sinhala]
+ 157, // Range #285: [70144, 70206, Khojki]
+ 164, // Range #286: [70272, 70313, Multani]
+ 145, // Range #287: [70320, 70378, Khudawadi]
+ 145, // Range #288: [70384, 70393, Khudawadi]
+ 137, // Range #289: [70400, 70457, Grantha]
+ 137, // Range #290: [70460, 70480, Grantha]
+ 137, // Range #291: [70487, 70487, Grantha]
+ 137, // Range #292: [70493, 70516, Grantha]
+ 170, // Range #293: [70656, 70750, Newa]
+ 158, // Range #294: [70784, 70855, Tirhuta]
+ 158, // Range #295: [70864, 70873, Tirhuta]
+ 166, // Range #296: [71040, 71133, Siddham]
+ 163, // Range #297: [71168, 71236, Modi]
+ 163, // Range #298: [71248, 71257, Modi]
+ 27, // Range #299: [71264, 71276, Mongolian]
+ 153, // Range #300: [71296, 71351, Takri]
+ 153, // Range #301: [71360, 71369, Takri]
+ 161, // Range #302: [71424, 71487, Ahom]
+ 178, // Range #303: [71680, 71739, Dogra]
+ 146, // Range #304: [71840, 71922, Warang_Citi]
+ 146, // Range #305: [71935, 71935, Warang_Citi]
+ 177, // Range #306: [72192, 72263, Zanabazar_Square]
+ 176, // Range #307: [72272, 72354, Soyombo]
+ 165, // Range #308: [72384, 72440, Pau_Cin_Hau]
+ 168, // Range #309: [72704, 72773, Bhaiksuki]
+ 168, // Range #310: [72784, 72812, Bhaiksuki]
+ 169, // Range #311: [72816, 72886, Marchen]
+ 175, // Range #312: [72960, 73031, Masaram_Gondi]
+ 175, // Range #313: [73040, 73049, Masaram_Gondi]
+ 179, // Range #314: [73056, 73112, Gunjala_Gondi]
+ 179, // Range #315: [73120, 73129, Gunjala_Gondi]
+ 180, // Range #316: [73440, 73464, Makasar]
+ 101, // Range #317: [73728, 74649, Cuneiform]
+ 101, // Range #318: [74752, 74868, Cuneiform]
+ 101, // Range #319: [74880, 75075, Cuneiform]
+ 71, // Range #320: [77824, 78894, Egyptian_Hieroglyphs]
+ 156, // Range #321: [82944, 83526, Anatolian_Hieroglyphs]
+ 130, // Range #322: [92160, 92728, Bamum]
+ 149, // Range #323: [92736, 92783, Mro]
+ 134, // Range #324: [92880, 92917, Bassa_Vah]
+ 75, // Range #325: [92928, 92997, Pahawh_Hmong]
+ 75, // Range #326: [93008, 93047, Pahawh_Hmong]
+ 75, // Range #327: [93053, 93071, Pahawh_Hmong]
+ 181, // Range #328: [93760, 93850, Medefaidrin]
+ 92, // Range #329: [93952, 94020, Miao]
+ 92, // Range #330: [94032, 94078, Miao]
+ 92, // Range #331: [94095, 94111, Miao]
+ 154, // Range #332: [94176, 94176, Tangut]
+ 150, // Range #333: [94177, 94177, Nushu]
+ 154, // Range #334: [94208, 100337, Tangut]
+ 154, // Range #335: [100352, 101106, Tangut]
+ 22, // Range #336: [110592, 110592, Katakana]
+ 20, // Range #337: [110593, 110878, Hiragana]
+ 150, // Range #338: [110960, 111355, Nushu]
+ 135, // Range #339: [113664, 113770, Duployan]
+ 135, // Range #340: [113776, 113800, Duployan]
+ 135, // Range #341: [113808, 113823, Duployan]
+ 14, // Range #342: [119296, 119365, Greek]
+ 112, // Range #343: [120832, 121483, SignWriting]
+ 112, // Range #344: [121499, 121519, SignWriting]
+ 56, // Range #345: [122880, 122922, Glagolitic]
+ 140, // Range #346: [124928, 125142, Mende_Kikakui]
+ 167, // Range #347: [125184, 125258, Adlam]
+ 167, // Range #348: [125264, 125279, Adlam]
+ 2, // Range #349: [126464, 126523, Arabic]
+ 2, // Range #350: [126530, 126619, Arabic]
+ 2, // Range #351: [126625, 126651, Arabic]
+ 2, // Range #352: [126704, 126705, Arabic]
+ 20, // Range #353: [127488, 127488, Hiragana]
+ 17, // Range #354: [131072, 173782, Han]
+ 17, // Range #355: [173824, 177972, Han]
+ 17, // Range #356: [177984, 183969, Han]
+ 17, // Range #357: [183984, 191456, Han]
+ 17, // Range #358: [194560, 195101, Han]
+};
+
+const uint8 kMaxScript = 184;
+
+} // namespace approx_script_internal
+} // namespace mobile
+} // namespace nlp_saft
diff --git a/lang_id/script/approx-script-data.h b/lang_id/script/approx-script-data.h
new file mode 100644
index 0000000..3eceed8
--- /dev/null
+++ b/lang_id/script/approx-script-data.h
@@ -0,0 +1,63 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_SCRIPT_APPROX_SCRIPT_DATA_H_
+#define NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_SCRIPT_APPROX_SCRIPT_DATA_H_
+
+#include "lang_id/common/lite_base/integral-types.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+namespace approx_script_internal {
+
+// Number of contiguous ranges of same-script codepoints (see below).
+extern const int kNumRanges;
+
+// Non-overlapping ranges of unicode characters. Characters from each range has
+// the same script (see kRangeScripts below). Multiple ranges may have the same
+// script. Note: we represent the kNumRanges ranges as an array with their
+// first codepoints, and a separate array with their sizes (see kRangeSize
+// below). This leads to better memory locality during the binary search (which
+// uses only the first codepoints, up until the very end).
+//
+// kRangeFirst[i] = first codepoint from range #i, \forall 0 <= i < kNumRanges.
+extern const uint32 kRangeFirst[];
+
+// kRangeSize[i] > 0 is the number of consecutive codepoints in range #i *minus*
+// 1, \forall 0 <= i < kNumRanges. I.e., 0 means that the range contains 1
+// codepoints. Since we don't have empty ranges, this "minus one" convention
+// allows us to use all 2^16 values here.
+extern const uint16 kRangeSizeMinusOne[];
+
+// Scripts for the ranges from kRanges. For each i such that 0 <= i <
+// kNumRanges, the range #i has the script kRangeScript[i]. Each uint8 element
+// can be casted to an UScriptCode enum value (see
+// unicode/uscript.h).
+//
+// NOTE: we don't use directly UScriptCode here, as that requires a full int
+// (due to USCRIPT_INVALID_CODE = -1). uint8 is enough for us (and shorter!)
+extern const uint8 kRangeScript[];
+
+// Max value from kRangeScript[]. Scripts are guaranteed to be in the interval
+// [0, kMaxScript] (inclusive on both sides). Can be used to e.g., set the
+// number of rows in an embedding table for a script-based feature.
+extern const uint8 kMaxScript;
+
+} // namespace approx_script_internal
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_SCRIPT_APPROX_SCRIPT_DATA_H_
diff --git a/lang_id/script/approx-script.cc b/lang_id/script/approx-script.cc
new file mode 100644
index 0000000..10afa9c
--- /dev/null
+++ b/lang_id/script/approx-script.cc
@@ -0,0 +1,94 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/script/approx-script.h"
+
+#include "lang_id/common/lite_base/logging.h"
+#include "lang_id/common/lite_base/integral-types.h"
+#include "lang_id/common/utf8.h"
+#include "lang_id/script/approx-script-data.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+// int value of USCRIPT_UNKNOWN from enum UScriptCode (from
+// unicode/uscript.h). Note: we do have a test that
+// USCRIPT_UNKNOWN evaluates to 103.
+const int kUnknownUscript = 103;
+
+namespace {
+using approx_script_internal::kNumRanges;
+using approx_script_internal::kRangeFirst;
+using approx_script_internal::kRangeScript;
+using approx_script_internal::kRangeSizeMinusOne;
+
+uint32 Utf8ToCodepoint(const unsigned char *s, int num_bytes) {
+ switch (num_bytes) {
+ case 1:
+ return s[0];
+ case 2:
+ return ((s[0] & 0x1F) << 6) | (s[1] & 0x3F);
+ case 3:
+ return (((s[0] & 0x0F) << 12) | ((s[1] & 0x3F) << 6) | (s[2] & 0x3F));
+ case 4:
+ return (((s[0] & 0x07) << 18) | ((s[1] & 0x3F) << 12) |
+ ((s[2] & 0x3F) << 6) | (s[3] & 0x3F));
+ default:
+ SAFTM_DLOG(FATAL) << "Illegal num_bytes: " << num_bytes;
+ return 0;
+ }
+}
+
+inline int BinarySearch(uint32 codepoint, int start, int end) {
+ while (end > start + 1) {
+ // Due to the while loop condition, middle > start and middle < end. Hence,
+ // on both branches of the if below, we strictly reduce the end - start
+ // value, so we eventually get that difference below 1 and complete the
+ // while loop.
+ int middle = (start + end) / 2;
+ if (codepoint < kRangeFirst[middle]) {
+ end = middle;
+ } else {
+ start = middle;
+ }
+ }
+
+ if (end == start + 1) {
+ const uint32 range_start = kRangeFirst[start];
+ if ((codepoint >= range_start) &&
+ (codepoint <= range_start + kRangeSizeMinusOne[start])) {
+ return kRangeScript[start];
+ }
+ }
+
+ return kUnknownUscript;
+}
+} // namespace
+
+int GetApproxScript(const unsigned char *s, int num_bytes) {
+ SAFTM_DCHECK_NE(s, nullptr);
+ SAFTM_DCHECK_EQ(num_bytes,
+ utils::OneCharLen(reinterpret_cast<const char *>(s)));
+ uint32 codepoint = Utf8ToCodepoint(s, num_bytes);
+ return BinarySearch(codepoint, 0, kNumRanges);
+}
+
+int GetMaxApproxScriptResult() { return approx_script_internal::kMaxScript; }
+
+SAFTM_STATIC_REGISTRATION(ApproxScriptDetector);
+
+} // namespace mobile
+} // namespace nlp_saft
diff --git a/lang_id/script/approx-script.h b/lang_id/script/approx-script.h
new file mode 100644
index 0000000..2472e86
--- /dev/null
+++ b/lang_id/script/approx-script.h
@@ -0,0 +1,82 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_SCRIPT_APPROX_SCRIPT_H_
+#define NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_SCRIPT_APPROX_SCRIPT_H_
+
+#include "lang_id/common/utf8.h"
+#include "lang_id/script/script-detector.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+// Returns script for the UTF-8 character that starts at address |s| and has
+// |num_bytes| bytes. Note: behavior is unspecified if s points to a UTF-8
+// character that has a different number of bytes. If you don't know
+// |num_bytes|, call GetApproxScript(const char *s).
+//
+// NOTE: to keep BUILD deps small, this function returns an int, but you can
+// assume it's an enum UScriptCode (unicode/uscript.h)
+//
+// If unable to determine the script, this function returns kUnknownUscript, the
+// int value of USCRIPT_UNKNOWN from enum UScriptCode.
+int GetApproxScript(const unsigned char *s, int num_bytes);
+
+// See comments for GetApproxScript() above.
+extern const int kUnknownUscript;
+
+// Same as before, but s is a const char *pointer (no unsigned). Internally, we
+// prefer "unsigned char" (the signed status of char is ambiguous), so we cast
+// and call the previous version (with const unsigned char *).
+inline int GetApproxScript(const char *s, int num_bytes) {
+ return GetApproxScript(reinterpret_cast<const unsigned char *>(s), num_bytes);
+}
+
+// Returns script for the UTF-8 character that starts at address |s|. NOTE:
+// UTF-8 is a var-length encoding, taking between 1 and 4 bytes per Unicode
+// character. We infer the number of bytes based on s[0]. If that number is k,
+// we expect to be able to read k bytes starting from address |s|. I.e., do not
+// call this function on broken UTF-8.
+inline int GetApproxScript(const char *s) {
+ return GetApproxScript(s, utils::OneCharLen(s));
+}
+
+// Returns max value returned by the GetApproxScript() functions.
+int GetMaxApproxScriptResult();
+
+class ApproxScriptDetector : public ScriptDetector {
+ public:
+ ~ApproxScriptDetector() override = default;
+
+ // Note: the int result of this method is actually a UScriptCode enum value.
+ // We return int to match the general case from the base class ScriptDetector
+ // (some script detectors do not use UScriptCode).
+ int GetScript(const char *s, int num_bytes) const override {
+ return GetApproxScript(s, num_bytes);
+ }
+
+ int GetMaxScript() const override {
+ return GetMaxApproxScriptResult();
+ }
+
+ SAFTM_DEFINE_REGISTRATION_METHOD("approx-unicode-script-detector",
+ ApproxScriptDetector);
+};
+
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_SCRIPT_APPROX_SCRIPT_H_
diff --git a/lang_id/script/script-detector.cc b/lang_id/script/script-detector.cc
new file mode 100644
index 0000000..6c19883
--- /dev/null
+++ b/lang_id/script/script-detector.cc
@@ -0,0 +1,25 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/script/script-detector.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+SAFTM_DEFINE_CLASS_REGISTRY_NAME("script detector", ScriptDetector);
+
+} // namespace mobile
+} // namespace nlp_saft
diff --git a/lang_id/script/script-detector.h b/lang_id/script/script-detector.h
new file mode 100644
index 0000000..12a7888
--- /dev/null
+++ b/lang_id/script/script-detector.h
@@ -0,0 +1,46 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_SCRIPT_SCRIPT_DETECTOR_H_
+#define NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_SCRIPT_SCRIPT_DETECTOR_H_
+
+#include "lang_id/common/registry.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+// Base class for Unicode script detectors. Individual detectors may differ in
+// code size, speed, precision, etc. You can use the registration mechanism to
+// get the ScriptDetector that's most appropriate to your application.
+class ScriptDetector : public RegisterableClass<ScriptDetector> {
+ public:
+ virtual ~ScriptDetector() = default;
+
+ // Returns a number between 0 and GetMaxScript() (inclusive on both ends) that
+ // indicates the script of the UTF8 character that starts at address |s| and
+ // has |num_bytes|.
+ virtual int GetScript(const char *s, int num_bytes) const = 0;
+
+ // Returns max result that can be returned by GetScript().
+ virtual int GetMaxScript() const = 0;
+};
+
+SAFTM_DECLARE_CLASS_REGISTRY_NAME(ScriptDetector);
+
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_SCRIPT_SCRIPT_DETECTOR_H_
diff --git a/lang_id/script/tiny-script-detector.cc b/lang_id/script/tiny-script-detector.cc
new file mode 100644
index 0000000..2f0dd98
--- /dev/null
+++ b/lang_id/script/tiny-script-detector.cc
@@ -0,0 +1,27 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/script/tiny-script-detector.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+namespace lang_id {
+
+SAFTM_STATIC_REGISTRATION(TinyScriptDetector);
+
+} // namespace lang_id
+} // namespace mobile
+} // namespace nlp_saft
diff --git a/lang_id/script/tiny-script-detector.h b/lang_id/script/tiny-script-detector.h
new file mode 100644
index 0000000..a55da04
--- /dev/null
+++ b/lang_id/script/tiny-script-detector.h
@@ -0,0 +1,181 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_SCRIPT_TINY_SCRIPT_DETECTOR_H_
+#define NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_SCRIPT_TINY_SCRIPT_DETECTOR_H_
+
+#include "lang_id/script/script-detector.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+namespace lang_id {
+
+// Unicode scripts we care about. To get compact and fast code, we detect only
+// a few Unicode scripts that offer a strong indication about the language of
+// the text (e.g., Hiragana -> Japanese).
+enum Script {
+ // Special value to indicate internal errors in the script detection code.
+ kScriptError,
+
+ // Special values for all Unicode scripts that we do not detect. One special
+ // value for Unicode characters of 1, 2, 3, respectively 4 bytes (as we
+ // already have that information, we use it). kScriptOtherUtf8OneByte means
+ // ~Latin and kScriptOtherUtf8FourBytes means ~Han.
+ kScriptOtherUtf8OneByte,
+ kScriptOtherUtf8TwoBytes,
+ kScriptOtherUtf8ThreeBytes,
+ kScriptOtherUtf8FourBytes,
+
+ kScriptGreek,
+ kScriptCyrillic,
+ kScriptHebrew,
+ kScriptArabic,
+ kScriptHangulJamo, // Used primarily for Korean.
+ kScriptHiragana, // Used primarily for Japanese.
+ kScriptKatakana, // Used primarily for Japanese.
+
+ // Add new scripts here.
+
+ // Do not add any script after kNumRelevantScripts. This value indicates the
+ // number of elements in this enum Script (except this value) such that we can
+ // easily iterate over the scripts.
+ kNumRelevantScripts,
+};
+
+template<typename IntType>
+inline bool InRange(IntType value, IntType low, IntType hi) {
+ return (value >= low) && (value <= hi);
+}
+
+// Returns Script for the UTF8 character that starts at address p.
+// Precondition: p points to a valid UTF8 character of num_bytes bytes.
+inline Script GetScript(const unsigned char *p, int num_bytes) {
+ switch (num_bytes) {
+ case 1:
+ return kScriptOtherUtf8OneByte;
+
+ case 2: {
+ // 2-byte UTF8 characters have 11 bits of information. unsigned int has
+ // at least 16 bits (http://en.cppreference.com/w/cpp/language/types) so
+ // it's enough. It's also usually the fastest int type on the current
+ // CPU, so it's better to use than int32.
+ static const unsigned int kGreekStart = 0x370;
+
+ // Commented out (unsued in the code): kGreekEnd = 0x3FF;
+ static const unsigned int kCyrillicStart = 0x400;
+ static const unsigned int kCyrillicEnd = 0x4FF;
+ static const unsigned int kHebrewStart = 0x590;
+
+ // Commented out (unsued in the code): kHebrewEnd = 0x5FF;
+ static const unsigned int kArabicStart = 0x600;
+ static const unsigned int kArabicEnd = 0x6FF;
+ const unsigned int codepoint = ((p[0] & 0x1F) << 6) | (p[1] & 0x3F);
+ if (codepoint > kCyrillicEnd) {
+ if (codepoint >= kArabicStart) {
+ if (codepoint <= kArabicEnd) {
+ return kScriptArabic;
+ }
+ } else {
+ // At this point, codepoint < kArabicStart = kHebrewEnd + 1, so
+ // codepoint <= kHebrewEnd.
+ if (codepoint >= kHebrewStart) {
+ return kScriptHebrew;
+ }
+ }
+ } else {
+ if (codepoint >= kCyrillicStart) {
+ return kScriptCyrillic;
+ } else {
+ // At this point, codepoint < kCyrillicStart = kGreekEnd + 1, so
+ // codepoint <= kGreekEnd.
+ if (codepoint >= kGreekStart) {
+ return kScriptGreek;
+ }
+ }
+ }
+ return kScriptOtherUtf8TwoBytes;
+ }
+
+ case 3: {
+ // 3-byte UTF8 characters have 16 bits of information. unsigned int has
+ // at least 16 bits.
+ static const unsigned int kHangulJamoStart = 0x1100;
+ static const unsigned int kHangulJamoEnd = 0x11FF;
+ static const unsigned int kHiraganaStart = 0x3041;
+ static const unsigned int kHiraganaEnd = 0x309F;
+
+ // Commented out (unsued in the code): kKatakanaStart = 0x30A0;
+ static const unsigned int kKatakanaEnd = 0x30FF;
+ const unsigned int codepoint =
+ ((p[0] & 0x0F) << 12) | ((p[1] & 0x3F) << 6) | (p[2] & 0x3F);
+ if (codepoint > kHiraganaEnd) {
+ // On this branch, codepoint > kHiraganaEnd = kKatakanaStart - 1, so
+ // codepoint >= kKatakanaStart.
+ if (codepoint <= kKatakanaEnd) {
+ return kScriptKatakana;
+ }
+ } else {
+ if (codepoint >= kHiraganaStart) {
+ return kScriptHiragana;
+ } else {
+ if (InRange(codepoint, kHangulJamoStart, kHangulJamoEnd)) {
+ return kScriptHangulJamo;
+ }
+ }
+ }
+ return kScriptOtherUtf8ThreeBytes;
+ }
+
+ case 4:
+ return kScriptOtherUtf8FourBytes;
+
+ default:
+ return kScriptError;
+ }
+}
+
+// Returns Script for the UTF8 character that starts at address p. Similar to
+// the previous version of GetScript, except for "char" vs "unsigned char".
+// Most code works with "char *" pointers, ignoring the fact that char is
+// unsigned (by default) on most platforms, but signed on iOS. This code takes
+// care of making sure we always treat chars as unsigned.
+inline Script GetScript(const char *p, int num_bytes) {
+ return GetScript(reinterpret_cast<const unsigned char *>(p),
+ num_bytes);
+}
+
+class TinyScriptDetector : public ScriptDetector {
+ public:
+ ~TinyScriptDetector() override = default;
+
+ int GetScript(const char *s, int num_bytes) const override {
+ // Add the namespace in indicate that we want to call the method outside
+ // this class, instead of performing an infinite recursive call.
+ return libtextclassifier3::mobile::lang_id::GetScript(s, num_bytes);
+ }
+
+ int GetMaxScript() const override {
+ return kNumRelevantScripts - 1;
+ }
+
+ SAFTM_DEFINE_REGISTRATION_METHOD("tiny-script-detector", TinyScriptDetector);
+};
+
+} // namespace lang_id
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_SCRIPT_TINY_SCRIPT_DETECTOR_H_
diff --git a/models/actions_suggestions.model b/models/actions_suggestions.model
new file mode 100644
index 0000000..ee60ce2
--- /dev/null
+++ b/models/actions_suggestions.model
Binary files differ
diff --git a/models/lang_id.model b/models/lang_id.model
new file mode 100644
index 0000000..e577a69
--- /dev/null
+++ b/models/lang_id.model
Binary files differ
diff --git a/utils/sentencepiece/double_array_trie.cc b/utils/sentencepiece/double_array_trie.cc
new file mode 100644
index 0000000..4a6fb3c
--- /dev/null
+++ b/utils/sentencepiece/double_array_trie.cc
@@ -0,0 +1,58 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/sentencepiece/double_array_trie.h"
+#include "utils/base/logging.h"
+
+namespace libtextclassifier3 {
+
+void DoubleArrayTrie::GatherPrefixMatches(
+ StringPiece input, const std::function<void(TrieMatch)>& update_fn) const {
+ int pos = 0;
+ TC3_CHECK(pos >= 0 && pos < nodes_length_);
+ pos = offset(0);
+ for (int i = 0; i < input.size(); i++) {
+ pos ^= input[i];
+ TC3_CHECK(pos >= 0 && pos < nodes_length_);
+ if (label(pos) != input[i]) {
+ break;
+ }
+ const bool node_has_leaf = has_leaf(pos);
+ pos ^= offset(pos);
+ TC3_CHECK(pos >= 0 && pos < nodes_length_);
+ if (node_has_leaf) {
+ update_fn(TrieMatch(/*id=*/value(pos), /*match_length=*/i + 1));
+ }
+ }
+}
+
+std::vector<TrieMatch> DoubleArrayTrie::FindAllPrefixMatches(
+ StringPiece input) const {
+ std::vector<TrieMatch> result;
+ GatherPrefixMatches(
+ input, [&result](const TrieMatch match) { result.push_back(match); });
+ return result;
+}
+
+TrieMatch DoubleArrayTrie::LongestPrefixMatch(StringPiece input) const {
+ TrieMatch longest_match;
+ GatherPrefixMatches(input, [&longest_match](const TrieMatch match) {
+ longest_match = match;
+ });
+ return longest_match;
+}
+
+} // namespace libtextclassifier3
diff --git a/utils/sentencepiece/double_array_trie.h b/utils/sentencepiece/double_array_trie.h
new file mode 100644
index 0000000..050c466
--- /dev/null
+++ b/utils/sentencepiece/double_array_trie.h
@@ -0,0 +1,75 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_H_
+#define LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_H_
+
+#include <functional>
+#include <vector>
+
+#include "utils/sentencepiece/matcher.h"
+#include "utils/strings/stringpiece.h"
+
+namespace libtextclassifier3 {
+
+// A trie node specifies a node in the tree, either an intermediate node or
+// a leaf node.
+// A leaf node contains the id as an int of the string match. This id is encoded
+// in the lower 30 bits, thus the number of distinct ids is 2^30.
+// An intermediate node has an associated label and an offset to it's children.
+// The label is encoded in the least significant byte and must match the input
+// character during matching.
+typedef unsigned int TrieNode;
+
+// A memory mappable trie, compatible with Darts::DoubleArray.
+class DoubleArrayTrie : public SentencePieceMatcher {
+ public:
+ // nodes and nodes_length specify the array of the nodes of the trie.
+ DoubleArrayTrie(const TrieNode* nodes, const int nodes_length)
+ : nodes_(nodes), nodes_length_(nodes_length) {}
+
+ // Find matches that are prefixes of a string.
+ std::vector<TrieMatch> FindAllPrefixMatches(StringPiece input) const override;
+
+ // Find the longest prefix match of a string.
+ TrieMatch LongestPrefixMatch(StringPiece input) const override;
+
+ private:
+ // Returns whether a node as a leaf as a child.
+ bool has_leaf(int i) const { return nodes_[i] & 0x100; }
+
+ // Available when a node is a leaf.
+ int value(int i) const { return static_cast<int>(nodes_[i] & 0x7fffffff); }
+
+ // Label associated with a node.
+ // A leaf node will have the MSB set and thus return an invalid label.
+ unsigned int label(int i) const { return nodes_[i] & 0x800000ff; }
+
+ // Returns offset to children.
+ unsigned int offset(int i) const {
+ return (nodes_[i] >> 10) << ((nodes_[i] & 0x200) >> 6);
+ }
+
+ void GatherPrefixMatches(
+ StringPiece input, const std::function<void(TrieMatch)>& update_fn) const;
+
+ const TrieNode* nodes_;
+ const int nodes_length_;
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_H_
diff --git a/utils/sentencepiece/double_array_trie_test.cc b/utils/sentencepiece/double_array_trie_test.cc
new file mode 100644
index 0000000..99fc6d0
--- /dev/null
+++ b/utils/sentencepiece/double_array_trie_test.cc
@@ -0,0 +1,66 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <fstream>
+#include <string>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+#include "utils/sentencepiece/double_array_trie.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+std::string GetTestConfigPath() {
+ return "";
+}
+
+TEST(DoubleArrayTest, Lookup) {
+ // Test trie that contains pieces "hell", "hello", "o", "there".
+ std::ifstream test_config_stream(GetTestConfigPath());
+ std::string config((std::istreambuf_iterator<char>(test_config_stream)),
+ (std::istreambuf_iterator<char>()));
+ DoubleArrayTrie trie(reinterpret_cast<const TrieNode*>(config.data()),
+ config.size() / sizeof(TrieNode));
+
+ auto matches = trie.FindAllPrefixMatches("hello there");
+ EXPECT_EQ(matches.size(), 2);
+ EXPECT_EQ(matches[0].id, 0 /*hell*/);
+ EXPECT_EQ(matches[0].match_length, 4 /*hell*/);
+ EXPECT_EQ(matches[1].id, 1 /*hello*/);
+ EXPECT_EQ(matches[1].match_length, 5 /*hello*/);
+
+ matches = trie.FindAllPrefixMatches("he");
+ EXPECT_EQ(matches.size(), 0);
+
+ matches = trie.FindAllPrefixMatches("abcd");
+ EXPECT_EQ(matches.size(), 0);
+
+ matches = trie.FindAllPrefixMatches("");
+ EXPECT_EQ(matches.size(), 0);
+
+ EXPECT_THAT(trie.FindAllPrefixMatches("hi there"), testing::IsEmpty());
+
+ EXPECT_EQ(trie.LongestPrefixMatch("hella there").id, 0 /*hell*/);
+ EXPECT_EQ(trie.LongestPrefixMatch("hello there").id, 1 /*hello*/);
+ EXPECT_EQ(trie.LongestPrefixMatch("abcd").id, -1);
+ EXPECT_EQ(trie.LongestPrefixMatch("").id, -1);
+}
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/utils/sentencepiece/encoder.cc b/utils/sentencepiece/encoder.cc
new file mode 100644
index 0000000..8f218ec
--- /dev/null
+++ b/utils/sentencepiece/encoder.cc
@@ -0,0 +1,87 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/sentencepiece/encoder.h"
+
+namespace libtextclassifier3 {
+
+std::vector<int> Encoder::Encode(StringPiece normalized_text) const {
+ const int len = normalized_text.size();
+ if (len <= 0) {
+ return {start_code_, end_code_};
+ }
+ // We use `previous_pos` to indicate whether a dynamic programming state was
+ // reachable.
+ std::vector<SegmentationEntry> segmentation(
+ len + 1, {/*score=*/0, /*previous_pos=*/-1, /*piece_id=*/-1,
+ /*num_pieces=*/0});
+ for (int i = 0; i < len; i++) {
+ // State couldn't be reached.
+ if (i > 0 && segmentation[i].previous_pos < 0) {
+ // Advance position.
+ normalized_text.RemovePrefix(1);
+ continue;
+ }
+ // Check whether we can use the unknown token.
+ if (unknown_code_ >= 0) {
+ const int pos = i + 1;
+ const float unknown_penalty = segmentation[i].score + unknown_score_;
+ if (segmentation[pos].previous_pos < 0 ||
+ segmentation[pos].score < unknown_penalty) {
+ // Merge multiple unknown tokens into one.
+ if (segmentation[i].piece_id == unknown_code_) {
+ segmentation[pos] = {/*score=*/unknown_penalty,
+ /*previous_pos=*/segmentation[i].previous_pos,
+ /*piece_id=*/unknown_code_,
+ /*num_pieces=*/segmentation[i].num_pieces};
+ } else {
+ segmentation[pos] = {/*score=*/unknown_penalty,
+ /*previous_pos=*/i,
+ /*piece_id=*/unknown_code_,
+ /*num_pieces=*/segmentation[i].num_pieces + 1};
+ }
+ }
+ }
+ for (const auto& match : matcher_->FindAllPrefixMatches(normalized_text)) {
+ TC3_CHECK(match.id >= 0 && match.id < num_pieces_);
+ const int pos = i + match.match_length;
+ const float candidate_score = segmentation[i].score + scores_[match.id];
+ if (segmentation[pos].previous_pos < 0 ||
+ segmentation[pos].score < candidate_score) {
+ segmentation[pos] = {/*score=*/candidate_score, /*previous_pos=*/i,
+ /*piece_id=*/match.id + encoding_offset_,
+ /*num_pieces=*/segmentation[i].num_pieces + 1};
+ }
+ }
+ // Advance position.
+ normalized_text.RemovePrefix(1);
+ }
+ if (segmentation[len].num_pieces <= 0) {
+ return {start_code_, end_code_};
+ }
+ const int num_pieces = segmentation[len].num_pieces;
+ std::vector<int> result(num_pieces + 2);
+ result[num_pieces + 1] = end_code_;
+ int pos = len;
+ for (int i = num_pieces; i > 0; i--) {
+ result[i] = segmentation[pos].piece_id;
+ pos = segmentation[pos].previous_pos;
+ }
+ result[0] = start_code_;
+ return result;
+}
+
+} // namespace libtextclassifier3
diff --git a/utils/sentencepiece/encoder.h b/utils/sentencepiece/encoder.h
new file mode 100644
index 0000000..0f1bfd3
--- /dev/null
+++ b/utils/sentencepiece/encoder.h
@@ -0,0 +1,88 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_ENCODER_H_
+#define LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_ENCODER_H_
+
+#include <vector>
+
+#include "utils/base/logging.h"
+#include "utils/sentencepiece/matcher.h"
+#include "utils/strings/stringpiece.h"
+
+namespace libtextclassifier3 {
+
+// Encoder to segment/tokenize strings into pieces such that the sum of the
+// scores of the pieces used is maximized.
+class Encoder {
+ public:
+ // matcher: the list of valid sentence pieces represented as a matcher, e.g.
+ // a trie.
+ // num_pieces: the number of pieces in the trie.
+ // pieces_scores: the scores of the individual pieces.
+ // start_code: code that is used as encoding of the start of input.
+ // end_code: code that is used as encoding of the end of input.
+ // encoding_offset: value added to the sentence piece ids to make them
+ // not interesecting with start_code and end_code.
+ // unknown_code: code that is used for out-of-dictionary characters.
+ // unknown_score: the penality score associated with the unknown code.
+ Encoder(const SentencePieceMatcher* matcher, const int num_pieces,
+ const float* pieces_scores, int start_code = 0, int end_code = 1,
+ int encoding_offset = 2, int unknown_code = -1,
+ float unknown_score = 0.f)
+ : num_pieces_(num_pieces),
+ scores_(pieces_scores),
+ matcher_(matcher),
+ start_code_(start_code),
+ end_code_(end_code),
+ encoding_offset_(encoding_offset),
+ unknown_code_(unknown_code),
+ unknown_score_(unknown_score) {}
+
+ // Segment the input so that the total score of the pieces used is maximized.
+ // This is a simplified implementation of the general Viterbi algorithm,
+ // assuming independence between individual pieces.
+ std::vector<int> Encode(StringPiece normalized_text) const;
+
+ private:
+ // State in the dynamic programming algorithm.
+ struct SegmentationEntry {
+ // Accumulated score.
+ float score;
+
+ // Position before last piece.
+ int previous_pos;
+
+ // Last piece used.
+ int piece_id;
+
+ // Total number of pieces used.
+ int num_pieces;
+ };
+
+ const int num_pieces_;
+ const float* scores_;
+ const SentencePieceMatcher* matcher_;
+ const int start_code_;
+ const int end_code_;
+ const int encoding_offset_;
+ const int unknown_code_;
+ const int unknown_score_;
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_ENCODER_H_
diff --git a/utils/sentencepiece/encoder_test.cc b/utils/sentencepiece/encoder_test.cc
new file mode 100644
index 0000000..6bc9aeb
--- /dev/null
+++ b/utils/sentencepiece/encoder_test.cc
@@ -0,0 +1,82 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <memory>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+#include "utils/sentencepiece/encoder.h"
+#include "utils/sentencepiece/sorted_strings_table.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+using testing::ElementsAre;
+using testing::IsEmpty;
+
+TEST(EncoderTest, SimpleTokenization) {
+ const char pieces[] = "hell\0hello\0o\0there\0";
+ const int offsets[] = {0, 5, 11, 13};
+ float scores[] = {-0.5, -1.0, -10.0, -1.0};
+ std::unique_ptr<SentencePieceMatcher> matcher(new SortedStringsTable(
+ /*num_pieces=*/4, offsets, StringPiece(pieces, 18)));
+ const Encoder encoder(matcher.get(),
+ /*num_pieces=*/4, scores);
+
+ EXPECT_THAT(encoder.Encode("hellothere"), ElementsAre(0, 3, 5, 1));
+
+ // Make probability of hello very low:
+ // hello gets now tokenized as hell + o.
+ scores[1] = -100.0;
+ EXPECT_THAT(encoder.Encode("hellothere"), ElementsAre(0, 2, 4, 5, 1));
+}
+
+TEST(EncoderTest, HandlesEdgeCases) {
+ const char pieces[] = "hell\0hello\0o\0there\0";
+ const int offsets[] = {0, 5, 11, 13};
+ float scores[] = {-0.5, -1.0, -10.0, -1.0};
+ std::unique_ptr<SentencePieceMatcher> matcher(new SortedStringsTable(
+ /*num_pieces=*/4, offsets, StringPiece(pieces, 18)));
+ const Encoder encoder(matcher.get(),
+ /*num_pieces=*/4, scores);
+ EXPECT_THAT(encoder.Encode("hellhello"), ElementsAre(0, 2, 3, 1));
+ EXPECT_THAT(encoder.Encode("hellohell"), ElementsAre(0, 3, 2, 1));
+ EXPECT_THAT(encoder.Encode(""), ElementsAre(0, 1));
+ EXPECT_THAT(encoder.Encode("hellathere"), ElementsAre(0, 1));
+}
+
+TEST(EncoderTest, HandlesOutOfDictionary) {
+ const char pieces[] = "hell\0hello\0o\0there\0";
+ const int offsets[] = {0, 5, 11, 13};
+ float scores[] = {-0.5, -1.0, -10.0, -1.0};
+ std::unique_ptr<SentencePieceMatcher> matcher(new SortedStringsTable(
+ /*num_pieces=*/4, offsets, StringPiece(pieces, 18)));
+ const Encoder encoder(matcher.get(),
+ /*num_pieces=*/4, scores,
+ /*start_code=*/0, /*end_code=*/1,
+ /*encoding_offset=*/3, /*unknown_code=*/2,
+ /*unknown_score=*/-100.0);
+ EXPECT_THAT(encoder.Encode("hellhello"), ElementsAre(0, 3, 4, 1));
+ EXPECT_THAT(encoder.Encode("hellohell"), ElementsAre(0, 4, 3, 1));
+ EXPECT_THAT(encoder.Encode(""), ElementsAre(0, 1));
+ EXPECT_THAT(encoder.Encode("hellathere"),
+ ElementsAre(0, /*hell*/ 3, /*unknown*/ 2, /*there*/ 6, 1));
+}
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/utils/sentencepiece/matcher.h b/utils/sentencepiece/matcher.h
new file mode 100644
index 0000000..b538d69
--- /dev/null
+++ b/utils/sentencepiece/matcher.h
@@ -0,0 +1,46 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_MATCHER_H_
+#define LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_MATCHER_H_
+
+#include <vector>
+#include "utils/strings/stringpiece.h"
+
+namespace libtextclassifier3 {
+
+struct TrieMatch {
+ TrieMatch() {}
+ TrieMatch(int id, int match_length) : id(id), match_length(match_length) {}
+ int id = -1;
+ int match_length = -1;
+};
+
+class SentencePieceMatcher {
+ public:
+ virtual ~SentencePieceMatcher() {}
+
+ // Find matches that are prefixes of a string.
+ virtual std::vector<TrieMatch> FindAllPrefixMatches(
+ StringPiece input) const = 0;
+
+ // Find the longest prefix match of a string.
+ virtual TrieMatch LongestPrefixMatch(StringPiece input) const = 0;
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_MATCHER_H_
diff --git a/utils/sentencepiece/normalizer.cc b/utils/sentencepiece/normalizer.cc
new file mode 100644
index 0000000..1dd20da
--- /dev/null
+++ b/utils/sentencepiece/normalizer.cc
@@ -0,0 +1,137 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/sentencepiece/normalizer.h"
+
+#include "utils/base/logging.h"
+#include "utils/strings/utf8.h"
+
+namespace libtextclassifier3 {
+
+std::string SentencePieceNormalizer::Normalize(StringPiece input) const {
+ std::string normalized;
+
+ // Ignores heading space.
+ if (remove_extra_whitespaces_) {
+ while (!input.empty()) {
+ const auto suffix_and_length = NormalizePrefix(input);
+ if (suffix_and_length.second <= 0) {
+ TC3_LOG(ERROR) << "Consumed string is empty.";
+ return normalized;
+ }
+ if (suffix_and_length.first.size() != 1 ||
+ suffix_and_length.first[0] != ' ') {
+ break;
+ }
+ input.RemovePrefix(suffix_and_length.second);
+ }
+ }
+
+ if (input.empty()) {
+ return normalized;
+ }
+
+ // Reserves the output buffer to avoid re-allocations.
+ const int kReservedSize = input.size() * 3;
+ normalized.reserve(kReservedSize);
+
+ // Replaces white space with U+2581 (LOWER ONE EIGHT BLOCK)
+ // if escape_whitespaces() is set (default = true).
+ const StringPiece kSpaceSymbol = "\xe2\x96\x81";
+
+ // Adds a space symbol as a prefix (default is true)
+ // With this prefix, "world" and "hello world" are converted into
+ // "_world" and "_hello_world", which help the trainer to extract
+ // "_world" as one symbol.
+ if (add_dummy_prefix_) {
+ if (escape_whitespaces_) {
+ normalized.append(kSpaceSymbol.data(), kSpaceSymbol.size());
+ } else {
+ normalized.append(" ");
+ }
+ }
+
+ bool is_prev_space = remove_extra_whitespaces_;
+ while (!input.empty()) {
+ auto p = NormalizePrefix(input);
+ if (p.second <= 0) {
+ TC3_LOG(ERROR) << "Consumed string is empty.";
+ return normalized;
+ }
+
+ StringPiece sp = p.first;
+
+ // Removes heading spaces in sentence piece,
+ // if the previous sentence piece ends with whitespace.
+ while (is_prev_space && ConsumePrefix(&sp, " ")) {
+ }
+
+ if (!sp.empty()) {
+ const char *data = sp.data();
+ for (int n = 0; n < sp.size(); ++n) {
+ if (escape_whitespaces_ && data[n] == ' ') {
+ normalized.append(kSpaceSymbol.data(), kSpaceSymbol.size());
+ } else {
+ normalized += data[n];
+ }
+ }
+ // Checks whether the last character of sp is whitespace.
+ is_prev_space = EndsWith(sp, " ");
+ }
+ input.RemovePrefix(p.second);
+ is_prev_space = is_prev_space && remove_extra_whitespaces_;
+ }
+
+ // Ignores tailing space.
+ if (remove_extra_whitespaces_) {
+ const StringPiece space = escape_whitespaces_ ? kSpaceSymbol : " ";
+ while (EndsWith(normalized, space)) {
+ const int length = normalized.size() - space.size();
+ normalized.resize(length);
+ }
+ }
+ return normalized;
+}
+
+std::pair<StringPiece, int> SentencePieceNormalizer::NormalizePrefix(
+ StringPiece input) const {
+ std::pair<StringPiece, int> result;
+ if (input.empty()) return result;
+ const TrieMatch match = charsmap_trie_.LongestPrefixMatch(input);
+ const bool no_match = match.match_length <= 0;
+ if (no_match) {
+ const int char_length = ValidUTF8CharLength(input.data(), input.size());
+ if (char_length <= 0) {
+ // Found a malformed utf8.
+ // The rune is set to be 0xFFFD (REPLACEMENT CHARACTER),
+ // which is a valid Unicode of three bytes in utf8,
+ // but here we only consume one byte.
+ static const char kReplacementChar[] = "\xEF\xBF\xBD";
+ result.first = StringPiece(kReplacementChar, 3);
+ result.second = 1; // Consumes 1 byte, buts emit 0xFFFD.
+ } else {
+ result.first = StringPiece(input.data(), char_length);
+ result.second = char_length;
+ }
+ } else {
+ TC3_CHECK(match.id >= 0 && match.id < charsmap_normalized_.size());
+ result.first = StringPiece(&charsmap_normalized_.data()[match.id]);
+ result.second = match.match_length;
+ }
+ return result;
+}
+
+} // namespace libtextclassifier3
diff --git a/utils/sentencepiece/normalizer.h b/utils/sentencepiece/normalizer.h
new file mode 100644
index 0000000..227e09b
--- /dev/null
+++ b/utils/sentencepiece/normalizer.h
@@ -0,0 +1,79 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_NORMALIZER_H_
+#define LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_NORMALIZER_H_
+
+#include <memory>
+#include <string>
+
+#include "utils/sentencepiece/double_array_trie.h"
+#include "utils/strings/stringpiece.h"
+
+namespace libtextclassifier3 {
+
+// Normalizer implements a simple text normalizer with user-defined
+// string-to-string rules and leftmost longest matching.
+class SentencePieceNormalizer {
+ public:
+ // charsmap_trie and charsmap_normalized specify the normalization/replacement
+ // string-to-string rules in the following way:
+ // A match in the trie for a string will return the offset in
+ // charsmap_normalized that contains the replacement string.
+ //
+ // add_dummy_prefix: Whether to add dummy whitespace at the beginning of the
+ // text in order to treat "world" in "world" and "hello world" uniformly.
+ //
+ // remove_extra_whitespaces: Whether to remove leading, trailing and duplicate
+ // internal whitespace.
+ //
+ // escape_whitespaces: Whether to replace whitespace with a meta symbol.
+ SentencePieceNormalizer(const DoubleArrayTrie &charsmap_trie,
+ StringPiece charsmap_normalized,
+ bool add_dummy_prefix = true,
+ bool remove_extra_whitespaces = true,
+ bool escape_whitespaces = true)
+ : charsmap_trie_(charsmap_trie),
+ charsmap_normalized_(charsmap_normalized),
+ add_dummy_prefix_(add_dummy_prefix),
+ remove_extra_whitespaces_(remove_extra_whitespaces),
+ escape_whitespaces_(escape_whitespaces) {}
+
+ // Normalizes a plain utf8 string into an internal representation for
+ // Sentencepiece model.
+ std::string Normalize(StringPiece input) const;
+
+ private:
+ // Normalizes the prefix of `input` and returns the pair of
+ // normalized prefix and the length of the prefix of `input` processed in the
+ // normalization.
+ std::pair<StringPiece, int> NormalizePrefix(StringPiece input) const;
+
+ // Internal trie for efficient longest prefix string matching.
+ DoubleArrayTrie charsmap_trie_;
+
+ // "\0" delimitered concatenated normalized strings.
+ // the value of `charsmap_trie_` stores offsets into this string.
+ StringPiece charsmap_normalized_;
+
+ const bool add_dummy_prefix_;
+ const bool remove_extra_whitespaces_;
+ const bool escape_whitespaces_;
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_NORMALIZER_H_
diff --git a/utils/sentencepiece/normalizer_test.cc b/utils/sentencepiece/normalizer_test.cc
new file mode 100644
index 0000000..f6018ab
--- /dev/null
+++ b/utils/sentencepiece/normalizer_test.cc
@@ -0,0 +1,130 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <fstream>
+#include <string>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+#include "utils/sentencepiece/double_array_trie.h"
+#include "utils/sentencepiece/normalizer.h"
+#include "utils/sentencepiece/test_utils.h"
+#include "utils/strings/stringpiece.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+std::string GetTestConfigPath() {
+ return "";
+}
+
+TEST(NormalizerTest, NormalizesAsReferenceNormalizer) {
+ std::ifstream test_config_stream(GetTestConfigPath());
+ std::string config((std::istreambuf_iterator<char>(test_config_stream)),
+ (std::istreambuf_iterator<char>()));
+ SentencePieceNormalizer normalizer =
+ NormalizerFromSpec(config, /*add_dummy_prefix=*/true,
+ /*remove_extra_whitespaces=*/true,
+ /*escape_whitespaces=*/true);
+
+ EXPECT_EQ(normalizer.Normalize("hello there"), "▁hello▁there");
+
+ // Redundant whitespace.
+ EXPECT_EQ(normalizer.Normalize("when is the world cup?"),
+ "▁when▁is▁the▁world▁cup?");
+
+ // Different whitespace.
+ EXPECT_EQ(normalizer.Normalize("general\tkenobi"), "▁general▁kenobi");
+
+ // NFKC char to multi-char normalization.
+ EXPECT_EQ(normalizer.Normalize("㍿"), "▁株式会社");
+
+ // Half width katakana, character composition happens.
+ EXPECT_EQ(normalizer.Normalize(" グーグル "), "▁グーグル");
+
+ // NFKC char to char normalization.
+ EXPECT_EQ(normalizer.Normalize("①②③"), "▁123");
+}
+
+TEST(NormalizerTest, NoDummyPrefix) {
+ std::ifstream test_config_stream(GetTestConfigPath());
+ std::string config((std::istreambuf_iterator<char>(test_config_stream)),
+ (std::istreambuf_iterator<char>()));
+ SentencePieceNormalizer normalizer =
+ NormalizerFromSpec(config, /*add_dummy_prefix=*/false,
+ /*remove_extra_whitespaces=*/true,
+ /*escape_whitespaces=*/true);
+
+ EXPECT_EQ(normalizer.Normalize("hello there"), "hello▁there");
+
+ // Redundant whitespace.
+ EXPECT_EQ(normalizer.Normalize("when is the world cup?"),
+ "when▁is▁the▁world▁cup?");
+
+ // Different whitespace.
+ EXPECT_EQ(normalizer.Normalize("general\tkenobi"), "general▁kenobi");
+
+ // NFKC char to multi-char normalization.
+ EXPECT_EQ(normalizer.Normalize("㍿"), "株式会社");
+
+ // Half width katakana, character composition happens.
+ EXPECT_EQ(normalizer.Normalize(" グーグル "), "グーグル");
+
+ // NFKC char to char normalization.
+ EXPECT_EQ(normalizer.Normalize("①②③"), "123");
+}
+
+TEST(NormalizerTest, NoRemoveExtraWhitespace) {
+ std::ifstream test_config_stream(GetTestConfigPath());
+ std::string config((std::istreambuf_iterator<char>(test_config_stream)),
+ (std::istreambuf_iterator<char>()));
+ SentencePieceNormalizer normalizer =
+ NormalizerFromSpec(config, /*add_dummy_prefix=*/false,
+ /*remove_extra_whitespaces=*/false,
+ /*escape_whitespaces=*/true);
+
+ EXPECT_EQ(normalizer.Normalize("hello there"), "hello▁there");
+
+ // Redundant whitespace.
+ EXPECT_EQ(normalizer.Normalize("when is the world cup?"),
+ "when▁is▁▁the▁▁world▁cup?");
+
+ // Different whitespace.
+ EXPECT_EQ(normalizer.Normalize("general\tkenobi"), "general▁kenobi");
+}
+
+TEST(NormalizerTest, NoEscapeWhitespaces) {
+ std::ifstream test_config_stream(GetTestConfigPath());
+ std::string config((std::istreambuf_iterator<char>(test_config_stream)),
+ (std::istreambuf_iterator<char>()));
+ SentencePieceNormalizer normalizer =
+ NormalizerFromSpec(config, /*add_dummy_prefix=*/false,
+ /*remove_extra_whitespaces=*/false,
+ /*escape_whitespaces=*/false);
+
+ EXPECT_EQ(normalizer.Normalize("hello there"), "hello there");
+
+ // Redundant whitespace.
+ EXPECT_EQ(normalizer.Normalize("when is the world cup?"),
+ "when is the world cup?");
+
+ // Different whitespace.
+ EXPECT_EQ(normalizer.Normalize("general\tkenobi"), "general kenobi");
+}
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/utils/sentencepiece/sorted_strings_table.cc b/utils/sentencepiece/sorted_strings_table.cc
new file mode 100644
index 0000000..332ce46
--- /dev/null
+++ b/utils/sentencepiece/sorted_strings_table.cc
@@ -0,0 +1,107 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/sentencepiece/sorted_strings_table.h"
+
+#include <algorithm>
+#include "utils/base/logging.h"
+
+namespace libtextclassifier3 {
+
+void SortedStringsTable::GatherPrefixMatches(
+ StringPiece input, const std::function<void(TrieMatch)>& update_fn) const {
+ int left = 0;
+ int right = num_pieces_;
+ int span_size = right - left;
+ int match_length = 0;
+
+ // Loop invariant:
+ // at the ith iteration, all strings from `left` ... `right` match the input
+ // on the first `match_length` characters.
+ while (span_size > use_linear_scan_threshold_) {
+ if (match_length >= input.length()) {
+ return;
+ }
+
+ // We find the possible range of pieces in `left` ... `right` matching the
+ // `match_length` + 1 character with two binary searches:
+ // `lower_bound` to find the start of the range of matching pieces.
+ // `upper_bound` to find the non-inclusive end of the range.
+ left = (std::lower_bound(
+ offsets_ + left, offsets_ + right, input[match_length],
+ [this, match_length](int piece_offset, int c) -> bool {
+ return pieces_[piece_offset + match_length] < c;
+ }) -
+ offsets_);
+ right = (std::upper_bound(
+ offsets_ + left, offsets_ + right, input[match_length],
+ [this, match_length](int c, int piece_offset) -> bool {
+ return c < pieces_[piece_offset + match_length];
+ }) -
+ offsets_);
+ span_size = right - left;
+ if (span_size <= 0) {
+ return;
+ }
+ ++match_length;
+
+ // Due to the loop invariant and the fact that the strings are sorted, there
+ // can only be one piece matching completely now, namely at left.
+ if (pieces_[offsets_[left] + match_length] == 0) {
+ update_fn(TrieMatch(/*id=*/left,
+ /*match_length=*/match_length));
+ left++;
+ }
+ }
+
+ // Use linear scan for small problem instances.
+ // By the loop invariant characters 0...`match_length` of all pieces in
+ // in `left`...`right` match the input on 0...`match_length`.
+ for (int i = left; i < right; i++) {
+ bool matches = true;
+ int piece_match_length = match_length;
+ for (int k = offsets_[i] + piece_match_length; pieces_[k] != 0; k++) {
+ if (match_length >= input.size() ||
+ input[piece_match_length] != pieces_[k]) {
+ matches = false;
+ break;
+ }
+ piece_match_length++;
+ }
+ if (matches) {
+ update_fn(TrieMatch(/*id=*/i,
+ /*match_length=*/piece_match_length));
+ }
+ }
+}
+
+std::vector<TrieMatch> SortedStringsTable::FindAllPrefixMatches(
+ StringPiece input) const {
+ std::vector<TrieMatch> result;
+ GatherPrefixMatches(
+ input, [&result](const TrieMatch match) { result.push_back(match); });
+ return result;
+}
+
+TrieMatch SortedStringsTable::LongestPrefixMatch(StringPiece input) const {
+ TrieMatch longest_match;
+ GatherPrefixMatches(input, [&longest_match](const TrieMatch match) {
+ longest_match = match;
+ });
+ return longest_match;
+}
+
+} // namespace libtextclassifier3
diff --git a/utils/sentencepiece/sorted_strings_table.h b/utils/sentencepiece/sorted_strings_table.h
new file mode 100644
index 0000000..82cda5c
--- /dev/null
+++ b/utils/sentencepiece/sorted_strings_table.h
@@ -0,0 +1,65 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_SORTED_STRINGS_TABLE_H_
+#define LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_SORTED_STRINGS_TABLE_H_
+
+#include <functional>
+#include <vector>
+
+#include "utils/sentencepiece/matcher.h"
+#include "utils/strings/stringpiece.h"
+
+namespace libtextclassifier3 {
+
+// A matcher to find string pieces matching prefixes of an input string.
+// The list of reference strings are kept in sorted order in a zero separated
+// string.
+// binary search is used to find all prefix matches.
+// num_pieces: Number of sentence pieces.
+// offsets: Offsets into `pieces` where a string starts.
+// pieces: String pieces, concatenated in sorted order and zero byte separated.
+// use_linear_scan_threshold: Minimum size of binary search range before
+// switching to a linear sweep for prefix match testing.
+class SortedStringsTable : public SentencePieceMatcher {
+ public:
+ SortedStringsTable(const int num_pieces, const int* offsets,
+ StringPiece pieces,
+ const int use_linear_scan_threshold = 10)
+ : num_pieces_(num_pieces),
+ offsets_(offsets),
+ pieces_(pieces),
+ use_linear_scan_threshold_(use_linear_scan_threshold) {}
+
+ // Find matches that are prefixes of a string.
+ std::vector<TrieMatch> FindAllPrefixMatches(StringPiece input) const override;
+
+ // Find the longest prefix match of a string.
+ TrieMatch LongestPrefixMatch(StringPiece input) const override;
+
+ private:
+ void GatherPrefixMatches(
+ StringPiece input, const std::function<void(TrieMatch)>& update_fn) const;
+
+ const int num_pieces_;
+ const int* offsets_;
+ const StringPiece pieces_;
+ const int use_linear_scan_threshold_;
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_SORTED_STRINGS_TABLE_H_
diff --git a/utils/sentencepiece/sorted_strings_table_test.cc b/utils/sentencepiece/sorted_strings_table_test.cc
new file mode 100644
index 0000000..61a0ef4
--- /dev/null
+++ b/utils/sentencepiece/sorted_strings_table_test.cc
@@ -0,0 +1,59 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+#include "utils/sentencepiece/sorted_strings_table.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+TEST(SortedStringsTest, Lookup) {
+ const char pieces[] = "hell\0hello\0o\0there\0";
+ const int offsets[] = {0, 5, 11, 13};
+
+ SortedStringsTable table(/*num_pieces=*/4, offsets, StringPiece(pieces, 18),
+ /*use_linear_scan_threshold=*/1);
+
+ auto matches = table.FindAllPrefixMatches("hello there");
+ EXPECT_EQ(matches.size(), 2);
+ EXPECT_EQ(matches[0].id, 0 /*hell*/);
+ EXPECT_EQ(matches[0].match_length, 4 /*hell*/);
+ EXPECT_EQ(matches[1].id, 1 /*hello*/);
+ EXPECT_EQ(matches[1].match_length, 5 /*hello*/);
+
+ matches = table.FindAllPrefixMatches("he");
+ EXPECT_EQ(matches.size(), 0);
+
+ matches = table.FindAllPrefixMatches("abcd");
+ EXPECT_EQ(matches.size(), 0);
+
+ matches = table.FindAllPrefixMatches("");
+ EXPECT_EQ(matches.size(), 0);
+
+ EXPECT_THAT(table.FindAllPrefixMatches("hi there"), testing::IsEmpty());
+
+ EXPECT_EQ(table.LongestPrefixMatch("hella there").id, 0 /*hell*/);
+ EXPECT_EQ(table.LongestPrefixMatch("hello there").id, 1 /*hello*/);
+ EXPECT_EQ(table.LongestPrefixMatch("abcd").id, -1);
+ EXPECT_EQ(table.LongestPrefixMatch("").id, -1);
+}
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/utils/sentencepiece/test_utils.cc b/utils/sentencepiece/test_utils.cc
new file mode 100644
index 0000000..1ed2bf3
--- /dev/null
+++ b/utils/sentencepiece/test_utils.cc
@@ -0,0 +1,42 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/sentencepiece/test_utils.h"
+
+#include <memory>
+
+#include "utils/base/integral_types.h"
+#include "utils/sentencepiece/double_array_trie.h"
+#include "utils/strings/stringpiece.h"
+
+namespace libtextclassifier3 {
+
+SentencePieceNormalizer NormalizerFromSpec(StringPiece spec,
+ bool add_dummy_prefix,
+ bool remove_extra_whitespaces,
+ bool escape_whitespaces) {
+ const uint32 trie_blob_size = reinterpret_cast<const uint32*>(spec.data())[0];
+ spec.RemovePrefix(sizeof(trie_blob_size));
+ const TrieNode* trie_blob = reinterpret_cast<const TrieNode*>(spec.data());
+ spec.RemovePrefix(trie_blob_size);
+ const int num_nodes = trie_blob_size / sizeof(TrieNode);
+ return SentencePieceNormalizer(
+ DoubleArrayTrie(trie_blob, num_nodes),
+ /*charsmap_normalized=*/StringPiece(spec.data(), spec.size()),
+ add_dummy_prefix, remove_extra_whitespaces, escape_whitespaces);
+}
+
+} // namespace libtextclassifier3
diff --git a/utils/sentencepiece/test_utils.h b/utils/sentencepiece/test_utils.h
new file mode 100644
index 0000000..0c833da
--- /dev/null
+++ b/utils/sentencepiece/test_utils.h
@@ -0,0 +1,35 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_TEST_UTILS_H_
+#define LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_TEST_UTILS_H_
+
+#include <string>
+#include <vector>
+
+#include "utils/sentencepiece/normalizer.h"
+#include "utils/strings/stringpiece.h"
+
+namespace libtextclassifier3 {
+
+SentencePieceNormalizer NormalizerFromSpec(StringPiece spec,
+ bool add_dummy_prefix,
+ bool remove_extra_whitespaces,
+ bool escape_whitespaces);
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_TEST_UTILS_H_
diff --git a/utils/tflite-model-executor.cc b/utils/tflite-model-executor.cc
index 4eb0449..4877d4a 100644
--- a/utils/tflite-model-executor.cc
+++ b/utils/tflite-model-executor.cc
@@ -30,10 +30,17 @@
} // namespace ops
} // namespace tflite
+#ifdef TC3_WITH_ACTIONS_OPS
+#include "utils/tflite/dist_diversification.h"
+#include "utils/tflite/text_encoder.h"
+// This function is defined in the file generated by :smart_reply_ops target.
+void RegisterSelectedOps(::tflite::MutableOpResolver* resolver);
+#else
void RegisterSelectedOps(::tflite::MutableOpResolver* resolver) {
resolver->AddBuiltin(::tflite::BuiltinOperator_FULLY_CONNECTED,
::tflite::ops::builtin::Register_FULLY_CONNECTED());
}
+#endif // TC3_WITH_ACTIONS_OPS
namespace libtextclassifier3 {
@@ -51,6 +58,12 @@
#else
std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver> resolver(
new tflite::ops::builtin::BuiltinOpResolver);
+#ifdef TC3_WITH_ACTIONS_OPS
+ resolver->AddCustom("DistanceDiversification",
+ tflite::ops::custom::Register_DISTANCE_DIVERSIFICATION());
+ resolver->AddCustom("TextEncoder",
+ tflite::ops::custom::Register_TEXT_ENCODER());
+#endif // TC3_WITH_ACTIONS_OPS
#endif
return std::unique_ptr<tflite::OpResolver>(std::move(resolver));
}
diff --git a/utils/tflite/dist_diversification.cc b/utils/tflite/dist_diversification.cc
new file mode 100644
index 0000000..faf9be0
--- /dev/null
+++ b/utils/tflite/dist_diversification.cc
@@ -0,0 +1,155 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/tflite/dist_diversification.h"
+
+#include <algorithm>
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+// Returns a vector of row indices in a distance matrix.
+// Indices are increasing and the distance of every selected index to others
+// is larger than `min_distance`.
+template <typename DistanceMatrixType>
+std::vector<int> DiversifyByDistance(const DistanceMatrixType& distance_matrix,
+ const int matrix_size,
+ const float min_distance,
+ const int max_num_results) {
+ std::vector<int> result{0};
+ result.reserve(max_num_results);
+ int index = 1;
+ while (result.size() < max_num_results && index < matrix_size) {
+ for (; index < matrix_size; ++index) {
+ bool too_close = false;
+ for (const int selected_index : result) {
+ if (distance_matrix(index, selected_index) < min_distance) {
+ too_close = true;
+ break;
+ }
+ }
+ if (!too_close) {
+ result.push_back(index);
+ ++index;
+ break;
+ }
+ }
+ }
+ return result;
+}
+
+// Input parameters for the op.
+enum DistDiversificationInputs {
+ DIST_DIVERSIFICATION_INPUT_DISTANCE_MATRIX = 0,
+ DIST_DIVERSIFICATION_INPUT_MIN_DISTANCE = 1,
+ DIST_DIVERSIFICATION_INPUT_NUM_RESULTS = 2
+};
+
+// Output parameters for the op.
+enum DistDiversificationOutputs {
+ DIST_DIVERSIFICATION_OUTPUT_INDICES = 0,
+ DIST_DIVERSIFICATION_OUTPUT_LENGTH = 1,
+};
+
+TfLiteIntArray* CreateSizeArray(const std::initializer_list<int>& sizes) {
+ TfLiteIntArray* array_size = TfLiteIntArrayCreate(sizes.size());
+ int index = 0;
+ for (const int size : sizes) {
+ array_size->data[index++] = size;
+ }
+ return array_size;
+}
+
+TfLiteStatus AllocateOutputIndexes(TfLiteContext* context, TfLiteNode* node) {
+ const TfLiteTensor& num_results =
+ context
+ ->tensors[node->inputs->data[DIST_DIVERSIFICATION_INPUT_NUM_RESULTS]];
+ TfLiteTensor& output_indices =
+ context
+ ->tensors[node->outputs->data[DIST_DIVERSIFICATION_OUTPUT_INDICES]];
+ return context->ResizeTensor(context, &output_indices,
+ CreateSizeArray({num_results.data.i32[0]}));
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ const TfLiteTensor& num_results =
+ context
+ ->tensors[node->inputs->data[DIST_DIVERSIFICATION_INPUT_NUM_RESULTS]];
+ if (tflite::IsConstantTensor(&num_results)) {
+ TF_LITE_ENSURE_OK(context, AllocateOutputIndexes(context, node));
+ } else {
+ TfLiteTensor& output_indices =
+ context
+ ->tensors[node->outputs->data[DIST_DIVERSIFICATION_OUTPUT_INDICES]];
+ tflite::SetTensorToDynamic(&output_indices);
+ }
+ TfLiteTensor& output_length =
+ context->tensors[node->outputs->data[DIST_DIVERSIFICATION_OUTPUT_LENGTH]];
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, &output_length,
+ CreateSizeArray({1})));
+ return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ TfLiteTensor& output_indices =
+ context
+ ->tensors[node->outputs->data[DIST_DIVERSIFICATION_OUTPUT_INDICES]];
+ if (tflite::IsDynamicTensor(&output_indices)) {
+ TF_LITE_ENSURE_OK(context, AllocateOutputIndexes(context, node));
+ }
+ const TfLiteTensor& distance_matrix =
+ context->tensors[node->inputs
+ ->data[DIST_DIVERSIFICATION_INPUT_DISTANCE_MATRIX]];
+ const int distance_matrix_dim = distance_matrix.dims->data[0];
+ const float min_distance =
+ context
+ ->tensors[node->inputs->data[DIST_DIVERSIFICATION_INPUT_MIN_DISTANCE]]
+ .data.f[0];
+ const int num_results =
+ context
+ ->tensors[node->inputs->data[DIST_DIVERSIFICATION_INPUT_NUM_RESULTS]]
+ .data.i32[0];
+ const auto indices = DiversifyByDistance(
+ [&](int row, int col) {
+ return distance_matrix.data.f[row * distance_matrix_dim + col];
+ },
+ distance_matrix_dim, min_distance, num_results);
+ std::copy(indices.begin(), indices.end(), output_indices.data.i32);
+ std::fill_n(output_indices.data.i32 + indices.size(),
+ num_results - indices.size(), -1);
+ TfLiteTensor& output_length =
+ context->tensors[node->outputs->data[DIST_DIVERSIFICATION_OUTPUT_LENGTH]];
+ *output_length.data.i32 = indices.size();
+ return kTfLiteOk;
+}
+
+} // namespace
+} // namespace libtextclassifier3
+
+namespace tflite {
+namespace ops {
+namespace custom {
+TfLiteRegistration* Register_DISTANCE_DIVERSIFICATION() {
+ static TfLiteRegistration r = {nullptr, nullptr, libtextclassifier3::Prepare,
+ libtextclassifier3::Eval};
+ return &r;
+}
+} // namespace custom
+} // namespace ops
+} // namespace tflite
diff --git a/utils/tflite/dist_diversification.h b/utils/tflite/dist_diversification.h
new file mode 100644
index 0000000..924186d
--- /dev/null
+++ b/utils/tflite/dist_diversification.h
@@ -0,0 +1,32 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_TFLITE_DIST_DIVERSIFICATION_H_
+#define LIBTEXTCLASSIFIER_UTILS_TFLITE_DIST_DIVERSIFICATION_H_
+
+#include "tensorflow/contrib/lite/context.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+
+TfLiteRegistration* Register_DISTANCE_DIVERSIFICATION();
+
+} // namespace custom
+} // namespace ops
+} // namespace tflite
+
+#endif // LIBTEXTCLASSIFIER_UTILS_TFLITE_DIST_DIVERSIFICATION_H_
diff --git a/utils/tflite/dist_diversification_test.cc b/utils/tflite/dist_diversification_test.cc
new file mode 100644
index 0000000..6ed578c
--- /dev/null
+++ b/utils/tflite/dist_diversification_test.cc
@@ -0,0 +1,82 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/tflite/dist_diversification.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+class DistanceDiversificationOpModel : public tflite::SingleOpModel {
+ public:
+ explicit DistanceDiversificationOpModel(int matrix_rows);
+ void SetDistanceMatrix(const std::initializer_list<float>& values) {
+ PopulateTensor(distance_matrix_, values);
+ }
+ void SetNumOutput(int length) { PopulateTensor(num_results_, {length}); }
+ void SetMinDistance(float min_distance) {
+ PopulateTensor(min_distance_, {min_distance});
+ }
+ int GetOutputLen() { return ExtractVector<int>(output_len_).front(); }
+ std::vector<int> GetOutputIndexes(int output_length) {
+ auto res = ExtractVector<int>(output_indexes_);
+ res.resize(output_length);
+ return res;
+ }
+
+ private:
+ int distance_matrix_;
+ int num_results_;
+ int min_distance_;
+
+ int output_len_;
+ int output_indexes_;
+};
+
+DistanceDiversificationOpModel::DistanceDiversificationOpModel(
+ int matrix_rows) {
+ distance_matrix_ = AddInput(tflite::TensorType_FLOAT32);
+ min_distance_ = AddInput(tflite::TensorType_FLOAT32);
+ num_results_ = AddInput(tflite::TensorType_INT32);
+
+ output_indexes_ = AddOutput(tflite::TensorType_INT32);
+ output_len_ = AddOutput(tflite::TensorType_INT32);
+ SetCustomOp("DistanceDiversification", {},
+ tflite::ops::custom::Register_DISTANCE_DIVERSIFICATION);
+ BuildInterpreter({{matrix_rows, matrix_rows}, {1}, {1}});
+}
+
+// Tests
+TEST(DistanceDiversificationOp, Simple) {
+ DistanceDiversificationOpModel m(5);
+ m.SetDistanceMatrix({0.0, 0.1, 0.2, 0.3, 0.4, 0.1, 0.0, 0.1, 0.2,
+ 0.3, 0.2, 0.1, 0.0, 0.1, 0.2, 0.3, 0.2, 0.1,
+ 0.0, 0.1, 0.4, 0.3, 0.2, 0.1, 0.0});
+ m.SetMinDistance(0.21);
+ m.SetNumOutput(3);
+ m.Invoke();
+ const int output_length = m.GetOutputLen();
+ EXPECT_EQ(output_length, 2);
+ EXPECT_THAT(m.GetOutputIndexes(output_length), testing::ElementsAre(0, 3));
+}
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/utils/tflite/text_encoder.cc b/utils/tflite/text_encoder.cc
new file mode 100644
index 0000000..abc472e
--- /dev/null
+++ b/utils/tflite/text_encoder.cc
@@ -0,0 +1,377 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <memory>
+#include <vector>
+
+#include "utils/base/logging.h"
+#include "utils/sentencepiece/double_array_trie.h"
+#include "utils/sentencepiece/encoder.h"
+#include "utils/sentencepiece/normalizer.h"
+#include "utils/sentencepiece/sorted_strings_table.h"
+#include "utils/strings/stringpiece.h"
+#include "utils/tflite/text_encoder.h"
+#include "utils/tflite/text_encoder_config_generated.h"
+#include "flatbuffers/flatbuffers.h"
+#include "flatbuffers/flexbuffers.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/string_util.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+struct TextEncoderOp {
+ std::unique_ptr<SentencePieceNormalizer> normalizer;
+ std::unique_ptr<Encoder> encoder;
+ std::unique_ptr<SentencePieceMatcher> matcher;
+};
+
+// Input parameters for the op.
+enum TextEncoderInputs {
+ TEXT_ENCODER_INPUT_TEXTS = 0,
+ TEXT_ENCODER_INPUT_NUM_TEXTS = 1,
+ TEXT_ENCODER_INPUT_MAX_LENGTH = 2,
+ TEXT_ENCODER_INPUT_ATTR = 3
+};
+
+// Output parameters for the op.
+enum SmartReplyModelOutputs {
+ TEXT_ENCODER_OUTPUT_ENCODED = 0,
+ TEXT_ENCODER_OUTPUT_POSITION = 1,
+ TEXT_ENCODER_OUTPUT_LENGTHS = 2,
+ TEXT_ENCODER_OUTPUT_ATTR = 3,
+};
+
+const char kTextEncoderConfigAttr[] = "text_encoder_config";
+
+// Input rank is 2 since there is a dummy batch dimension of 1.
+const int kInputRank = 2;
+const int kBatchSize = 1;
+
+// Initializes text encoder object from serialized options:
+// The options are a flexbuffers attribute map that contain the op config
+// with the key `text_encoder_config` as `TextEncoderConfig`.
+void* Initialize(TfLiteContext* context, const char* buffer, size_t length) {
+ const flexbuffers::Map& attr_map =
+ flexbuffers::GetRoot(reinterpret_cast<const uint8_t*>(buffer), length)
+ .AsMap();
+ const flexbuffers::Blob serialized_config =
+ attr_map[kTextEncoderConfigAttr].AsBlob();
+ const TextEncoderConfig* config =
+ flatbuffers::GetRoot<TextEncoderConfig>(serialized_config.data());
+
+ std::unique_ptr<TextEncoderOp> encoder_op(new TextEncoderOp());
+
+ // Create normalizer from options.
+ const TrieNode* charsmap_trie_nodes = reinterpret_cast<const TrieNode*>(
+ config->normalization_charsmap()->Data());
+ const int charsmap_trie_nodes_length =
+ config->normalization_charsmap()->Length() / sizeof(TrieNode);
+ encoder_op->normalizer.reset(new SentencePieceNormalizer(
+ DoubleArrayTrie(charsmap_trie_nodes, charsmap_trie_nodes_length),
+ StringPiece(config->normalization_charsmap_values()->data(),
+ config->normalization_charsmap_values()->size()),
+ config->add_dummy_prefix(), config->remove_extra_whitespaces(),
+ config->escape_whitespaces()));
+
+ const int num_pieces = config->pieces_scores()->Length();
+
+ switch (config->matcher_type()) {
+ case SentencePieceMatcherType_MAPPED_TRIE: {
+ const TrieNode* pieces_trie_nodes =
+ reinterpret_cast<const TrieNode*>(config->pieces()->Data());
+ const int pieces_trie_nodes_length =
+ config->pieces()->Length() / sizeof(TrieNode);
+ encoder_op->matcher.reset(
+ new DoubleArrayTrie(pieces_trie_nodes, pieces_trie_nodes_length));
+ break;
+ }
+ case SentencePieceMatcherType_SORTED_STRING_TABLE: {
+ encoder_op->matcher.reset(new SortedStringsTable(
+ num_pieces, config->pieces_offsets()->data(),
+ StringPiece(config->pieces()->data(), config->pieces()->Length())));
+ break;
+ }
+ default: {
+ TC3_LOG(ERROR) << "Unknown sentence piece matcher type.";
+ return nullptr;
+ }
+ }
+ encoder_op->encoder.reset(new Encoder(
+ encoder_op->matcher.get(), num_pieces, config->pieces_scores()->data(),
+ config->start_code(), config->end_code(), config->encoding_offset(),
+ config->unknown_code(), config->unknown_score()));
+ return encoder_op.release();
+}
+
+void Free(TfLiteContext* context, void* buffer) {
+ delete reinterpret_cast<TextEncoderOp*>(buffer);
+}
+
+namespace {
+TfLiteIntArray* CreateSizeArray(const std::initializer_list<int>& sizes) {
+ TfLiteIntArray* array_size = TfLiteIntArrayCreate(sizes.size());
+ int index = 0;
+ for (const int size : sizes) {
+ array_size->data[index++] = size;
+ }
+ return array_size;
+}
+
+// Copies attributes values according to the encoding_offsets of every string.
+TfLiteStatus CopyAttribute(const TfLiteTensor& in,
+ const std::vector<int>& encoding_end_offsets,
+ int start_offset, TfLiteContext* context,
+ TfLiteTensor* out) {
+ TF_LITE_ENSURE_EQ(context, in.dims->size, kInputRank);
+ TF_LITE_ENSURE_EQ(context, in.dims->data[0], kBatchSize);
+ const int output_size = out->dims->data[1];
+ int output_offset = 0;
+ for (int value_index = 0;
+ value_index < encoding_end_offsets.size() && output_offset < output_size;
+ ++value_index) {
+ // Calculate how many elements need to be set with this value.
+ // The low bound depends on the offset from the beggining. If this is 0, it
+ // means that this value it truncated.
+ // The upper bound depends on how many elements are in the output tensor.
+ const int from_this_element =
+ std::min(std::max(0, encoding_end_offsets[value_index] - start_offset -
+ output_offset),
+ output_size - output_offset);
+ if (from_this_element == 0) {
+ continue;
+ }
+
+ switch (in.type) {
+ case kTfLiteInt32: {
+ std::fill(out->data.i32 + output_offset,
+ out->data.i32 + output_offset + from_this_element,
+ in.data.i32[value_index]);
+ } break;
+ case kTfLiteFloat32: {
+ std::fill(out->data.f + output_offset,
+ out->data.f + output_offset + from_this_element,
+ in.data.f[value_index]);
+ } break;
+ default:
+ context->ReportError(
+ (context), __FILE__ " Not supported attribute type %d", in.type);
+ return kTfLiteError;
+ }
+ output_offset += from_this_element;
+ }
+ // Do final padding.
+ switch (in.type) {
+ case kTfLiteInt32: {
+ const int32_t value =
+ (output_offset > 0) ? out->data.i32[output_offset - 1] : 0;
+ std::fill(out->data.i32 + output_offset, out->data.i32 + output_size,
+ value);
+ } break;
+ case kTfLiteFloat32: {
+ const float value =
+ (output_offset > 0) ? out->data.f[output_offset - 1] : 0;
+ std::fill(out->data.f + output_offset, out->data.f + output_size, value);
+ } break;
+ default:
+ break;
+ }
+ return kTfLiteOk;
+}
+
+TfLiteStatus ResizeOutputTensors(TfLiteContext* context, TfLiteNode* node,
+ int max_output_length) {
+ TfLiteTensor& output_encoded =
+ context->tensors[node->outputs->data[TEXT_ENCODER_OUTPUT_ENCODED]];
+
+ TF_LITE_ENSURE_OK(
+ context,
+ context->ResizeTensor(context, &output_encoded,
+ CreateSizeArray({kBatchSize, max_output_length})));
+
+ TfLiteTensor& output_positions =
+ context->tensors[node->outputs->data[TEXT_ENCODER_OUTPUT_POSITION]];
+
+ TF_LITE_ENSURE_OK(
+ context,
+ context->ResizeTensor(context, &output_positions,
+ CreateSizeArray({kBatchSize, max_output_length})));
+
+ const int num_output_attrs = node->outputs->size - TEXT_ENCODER_OUTPUT_ATTR;
+ for (int i = 0; i < num_output_attrs; ++i) {
+ TfLiteTensor& output =
+ context->tensors[node->outputs->data[TEXT_ENCODER_OUTPUT_ATTR + i]];
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(
+ context, &output,
+ CreateSizeArray({kBatchSize, max_output_length})));
+ }
+ return kTfLiteOk;
+}
+
+} // namespace
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ // Check that the batch dimension is kBatchSize.
+ const TfLiteTensor& input_text =
+ context->tensors[node->inputs->data[TEXT_ENCODER_INPUT_TEXTS]];
+ TF_LITE_ENSURE_EQ(context, input_text.dims->size, kInputRank);
+ TF_LITE_ENSURE_EQ(context, input_text.dims->data[0], kBatchSize);
+
+ TfLiteTensor& output_lengths =
+ context->tensors[node->outputs->data[TEXT_ENCODER_OUTPUT_LENGTHS]];
+ TfLiteTensor& output_encoded =
+ context->tensors[node->outputs->data[TEXT_ENCODER_OUTPUT_ENCODED]];
+ TfLiteTensor& output_positions =
+ context->tensors[node->outputs->data[TEXT_ENCODER_OUTPUT_POSITION]];
+
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, &output_lengths,
+ CreateSizeArray({kBatchSize})));
+
+ // Check that there are enough outputs for attributes.
+ const int num_output_attrs = node->outputs->size - TEXT_ENCODER_OUTPUT_ATTR;
+ TF_LITE_ENSURE_EQ(context, node->inputs->size - TEXT_ENCODER_INPUT_ATTR,
+ num_output_attrs);
+
+ // Copy attribute types from input to output tensors.
+ for (int i = 0; i < num_output_attrs; ++i) {
+ TfLiteTensor& input =
+ context->tensors[node->inputs->data[TEXT_ENCODER_INPUT_ATTR + i]];
+ TfLiteTensor& output =
+ context->tensors[node->outputs->data[TEXT_ENCODER_OUTPUT_ATTR + i]];
+ output.type = input.type;
+ }
+
+ const TfLiteTensor& output_length =
+ context->tensors[node->inputs->data[TEXT_ENCODER_INPUT_MAX_LENGTH]];
+
+ if (tflite::IsConstantTensor(&output_length)) {
+ return ResizeOutputTensors(context, node, output_length.data.i64[0]);
+ } else {
+ tflite::SetTensorToDynamic(&output_encoded);
+ tflite::SetTensorToDynamic(&output_positions);
+ for (int i = 0; i < num_output_attrs; ++i) {
+ TfLiteTensor& output_attr =
+ context->tensors[node->outputs->data[TEXT_ENCODER_OUTPUT_ATTR + i]];
+ tflite::SetTensorToDynamic(&output_attr);
+ }
+ }
+
+ return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ if (node->user_data == nullptr) {
+ return kTfLiteError;
+ }
+ const TextEncoderOp* encoder_op =
+ reinterpret_cast<TextEncoderOp*>(node->user_data);
+ const TfLiteTensor& input_text =
+ context->tensors[node->inputs->data[TEXT_ENCODER_INPUT_TEXTS]];
+ const int num_strings = tflite::GetStringCount(&input_text);
+ // Check that the number of strings matches the length parameter.
+ const int num_strings_param =
+ context->tensors[node->inputs->data[TEXT_ENCODER_INPUT_NUM_TEXTS]]
+ .data.i32[0];
+ TF_LITE_ENSURE_EQ(context, num_strings, num_strings_param);
+
+ TfLiteTensor& output_encoded =
+ context->tensors[node->outputs->data[TEXT_ENCODER_OUTPUT_ENCODED]];
+ if (tflite::IsDynamicTensor(&output_encoded)) {
+ const TfLiteTensor& output_length =
+ context->tensors[node->inputs->data[TEXT_ENCODER_INPUT_MAX_LENGTH]];
+ TF_LITE_ENSURE_OK(
+ context, ResizeOutputTensors(context, node, output_length.data.i64[0]));
+ }
+ TfLiteTensor& output_positions =
+ context->tensors[node->outputs->data[TEXT_ENCODER_OUTPUT_POSITION]];
+
+ std::vector<int> encoded_total;
+ std::vector<int> encoded_offsets;
+ std::vector<int> encoded_positions;
+ encoded_offsets.reserve(num_strings);
+ const int max_output_length = output_encoded.dims->data[1];
+ const int max_encoded_position = max_output_length;
+
+ for (int i = 0; i < num_strings; ++i) {
+ const auto& strref = tflite::GetString(&input_text, i);
+ const std::vector<int> encoded = encoder_op->encoder->Encode(
+ encoder_op->normalizer->Normalize(StringPiece(strref.str, strref.len)));
+ encoded_total.insert(encoded_total.end(), encoded.begin(), encoded.end());
+ encoded_offsets.push_back(encoded_total.size());
+ for (int i = 0; i < encoded.size(); i++) {
+ encoded_positions.push_back(std::min(i, max_encoded_position - 1));
+ }
+ }
+
+ // Copy encoding to output tensor.
+ const int start_offset =
+ std::max(0, static_cast<int>(encoded_total.size()) - max_output_length);
+ int output_offset = 0;
+ int32_t* output_buffer = output_encoded.data.i32;
+ int32_t* output_positions_buffer = output_positions.data.i32;
+ for (int i = start_offset; i < encoded_total.size(); ++i, ++output_offset) {
+ output_buffer[output_offset] = encoded_total[i];
+ output_positions_buffer[output_offset] = encoded_positions[i];
+ }
+
+ // Save output encoded length.
+ TfLiteTensor& output_lengths =
+ context->tensors[node->outputs->data[TEXT_ENCODER_OUTPUT_LENGTHS]];
+ output_lengths.data.i32[0] = output_offset;
+
+ // Do padding.
+ for (; output_offset < max_output_length; ++output_offset) {
+ output_buffer[output_offset] = encoded_total.back();
+ output_positions_buffer[output_offset] = max_encoded_position;
+ }
+
+ // Process attributes, all checks of sizes and types are done in Prepare.
+ const int num_output_attrs = node->outputs->size - TEXT_ENCODER_OUTPUT_ATTR;
+ TF_LITE_ENSURE_EQ(context, node->inputs->size - TEXT_ENCODER_INPUT_ATTR,
+ num_output_attrs);
+ for (int i = 0; i < num_output_attrs; ++i) {
+ TfLiteStatus attr_status = CopyAttribute(
+ context->tensors[node->inputs->data[TEXT_ENCODER_INPUT_ATTR + i]],
+ encoded_offsets, start_offset, context,
+ &context->tensors[node->outputs->data[TEXT_ENCODER_OUTPUT_ATTR + i]]);
+ if (attr_status != kTfLiteOk) {
+ return attr_status;
+ }
+ }
+
+ return kTfLiteOk;
+}
+
+} // namespace
+} // namespace libtextclassifier3
+
+namespace tflite {
+namespace ops {
+namespace custom {
+
+TfLiteRegistration* Register_TEXT_ENCODER() {
+ static TfLiteRegistration registration = {
+ libtextclassifier3::Initialize, libtextclassifier3::Free,
+ libtextclassifier3::Prepare, libtextclassifier3::Eval};
+ return ®istration;
+}
+
+} // namespace custom
+} // namespace ops
+} // namespace tflite
diff --git a/utils/tflite/text_encoder.h b/utils/tflite/text_encoder.h
new file mode 100644
index 0000000..1143031
--- /dev/null
+++ b/utils/tflite/text_encoder.h
@@ -0,0 +1,32 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_TFLITE_TEXT_ENCODER_H_
+#define LIBTEXTCLASSIFIER_UTILS_TFLITE_TEXT_ENCODER_H_
+
+#include "tensorflow/contrib/lite/context.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+
+TfLiteRegistration* Register_TEXT_ENCODER();
+
+} // namespace custom
+} // namespace ops
+} // namespace tflite
+
+#endif // LIBTEXTCLASSIFIER_UTILS_TFLITE_TEXT_ENCODER_H_
diff --git a/utils/tflite/text_encoder_config.fbs b/utils/tflite/text_encoder_config.fbs
new file mode 100644
index 0000000..8ae8fc5
--- /dev/null
+++ b/utils/tflite/text_encoder_config.fbs
@@ -0,0 +1,65 @@
+//
+// Copyright (C) 2018 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+// Configuration for the text encoder op.
+
+namespace libtextclassifier3;
+
+enum SentencePieceMatcherType : byte {
+ MAPPED_TRIE = 0,
+ SORTED_STRING_TABLE = 1,
+}
+
+table TextEncoderConfig {
+ // Code that is used as encoding of the start code.
+ start_code:int32 = 0;
+
+ // Code that is used as encoding of the end code.
+ end_code:int32 = 1;
+
+ // This value is added to all codes to make them not intersect with
+ // `start_code` and `end_code`.
+ encoding_offset:int32 = 2;
+
+ // Code that is used for out-of-dictionary characters.
+ unknown_code:int32 = -1;
+
+ // Penalty associated with the unknown code.
+ unknown_score:float;
+
+ // Normalization options.
+ // Serialized normalization charsmap.
+ normalization_charsmap:string;
+ normalization_charsmap_values:string;
+
+ // Whether to add dummy whitespace at the beginning of the text in order to
+ // treat "world" in "world" and "hello world" uniformly.
+ add_dummy_prefix:bool = true;
+
+ // Whether to remove leading, trailing and duplicate internal whitespace.
+ remove_extra_whitespaces:bool = true;
+
+ // Whether to replace whitespace with a meta symbol.
+ escape_whitespaces:bool = true;
+
+ // Sentence pieces scores.
+ pieces_scores:[float];
+
+ // Serialized sentence pieces.
+ pieces:string;
+ pieces_offsets:[int32];
+ matcher_type: SentencePieceMatcherType = MAPPED_TRIE;
+}
diff --git a/utils/tflite/text_encoder_test.cc b/utils/tflite/text_encoder_test.cc
new file mode 100644
index 0000000..0cd67ce
--- /dev/null
+++ b/utils/tflite/text_encoder_test.cc
@@ -0,0 +1,170 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <fstream>
+#include <string>
+#include <vector>
+
+#include "utils/tflite/text_encoder.h"
+#include "gtest/gtest.h"
+#include "third_party/absl/flags/flag.h"
+#include "flatbuffers/flexbuffers.h"
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/string_util.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+std::string GetTestConfigPath() {
+ return "";
+}
+
+class TextEncoderOpModel : public tflite::SingleOpModel {
+ public:
+ TextEncoderOpModel(std::initializer_list<int> input_strings_shape,
+ std::initializer_list<int> attribute_shape);
+ void SetInputText(const std::initializer_list<string>& strings) {
+ PopulateStringTensor(input_string_, strings);
+ PopulateTensor(input_length_, {static_cast<int32_t>(strings.size())});
+ }
+ void SetMaxOutputLength(int length) {
+ PopulateTensor(input_output_maxlength_, {length});
+ }
+ void SetInt32Attribute(const std::initializer_list<int>& attribute) {
+ PopulateTensor(input_attributes_int32_, attribute);
+ }
+ void SetFloatAttribute(const std::initializer_list<float>& attribute) {
+ PopulateTensor(input_attributes_float_, attribute);
+ }
+
+ std::vector<int> GetOutputEncoding() {
+ return ExtractVector<int>(output_encoding_);
+ }
+ std::vector<int> GetOutputPositions() {
+ return ExtractVector<int>(output_positions_);
+ }
+ std::vector<int> GetOutputAttributeInt32() {
+ return ExtractVector<int>(output_attributes_int32_);
+ }
+ std::vector<float> GetOutputAttributeFloat() {
+ return ExtractVector<float>(output_attributes_float_);
+ }
+ int GetEncodedLength() { return ExtractVector<int>(output_length_)[0]; }
+
+ private:
+ int input_string_;
+ int input_length_;
+ int input_output_maxlength_;
+ int input_attributes_int32_;
+ int input_attributes_float_;
+
+ int output_encoding_;
+ int output_positions_;
+ int output_length_;
+ int output_attributes_int32_;
+ int output_attributes_float_;
+};
+
+TextEncoderOpModel::TextEncoderOpModel(
+ std::initializer_list<int> input_strings_shape,
+ std::initializer_list<int> attribute_shape) {
+ input_string_ = AddInput(tflite::TensorType_STRING);
+ input_length_ = AddInput(tflite::TensorType_INT32);
+ input_output_maxlength_ = AddInput(tflite::TensorType_INT32);
+ input_attributes_int32_ = AddInput(tflite::TensorType_INT32);
+ input_attributes_float_ = AddInput(tflite::TensorType_FLOAT32);
+
+ output_encoding_ = AddOutput(tflite::TensorType_INT32);
+ output_positions_ = AddOutput(tflite::TensorType_INT32);
+ output_length_ = AddOutput(tflite::TensorType_INT32);
+ output_attributes_int32_ = AddOutput(tflite::TensorType_INT32);
+ output_attributes_float_ = AddOutput(tflite::TensorType_FLOAT32);
+
+ std::ifstream test_config_stream(GetTestConfigPath());
+ std::string config((std::istreambuf_iterator<char>(test_config_stream)),
+ (std::istreambuf_iterator<char>()));
+ flexbuffers::Builder builder;
+ builder.Map([&]() { builder.String("text_encoder_config", config); });
+ builder.Finish();
+ SetCustomOp("TextEncoder", builder.GetBuffer(),
+ tflite::ops::custom::Register_TEXT_ENCODER);
+ BuildInterpreter(
+ {input_strings_shape, {1}, {1}, attribute_shape, attribute_shape});
+}
+
+// Tests
+TEST(TextEncoderTest, SimpleEncoder) {
+ TextEncoderOpModel m({1, 1}, {1, 1});
+ m.SetInputText({"Hello"});
+ m.SetMaxOutputLength(10);
+ m.SetInt32Attribute({7});
+ m.SetFloatAttribute({3.f});
+ m.Invoke();
+ EXPECT_EQ(m.GetEncodedLength(), 5);
+ EXPECT_THAT(m.GetOutputEncoding(),
+ testing::ElementsAre(1, 90, 547, 58, 2, 2, 2, 2, 2, 2));
+ EXPECT_THAT(m.GetOutputPositions(),
+ testing::ElementsAre(0, 1, 2, 3, 4, 10, 10, 10, 10, 10));
+ EXPECT_THAT(m.GetOutputAttributeInt32(),
+ testing::ElementsAre(7, 7, 7, 7, 7, 7, 7, 7, 7, 7));
+ EXPECT_THAT(
+ m.GetOutputAttributeFloat(),
+ testing::ElementsAre(3.f, 3.f, 3.f, 3.f, 3.f, 3.f, 3.f, 3.f, 3.f, 3.f));
+}
+
+TEST(TextEncoderTest, ManyStrings) {
+ TextEncoderOpModel m({1, 3}, {1, 3});
+ m.SetInt32Attribute({1, 2, 3});
+ m.SetFloatAttribute({5.f, 4.f, 3.f});
+ m.SetInputText({"Hello", "Hi", "Bye"});
+ m.SetMaxOutputLength(10);
+ m.Invoke();
+ EXPECT_EQ(m.GetEncodedLength(), 10);
+ EXPECT_THAT(m.GetOutputEncoding(),
+ testing::ElementsAre(547, 58, 2, 1, 862, 2, 1, 1919, 19, 2));
+ EXPECT_THAT(m.GetOutputPositions(),
+ testing::ElementsAre(2, 3, 4, 0, 1, 2, 0, 1, 2, 3));
+ EXPECT_THAT(m.GetOutputAttributeInt32(),
+ testing::ElementsAre(1, 1, 1, 2, 2, 2, 3, 3, 3, 3));
+ EXPECT_THAT(
+ m.GetOutputAttributeFloat(),
+ testing::ElementsAre(5.f, 5.f, 5.f, 4.f, 4.f, 4.f, 3.f, 3.f, 3.f, 3.f));
+}
+
+TEST(TextEncoderTest, LongStrings) {
+ TextEncoderOpModel m({1, 4}, {1, 4});
+ m.SetInt32Attribute({1, 2, 3, 4});
+ m.SetFloatAttribute({5.f, 4.f, 3.f, 2.f});
+ m.SetInputText({"Hello", "Hi", "Bye", "Hi"});
+ m.SetMaxOutputLength(9);
+ m.Invoke();
+ EXPECT_EQ(m.GetEncodedLength(), 9);
+ EXPECT_THAT(m.GetOutputEncoding(),
+ testing::ElementsAre(862, 2, 1, 1919, 19, 2, 1, 862, 2));
+ EXPECT_THAT(m.GetOutputPositions(),
+ testing::ElementsAre(1, 2, 0, 1, 2, 3, 0, 1, 2));
+ EXPECT_THAT(m.GetOutputAttributeInt32(),
+ testing::ElementsAre(2, 2, 3, 3, 3, 3, 4, 4, 4));
+ EXPECT_THAT(
+ m.GetOutputAttributeFloat(),
+ testing::ElementsAre(4.f, 4.f, 3.f, 3.f, 3.f, 3.f, 2.f, 2.f, 2.f));
+}
+
+} // namespace
+} // namespace libtextclassifier3